Skip to main content

prax_cli/commands/
migrate.rs

1//! `prax migrate` commands - Database migration management.
2
3use std::path::PathBuf;
4
5use crate::cli::MigrateArgs;
6use crate::commands::seed::{SeedRunner, find_seed_file, get_database_url};
7use crate::config::{CONFIG_FILE_NAME, Config, MIGRATIONS_DIR, SCHEMA_FILE_PATH};
8use crate::error::{CliError, CliResult};
9use crate::output::{self, success, warn};
10
11/// Run the migrate command
12pub async fn run(args: MigrateArgs) -> CliResult<()> {
13    match args.command {
14        crate::cli::MigrateSubcommand::Dev(dev_args) => run_dev(dev_args).await,
15        crate::cli::MigrateSubcommand::Deploy => run_deploy().await,
16        crate::cli::MigrateSubcommand::Reset(reset_args) => run_reset(reset_args).await,
17        crate::cli::MigrateSubcommand::Status => run_status().await,
18        crate::cli::MigrateSubcommand::Resolve(resolve_args) => run_resolve(resolve_args).await,
19        crate::cli::MigrateSubcommand::Diff(diff_args) => run_diff(diff_args).await,
20    }
21}
22
23/// Run `prax migrate dev` - development migration workflow
24async fn run_dev(args: crate::cli::MigrateDevArgs) -> CliResult<()> {
25    output::header("Migrate Dev");
26
27    let cwd = std::env::current_dir()?;
28    let config = load_config(&cwd)?;
29
30    let schema_path = args
31        .schema
32        .clone()
33        .unwrap_or_else(|| cwd.join(SCHEMA_FILE_PATH));
34    let migrations_dir = cwd.join(MIGRATIONS_DIR);
35
36    output::kv("Schema", &schema_path.display().to_string());
37    output::kv("Migrations", &migrations_dir.display().to_string());
38    output::newline();
39
40    // Determine total steps (5 or 6 depending on seed)
41    let total_steps = if args.skip_seed { 5 } else { 6 };
42
43    // 1. Parse and validate schema
44    output::step(1, total_steps, "Parsing schema...");
45    let schema_content = std::fs::read_to_string(&schema_path)?;
46    let schema = parse_schema(&schema_content)?;
47
48    // 2. Check for pending migrations
49    output::step(2, total_steps, "Checking migration status...");
50    let pending = check_pending_migrations(&migrations_dir)?;
51
52    if !pending.is_empty() {
53        output::list(&format!("{} pending migrations found:", pending.len()));
54        for migration in &pending {
55            output::list_item(&migration.display().to_string());
56        }
57        output::newline();
58    }
59
60    // 3. Diff schema against database
61    output::step(3, total_steps, "Comparing schema to database...");
62    let migration_name = args
63        .name
64        .unwrap_or_else(|| format!("migration_{}", chrono::Utc::now().format("%Y%m%d%H%M%S")));
65
66    // 4. Generate migration
67    output::step(4, total_steps, "Generating migration...");
68    let migration_path = create_migration(&migrations_dir, &migration_name, &schema)?;
69
70    // 5. Apply migration (if not --create-only)
71    if !args.create_only {
72        output::step(5, total_steps, "Applying migration...");
73        apply_migration(&migration_path, &config).await?;
74    } else {
75        output::step(5, total_steps, "Skipping apply (--create-only)...");
76    }
77
78    // 6. Run seed (if not --skip-seed)
79    if !args.skip_seed && !args.create_only {
80        output::step(6, total_steps, "Running seed...");
81
82        if let Some(seed_path) = find_seed_file(&cwd, &config) {
83            let database_url = get_database_url(&config)?;
84            let runner = SeedRunner::new(
85                seed_path,
86                database_url,
87                config.database.provider.clone(),
88                cwd.clone(),
89            )?;
90
91            match runner.run().await {
92                Ok(result) => {
93                    output::list_item(&format!("Seeded {} records", result.records_affected));
94                }
95                Err(e) => {
96                    output::warn(&format!("Seed failed: {}. Continuing...", e));
97                }
98            }
99        } else {
100            output::list_item("No seed file found, skipping");
101        }
102    }
103
104    output::newline();
105    success(&format!("Migration '{}' created", migration_name));
106
107    output::newline();
108    output::section("Next steps");
109    output::list_item("Review the generated migration SQL");
110    output::list_item("Run `prax generate` to update your client");
111
112    Ok(())
113}
114
115/// Run `prax migrate deploy` - production deployment
116async fn run_deploy() -> CliResult<()> {
117    output::header("Migrate Deploy");
118
119    let cwd = std::env::current_dir()?;
120    let config = load_config(&cwd)?;
121    let migrations_dir = cwd.join(MIGRATIONS_DIR);
122
123    output::kv("Migrations", &migrations_dir.display().to_string());
124    output::newline();
125
126    // Check for pending migrations
127    output::step(1, 3, "Checking for pending migrations...");
128    let pending = check_pending_migrations(&migrations_dir)?;
129
130    if pending.is_empty() {
131        output::newline();
132        success("No pending migrations to apply.");
133        return Ok(());
134    }
135
136    output::list(&format!("{} pending migrations:", pending.len()));
137    for migration in &pending {
138        output::list_item(&migration.file_name().unwrap().to_string_lossy());
139    }
140    output::newline();
141
142    // Apply migrations
143    output::step(2, 3, "Applying migrations...");
144    for migration in &pending {
145        output::list_item(&format!(
146            "Applying {}",
147            migration.file_name().unwrap().to_string_lossy()
148        ));
149        apply_migration(migration, &config).await?;
150    }
151
152    // Verify
153    output::step(3, 3, "Verifying migrations...");
154
155    output::newline();
156    success(&format!(
157        "Applied {} migrations successfully!",
158        pending.len()
159    ));
160
161    Ok(())
162}
163
164/// Run `prax migrate reset` - reset database
165async fn run_reset(args: crate::cli::MigrateResetArgs) -> CliResult<()> {
166    output::header("Migrate Reset");
167
168    let cwd = std::env::current_dir()?;
169    let config = load_config(&cwd)?;
170
171    if !args.force {
172        warn("This will delete all data in the database!");
173        output::newline();
174        if !output::confirm("Are you sure you want to reset the database?") {
175            output::newline();
176            output::info("Reset cancelled.");
177            return Ok(());
178        }
179    }
180
181    output::newline();
182    output::step(1, 4, "Dropping database...");
183    // TODO: Implement database drop
184
185    output::step(2, 4, "Creating database...");
186    // TODO: Implement database create
187
188    output::step(3, 4, "Applying migrations...");
189    let migrations_dir = cwd.join(MIGRATIONS_DIR);
190    let migrations = check_pending_migrations(&migrations_dir)?;
191
192    for migration in &migrations {
193        apply_migration(migration, &config).await?;
194    }
195
196    // Run seed if requested
197    if args.seed {
198        output::step(4, 4, "Running seed...");
199
200        // Find and run seed file
201        if let Some(seed_path) = find_seed_file(&cwd, &config) {
202            let database_url = get_database_url(&config)?;
203            let runner = SeedRunner::new(
204                seed_path,
205                database_url,
206                config.database.provider.clone(),
207                cwd,
208            )?;
209
210            let result = runner.run().await?;
211            output::list_item(&format!("Seeded {} records", result.records_affected));
212        } else {
213            output::list_item("No seed file found, skipping seed");
214        }
215    } else {
216        output::step(4, 4, "Skipping seed...");
217    }
218
219    output::newline();
220    success("Database reset complete!");
221
222    Ok(())
223}
224
225/// Run `prax migrate status` - show migration status
226async fn run_status() -> CliResult<()> {
227    output::header("Migration Status");
228
229    let cwd = std::env::current_dir()?;
230    let _config = load_config(&cwd)?;
231    let migrations_dir = cwd.join(MIGRATIONS_DIR);
232
233    // List all migrations
234    let mut migrations = Vec::new();
235    if migrations_dir.exists() {
236        for entry in std::fs::read_dir(&migrations_dir)? {
237            let entry = entry?;
238            let path = entry.path();
239            if path.is_dir() {
240                migrations.push(path);
241            }
242        }
243    }
244    migrations.sort();
245
246    if migrations.is_empty() {
247        output::info("No migrations found.");
248        output::newline();
249        output::section("Getting started");
250        output::list_item("Run `prax migrate dev` to create your first migration");
251        return Ok(());
252    }
253
254    output::section("Migrations");
255
256    for (i, migration) in migrations.iter().enumerate() {
257        let name = migration.file_name().unwrap().to_string_lossy();
258        let applied = is_migration_applied(migration)?;
259
260        let status = if applied {
261            output::style_success("✓ Applied")
262        } else {
263            output::style_pending("○ Pending")
264        };
265
266        output::numbered_item(i + 1, &format!("{} - {}", name, status));
267    }
268
269    output::newline();
270
271    let applied_count = migrations
272        .iter()
273        .filter(|m| is_migration_applied(m).unwrap_or(false))
274        .count();
275    let pending_count = migrations.len() - applied_count;
276
277    output::kv("Total", &migrations.len().to_string());
278    output::kv("Applied", &applied_count.to_string());
279    output::kv("Pending", &pending_count.to_string());
280
281    Ok(())
282}
283
284/// Run `prax migrate resolve` - resolve migration issues
285async fn run_resolve(args: crate::cli::MigrateResolveArgs) -> CliResult<()> {
286    output::header("Migrate Resolve");
287
288    if args.rolled_back {
289        output::step(1, 2, "Marking migration as rolled back...");
290        // TODO: Mark migration as rolled back in history table
291
292        output::step(2, 2, "Updating migration history...");
293
294        output::newline();
295        success(&format!(
296            "Migration '{}' marked as rolled back",
297            args.migration
298        ));
299    } else if args.applied {
300        output::step(1, 2, "Marking migration as applied...");
301        // TODO: Mark migration as applied in history table
302
303        output::step(2, 2, "Updating migration history...");
304
305        output::newline();
306        success(&format!("Migration '{}' marked as applied", args.migration));
307    } else {
308        return Err(
309            CliError::Command("Must specify --applied or --rolled-back".to_string()).into(),
310        );
311    }
312
313    Ok(())
314}
315
316/// Run `prax migrate diff` - generate migration diff without applying
317async fn run_diff(args: crate::cli::MigrateDiffArgs) -> CliResult<()> {
318    output::header("Migrate Diff");
319
320    let cwd = std::env::current_dir()?;
321    let schema_path = args.schema.unwrap_or_else(|| cwd.join(SCHEMA_FILE_PATH));
322
323    // Parse schema
324    output::step(1, 3, "Parsing schema...");
325    let schema_content = std::fs::read_to_string(&schema_path)?;
326    let schema = parse_schema(&schema_content)?;
327
328    // Get current database state
329    output::step(2, 3, "Introspecting database...");
330    // TODO: Implement database introspection
331
332    // Generate diff
333    output::step(3, 3, "Generating diff...");
334    let diff_sql = generate_schema_diff(&schema)?;
335
336    output::newline();
337
338    if diff_sql.is_empty() {
339        success("Schema is in sync with database - no changes needed");
340    } else {
341        output::section("Generated SQL");
342        output::code(&diff_sql, "sql");
343
344        if let Some(output_path) = args.output {
345            std::fs::write(&output_path, &diff_sql)?;
346            output::newline();
347            success(&format!("Diff written to {}", output_path.display()));
348        }
349    }
350
351    Ok(())
352}
353
354// =============================================================================
355// Helper Functions
356// =============================================================================
357
358fn load_config(cwd: &PathBuf) -> CliResult<Config> {
359    let config_path = cwd.join(CONFIG_FILE_NAME);
360    if config_path.exists() {
361        Config::load(&config_path)
362    } else {
363        Ok(Config::default())
364    }
365}
366
367fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
368    // Use validate_schema to ensure field types are properly resolved
369    // (e.g., FieldType::Model -> FieldType::Enum for enum references)
370    prax_schema::validate_schema(content)
371        .map_err(|e| CliError::Schema(format!("Failed to parse/validate schema: {}", e)))
372}
373
374fn check_pending_migrations(migrations_dir: &PathBuf) -> CliResult<Vec<PathBuf>> {
375    let mut pending = Vec::new();
376
377    if !migrations_dir.exists() {
378        return Ok(pending);
379    }
380
381    for entry in std::fs::read_dir(migrations_dir)? {
382        let entry = entry?;
383        let path = entry.path();
384        if path.is_dir() {
385            if !is_migration_applied(&path)? {
386                pending.push(path);
387            }
388        }
389    }
390
391    pending.sort();
392    Ok(pending)
393}
394
395fn is_migration_applied(migration_path: &PathBuf) -> CliResult<bool> {
396    // Check for a marker file indicating the migration has been applied
397    // In production, this would check the migration history table
398    let marker = migration_path.join(".applied");
399    Ok(marker.exists())
400}
401
402fn create_migration(
403    migrations_dir: &PathBuf,
404    name: &str,
405    schema: &prax_schema::ast::Schema,
406) -> CliResult<PathBuf> {
407    // Create migration directory
408    let timestamp = chrono::Utc::now().format("%Y%m%d%H%M%S");
409    let migration_name = format!("{}_{}", timestamp, name);
410    let migration_path = migrations_dir.join(&migration_name);
411
412    std::fs::create_dir_all(&migration_path)?;
413
414    // Generate migration SQL
415    let sql = generate_schema_diff(schema)?;
416
417    // Write migration.sql
418    let sql_path = migration_path.join("migration.sql");
419    std::fs::write(&sql_path, &sql)?;
420
421    Ok(migration_path)
422}
423
424fn generate_schema_diff(schema: &prax_schema::ast::Schema) -> CliResult<String> {
425    use prax_schema::ast::{FieldType, ScalarType};
426
427    let mut sql = String::new();
428
429    sql.push_str("-- Migration generated by Prax\n\n");
430
431    // Generate enums FIRST (before tables that reference them)
432    if !schema.enums.is_empty() {
433        sql.push_str("-- Enum types\n");
434        for enum_def in schema.enums.values() {
435            let enum_name = enum_def
436                .attributes
437                .iter()
438                .find(|a| a.is("map"))
439                .and_then(|a: &prax_schema::ast::Attribute| a.first_arg())
440                .and_then(|v: &prax_schema::ast::AttributeValue| v.as_string())
441                .map(|s| s.to_string())
442                .unwrap_or_else(|| to_snake_case(enum_def.name()));
443
444            sql.push_str(&format!(
445                "DO $$ BEGIN\n    CREATE TYPE \"{}\" AS ENUM (",
446                enum_name
447            ));
448
449            let variants: Vec<String> = enum_def
450                .variants
451                .iter()
452                .map(|v| format!("'{}'", v.name()))
453                .collect();
454
455            sql.push_str(&variants.join(", "));
456            sql.push_str(");\nEXCEPTION\n    WHEN duplicate_object THEN null;\nEND $$;\n\n");
457        }
458        sql.push_str("\n");
459    }
460
461    // Generate CREATE TABLE statements for each model
462    sql.push_str("-- Tables\n");
463    for model in schema.models.values() {
464        let table_name = model.table_name();
465
466        sql.push_str(&format!(
467            "CREATE TABLE IF NOT EXISTS \"{}\" (\n",
468            table_name
469        ));
470
471        let mut columns = Vec::new();
472        let mut primary_keys = Vec::new();
473
474        for field in model.fields.values() {
475            if field.is_relation() {
476                continue;
477            }
478
479            let column_name = field
480                .get_attribute("map")
481                .and_then(|a| a.first_arg())
482                .and_then(|v| v.as_string())
483                .map(|s| s.to_string())
484                .unwrap_or_else(|| to_snake_case(field.name()));
485
486            let sql_type = field_type_to_sql(&field.field_type);
487            let mut column_def = format!("    \"{}\" {}", column_name, sql_type);
488
489            // Add constraints
490            if field.is_id() {
491                primary_keys.push(column_name.clone());
492            }
493
494            if field.has_attribute("auto") || field.has_attribute("autoincrement") {
495                // PostgreSQL uses SERIAL types
496                column_def = format!("    \"{}\" SERIAL", column_name);
497            }
498
499            if field.has_attribute("unique") {
500                column_def.push_str(" UNIQUE");
501            }
502
503            if !field.is_optional() && !field.is_id() {
504                column_def.push_str(" NOT NULL");
505            }
506
507            // Default values
508            if let Some(default_attr) = field.get_attribute("default") {
509                if let Some(value) = default_attr.first_arg() {
510                    let value_str = format_attribute_value(value);
511                    column_def.push_str(&format!(
512                        " DEFAULT {}",
513                        sql_default_value(&value_str, &field.field_type)
514                    ));
515                }
516            }
517
518            columns.push(column_def);
519        }
520
521        sql.push_str(&columns.join(",\n"));
522
523        if !primary_keys.is_empty() {
524            sql.push_str(",\n");
525            sql.push_str(&format!(
526                "    PRIMARY KEY (\"{}\")",
527                primary_keys.join("\", \"")
528            ));
529        }
530
531        sql.push_str("\n);\n\n");
532    }
533
534    return Ok(sql);
535
536    fn field_type_to_sql(field_type: &FieldType) -> String {
537        match field_type {
538            FieldType::Scalar(scalar) => match scalar {
539                ScalarType::Int => "INTEGER".to_string(),
540                ScalarType::BigInt => "BIGINT".to_string(),
541                ScalarType::Float => "DOUBLE PRECISION".to_string(),
542                ScalarType::String => "TEXT".to_string(),
543                ScalarType::Boolean => "BOOLEAN".to_string(),
544                ScalarType::DateTime => "TIMESTAMP WITH TIME ZONE".to_string(),
545                ScalarType::Date => "DATE".to_string(),
546                ScalarType::Time => "TIME".to_string(),
547                ScalarType::Json => "JSONB".to_string(),
548                ScalarType::Bytes => "BYTEA".to_string(),
549                ScalarType::Decimal => "DECIMAL".to_string(),
550                ScalarType::Uuid => "UUID".to_string(),
551                ScalarType::Cuid | ScalarType::Cuid2 | ScalarType::NanoId | ScalarType::Ulid => {
552                    "TEXT".to_string()
553                }
554                ScalarType::Vector(dim) => match dim {
555                    Some(d) => format!("vector({})", d),
556                    None => "vector".to_string(),
557                },
558                ScalarType::HalfVector(dim) => match dim {
559                    Some(d) => format!("halfvec({})", d),
560                    None => "halfvec".to_string(),
561                },
562                ScalarType::SparseVector(dim) => match dim {
563                    Some(d) => format!("sparsevec({})", d),
564                    None => "sparsevec".to_string(),
565                },
566                ScalarType::Bit(dim) => match dim {
567                    Some(d) => format!("bit({})", d),
568                    None => "bit".to_string(),
569                },
570            },
571            FieldType::Enum(name) => format!("\"{}\"", to_snake_case(name)),
572            _ => "TEXT".to_string(),
573        }
574    }
575}
576
577async fn apply_migration(migration_path: &PathBuf, _config: &Config) -> CliResult<()> {
578    let sql_path = migration_path.join("migration.sql");
579
580    if !sql_path.exists() {
581        return Err(CliError::Migration(format!(
582            "Migration file not found: {}",
583            sql_path.display()
584        )));
585    }
586
587    let _sql = std::fs::read_to_string(&sql_path)?;
588
589    // TODO: Execute SQL against database
590    // This would use the database URL from config
591
592    // Mark as applied
593    let marker = migration_path.join(".applied");
594    std::fs::write(&marker, chrono::Utc::now().to_rfc3339())?;
595
596    Ok(())
597}
598
599fn sql_default_value(value: &str, field_type: &prax_schema::ast::FieldType) -> String {
600    // Handle enum defaults - need to be quoted as strings
601    if matches!(field_type, prax_schema::ast::FieldType::Enum(_)) {
602        return format!("'{}'", value);
603    }
604
605    match value.to_lowercase().as_str() {
606        "now()" => "CURRENT_TIMESTAMP".to_string(),
607        "uuid()" => "gen_random_uuid()".to_string(),
608        "cuid()" | "cuid2()" | "nanoid()" | "ulid()" => {
609            // These need application-level generation
610            "''".to_string()
611        }
612        "true" => "TRUE".to_string(),
613        "false" => "FALSE".to_string(),
614        _ => value.to_string(),
615    }
616}
617
618fn to_snake_case(name: &str) -> String {
619    let mut result = String::new();
620    for (i, c) in name.chars().enumerate() {
621        if c.is_uppercase() {
622            if i > 0 {
623                result.push('_');
624            }
625            result.push(c.to_lowercase().next().unwrap());
626        } else {
627            result.push(c);
628        }
629    }
630    result
631}
632
633fn format_attribute_value(value: &prax_schema::ast::AttributeValue) -> String {
634    use prax_schema::ast::AttributeValue;
635
636    match value {
637        AttributeValue::String(s) => format!("\"{}\"", s),
638        AttributeValue::Int(i) => i.to_string(),
639        AttributeValue::Float(f) => f.to_string(),
640        AttributeValue::Boolean(b) => b.to_string(),
641        AttributeValue::Ident(id) => id.to_string(),
642        AttributeValue::Function(name, args) => {
643            if args.is_empty() {
644                format!("{}()", name)
645            } else {
646                let arg_strs: Vec<String> = args.iter().map(format_attribute_value).collect();
647                format!("{}({})", name, arg_strs.join(", "))
648            }
649        }
650        AttributeValue::Array(items) => {
651            let item_strs: Vec<String> = items.iter().map(format_attribute_value).collect();
652            format!("[{}]", item_strs.join(", "))
653        }
654        AttributeValue::FieldRef(field) => field.to_string(),
655        AttributeValue::FieldRefList(fields) => {
656            format!(
657                "[{}]",
658                fields
659                    .iter()
660                    .map(|f| f.to_string())
661                    .collect::<Vec<_>>()
662                    .join(", ")
663            )
664        }
665    }
666}