Skip to main content

oxide_sql_core/migrations/
codegen.rs

1//! Migration code generation.
2//!
3//! Generates Rust source code implementing the [`Migration`] trait
4//! from a [`SchemaDiff`], enabling `makemigrations`-style tooling.
5
6use super::column_builder::DefaultValue;
7use super::diff::SchemaDiff;
8use super::operation::{AlterColumnChange, CreateTableOp, Operation};
9use crate::ast::DataType;
10
11/// Generates a Rust source string containing a `Migration` impl
12/// for the given diff.
13///
14/// # Arguments
15///
16/// * `id` — The migration ID (e.g. `"0002_add_email"`).
17/// * `diff` — The schema diff to translate into code.
18///
19/// # Returns
20///
21/// A Rust source string that, when compiled, produces a struct
22/// implementing the `Migration` trait with `up()` and `down()`
23/// methods.
24#[must_use]
25pub fn generate_migration_code(id: &str, diff: &SchemaDiff) -> String {
26    let struct_name = id_to_struct_name(id);
27    let up_body = render_operations(&diff.operations);
28    let down_body = render_down(&diff.operations);
29
30    format!(
31        "use oxide_sql_core::migrations::{{\n\
32         \x20   Migration, Operation, CreateTableBuilder,\n\
33         \x20   bigint, varchar, text, integer, smallint,\n\
34         \x20   boolean, timestamp, datetime, date, time,\n\
35         \x20   real, double, decimal, numeric, blob, binary,\n\
36         \x20   varbinary, char,\n\
37         }};\n\
38         \n\
39         pub struct {struct_name};\n\
40         \n\
41         impl Migration for {struct_name} {{\n\
42         \x20   const ID: &'static str = \"{id}\";\n\
43         \n\
44         \x20   fn up() -> Vec<Operation> {{\n\
45         \x20       vec![\n\
46         {up_body}\
47         \x20       ]\n\
48         \x20   }}\n\
49         \n\
50         \x20   fn down() -> Vec<Operation> {{\n\
51         \x20       vec![\n\
52         {down_body}\
53         \x20       ]\n\
54         \x20   }}\n\
55         }}\n"
56    )
57}
58
59// ================================================================
60// Internal helpers
61// ================================================================
62
63/// Converts a migration ID like "0002_add_email" into a struct
64/// name like "Migration0002AddEmail".
65fn id_to_struct_name(id: &str) -> String {
66    let mut result = String::from("Migration");
67    let mut capitalize_next = true;
68    for ch in id.chars() {
69        if ch == '_' {
70            capitalize_next = true;
71        } else if capitalize_next {
72            result.push(ch.to_ascii_uppercase());
73            capitalize_next = false;
74        } else {
75            result.push(ch);
76        }
77    }
78    result
79}
80
81/// Renders a list of operations as Rust expressions.
82fn render_operations(ops: &[Operation]) -> String {
83    let mut out = String::new();
84    for op in ops {
85        out.push_str(&format!("            {},\n", render_operation(op)));
86    }
87    out
88}
89
90/// Renders the `down()` body from the up operations.
91fn render_down(ops: &[Operation]) -> String {
92    let mut out = String::new();
93    for op in ops.iter().rev() {
94        match op.reverse() {
95            Some(rev) => {
96                out.push_str(&format!("            {},\n", render_operation(&rev)));
97            }
98            None => {
99                out.push_str(&format!(
100                    "            // TODO: cannot auto-reverse: \
101                     {:?}\n",
102                    op_summary(op)
103                ));
104            }
105        }
106    }
107    out
108}
109
110/// Short human-readable summary for comments.
111fn op_summary(op: &Operation) -> String {
112    match op {
113        Operation::CreateTable(ct) => {
114            format!("CreateTable({})", ct.name)
115        }
116        Operation::DropTable(dt) => {
117            format!("DropTable({})", dt.name)
118        }
119        Operation::RenameTable(rt) => {
120            format!("RenameTable({} -> {})", rt.old_name, rt.new_name)
121        }
122        Operation::AddColumn(ac) => {
123            format!("AddColumn({}.{})", ac.table, ac.column.name)
124        }
125        Operation::DropColumn(dc) => {
126            format!("DropColumn({}.{})", dc.table, dc.column)
127        }
128        Operation::AlterColumn(ac) => {
129            format!("AlterColumn({}.{})", ac.table, ac.column)
130        }
131        Operation::RenameColumn(rc) => {
132            format!(
133                "RenameColumn({}.{} -> {})",
134                rc.table, rc.old_name, rc.new_name
135            )
136        }
137        Operation::CreateIndex(ci) => {
138            format!("CreateIndex({})", ci.name)
139        }
140        Operation::DropIndex(di) => {
141            format!("DropIndex({})", di.name)
142        }
143        Operation::AddForeignKey(fk) => {
144            format!("AddForeignKey({} -> {})", fk.table, fk.references_table)
145        }
146        Operation::DropForeignKey(fk) => {
147            format!("DropForeignKey({}.{})", fk.table, fk.name)
148        }
149        Operation::RunSql(_) => "RunSql(...)".to_string(),
150    }
151}
152
153/// Renders a single Operation as a Rust expression string.
154fn render_operation(op: &Operation) -> String {
155    match op {
156        Operation::CreateTable(ct) => render_create_table(ct),
157        Operation::DropTable(dt) => {
158            format!("Operation::drop_table(\"{}\")", dt.name)
159        }
160        Operation::RenameTable(rt) => {
161            format!(
162                "Operation::rename_table(\"{}\", \"{}\")",
163                rt.old_name, rt.new_name
164            )
165        }
166        Operation::AddColumn(ac) => {
167            format!(
168                "Operation::add_column(\"{}\", {})",
169                ac.table,
170                render_column_builder(&ac.column.name, &ac.column)
171            )
172        }
173        Operation::DropColumn(dc) => {
174            format!(
175                "Operation::drop_column(\"{}\", \"{}\")",
176                dc.table, dc.column
177            )
178        }
179        Operation::RenameColumn(rc) => {
180            format!(
181                "Operation::rename_column(\"{}\", \"{}\", \"{}\")",
182                rc.table, rc.old_name, rc.new_name
183            )
184        }
185        Operation::AlterColumn(ac) => render_alter_column(ac),
186        Operation::CreateIndex(ci) => {
187            format!(
188                "Operation::CreateIndex(CreateIndexOp {{ \
189                 name: \"{}\".into(), \
190                 table: \"{}\".into(), \
191                 columns: vec![{}], \
192                 unique: {}, \
193                 index_type: IndexType::BTree, \
194                 if_not_exists: false, \
195                 condition: None \
196                 }})",
197                ci.name,
198                ci.table,
199                ci.columns
200                    .iter()
201                    .map(|c| format!("\"{c}\".into()"))
202                    .collect::<Vec<_>>()
203                    .join(", "),
204                ci.unique,
205            )
206        }
207        Operation::DropIndex(di) => {
208            format!(
209                "Operation::DropIndex(DropIndexOp {{ \
210                 name: \"{}\".into(), table: None, \
211                 if_exists: false }})",
212                di.name
213            )
214        }
215        Operation::AddForeignKey(_) | Operation::DropForeignKey(_) => {
216            format!("// TODO: manually write FK operation: {:?}", op_summary(op))
217        }
218        Operation::RunSql(rs) => {
219            if let Some(ref down) = rs.down_sql {
220                format!(
221                    "Operation::run_sql_reversible(\"{}\", \"{}\")",
222                    escape_str(&rs.up_sql),
223                    escape_str(down)
224                )
225            } else {
226                format!("Operation::run_sql(\"{}\")", escape_str(&rs.up_sql))
227            }
228        }
229    }
230}
231
232/// Renders a `CreateTableBuilder` chain.
233fn render_create_table(ct: &CreateTableOp) -> String {
234    let mut s = String::from("CreateTableBuilder::new()\n");
235    s.push_str(&format!("                .name(\"{}\")\n", ct.name));
236    for col in &ct.columns {
237        s.push_str(&format!(
238            "                .column({})\n",
239            render_column_builder(&col.name, col)
240        ));
241    }
242    if ct.if_not_exists {
243        s.push_str("                .if_not_exists()\n");
244    }
245    s.push_str("                .build()\n");
246    s.push_str("                .into()");
247    s
248}
249
250/// Renders a column builder expression.
251fn render_column_builder(_name: &str, col: &super::column_builder::ColumnDefinition) -> String {
252    let type_fn = match &col.data_type {
253        DataType::Bigint => {
254            format!("bigint(\"{}\")", col.name)
255        }
256        DataType::Integer => {
257            format!("integer(\"{}\")", col.name)
258        }
259        DataType::Smallint => {
260            format!("smallint(\"{}\")", col.name)
261        }
262        DataType::Text => {
263            format!("text(\"{}\")", col.name)
264        }
265        DataType::Varchar(Some(len)) => {
266            format!("varchar(\"{}\", {len})", col.name)
267        }
268        DataType::Varchar(None) => {
269            format!("text(\"{}\")", col.name)
270        }
271        DataType::Boolean => {
272            format!("boolean(\"{}\")", col.name)
273        }
274        DataType::Timestamp => {
275            format!("timestamp(\"{}\")", col.name)
276        }
277        DataType::Datetime => {
278            format!("datetime(\"{}\")", col.name)
279        }
280        DataType::Date => {
281            format!("date(\"{}\")", col.name)
282        }
283        DataType::Time => {
284            format!("time(\"{}\")", col.name)
285        }
286        DataType::Real => {
287            format!("real(\"{}\")", col.name)
288        }
289        DataType::Double => {
290            format!("double(\"{}\")", col.name)
291        }
292        DataType::Blob => {
293            format!("blob(\"{}\")", col.name)
294        }
295        DataType::Decimal {
296            precision: Some(p),
297            scale: Some(s),
298        } => {
299            format!("decimal(\"{}\", {p}, {s})", col.name)
300        }
301        DataType::Numeric {
302            precision: Some(p),
303            scale: Some(s),
304        } => {
305            format!("numeric(\"{}\", {p}, {s})", col.name)
306        }
307        DataType::Char(Some(len)) => {
308            format!("char(\"{}\", {len})", col.name)
309        }
310        _ => format!("text(\"{}\")", col.name),
311    };
312
313    let mut chain = type_fn;
314    if col.primary_key {
315        chain.push_str(".primary_key()");
316    }
317    if col.autoincrement {
318        chain.push_str(".autoincrement()");
319    }
320    if !col.nullable && !col.primary_key {
321        chain.push_str(".not_null()");
322    }
323    if col.unique {
324        chain.push_str(".unique()");
325    }
326    if let Some(ref default) = col.default {
327        match default {
328            DefaultValue::Boolean(b) => {
329                chain.push_str(&format!(".default_bool({b})"));
330            }
331            DefaultValue::Integer(i) => {
332                chain.push_str(&format!(".default_int({i})"));
333            }
334            DefaultValue::Float(f) => {
335                chain.push_str(&format!(".default_float({f})"));
336            }
337            DefaultValue::String(s) => {
338                chain.push_str(&format!(".default_str(\"{}\")", escape_str(s)));
339            }
340            DefaultValue::Null => {
341                chain.push_str(".default_null()");
342            }
343            DefaultValue::Expression(expr) => {
344                chain.push_str(&format!(".default_expr(\"{}\")", escape_str(expr)));
345            }
346        }
347    }
348    chain.push_str(".build()");
349    chain
350}
351
352/// Renders an AlterColumn operation as Rust code.
353fn render_alter_column(ac: &super::operation::AlterColumnOp) -> String {
354    let change = match &ac.change {
355        AlterColumnChange::SetDataType(dt) => {
356            format!("AlterColumnChange::SetDataType(DataType::{:?})", dt)
357        }
358        AlterColumnChange::SetNullable(n) => {
359            format!("AlterColumnChange::SetNullable({n})")
360        }
361        AlterColumnChange::SetDefault(d) => {
362            format!("AlterColumnChange::SetDefault({})", render_default_value(d))
363        }
364        AlterColumnChange::DropDefault => "AlterColumnChange::DropDefault".to_string(),
365        AlterColumnChange::SetUnique(u) => {
366            format!("AlterColumnChange::SetUnique({u})")
367        }
368        AlterColumnChange::SetAutoincrement(a) => {
369            format!("AlterColumnChange::SetAutoincrement({a})")
370        }
371    };
372    format!(
373        "Operation::AlterColumn(AlterColumnOp {{ \
374         table: \"{}\".into(), \
375         column: \"{}\".into(), \
376         change: {} }})",
377        ac.table, ac.column, change
378    )
379}
380
381/// Renders a DefaultValue as Rust code.
382fn render_default_value(dv: &DefaultValue) -> String {
383    match dv {
384        DefaultValue::Null => "DefaultValue::Null".to_string(),
385        DefaultValue::Boolean(b) => {
386            format!("DefaultValue::Boolean({b})")
387        }
388        DefaultValue::Integer(i) => {
389            format!("DefaultValue::Integer({i})")
390        }
391        DefaultValue::Float(f) => {
392            format!("DefaultValue::Float({f})")
393        }
394        DefaultValue::String(s) => {
395            format!("DefaultValue::String(\"{}\".into())", escape_str(s))
396        }
397        DefaultValue::Expression(e) => {
398            format!("DefaultValue::Expression(\"{}\".into())", escape_str(e))
399        }
400    }
401}
402
403/// Escapes a string for inclusion in a Rust string literal.
404fn escape_str(s: &str) -> String {
405    s.replace('\\', "\\\\").replace('"', "\\\"")
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use crate::migrations::column_builder::varchar;
412    use crate::migrations::diff::SchemaDiff;
413    use crate::migrations::operation::Operation;
414    use crate::migrations::table_builder::CreateTableBuilder;
415
416    #[test]
417    fn id_to_struct_name_works() {
418        assert_eq!(
419            id_to_struct_name("0001_create_users"),
420            "Migration0001CreateUsers"
421        );
422        assert_eq!(id_to_struct_name("0002_add_email"), "Migration0002AddEmail");
423    }
424
425    #[test]
426    fn generate_simple_migration() {
427        let diff = SchemaDiff {
428            operations: vec![Operation::add_column(
429                "users",
430                varchar("email", 255).not_null().build(),
431            )],
432            ambiguous: vec![],
433            warnings: vec![],
434        };
435
436        let code = generate_migration_code("0002_add_email", &diff);
437        assert!(code.contains("struct Migration0002AddEmail"));
438        assert!(code.contains("fn up()"));
439        assert!(code.contains("fn down()"));
440        assert!(code.contains("add_column"));
441        assert!(code.contains("varchar"));
442        assert!(code.contains("drop_column"));
443    }
444
445    #[test]
446    fn generate_create_table_migration() {
447        let op: Operation = CreateTableBuilder::new()
448            .name("users")
449            .column(
450                crate::migrations::column_builder::bigint("id")
451                    .primary_key()
452                    .autoincrement()
453                    .build(),
454            )
455            .column(varchar("name", 255).not_null().unique().build())
456            .build()
457            .into();
458
459        let diff = SchemaDiff {
460            operations: vec![op],
461            ambiguous: vec![],
462            warnings: vec![],
463        };
464
465        let code = generate_migration_code("0001_create_users", &diff);
466        assert!(code.contains("CreateTableBuilder::new()"));
467        assert!(code.contains(".primary_key()"));
468        assert!(code.contains(".autoincrement()"));
469        assert!(code.contains(".unique()"));
470        // down should have drop_table
471        assert!(code.contains("drop_table"));
472    }
473}