1use std::collections::HashMap;
2
3use crate::query::{AlterTable, Update};
4use crate::{AlterAction, Constraint, Dialect, DropTable, Index, Schema, Table, ToSql};
5use anyhow::Result;
6use topo_sort::{SortResults, TopoSort};
7
8#[derive(Debug, Clone, Default)]
9pub struct MigrationOptions {
10 pub debug: bool,
11 pub allow_destructive: bool,
12}
13
14pub fn migrate(current: Schema, desired: Schema, options: &MigrationOptions) -> Result<Migration> {
15 let current_tables = current
16 .tables
17 .iter()
18 .map(|t| (&t.name, t))
19 .collect::<HashMap<_, _>>();
20 let desired_tables = desired
21 .tables
22 .iter()
23 .map(|t| (&t.name, t))
24 .collect::<HashMap<_, _>>();
25
26 let mut debug_results = vec![];
27 let mut statements = Vec::new();
28 for (_name, &table) in desired_tables
30 .iter()
31 .filter(|(name, _)| !current_tables.contains_key(*name))
32 {
33 let statement = Statement::CreateTable(table.clone());
34 statements.push(statement);
35 }
36
37 for (name, desired_table) in desired_tables
39 .iter()
40 .filter(|(name, _)| current_tables.contains_key(*name))
41 {
42 let current_table = current_tables[name];
43 let current_columns = current_table
44 .columns
45 .iter()
46 .map(|c| (&c.name, c))
47 .collect::<HashMap<_, _>>();
48 let mut actions = vec![];
50 for desired_column in desired_table.columns.iter() {
51 if let Some(current) = current_columns.get(&desired_column.name) {
52 if current.nullable != desired_column.nullable {
53 actions.push(AlterAction::set_nullable(
54 desired_column.name.clone(),
55 desired_column.nullable,
56 ));
57 }
58 if !desired_column.typ.lossy_eq(¤t.typ) {
59 actions.push(AlterAction::set_type(
60 desired_column.name.clone(),
61 desired_column.typ.clone(),
62 ));
63 };
64 if desired_column.constraint.is_some() && current.constraint.is_none() {
65 if let Some(c) = &desired_column.constraint {
66 let name = desired_column.name.clone();
67 actions.push(AlterAction::add_constraint(
68 &desired_table.name,
69 name,
70 c.clone(),
71 ));
72 }
73 }
74 } else {
75 if desired_column.nullable {
77 actions.push(AlterAction::AddColumn {
78 column: desired_column.clone(),
79 });
80 } else {
81 let mut nullable = desired_column.clone();
82 nullable.nullable = true;
83 statements.push(Statement::AlterTable(AlterTable {
84 schema: desired_table.schema.clone(),
85 name: desired_table.name.clone(),
86 actions: vec![AlterAction::AddColumn { column: nullable }],
87 }));
88 statements.push(Statement::Update(
89 Update::new(name)
90 .set(
91 &desired_column.name,
92 "/* TODO set a value before setting the column to null */",
93 )
94 .where_(crate::query::Where::raw("true")),
95 ));
96 statements.push(Statement::AlterTable(AlterTable {
97 schema: desired_table.schema.clone(),
98 name: desired_table.name.clone(),
99 actions: vec![AlterAction::AlterColumn {
100 name: desired_column.name.clone(),
101 action: crate::query::AlterColumnAction::SetNullable(false),
102 }],
103 }));
104 }
105 }
106 }
107 if actions.is_empty() {
108 debug_results.push(DebugResults::TablesIdentical(name.to_string()));
109 } else {
110 statements.push(Statement::AlterTable(AlterTable {
111 schema: desired_table.schema.clone(),
112 name: desired_table.name.clone(),
113 actions,
114 }));
115 }
116 }
117
118 for (_name, current_table) in current_tables
119 .iter()
120 .filter(|(name, _)| !desired_tables.contains_key(*name))
121 {
122 if options.allow_destructive {
123 statements.push(Statement::DropTable(DropTable {
124 schema: current_table.schema.clone(),
125 name: current_table.name.clone(),
126 }));
127 } else {
128 debug_results.push(DebugResults::SkippedDropTable(current_table.name.clone()));
129 }
130 }
131
132 let sorted_statements = topologically_sort_statements(&statements, &desired_tables);
134
135 Ok(Migration {
136 statements: sorted_statements,
137 debug_results,
138 })
139}
140
141fn topologically_sort_statements(
143 statements: &[Statement],
144 tables: &HashMap<&String, &crate::schema::Table>,
145) -> Vec<Statement> {
146 let create_statements: Vec<_> = statements
148 .iter()
149 .filter(|s| matches!(s, Statement::CreateTable(_)))
150 .collect();
151
152 if create_statements.is_empty() {
153 return statements.to_vec();
155 }
156
157 let mut table_to_index = HashMap::new();
159 for (i, stmt) in create_statements.iter().enumerate() {
160 if let Statement::CreateTable(create) = stmt {
161 table_to_index.insert(create.name.clone(), i);
162 }
163 }
164
165 let mut topo_sort = TopoSort::new();
167
168 for stmt in &create_statements {
170 if let Statement::CreateTable(create) = stmt {
171 let table_name = &create.name;
172 let mut dependencies = Vec::new();
173
174 if let Some(table) = tables.values().find(|t| &t.name == table_name) {
176 for column in &table.columns {
178 if let Some(Constraint::ForeignKey(fk)) = &column.constraint {
179 dependencies.push(fk.table.clone());
180 }
181 }
182 }
183
184 dbg!(table_name, &dependencies);
186 topo_sort.insert(table_name.clone(), dependencies);
187 }
188 }
189
190 let table_order = match topo_sort.into_vec_nodes() {
192 SortResults::Full(nodes) => nodes,
193 SortResults::Partial(nodes) => {
194 nodes
196 }
197 };
198
199 let mut sorted_statements = Vec::new();
201 for table_name in &table_order {
202 if let Some(&idx) = table_to_index.get(table_name) {
203 sorted_statements.push(create_statements[idx].clone());
204 }
205 }
206
207 for stmt in statements {
209 if !matches!(stmt, Statement::CreateTable(_)) {
210 sorted_statements.push(stmt.clone());
211 }
212 }
213
214 sorted_statements
215}
216
217#[derive(Debug)]
218pub struct Migration {
219 pub statements: Vec<Statement>,
220 pub debug_results: Vec<DebugResults>,
221}
222
223impl Migration {
224 pub fn is_empty(&self) -> bool {
225 self.statements.is_empty()
226 }
227
228 pub fn set_schema(&mut self, schema_name: &str) {
229 for statement in &mut self.statements {
230 statement.set_schema(schema_name);
231 }
232 }
233}
234
235#[derive(Debug, Clone, PartialEq, Eq)]
236pub enum Statement {
237 CreateTable(Table),
238 CreateIndex(Index),
239 AlterTable(AlterTable),
240 DropTable(DropTable),
241 Update(Update),
242}
243
244impl Statement {
245 pub fn set_schema(&mut self, schema_name: &str) {
246 match self {
247 Statement::CreateTable(s) => {
248 s.schema = Some(schema_name.to_string());
249 }
250 Statement::AlterTable(s) => {
251 s.schema = Some(schema_name.to_string());
252 }
253 Statement::DropTable(s) => {
254 s.schema = Some(schema_name.to_string());
255 }
256 Statement::CreateIndex(s) => {
257 s.schema = Some(schema_name.to_string());
258 }
259 Statement::Update(s) => {
260 s.schema = Some(schema_name.to_string());
261 }
262 }
263 }
264
265 pub fn table_name(&self) -> &str {
266 match self {
267 Statement::CreateTable(s) => &s.name,
268 Statement::AlterTable(s) => &s.name,
269 Statement::DropTable(s) => &s.name,
270 Statement::CreateIndex(s) => &s.table,
271 Statement::Update(s) => &s.table,
272 }
273 }
274}
275
276impl ToSql for Statement {
277 fn write_sql(&self, buf: &mut String, dialect: Dialect) {
278 use Statement::*;
279 match self {
280 CreateTable(c) => c.write_sql(buf, dialect),
281 CreateIndex(c) => c.write_sql(buf, dialect),
282 AlterTable(a) => a.write_sql(buf, dialect),
283 DropTable(d) => d.write_sql(buf, dialect),
284 Update(u) => u.write_sql(buf, dialect),
285 }
286 }
287}
288
289#[derive(Debug)]
290pub enum DebugResults {
291 TablesIdentical(String),
292 SkippedDropTable(String),
293}
294
295impl DebugResults {
296 pub fn table_name(&self) -> &str {
297 match self {
298 DebugResults::TablesIdentical(name) => name,
299 DebugResults::SkippedDropTable(name) => name,
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 use crate::Table;
309 use crate::Type;
310 use crate::schema::{Column, Constraint, ForeignKey};
311
312 #[test]
313 fn test_drop_table() {
314 let empty_schema = Schema::default();
315 let mut single_table_schema = Schema::default();
316 let t = Table::new("new_table");
317 single_table_schema.tables.push(t.clone());
318 let mut allow_destructive_options = MigrationOptions::default();
319 allow_destructive_options.allow_destructive = true;
320
321 let mut migrations = migrate(
322 single_table_schema,
323 empty_schema,
324 &allow_destructive_options,
325 )
326 .unwrap();
327
328 let statement = migrations.statements.pop().unwrap();
329 let expected_statement = Statement::DropTable(DropTable {
330 schema: t.schema,
331 name: t.name,
332 });
333
334 assert_eq!(statement, expected_statement);
335 }
336
337 #[test]
338 fn test_drop_table_without_destructive_operations() {
339 let empty_schema = Schema::default();
340 let mut single_table_schema = Schema::default();
341 let t = Table::new("new_table");
342 single_table_schema.tables.push(t.clone());
343 let options = MigrationOptions::default();
344
345 let migrations = migrate(single_table_schema, empty_schema, &options).unwrap();
346 assert!(migrations.statements.is_empty());
347 }
348
349 #[test]
350 fn test_topological_sort_statements() {
351 let empty_schema = Schema::default();
352 let mut schema_with_tables = Schema::default();
353
354 let team_table = Table::new("team").column(Column {
356 name: "id".to_string(),
357 typ: Type::I32,
358 nullable: false,
359 primary_key: true,
360 default: None,
361 constraint: None,
362 generated: None,
363 });
364
365 let user_table = Table::new("user")
366 .column(Column {
367 name: "id".to_string(),
368 typ: Type::I32,
369 nullable: false,
370 primary_key: true,
371 default: None,
372 constraint: None,
373 generated: None,
374 })
375 .column(Column {
376 name: "team_id".to_string(),
377 typ: Type::I32,
378 nullable: false,
379 primary_key: false,
380 default: None,
381 constraint: Some(Constraint::ForeignKey(ForeignKey {
382 table: "team".to_string(),
383 columns: vec!["id".to_string()],
384 })),
385 generated: None,
386 });
387
388 schema_with_tables.tables.push(user_table);
389 schema_with_tables.tables.push(team_table);
390
391 let options = MigrationOptions::default();
392
393 let migration = migrate(empty_schema, schema_with_tables, &options).unwrap();
395
396 let team_index = migration
398 .statements
399 .iter()
400 .position(|s| {
401 if let Statement::CreateTable(create) = s {
402 create.name == "team"
403 } else {
404 false
405 }
406 })
407 .unwrap();
408
409 let user_index = migration
410 .statements
411 .iter()
412 .position(|s| {
413 if let Statement::CreateTable(create) = s {
414 create.name == "user"
415 } else {
416 false
417 }
418 })
419 .unwrap();
420
421 assert!(
422 team_index < user_index,
423 "Team table should be created before User table"
424 );
425 }
426}