oxide_sql_core/migrations/
codegen.rs1use super::column_builder::DefaultValue;
7use super::diff::SchemaDiff;
8use super::operation::{AlterColumnChange, CreateTableOp, Operation};
9use crate::ast::DataType;
10
11#[must_use]
25pub fn generate_migration_code(id: &str, diff: &SchemaDiff) -> String {
26 let struct_name = id_to_struct_name(id);
27 let up_body = render_operations(&diff.operations);
28 let down_body = render_down(&diff.operations);
29
30 format!(
31 "use oxide_sql_core::migrations::{{\n\
32 \x20 Migration, Operation, CreateTableBuilder,\n\
33 \x20 bigint, varchar, text, integer, smallint,\n\
34 \x20 boolean, timestamp, datetime, date, time,\n\
35 \x20 real, double, decimal, numeric, blob, binary,\n\
36 \x20 varbinary, char,\n\
37 }};\n\
38 \n\
39 pub struct {struct_name};\n\
40 \n\
41 impl Migration for {struct_name} {{\n\
42 \x20 const ID: &'static str = \"{id}\";\n\
43 \n\
44 \x20 fn up() -> Vec<Operation> {{\n\
45 \x20 vec![\n\
46 {up_body}\
47 \x20 ]\n\
48 \x20 }}\n\
49 \n\
50 \x20 fn down() -> Vec<Operation> {{\n\
51 \x20 vec![\n\
52 {down_body}\
53 \x20 ]\n\
54 \x20 }}\n\
55 }}\n"
56 )
57}
58
59fn id_to_struct_name(id: &str) -> String {
66 let mut result = String::from("Migration");
67 let mut capitalize_next = true;
68 for ch in id.chars() {
69 if ch == '_' {
70 capitalize_next = true;
71 } else if capitalize_next {
72 result.push(ch.to_ascii_uppercase());
73 capitalize_next = false;
74 } else {
75 result.push(ch);
76 }
77 }
78 result
79}
80
81fn render_operations(ops: &[Operation]) -> String {
83 let mut out = String::new();
84 for op in ops {
85 out.push_str(&format!(" {},\n", render_operation(op)));
86 }
87 out
88}
89
90fn render_down(ops: &[Operation]) -> String {
92 let mut out = String::new();
93 for op in ops.iter().rev() {
94 match op.reverse() {
95 Some(rev) => {
96 out.push_str(&format!(" {},\n", render_operation(&rev)));
97 }
98 None => {
99 out.push_str(&format!(
100 " // TODO: cannot auto-reverse: \
101 {:?}\n",
102 op_summary(op)
103 ));
104 }
105 }
106 }
107 out
108}
109
110fn op_summary(op: &Operation) -> String {
112 match op {
113 Operation::CreateTable(ct) => {
114 format!("CreateTable({})", ct.name)
115 }
116 Operation::DropTable(dt) => {
117 format!("DropTable({})", dt.name)
118 }
119 Operation::RenameTable(rt) => {
120 format!("RenameTable({} -> {})", rt.old_name, rt.new_name)
121 }
122 Operation::AddColumn(ac) => {
123 format!("AddColumn({}.{})", ac.table, ac.column.name)
124 }
125 Operation::DropColumn(dc) => {
126 format!("DropColumn({}.{})", dc.table, dc.column)
127 }
128 Operation::AlterColumn(ac) => {
129 format!("AlterColumn({}.{})", ac.table, ac.column)
130 }
131 Operation::RenameColumn(rc) => {
132 format!(
133 "RenameColumn({}.{} -> {})",
134 rc.table, rc.old_name, rc.new_name
135 )
136 }
137 Operation::CreateIndex(ci) => {
138 format!("CreateIndex({})", ci.name)
139 }
140 Operation::DropIndex(di) => {
141 format!("DropIndex({})", di.name)
142 }
143 Operation::AddForeignKey(fk) => {
144 format!("AddForeignKey({} -> {})", fk.table, fk.references_table)
145 }
146 Operation::DropForeignKey(fk) => {
147 format!("DropForeignKey({}.{})", fk.table, fk.name)
148 }
149 Operation::RunSql(_) => "RunSql(...)".to_string(),
150 }
151}
152
153fn render_operation(op: &Operation) -> String {
155 match op {
156 Operation::CreateTable(ct) => render_create_table(ct),
157 Operation::DropTable(dt) => {
158 format!("Operation::drop_table(\"{}\")", dt.name)
159 }
160 Operation::RenameTable(rt) => {
161 format!(
162 "Operation::rename_table(\"{}\", \"{}\")",
163 rt.old_name, rt.new_name
164 )
165 }
166 Operation::AddColumn(ac) => {
167 format!(
168 "Operation::add_column(\"{}\", {})",
169 ac.table,
170 render_column_builder(&ac.column.name, &ac.column)
171 )
172 }
173 Operation::DropColumn(dc) => {
174 format!(
175 "Operation::drop_column(\"{}\", \"{}\")",
176 dc.table, dc.column
177 )
178 }
179 Operation::RenameColumn(rc) => {
180 format!(
181 "Operation::rename_column(\"{}\", \"{}\", \"{}\")",
182 rc.table, rc.old_name, rc.new_name
183 )
184 }
185 Operation::AlterColumn(ac) => render_alter_column(ac),
186 Operation::CreateIndex(ci) => {
187 format!(
188 "Operation::CreateIndex(CreateIndexOp {{ \
189 name: \"{}\".into(), \
190 table: \"{}\".into(), \
191 columns: vec![{}], \
192 unique: {}, \
193 index_type: IndexType::BTree, \
194 if_not_exists: false, \
195 condition: None \
196 }})",
197 ci.name,
198 ci.table,
199 ci.columns
200 .iter()
201 .map(|c| format!("\"{c}\".into()"))
202 .collect::<Vec<_>>()
203 .join(", "),
204 ci.unique,
205 )
206 }
207 Operation::DropIndex(di) => {
208 format!(
209 "Operation::DropIndex(DropIndexOp {{ \
210 name: \"{}\".into(), table: None, \
211 if_exists: false }})",
212 di.name
213 )
214 }
215 Operation::AddForeignKey(_) | Operation::DropForeignKey(_) => {
216 format!("// TODO: manually write FK operation: {:?}", op_summary(op))
217 }
218 Operation::RunSql(rs) => {
219 if let Some(ref down) = rs.down_sql {
220 format!(
221 "Operation::run_sql_reversible(\"{}\", \"{}\")",
222 escape_str(&rs.up_sql),
223 escape_str(down)
224 )
225 } else {
226 format!("Operation::run_sql(\"{}\")", escape_str(&rs.up_sql))
227 }
228 }
229 }
230}
231
232fn render_create_table(ct: &CreateTableOp) -> String {
234 let mut s = String::from("CreateTableBuilder::new()\n");
235 s.push_str(&format!(" .name(\"{}\")\n", ct.name));
236 for col in &ct.columns {
237 s.push_str(&format!(
238 " .column({})\n",
239 render_column_builder(&col.name, col)
240 ));
241 }
242 if ct.if_not_exists {
243 s.push_str(" .if_not_exists()\n");
244 }
245 s.push_str(" .build()\n");
246 s.push_str(" .into()");
247 s
248}
249
250fn render_column_builder(_name: &str, col: &super::column_builder::ColumnDefinition) -> String {
252 let type_fn = match &col.data_type {
253 DataType::Bigint => {
254 format!("bigint(\"{}\")", col.name)
255 }
256 DataType::Integer => {
257 format!("integer(\"{}\")", col.name)
258 }
259 DataType::Smallint => {
260 format!("smallint(\"{}\")", col.name)
261 }
262 DataType::Text => {
263 format!("text(\"{}\")", col.name)
264 }
265 DataType::Varchar(Some(len)) => {
266 format!("varchar(\"{}\", {len})", col.name)
267 }
268 DataType::Varchar(None) => {
269 format!("text(\"{}\")", col.name)
270 }
271 DataType::Boolean => {
272 format!("boolean(\"{}\")", col.name)
273 }
274 DataType::Timestamp => {
275 format!("timestamp(\"{}\")", col.name)
276 }
277 DataType::Datetime => {
278 format!("datetime(\"{}\")", col.name)
279 }
280 DataType::Date => {
281 format!("date(\"{}\")", col.name)
282 }
283 DataType::Time => {
284 format!("time(\"{}\")", col.name)
285 }
286 DataType::Real => {
287 format!("real(\"{}\")", col.name)
288 }
289 DataType::Double => {
290 format!("double(\"{}\")", col.name)
291 }
292 DataType::Blob => {
293 format!("blob(\"{}\")", col.name)
294 }
295 DataType::Decimal {
296 precision: Some(p),
297 scale: Some(s),
298 } => {
299 format!("decimal(\"{}\", {p}, {s})", col.name)
300 }
301 DataType::Numeric {
302 precision: Some(p),
303 scale: Some(s),
304 } => {
305 format!("numeric(\"{}\", {p}, {s})", col.name)
306 }
307 DataType::Char(Some(len)) => {
308 format!("char(\"{}\", {len})", col.name)
309 }
310 _ => format!("text(\"{}\")", col.name),
311 };
312
313 let mut chain = type_fn;
314 if col.primary_key {
315 chain.push_str(".primary_key()");
316 }
317 if col.autoincrement {
318 chain.push_str(".autoincrement()");
319 }
320 if !col.nullable && !col.primary_key {
321 chain.push_str(".not_null()");
322 }
323 if col.unique {
324 chain.push_str(".unique()");
325 }
326 if let Some(ref default) = col.default {
327 match default {
328 DefaultValue::Boolean(b) => {
329 chain.push_str(&format!(".default_bool({b})"));
330 }
331 DefaultValue::Integer(i) => {
332 chain.push_str(&format!(".default_int({i})"));
333 }
334 DefaultValue::Float(f) => {
335 chain.push_str(&format!(".default_float({f})"));
336 }
337 DefaultValue::String(s) => {
338 chain.push_str(&format!(".default_str(\"{}\")", escape_str(s)));
339 }
340 DefaultValue::Null => {
341 chain.push_str(".default_null()");
342 }
343 DefaultValue::Expression(expr) => {
344 chain.push_str(&format!(".default_expr(\"{}\")", escape_str(expr)));
345 }
346 }
347 }
348 chain.push_str(".build()");
349 chain
350}
351
352fn render_alter_column(ac: &super::operation::AlterColumnOp) -> String {
354 let change = match &ac.change {
355 AlterColumnChange::SetDataType(dt) => {
356 format!("AlterColumnChange::SetDataType(DataType::{:?})", dt)
357 }
358 AlterColumnChange::SetNullable(n) => {
359 format!("AlterColumnChange::SetNullable({n})")
360 }
361 AlterColumnChange::SetDefault(d) => {
362 format!("AlterColumnChange::SetDefault({})", render_default_value(d))
363 }
364 AlterColumnChange::DropDefault => "AlterColumnChange::DropDefault".to_string(),
365 AlterColumnChange::SetUnique(u) => {
366 format!("AlterColumnChange::SetUnique({u})")
367 }
368 AlterColumnChange::SetAutoincrement(a) => {
369 format!("AlterColumnChange::SetAutoincrement({a})")
370 }
371 };
372 format!(
373 "Operation::AlterColumn(AlterColumnOp {{ \
374 table: \"{}\".into(), \
375 column: \"{}\".into(), \
376 change: {} }})",
377 ac.table, ac.column, change
378 )
379}
380
381fn render_default_value(dv: &DefaultValue) -> String {
383 match dv {
384 DefaultValue::Null => "DefaultValue::Null".to_string(),
385 DefaultValue::Boolean(b) => {
386 format!("DefaultValue::Boolean({b})")
387 }
388 DefaultValue::Integer(i) => {
389 format!("DefaultValue::Integer({i})")
390 }
391 DefaultValue::Float(f) => {
392 format!("DefaultValue::Float({f})")
393 }
394 DefaultValue::String(s) => {
395 format!("DefaultValue::String(\"{}\".into())", escape_str(s))
396 }
397 DefaultValue::Expression(e) => {
398 format!("DefaultValue::Expression(\"{}\".into())", escape_str(e))
399 }
400 }
401}
402
403fn escape_str(s: &str) -> String {
405 s.replace('\\', "\\\\").replace('"', "\\\"")
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use crate::migrations::column_builder::varchar;
412 use crate::migrations::diff::SchemaDiff;
413 use crate::migrations::operation::Operation;
414 use crate::migrations::table_builder::CreateTableBuilder;
415
416 #[test]
417 fn id_to_struct_name_works() {
418 assert_eq!(
419 id_to_struct_name("0001_create_users"),
420 "Migration0001CreateUsers"
421 );
422 assert_eq!(id_to_struct_name("0002_add_email"), "Migration0002AddEmail");
423 }
424
425 #[test]
426 fn generate_simple_migration() {
427 let diff = SchemaDiff {
428 operations: vec![Operation::add_column(
429 "users",
430 varchar("email", 255).not_null().build(),
431 )],
432 ambiguous: vec![],
433 warnings: vec![],
434 };
435
436 let code = generate_migration_code("0002_add_email", &diff);
437 assert!(code.contains("struct Migration0002AddEmail"));
438 assert!(code.contains("fn up()"));
439 assert!(code.contains("fn down()"));
440 assert!(code.contains("add_column"));
441 assert!(code.contains("varchar"));
442 assert!(code.contains("drop_column"));
443 }
444
445 #[test]
446 fn generate_create_table_migration() {
447 let op: Operation = CreateTableBuilder::new()
448 .name("users")
449 .column(
450 crate::migrations::column_builder::bigint("id")
451 .primary_key()
452 .autoincrement()
453 .build(),
454 )
455 .column(varchar("name", 255).not_null().unique().build())
456 .build()
457 .into();
458
459 let diff = SchemaDiff {
460 operations: vec![op],
461 ambiguous: vec![],
462 warnings: vec![],
463 };
464
465 let code = generate_migration_code("0001_create_users", &diff);
466 assert!(code.contains("CreateTableBuilder::new()"));
467 assert!(code.contains(".primary_key()"));
468 assert!(code.contains(".autoincrement()"));
469 assert!(code.contains(".unique()"));
470 assert!(code.contains("drop_table"));
472 }
473}