1use std::path::PathBuf;
4
5use crate::cli::MigrateArgs;
6use crate::commands::seed::{find_seed_file, get_database_url, SeedRunner};
7use crate::config::{Config, CONFIG_FILE_NAME, MIGRATIONS_DIR, SCHEMA_FILE_NAME};
8use crate::error::{CliError, CliResult};
9use crate::output::{self, success, warn};
10
11pub async fn run(args: MigrateArgs) -> CliResult<()> {
13 match args.command {
14 crate::cli::MigrateSubcommand::Dev(dev_args) => run_dev(dev_args).await,
15 crate::cli::MigrateSubcommand::Deploy => run_deploy().await,
16 crate::cli::MigrateSubcommand::Reset(reset_args) => run_reset(reset_args).await,
17 crate::cli::MigrateSubcommand::Status => run_status().await,
18 crate::cli::MigrateSubcommand::Resolve(resolve_args) => run_resolve(resolve_args).await,
19 crate::cli::MigrateSubcommand::Diff(diff_args) => run_diff(diff_args).await,
20 }
21}
22
23async fn run_dev(args: crate::cli::MigrateDevArgs) -> CliResult<()> {
25 output::header("Migrate Dev");
26
27 let cwd = std::env::current_dir()?;
28 let config = load_config(&cwd)?;
29
30 let schema_path = args.schema.clone().unwrap_or_else(|| cwd.join(SCHEMA_FILE_NAME));
31 let migrations_dir = cwd.join(MIGRATIONS_DIR);
32
33 output::kv("Schema", &schema_path.display().to_string());
34 output::kv("Migrations", &migrations_dir.display().to_string());
35 output::newline();
36
37 let total_steps = if args.skip_seed { 5 } else { 6 };
39
40 output::step(1, total_steps, "Parsing schema...");
42 let schema_content = std::fs::read_to_string(&schema_path)?;
43 let schema = parse_schema(&schema_content)?;
44
45 output::step(2, total_steps, "Checking migration status...");
47 let pending = check_pending_migrations(&migrations_dir)?;
48
49 if !pending.is_empty() {
50 output::list(&format!("{} pending migrations found:", pending.len()));
51 for migration in &pending {
52 output::list_item(&migration.display().to_string());
53 }
54 output::newline();
55 }
56
57 output::step(3, total_steps, "Comparing schema to database...");
59 let migration_name = args.name.unwrap_or_else(|| {
60 format!(
61 "migration_{}",
62 chrono::Utc::now().format("%Y%m%d%H%M%S")
63 )
64 });
65
66 output::step(4, total_steps, "Generating migration...");
68 let migration_path = create_migration(&migrations_dir, &migration_name, &schema)?;
69
70 if !args.create_only {
72 output::step(5, total_steps, "Applying migration...");
73 apply_migration(&migration_path, &config).await?;
74 } else {
75 output::step(5, total_steps, "Skipping apply (--create-only)...");
76 }
77
78 if !args.skip_seed && !args.create_only {
80 output::step(6, total_steps, "Running seed...");
81
82 if let Some(seed_path) = find_seed_file(&cwd, &config) {
83 let database_url = get_database_url(&config)?;
84 let runner = SeedRunner::new(
85 seed_path,
86 database_url,
87 config.database.provider.clone(),
88 cwd.clone(),
89 )?;
90
91 match runner.run().await {
92 Ok(result) => {
93 output::list_item(&format!("Seeded {} records", result.records_affected));
94 }
95 Err(e) => {
96 output::warn(&format!("Seed failed: {}. Continuing...", e));
97 }
98 }
99 } else {
100 output::list_item("No seed file found, skipping");
101 }
102 }
103
104 output::newline();
105 success(&format!("Migration '{}' created", migration_name));
106
107 output::newline();
108 output::section("Next steps");
109 output::list_item("Review the generated migration SQL");
110 output::list_item("Run `prax generate` to update your client");
111
112 Ok(())
113}
114
115async fn run_deploy() -> CliResult<()> {
117 output::header("Migrate Deploy");
118
119 let cwd = std::env::current_dir()?;
120 let config = load_config(&cwd)?;
121 let migrations_dir = cwd.join(MIGRATIONS_DIR);
122
123 output::kv("Migrations", &migrations_dir.display().to_string());
124 output::newline();
125
126 output::step(1, 3, "Checking for pending migrations...");
128 let pending = check_pending_migrations(&migrations_dir)?;
129
130 if pending.is_empty() {
131 output::newline();
132 success("No pending migrations to apply.");
133 return Ok(());
134 }
135
136 output::list(&format!("{} pending migrations:", pending.len()));
137 for migration in &pending {
138 output::list_item(&migration.file_name().unwrap().to_string_lossy());
139 }
140 output::newline();
141
142 output::step(2, 3, "Applying migrations...");
144 for migration in &pending {
145 output::list_item(&format!("Applying {}", migration.file_name().unwrap().to_string_lossy()));
146 apply_migration(migration, &config).await?;
147 }
148
149 output::step(3, 3, "Verifying migrations...");
151
152 output::newline();
153 success(&format!(
154 "Applied {} migrations successfully!",
155 pending.len()
156 ));
157
158 Ok(())
159}
160
161async fn run_reset(args: crate::cli::MigrateResetArgs) -> CliResult<()> {
163 output::header("Migrate Reset");
164
165 let cwd = std::env::current_dir()?;
166 let config = load_config(&cwd)?;
167
168 if !args.force {
169 warn("This will delete all data in the database!");
170 output::newline();
171 if !output::confirm("Are you sure you want to reset the database?") {
172 output::newline();
173 output::info("Reset cancelled.");
174 return Ok(());
175 }
176 }
177
178 output::newline();
179 output::step(1, 4, "Dropping database...");
180 output::step(2, 4, "Creating database...");
183 output::step(3, 4, "Applying migrations...");
186 let migrations_dir = cwd.join(MIGRATIONS_DIR);
187 let migrations = check_pending_migrations(&migrations_dir)?;
188
189 for migration in &migrations {
190 apply_migration(migration, &config).await?;
191 }
192
193 if args.seed {
195 output::step(4, 4, "Running seed...");
196
197 if let Some(seed_path) = find_seed_file(&cwd, &config) {
199 let database_url = get_database_url(&config)?;
200 let runner = SeedRunner::new(
201 seed_path,
202 database_url,
203 config.database.provider.clone(),
204 cwd,
205 )?;
206
207 let result = runner.run().await?;
208 output::list_item(&format!("Seeded {} records", result.records_affected));
209 } else {
210 output::list_item("No seed file found, skipping seed");
211 }
212 } else {
213 output::step(4, 4, "Skipping seed...");
214 }
215
216 output::newline();
217 success("Database reset complete!");
218
219 Ok(())
220}
221
222async fn run_status() -> CliResult<()> {
224 output::header("Migration Status");
225
226 let cwd = std::env::current_dir()?;
227 let _config = load_config(&cwd)?;
228 let migrations_dir = cwd.join(MIGRATIONS_DIR);
229
230 let mut migrations = Vec::new();
232 if migrations_dir.exists() {
233 for entry in std::fs::read_dir(&migrations_dir)? {
234 let entry = entry?;
235 let path = entry.path();
236 if path.is_dir() {
237 migrations.push(path);
238 }
239 }
240 }
241 migrations.sort();
242
243 if migrations.is_empty() {
244 output::info("No migrations found.");
245 output::newline();
246 output::section("Getting started");
247 output::list_item("Run `prax migrate dev` to create your first migration");
248 return Ok(());
249 }
250
251 output::section("Migrations");
252
253 for (i, migration) in migrations.iter().enumerate() {
254 let name = migration.file_name().unwrap().to_string_lossy();
255 let applied = is_migration_applied(migration)?;
256
257 let status = if applied {
258 output::style_success("✓ Applied")
259 } else {
260 output::style_pending("○ Pending")
261 };
262
263 output::numbered_item(i + 1, &format!("{} - {}", name, status));
264 }
265
266 output::newline();
267
268 let applied_count = migrations.iter().filter(|m| is_migration_applied(m).unwrap_or(false)).count();
269 let pending_count = migrations.len() - applied_count;
270
271 output::kv("Total", &migrations.len().to_string());
272 output::kv("Applied", &applied_count.to_string());
273 output::kv("Pending", &pending_count.to_string());
274
275 Ok(())
276}
277
278async fn run_resolve(args: crate::cli::MigrateResolveArgs) -> CliResult<()> {
280 output::header("Migrate Resolve");
281
282 if args.rolled_back {
283 output::step(1, 2, "Marking migration as rolled back...");
284 output::step(2, 2, "Updating migration history...");
287
288 output::newline();
289 success(&format!(
290 "Migration '{}' marked as rolled back",
291 args.migration
292 ));
293 } else if args.applied {
294 output::step(1, 2, "Marking migration as applied...");
295 output::step(2, 2, "Updating migration history...");
298
299 output::newline();
300 success(&format!(
301 "Migration '{}' marked as applied",
302 args.migration
303 ));
304 } else {
305 return Err(CliError::Command(
306 "Must specify --applied or --rolled-back".to_string()
307 ).into());
308 }
309
310 Ok(())
311}
312
313async fn run_diff(args: crate::cli::MigrateDiffArgs) -> CliResult<()> {
315 output::header("Migrate Diff");
316
317 let cwd = std::env::current_dir()?;
318 let schema_path = args.schema.unwrap_or_else(|| cwd.join(SCHEMA_FILE_NAME));
319
320 output::step(1, 3, "Parsing schema...");
322 let schema_content = std::fs::read_to_string(&schema_path)?;
323 let schema = parse_schema(&schema_content)?;
324
325 output::step(2, 3, "Introspecting database...");
327 output::step(3, 3, "Generating diff...");
331 let diff_sql = generate_schema_diff(&schema)?;
332
333 output::newline();
334
335 if diff_sql.is_empty() {
336 success("Schema is in sync with database - no changes needed");
337 } else {
338 output::section("Generated SQL");
339 output::code(&diff_sql, "sql");
340
341 if let Some(output_path) = args.output {
342 std::fs::write(&output_path, &diff_sql)?;
343 output::newline();
344 success(&format!("Diff written to {}", output_path.display()));
345 }
346 }
347
348 Ok(())
349}
350
351fn load_config(cwd: &PathBuf) -> CliResult<Config> {
356 let config_path = cwd.join(CONFIG_FILE_NAME);
357 if config_path.exists() {
358 Config::load(&config_path)
359 } else {
360 Ok(Config::default())
361 }
362}
363
364fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
365 prax_schema::parse_schema(content)
366 .map_err(|e| CliError::Schema(format!("Failed to parse schema: {}", e)))
367}
368
369fn check_pending_migrations(migrations_dir: &PathBuf) -> CliResult<Vec<PathBuf>> {
370 let mut pending = Vec::new();
371
372 if !migrations_dir.exists() {
373 return Ok(pending);
374 }
375
376 for entry in std::fs::read_dir(migrations_dir)? {
377 let entry = entry?;
378 let path = entry.path();
379 if path.is_dir() {
380 if !is_migration_applied(&path)? {
381 pending.push(path);
382 }
383 }
384 }
385
386 pending.sort();
387 Ok(pending)
388}
389
390fn is_migration_applied(migration_path: &PathBuf) -> CliResult<bool> {
391 let marker = migration_path.join(".applied");
394 Ok(marker.exists())
395}
396
397fn create_migration(
398 migrations_dir: &PathBuf,
399 name: &str,
400 schema: &prax_schema::ast::Schema,
401) -> CliResult<PathBuf> {
402 let timestamp = chrono::Utc::now().format("%Y%m%d%H%M%S");
404 let migration_name = format!("{}_{}", timestamp, name);
405 let migration_path = migrations_dir.join(&migration_name);
406
407 std::fs::create_dir_all(&migration_path)?;
408
409 let sql = generate_schema_diff(schema)?;
411
412 let sql_path = migration_path.join("migration.sql");
414 std::fs::write(&sql_path, &sql)?;
415
416 Ok(migration_path)
417}
418
419fn generate_schema_diff(schema: &prax_schema::ast::Schema) -> CliResult<String> {
420 use prax_schema::ast::{FieldType, ScalarType};
421
422 let mut sql = String::new();
423
424 sql.push_str("-- Migration generated by Prax\n\n");
425
426 for model in schema.models.values() {
428 let table_name = model.table_name();
429
430 sql.push_str(&format!("CREATE TABLE IF NOT EXISTS \"{}\" (\n", table_name));
431
432 let mut columns = Vec::new();
433 let mut primary_keys = Vec::new();
434
435 for field in model.fields.values() {
436 if field.is_relation() {
437 continue;
438 }
439
440 let column_name = field
441 .get_attribute("map")
442 .and_then(|a| a.first_arg())
443 .and_then(|v| v.as_string())
444 .map(|s| s.to_string())
445 .unwrap_or_else(|| to_snake_case(field.name()));
446
447 let sql_type = field_type_to_sql(&field.field_type);
448 let mut column_def = format!(" \"{}\" {}", column_name, sql_type);
449
450 if field.is_id() {
452 primary_keys.push(column_name.clone());
453 }
454
455 if field.has_attribute("auto") || field.has_attribute("autoincrement") {
456 column_def = format!(
458 " \"{}\" SERIAL",
459 column_name
460 );
461 }
462
463 if field.has_attribute("unique") {
464 column_def.push_str(" UNIQUE");
465 }
466
467 if !field.is_optional() && !field.is_id() {
468 column_def.push_str(" NOT NULL");
469 }
470
471 if let Some(default_attr) = field.get_attribute("default") {
473 if let Some(value) = default_attr.first_arg() {
474 let value_str = format_attribute_value(value);
475 column_def.push_str(&format!(
476 " DEFAULT {}",
477 sql_default_value(&value_str)
478 ));
479 }
480 }
481
482 columns.push(column_def);
483 }
484
485 sql.push_str(&columns.join(",\n"));
486
487 if !primary_keys.is_empty() {
488 sql.push_str(",\n");
489 sql.push_str(&format!(
490 " PRIMARY KEY (\"{}\")",
491 primary_keys.join("\", \"")
492 ));
493 }
494
495 sql.push_str("\n);\n\n");
496 sql.push_str("\n");
497 }
498
499 for enum_def in schema.enums.values() {
501 let enum_name = enum_def
502 .attributes
503 .iter()
504 .find(|a| a.is("map"))
505 .and_then(|a: &prax_schema::ast::Attribute| a.first_arg())
506 .and_then(|v: &prax_schema::ast::AttributeValue| v.as_string())
507 .map(|s| s.to_string())
508 .unwrap_or_else(|| to_snake_case(enum_def.name()));
509
510 sql.push_str(&format!(
511 "DO $$ BEGIN\n CREATE TYPE \"{}\" AS ENUM (",
512 enum_name
513 ));
514
515 let variants: Vec<String> = enum_def
516 .variants
517 .iter()
518 .map(|v| format!("'{}'", v.name()))
519 .collect();
520
521 sql.push_str(&variants.join(", "));
522 sql.push_str(");\nEXCEPTION\n WHEN duplicate_object THEN null;\nEND $$;\n\n");
523 }
524
525 return Ok(sql);
526
527 fn field_type_to_sql(field_type: &FieldType) -> String {
528 match field_type {
529 FieldType::Scalar(scalar) => match scalar {
530 ScalarType::Int => "INTEGER".to_string(),
531 ScalarType::BigInt => "BIGINT".to_string(),
532 ScalarType::Float => "DOUBLE PRECISION".to_string(),
533 ScalarType::String => "TEXT".to_string(),
534 ScalarType::Boolean => "BOOLEAN".to_string(),
535 ScalarType::DateTime => "TIMESTAMP WITH TIME ZONE".to_string(),
536 ScalarType::Date => "DATE".to_string(),
537 ScalarType::Time => "TIME".to_string(),
538 ScalarType::Json => "JSONB".to_string(),
539 ScalarType::Bytes => "BYTEA".to_string(),
540 ScalarType::Decimal => "DECIMAL".to_string(),
541 ScalarType::Uuid => "UUID".to_string(),
542 ScalarType::Cuid | ScalarType::Cuid2 | ScalarType::NanoId | ScalarType::Ulid => {
543 "TEXT".to_string()
544 }
545 },
546 FieldType::Enum(name) => format!("\"{}\"", to_snake_case(name)),
547 _ => "TEXT".to_string(),
548 }
549 }
550}
551
552async fn apply_migration(migration_path: &PathBuf, _config: &Config) -> CliResult<()> {
553 let sql_path = migration_path.join("migration.sql");
554
555 if !sql_path.exists() {
556 return Err(CliError::Migration(format!(
557 "Migration file not found: {}",
558 sql_path.display()
559 )));
560 }
561
562 let _sql = std::fs::read_to_string(&sql_path)?;
563
564 let marker = migration_path.join(".applied");
569 std::fs::write(&marker, chrono::Utc::now().to_rfc3339())?;
570
571 Ok(())
572}
573
574fn sql_default_value(value: &str) -> String {
575 match value.to_lowercase().as_str() {
576 "now()" => "CURRENT_TIMESTAMP".to_string(),
577 "uuid()" => "gen_random_uuid()".to_string(),
578 "cuid()" | "cuid2()" | "nanoid()" | "ulid()" => {
579 "''".to_string()
581 }
582 "true" => "TRUE".to_string(),
583 "false" => "FALSE".to_string(),
584 _ => value.to_string(),
585 }
586}
587
588fn to_snake_case(name: &str) -> String {
589 let mut result = String::new();
590 for (i, c) in name.chars().enumerate() {
591 if c.is_uppercase() {
592 if i > 0 {
593 result.push('_');
594 }
595 result.push(c.to_lowercase().next().unwrap());
596 } else {
597 result.push(c);
598 }
599 }
600 result
601}
602
603fn format_attribute_value(value: &prax_schema::ast::AttributeValue) -> String {
604 use prax_schema::ast::AttributeValue;
605
606 match value {
607 AttributeValue::String(s) => format!("\"{}\"", s),
608 AttributeValue::Int(i) => i.to_string(),
609 AttributeValue::Float(f) => f.to_string(),
610 AttributeValue::Boolean(b) => b.to_string(),
611 AttributeValue::Ident(id) => id.to_string(),
612 AttributeValue::Function(name, args) => {
613 if args.is_empty() {
614 format!("{}()", name)
615 } else {
616 let arg_strs: Vec<String> = args.iter().map(format_attribute_value).collect();
617 format!("{}({})", name, arg_strs.join(", "))
618 }
619 }
620 AttributeValue::Array(items) => {
621 let item_strs: Vec<String> = items.iter().map(format_attribute_value).collect();
622 format!("[{}]", item_strs.join(", "))
623 }
624 AttributeValue::FieldRef(field) => field.to_string(),
625 AttributeValue::FieldRefList(fields) => {
626 format!(
627 "[{}]",
628 fields
629 .iter()
630 .map(|f| f.to_string())
631 .collect::<Vec<_>>()
632 .join(", ")
633 )
634 }
635 }
636}