sql/
migrate.rs

1use std::collections::HashMap;
2
3use crate::query::{AlterTable, Update};
4use crate::{AlterAction, Constraint, Dialect, DropTable, Index, Schema, Table, ToSql};
5use anyhow::Result;
6use topo_sort::{SortResults, TopoSort};
7
8#[derive(Debug, Clone, Default)]
9pub struct MigrationOptions {
10    pub debug: bool,
11    pub allow_destructive: bool,
12}
13
14pub fn migrate(current: Schema, desired: Schema, options: &MigrationOptions) -> Result<Migration> {
15    let current_tables = current
16        .tables
17        .iter()
18        .map(|t| (&t.name, t))
19        .collect::<HashMap<_, _>>();
20    let desired_tables = desired
21        .tables
22        .iter()
23        .map(|t| (&t.name, t))
24        .collect::<HashMap<_, _>>();
25
26    let mut debug_results = vec![];
27    let mut statements = Vec::new();
28    // new tables
29    for (_name, &table) in desired_tables
30        .iter()
31        .filter(|(name, _)| !current_tables.contains_key(*name))
32    {
33        let statement = Statement::CreateTable(table.clone());
34        statements.push(statement);
35    }
36
37    // alter existing tables
38    for (name, desired_table) in desired_tables
39        .iter()
40        .filter(|(name, _)| current_tables.contains_key(*name))
41    {
42        let current_table = current_tables[name];
43        let current_columns = current_table
44            .columns
45            .iter()
46            .map(|c| (&c.name, c))
47            .collect::<HashMap<_, _>>();
48        // add columns
49        let mut actions = vec![];
50        for desired_column in desired_table.columns.iter() {
51            if let Some(current) = current_columns.get(&desired_column.name) {
52                if current.nullable != desired_column.nullable {
53                    actions.push(AlterAction::set_nullable(
54                        desired_column.name.clone(),
55                        desired_column.nullable,
56                    ));
57                }
58                if !desired_column.typ.lossy_eq(&current.typ) {
59                    actions.push(AlterAction::set_type(
60                        desired_column.name.clone(),
61                        desired_column.typ.clone(),
62                    ));
63                };
64                if desired_column.constraint.is_some() && current.constraint.is_none() {
65                    if let Some(c) = &desired_column.constraint {
66                        let name = desired_column.name.clone();
67                        actions.push(AlterAction::add_constraint(
68                            &desired_table.name,
69                            name,
70                            c.clone(),
71                        ));
72                    }
73                }
74            } else {
75                // add the column can be in 1 step if the column is nullable
76                if desired_column.nullable {
77                    actions.push(AlterAction::AddColumn {
78                        column: desired_column.clone(),
79                    });
80                } else {
81                    let mut nullable = desired_column.clone();
82                    nullable.nullable = true;
83                    statements.push(Statement::AlterTable(AlterTable {
84                        schema: desired_table.schema.clone(),
85                        name: desired_table.name.clone(),
86                        actions: vec![AlterAction::AddColumn { column: nullable }],
87                    }));
88                    statements.push(Statement::Update(
89                        Update::new(name)
90                            .set(
91                                &desired_column.name,
92                                "/* TODO set a value before setting the column to null */",
93                            )
94                            .where_(crate::query::Where::raw("true")),
95                    ));
96                    statements.push(Statement::AlterTable(AlterTable {
97                        schema: desired_table.schema.clone(),
98                        name: desired_table.name.clone(),
99                        actions: vec![AlterAction::AlterColumn {
100                            name: desired_column.name.clone(),
101                            action: crate::query::AlterColumnAction::SetNullable(false),
102                        }],
103                    }));
104                }
105            }
106        }
107        if actions.is_empty() {
108            debug_results.push(DebugResults::TablesIdentical(name.to_string()));
109        } else {
110            statements.push(Statement::AlterTable(AlterTable {
111                schema: desired_table.schema.clone(),
112                name: desired_table.name.clone(),
113                actions,
114            }));
115        }
116    }
117
118    for (_name, current_table) in current_tables
119        .iter()
120        .filter(|(name, _)| !desired_tables.contains_key(*name))
121    {
122        if options.allow_destructive {
123            statements.push(Statement::DropTable(DropTable {
124                schema: current_table.schema.clone(),
125                name: current_table.name.clone(),
126            }));
127        } else {
128            debug_results.push(DebugResults::SkippedDropTable(current_table.name.clone()));
129        }
130    }
131
132    // Sort statements topologically based on foreign key dependencies
133    let sorted_statements = topologically_sort_statements(&statements, &desired_tables);
134
135    Ok(Migration {
136        statements: sorted_statements,
137        debug_results,
138    })
139}
140
141/// Topologically sorts the migration statements based on foreign key dependencies
142fn topologically_sort_statements(
143    statements: &[Statement],
144    tables: &HashMap<&String, &crate::schema::Table>,
145) -> Vec<Statement> {
146    // First, extract create table statements
147    let create_statements: Vec<_> = statements
148        .iter()
149        .filter(|s| matches!(s, Statement::CreateTable(_)))
150        .collect();
151
152    if create_statements.is_empty() {
153        // If there are no create statements, just return the original
154        return statements.to_vec();
155    }
156
157    // Build a map of table name to index in the statements array
158    let mut table_to_index = HashMap::new();
159    for (i, stmt) in create_statements.iter().enumerate() {
160        if let Statement::CreateTable(create) = stmt {
161            table_to_index.insert(create.name.clone(), i);
162        }
163    }
164
165    // Set up topological sort
166    let mut topo_sort = TopoSort::new();
167
168    // Find table dependencies and add them to topo_sort
169    for stmt in &create_statements {
170        if let Statement::CreateTable(create) = stmt {
171            let table_name = &create.name;
172            let mut dependencies = Vec::new();
173
174            // Get the actual table from the tables map
175            if let Some(table) = tables.values().find(|t| &t.name == table_name) {
176                // Check all columns for foreign key constraints
177                for column in &table.columns {
178                    if let Some(Constraint::ForeignKey(fk)) = &column.constraint {
179                        dependencies.push(fk.table.clone());
180                    }
181                }
182            }
183
184            // Add this table and its dependencies to the topo_sort
185            topo_sort.insert(table_name.clone(), dependencies);
186        }
187    }
188
189    // Perform the sort
190    let table_order = match topo_sort.into_vec_nodes() {
191        SortResults::Full(nodes) => nodes,
192        SortResults::Partial(nodes) => {
193            // Return partial results even if there's a cycle
194            nodes
195        }
196    };
197
198    // First create a sorted list of CREATE TABLE statements
199    let mut sorted_statements = Vec::new();
200    for table_name in &table_order {
201        if let Some(&idx) = table_to_index.get(table_name) {
202            sorted_statements.push(create_statements[idx].clone());
203        }
204    }
205
206    // Add remaining statements (non-create-table) in their original order
207    for stmt in statements {
208        if !matches!(stmt, Statement::CreateTable(_)) {
209            sorted_statements.push(stmt.clone());
210        }
211    }
212
213    sorted_statements
214}
215
216#[derive(Debug)]
217pub struct Migration {
218    pub statements: Vec<Statement>,
219    pub debug_results: Vec<DebugResults>,
220}
221
222impl Migration {
223    pub fn is_empty(&self) -> bool {
224        self.statements.is_empty()
225    }
226
227    pub fn set_schema(&mut self, schema_name: &str) {
228        for statement in &mut self.statements {
229            statement.set_schema(schema_name);
230        }
231    }
232}
233
234#[derive(Debug, Clone, PartialEq, Eq)]
235pub enum Statement {
236    CreateTable(Table),
237    CreateIndex(Index),
238    AlterTable(AlterTable),
239    DropTable(DropTable),
240    Update(Update),
241}
242
243impl Statement {
244    pub fn set_schema(&mut self, schema_name: &str) {
245        match self {
246            Statement::CreateTable(s) => {
247                s.schema = Some(schema_name.to_string());
248            }
249            Statement::AlterTable(s) => {
250                s.schema = Some(schema_name.to_string());
251            }
252            Statement::DropTable(s) => {
253                s.schema = Some(schema_name.to_string());
254            }
255            Statement::CreateIndex(s) => {
256                s.schema = Some(schema_name.to_string());
257            }
258            Statement::Update(s) => {
259                s.schema = Some(schema_name.to_string());
260            }
261        }
262    }
263
264    pub fn table_name(&self) -> &str {
265        match self {
266            Statement::CreateTable(s) => &s.name,
267            Statement::AlterTable(s) => &s.name,
268            Statement::DropTable(s) => &s.name,
269            Statement::CreateIndex(s) => &s.table,
270            Statement::Update(s) => &s.table,
271        }
272    }
273}
274
275impl ToSql for Statement {
276    fn write_sql(&self, buf: &mut String, dialect: Dialect) {
277        use Statement::*;
278        match self {
279            CreateTable(c) => c.write_sql(buf, dialect),
280            CreateIndex(c) => c.write_sql(buf, dialect),
281            AlterTable(a) => a.write_sql(buf, dialect),
282            DropTable(d) => d.write_sql(buf, dialect),
283            Update(u) => u.write_sql(buf, dialect),
284        }
285    }
286}
287
288#[derive(Debug)]
289pub enum DebugResults {
290    TablesIdentical(String),
291    SkippedDropTable(String),
292}
293
294impl DebugResults {
295    pub fn table_name(&self) -> &str {
296        match self {
297            DebugResults::TablesIdentical(name) => name,
298            DebugResults::SkippedDropTable(name) => name,
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    use crate::Table;
308    use crate::Type;
309    use crate::schema::{Column, Constraint, ForeignKey};
310
311    #[test]
312    fn test_drop_table() {
313        let empty_schema = Schema::default();
314        let mut single_table_schema = Schema::default();
315        let t = Table::new("new_table");
316        single_table_schema.tables.push(t.clone());
317        let mut allow_destructive_options = MigrationOptions::default();
318        allow_destructive_options.allow_destructive = true;
319
320        let mut migrations = migrate(
321            single_table_schema,
322            empty_schema,
323            &allow_destructive_options,
324        )
325        .unwrap();
326
327        let statement = migrations.statements.pop().unwrap();
328        let expected_statement = Statement::DropTable(DropTable {
329            schema: t.schema,
330            name: t.name,
331        });
332
333        assert_eq!(statement, expected_statement);
334    }
335
336    #[test]
337    fn test_drop_table_without_destructive_operations() {
338        let empty_schema = Schema::default();
339        let mut single_table_schema = Schema::default();
340        let t = Table::new("new_table");
341        single_table_schema.tables.push(t.clone());
342        let options = MigrationOptions::default();
343
344        let migrations = migrate(single_table_schema, empty_schema, &options).unwrap();
345        assert!(migrations.statements.is_empty());
346    }
347
348    #[test]
349    fn test_topological_sort_statements() {
350        let empty_schema = Schema::default();
351        let mut schema_with_tables = Schema::default();
352
353        // Create dependent tables: User depends on Team
354        let team_table = Table::new("team").column(Column {
355            name: "id".to_string(),
356            typ: Type::I32,
357            nullable: false,
358            primary_key: true,
359            default: None,
360            constraint: None,
361            generated: None,
362        });
363
364        let user_table = Table::new("user")
365            .column(Column {
366                name: "id".to_string(),
367                typ: Type::I32,
368                nullable: false,
369                primary_key: true,
370                default: None,
371                constraint: None,
372                generated: None,
373            })
374            .column(Column {
375                name: "team_id".to_string(),
376                typ: Type::I32,
377                nullable: false,
378                primary_key: false,
379                default: None,
380                constraint: Some(Constraint::ForeignKey(ForeignKey {
381                    table: "team".to_string(),
382                    columns: vec!["id".to_string()],
383                })),
384                generated: None,
385            });
386
387        schema_with_tables.tables.push(user_table);
388        schema_with_tables.tables.push(team_table);
389
390        let options = MigrationOptions::default();
391
392        // Generate migration
393        let migration = migrate(empty_schema, schema_with_tables, &options).unwrap();
394
395        // Check that team table is created before user table
396        let team_index = migration
397            .statements
398            .iter()
399            .position(|s| {
400                if let Statement::CreateTable(create) = s {
401                    create.name == "team"
402                } else {
403                    false
404                }
405            })
406            .unwrap();
407
408        let user_index = migration
409            .statements
410            .iter()
411            .position(|s| {
412                if let Statement::CreateTable(create) = s {
413                    create.name == "user"
414                } else {
415                    false
416                }
417            })
418            .unwrap();
419
420        assert!(
421            team_index < user_index,
422            "Team table should be created before User table"
423        );
424    }
425}