1use std::path::PathBuf;
4
5use crate::cli::MigrateArgs;
6use crate::commands::seed::{SeedRunner, find_seed_file, get_database_url};
7use crate::config::{CONFIG_FILE_NAME, Config, MIGRATIONS_DIR, SCHEMA_FILE_NAME};
8use crate::error::{CliError, CliResult};
9use crate::output::{self, success, warn};
10
11pub async fn run(args: MigrateArgs) -> CliResult<()> {
13 match args.command {
14 crate::cli::MigrateSubcommand::Dev(dev_args) => run_dev(dev_args).await,
15 crate::cli::MigrateSubcommand::Deploy => run_deploy().await,
16 crate::cli::MigrateSubcommand::Reset(reset_args) => run_reset(reset_args).await,
17 crate::cli::MigrateSubcommand::Status => run_status().await,
18 crate::cli::MigrateSubcommand::Resolve(resolve_args) => run_resolve(resolve_args).await,
19 crate::cli::MigrateSubcommand::Diff(diff_args) => run_diff(diff_args).await,
20 }
21}
22
23async fn run_dev(args: crate::cli::MigrateDevArgs) -> CliResult<()> {
25 output::header("Migrate Dev");
26
27 let cwd = std::env::current_dir()?;
28 let config = load_config(&cwd)?;
29
30 let schema_path = args
31 .schema
32 .clone()
33 .unwrap_or_else(|| cwd.join(SCHEMA_FILE_NAME));
34 let migrations_dir = cwd.join(MIGRATIONS_DIR);
35
36 output::kv("Schema", &schema_path.display().to_string());
37 output::kv("Migrations", &migrations_dir.display().to_string());
38 output::newline();
39
40 let total_steps = if args.skip_seed { 5 } else { 6 };
42
43 output::step(1, total_steps, "Parsing schema...");
45 let schema_content = std::fs::read_to_string(&schema_path)?;
46 let schema = parse_schema(&schema_content)?;
47
48 output::step(2, total_steps, "Checking migration status...");
50 let pending = check_pending_migrations(&migrations_dir)?;
51
52 if !pending.is_empty() {
53 output::list(&format!("{} pending migrations found:", pending.len()));
54 for migration in &pending {
55 output::list_item(&migration.display().to_string());
56 }
57 output::newline();
58 }
59
60 output::step(3, total_steps, "Comparing schema to database...");
62 let migration_name = args
63 .name
64 .unwrap_or_else(|| format!("migration_{}", chrono::Utc::now().format("%Y%m%d%H%M%S")));
65
66 output::step(4, total_steps, "Generating migration...");
68 let migration_path = create_migration(&migrations_dir, &migration_name, &schema)?;
69
70 if !args.create_only {
72 output::step(5, total_steps, "Applying migration...");
73 apply_migration(&migration_path, &config).await?;
74 } else {
75 output::step(5, total_steps, "Skipping apply (--create-only)...");
76 }
77
78 if !args.skip_seed && !args.create_only {
80 output::step(6, total_steps, "Running seed...");
81
82 if let Some(seed_path) = find_seed_file(&cwd, &config) {
83 let database_url = get_database_url(&config)?;
84 let runner = SeedRunner::new(
85 seed_path,
86 database_url,
87 config.database.provider.clone(),
88 cwd.clone(),
89 )?;
90
91 match runner.run().await {
92 Ok(result) => {
93 output::list_item(&format!("Seeded {} records", result.records_affected));
94 }
95 Err(e) => {
96 output::warn(&format!("Seed failed: {}. Continuing...", e));
97 }
98 }
99 } else {
100 output::list_item("No seed file found, skipping");
101 }
102 }
103
104 output::newline();
105 success(&format!("Migration '{}' created", migration_name));
106
107 output::newline();
108 output::section("Next steps");
109 output::list_item("Review the generated migration SQL");
110 output::list_item("Run `prax generate` to update your client");
111
112 Ok(())
113}
114
115async fn run_deploy() -> CliResult<()> {
117 output::header("Migrate Deploy");
118
119 let cwd = std::env::current_dir()?;
120 let config = load_config(&cwd)?;
121 let migrations_dir = cwd.join(MIGRATIONS_DIR);
122
123 output::kv("Migrations", &migrations_dir.display().to_string());
124 output::newline();
125
126 output::step(1, 3, "Checking for pending migrations...");
128 let pending = check_pending_migrations(&migrations_dir)?;
129
130 if pending.is_empty() {
131 output::newline();
132 success("No pending migrations to apply.");
133 return Ok(());
134 }
135
136 output::list(&format!("{} pending migrations:", pending.len()));
137 for migration in &pending {
138 output::list_item(&migration.file_name().unwrap().to_string_lossy());
139 }
140 output::newline();
141
142 output::step(2, 3, "Applying migrations...");
144 for migration in &pending {
145 output::list_item(&format!(
146 "Applying {}",
147 migration.file_name().unwrap().to_string_lossy()
148 ));
149 apply_migration(migration, &config).await?;
150 }
151
152 output::step(3, 3, "Verifying migrations...");
154
155 output::newline();
156 success(&format!(
157 "Applied {} migrations successfully!",
158 pending.len()
159 ));
160
161 Ok(())
162}
163
164async fn run_reset(args: crate::cli::MigrateResetArgs) -> CliResult<()> {
166 output::header("Migrate Reset");
167
168 let cwd = std::env::current_dir()?;
169 let config = load_config(&cwd)?;
170
171 if !args.force {
172 warn("This will delete all data in the database!");
173 output::newline();
174 if !output::confirm("Are you sure you want to reset the database?") {
175 output::newline();
176 output::info("Reset cancelled.");
177 return Ok(());
178 }
179 }
180
181 output::newline();
182 output::step(1, 4, "Dropping database...");
183 output::step(2, 4, "Creating database...");
186 output::step(3, 4, "Applying migrations...");
189 let migrations_dir = cwd.join(MIGRATIONS_DIR);
190 let migrations = check_pending_migrations(&migrations_dir)?;
191
192 for migration in &migrations {
193 apply_migration(migration, &config).await?;
194 }
195
196 if args.seed {
198 output::step(4, 4, "Running seed...");
199
200 if let Some(seed_path) = find_seed_file(&cwd, &config) {
202 let database_url = get_database_url(&config)?;
203 let runner = SeedRunner::new(
204 seed_path,
205 database_url,
206 config.database.provider.clone(),
207 cwd,
208 )?;
209
210 let result = runner.run().await?;
211 output::list_item(&format!("Seeded {} records", result.records_affected));
212 } else {
213 output::list_item("No seed file found, skipping seed");
214 }
215 } else {
216 output::step(4, 4, "Skipping seed...");
217 }
218
219 output::newline();
220 success("Database reset complete!");
221
222 Ok(())
223}
224
225async fn run_status() -> CliResult<()> {
227 output::header("Migration Status");
228
229 let cwd = std::env::current_dir()?;
230 let _config = load_config(&cwd)?;
231 let migrations_dir = cwd.join(MIGRATIONS_DIR);
232
233 let mut migrations = Vec::new();
235 if migrations_dir.exists() {
236 for entry in std::fs::read_dir(&migrations_dir)? {
237 let entry = entry?;
238 let path = entry.path();
239 if path.is_dir() {
240 migrations.push(path);
241 }
242 }
243 }
244 migrations.sort();
245
246 if migrations.is_empty() {
247 output::info("No migrations found.");
248 output::newline();
249 output::section("Getting started");
250 output::list_item("Run `prax migrate dev` to create your first migration");
251 return Ok(());
252 }
253
254 output::section("Migrations");
255
256 for (i, migration) in migrations.iter().enumerate() {
257 let name = migration.file_name().unwrap().to_string_lossy();
258 let applied = is_migration_applied(migration)?;
259
260 let status = if applied {
261 output::style_success("✓ Applied")
262 } else {
263 output::style_pending("○ Pending")
264 };
265
266 output::numbered_item(i + 1, &format!("{} - {}", name, status));
267 }
268
269 output::newline();
270
271 let applied_count = migrations
272 .iter()
273 .filter(|m| is_migration_applied(m).unwrap_or(false))
274 .count();
275 let pending_count = migrations.len() - applied_count;
276
277 output::kv("Total", &migrations.len().to_string());
278 output::kv("Applied", &applied_count.to_string());
279 output::kv("Pending", &pending_count.to_string());
280
281 Ok(())
282}
283
284async fn run_resolve(args: crate::cli::MigrateResolveArgs) -> CliResult<()> {
286 output::header("Migrate Resolve");
287
288 if args.rolled_back {
289 output::step(1, 2, "Marking migration as rolled back...");
290 output::step(2, 2, "Updating migration history...");
293
294 output::newline();
295 success(&format!(
296 "Migration '{}' marked as rolled back",
297 args.migration
298 ));
299 } else if args.applied {
300 output::step(1, 2, "Marking migration as applied...");
301 output::step(2, 2, "Updating migration history...");
304
305 output::newline();
306 success(&format!("Migration '{}' marked as applied", args.migration));
307 } else {
308 return Err(
309 CliError::Command("Must specify --applied or --rolled-back".to_string()).into(),
310 );
311 }
312
313 Ok(())
314}
315
316async fn run_diff(args: crate::cli::MigrateDiffArgs) -> CliResult<()> {
318 output::header("Migrate Diff");
319
320 let cwd = std::env::current_dir()?;
321 let schema_path = args.schema.unwrap_or_else(|| cwd.join(SCHEMA_FILE_NAME));
322
323 output::step(1, 3, "Parsing schema...");
325 let schema_content = std::fs::read_to_string(&schema_path)?;
326 let schema = parse_schema(&schema_content)?;
327
328 output::step(2, 3, "Introspecting database...");
330 output::step(3, 3, "Generating diff...");
334 let diff_sql = generate_schema_diff(&schema)?;
335
336 output::newline();
337
338 if diff_sql.is_empty() {
339 success("Schema is in sync with database - no changes needed");
340 } else {
341 output::section("Generated SQL");
342 output::code(&diff_sql, "sql");
343
344 if let Some(output_path) = args.output {
345 std::fs::write(&output_path, &diff_sql)?;
346 output::newline();
347 success(&format!("Diff written to {}", output_path.display()));
348 }
349 }
350
351 Ok(())
352}
353
354fn load_config(cwd: &PathBuf) -> CliResult<Config> {
359 let config_path = cwd.join(CONFIG_FILE_NAME);
360 if config_path.exists() {
361 Config::load(&config_path)
362 } else {
363 Ok(Config::default())
364 }
365}
366
367fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
368 prax_schema::parse_schema(content)
369 .map_err(|e| CliError::Schema(format!("Failed to parse schema: {}", e)))
370}
371
372fn check_pending_migrations(migrations_dir: &PathBuf) -> CliResult<Vec<PathBuf>> {
373 let mut pending = Vec::new();
374
375 if !migrations_dir.exists() {
376 return Ok(pending);
377 }
378
379 for entry in std::fs::read_dir(migrations_dir)? {
380 let entry = entry?;
381 let path = entry.path();
382 if path.is_dir() {
383 if !is_migration_applied(&path)? {
384 pending.push(path);
385 }
386 }
387 }
388
389 pending.sort();
390 Ok(pending)
391}
392
393fn is_migration_applied(migration_path: &PathBuf) -> CliResult<bool> {
394 let marker = migration_path.join(".applied");
397 Ok(marker.exists())
398}
399
400fn create_migration(
401 migrations_dir: &PathBuf,
402 name: &str,
403 schema: &prax_schema::ast::Schema,
404) -> CliResult<PathBuf> {
405 let timestamp = chrono::Utc::now().format("%Y%m%d%H%M%S");
407 let migration_name = format!("{}_{}", timestamp, name);
408 let migration_path = migrations_dir.join(&migration_name);
409
410 std::fs::create_dir_all(&migration_path)?;
411
412 let sql = generate_schema_diff(schema)?;
414
415 let sql_path = migration_path.join("migration.sql");
417 std::fs::write(&sql_path, &sql)?;
418
419 Ok(migration_path)
420}
421
422fn generate_schema_diff(schema: &prax_schema::ast::Schema) -> CliResult<String> {
423 use prax_schema::ast::{FieldType, ScalarType};
424
425 let mut sql = String::new();
426
427 sql.push_str("-- Migration generated by Prax\n\n");
428
429 for model in schema.models.values() {
431 let table_name = model.table_name();
432
433 sql.push_str(&format!(
434 "CREATE TABLE IF NOT EXISTS \"{}\" (\n",
435 table_name
436 ));
437
438 let mut columns = Vec::new();
439 let mut primary_keys = Vec::new();
440
441 for field in model.fields.values() {
442 if field.is_relation() {
443 continue;
444 }
445
446 let column_name = field
447 .get_attribute("map")
448 .and_then(|a| a.first_arg())
449 .and_then(|v| v.as_string())
450 .map(|s| s.to_string())
451 .unwrap_or_else(|| to_snake_case(field.name()));
452
453 let sql_type = field_type_to_sql(&field.field_type);
454 let mut column_def = format!(" \"{}\" {}", column_name, sql_type);
455
456 if field.is_id() {
458 primary_keys.push(column_name.clone());
459 }
460
461 if field.has_attribute("auto") || field.has_attribute("autoincrement") {
462 column_def = format!(" \"{}\" SERIAL", column_name);
464 }
465
466 if field.has_attribute("unique") {
467 column_def.push_str(" UNIQUE");
468 }
469
470 if !field.is_optional() && !field.is_id() {
471 column_def.push_str(" NOT NULL");
472 }
473
474 if let Some(default_attr) = field.get_attribute("default") {
476 if let Some(value) = default_attr.first_arg() {
477 let value_str = format_attribute_value(value);
478 column_def.push_str(&format!(" DEFAULT {}", sql_default_value(&value_str)));
479 }
480 }
481
482 columns.push(column_def);
483 }
484
485 sql.push_str(&columns.join(",\n"));
486
487 if !primary_keys.is_empty() {
488 sql.push_str(",\n");
489 sql.push_str(&format!(
490 " PRIMARY KEY (\"{}\")",
491 primary_keys.join("\", \"")
492 ));
493 }
494
495 sql.push_str("\n);\n\n");
496 sql.push_str("\n");
497 }
498
499 for enum_def in schema.enums.values() {
501 let enum_name = enum_def
502 .attributes
503 .iter()
504 .find(|a| a.is("map"))
505 .and_then(|a: &prax_schema::ast::Attribute| a.first_arg())
506 .and_then(|v: &prax_schema::ast::AttributeValue| v.as_string())
507 .map(|s| s.to_string())
508 .unwrap_or_else(|| to_snake_case(enum_def.name()));
509
510 sql.push_str(&format!(
511 "DO $$ BEGIN\n CREATE TYPE \"{}\" AS ENUM (",
512 enum_name
513 ));
514
515 let variants: Vec<String> = enum_def
516 .variants
517 .iter()
518 .map(|v| format!("'{}'", v.name()))
519 .collect();
520
521 sql.push_str(&variants.join(", "));
522 sql.push_str(");\nEXCEPTION\n WHEN duplicate_object THEN null;\nEND $$;\n\n");
523 }
524
525 return Ok(sql);
526
527 fn field_type_to_sql(field_type: &FieldType) -> String {
528 match field_type {
529 FieldType::Scalar(scalar) => match scalar {
530 ScalarType::Int => "INTEGER".to_string(),
531 ScalarType::BigInt => "BIGINT".to_string(),
532 ScalarType::Float => "DOUBLE PRECISION".to_string(),
533 ScalarType::String => "TEXT".to_string(),
534 ScalarType::Boolean => "BOOLEAN".to_string(),
535 ScalarType::DateTime => "TIMESTAMP WITH TIME ZONE".to_string(),
536 ScalarType::Date => "DATE".to_string(),
537 ScalarType::Time => "TIME".to_string(),
538 ScalarType::Json => "JSONB".to_string(),
539 ScalarType::Bytes => "BYTEA".to_string(),
540 ScalarType::Decimal => "DECIMAL".to_string(),
541 ScalarType::Uuid => "UUID".to_string(),
542 ScalarType::Cuid | ScalarType::Cuid2 | ScalarType::NanoId | ScalarType::Ulid => {
543 "TEXT".to_string()
544 }
545 ScalarType::Vector(dim) => match dim {
546 Some(d) => format!("vector({})", d),
547 None => "vector".to_string(),
548 },
549 ScalarType::HalfVector(dim) => match dim {
550 Some(d) => format!("halfvec({})", d),
551 None => "halfvec".to_string(),
552 },
553 ScalarType::SparseVector(dim) => match dim {
554 Some(d) => format!("sparsevec({})", d),
555 None => "sparsevec".to_string(),
556 },
557 ScalarType::Bit(dim) => match dim {
558 Some(d) => format!("bit({})", d),
559 None => "bit".to_string(),
560 },
561 },
562 FieldType::Enum(name) => format!("\"{}\"", to_snake_case(name)),
563 _ => "TEXT".to_string(),
564 }
565 }
566}
567
568async fn apply_migration(migration_path: &PathBuf, _config: &Config) -> CliResult<()> {
569 let sql_path = migration_path.join("migration.sql");
570
571 if !sql_path.exists() {
572 return Err(CliError::Migration(format!(
573 "Migration file not found: {}",
574 sql_path.display()
575 )));
576 }
577
578 let _sql = std::fs::read_to_string(&sql_path)?;
579
580 let marker = migration_path.join(".applied");
585 std::fs::write(&marker, chrono::Utc::now().to_rfc3339())?;
586
587 Ok(())
588}
589
590fn sql_default_value(value: &str) -> String {
591 match value.to_lowercase().as_str() {
592 "now()" => "CURRENT_TIMESTAMP".to_string(),
593 "uuid()" => "gen_random_uuid()".to_string(),
594 "cuid()" | "cuid2()" | "nanoid()" | "ulid()" => {
595 "''".to_string()
597 }
598 "true" => "TRUE".to_string(),
599 "false" => "FALSE".to_string(),
600 _ => value.to_string(),
601 }
602}
603
604fn to_snake_case(name: &str) -> String {
605 let mut result = String::new();
606 for (i, c) in name.chars().enumerate() {
607 if c.is_uppercase() {
608 if i > 0 {
609 result.push('_');
610 }
611 result.push(c.to_lowercase().next().unwrap());
612 } else {
613 result.push(c);
614 }
615 }
616 result
617}
618
619fn format_attribute_value(value: &prax_schema::ast::AttributeValue) -> String {
620 use prax_schema::ast::AttributeValue;
621
622 match value {
623 AttributeValue::String(s) => format!("\"{}\"", s),
624 AttributeValue::Int(i) => i.to_string(),
625 AttributeValue::Float(f) => f.to_string(),
626 AttributeValue::Boolean(b) => b.to_string(),
627 AttributeValue::Ident(id) => id.to_string(),
628 AttributeValue::Function(name, args) => {
629 if args.is_empty() {
630 format!("{}()", name)
631 } else {
632 let arg_strs: Vec<String> = args.iter().map(format_attribute_value).collect();
633 format!("{}({})", name, arg_strs.join(", "))
634 }
635 }
636 AttributeValue::Array(items) => {
637 let item_strs: Vec<String> = items.iter().map(format_attribute_value).collect();
638 format!("[{}]", item_strs.join(", "))
639 }
640 AttributeValue::FieldRef(field) => field.to_string(),
641 AttributeValue::FieldRefList(fields) => {
642 format!(
643 "[{}]",
644 fields
645 .iter()
646 .map(|f| f.to_string())
647 .collect::<Vec<_>>()
648 .join(", ")
649 )
650 }
651 }
652}