1use std::collections::HashMap;
2
3use crate::query::{AlterTable, Update};
4use crate::{AlterAction, Constraint, Dialect, DropTable, Index, Schema, Table, ToSql};
5use anyhow::Result;
6use topo_sort::{SortResults, TopoSort};
7
8#[derive(Debug, Clone, Default)]
9pub struct MigrationOptions {
10 pub debug: bool,
11 pub allow_destructive: bool,
12}
13
14pub fn migrate(current: Schema, desired: Schema, options: &MigrationOptions) -> Result<Migration> {
15 let current_tables = current
16 .tables
17 .iter()
18 .map(|t| (&t.name, t))
19 .collect::<HashMap<_, _>>();
20 let desired_tables = desired
21 .tables
22 .iter()
23 .map(|t| (&t.name, t))
24 .collect::<HashMap<_, _>>();
25
26 let mut debug_results = vec![];
27 let mut statements = Vec::new();
28 for (_name, &table) in desired_tables
30 .iter()
31 .filter(|(name, _)| !current_tables.contains_key(*name))
32 {
33 let statement = Statement::CreateTable(table.clone());
34 statements.push(statement);
35 }
36
37 for (name, desired_table) in desired_tables
39 .iter()
40 .filter(|(name, _)| current_tables.contains_key(*name))
41 {
42 let current_table = current_tables[name];
43 let current_columns = current_table
44 .columns
45 .iter()
46 .map(|c| (&c.name, c))
47 .collect::<HashMap<_, _>>();
48 let mut actions = vec![];
50 for desired_column in desired_table.columns.iter() {
51 if let Some(current) = current_columns.get(&desired_column.name) {
52 if current.nullable != desired_column.nullable {
53 actions.push(AlterAction::set_nullable(
54 desired_column.name.clone(),
55 desired_column.nullable,
56 ));
57 }
58 if !desired_column.typ.lossy_eq(¤t.typ) {
59 actions.push(AlterAction::set_type(
60 desired_column.name.clone(),
61 desired_column.typ.clone(),
62 ));
63 };
64 if desired_column.constraint.is_some() && current.constraint.is_none() {
65 if let Some(c) = &desired_column.constraint {
66 let name = desired_column.name.clone();
67 actions.push(AlterAction::add_constraint(
68 &desired_table.name,
69 name,
70 c.clone(),
71 ));
72 }
73 }
74 } else {
75 if desired_column.nullable {
77 actions.push(AlterAction::AddColumn {
78 column: desired_column.clone(),
79 });
80 } else {
81 let mut nullable = desired_column.clone();
82 nullable.nullable = true;
83 statements.push(Statement::AlterTable(AlterTable {
84 schema: desired_table.schema.clone(),
85 name: desired_table.name.clone(),
86 actions: vec![AlterAction::AddColumn { column: nullable }],
87 }));
88 statements.push(Statement::Update(
89 Update::new(name)
90 .set(
91 &desired_column.name,
92 "/* TODO set a value before setting the column to null */",
93 )
94 .where_(crate::query::Where::raw("true")),
95 ));
96 statements.push(Statement::AlterTable(AlterTable {
97 schema: desired_table.schema.clone(),
98 name: desired_table.name.clone(),
99 actions: vec![AlterAction::AlterColumn {
100 name: desired_column.name.clone(),
101 action: crate::query::AlterColumnAction::SetNullable(false),
102 }],
103 }));
104 }
105 }
106 }
107 if actions.is_empty() {
108 debug_results.push(DebugResults::TablesIdentical(name.to_string()));
109 } else {
110 statements.push(Statement::AlterTable(AlterTable {
111 schema: desired_table.schema.clone(),
112 name: desired_table.name.clone(),
113 actions,
114 }));
115 }
116 }
117
118 for (_name, current_table) in current_tables
119 .iter()
120 .filter(|(name, _)| !desired_tables.contains_key(*name))
121 {
122 if options.allow_destructive {
123 statements.push(Statement::DropTable(DropTable {
124 schema: current_table.schema.clone(),
125 name: current_table.name.clone(),
126 }));
127 } else {
128 debug_results.push(DebugResults::SkippedDropTable(current_table.name.clone()));
129 }
130 }
131
132 let sorted_statements = topologically_sort_statements(&statements, &desired_tables);
134
135 Ok(Migration {
136 statements: sorted_statements,
137 debug_results,
138 })
139}
140
141fn topologically_sort_statements(
143 statements: &[Statement],
144 tables: &HashMap<&String, &crate::schema::Table>,
145) -> Vec<Statement> {
146 let create_statements: Vec<_> = statements
148 .iter()
149 .filter(|s| matches!(s, Statement::CreateTable(_)))
150 .collect();
151
152 if create_statements.is_empty() {
153 return statements.to_vec();
155 }
156
157 let mut table_to_index = HashMap::new();
159 for (i, stmt) in create_statements.iter().enumerate() {
160 if let Statement::CreateTable(create) = stmt {
161 table_to_index.insert(create.name.clone(), i);
162 }
163 }
164
165 let mut topo_sort = TopoSort::new();
167
168 for stmt in &create_statements {
170 if let Statement::CreateTable(create) = stmt {
171 let table_name = &create.name;
172 let mut dependencies = Vec::new();
173
174 if let Some(table) = tables.values().find(|t| &t.name == table_name) {
176 for column in &table.columns {
178 if let Some(Constraint::ForeignKey(fk)) = &column.constraint {
179 dependencies.push(fk.table.clone());
180 }
181 }
182 }
183
184 topo_sort.insert(table_name.clone(), dependencies);
186 }
187 }
188
189 let table_order = match topo_sort.into_vec_nodes() {
191 SortResults::Full(nodes) => nodes,
192 SortResults::Partial(nodes) => {
193 nodes
195 }
196 };
197
198 let mut sorted_statements = Vec::new();
200 for table_name in &table_order {
201 if let Some(&idx) = table_to_index.get(table_name) {
202 sorted_statements.push(create_statements[idx].clone());
203 }
204 }
205
206 for stmt in statements {
208 if !matches!(stmt, Statement::CreateTable(_)) {
209 sorted_statements.push(stmt.clone());
210 }
211 }
212
213 sorted_statements
214}
215
216#[derive(Debug)]
217pub struct Migration {
218 pub statements: Vec<Statement>,
219 pub debug_results: Vec<DebugResults>,
220}
221
222impl Migration {
223 pub fn is_empty(&self) -> bool {
224 self.statements.is_empty()
225 }
226
227 pub fn set_schema(&mut self, schema_name: &str) {
228 for statement in &mut self.statements {
229 statement.set_schema(schema_name);
230 }
231 }
232}
233
234#[derive(Debug, Clone, PartialEq, Eq)]
235pub enum Statement {
236 CreateTable(Table),
237 CreateIndex(Index),
238 AlterTable(AlterTable),
239 DropTable(DropTable),
240 Update(Update),
241}
242
243impl Statement {
244 pub fn set_schema(&mut self, schema_name: &str) {
245 match self {
246 Statement::CreateTable(s) => {
247 s.schema = Some(schema_name.to_string());
248 }
249 Statement::AlterTable(s) => {
250 s.schema = Some(schema_name.to_string());
251 }
252 Statement::DropTable(s) => {
253 s.schema = Some(schema_name.to_string());
254 }
255 Statement::CreateIndex(s) => {
256 s.schema = Some(schema_name.to_string());
257 }
258 Statement::Update(s) => {
259 s.schema = Some(schema_name.to_string());
260 }
261 }
262 }
263
264 pub fn table_name(&self) -> &str {
265 match self {
266 Statement::CreateTable(s) => &s.name,
267 Statement::AlterTable(s) => &s.name,
268 Statement::DropTable(s) => &s.name,
269 Statement::CreateIndex(s) => &s.table,
270 Statement::Update(s) => &s.table,
271 }
272 }
273}
274
275impl ToSql for Statement {
276 fn write_sql(&self, buf: &mut String, dialect: Dialect) {
277 use Statement::*;
278 match self {
279 CreateTable(c) => c.write_sql(buf, dialect),
280 CreateIndex(c) => c.write_sql(buf, dialect),
281 AlterTable(a) => a.write_sql(buf, dialect),
282 DropTable(d) => d.write_sql(buf, dialect),
283 Update(u) => u.write_sql(buf, dialect),
284 }
285 }
286}
287
288#[derive(Debug)]
289pub enum DebugResults {
290 TablesIdentical(String),
291 SkippedDropTable(String),
292}
293
294impl DebugResults {
295 pub fn table_name(&self) -> &str {
296 match self {
297 DebugResults::TablesIdentical(name) => name,
298 DebugResults::SkippedDropTable(name) => name,
299 }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 use crate::Table;
308 use crate::Type;
309 use crate::schema::{Column, Constraint, ForeignKey};
310
311 #[test]
312 fn test_drop_table() {
313 let empty_schema = Schema::default();
314 let mut single_table_schema = Schema::default();
315 let t = Table::new("new_table");
316 single_table_schema.tables.push(t.clone());
317 let mut allow_destructive_options = MigrationOptions::default();
318 allow_destructive_options.allow_destructive = true;
319
320 let mut migrations = migrate(
321 single_table_schema,
322 empty_schema,
323 &allow_destructive_options,
324 )
325 .unwrap();
326
327 let statement = migrations.statements.pop().unwrap();
328 let expected_statement = Statement::DropTable(DropTable {
329 schema: t.schema,
330 name: t.name,
331 });
332
333 assert_eq!(statement, expected_statement);
334 }
335
336 #[test]
337 fn test_drop_table_without_destructive_operations() {
338 let empty_schema = Schema::default();
339 let mut single_table_schema = Schema::default();
340 let t = Table::new("new_table");
341 single_table_schema.tables.push(t.clone());
342 let options = MigrationOptions::default();
343
344 let migrations = migrate(single_table_schema, empty_schema, &options).unwrap();
345 assert!(migrations.statements.is_empty());
346 }
347
348 #[test]
349 fn test_topological_sort_statements() {
350 let empty_schema = Schema::default();
351 let mut schema_with_tables = Schema::default();
352
353 let team_table = Table::new("team").column(Column {
355 name: "id".to_string(),
356 typ: Type::I32,
357 nullable: false,
358 primary_key: true,
359 default: None,
360 constraint: None,
361 generated: None,
362 });
363
364 let user_table = Table::new("user")
365 .column(Column {
366 name: "id".to_string(),
367 typ: Type::I32,
368 nullable: false,
369 primary_key: true,
370 default: None,
371 constraint: None,
372 generated: None,
373 })
374 .column(Column {
375 name: "team_id".to_string(),
376 typ: Type::I32,
377 nullable: false,
378 primary_key: false,
379 default: None,
380 constraint: Some(Constraint::ForeignKey(ForeignKey {
381 table: "team".to_string(),
382 columns: vec!["id".to_string()],
383 })),
384 generated: None,
385 });
386
387 schema_with_tables.tables.push(user_table);
388 schema_with_tables.tables.push(team_table);
389
390 let options = MigrationOptions::default();
391
392 let migration = migrate(empty_schema, schema_with_tables, &options).unwrap();
394
395 let team_index = migration
397 .statements
398 .iter()
399 .position(|s| {
400 if let Statement::CreateTable(create) = s {
401 create.name == "team"
402 } else {
403 false
404 }
405 })
406 .unwrap();
407
408 let user_index = migration
409 .statements
410 .iter()
411 .position(|s| {
412 if let Statement::CreateTable(create) = s {
413 create.name == "user"
414 } else {
415 false
416 }
417 })
418 .unwrap();
419
420 assert!(
421 team_index < user_index,
422 "Team table should be created before User table"
423 );
424 }
425}