1use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::process::Command;
12
13use serde::{Deserialize, Serialize};
14
15use crate::config::Config;
16use crate::error::{CliError, CliResult};
17use crate::output;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SeedFileType {
22 Rust,
24 Sql,
26 Json,
28 Toml,
30}
31
32impl SeedFileType {
33 pub fn from_path(path: &Path) -> Option<Self> {
35 match path.extension()?.to_str()? {
36 "rs" => Some(Self::Rust),
37 "sql" => Some(Self::Sql),
38 "json" => Some(Self::Json),
39 "toml" => Some(Self::Toml),
40 _ => None,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct SeedRunner {
48 pub seed_path: PathBuf,
50 pub file_type: SeedFileType,
52 pub database_url: String,
54 pub provider: String,
56 pub cwd: PathBuf,
58 pub environment: String,
60 pub reset_before_seed: bool,
62}
63
64impl SeedRunner {
65 pub fn new(
67 seed_path: PathBuf,
68 database_url: String,
69 provider: String,
70 cwd: PathBuf,
71 ) -> CliResult<Self> {
72 let file_type = SeedFileType::from_path(&seed_path).ok_or_else(|| {
73 CliError::Config(format!(
74 "Unsupported seed file type: {}. Supported: .rs, .sql, .json, .toml",
75 seed_path.display()
76 ))
77 })?;
78
79 Ok(Self {
80 seed_path,
81 file_type,
82 database_url,
83 provider,
84 cwd,
85 environment: std::env::var("PRAX_ENV").unwrap_or_else(|_| "development".to_string()),
86 reset_before_seed: false,
87 })
88 }
89
90 pub fn with_environment(mut self, env: impl Into<String>) -> Self {
92 self.environment = env.into();
93 self
94 }
95
96 pub fn with_reset(mut self, reset: bool) -> Self {
98 self.reset_before_seed = reset;
99 self
100 }
101
102 pub async fn run(&self) -> CliResult<SeedResult> {
104 match self.file_type {
105 SeedFileType::Rust => self.run_rust_seed().await,
106 SeedFileType::Sql => self.run_sql_seed().await,
107 SeedFileType::Json => self.run_json_seed().await,
108 SeedFileType::Toml => self.run_toml_seed().await,
109 }
110 }
111
112 async fn run_rust_seed(&self) -> CliResult<SeedResult> {
114 output::step(1, 4, "Compiling seed script...");
115
116 let cargo_toml = self.cwd.join("Cargo.toml");
118 if !cargo_toml.exists() {
119 return Err(CliError::Config(
120 "No Cargo.toml found. Rust seed scripts require a Rust project.".to_string(),
121 ));
122 }
123
124 let seed_name = self
126 .seed_path
127 .file_stem()
128 .and_then(|s| s.to_str())
129 .unwrap_or("seed");
130
131 let has_bin_target = self.check_bin_target(seed_name)?;
133
134 let mut records_affected = 0u64;
135
136 if has_bin_target {
137 output::step(2, 4, &format!("Building seed binary '{}'...", seed_name));
139
140 let build_status = Command::new("cargo")
141 .args(["build", "--bin", seed_name, "--release"])
142 .current_dir(&self.cwd)
143 .env("DATABASE_URL", &self.database_url)
144 .env("PRAX_ENV", &self.environment)
145 .status()?;
146
147 if !build_status.success() {
148 return Err(CliError::Command("Failed to build seed binary".to_string()));
149 }
150
151 output::step(3, 4, "Running seed...");
152
153 let run_output = Command::new("cargo")
154 .args(["run", "--bin", seed_name, "--release"])
155 .current_dir(&self.cwd)
156 .env("DATABASE_URL", &self.database_url)
157 .env("PRAX_ENV", &self.environment)
158 .output()?;
159
160 if !run_output.status.success() {
161 let stderr = String::from_utf8_lossy(&run_output.stderr);
162 return Err(CliError::Command(format!("Seed failed: {}", stderr)));
163 }
164
165 let stdout = String::from_utf8_lossy(&run_output.stdout);
167 for line in stdout.lines() {
168 output::list_item(line);
169 if let Some(count) = parse_seed_output(line) {
171 records_affected += count;
172 }
173 }
174
175 output::step(4, 4, "Verifying seed data...");
176 } else {
177 output::step(2, 4, "Compiling standalone seed script...");
179
180 let temp_dir = std::env::temp_dir().join("prax_seed");
182 std::fs::create_dir_all(&temp_dir)?;
183
184 let output_binary = temp_dir.join(seed_name);
185
186 let seed_content = std::fs::read_to_string(&self.seed_path)?;
188
189 if seed_content.contains("use prax") || seed_content.contains("#[tokio::main]") {
190 output::list_item("Creating temporary build environment...");
192
193 let temp_project = temp_dir.join("seed_project");
194 std::fs::create_dir_all(temp_project.join("src"))?;
195
196 std::fs::copy(&self.seed_path, temp_project.join("src/main.rs"))?;
198
199 let seed_cargo = create_seed_cargo_toml(&self.cwd)?;
201 std::fs::write(temp_project.join("Cargo.toml"), seed_cargo)?;
202
203 let build_status = Command::new("cargo")
205 .args(["build", "--release"])
206 .current_dir(&temp_project)
207 .env("DATABASE_URL", &self.database_url)
208 .env("PRAX_ENV", &self.environment)
209 .status()?;
210
211 if !build_status.success() {
212 return Err(CliError::Command(
213 "Failed to compile seed script".to_string(),
214 ));
215 }
216
217 let built_binary = temp_project.join("target/release/seed");
219 if built_binary.exists() {
220 std::fs::copy(&built_binary, &output_binary)?;
221 }
222 } else {
223 return Err(CliError::Config(
224 "Seed script must be a valid Rust file with a main function".to_string(),
225 ));
226 }
227
228 output::step(3, 4, "Running seed...");
229
230 let run_output = Command::new(&output_binary)
231 .current_dir(&self.cwd)
232 .env("DATABASE_URL", &self.database_url)
233 .env("PRAX_ENV", &self.environment)
234 .output()?;
235
236 if !run_output.status.success() {
237 let stderr = String::from_utf8_lossy(&run_output.stderr);
238 return Err(CliError::Command(format!("Seed failed: {}", stderr)));
239 }
240
241 let stdout = String::from_utf8_lossy(&run_output.stdout);
242 for line in stdout.lines() {
243 output::list_item(line);
244 if let Some(count) = parse_seed_output(line) {
245 records_affected += count;
246 }
247 }
248
249 output::step(4, 4, "Verifying seed data...");
250 }
251
252 Ok(SeedResult {
253 file_type: self.file_type,
254 records_affected,
255 tables_seeded: Vec::new(),
256 duration: std::time::Duration::from_secs(0),
257 })
258 }
259
260 async fn run_sql_seed(&self) -> CliResult<SeedResult> {
262 output::step(1, 3, "Reading SQL seed file...");
263
264 let sql_content = std::fs::read_to_string(&self.seed_path)?;
265
266 let statements: Vec<&str> = sql_content
268 .split(';')
269 .map(|s| s.trim())
270 .filter(|s| !s.is_empty() && !s.starts_with("--"))
271 .collect();
272
273 output::list_item(&format!("Found {} SQL statements", statements.len()));
274
275 output::step(2, 3, "Executing SQL...");
276
277 let records = self.execute_sql(&sql_content).await?;
279
280 output::step(3, 3, "Verifying seed data...");
281
282 Ok(SeedResult {
283 file_type: self.file_type,
284 records_affected: records,
285 tables_seeded: Vec::new(),
286 duration: std::time::Duration::from_secs(0),
287 })
288 }
289
290 async fn run_json_seed(&self) -> CliResult<SeedResult> {
292 output::step(1, 4, "Reading JSON seed file...");
293
294 let json_content = std::fs::read_to_string(&self.seed_path)?;
295 let seed_data: SeedData =
296 serde_json::from_str(&json_content).map_err(|e| CliError::Config(e.to_string()))?;
297
298 output::step(2, 4, "Validating seed data...");
299 output::list_item(&format!("Found {} tables to seed", seed_data.tables.len()));
300
301 output::step(3, 4, "Inserting seed data...");
302
303 let mut total_records = 0u64;
304 let mut tables_seeded = Vec::new();
305
306 for (table_name, records) in &seed_data.tables {
307 let sql = self.generate_insert_sql(table_name, records)?;
308 let count = self.execute_sql(&sql).await?;
309 output::list_item(&format!(" {} - {} records", table_name, records.len()));
310 total_records += count;
311 tables_seeded.push(table_name.clone());
312 }
313
314 output::step(4, 4, "Verifying seed data...");
315
316 Ok(SeedResult {
317 file_type: self.file_type,
318 records_affected: total_records,
319 tables_seeded,
320 duration: std::time::Duration::from_secs(0),
321 })
322 }
323
324 async fn run_toml_seed(&self) -> CliResult<SeedResult> {
326 output::step(1, 4, "Reading TOML seed file...");
327
328 let toml_content = std::fs::read_to_string(&self.seed_path)?;
329 let seed_data: SeedData =
330 toml::from_str(&toml_content).map_err(|e| CliError::Config(e.to_string()))?;
331
332 output::step(2, 4, "Validating seed data...");
333 output::list_item(&format!("Found {} tables to seed", seed_data.tables.len()));
334
335 output::step(3, 4, "Inserting seed data...");
336
337 let mut total_records = 0u64;
338 let mut tables_seeded = Vec::new();
339
340 for (table_name, records) in &seed_data.tables {
341 let sql = self.generate_insert_sql(table_name, records)?;
342 let count = self.execute_sql(&sql).await?;
343 output::list_item(&format!(" {} - {} records", table_name, records.len()));
344 total_records += count;
345 tables_seeded.push(table_name.clone());
346 }
347
348 output::step(4, 4, "Verifying seed data...");
349
350 Ok(SeedResult {
351 file_type: self.file_type,
352 records_affected: total_records,
353 tables_seeded,
354 duration: std::time::Duration::from_secs(0),
355 })
356 }
357
358 fn check_bin_target(&self, name: &str) -> CliResult<bool> {
360 let cargo_toml = self.cwd.join("Cargo.toml");
361 let content = std::fs::read_to_string(&cargo_toml)?;
362
363 Ok(content.contains(&format!("name = \"{}\"", name))
365 || content.contains(&format!("name = '{}'", name)))
366 }
367
368 fn generate_insert_sql(
370 &self,
371 table: &str,
372 records: &[HashMap<String, serde_json::Value>],
373 ) -> CliResult<String> {
374 if records.is_empty() {
375 return Ok(String::new());
376 }
377
378 let mut sql = String::new();
379
380 let columns: Vec<&String> = records[0].keys().collect();
382 let column_names = columns
383 .iter()
384 .map(|c| format!("\"{}\"", c))
385 .collect::<Vec<_>>()
386 .join(", ");
387
388 for record in records {
389 let values = columns
390 .iter()
391 .map(|col| {
392 record
393 .get(*col)
394 .map(|v| self.value_to_sql(v))
395 .unwrap_or_else(|| "NULL".to_string())
396 })
397 .collect::<Vec<_>>()
398 .join(", ");
399
400 sql.push_str(&format!(
401 "INSERT INTO \"{}\" ({}) VALUES ({});\n",
402 table, column_names, values
403 ));
404 }
405
406 Ok(sql)
407 }
408
409 fn value_to_sql(&self, value: &serde_json::Value) -> String {
411 match value {
412 serde_json::Value::Null => "NULL".to_string(),
413 serde_json::Value::Bool(b) => {
414 if *b {
415 "TRUE".to_string()
416 } else {
417 "FALSE".to_string()
418 }
419 }
420 serde_json::Value::Number(n) => n.to_string(),
421 serde_json::Value::String(s) => {
422 match s.as_str() {
424 "now()" | "NOW()" => match self.provider.as_str() {
425 "postgresql" => "CURRENT_TIMESTAMP".to_string(),
426 "mysql" => "NOW()".to_string(),
427 "sqlite" => "datetime('now')".to_string(),
428 _ => "CURRENT_TIMESTAMP".to_string(),
429 },
430 "uuid()" | "UUID()" => match self.provider.as_str() {
431 "postgresql" => "gen_random_uuid()".to_string(),
432 "mysql" => "UUID()".to_string(),
433 "sqlite" => format!("'{}'", uuid::Uuid::new_v4()),
434 _ => "gen_random_uuid()".to_string(),
435 },
436 _ => format!("'{}'", s.replace('\'', "''")),
437 }
438 }
439 serde_json::Value::Array(arr) => {
440 let items = arr
442 .iter()
443 .map(|v| self.value_to_sql(v))
444 .collect::<Vec<_>>()
445 .join(", ");
446 format!("ARRAY[{}]", items)
447 }
448 serde_json::Value::Object(_) => {
449 format!("'{}'", value)
451 }
452 }
453 }
454
455 async fn execute_sql(&self, sql: &str) -> CliResult<u64> {
457 match self.provider.as_str() {
459 "postgresql" | "postgres" => self.execute_postgres_sql(sql).await,
460 "mysql" => self.execute_mysql_sql(sql).await,
461 "sqlite" => self.execute_sqlite_sql(sql).await,
462 _ => Err(CliError::Database(format!(
463 "Unsupported database provider: {}",
464 self.provider
465 ))),
466 }
467 }
468
469 async fn execute_postgres_sql(&self, sql: &str) -> CliResult<u64> {
471 let psql_result = Command::new("psql")
473 .args(["-d", &self.database_url, "-c", sql])
474 .output();
475
476 match psql_result {
477 Ok(output) if output.status.success() => {
478 let stdout = String::from_utf8_lossy(&output.stdout);
480 Ok(parse_affected_rows(&stdout).unwrap_or(0))
481 }
482 Ok(output) => {
483 let stderr = String::from_utf8_lossy(&output.stderr);
484 if stderr.contains("not found") || stderr.contains("No such file") {
486 Err(CliError::Command(
487 "psql not found. Install PostgreSQL client tools or use a Rust seed script.".to_string()
488 ))
489 } else {
490 Err(CliError::Database(format!(
491 "SQL execution failed: {}",
492 stderr
493 )))
494 }
495 }
496 Err(e) => {
497 let sqlx_result = Command::new("sqlx")
499 .args(["database", "seed"])
500 .env("DATABASE_URL", &self.database_url)
501 .stdin(std::process::Stdio::piped())
502 .output();
503
504 match sqlx_result {
505 Ok(output) if output.status.success() => Ok(0),
506 _ => Err(CliError::Command(format!(
507 "Failed to execute SQL. Install psql or use a Rust seed script: {}",
508 e
509 ))),
510 }
511 }
512 }
513 }
514
515 async fn execute_mysql_sql(&self, sql: &str) -> CliResult<u64> {
517 let url = url::Url::parse(&self.database_url)
519 .map_err(|e| CliError::Config(format!("Invalid MySQL URL: {}", e)))?;
520
521 let host = url.host_str().unwrap_or("localhost");
522 let port = url.port().unwrap_or(3306);
523 let user = url.username();
524 let password = url.password().unwrap_or("");
525 let database = url.path().trim_start_matches('/');
526
527 let mut cmd = Command::new("mysql");
528 cmd.args(["-h", host, "-P", &port.to_string(), "-u", user]);
529
530 if !password.is_empty() {
531 cmd.arg(format!("-p{}", password));
532 }
533
534 cmd.args(["-D", database, "-e", sql]);
535
536 let output = cmd.output()?;
537
538 if output.status.success() {
539 let stdout = String::from_utf8_lossy(&output.stdout);
540 Ok(parse_affected_rows(&stdout).unwrap_or(0))
541 } else {
542 let stderr = String::from_utf8_lossy(&output.stderr);
543 if stderr.contains("not found") || stderr.contains("No such file") {
544 Err(CliError::Command(
545 "mysql client not found. Install MySQL client tools or use a Rust seed script."
546 .to_string(),
547 ))
548 } else {
549 Err(CliError::Database(format!(
550 "SQL execution failed: {}",
551 stderr
552 )))
553 }
554 }
555 }
556
557 async fn execute_sqlite_sql(&self, sql: &str) -> CliResult<u64> {
559 let db_path = self
561 .database_url
562 .strip_prefix("sqlite://")
563 .or_else(|| self.database_url.strip_prefix("sqlite:"))
564 .unwrap_or(&self.database_url);
565
566 let output = Command::new("sqlite3").args([db_path, sql]).output()?;
567
568 if output.status.success() {
569 let stdout = String::from_utf8_lossy(&output.stdout);
570 Ok(parse_affected_rows(&stdout).unwrap_or(0))
571 } else {
572 let stderr = String::from_utf8_lossy(&output.stderr);
573 if stderr.contains("not found") || stderr.contains("No such file") {
574 Err(CliError::Command(
575 "sqlite3 not found. Install SQLite tools or use a Rust seed script."
576 .to_string(),
577 ))
578 } else {
579 Err(CliError::Database(format!(
580 "SQL execution failed: {}",
581 stderr
582 )))
583 }
584 }
585 }
586}
587
588#[derive(Debug)]
590pub struct SeedResult {
591 pub file_type: SeedFileType,
593 pub records_affected: u64,
595 pub tables_seeded: Vec<String>,
597 pub duration: std::time::Duration,
599}
600
601#[derive(Debug, Clone, Deserialize, Serialize)]
603pub struct SeedData {
604 #[serde(default)]
606 pub tables: HashMap<String, Vec<HashMap<String, serde_json::Value>>>,
607
608 #[serde(default)]
610 pub order: Vec<String>,
611
612 #[serde(default)]
614 pub truncate: bool,
615
616 #[serde(default)]
618 pub disable_fk_checks: bool,
619}
620
621pub fn find_seed_file(cwd: &Path, config: &Config) -> Option<PathBuf> {
627 if let Some(ref seed_path) = config.database.seed_path {
629 if seed_path.exists() {
630 return Some(seed_path.clone());
631 }
632 }
633
634 let candidates = [
636 cwd.join("seed.rs"),
637 cwd.join("seed.sql"),
638 cwd.join("seed.json"),
639 cwd.join("seed.toml"),
640 cwd.join("prax/seed.rs"),
641 cwd.join("prax/seed.sql"),
642 cwd.join("prax/seed.json"),
643 cwd.join("prax/seed.toml"),
644 cwd.join("prisma/seed.rs"),
645 cwd.join("prisma/seed.ts"), cwd.join("src/seed.rs"),
647 cwd.join("seeds/seed.rs"),
648 cwd.join("seeds/seed.sql"),
649 ];
650
651 candidates.into_iter().find(|p| p.exists())
652}
653
654pub fn get_database_url(config: &Config) -> CliResult<String> {
656 if let Some(ref url) = config.database.url {
658 let expanded = expand_env_var(url);
660 if !expanded.is_empty() && !expanded.contains("${") {
661 return Ok(expanded);
662 }
663 }
664
665 std::env::var("DATABASE_URL").map_err(|_| {
667 CliError::Config(
668 "Database URL not found. Set DATABASE_URL environment variable or configure in prax.toml"
669 .to_string(),
670 )
671 })
672}
673
674fn expand_env_var(s: &str) -> String {
676 let mut result = s.to_string();
677
678 let re = regex_lite::Regex::new(r"\$\{([^}]+)\}").unwrap();
680 for cap in re.captures_iter(s) {
681 let var_name = &cap[1];
682 if let Ok(value) = std::env::var(var_name) {
683 result = result.replace(&cap[0], &value);
684 }
685 }
686
687 let re2 = regex_lite::Regex::new(r"\$([A-Z_][A-Z0-9_]*)").unwrap();
689 for cap in re2.captures_iter(&result.clone()) {
690 let var_name = &cap[1];
691 if let Ok(value) = std::env::var(var_name) {
692 result = result.replace(&cap[0], &value);
693 }
694 }
695
696 result
697}
698
699fn parse_seed_output(line: &str) -> Option<u64> {
701 let patterns = [
706 r"(?i)created\s+(\d+)",
707 r"(?i)seeded\s+(\d+)",
708 r"(?i)inserted[:\s]+(\d+)",
709 r"(?i)(\d+)\s+records?",
710 r"(?i)(\d+)\s+rows?",
711 ];
712
713 for pattern in patterns {
714 if let Ok(re) = regex_lite::Regex::new(pattern) {
715 if let Some(caps) = re.captures(line) {
716 if let Some(m) = caps.get(1) {
717 if let Ok(n) = m.as_str().parse() {
718 return Some(n);
719 }
720 }
721 }
722 }
723 }
724
725 None
726}
727
728fn parse_affected_rows(output: &str) -> Option<u64> {
730 let patterns = [
735 r"INSERT\s+\d+\s+(\d+)",
736 r"UPDATE\s+(\d+)",
737 r"DELETE\s+(\d+)",
738 r"(\d+)\s+rows?\s+affected",
739 ];
740
741 let mut total = 0u64;
742
743 for pattern in patterns {
744 if let Ok(re) = regex_lite::Regex::new(pattern) {
745 for caps in re.captures_iter(output) {
746 if let Some(m) = caps.get(1) {
747 if let Ok(n) = m.as_str().parse::<u64>() {
748 total += n;
749 }
750 }
751 }
752 }
753 }
754
755 if total > 0 { Some(total) } else { None }
756}
757
758fn create_seed_cargo_toml(project_root: &Path) -> CliResult<String> {
760 let workspace_cargo = project_root.join("Cargo.toml");
762 let prax_version = if workspace_cargo.exists() {
763 let content = std::fs::read_to_string(&workspace_cargo)?;
764 extract_prax_version(&content).unwrap_or_else(|| "0.2".to_string())
766 } else {
767 "0.2".to_string()
768 };
769
770 Ok(format!(
771 r#"[package]
772name = "seed"
773version = "0.1.0"
774edition = "2024"
775
776[dependencies]
777prax-orm = "{}"
778tokio = {{ version = "1", features = ["full"] }}
779"#,
780 prax_version
781 ))
782}
783
784fn extract_prax_version(content: &str) -> Option<String> {
786 let simple_re = regex_lite::Regex::new(r#"prax-orm\s*=\s*"([^"]+)""#).ok()?;
788 if let Some(caps) = simple_re.captures(content) {
789 return Some(caps.get(1)?.as_str().to_string());
790 }
791
792 let complex_re = regex_lite::Regex::new(r#"prax-orm\s*=\s*\{[^}]*version\s*=\s*"([^"]+)""#).ok()?;
793 if let Some(caps) = complex_re.captures(content) {
794 return Some(caps.get(1)?.as_str().to_string());
795 }
796
797 None
798}
799
800#[cfg(test)]
805mod tests {
806 use super::*;
807
808 #[test]
809 fn test_seed_file_type_detection() {
810 assert_eq!(
811 SeedFileType::from_path(Path::new("seed.rs")),
812 Some(SeedFileType::Rust)
813 );
814 assert_eq!(
815 SeedFileType::from_path(Path::new("seed.sql")),
816 Some(SeedFileType::Sql)
817 );
818 assert_eq!(
819 SeedFileType::from_path(Path::new("data.json")),
820 Some(SeedFileType::Json)
821 );
822 assert_eq!(
823 SeedFileType::from_path(Path::new("data.toml")),
824 Some(SeedFileType::Toml)
825 );
826 assert_eq!(SeedFileType::from_path(Path::new("seed.txt")), None);
827 }
828
829 #[test]
830 fn test_parse_seed_output() {
831 assert_eq!(parse_seed_output("Created 10 users"), Some(10));
832 assert_eq!(parse_seed_output("Seeded 100 records"), Some(100));
833 assert_eq!(parse_seed_output("Inserted: 50"), Some(50));
834 assert_eq!(parse_seed_output("5 rows affected"), Some(5));
835 assert_eq!(parse_seed_output("no numbers here"), None);
836 }
837
838 #[test]
839 fn test_parse_affected_rows() {
840 assert_eq!(parse_affected_rows("INSERT 0 5"), Some(5));
841 assert_eq!(parse_affected_rows("UPDATE 3"), Some(3));
842 assert_eq!(parse_affected_rows("Query OK, 10 rows affected"), Some(10));
843 }
844
845 #[test]
846 fn test_expand_env_var() {
847 unsafe {
849 std::env::set_var("TEST_VAR", "test_value");
850 }
851 assert_eq!(expand_env_var("${TEST_VAR}"), "test_value");
852 assert_eq!(expand_env_var("$TEST_VAR"), "test_value");
853 assert_eq!(
854 expand_env_var("postgres://${TEST_VAR}@localhost"),
855 "postgres://test_value@localhost"
856 );
857 unsafe {
859 std::env::remove_var("TEST_VAR");
860 }
861 }
862}