sql/
migrate.rs

1use std::collections::HashMap;
2
3use crate::query::{AlterTable, Update};
4use crate::{AlterAction, Constraint, Dialect, DropTable, Index, Schema, Table, ToSql};
5use anyhow::Result;
6use topo_sort::{SortResults, TopoSort};
7
8#[derive(Debug, Clone, Default)]
9pub struct MigrationOptions {
10    pub debug: bool,
11    pub allow_destructive: bool,
12}
13
14pub fn migrate(current: Schema, desired: Schema, options: &MigrationOptions) -> Result<Migration> {
15    let current_tables = current
16        .tables
17        .iter()
18        .map(|t| (&t.name, t))
19        .collect::<HashMap<_, _>>();
20    let desired_tables = desired
21        .tables
22        .iter()
23        .map(|t| (&t.name, t))
24        .collect::<HashMap<_, _>>();
25
26    let mut debug_results = vec![];
27    let mut statements = Vec::new();
28    // new tables
29    for (_name, &table) in desired_tables
30        .iter()
31        .filter(|(name, _)| !current_tables.contains_key(*name))
32    {
33        let statement = Statement::CreateTable(table.clone());
34        statements.push(statement);
35    }
36
37    // alter existing tables
38    for (name, desired_table) in desired_tables
39        .iter()
40        .filter(|(name, _)| current_tables.contains_key(*name))
41    {
42        let current_table = current_tables[name];
43        let current_columns = current_table
44            .columns
45            .iter()
46            .map(|c| (&c.name, c))
47            .collect::<HashMap<_, _>>();
48        // add columns
49        let mut actions = vec![];
50        for desired_column in desired_table.columns.iter() {
51            if let Some(current) = current_columns.get(&desired_column.name) {
52                if current.nullable != desired_column.nullable {
53                    actions.push(AlterAction::set_nullable(
54                        desired_column.name.clone(),
55                        desired_column.nullable,
56                    ));
57                }
58                if !desired_column.typ.lossy_eq(&current.typ) {
59                    actions.push(AlterAction::set_type(
60                        desired_column.name.clone(),
61                        desired_column.typ.clone(),
62                    ));
63                };
64                if desired_column.constraint.is_some() && current.constraint.is_none() {
65                    if let Some(c) = &desired_column.constraint {
66                        let name = desired_column.name.clone();
67                        actions.push(AlterAction::add_constraint(
68                            &desired_table.name,
69                            name,
70                            c.clone(),
71                        ));
72                    }
73                }
74            } else {
75                // add the column can be in 1 step if the column is nullable
76                if desired_column.nullable {
77                    actions.push(AlterAction::AddColumn {
78                        column: desired_column.clone(),
79                    });
80                } else {
81                    let mut nullable = desired_column.clone();
82                    nullable.nullable = true;
83                    statements.push(Statement::AlterTable(AlterTable {
84                        schema: desired_table.schema.clone(),
85                        name: desired_table.name.clone(),
86                        actions: vec![AlterAction::AddColumn { column: nullable }],
87                    }));
88                    statements.push(Statement::Update(
89                        Update::new(name)
90                            .set(
91                                &desired_column.name,
92                                "/* TODO set a value before setting the column to null */",
93                            )
94                            .where_(crate::query::Where::raw("true")),
95                    ));
96                    statements.push(Statement::AlterTable(AlterTable {
97                        schema: desired_table.schema.clone(),
98                        name: desired_table.name.clone(),
99                        actions: vec![AlterAction::AlterColumn {
100                            name: desired_column.name.clone(),
101                            action: crate::query::AlterColumnAction::SetNullable(false),
102                        }],
103                    }));
104                }
105            }
106        }
107        if actions.is_empty() {
108            debug_results.push(DebugResults::TablesIdentical(name.to_string()));
109        } else {
110            statements.push(Statement::AlterTable(AlterTable {
111                schema: desired_table.schema.clone(),
112                name: desired_table.name.clone(),
113                actions,
114            }));
115        }
116    }
117
118    for (_name, current_table) in current_tables
119        .iter()
120        .filter(|(name, _)| !desired_tables.contains_key(*name))
121    {
122        if options.allow_destructive {
123            statements.push(Statement::DropTable(DropTable {
124                schema: current_table.schema.clone(),
125                name: current_table.name.clone(),
126            }));
127        } else {
128            debug_results.push(DebugResults::SkippedDropTable(current_table.name.clone()));
129        }
130    }
131
132    // Sort statements topologically based on foreign key dependencies
133    let sorted_statements = topologically_sort_statements(&statements, &desired_tables);
134
135    Ok(Migration {
136        statements: sorted_statements,
137        debug_results,
138    })
139}
140
141/// Topologically sorts the migration statements based on foreign key dependencies
142fn topologically_sort_statements(
143    statements: &[Statement],
144    tables: &HashMap<&String, &crate::schema::Table>,
145) -> Vec<Statement> {
146    // First, extract create table statements
147    let create_statements: Vec<_> = statements
148        .iter()
149        .filter(|s| matches!(s, Statement::CreateTable(_)))
150        .collect();
151
152    if create_statements.is_empty() {
153        // If there are no create statements, just return the original
154        return statements.to_vec();
155    }
156
157    // Build a map of table name to index in the statements array
158    let mut table_to_index = HashMap::new();
159    for (i, stmt) in create_statements.iter().enumerate() {
160        if let Statement::CreateTable(create) = stmt {
161            table_to_index.insert(create.name.clone(), i);
162        }
163    }
164
165    // Set up topological sort
166    let mut topo_sort = TopoSort::new();
167
168    // Find table dependencies and add them to topo_sort
169    for stmt in &create_statements {
170        if let Statement::CreateTable(create) = stmt {
171            let table_name = &create.name;
172            let mut dependencies = Vec::new();
173
174            // Get the actual table from the tables map
175            if let Some(table) = tables.values().find(|t| &t.name == table_name) {
176                // Check all columns for foreign key constraints
177                for column in &table.columns {
178                    if let Some(Constraint::ForeignKey(fk)) = &column.constraint {
179                        dependencies.push(fk.table.clone());
180                    }
181                }
182            }
183
184            // Add this table and its dependencies to the topo_sort
185            dbg!(table_name, &dependencies);
186            topo_sort.insert(table_name.clone(), dependencies);
187        }
188    }
189
190    // Perform the sort
191    let table_order = match topo_sort.into_vec_nodes() {
192        SortResults::Full(nodes) => nodes,
193        SortResults::Partial(nodes) => {
194            // Return partial results even if there's a cycle
195            nodes
196        }
197    };
198
199    // First create a sorted list of CREATE TABLE statements
200    let mut sorted_statements = Vec::new();
201    for table_name in &table_order {
202        if let Some(&idx) = table_to_index.get(table_name) {
203            sorted_statements.push(create_statements[idx].clone());
204        }
205    }
206
207    // Add remaining statements (non-create-table) in their original order
208    for stmt in statements {
209        if !matches!(stmt, Statement::CreateTable(_)) {
210            sorted_statements.push(stmt.clone());
211        }
212    }
213
214    sorted_statements
215}
216
217#[derive(Debug)]
218pub struct Migration {
219    pub statements: Vec<Statement>,
220    pub debug_results: Vec<DebugResults>,
221}
222
223impl Migration {
224    pub fn is_empty(&self) -> bool {
225        self.statements.is_empty()
226    }
227
228    pub fn set_schema(&mut self, schema_name: &str) {
229        for statement in &mut self.statements {
230            statement.set_schema(schema_name);
231        }
232    }
233}
234
235#[derive(Debug, Clone, PartialEq, Eq)]
236pub enum Statement {
237    CreateTable(Table),
238    CreateIndex(Index),
239    AlterTable(AlterTable),
240    DropTable(DropTable),
241    Update(Update),
242}
243
244impl Statement {
245    pub fn set_schema(&mut self, schema_name: &str) {
246        match self {
247            Statement::CreateTable(s) => {
248                s.schema = Some(schema_name.to_string());
249            }
250            Statement::AlterTable(s) => {
251                s.schema = Some(schema_name.to_string());
252            }
253            Statement::DropTable(s) => {
254                s.schema = Some(schema_name.to_string());
255            }
256            Statement::CreateIndex(s) => {
257                s.schema = Some(schema_name.to_string());
258            }
259            Statement::Update(s) => {
260                s.schema = Some(schema_name.to_string());
261            }
262        }
263    }
264
265    pub fn table_name(&self) -> &str {
266        match self {
267            Statement::CreateTable(s) => &s.name,
268            Statement::AlterTable(s) => &s.name,
269            Statement::DropTable(s) => &s.name,
270            Statement::CreateIndex(s) => &s.table,
271            Statement::Update(s) => &s.table,
272        }
273    }
274}
275
276impl ToSql for Statement {
277    fn write_sql(&self, buf: &mut String, dialect: Dialect) {
278        use Statement::*;
279        match self {
280            CreateTable(c) => c.write_sql(buf, dialect),
281            CreateIndex(c) => c.write_sql(buf, dialect),
282            AlterTable(a) => a.write_sql(buf, dialect),
283            DropTable(d) => d.write_sql(buf, dialect),
284            Update(u) => u.write_sql(buf, dialect),
285        }
286    }
287}
288
289#[derive(Debug)]
290pub enum DebugResults {
291    TablesIdentical(String),
292    SkippedDropTable(String),
293}
294
295impl DebugResults {
296    pub fn table_name(&self) -> &str {
297        match self {
298            DebugResults::TablesIdentical(name) => name,
299            DebugResults::SkippedDropTable(name) => name,
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    use crate::Table;
309    use crate::Type;
310    use crate::schema::{Column, Constraint, ForeignKey};
311
312    #[test]
313    fn test_drop_table() {
314        let empty_schema = Schema::default();
315        let mut single_table_schema = Schema::default();
316        let t = Table::new("new_table");
317        single_table_schema.tables.push(t.clone());
318        let mut allow_destructive_options = MigrationOptions::default();
319        allow_destructive_options.allow_destructive = true;
320
321        let mut migrations = migrate(
322            single_table_schema,
323            empty_schema,
324            &allow_destructive_options,
325        )
326        .unwrap();
327
328        let statement = migrations.statements.pop().unwrap();
329        let expected_statement = Statement::DropTable(DropTable {
330            schema: t.schema,
331            name: t.name,
332        });
333
334        assert_eq!(statement, expected_statement);
335    }
336
337    #[test]
338    fn test_drop_table_without_destructive_operations() {
339        let empty_schema = Schema::default();
340        let mut single_table_schema = Schema::default();
341        let t = Table::new("new_table");
342        single_table_schema.tables.push(t.clone());
343        let options = MigrationOptions::default();
344
345        let migrations = migrate(single_table_schema, empty_schema, &options).unwrap();
346        assert!(migrations.statements.is_empty());
347    }
348
349    #[test]
350    fn test_topological_sort_statements() {
351        let empty_schema = Schema::default();
352        let mut schema_with_tables = Schema::default();
353
354        // Create dependent tables: User depends on Team
355        let team_table = Table::new("team").column(Column {
356            name: "id".to_string(),
357            typ: Type::I32,
358            nullable: false,
359            primary_key: true,
360            default: None,
361            constraint: None,
362            generated: None,
363        });
364
365        let user_table = Table::new("user")
366            .column(Column {
367                name: "id".to_string(),
368                typ: Type::I32,
369                nullable: false,
370                primary_key: true,
371                default: None,
372                constraint: None,
373                generated: None,
374            })
375            .column(Column {
376                name: "team_id".to_string(),
377                typ: Type::I32,
378                nullable: false,
379                primary_key: false,
380                default: None,
381                constraint: Some(Constraint::ForeignKey(ForeignKey {
382                    table: "team".to_string(),
383                    columns: vec!["id".to_string()],
384                })),
385                generated: None,
386            });
387
388        schema_with_tables.tables.push(user_table);
389        schema_with_tables.tables.push(team_table);
390
391        let options = MigrationOptions::default();
392
393        // Generate migration
394        let migration = migrate(empty_schema, schema_with_tables, &options).unwrap();
395
396        // Check that team table is created before user table
397        let team_index = migration
398            .statements
399            .iter()
400            .position(|s| {
401                if let Statement::CreateTable(create) = s {
402                    create.name == "team"
403                } else {
404                    false
405                }
406            })
407            .unwrap();
408
409        let user_index = migration
410            .statements
411            .iter()
412            .position(|s| {
413                if let Statement::CreateTable(create) = s {
414                    create.name == "user"
415                } else {
416                    false
417                }
418            })
419            .unwrap();
420
421        assert!(
422            team_index < user_index,
423            "Team table should be created before User table"
424        );
425    }
426}