prax_cli/commands/
migrate.rs

1//! `prax migrate` commands - Database migration management.
2
3use std::path::PathBuf;
4
5use crate::cli::MigrateArgs;
6use crate::commands::seed::{SeedRunner, find_seed_file, get_database_url};
7use crate::config::{CONFIG_FILE_NAME, Config, MIGRATIONS_DIR, SCHEMA_FILE_NAME};
8use crate::error::{CliError, CliResult};
9use crate::output::{self, success, warn};
10
11/// Run the migrate command
12pub async fn run(args: MigrateArgs) -> CliResult<()> {
13    match args.command {
14        crate::cli::MigrateSubcommand::Dev(dev_args) => run_dev(dev_args).await,
15        crate::cli::MigrateSubcommand::Deploy => run_deploy().await,
16        crate::cli::MigrateSubcommand::Reset(reset_args) => run_reset(reset_args).await,
17        crate::cli::MigrateSubcommand::Status => run_status().await,
18        crate::cli::MigrateSubcommand::Resolve(resolve_args) => run_resolve(resolve_args).await,
19        crate::cli::MigrateSubcommand::Diff(diff_args) => run_diff(diff_args).await,
20    }
21}
22
23/// Run `prax migrate dev` - development migration workflow
24async fn run_dev(args: crate::cli::MigrateDevArgs) -> CliResult<()> {
25    output::header("Migrate Dev");
26
27    let cwd = std::env::current_dir()?;
28    let config = load_config(&cwd)?;
29
30    let schema_path = args
31        .schema
32        .clone()
33        .unwrap_or_else(|| cwd.join(SCHEMA_FILE_NAME));
34    let migrations_dir = cwd.join(MIGRATIONS_DIR);
35
36    output::kv("Schema", &schema_path.display().to_string());
37    output::kv("Migrations", &migrations_dir.display().to_string());
38    output::newline();
39
40    // Determine total steps (5 or 6 depending on seed)
41    let total_steps = if args.skip_seed { 5 } else { 6 };
42
43    // 1. Parse and validate schema
44    output::step(1, total_steps, "Parsing schema...");
45    let schema_content = std::fs::read_to_string(&schema_path)?;
46    let schema = parse_schema(&schema_content)?;
47
48    // 2. Check for pending migrations
49    output::step(2, total_steps, "Checking migration status...");
50    let pending = check_pending_migrations(&migrations_dir)?;
51
52    if !pending.is_empty() {
53        output::list(&format!("{} pending migrations found:", pending.len()));
54        for migration in &pending {
55            output::list_item(&migration.display().to_string());
56        }
57        output::newline();
58    }
59
60    // 3. Diff schema against database
61    output::step(3, total_steps, "Comparing schema to database...");
62    let migration_name = args
63        .name
64        .unwrap_or_else(|| format!("migration_{}", chrono::Utc::now().format("%Y%m%d%H%M%S")));
65
66    // 4. Generate migration
67    output::step(4, total_steps, "Generating migration...");
68    let migration_path = create_migration(&migrations_dir, &migration_name, &schema)?;
69
70    // 5. Apply migration (if not --create-only)
71    if !args.create_only {
72        output::step(5, total_steps, "Applying migration...");
73        apply_migration(&migration_path, &config).await?;
74    } else {
75        output::step(5, total_steps, "Skipping apply (--create-only)...");
76    }
77
78    // 6. Run seed (if not --skip-seed)
79    if !args.skip_seed && !args.create_only {
80        output::step(6, total_steps, "Running seed...");
81
82        if let Some(seed_path) = find_seed_file(&cwd, &config) {
83            let database_url = get_database_url(&config)?;
84            let runner = SeedRunner::new(
85                seed_path,
86                database_url,
87                config.database.provider.clone(),
88                cwd.clone(),
89            )?;
90
91            match runner.run().await {
92                Ok(result) => {
93                    output::list_item(&format!("Seeded {} records", result.records_affected));
94                }
95                Err(e) => {
96                    output::warn(&format!("Seed failed: {}. Continuing...", e));
97                }
98            }
99        } else {
100            output::list_item("No seed file found, skipping");
101        }
102    }
103
104    output::newline();
105    success(&format!("Migration '{}' created", migration_name));
106
107    output::newline();
108    output::section("Next steps");
109    output::list_item("Review the generated migration SQL");
110    output::list_item("Run `prax generate` to update your client");
111
112    Ok(())
113}
114
115/// Run `prax migrate deploy` - production deployment
116async fn run_deploy() -> CliResult<()> {
117    output::header("Migrate Deploy");
118
119    let cwd = std::env::current_dir()?;
120    let config = load_config(&cwd)?;
121    let migrations_dir = cwd.join(MIGRATIONS_DIR);
122
123    output::kv("Migrations", &migrations_dir.display().to_string());
124    output::newline();
125
126    // Check for pending migrations
127    output::step(1, 3, "Checking for pending migrations...");
128    let pending = check_pending_migrations(&migrations_dir)?;
129
130    if pending.is_empty() {
131        output::newline();
132        success("No pending migrations to apply.");
133        return Ok(());
134    }
135
136    output::list(&format!("{} pending migrations:", pending.len()));
137    for migration in &pending {
138        output::list_item(&migration.file_name().unwrap().to_string_lossy());
139    }
140    output::newline();
141
142    // Apply migrations
143    output::step(2, 3, "Applying migrations...");
144    for migration in &pending {
145        output::list_item(&format!(
146            "Applying {}",
147            migration.file_name().unwrap().to_string_lossy()
148        ));
149        apply_migration(migration, &config).await?;
150    }
151
152    // Verify
153    output::step(3, 3, "Verifying migrations...");
154
155    output::newline();
156    success(&format!(
157        "Applied {} migrations successfully!",
158        pending.len()
159    ));
160
161    Ok(())
162}
163
164/// Run `prax migrate reset` - reset database
165async fn run_reset(args: crate::cli::MigrateResetArgs) -> CliResult<()> {
166    output::header("Migrate Reset");
167
168    let cwd = std::env::current_dir()?;
169    let config = load_config(&cwd)?;
170
171    if !args.force {
172        warn("This will delete all data in the database!");
173        output::newline();
174        if !output::confirm("Are you sure you want to reset the database?") {
175            output::newline();
176            output::info("Reset cancelled.");
177            return Ok(());
178        }
179    }
180
181    output::newline();
182    output::step(1, 4, "Dropping database...");
183    // TODO: Implement database drop
184
185    output::step(2, 4, "Creating database...");
186    // TODO: Implement database create
187
188    output::step(3, 4, "Applying migrations...");
189    let migrations_dir = cwd.join(MIGRATIONS_DIR);
190    let migrations = check_pending_migrations(&migrations_dir)?;
191
192    for migration in &migrations {
193        apply_migration(migration, &config).await?;
194    }
195
196    // Run seed if requested
197    if args.seed {
198        output::step(4, 4, "Running seed...");
199
200        // Find and run seed file
201        if let Some(seed_path) = find_seed_file(&cwd, &config) {
202            let database_url = get_database_url(&config)?;
203            let runner = SeedRunner::new(
204                seed_path,
205                database_url,
206                config.database.provider.clone(),
207                cwd,
208            )?;
209
210            let result = runner.run().await?;
211            output::list_item(&format!("Seeded {} records", result.records_affected));
212        } else {
213            output::list_item("No seed file found, skipping seed");
214        }
215    } else {
216        output::step(4, 4, "Skipping seed...");
217    }
218
219    output::newline();
220    success("Database reset complete!");
221
222    Ok(())
223}
224
225/// Run `prax migrate status` - show migration status
226async fn run_status() -> CliResult<()> {
227    output::header("Migration Status");
228
229    let cwd = std::env::current_dir()?;
230    let _config = load_config(&cwd)?;
231    let migrations_dir = cwd.join(MIGRATIONS_DIR);
232
233    // List all migrations
234    let mut migrations = Vec::new();
235    if migrations_dir.exists() {
236        for entry in std::fs::read_dir(&migrations_dir)? {
237            let entry = entry?;
238            let path = entry.path();
239            if path.is_dir() {
240                migrations.push(path);
241            }
242        }
243    }
244    migrations.sort();
245
246    if migrations.is_empty() {
247        output::info("No migrations found.");
248        output::newline();
249        output::section("Getting started");
250        output::list_item("Run `prax migrate dev` to create your first migration");
251        return Ok(());
252    }
253
254    output::section("Migrations");
255
256    for (i, migration) in migrations.iter().enumerate() {
257        let name = migration.file_name().unwrap().to_string_lossy();
258        let applied = is_migration_applied(migration)?;
259
260        let status = if applied {
261            output::style_success("✓ Applied")
262        } else {
263            output::style_pending("○ Pending")
264        };
265
266        output::numbered_item(i + 1, &format!("{} - {}", name, status));
267    }
268
269    output::newline();
270
271    let applied_count = migrations
272        .iter()
273        .filter(|m| is_migration_applied(m).unwrap_or(false))
274        .count();
275    let pending_count = migrations.len() - applied_count;
276
277    output::kv("Total", &migrations.len().to_string());
278    output::kv("Applied", &applied_count.to_string());
279    output::kv("Pending", &pending_count.to_string());
280
281    Ok(())
282}
283
284/// Run `prax migrate resolve` - resolve migration issues
285async fn run_resolve(args: crate::cli::MigrateResolveArgs) -> CliResult<()> {
286    output::header("Migrate Resolve");
287
288    if args.rolled_back {
289        output::step(1, 2, "Marking migration as rolled back...");
290        // TODO: Mark migration as rolled back in history table
291
292        output::step(2, 2, "Updating migration history...");
293
294        output::newline();
295        success(&format!(
296            "Migration '{}' marked as rolled back",
297            args.migration
298        ));
299    } else if args.applied {
300        output::step(1, 2, "Marking migration as applied...");
301        // TODO: Mark migration as applied in history table
302
303        output::step(2, 2, "Updating migration history...");
304
305        output::newline();
306        success(&format!("Migration '{}' marked as applied", args.migration));
307    } else {
308        return Err(
309            CliError::Command("Must specify --applied or --rolled-back".to_string()).into(),
310        );
311    }
312
313    Ok(())
314}
315
316/// Run `prax migrate diff` - generate migration diff without applying
317async fn run_diff(args: crate::cli::MigrateDiffArgs) -> CliResult<()> {
318    output::header("Migrate Diff");
319
320    let cwd = std::env::current_dir()?;
321    let schema_path = args.schema.unwrap_or_else(|| cwd.join(SCHEMA_FILE_NAME));
322
323    // Parse schema
324    output::step(1, 3, "Parsing schema...");
325    let schema_content = std::fs::read_to_string(&schema_path)?;
326    let schema = parse_schema(&schema_content)?;
327
328    // Get current database state
329    output::step(2, 3, "Introspecting database...");
330    // TODO: Implement database introspection
331
332    // Generate diff
333    output::step(3, 3, "Generating diff...");
334    let diff_sql = generate_schema_diff(&schema)?;
335
336    output::newline();
337
338    if diff_sql.is_empty() {
339        success("Schema is in sync with database - no changes needed");
340    } else {
341        output::section("Generated SQL");
342        output::code(&diff_sql, "sql");
343
344        if let Some(output_path) = args.output {
345            std::fs::write(&output_path, &diff_sql)?;
346            output::newline();
347            success(&format!("Diff written to {}", output_path.display()));
348        }
349    }
350
351    Ok(())
352}
353
354// =============================================================================
355// Helper Functions
356// =============================================================================
357
358fn load_config(cwd: &PathBuf) -> CliResult<Config> {
359    let config_path = cwd.join(CONFIG_FILE_NAME);
360    if config_path.exists() {
361        Config::load(&config_path)
362    } else {
363        Ok(Config::default())
364    }
365}
366
367fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
368    prax_schema::parse_schema(content)
369        .map_err(|e| CliError::Schema(format!("Failed to parse schema: {}", e)))
370}
371
372fn check_pending_migrations(migrations_dir: &PathBuf) -> CliResult<Vec<PathBuf>> {
373    let mut pending = Vec::new();
374
375    if !migrations_dir.exists() {
376        return Ok(pending);
377    }
378
379    for entry in std::fs::read_dir(migrations_dir)? {
380        let entry = entry?;
381        let path = entry.path();
382        if path.is_dir() {
383            if !is_migration_applied(&path)? {
384                pending.push(path);
385            }
386        }
387    }
388
389    pending.sort();
390    Ok(pending)
391}
392
393fn is_migration_applied(migration_path: &PathBuf) -> CliResult<bool> {
394    // Check for a marker file indicating the migration has been applied
395    // In production, this would check the migration history table
396    let marker = migration_path.join(".applied");
397    Ok(marker.exists())
398}
399
400fn create_migration(
401    migrations_dir: &PathBuf,
402    name: &str,
403    schema: &prax_schema::ast::Schema,
404) -> CliResult<PathBuf> {
405    // Create migration directory
406    let timestamp = chrono::Utc::now().format("%Y%m%d%H%M%S");
407    let migration_name = format!("{}_{}", timestamp, name);
408    let migration_path = migrations_dir.join(&migration_name);
409
410    std::fs::create_dir_all(&migration_path)?;
411
412    // Generate migration SQL
413    let sql = generate_schema_diff(schema)?;
414
415    // Write migration.sql
416    let sql_path = migration_path.join("migration.sql");
417    std::fs::write(&sql_path, &sql)?;
418
419    Ok(migration_path)
420}
421
422fn generate_schema_diff(schema: &prax_schema::ast::Schema) -> CliResult<String> {
423    use prax_schema::ast::{FieldType, ScalarType};
424
425    let mut sql = String::new();
426
427    sql.push_str("-- Migration generated by Prax\n\n");
428
429    // Generate CREATE TABLE statements for each model
430    for model in schema.models.values() {
431        let table_name = model.table_name();
432
433        sql.push_str(&format!(
434            "CREATE TABLE IF NOT EXISTS \"{}\" (\n",
435            table_name
436        ));
437
438        let mut columns = Vec::new();
439        let mut primary_keys = Vec::new();
440
441        for field in model.fields.values() {
442            if field.is_relation() {
443                continue;
444            }
445
446            let column_name = field
447                .get_attribute("map")
448                .and_then(|a| a.first_arg())
449                .and_then(|v| v.as_string())
450                .map(|s| s.to_string())
451                .unwrap_or_else(|| to_snake_case(field.name()));
452
453            let sql_type = field_type_to_sql(&field.field_type);
454            let mut column_def = format!("    \"{}\" {}", column_name, sql_type);
455
456            // Add constraints
457            if field.is_id() {
458                primary_keys.push(column_name.clone());
459            }
460
461            if field.has_attribute("auto") || field.has_attribute("autoincrement") {
462                // PostgreSQL uses SERIAL types
463                column_def = format!("    \"{}\" SERIAL", column_name);
464            }
465
466            if field.has_attribute("unique") {
467                column_def.push_str(" UNIQUE");
468            }
469
470            if !field.is_optional() && !field.is_id() {
471                column_def.push_str(" NOT NULL");
472            }
473
474            // Default values
475            if let Some(default_attr) = field.get_attribute("default") {
476                if let Some(value) = default_attr.first_arg() {
477                    let value_str = format_attribute_value(value);
478                    column_def.push_str(&format!(" DEFAULT {}", sql_default_value(&value_str)));
479                }
480            }
481
482            columns.push(column_def);
483        }
484
485        sql.push_str(&columns.join(",\n"));
486
487        if !primary_keys.is_empty() {
488            sql.push_str(",\n");
489            sql.push_str(&format!(
490                "    PRIMARY KEY (\"{}\")",
491                primary_keys.join("\", \"")
492            ));
493        }
494
495        sql.push_str("\n);\n\n");
496        sql.push_str("\n");
497    }
498
499    // Generate enums
500    for enum_def in schema.enums.values() {
501        let enum_name = enum_def
502            .attributes
503            .iter()
504            .find(|a| a.is("map"))
505            .and_then(|a: &prax_schema::ast::Attribute| a.first_arg())
506            .and_then(|v: &prax_schema::ast::AttributeValue| v.as_string())
507            .map(|s| s.to_string())
508            .unwrap_or_else(|| to_snake_case(enum_def.name()));
509
510        sql.push_str(&format!(
511            "DO $$ BEGIN\n    CREATE TYPE \"{}\" AS ENUM (",
512            enum_name
513        ));
514
515        let variants: Vec<String> = enum_def
516            .variants
517            .iter()
518            .map(|v| format!("'{}'", v.name()))
519            .collect();
520
521        sql.push_str(&variants.join(", "));
522        sql.push_str(");\nEXCEPTION\n    WHEN duplicate_object THEN null;\nEND $$;\n\n");
523    }
524
525    return Ok(sql);
526
527    fn field_type_to_sql(field_type: &FieldType) -> String {
528        match field_type {
529            FieldType::Scalar(scalar) => match scalar {
530                ScalarType::Int => "INTEGER".to_string(),
531                ScalarType::BigInt => "BIGINT".to_string(),
532                ScalarType::Float => "DOUBLE PRECISION".to_string(),
533                ScalarType::String => "TEXT".to_string(),
534                ScalarType::Boolean => "BOOLEAN".to_string(),
535                ScalarType::DateTime => "TIMESTAMP WITH TIME ZONE".to_string(),
536                ScalarType::Date => "DATE".to_string(),
537                ScalarType::Time => "TIME".to_string(),
538                ScalarType::Json => "JSONB".to_string(),
539                ScalarType::Bytes => "BYTEA".to_string(),
540                ScalarType::Decimal => "DECIMAL".to_string(),
541                ScalarType::Uuid => "UUID".to_string(),
542                ScalarType::Cuid | ScalarType::Cuid2 | ScalarType::NanoId | ScalarType::Ulid => {
543                    "TEXT".to_string()
544                }
545                ScalarType::Vector(dim) => match dim {
546                    Some(d) => format!("vector({})", d),
547                    None => "vector".to_string(),
548                },
549                ScalarType::HalfVector(dim) => match dim {
550                    Some(d) => format!("halfvec({})", d),
551                    None => "halfvec".to_string(),
552                },
553                ScalarType::SparseVector(dim) => match dim {
554                    Some(d) => format!("sparsevec({})", d),
555                    None => "sparsevec".to_string(),
556                },
557                ScalarType::Bit(dim) => match dim {
558                    Some(d) => format!("bit({})", d),
559                    None => "bit".to_string(),
560                },
561            },
562            FieldType::Enum(name) => format!("\"{}\"", to_snake_case(name)),
563            _ => "TEXT".to_string(),
564        }
565    }
566}
567
568async fn apply_migration(migration_path: &PathBuf, _config: &Config) -> CliResult<()> {
569    let sql_path = migration_path.join("migration.sql");
570
571    if !sql_path.exists() {
572        return Err(CliError::Migration(format!(
573            "Migration file not found: {}",
574            sql_path.display()
575        )));
576    }
577
578    let _sql = std::fs::read_to_string(&sql_path)?;
579
580    // TODO: Execute SQL against database
581    // This would use the database URL from config
582
583    // Mark as applied
584    let marker = migration_path.join(".applied");
585    std::fs::write(&marker, chrono::Utc::now().to_rfc3339())?;
586
587    Ok(())
588}
589
590fn sql_default_value(value: &str) -> String {
591    match value.to_lowercase().as_str() {
592        "now()" => "CURRENT_TIMESTAMP".to_string(),
593        "uuid()" => "gen_random_uuid()".to_string(),
594        "cuid()" | "cuid2()" | "nanoid()" | "ulid()" => {
595            // These need application-level generation
596            "''".to_string()
597        }
598        "true" => "TRUE".to_string(),
599        "false" => "FALSE".to_string(),
600        _ => value.to_string(),
601    }
602}
603
604fn to_snake_case(name: &str) -> String {
605    let mut result = String::new();
606    for (i, c) in name.chars().enumerate() {
607        if c.is_uppercase() {
608            if i > 0 {
609                result.push('_');
610            }
611            result.push(c.to_lowercase().next().unwrap());
612        } else {
613            result.push(c);
614        }
615    }
616    result
617}
618
619fn format_attribute_value(value: &prax_schema::ast::AttributeValue) -> String {
620    use prax_schema::ast::AttributeValue;
621
622    match value {
623        AttributeValue::String(s) => format!("\"{}\"", s),
624        AttributeValue::Int(i) => i.to_string(),
625        AttributeValue::Float(f) => f.to_string(),
626        AttributeValue::Boolean(b) => b.to_string(),
627        AttributeValue::Ident(id) => id.to_string(),
628        AttributeValue::Function(name, args) => {
629            if args.is_empty() {
630                format!("{}()", name)
631            } else {
632                let arg_strs: Vec<String> = args.iter().map(format_attribute_value).collect();
633                format!("{}({})", name, arg_strs.join(", "))
634            }
635        }
636        AttributeValue::Array(items) => {
637            let item_strs: Vec<String> = items.iter().map(format_attribute_value).collect();
638            format!("[{}]", item_strs.join(", "))
639        }
640        AttributeValue::FieldRef(field) => field.to_string(),
641        AttributeValue::FieldRefList(fields) => {
642            format!(
643                "[{}]",
644                fields
645                    .iter()
646                    .map(|f| f.to_string())
647                    .collect::<Vec<_>>()
648                    .join(", ")
649            )
650        }
651    }
652}