Skip to main content

drizzle_cli/commands/
new.rs

1//! Interactive schema builder (`drizzle new`)
2//!
3//! Walks the user through an interactive wizard to define tables, columns,
4//! indexes, and foreign keys, then generates Rust schema code using the
5//! existing codegen pipeline (the same one `drizzle introspect` uses).
6//!
7//! Supports JSON import/export for CI-friendly, reproducible schema generation:
8//! - `drizzle new --json` reads a JSON schema definition from stdin
9//! - `drizzle new --json --from file.json` reads from a file
10//! - `drizzle new --export-json out.json` exports the schema as JSON
11//! - `drizzle new --schema-help` prints the expected JSON shape
12
13use std::borrow::Cow;
14use std::collections::HashSet;
15use std::path::PathBuf;
16
17use inquire::validator::Validation;
18use inquire::{Confirm, MultiSelect, Select, Text};
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21
22use crate::config::{Config, Dialect};
23use crate::error::CliError;
24use crate::output;
25
26// ── Public API ──────────────────────────────────────────────────────────────
27
28#[derive(clap::Args, Debug)]
29pub struct NewOptions {
30    /// Override dialect (sqlite, postgresql)
31    #[arg(long)]
32    pub dialect: Option<Dialect>,
33
34    /// Override schema output path
35    #[arg(long)]
36    pub schema: Option<String>,
37
38    /// Read schema definition from JSON (stdin by default)
39    #[arg(long)]
40    pub json: bool,
41
42    /// Read JSON from a file instead of stdin (requires --json)
43    #[arg(long, requires = "json", value_name = "PATH")]
44    pub from: Option<PathBuf>,
45
46    /// Export the schema definition as JSON after building
47    #[arg(long = "export-json", value_name = "PATH")]
48    pub export_json: Option<PathBuf>,
49
50    /// Print the expected JSON schema shape and exit
51    #[arg(long = "schema-help")]
52    pub schema_help: bool,
53}
54
55/// Run the `new` command to scaffold a schema definition.
56///
57/// # Errors
58///
59/// Returns [`CliError`] if loading or validating the JSON schema definition
60/// fails, if interactive prompts are cancelled, or if writing the generated
61/// Rust schema files fails.
62pub fn run(config: Option<&Config>, options: &NewOptions) -> Result<(), CliError> {
63    // --schema-help: print annotated example and exit
64    if options.schema_help {
65        print_json_schema();
66        return Ok(());
67    }
68
69    // Build the schema definition from either JSON input or interactive prompts
70    let def = if options.json {
71        load_json(options.from.as_deref())?
72    } else {
73        collect_interactively(config, options)?
74    };
75
76    // Validate the schema definition
77    validate_schema(&def)?;
78
79    // Export JSON if requested
80    if let Some(ref export_path) = options.export_json {
81        export_to_json(&def, export_path)?;
82    }
83
84    // Determine output path (JSON definition's output_path, or CLI override)
85    let output_path = if let Some(ref s) = options.schema {
86        s.clone()
87    } else {
88        def.output_path.clone()
89    };
90
91    // Generate code
92    let code = match def.dialect {
93        Dialect::Sqlite | Dialect::Turso => generate_sqlite(
94            &def.tables,
95            &def.indexes,
96            &def.foreign_keys,
97            &def.schema_name,
98            def.casing,
99        ),
100        Dialect::Postgresql => generate_postgres(
101            &def.tables,
102            &def.indexes,
103            &def.foreign_keys,
104            &def.enums,
105            &def.schema_name,
106            def.casing,
107        ),
108    };
109
110    // Write output
111    let path = PathBuf::from(&output_path);
112    if let Some(parent) = path.parent() {
113        std::fs::create_dir_all(parent)
114            .map_err(|e| CliError::IoError(format!("Failed to create directory: {e}")))?;
115    }
116    std::fs::write(&path, &code)
117        .map_err(|e| CliError::IoError(format!("Failed to write schema file: {e}")))?;
118
119    // Print summary
120    println!();
121    println!("{}", output::success("Schema generated successfully!"));
122    println!();
123    println!(
124        "  Tables: {}",
125        def.tables
126            .iter()
127            .map(|t| t.name.as_str())
128            .collect::<Vec<_>>()
129            .join(", ")
130    );
131    if !def.indexes.is_empty() {
132        println!(
133            "  Indexes: {}",
134            def.indexes
135                .iter()
136                .map(|i| i.name.as_str())
137                .collect::<Vec<_>>()
138                .join(", ")
139        );
140    }
141    if !def.foreign_keys.is_empty() {
142        println!("  Foreign keys: {}", def.foreign_keys.len());
143    }
144    println!("  Output: {output_path}");
145    if let Some(ref export_path) = options.export_json {
146        println!("  JSON export: {}", export_path.display());
147    }
148    println!();
149    println!("Next steps:");
150    println!(
151        "  Run {} to generate your first migration",
152        output::heading("drizzle generate")
153    );
154
155    Ok(())
156}
157
158// ── Schema definition (top-level JSON document) ─────────────────────────────
159
160#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
161pub struct SchemaDefinition {
162    pub dialect: Dialect,
163    #[serde(default = "default_casing")]
164    pub casing: FieldCasing,
165    #[serde(default = "default_schema_name")]
166    pub schema_name: String,
167    #[serde(default = "default_output_path")]
168    pub output_path: String,
169    #[serde(default)]
170    pub enums: Vec<EnumDef>,
171    pub tables: Vec<TableDef>,
172    #[serde(default)]
173    pub indexes: Vec<IndexDef>,
174    #[serde(default)]
175    pub foreign_keys: Vec<ForeignKeyDef>,
176}
177
178const fn default_casing() -> FieldCasing {
179    FieldCasing::Snake
180}
181
182fn default_schema_name() -> String {
183    "AppSchema".to_string()
184}
185
186fn default_output_path() -> String {
187    "src/schema.rs".to_string()
188}
189
190fn default_fk_action() -> String {
191    "No Action".to_string()
192}
193
194// ── Intermediate structs ────────────────────────────────────────────────────
195
196#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
197pub struct EnumDef {
198    pub name: String,
199    pub variants: Vec<String>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
203pub struct TableDef {
204    pub name: String,
205    pub columns: Vec<ColumnDef>,
206    /// `SQLite` only
207    #[serde(default)]
208    pub strict: bool,
209    /// `SQLite` only
210    #[serde(default)]
211    pub without_rowid: bool,
212    /// `PostgreSQL` only
213    #[serde(default = "default_pg_schema")]
214    pub pg_schema: String,
215}
216
217fn default_pg_schema() -> String {
218    "public".to_string()
219}
220
221/// Auto-generation strategy for a column value.
222///
223/// `autoincrement` is `SQLite`-specific (`INTEGER PRIMARY KEY AUTOINCREMENT`) and
224/// `identity` is `PostgreSQL`-specific (`GENERATED ALWAYS AS IDENTITY`). They are
225/// mutually exclusive dialect variants, so they live in a single optional enum
226/// rather than two parallel booleans.
227#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
228#[serde(rename_all = "lowercase")]
229pub enum AutoGenKind {
230    /// `SQLite` `INTEGER PRIMARY KEY AUTOINCREMENT`.
231    Autoincrement,
232    /// `PostgreSQL` `GENERATED ALWAYS AS IDENTITY`.
233    Identity,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
237pub struct ColumnDef {
238    pub name: String,
239    /// The SQL type string the codegen expects
240    pub sql_type: String,
241    #[serde(default)]
242    pub not_null: bool,
243    #[serde(default)]
244    pub primary_key: bool,
245    #[serde(default)]
246    pub unique: bool,
247    #[serde(default, skip_serializing_if = "Option::is_none")]
248    pub default: Option<String>,
249    /// Auto-generation strategy (`SQLite` autoincrement / `PostgreSQL` identity).
250    #[serde(default, skip_serializing_if = "Option::is_none")]
251    pub auto_gen: Option<AutoGenKind>,
252    /// For PG enum columns: the enum name
253    #[serde(default, skip_serializing_if = "Option::is_none")]
254    pub enum_name: Option<String>,
255}
256
257impl ColumnDef {
258    #[must_use]
259    pub const fn is_autoincrement(&self) -> bool {
260        matches!(self.auto_gen, Some(AutoGenKind::Autoincrement))
261    }
262
263    #[must_use]
264    pub const fn is_identity(&self) -> bool {
265        matches!(self.auto_gen, Some(AutoGenKind::Identity))
266    }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
270pub struct IndexDef {
271    pub name: String,
272    pub table: String,
273    pub columns: Vec<String>,
274    #[serde(default)]
275    pub unique: bool,
276    /// PG schema
277    #[serde(default)]
278    pub pg_schema: String,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
282pub struct ForeignKeyDef {
283    pub name: String,
284    pub table: String,
285    pub columns: Vec<String>,
286    pub table_to: String,
287    pub columns_to: Vec<String>,
288    #[serde(default = "default_fk_action")]
289    pub on_delete: String,
290    #[serde(default = "default_fk_action")]
291    pub on_update: String,
292    /// PG schema
293    #[serde(default)]
294    pub pg_schema: String,
295    #[serde(default)]
296    pub pg_schema_to: String,
297}
298
299#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
300pub enum FieldCasing {
301    #[serde(rename = "snake_case")]
302    Snake,
303    #[serde(rename = "camelCase")]
304    Camel,
305}
306
307// ── JSON import/export ──────────────────────────────────────────────────────
308
309fn load_json(from: Option<&std::path::Path>) -> Result<SchemaDefinition, CliError> {
310    let content = if let Some(path) = from {
311        std::fs::read_to_string(path)
312            .map_err(|e| CliError::IoError(format!("Failed to read {}: {e}", path.display())))?
313    } else {
314        use std::io::Read;
315        let mut buf = String::new();
316        std::io::stdin()
317            .read_to_string(&mut buf)
318            .map_err(|e| CliError::IoError(format!("Failed to read stdin: {e}")))?;
319        buf
320    };
321    serde_json::from_str(&content)
322        .map_err(|e| CliError::Other(format!("Invalid JSON schema definition: {e}")))
323}
324
325fn export_to_json(def: &SchemaDefinition, path: &std::path::Path) -> Result<(), CliError> {
326    let json = serde_json::to_string_pretty(def)
327        .map_err(|e| CliError::Other(format!("Failed to serialize schema: {e}")))?;
328    if let Some(parent) = path.parent() {
329        std::fs::create_dir_all(parent)
330            .map_err(|e| CliError::IoError(format!("Failed to create directory: {e}")))?;
331    }
332    std::fs::write(path, json)
333        .map_err(|e| CliError::IoError(format!("Failed to write JSON: {e}")))?;
334    Ok(())
335}
336
337fn print_json_schema() {
338    let schema = schemars::schema_for!(SchemaDefinition);
339    let json = serde_json::to_string_pretty(&schema).expect("schema serialization cannot fail");
340    println!("{json}");
341    println!();
342    println!(
343        "Valid on_delete/on_update actions: \"No Action\", \"Cascade\", \"Set Null\", \"Set Default\", \"Restrict\""
344    );
345    println!();
346    println!("Tip: Run `drizzle new --export-json schema.json` to export an interactive");
347    println!(
348        "session as valid JSON, then edit and replay with `drizzle new --json --from schema.json`."
349    );
350}
351
352// ── Validation ──────────────────────────────────────────────────────────────
353
354const VALID_FK_ACTIONS: &[&str] = &[
355    "No Action",
356    "Cascade",
357    "Set Null",
358    "Set Default",
359    "Restrict",
360];
361
362fn validate_schema(def: &SchemaDefinition) -> Result<(), CliError> {
363    // Must have at least one table
364    if def.tables.is_empty() {
365        return Err(CliError::Other(
366            "Schema must have at least one table".into(),
367        ));
368    }
369
370    // Check table names are valid and unique, and per-table column/dialect rules
371    let mut table_names = HashSet::new();
372    for table in &def.tables {
373        if !is_valid_identifier(&table.name) {
374            return Err(CliError::Other(format!(
375                "Invalid table name: '{}'",
376                table.name
377            )));
378        }
379        if !table_names.insert(&table.name) {
380            return Err(CliError::Other(format!(
381                "Duplicate table name: '{}'",
382                table.name
383            )));
384        }
385        validate_table(table, def.dialect)?;
386    }
387
388    let enum_names = validate_enums(def)?;
389
390    // Validate enum references in columns
391    for table in &def.tables {
392        for col in &table.columns {
393            if let Some(ref en) = col.enum_name
394                && !enum_names.contains(en.as_str())
395            {
396                return Err(CliError::Other(format!(
397                    "Column '{}.{}' references unknown enum '{}'",
398                    table.name, col.name, en
399                )));
400            }
401        }
402    }
403
404    validate_indexes(def)?;
405    validate_foreign_keys(def)?;
406
407    Ok(())
408}
409
410/// Validate a single table: column names, dialect-specific column rules, and
411/// that the table itself has at least one column.
412fn validate_table(table: &TableDef, dialect: Dialect) -> Result<(), CliError> {
413    if table.columns.is_empty() {
414        return Err(CliError::Other(format!(
415            "Table '{}' must have at least one column",
416            table.name
417        )));
418    }
419
420    let mut col_names = HashSet::new();
421    for col in &table.columns {
422        if !is_valid_identifier(&col.name) {
423            return Err(CliError::Other(format!(
424                "Invalid column name '{}' in table '{}'",
425                col.name, table.name
426            )));
427        }
428        if !col_names.insert(&col.name) {
429            return Err(CliError::Other(format!(
430                "Duplicate column name '{}' in table '{}'",
431                col.name, table.name
432            )));
433        }
434    }
435
436    match dialect {
437        Dialect::Sqlite | Dialect::Turso => {
438            for col in &table.columns {
439                if col.is_identity() {
440                    return Err(CliError::Other(format!(
441                        "Column '{}.{}': 'identity' is only supported for PostgreSQL",
442                        table.name, col.name
443                    )));
444                }
445                if col.enum_name.is_some() {
446                    return Err(CliError::Other(format!(
447                        "Column '{}.{}': 'enum_name' is only supported for PostgreSQL",
448                        table.name, col.name
449                    )));
450                }
451            }
452        }
453        Dialect::Postgresql => {
454            if table.strict {
455                return Err(CliError::Other(format!(
456                    "Table '{}': 'strict' is only supported for SQLite",
457                    table.name
458                )));
459            }
460            if table.without_rowid {
461                return Err(CliError::Other(format!(
462                    "Table '{}': 'without_rowid' is only supported for SQLite",
463                    table.name
464                )));
465            }
466            for col in &table.columns {
467                if col.is_autoincrement() {
468                    return Err(CliError::Other(format!(
469                        "Column '{}.{}': 'autoincrement' is only supported for SQLite (use 'identity' for PostgreSQL)",
470                        table.name, col.name
471                    )));
472                }
473            }
474        }
475    }
476
477    Ok(())
478}
479
480/// Validate enum definitions and return the set of declared enum names.
481fn validate_enums(def: &SchemaDefinition) -> Result<HashSet<&str>, CliError> {
482    if def.dialect != Dialect::Postgresql && !def.enums.is_empty() {
483        return Err(CliError::Other(
484            "Enums are only supported for PostgreSQL".into(),
485        ));
486    }
487    let mut enum_names = HashSet::new();
488    for e in &def.enums {
489        if !is_valid_identifier(&e.name) {
490            return Err(CliError::Other(format!("Invalid enum name: '{}'", e.name)));
491        }
492        if !enum_names.insert(e.name.as_str()) {
493            return Err(CliError::Other(format!(
494                "Duplicate enum name: '{}'",
495                e.name
496            )));
497        }
498        if e.variants.is_empty() {
499            return Err(CliError::Other(format!(
500                "Enum '{}' must have at least one variant",
501                e.name
502            )));
503        }
504    }
505    Ok(enum_names)
506}
507
508/// Validate that each index references a real table and real columns.
509fn validate_indexes(def: &SchemaDefinition) -> Result<(), CliError> {
510    for idx in &def.indexes {
511        let table = def.tables.iter().find(|t| t.name == idx.table);
512        let Some(table) = table else {
513            return Err(CliError::Other(format!(
514                "Index '{}' references unknown table '{}'",
515                idx.name, idx.table
516            )));
517        };
518        for col_name in &idx.columns {
519            if !table.columns.iter().any(|c| &c.name == col_name) {
520                return Err(CliError::Other(format!(
521                    "Index '{}' references unknown column '{}.{}'",
522                    idx.name, idx.table, col_name
523                )));
524            }
525        }
526    }
527    Ok(())
528}
529
530/// Validate that each foreign key references real tables/columns and uses a
531/// recognized on-delete / on-update action.
532fn validate_foreign_keys(def: &SchemaDefinition) -> Result<(), CliError> {
533    for fk in &def.foreign_keys {
534        // Source table
535        let src = def.tables.iter().find(|t| t.name == fk.table);
536        let Some(src) = src else {
537            return Err(CliError::Other(format!(
538                "Foreign key '{}' references unknown source table '{}'",
539                fk.name, fk.table
540            )));
541        };
542        for col_name in &fk.columns {
543            if !src.columns.iter().any(|c| &c.name == col_name) {
544                return Err(CliError::Other(format!(
545                    "Foreign key '{}' references unknown source column '{}.{}'",
546                    fk.name, fk.table, col_name
547                )));
548            }
549        }
550
551        // Target table
552        let tgt = def.tables.iter().find(|t| t.name == fk.table_to);
553        let Some(tgt) = tgt else {
554            return Err(CliError::Other(format!(
555                "Foreign key '{}' references unknown target table '{}'",
556                fk.name, fk.table_to
557            )));
558        };
559        for col_name in &fk.columns_to {
560            if !tgt.columns.iter().any(|c| &c.name == col_name) {
561                return Err(CliError::Other(format!(
562                    "Foreign key '{}' references unknown target column '{}.{}'",
563                    fk.name, fk.table_to, col_name
564                )));
565            }
566        }
567
568        // Validate FK actions
569        if !VALID_FK_ACTIONS.contains(&fk.on_delete.as_str()) {
570            return Err(CliError::Other(format!(
571                "Foreign key '{}': invalid on_delete action '{}'. Valid: {}",
572                fk.name,
573                fk.on_delete,
574                VALID_FK_ACTIONS.join(", ")
575            )));
576        }
577        if !VALID_FK_ACTIONS.contains(&fk.on_update.as_str()) {
578            return Err(CliError::Other(format!(
579                "Foreign key '{}': invalid on_update action '{}'. Valid: {}",
580                fk.name,
581                fk.on_update,
582                VALID_FK_ACTIONS.join(", ")
583            )));
584        }
585    }
586    Ok(())
587}
588
589// ── Interactive collection ──────────────────────────────────────────────────
590
591fn collect_interactively(
592    config: Option<&Config>,
593    options: &NewOptions,
594) -> Result<SchemaDefinition, CliError> {
595    // Phase 1: Setup
596    let dialect = resolve_dialect(config, options.dialect)?;
597    let casing = prompt_casing()?;
598    let output_path = resolve_output_path(config, options.schema.clone())?;
599    let schema_name = prompt_schema_name()?;
600
601    // Phase 2: Enums (PostgreSQL only)
602    let enums: Vec<EnumDef> = if dialect == Dialect::Postgresql {
603        prompt_enums()?
604    } else {
605        Vec::new()
606    };
607
608    // Phase 3 & 4: Tables + Columns
609    let mut tables: Vec<TableDef> = Vec::new();
610    loop {
611        let table = prompt_table(dialect, &enums)?;
612        tables.push(table);
613        if !confirm("Add another table?", false)? {
614            break;
615        }
616    }
617
618    // Phase 5: Indexes
619    let indexes: Vec<IndexDef> = if confirm("Add indexes?", false)? {
620        prompt_indexes(&tables)?
621    } else {
622        Vec::new()
623    };
624
625    // Phase 6: Foreign Keys
626    let foreign_keys: Vec<ForeignKeyDef> =
627        if tables.len() > 1 && confirm("Add foreign keys?", false)? {
628            prompt_foreign_keys(&tables, dialect)?
629        } else {
630            Vec::new()
631        };
632
633    Ok(SchemaDefinition {
634        dialect,
635        casing,
636        schema_name,
637        output_path,
638        enums,
639        tables,
640        indexes,
641        foreign_keys,
642    })
643}
644
645// ── Phase 1: Setup prompts ──────────────────────────────────────────────────
646
647fn resolve_dialect(
648    config: Option<&Config>,
649    cli_dialect: Option<Dialect>,
650) -> Result<Dialect, CliError> {
651    if let Some(d) = cli_dialect {
652        return Ok(d);
653    }
654    if let Some(c) = config {
655        return Ok(c.dialect());
656    }
657    let options = vec!["SQLite", "PostgreSQL"];
658    let answer = Select::new("Select database dialect:", options)
659        .prompt()
660        .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
661    match answer {
662        "SQLite" => Ok(Dialect::Sqlite),
663        "PostgreSQL" => Ok(Dialect::Postgresql),
664        _ => unreachable!(),
665    }
666}
667
668fn prompt_casing() -> Result<FieldCasing, CliError> {
669    let options = vec!["snake_case (default)", "camelCase"];
670    let answer = Select::new("Select field casing:", options)
671        .prompt()
672        .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
673    match answer {
674        s if s.starts_with("snake") => Ok(FieldCasing::Snake),
675        s if s.starts_with("camel") => Ok(FieldCasing::Camel),
676        _ => Ok(FieldCasing::Snake),
677    }
678}
679
680fn resolve_output_path(
681    config: Option<&Config>,
682    cli_schema: Option<String>,
683) -> Result<String, CliError> {
684    if let Some(s) = cli_schema {
685        return Ok(s);
686    }
687    let default = config.map_or_else(
688        || "src/schema.rs".to_string(),
689        super::super::config::Config::schema_display,
690    );
691    Text::new("Schema output path:")
692        .with_default(&default)
693        .prompt()
694        .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))
695}
696
697fn prompt_schema_name() -> Result<String, CliError> {
698    Text::new("Schema struct name:")
699        .with_default("AppSchema")
700        .prompt()
701        .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))
702}
703
704// ── Phase 2: Enums (PostgreSQL only) ────────────────────────────────────────
705
706fn prompt_enums() -> Result<Vec<EnumDef>, CliError> {
707    let mut enums = Vec::new();
708    if !confirm("Define any enums?", false)? {
709        return Ok(enums);
710    }
711    loop {
712        let name = Text::new("Enum name:")
713            .with_validator(|s: &str| {
714                if s.is_empty() {
715                    Ok(Validation::Invalid("Name cannot be empty".into()))
716                } else if !is_valid_identifier(s) {
717                    Ok(Validation::Invalid(
718                        "Must be a valid Rust identifier".into(),
719                    ))
720                } else {
721                    Ok(Validation::Valid)
722                }
723            })
724            .prompt()
725            .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
726
727        let mut variants = Vec::new();
728        loop {
729            let variant = Text::new("  Enum variant (empty to finish):")
730                .prompt()
731                .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
732            if variant.is_empty() {
733                break;
734            }
735            variants.push(variant);
736        }
737        if variants.is_empty() {
738            println!("  Skipping enum with no variants.");
739        } else {
740            enums.push(EnumDef { name, variants });
741        }
742        if !confirm("Add another enum?", false)? {
743            break;
744        }
745    }
746    Ok(enums)
747}
748
749// ── Phase 3 & 4: Tables + Columns ──────────────────────────────────────────
750
751fn prompt_table(dialect: Dialect, enums: &[EnumDef]) -> Result<TableDef, CliError> {
752    let name = Text::new("Table name:")
753        .with_validator(|s: &str| {
754            if s.is_empty() {
755                Ok(Validation::Invalid("Name cannot be empty".into()))
756            } else if !is_valid_identifier(s) {
757                Ok(Validation::Invalid(
758                    "Must be a valid Rust identifier (letters, digits, underscores)".into(),
759                ))
760            } else {
761                Ok(Validation::Valid)
762            }
763        })
764        .prompt()
765        .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
766
767    let mut strict = false;
768    let mut without_rowid = false;
769    let mut pg_schema = "public".to_string();
770
771    match dialect {
772        Dialect::Sqlite | Dialect::Turso => {
773            let table_opts = vec!["strict", "without_rowid"];
774            let selected = MultiSelect::new("Table options (space to toggle):", table_opts)
775                .prompt()
776                .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
777            strict = selected.contains(&"strict");
778            without_rowid = selected.contains(&"without_rowid");
779        }
780        Dialect::Postgresql => {
781            pg_schema = Text::new("PostgreSQL schema:")
782                .with_default("public")
783                .prompt()
784                .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
785        }
786    }
787
788    // Columns
789    let mut columns = Vec::new();
790    println!();
791    println!("  Define columns for '{name}':");
792    loop {
793        let col = prompt_column(dialect, enums)?;
794        columns.push(col);
795        if !confirm("  Add another column?", true)? {
796            break;
797        }
798    }
799
800    Ok(TableDef {
801        name,
802        columns,
803        strict,
804        without_rowid,
805        pg_schema,
806    })
807}
808
809fn prompt_column(dialect: Dialect, enums: &[EnumDef]) -> Result<ColumnDef, CliError> {
810    let col_name = Text::new("  Column name:")
811        .with_validator(|s: &str| {
812            if s.is_empty() {
813                Ok(Validation::Invalid("Name cannot be empty".into()))
814            } else if !is_valid_identifier(s) {
815                Ok(Validation::Invalid("Must be a valid identifier".into()))
816            } else {
817                Ok(Validation::Valid)
818            }
819        })
820        .prompt()
821        .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
822
823    let (sql_type, enum_name) = prompt_type(dialect, enums)?;
824
825    let nullable = confirm("  Nullable (Option<T>)?", false)?;
826
827    let constraint_opts = match dialect {
828        Dialect::Sqlite | Dialect::Turso => {
829            vec!["Primary Key", "Autoincrement", "Unique", "Default value"]
830        }
831        Dialect::Postgresql => {
832            vec![
833                "Primary Key",
834                "Identity (auto-increment)",
835                "Unique",
836                "Default value",
837            ]
838        }
839    };
840    let selected = MultiSelect::new("  Column constraints (space to toggle):", constraint_opts)
841        .prompt()
842        .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
843
844    let primary_key = selected.iter().any(|s| s.starts_with("Primary"));
845    let autoincrement = selected.iter().any(|s| s.starts_with("Autoincrement"));
846    let identity = selected.iter().any(|s| s.starts_with("Identity"));
847    let unique = selected.iter().any(|s| s.starts_with("Unique"));
848    let has_default = selected.iter().any(|s| s.starts_with("Default"));
849
850    let default = if has_default {
851        let val = Text::new("  Default value:")
852            .prompt()
853            .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
854        Some(val)
855    } else {
856        None
857    };
858
859    let auto_gen = if autoincrement {
860        Some(AutoGenKind::Autoincrement)
861    } else if identity {
862        Some(AutoGenKind::Identity)
863    } else {
864        None
865    };
866
867    Ok(ColumnDef {
868        name: col_name,
869        sql_type,
870        not_null: !nullable,
871        primary_key,
872        unique,
873        default,
874        auto_gen,
875        enum_name,
876    })
877}
878
879fn prompt_type(dialect: Dialect, enums: &[EnumDef]) -> Result<(String, Option<String>), CliError> {
880    let mut options: Vec<String> = match dialect {
881        Dialect::Sqlite | Dialect::Turso => {
882            vec![
883                "i32".into(),
884                "i64".into(),
885                "f64".into(),
886                "String".into(),
887                "bool".into(),
888                "Vec<u8>".into(),
889            ]
890        }
891        Dialect::Postgresql => {
892            vec![
893                "i16".into(),
894                "i32".into(),
895                "i64".into(),
896                "f32".into(),
897                "f64".into(),
898                "String".into(),
899                "bool".into(),
900                "Vec<u8>".into(),
901                "uuid::Uuid".into(),
902                "chrono::NaiveDate".into(),
903                "chrono::NaiveDateTime".into(),
904                "chrono::DateTime<chrono::Utc>".into(),
905                "serde_json::Value".into(),
906            ]
907        }
908    };
909
910    // Append user-defined enums as type choices
911    for e in enums {
912        options.push(format!("enum:{}", e.name));
913    }
914
915    let refs: Vec<&str> = options.iter().map(std::string::String::as_str).collect();
916    let chosen = Select::new("  Rust type:", refs)
917        .prompt()
918        .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
919
920    // Map user-friendly Rust type -> SQL type string for codegen
921    if let Some(enum_name) = chosen.strip_prefix("enum:") {
922        // For enum columns, the sql_type is the enum name itself
923        return Ok((enum_name.to_string(), Some(enum_name.to_string())));
924    }
925
926    let sql_type = match dialect {
927        Dialect::Sqlite | Dialect::Turso => match chosen {
928            "i32" | "i64" => "integer",
929            "f64" => "real",
930            "bool" => "boolean",
931            "Vec<u8>" => "blob",
932            _ => "text",
933        },
934        Dialect::Postgresql => match chosen {
935            "i16" => "int2",
936            "i32" => "int4",
937            "i64" => "int8",
938            "f32" => "float4",
939            "f64" => "float8",
940            "bool" => "bool",
941            "Vec<u8>" => "bytea",
942            "uuid::Uuid" => "uuid",
943            "chrono::NaiveDate" => "date",
944            "chrono::NaiveDateTime" => "timestamp",
945            "chrono::DateTime<chrono::Utc>" => "timestamptz",
946            "serde_json::Value" => "jsonb",
947            _ => "text",
948        },
949    };
950
951    Ok((sql_type.to_string(), None))
952}
953
954// ── Phase 5: Indexes ────────────────────────────────────────────────────────
955
956fn prompt_indexes(tables: &[TableDef]) -> Result<Vec<IndexDef>, CliError> {
957    let mut indexes = Vec::new();
958    loop {
959        let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect();
960        let table_name = Select::new("Index on which table?", table_names)
961            .prompt()
962            .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
963
964        let table = tables.iter().find(|t| t.name == table_name).unwrap();
965        let col_names: Vec<&str> = table.columns.iter().map(|c| c.name.as_str()).collect();
966
967        if col_names.is_empty() {
968            println!("  Table has no columns, skipping.");
969            if !confirm("Add another index?", false)? {
970                break;
971            }
972            continue;
973        }
974
975        let selected_cols = MultiSelect::new("Select columns for index:", col_names)
976            .prompt()
977            .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
978
979        if selected_cols.is_empty() {
980            println!("  No columns selected, skipping.");
981            if !confirm("Add another index?", false)? {
982                break;
983            }
984            continue;
985        }
986
987        let is_unique = confirm("  Unique index?", false)?;
988
989        let suggested_name = format!("{}_{}_idx", table_name, selected_cols.join("_"));
990        let idx_name = Text::new("  Index name:")
991            .with_default(&suggested_name)
992            .prompt()
993            .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
994
995        indexes.push(IndexDef {
996            name: idx_name,
997            table: table_name.to_string(),
998            columns: selected_cols
999                .into_iter()
1000                .map(std::string::ToString::to_string)
1001                .collect(),
1002            unique: is_unique,
1003            pg_schema: table.pg_schema.clone(),
1004        });
1005
1006        if !confirm("Add another index?", false)? {
1007            break;
1008        }
1009    }
1010    Ok(indexes)
1011}
1012
1013// ── Phase 6: Foreign Keys ───────────────────────────────────────────────────
1014
1015fn prompt_foreign_keys(
1016    tables: &[TableDef],
1017    dialect: Dialect,
1018) -> Result<Vec<ForeignKeyDef>, CliError> {
1019    let mut fks = Vec::new();
1020    let action_options = vec![
1021        "No Action",
1022        "Cascade",
1023        "Set Null",
1024        "Set Default",
1025        "Restrict",
1026    ];
1027    loop {
1028        let table_names: Vec<&str> = tables.iter().map(|t| t.name.as_str()).collect();
1029
1030        let src_table_name = Select::new("Source table:", table_names.clone())
1031            .prompt()
1032            .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
1033        let src_table = tables.iter().find(|t| t.name == src_table_name).unwrap();
1034        let src_col_names: Vec<&str> = src_table.columns.iter().map(|c| c.name.as_str()).collect();
1035
1036        let src_cols = MultiSelect::new("Source column(s):", src_col_names)
1037            .prompt()
1038            .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
1039
1040        let tgt_table_name = Select::new("Target (referenced) table:", table_names)
1041            .prompt()
1042            .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
1043        let tgt_table = tables.iter().find(|t| t.name == tgt_table_name).unwrap();
1044        let tgt_col_names: Vec<&str> = tgt_table.columns.iter().map(|c| c.name.as_str()).collect();
1045
1046        let tgt_cols = MultiSelect::new("Target column(s):", tgt_col_names)
1047            .prompt()
1048            .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
1049
1050        let on_delete = Select::new("ON DELETE action:", action_options.clone())
1051            .prompt()
1052            .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
1053
1054        let on_update = Select::new("ON UPDATE action:", action_options.clone())
1055            .prompt()
1056            .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))?;
1057
1058        let fk_name = format!("{}_{}_fk", src_table_name, src_cols.join("_"));
1059
1060        let pg_schema_to = match dialect {
1061            Dialect::Postgresql => tgt_table.pg_schema.clone(),
1062            _ => String::new(),
1063        };
1064
1065        fks.push(ForeignKeyDef {
1066            name: fk_name,
1067            table: src_table_name.to_string(),
1068            columns: src_cols
1069                .into_iter()
1070                .map(std::string::ToString::to_string)
1071                .collect(),
1072            table_to: tgt_table_name.to_string(),
1073            columns_to: tgt_cols
1074                .into_iter()
1075                .map(std::string::ToString::to_string)
1076                .collect(),
1077            on_delete: on_delete.to_string(),
1078            on_update: on_update.to_string(),
1079            pg_schema: src_table.pg_schema.clone(),
1080            pg_schema_to,
1081        });
1082
1083        if !confirm("Add another foreign key?", false)? {
1084            break;
1085        }
1086    }
1087    Ok(fks)
1088}
1089
1090// ── Phase 7: Code generation ────────────────────────────────────────────────
1091
1092fn generate_sqlite(
1093    tables: &[TableDef],
1094    indexes: &[IndexDef],
1095    fks: &[ForeignKeyDef],
1096    schema_name: &str,
1097    casing: FieldCasing,
1098) -> String {
1099    use drizzle_migrations::sqlite::codegen;
1100    use drizzle_migrations::sqlite::collection::SQLiteDDL;
1101    use drizzle_types::sqlite::ddl::{
1102        Column, ForeignKey, Index, IndexColumn, PrimaryKey, Table, UniqueConstraint,
1103    };
1104
1105    let mut ddl = SQLiteDDL::new();
1106
1107    for (table_idx, table) in tables.iter().enumerate() {
1108        let mut t = Table::new(table.name.clone());
1109        if table.strict {
1110            t = t.strict();
1111        }
1112        if table.without_rowid {
1113            t = t.without_rowid();
1114        }
1115        ddl.tables.push(t);
1116
1117        let mut pk_cols: Vec<String> = Vec::new();
1118        let mut unique_cols: Vec<String> = Vec::new();
1119
1120        for (col_idx, col) in table.columns.iter().enumerate() {
1121            let mut column =
1122                Column::new(table.name.clone(), col.name.clone(), col.sql_type.clone());
1123            if col.not_null {
1124                column = column.not_null();
1125            }
1126            if col.is_autoincrement() {
1127                column = column.autoincrement();
1128            }
1129            if let Some(ref default) = col.default {
1130                column = column.default_value(default.clone());
1131            }
1132            // Set ordinal position to preserve order
1133            column.ordinal_position = Some(
1134                i32::try_from(col_idx)
1135                    .ok()
1136                    .and_then(|i| i.checked_add(1))
1137                    .unwrap_or(i32::MAX),
1138            );
1139            ddl.columns.push(column);
1140
1141            if col.primary_key {
1142                pk_cols.push(col.name.clone());
1143            }
1144            if col.unique {
1145                unique_cols.push(col.name.clone());
1146            }
1147        }
1148
1149        if !pk_cols.is_empty() {
1150            ddl.pks.push(PrimaryKey::from_strings(
1151                table.name.clone(),
1152                format!("{}_pk", table.name),
1153                pk_cols,
1154            ));
1155        }
1156        for uc in unique_cols {
1157            ddl.uniques.push(UniqueConstraint::from_strings(
1158                table.name.clone(),
1159                format!("{}_{}_unique", table.name, uc),
1160                vec![uc],
1161            ));
1162        }
1163
1164        // Drop the table_idx binding explicitly
1165        let _ = table_idx;
1166    }
1167
1168    // Add indexes
1169    for idx in indexes {
1170        let columns: Vec<IndexColumn> = idx
1171            .columns
1172            .iter()
1173            .map(|c| IndexColumn::new(c.clone()))
1174            .collect();
1175        let mut index = Index::new(idx.table.clone(), idx.name.clone(), columns);
1176        if idx.unique {
1177            index = index.unique();
1178        }
1179        ddl.indexes.push(index);
1180    }
1181
1182    // Add foreign keys
1183    for fk in fks {
1184        let mut foreign_key = ForeignKey::from_strings(
1185            fk.table.clone(),
1186            fk.name.clone(),
1187            fk.columns.clone(),
1188            fk.table_to.clone(),
1189            fk.columns_to.clone(),
1190        );
1191        if fk.on_delete != "No Action" {
1192            foreign_key = foreign_key.on_delete(fk.on_delete.to_uppercase());
1193        }
1194        if fk.on_update != "No Action" {
1195            foreign_key = foreign_key.on_update(fk.on_update.to_uppercase());
1196        }
1197        ddl.fks.push(foreign_key);
1198    }
1199
1200    let field_casing = match casing {
1201        FieldCasing::Snake => codegen::FieldCasing::Snake,
1202        FieldCasing::Camel => codegen::FieldCasing::Camel,
1203    };
1204
1205    let options = codegen::CodegenOptions {
1206        module_doc: Some("Generated by `drizzle new`".to_string()),
1207        include_schema: true,
1208        schema_name: schema_name.to_string(),
1209        use_pub: true,
1210        field_casing,
1211    };
1212
1213    codegen::generate_rust_schema(&ddl, &options).code
1214}
1215
1216fn generate_postgres(
1217    tables: &[TableDef],
1218    indexes: &[IndexDef],
1219    fks: &[ForeignKeyDef],
1220    enums: &[EnumDef],
1221    schema_name: &str,
1222    casing: FieldCasing,
1223) -> String {
1224    use drizzle_migrations::postgres::codegen;
1225    use drizzle_migrations::postgres::collection::PostgresDDL;
1226    use drizzle_types::postgres::ddl::{Enum, Table};
1227
1228    let mut ddl = PostgresDDL::new();
1229
1230    // Add enums
1231    for e in enums {
1232        let values: Vec<Cow<'static, str>> =
1233            e.variants.iter().map(|v| Cow::Owned(v.clone())).collect();
1234        ddl.enums.push(Enum::new(
1235            "public",
1236            Cow::<str>::Owned(e.name.clone()),
1237            Cow::<[Cow<'static, str>]>::Owned(values),
1238        ));
1239    }
1240
1241    // Add tables + columns
1242    for table in tables {
1243        ddl.tables
1244            .push(Table::new(table.pg_schema.clone(), table.name.clone()));
1245        add_postgres_table_columns(&mut ddl, table);
1246    }
1247
1248    add_postgres_indexes(&mut ddl, indexes);
1249    add_postgres_foreign_keys(&mut ddl, fks);
1250
1251    let field_casing = match casing {
1252        FieldCasing::Snake => codegen::FieldCasing::Snake,
1253        FieldCasing::Camel => codegen::FieldCasing::Camel,
1254    };
1255
1256    let options = codegen::CodegenOptions {
1257        module_doc: Some("Generated by `drizzle new`".to_string()),
1258        include_schema: true,
1259        schema_name: schema_name.to_string(),
1260        use_pub: true,
1261        field_casing,
1262    };
1263
1264    codegen::generate_rust_schema(&ddl, &options).code
1265}
1266
1267/// Populate `ddl.columns`, `ddl.pks`, `ddl.uniques` for a single postgres table.
1268fn add_postgres_table_columns(
1269    ddl: &mut drizzle_migrations::postgres::collection::PostgresDDL,
1270    table: &TableDef,
1271) {
1272    use drizzle_types::postgres::ddl::{Column, PrimaryKey, UniqueConstraint};
1273
1274    let mut pk_cols: Vec<String> = Vec::new();
1275    let mut unique_cols: Vec<String> = Vec::new();
1276
1277    for (col_idx, col) in table.columns.iter().enumerate() {
1278        let mut column = Column::new(
1279            table.pg_schema.clone(),
1280            table.name.clone(),
1281            col.name.clone(),
1282            col.sql_type.clone(),
1283        );
1284        if col.not_null {
1285            column = column.not_null();
1286        }
1287        if let Some(ref default) = col.default {
1288            column = column.default_value(default.clone());
1289        }
1290        if col.is_identity() {
1291            use drizzle_types::postgres::ddl::Identity;
1292            let seq_name = format!("{}_{}_seq", table.name, col.name);
1293            column.identity = Some(Identity::always(seq_name));
1294        }
1295        if col.enum_name.is_some() {
1296            // Set type_schema so codegen can find it in the enum_map
1297            column.type_schema = Some(Cow::Owned(table.pg_schema.clone()));
1298        }
1299        column.ordinal_position = Some(
1300            i32::try_from(col_idx)
1301                .ok()
1302                .and_then(|i| i.checked_add(1))
1303                .unwrap_or(i32::MAX),
1304        );
1305        ddl.columns.push(column);
1306
1307        if col.primary_key {
1308            pk_cols.push(col.name.clone());
1309        }
1310        if col.unique {
1311            unique_cols.push(col.name.clone());
1312        }
1313    }
1314
1315    if !pk_cols.is_empty() {
1316        ddl.pks.push(PrimaryKey::from_strings(
1317            table.pg_schema.clone(),
1318            table.name.clone(),
1319            format!("{}_pk", table.name),
1320            pk_cols,
1321        ));
1322    }
1323    for uc in unique_cols {
1324        ddl.uniques.push(UniqueConstraint::from_strings(
1325            table.pg_schema.clone(),
1326            table.name.clone(),
1327            format!("{}_{}_unique", table.name, uc),
1328            vec![uc],
1329        ));
1330    }
1331}
1332
1333/// Append index definitions to the postgres DDL collection.
1334fn add_postgres_indexes(
1335    ddl: &mut drizzle_migrations::postgres::collection::PostgresDDL,
1336    indexes: &[IndexDef],
1337) {
1338    use drizzle_types::postgres::ddl::{Index, IndexColumn};
1339    for idx in indexes {
1340        let columns: Vec<IndexColumn> = idx
1341            .columns
1342            .iter()
1343            .map(|c| IndexColumn::new(c.clone()))
1344            .collect();
1345        let mut index = Index::new(
1346            idx.pg_schema.clone(),
1347            idx.table.clone(),
1348            idx.name.clone(),
1349            columns,
1350        );
1351        if idx.unique {
1352            index = index.unique();
1353        }
1354        ddl.indexes.push(index);
1355    }
1356}
1357
1358/// Append foreign-key definitions to the postgres DDL collection.
1359fn add_postgres_foreign_keys(
1360    ddl: &mut drizzle_migrations::postgres::collection::PostgresDDL,
1361    fks: &[ForeignKeyDef],
1362) {
1363    use drizzle_types::postgres::ddl::ForeignKey;
1364    for fk in fks {
1365        let mut foreign_key = ForeignKey::from_strings(
1366            fk.pg_schema.clone(),
1367            fk.table.clone(),
1368            fk.name.clone(),
1369            fk.columns.clone(),
1370            fk.pg_schema_to.clone(),
1371            fk.table_to.clone(),
1372            fk.columns_to.clone(),
1373        );
1374        if fk.on_delete != "No Action" {
1375            foreign_key = foreign_key.on_delete(fk.on_delete.to_uppercase());
1376        }
1377        if fk.on_update != "No Action" {
1378            foreign_key = foreign_key.on_update(fk.on_update.to_uppercase());
1379        }
1380        ddl.fks.push(foreign_key);
1381    }
1382}
1383
1384// ── Utility helpers ─────────────────────────────────────────────────────────
1385
1386fn confirm(message: &str, default: bool) -> Result<bool, CliError> {
1387    Confirm::new(message)
1388        .with_default(default)
1389        .prompt()
1390        .map_err(|e| CliError::Other(format!("Prompt cancelled: {e}")))
1391}
1392
1393fn is_valid_identifier(s: &str) -> bool {
1394    if s.is_empty() {
1395        return false;
1396    }
1397    let mut chars = s.chars();
1398    let first = chars.next().unwrap();
1399    if !first.is_ascii_alphabetic() && first != '_' {
1400        return false;
1401    }
1402    chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
1403}
1404
1405// ── Tests ───────────────────────────────────────────────────────────────────
1406
1407#[cfg(test)]
1408mod tests {
1409    use super::*;
1410
1411    fn minimal_sqlite_def() -> SchemaDefinition {
1412        SchemaDefinition {
1413            dialect: Dialect::Sqlite,
1414            casing: FieldCasing::Snake,
1415            schema_name: "TestSchema".into(),
1416            output_path: "src/schema.rs".into(),
1417            enums: vec![],
1418            tables: vec![TableDef {
1419                name: "users".into(),
1420                columns: vec![ColumnDef {
1421                    name: "id".into(),
1422                    sql_type: "integer".into(),
1423                    not_null: true,
1424                    primary_key: true,
1425                    unique: false,
1426                    default: None,
1427                    auto_gen: None,
1428                    enum_name: None,
1429                }],
1430                strict: false,
1431                without_rowid: false,
1432                pg_schema: String::new(),
1433            }],
1434            indexes: vec![],
1435            foreign_keys: vec![],
1436        }
1437    }
1438
1439    #[test]
1440    fn validate_minimal_schema() {
1441        let def = minimal_sqlite_def();
1442        assert!(validate_schema(&def).is_ok());
1443    }
1444
1445    #[test]
1446    fn validate_rejects_empty_tables() {
1447        let mut def = minimal_sqlite_def();
1448        def.tables.clear();
1449        let err = validate_schema(&def).unwrap_err();
1450        assert!(err.to_string().contains("at least one table"));
1451    }
1452
1453    #[test]
1454    fn validate_rejects_duplicate_table_names() {
1455        let mut def = minimal_sqlite_def();
1456        def.tables.push(def.tables[0].clone());
1457        let err = validate_schema(&def).unwrap_err();
1458        assert!(err.to_string().contains("Duplicate table name"));
1459    }
1460
1461    #[test]
1462    fn validate_rejects_empty_columns() {
1463        let mut def = minimal_sqlite_def();
1464        def.tables[0].columns.clear();
1465        let err = validate_schema(&def).unwrap_err();
1466        assert!(err.to_string().contains("at least one column"));
1467    }
1468
1469    #[test]
1470    fn validate_rejects_duplicate_column_names() {
1471        let mut def = minimal_sqlite_def();
1472        let dup = def.tables[0].columns[0].clone();
1473        def.tables[0].columns.push(dup);
1474        let err = validate_schema(&def).unwrap_err();
1475        assert!(err.to_string().contains("Duplicate column name"));
1476    }
1477
1478    #[test]
1479    fn validate_rejects_identity_on_sqlite() {
1480        let mut def = minimal_sqlite_def();
1481        def.tables[0].columns[0].auto_gen = Some(AutoGenKind::Identity);
1482        let err = validate_schema(&def).unwrap_err();
1483        assert!(err.to_string().contains("identity"));
1484        assert!(err.to_string().contains("PostgreSQL"));
1485    }
1486
1487    #[test]
1488    fn validate_rejects_autoincrement_on_postgres() {
1489        let mut def = minimal_sqlite_def();
1490        def.dialect = Dialect::Postgresql;
1491        def.tables[0].columns[0].auto_gen = Some(AutoGenKind::Autoincrement);
1492        let err = validate_schema(&def).unwrap_err();
1493        assert!(err.to_string().contains("autoincrement"));
1494        assert!(err.to_string().contains("SQLite"));
1495    }
1496
1497    #[test]
1498    fn validate_rejects_strict_on_postgres() {
1499        let mut def = minimal_sqlite_def();
1500        def.dialect = Dialect::Postgresql;
1501        def.tables[0].strict = true;
1502        let err = validate_schema(&def).unwrap_err();
1503        assert!(err.to_string().contains("strict"));
1504        assert!(err.to_string().contains("SQLite"));
1505    }
1506
1507    #[test]
1508    fn validate_rejects_enums_on_sqlite() {
1509        let mut def = minimal_sqlite_def();
1510        def.enums.push(EnumDef {
1511            name: "status".into(),
1512            variants: vec!["active".into()],
1513        });
1514        let err = validate_schema(&def).unwrap_err();
1515        assert!(err.to_string().contains("Enums"));
1516        assert!(err.to_string().contains("PostgreSQL"));
1517    }
1518
1519    #[test]
1520    fn validate_rejects_unknown_enum_reference() {
1521        let mut def = minimal_sqlite_def();
1522        def.dialect = Dialect::Postgresql;
1523        def.tables[0].columns[0].enum_name = Some("nonexistent".into());
1524        let err = validate_schema(&def).unwrap_err();
1525        assert!(err.to_string().contains("unknown enum"));
1526    }
1527
1528    #[test]
1529    fn validate_rejects_bad_fk_table_ref() {
1530        let mut def = minimal_sqlite_def();
1531        def.foreign_keys.push(ForeignKeyDef {
1532            name: "test_fk".into(),
1533            table: "nonexistent".into(),
1534            columns: vec!["id".into()],
1535            table_to: "users".into(),
1536            columns_to: vec!["id".into()],
1537            on_delete: "No Action".into(),
1538            on_update: "No Action".into(),
1539            pg_schema: String::new(),
1540            pg_schema_to: String::new(),
1541        });
1542        let err = validate_schema(&def).unwrap_err();
1543        assert!(err.to_string().contains("unknown source table"));
1544    }
1545
1546    #[test]
1547    fn validate_rejects_bad_fk_action() {
1548        let mut def = minimal_sqlite_def();
1549        def.tables.push(TableDef {
1550            name: "posts".into(),
1551            columns: vec![ColumnDef {
1552                name: "user_id".into(),
1553                sql_type: "integer".into(),
1554                not_null: true,
1555                primary_key: false,
1556                unique: false,
1557                default: None,
1558                auto_gen: None,
1559                enum_name: None,
1560            }],
1561            strict: false,
1562            without_rowid: false,
1563            pg_schema: String::new(),
1564        });
1565        def.foreign_keys.push(ForeignKeyDef {
1566            name: "posts_user_id_fk".into(),
1567            table: "posts".into(),
1568            columns: vec!["user_id".into()],
1569            table_to: "users".into(),
1570            columns_to: vec!["id".into()],
1571            on_delete: "INVALID".into(),
1572            on_update: "No Action".into(),
1573            pg_schema: String::new(),
1574            pg_schema_to: String::new(),
1575        });
1576        let err = validate_schema(&def).unwrap_err();
1577        assert!(err.to_string().contains("invalid on_delete"));
1578    }
1579
1580    #[test]
1581    fn validate_rejects_bad_index_column_ref() {
1582        let mut def = minimal_sqlite_def();
1583        def.indexes.push(IndexDef {
1584            name: "test_idx".into(),
1585            table: "users".into(),
1586            columns: vec!["nonexistent".into()],
1587            unique: false,
1588            pg_schema: String::new(),
1589        });
1590        let err = validate_schema(&def).unwrap_err();
1591        assert!(err.to_string().contains("unknown column"));
1592    }
1593
1594    #[test]
1595    fn json_round_trip() {
1596        let def = minimal_sqlite_def();
1597        let json = serde_json::to_string_pretty(&def).unwrap();
1598        let parsed: SchemaDefinition = serde_json::from_str(&json).unwrap();
1599        assert_eq!(parsed.dialect, def.dialect);
1600        assert_eq!(parsed.tables.len(), 1);
1601        assert_eq!(parsed.tables[0].name, "users");
1602        assert_eq!(parsed.tables[0].columns[0].name, "id");
1603    }
1604
1605    #[test]
1606    fn json_defaults_applied() {
1607        let json = r#"{
1608            "dialect": "sqlite",
1609            "tables": [{
1610                "name": "items",
1611                "columns": [{"name": "id", "sql_type": "integer"}]
1612            }]
1613        }"#;
1614        let def: SchemaDefinition = serde_json::from_str(json).unwrap();
1615        assert_eq!(def.schema_name, "AppSchema");
1616        assert_eq!(def.output_path, "src/schema.rs");
1617        assert!(def.enums.is_empty());
1618        assert!(def.indexes.is_empty());
1619        assert!(def.foreign_keys.is_empty());
1620        assert!(!def.tables[0].columns[0].not_null);
1621        assert!(!def.tables[0].columns[0].primary_key);
1622    }
1623
1624    #[test]
1625    fn json_fk_action_defaults() {
1626        let json = r#"{
1627            "name": "test_fk",
1628            "table": "a",
1629            "columns": ["x"],
1630            "table_to": "b",
1631            "columns_to": ["y"]
1632        }"#;
1633        let fk: ForeignKeyDef = serde_json::from_str(json).unwrap();
1634        assert_eq!(fk.on_delete, "No Action");
1635        assert_eq!(fk.on_update, "No Action");
1636    }
1637}