1use drizzle_migrations::parser::{ParseResult, ParsedField, ParsedIndex};
7use drizzle_migrations::postgres::PostgresSnapshot;
8use drizzle_migrations::schema::Snapshot;
9use drizzle_migrations::sqlite::SQLiteSnapshot;
10use drizzle_types::Dialect;
11use heck::ToSnakeCase;
12use std::borrow::Cow;
13
14pub fn parse_result_to_snapshot(result: &ParseResult, dialect: Dialect) -> Snapshot {
19 match dialect {
20 Dialect::SQLite => Snapshot::Sqlite(build_sqlite_snapshot(result)),
21 Dialect::PostgreSQL => Snapshot::Postgres(build_postgres_snapshot(result)),
22 _ => unreachable!("Unsupported dialect for drizzle-cli snapshot generation: {dialect:?}"),
23 }
24}
25
26fn build_sqlite_snapshot(result: &ParseResult) -> SQLiteSnapshot {
28 use drizzle_migrations::sqlite::{PrimaryKey, SqliteEntity, Table, UniqueConstraint};
29
30 let mut snapshot = SQLiteSnapshot::new();
31
32 for table in result
34 .tables
35 .values()
36 .filter(|t| t.dialect == Dialect::SQLite)
37 {
38 let table_name = table.name.to_snake_case();
39
40 snapshot.add_entity(SqliteEntity::Table(Table::new(table_name.clone())));
42
43 let mut pk_columns = Vec::new();
45
46 for field in &table.fields {
47 let col = build_sqlite_column(&table_name, field);
48 snapshot.add_entity(SqliteEntity::Column(col));
49
50 if field.is_primary_key() {
52 pk_columns.push(field.name.to_snake_case());
53 }
54
55 if field.is_unique() && !field.is_primary_key() {
57 let col_name = field.name.to_snake_case();
58 let constraint_name = format!("{}_{}_unique", table_name, col_name);
59 snapshot.add_entity(SqliteEntity::UniqueConstraint(
60 UniqueConstraint::from_strings(
61 table_name.clone(),
62 constraint_name,
63 vec![col_name],
64 ),
65 ));
66 }
67
68 if let Some(ref_target) = field.references()
70 && let Some(fk) = build_sqlite_foreign_key(&table_name, field, &ref_target)
71 {
72 snapshot.add_entity(SqliteEntity::ForeignKey(fk));
73 }
74 }
75
76 if !pk_columns.is_empty() {
78 let pk_name = format!("{}_pk", table_name);
79 snapshot.add_entity(SqliteEntity::PrimaryKey(PrimaryKey::from_strings(
80 table_name, pk_name, pk_columns,
81 )));
82 }
83 }
84
85 for index in result
87 .indexes
88 .values()
89 .filter(|i| i.dialect == Dialect::SQLite)
90 {
91 let idx = build_sqlite_index(index);
92 snapshot.add_entity(SqliteEntity::Index(idx));
93 }
94
95 snapshot
96}
97
98fn build_postgres_snapshot(result: &ParseResult) -> PostgresSnapshot {
100 use drizzle_migrations::postgres::{
101 PostgresEntity, PrimaryKey, Schema as PgSchema, Table, UniqueConstraint,
102 };
103
104 let mut snapshot = PostgresSnapshot::new();
105
106 snapshot.add_entity(PostgresEntity::Schema(PgSchema::new("public")));
108
109 for table in result
111 .tables
112 .values()
113 .filter(|t| t.dialect == Dialect::PostgreSQL)
114 {
115 let table_name = table.name.to_snake_case();
116
117 snapshot.add_entity(PostgresEntity::Table(Table {
119 schema: "public".into(),
120 name: table_name.clone().into(),
121 is_rls_enabled: None,
122 }));
123
124 let mut pk_columns = Vec::new();
126
127 for field in &table.fields {
128 let col = build_postgres_column(&table_name, field);
129 snapshot.add_entity(PostgresEntity::Column(col));
130
131 if field.is_primary_key() {
133 pk_columns.push(field.name.to_snake_case());
134 }
135
136 if field.is_unique() && !field.is_primary_key() {
138 let col_name = field.name.to_snake_case();
139 snapshot.add_entity(PostgresEntity::UniqueConstraint(
140 UniqueConstraint::from_strings(
141 "public".to_string(),
142 table_name.clone(),
143 format!("{}_{}_key", table_name, col_name),
144 vec![col_name],
145 ),
146 ));
147 }
148
149 if let Some(ref_target) = field.references()
151 && let Some(fk) = build_postgres_foreign_key(&table_name, field, &ref_target)
152 {
153 snapshot.add_entity(PostgresEntity::ForeignKey(fk));
154 }
155 }
156
157 if !pk_columns.is_empty() {
159 snapshot.add_entity(PostgresEntity::PrimaryKey(PrimaryKey::from_strings(
160 "public".to_string(),
161 table_name.clone(),
162 format!("{}_pkey", table_name),
163 pk_columns,
164 )));
165 }
166 }
167
168 for index in result
170 .indexes
171 .values()
172 .filter(|i| i.dialect == Dialect::PostgreSQL)
173 {
174 let idx = build_postgres_index(index);
175 snapshot.add_entity(PostgresEntity::Index(idx));
176 }
177
178 snapshot
179}
180
181fn build_sqlite_column(
183 table_name: &str,
184 field: &ParsedField,
185) -> drizzle_migrations::sqlite::Column {
186 use drizzle_migrations::sqlite::Column;
187
188 let col_name = field.name.to_snake_case();
189 let col_type = infer_sqlite_type(&field.ty);
190
191 let mut col = Column::new(table_name.to_string(), col_name, col_type);
192
193 if !field.is_nullable() {
194 col = col.not_null();
195 }
196
197 if field.is_autoincrement() {
198 col = col.autoincrement();
199 }
200
201 if let Some(default) = field.default_value() {
202 col = col.default_value(default);
203 }
204
205 col
206}
207
208fn build_postgres_column(
210 table_name: &str,
211 field: &ParsedField,
212) -> drizzle_migrations::postgres::Column {
213 use drizzle_migrations::postgres::ddl::IdentityType;
214 use drizzle_migrations::postgres::{Column, Identity};
215
216 let col_name = field.name.to_snake_case();
217 let col_type = infer_postgres_type(&field.ty);
218 let is_serial = field.has_attr("serial") || field.has_attr("bigserial");
219 let is_identity = field.has_attr("generated") || field.has_attr("identity");
220
221 Column {
222 schema: "public".into(),
223 table: table_name.to_string().into(),
224 name: col_name.clone().into(),
225 sql_type: col_type.into(),
226 type_schema: None,
227 not_null: !field.is_nullable(),
228 default: field.default_value().map(Cow::Owned),
229 generated: None,
230 identity: if is_serial || is_identity {
231 Some(Identity {
232 name: format!("{}_{}_seq", table_name, col_name).into(),
233 schema: Some("public".into()),
234 type_: if is_identity {
235 IdentityType::Always
236 } else {
237 IdentityType::ByDefault
238 },
239 increment: None,
240 min_value: None,
241 max_value: None,
242 start_with: None,
243 cache: None,
244 cycle: None,
245 })
246 } else {
247 None
248 },
249 dimensions: None,
250 ordinal_position: None,
251 }
252}
253
254fn build_sqlite_foreign_key(
256 table_name: &str,
257 field: &ParsedField,
258 ref_target: &str,
259) -> Option<drizzle_migrations::sqlite::ForeignKey> {
260 use drizzle_migrations::sqlite::ForeignKey;
261
262 let parts: Vec<&str> = ref_target.split("::").collect();
264 if parts.len() != 2 {
265 return None;
266 }
267
268 let ref_table = parts[0].to_snake_case();
269 let ref_column = parts[1].to_snake_case();
270 let col_name = field.name.to_snake_case();
271 let fk_name = format!(
272 "{}_{}_{}_{}_fk",
273 table_name, col_name, ref_table, ref_column
274 );
275
276 let mut fk = ForeignKey::from_strings(
277 table_name.to_string(),
278 fk_name,
279 vec![col_name],
280 ref_table,
281 vec![ref_column],
282 );
283
284 fk.on_delete = field.on_delete().map(Cow::Owned);
285 fk.on_update = field.on_update().map(Cow::Owned);
286
287 Some(fk)
288}
289
290fn build_postgres_foreign_key(
292 table_name: &str,
293 field: &ParsedField,
294 ref_target: &str,
295) -> Option<drizzle_migrations::postgres::ForeignKey> {
296 use drizzle_migrations::postgres::ForeignKey;
297
298 let parts: Vec<&str> = ref_target.split("::").collect();
300 if parts.len() != 2 {
301 return None;
302 }
303
304 let ref_table = parts[0].to_snake_case();
305 let ref_column = parts[1].to_snake_case();
306 let col_name = field.name.to_snake_case();
307 let fk_name = format!(
308 "{}_{}_{}_{}_fk",
309 table_name, col_name, ref_table, ref_column
310 );
311
312 Some(ForeignKey {
313 schema: "public".into(),
314 table: table_name.to_string().into(),
315 name: fk_name.into(),
316 name_explicit: false,
317 columns: Cow::Owned(vec![Cow::Owned(col_name)]),
318 schema_to: "public".into(),
319 table_to: ref_table.into(),
320 columns_to: Cow::Owned(vec![Cow::Owned(ref_column)]),
321 on_update: field.on_update().map(Cow::Owned),
322 on_delete: field.on_delete().map(Cow::Owned),
323 })
324}
325
326fn build_sqlite_index(index: &ParsedIndex) -> drizzle_migrations::sqlite::Index {
328 use drizzle_migrations::sqlite::{Index, IndexColumn, IndexOrigin};
329
330 let table_name = index
331 .table_name()
332 .map(str::to_snake_case)
333 .unwrap_or_default();
334 let index_name = index.name.to_snake_case();
335
336 let columns: Vec<IndexColumn> = index
337 .columns
338 .iter()
339 .filter_map(|c| {
340 c.split("::")
342 .last()
343 .map(|s| IndexColumn::new(s.to_snake_case()))
344 })
345 .collect();
346
347 Index {
348 table: table_name.into(),
349 name: index_name.into(),
350 columns,
351 is_unique: index.is_unique(),
352 where_clause: None,
353 origin: IndexOrigin::Manual,
354 }
355}
356
357fn build_postgres_index(index: &ParsedIndex) -> drizzle_migrations::postgres::Index {
359 use drizzle_migrations::postgres::{Index, IndexColumn};
360
361 let table_name = index
362 .table_name()
363 .map(str::to_snake_case)
364 .unwrap_or_default();
365 let index_name = index.name.to_snake_case();
366
367 let columns: Vec<IndexColumn> = index
368 .columns
369 .iter()
370 .filter_map(|c| {
371 c.split("::")
372 .last()
373 .map(|s| IndexColumn::new(s.to_snake_case()))
374 })
375 .collect();
376
377 Index {
378 schema: "public".into(),
379 table: table_name.into(),
380 name: index_name.into(),
381 name_explicit: false,
382 columns,
383 is_unique: index.is_unique(),
384 where_clause: None,
385 method: None,
386 with: None,
387 concurrently: false,
388 }
389}
390
391fn infer_sqlite_type(rust_type: &str) -> String {
393 let base_type = rust_type
394 .trim()
395 .strip_prefix("Option<")
396 .and_then(|s| s.strip_suffix(">"))
397 .unwrap_or(rust_type)
398 .trim();
399
400 match base_type {
401 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "isize" | "usize"
402 | "bool" => "integer".to_string(),
403 "f32" | "f64" => "real".to_string(),
404 "String" | "&str" | "str" => "text".to_string(),
405 "Vec<u8>" | "[u8]" => "blob".to_string(),
406 _ if base_type.contains("Uuid") => "text".to_string(),
407 _ if base_type.contains("DateTime") => "text".to_string(),
408 _ if base_type.contains("NaiveDate") => "text".to_string(),
409 _ => "any".to_string(),
410 }
411}
412
413fn infer_postgres_type(rust_type: &str) -> String {
415 let base_type = rust_type
416 .trim()
417 .strip_prefix("Option<")
418 .and_then(|s| s.strip_suffix(">"))
419 .unwrap_or(rust_type)
420 .trim();
421
422 match base_type {
423 "i16" => "smallint".to_string(),
424 "i32" => "integer".to_string(),
425 "i64" => "bigint".to_string(),
426 "u8" | "u16" | "u32" => "integer".to_string(),
427 "u64" => "bigint".to_string(),
428 "f32" => "real".to_string(),
429 "f64" => "double precision".to_string(),
430 "bool" => "boolean".to_string(),
431 "String" | "&str" | "str" => "text".to_string(),
432 "Vec<u8>" | "[u8]" => "bytea".to_string(),
433 _ if base_type.contains("Uuid") => "uuid".to_string(),
434 _ if base_type.contains("DateTime") => "timestamptz".to_string(),
435 _ if base_type.contains("NaiveDateTime") => "timestamp".to_string(),
436 _ if base_type.contains("NaiveDate") => "date".to_string(),
437 _ if base_type.contains("NaiveTime") => "time".to_string(),
438 _ if base_type.contains("IpAddr") => "inet".to_string(),
439 _ if base_type.contains("MacAddr") => "macaddr".to_string(),
440 _ if base_type.contains("Point") => "point".to_string(),
441 _ if base_type.contains("Decimal") => "numeric".to_string(),
442 _ => "text".to_string(),
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 #[test]
451 fn test_infer_sqlite_type() {
452 assert_eq!(infer_sqlite_type("i32"), "integer");
453 assert_eq!(infer_sqlite_type("i64"), "integer");
454 assert_eq!(infer_sqlite_type("f64"), "real");
455 assert_eq!(infer_sqlite_type("String"), "text");
456 assert_eq!(infer_sqlite_type("Option<String>"), "text");
457 assert_eq!(infer_sqlite_type("Vec<u8>"), "blob");
458 }
459
460 #[test]
461 fn test_infer_postgres_type() {
462 assert_eq!(infer_postgres_type("i32"), "integer");
463 assert_eq!(infer_postgres_type("i64"), "bigint");
464 assert_eq!(infer_postgres_type("bool"), "boolean");
465 assert_eq!(infer_postgres_type("String"), "text");
466 assert_eq!(infer_postgres_type("Vec<u8>"), "bytea");
467 assert_eq!(infer_postgres_type("Uuid"), "uuid");
468 }
469
470 #[test]
472 fn test_nullable_to_not_null_generates_migration() {
473 use drizzle_migrations::parser::SchemaParser;
474 use drizzle_migrations::sqlite::collection::SQLiteDDL;
475 use drizzle_migrations::sqlite::diff::compute_migration;
476
477 let prev_code = r#"
479#[SQLiteTable]
480pub struct User {
481 #[column(primary)]
482 pub id: i64,
483 pub name: String,
484 pub email: Option<String>,
485}
486"#;
487
488 let cur_code = r#"
490#[SQLiteTable]
491pub struct User {
492 #[column(primary)]
493 pub id: i64,
494 pub name: String,
495 pub email: String,
496}
497"#;
498
499 let prev_result = SchemaParser::parse(prev_code);
500 let cur_result = SchemaParser::parse(cur_code);
501
502 let prev_snapshot = parse_result_to_snapshot(&prev_result, Dialect::SQLite);
503 let cur_snapshot = parse_result_to_snapshot(&cur_result, Dialect::SQLite);
504
505 let (prev_ddl, cur_ddl) = match (&prev_snapshot, &cur_snapshot) {
507 (Snapshot::Sqlite(p), Snapshot::Sqlite(c)) => (
508 SQLiteDDL::from_entities(p.ddl.clone()),
509 SQLiteDDL::from_entities(c.ddl.clone()),
510 ),
511 _ => panic!("Expected SQLite snapshots"),
512 };
513
514 let prev_email = prev_ddl
516 .columns
517 .one("user", "email")
518 .expect("email column in prev");
519 let cur_email = cur_ddl
520 .columns
521 .one("user", "email")
522 .expect("email column in cur");
523 assert!(!prev_email.not_null, "Previous email should be nullable");
524 assert!(cur_email.not_null, "Current email should be NOT NULL");
525
526 let migration = compute_migration(&prev_ddl, &cur_ddl);
528
529 assert!(
531 !migration.sql_statements.is_empty(),
532 "Should generate migration SQL for nullable change"
533 );
534
535 let combined = migration.sql_statements.join("\n");
536 assert!(
537 combined.contains("PRAGMA foreign_keys=OFF"),
538 "Should contain PRAGMA foreign_keys=OFF for table recreation"
539 );
540 assert!(
541 combined.contains("__new_user"),
542 "Should create temporary table __new_user"
543 );
544 assert!(
545 combined.contains("NOT NULL"),
546 "New table should have NOT NULL on email column"
547 );
548 assert!(combined.contains("DROP TABLE"), "Should drop old table");
549 assert!(
550 combined.contains("RENAME TO"),
551 "Should rename temp table to original"
552 );
553 }
554
555 #[test]
557 fn test_not_null_to_nullable_generates_migration() {
558 use drizzle_migrations::parser::SchemaParser;
559 use drizzle_migrations::sqlite::collection::SQLiteDDL;
560 use drizzle_migrations::sqlite::diff::compute_migration;
561
562 let prev_code = r#"
564#[SQLiteTable]
565pub struct User {
566 #[column(primary)]
567 pub id: i64,
568 pub email: String,
569}
570"#;
571
572 let cur_code = r#"
574#[SQLiteTable]
575pub struct User {
576 #[column(primary)]
577 pub id: i64,
578 pub email: Option<String>,
579}
580"#;
581
582 let prev_result = SchemaParser::parse(prev_code);
583 let cur_result = SchemaParser::parse(cur_code);
584
585 let prev_snapshot = parse_result_to_snapshot(&prev_result, Dialect::SQLite);
586 let cur_snapshot = parse_result_to_snapshot(&cur_result, Dialect::SQLite);
587
588 let (prev_ddl, cur_ddl) = match (&prev_snapshot, &cur_snapshot) {
590 (Snapshot::Sqlite(p), Snapshot::Sqlite(c)) => (
591 SQLiteDDL::from_entities(p.ddl.clone()),
592 SQLiteDDL::from_entities(c.ddl.clone()),
593 ),
594 _ => panic!("Expected SQLite snapshots"),
595 };
596
597 let migration = compute_migration(&prev_ddl, &cur_ddl);
599
600 assert!(
602 !migration.sql_statements.is_empty(),
603 "Should generate migration SQL for nullable change"
604 );
605
606 let combined = migration.sql_statements.join("\n");
607 assert!(
608 combined.contains("__new_user"),
609 "Should create temporary table for recreation"
610 );
611 }
612}