drizzle_cli/
snapshot.rs

1//! Schema snapshot builder from parsed schema files
2//!
3//! This module converts `ParseResult` from the schema parser into
4//! `Snapshot` types that can be used for migration diffing.
5
6use drizzle_migrations::parser::{ParseResult, ParsedField, ParsedIndex};
7use drizzle_migrations::postgres::PostgresSnapshot;
8use drizzle_migrations::schema::Snapshot;
9use drizzle_migrations::sqlite::SQLiteSnapshot;
10use drizzle_types::Dialect;
11use heck::ToSnakeCase;
12use std::borrow::Cow;
13
14/// Convert a `ParseResult` into a `Snapshot` for migration diffing
15///
16/// Uses the provided `dialect` from config rather than the parser-detected dialect,
17/// allowing users to have multi-dialect schema files and select which to use via config.
18pub fn parse_result_to_snapshot(result: &ParseResult, dialect: Dialect) -> Snapshot {
19    match dialect {
20        Dialect::SQLite => Snapshot::Sqlite(build_sqlite_snapshot(result)),
21        Dialect::PostgreSQL => Snapshot::Postgres(build_postgres_snapshot(result)),
22        _ => unreachable!("Unsupported dialect for drizzle-cli snapshot generation: {dialect:?}"),
23    }
24}
25
26/// Build an SQLite snapshot from parsed schema
27fn build_sqlite_snapshot(result: &ParseResult) -> SQLiteSnapshot {
28    use drizzle_migrations::sqlite::{PrimaryKey, SqliteEntity, Table, UniqueConstraint};
29
30    let mut snapshot = SQLiteSnapshot::new();
31
32    // Process tables (only those matching SQLite dialect)
33    for table in result
34        .tables
35        .values()
36        .filter(|t| t.dialect == Dialect::SQLite)
37    {
38        let table_name = table.name.to_snake_case();
39
40        // Add table entity
41        snapshot.add_entity(SqliteEntity::Table(Table::new(table_name.clone())));
42
43        // Process columns
44        let mut pk_columns = Vec::new();
45
46        for field in &table.fields {
47            let col = build_sqlite_column(&table_name, field);
48            snapshot.add_entity(SqliteEntity::Column(col));
49
50            // Track primary key columns
51            if field.is_primary_key() {
52                pk_columns.push(field.name.to_snake_case());
53            }
54
55            // Add unique constraint if column is unique (not primary)
56            if field.is_unique() && !field.is_primary_key() {
57                let col_name = field.name.to_snake_case();
58                let constraint_name = format!("{}_{}_unique", table_name, col_name);
59                snapshot.add_entity(SqliteEntity::UniqueConstraint(
60                    UniqueConstraint::from_strings(
61                        table_name.clone(),
62                        constraint_name,
63                        vec![col_name],
64                    ),
65                ));
66            }
67
68            // Add foreign key if references exist
69            if let Some(ref_target) = field.references()
70                && let Some(fk) = build_sqlite_foreign_key(&table_name, field, &ref_target)
71            {
72                snapshot.add_entity(SqliteEntity::ForeignKey(fk));
73            }
74        }
75
76        // Add primary key entity
77        if !pk_columns.is_empty() {
78            let pk_name = format!("{}_pk", table_name);
79            snapshot.add_entity(SqliteEntity::PrimaryKey(PrimaryKey::from_strings(
80                table_name, pk_name, pk_columns,
81            )));
82        }
83    }
84
85    // Process indexes (only those matching SQLite dialect)
86    for index in result
87        .indexes
88        .values()
89        .filter(|i| i.dialect == Dialect::SQLite)
90    {
91        let idx = build_sqlite_index(index);
92        snapshot.add_entity(SqliteEntity::Index(idx));
93    }
94
95    snapshot
96}
97
98/// Build a PostgreSQL snapshot from parsed schema
99fn build_postgres_snapshot(result: &ParseResult) -> PostgresSnapshot {
100    use drizzle_migrations::postgres::{
101        PostgresEntity, PrimaryKey, Schema as PgSchema, Table, UniqueConstraint,
102    };
103
104    let mut snapshot = PostgresSnapshot::new();
105
106    // Add public schema
107    snapshot.add_entity(PostgresEntity::Schema(PgSchema::new("public")));
108
109    // Process tables (only those matching PostgreSQL dialect)
110    for table in result
111        .tables
112        .values()
113        .filter(|t| t.dialect == Dialect::PostgreSQL)
114    {
115        let table_name = table.name.to_snake_case();
116
117        // Add table entity
118        snapshot.add_entity(PostgresEntity::Table(Table {
119            schema: "public".into(),
120            name: table_name.clone().into(),
121            is_rls_enabled: None,
122        }));
123
124        // Process columns
125        let mut pk_columns = Vec::new();
126
127        for field in &table.fields {
128            let col = build_postgres_column(&table_name, field);
129            snapshot.add_entity(PostgresEntity::Column(col));
130
131            // Track primary key columns
132            if field.is_primary_key() {
133                pk_columns.push(field.name.to_snake_case());
134            }
135
136            // Add unique constraint if column is unique (not primary)
137            if field.is_unique() && !field.is_primary_key() {
138                let col_name = field.name.to_snake_case();
139                snapshot.add_entity(PostgresEntity::UniqueConstraint(
140                    UniqueConstraint::from_strings(
141                        "public".to_string(),
142                        table_name.clone(),
143                        format!("{}_{}_key", table_name, col_name),
144                        vec![col_name],
145                    ),
146                ));
147            }
148
149            // Add foreign key if references exist
150            if let Some(ref_target) = field.references()
151                && let Some(fk) = build_postgres_foreign_key(&table_name, field, &ref_target)
152            {
153                snapshot.add_entity(PostgresEntity::ForeignKey(fk));
154            }
155        }
156
157        // Add primary key entity
158        if !pk_columns.is_empty() {
159            snapshot.add_entity(PostgresEntity::PrimaryKey(PrimaryKey::from_strings(
160                "public".to_string(),
161                table_name.clone(),
162                format!("{}_pkey", table_name),
163                pk_columns,
164            )));
165        }
166    }
167
168    // Process indexes (only those matching PostgreSQL dialect)
169    for index in result
170        .indexes
171        .values()
172        .filter(|i| i.dialect == Dialect::PostgreSQL)
173    {
174        let idx = build_postgres_index(index);
175        snapshot.add_entity(PostgresEntity::Index(idx));
176    }
177
178    snapshot
179}
180
181/// Build an SQLite column from a parsed field
182fn build_sqlite_column(
183    table_name: &str,
184    field: &ParsedField,
185) -> drizzle_migrations::sqlite::Column {
186    use drizzle_migrations::sqlite::Column;
187
188    let col_name = field.name.to_snake_case();
189    let col_type = infer_sqlite_type(&field.ty);
190
191    let mut col = Column::new(table_name.to_string(), col_name, col_type);
192
193    if !field.is_nullable() {
194        col = col.not_null();
195    }
196
197    if field.is_autoincrement() {
198        col = col.autoincrement();
199    }
200
201    if let Some(default) = field.default_value() {
202        col = col.default_value(default);
203    }
204
205    col
206}
207
208/// Build a PostgreSQL column from a parsed field
209fn build_postgres_column(
210    table_name: &str,
211    field: &ParsedField,
212) -> drizzle_migrations::postgres::Column {
213    use drizzle_migrations::postgres::ddl::IdentityType;
214    use drizzle_migrations::postgres::{Column, Identity};
215
216    let col_name = field.name.to_snake_case();
217    let col_type = infer_postgres_type(&field.ty);
218    let is_serial = field.has_attr("serial") || field.has_attr("bigserial");
219    let is_identity = field.has_attr("generated") || field.has_attr("identity");
220
221    Column {
222        schema: "public".into(),
223        table: table_name.to_string().into(),
224        name: col_name.clone().into(),
225        sql_type: col_type.into(),
226        type_schema: None,
227        not_null: !field.is_nullable(),
228        default: field.default_value().map(Cow::Owned),
229        generated: None,
230        identity: if is_serial || is_identity {
231            Some(Identity {
232                name: format!("{}_{}_seq", table_name, col_name).into(),
233                schema: Some("public".into()),
234                type_: if is_identity {
235                    IdentityType::Always
236                } else {
237                    IdentityType::ByDefault
238                },
239                increment: None,
240                min_value: None,
241                max_value: None,
242                start_with: None,
243                cache: None,
244                cycle: None,
245            })
246        } else {
247            None
248        },
249        dimensions: None,
250        ordinal_position: None,
251    }
252}
253
254/// Build an SQLite foreign key from a parsed field
255fn build_sqlite_foreign_key(
256    table_name: &str,
257    field: &ParsedField,
258    ref_target: &str,
259) -> Option<drizzle_migrations::sqlite::ForeignKey> {
260    use drizzle_migrations::sqlite::ForeignKey;
261
262    // Parse "Table::column" reference
263    let parts: Vec<&str> = ref_target.split("::").collect();
264    if parts.len() != 2 {
265        return None;
266    }
267
268    let ref_table = parts[0].to_snake_case();
269    let ref_column = parts[1].to_snake_case();
270    let col_name = field.name.to_snake_case();
271    let fk_name = format!(
272        "{}_{}_{}_{}_fk",
273        table_name, col_name, ref_table, ref_column
274    );
275
276    let mut fk = ForeignKey::from_strings(
277        table_name.to_string(),
278        fk_name,
279        vec![col_name],
280        ref_table,
281        vec![ref_column],
282    );
283
284    fk.on_delete = field.on_delete().map(Cow::Owned);
285    fk.on_update = field.on_update().map(Cow::Owned);
286
287    Some(fk)
288}
289
290/// Build a PostgreSQL foreign key from a parsed field
291fn build_postgres_foreign_key(
292    table_name: &str,
293    field: &ParsedField,
294    ref_target: &str,
295) -> Option<drizzle_migrations::postgres::ForeignKey> {
296    use drizzle_migrations::postgres::ForeignKey;
297
298    // Parse "Table::column" reference
299    let parts: Vec<&str> = ref_target.split("::").collect();
300    if parts.len() != 2 {
301        return None;
302    }
303
304    let ref_table = parts[0].to_snake_case();
305    let ref_column = parts[1].to_snake_case();
306    let col_name = field.name.to_snake_case();
307    let fk_name = format!(
308        "{}_{}_{}_{}_fk",
309        table_name, col_name, ref_table, ref_column
310    );
311
312    Some(ForeignKey {
313        schema: "public".into(),
314        table: table_name.to_string().into(),
315        name: fk_name.into(),
316        name_explicit: false,
317        columns: Cow::Owned(vec![Cow::Owned(col_name)]),
318        schema_to: "public".into(),
319        table_to: ref_table.into(),
320        columns_to: Cow::Owned(vec![Cow::Owned(ref_column)]),
321        on_update: field.on_update().map(Cow::Owned),
322        on_delete: field.on_delete().map(Cow::Owned),
323    })
324}
325
326/// Build an SQLite index from a parsed index
327fn build_sqlite_index(index: &ParsedIndex) -> drizzle_migrations::sqlite::Index {
328    use drizzle_migrations::sqlite::{Index, IndexColumn, IndexOrigin};
329
330    let table_name = index
331        .table_name()
332        .map(str::to_snake_case)
333        .unwrap_or_default();
334    let index_name = index.name.to_snake_case();
335
336    let columns: Vec<IndexColumn> = index
337        .columns
338        .iter()
339        .filter_map(|c| {
340            // Parse "Table::column" and extract just the column
341            c.split("::")
342                .last()
343                .map(|s| IndexColumn::new(s.to_snake_case()))
344        })
345        .collect();
346
347    Index {
348        table: table_name.into(),
349        name: index_name.into(),
350        columns,
351        is_unique: index.is_unique(),
352        where_clause: None,
353        origin: IndexOrigin::Manual,
354    }
355}
356
357/// Build a PostgreSQL index from a parsed index
358fn build_postgres_index(index: &ParsedIndex) -> drizzle_migrations::postgres::Index {
359    use drizzle_migrations::postgres::{Index, IndexColumn};
360
361    let table_name = index
362        .table_name()
363        .map(str::to_snake_case)
364        .unwrap_or_default();
365    let index_name = index.name.to_snake_case();
366
367    let columns: Vec<IndexColumn> = index
368        .columns
369        .iter()
370        .filter_map(|c| {
371            c.split("::")
372                .last()
373                .map(|s| IndexColumn::new(s.to_snake_case()))
374        })
375        .collect();
376
377    Index {
378        schema: "public".into(),
379        table: table_name.into(),
380        name: index_name.into(),
381        name_explicit: false,
382        columns,
383        is_unique: index.is_unique(),
384        where_clause: None,
385        method: None,
386        with: None,
387        concurrently: false,
388    }
389}
390
391/// Infer SQLite type from Rust type string
392fn infer_sqlite_type(rust_type: &str) -> String {
393    let base_type = rust_type
394        .trim()
395        .strip_prefix("Option<")
396        .and_then(|s| s.strip_suffix(">"))
397        .unwrap_or(rust_type)
398        .trim();
399
400    match base_type {
401        "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "isize" | "usize"
402        | "bool" => "integer".to_string(),
403        "f32" | "f64" => "real".to_string(),
404        "String" | "&str" | "str" => "text".to_string(),
405        "Vec<u8>" | "[u8]" => "blob".to_string(),
406        _ if base_type.contains("Uuid") => "text".to_string(),
407        _ if base_type.contains("DateTime") => "text".to_string(),
408        _ if base_type.contains("NaiveDate") => "text".to_string(),
409        _ => "any".to_string(),
410    }
411}
412
413/// Infer PostgreSQL type from Rust type string
414fn infer_postgres_type(rust_type: &str) -> String {
415    let base_type = rust_type
416        .trim()
417        .strip_prefix("Option<")
418        .and_then(|s| s.strip_suffix(">"))
419        .unwrap_or(rust_type)
420        .trim();
421
422    match base_type {
423        "i16" => "smallint".to_string(),
424        "i32" => "integer".to_string(),
425        "i64" => "bigint".to_string(),
426        "u8" | "u16" | "u32" => "integer".to_string(),
427        "u64" => "bigint".to_string(),
428        "f32" => "real".to_string(),
429        "f64" => "double precision".to_string(),
430        "bool" => "boolean".to_string(),
431        "String" | "&str" | "str" => "text".to_string(),
432        "Vec<u8>" | "[u8]" => "bytea".to_string(),
433        _ if base_type.contains("Uuid") => "uuid".to_string(),
434        _ if base_type.contains("DateTime") => "timestamptz".to_string(),
435        _ if base_type.contains("NaiveDateTime") => "timestamp".to_string(),
436        _ if base_type.contains("NaiveDate") => "date".to_string(),
437        _ if base_type.contains("NaiveTime") => "time".to_string(),
438        _ if base_type.contains("IpAddr") => "inet".to_string(),
439        _ if base_type.contains("MacAddr") => "macaddr".to_string(),
440        _ if base_type.contains("Point") => "point".to_string(),
441        _ if base_type.contains("Decimal") => "numeric".to_string(),
442        _ => "text".to_string(),
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[test]
451    fn test_infer_sqlite_type() {
452        assert_eq!(infer_sqlite_type("i32"), "integer");
453        assert_eq!(infer_sqlite_type("i64"), "integer");
454        assert_eq!(infer_sqlite_type("f64"), "real");
455        assert_eq!(infer_sqlite_type("String"), "text");
456        assert_eq!(infer_sqlite_type("Option<String>"), "text");
457        assert_eq!(infer_sqlite_type("Vec<u8>"), "blob");
458    }
459
460    #[test]
461    fn test_infer_postgres_type() {
462        assert_eq!(infer_postgres_type("i32"), "integer");
463        assert_eq!(infer_postgres_type("i64"), "bigint");
464        assert_eq!(infer_postgres_type("bool"), "boolean");
465        assert_eq!(infer_postgres_type("String"), "text");
466        assert_eq!(infer_postgres_type("Vec<u8>"), "bytea");
467        assert_eq!(infer_postgres_type("Uuid"), "uuid");
468    }
469
470    /// Test that changing a column from Option<String> to String generates table recreation
471    #[test]
472    fn test_nullable_to_not_null_generates_migration() {
473        use drizzle_migrations::parser::SchemaParser;
474        use drizzle_migrations::sqlite::collection::SQLiteDDL;
475        use drizzle_migrations::sqlite::diff::compute_migration;
476
477        // Previous schema: email is nullable (Option<String>)
478        let prev_code = r#"
479#[SQLiteTable]
480pub struct User {
481    #[column(primary)]
482    pub id: i64,
483    pub name: String,
484    pub email: Option<String>,
485}
486"#;
487
488        // Current schema: email is NOT nullable (String)
489        let cur_code = r#"
490#[SQLiteTable]
491pub struct User {
492    #[column(primary)]
493    pub id: i64,
494    pub name: String,
495    pub email: String,
496}
497"#;
498
499        let prev_result = SchemaParser::parse(prev_code);
500        let cur_result = SchemaParser::parse(cur_code);
501
502        let prev_snapshot = parse_result_to_snapshot(&prev_result, Dialect::SQLite);
503        let cur_snapshot = parse_result_to_snapshot(&cur_result, Dialect::SQLite);
504
505        // Extract DDL from snapshots
506        let (prev_ddl, cur_ddl) = match (&prev_snapshot, &cur_snapshot) {
507            (Snapshot::Sqlite(p), Snapshot::Sqlite(c)) => (
508                SQLiteDDL::from_entities(p.ddl.clone()),
509                SQLiteDDL::from_entities(c.ddl.clone()),
510            ),
511            _ => panic!("Expected SQLite snapshots"),
512        };
513
514        // Check that previous email column is nullable and current is not
515        let prev_email = prev_ddl
516            .columns
517            .one("user", "email")
518            .expect("email column in prev");
519        let cur_email = cur_ddl
520            .columns
521            .one("user", "email")
522            .expect("email column in cur");
523        assert!(!prev_email.not_null, "Previous email should be nullable");
524        assert!(cur_email.not_null, "Current email should be NOT NULL");
525
526        // Compute migration
527        let migration = compute_migration(&prev_ddl, &cur_ddl);
528
529        // Should have SQL statements for table recreation
530        assert!(
531            !migration.sql_statements.is_empty(),
532            "Should generate migration SQL for nullable change"
533        );
534
535        let combined = migration.sql_statements.join("\n");
536        assert!(
537            combined.contains("PRAGMA foreign_keys=OFF"),
538            "Should contain PRAGMA foreign_keys=OFF for table recreation"
539        );
540        assert!(
541            combined.contains("__new_user"),
542            "Should create temporary table __new_user"
543        );
544        assert!(
545            combined.contains("NOT NULL"),
546            "New table should have NOT NULL on email column"
547        );
548        assert!(combined.contains("DROP TABLE"), "Should drop old table");
549        assert!(
550            combined.contains("RENAME TO"),
551            "Should rename temp table to original"
552        );
553    }
554
555    /// Test that changing a column from String to Option<String> generates table recreation
556    #[test]
557    fn test_not_null_to_nullable_generates_migration() {
558        use drizzle_migrations::parser::SchemaParser;
559        use drizzle_migrations::sqlite::collection::SQLiteDDL;
560        use drizzle_migrations::sqlite::diff::compute_migration;
561
562        // Previous schema: email is NOT nullable (String)
563        let prev_code = r#"
564#[SQLiteTable]
565pub struct User {
566    #[column(primary)]
567    pub id: i64,
568    pub email: String,
569}
570"#;
571
572        // Current schema: email is nullable (Option<String>)
573        let cur_code = r#"
574#[SQLiteTable]
575pub struct User {
576    #[column(primary)]
577    pub id: i64,
578    pub email: Option<String>,
579}
580"#;
581
582        let prev_result = SchemaParser::parse(prev_code);
583        let cur_result = SchemaParser::parse(cur_code);
584
585        let prev_snapshot = parse_result_to_snapshot(&prev_result, Dialect::SQLite);
586        let cur_snapshot = parse_result_to_snapshot(&cur_result, Dialect::SQLite);
587
588        // Extract DDL from snapshots
589        let (prev_ddl, cur_ddl) = match (&prev_snapshot, &cur_snapshot) {
590            (Snapshot::Sqlite(p), Snapshot::Sqlite(c)) => (
591                SQLiteDDL::from_entities(p.ddl.clone()),
592                SQLiteDDL::from_entities(c.ddl.clone()),
593            ),
594            _ => panic!("Expected SQLite snapshots"),
595        };
596
597        // Compute migration
598        let migration = compute_migration(&prev_ddl, &cur_ddl);
599
600        // Should have SQL statements for table recreation
601        assert!(
602            !migration.sql_statements.is_empty(),
603            "Should generate migration SQL for nullable change"
604        );
605
606        let combined = migration.sql_statements.join("\n");
607        assert!(
608            combined.contains("__new_user"),
609            "Should create temporary table for recreation"
610        );
611    }
612}