1use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::process::Command;
12
13use serde::{Deserialize, Serialize};
14
15use crate::config::Config;
16use crate::error::{CliError, CliResult};
17use crate::output;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SeedFileType {
22 Rust,
24 Sql,
26 Json,
28 Toml,
30}
31
32impl SeedFileType {
33 pub fn from_path(path: &Path) -> Option<Self> {
35 match path.extension()?.to_str()? {
36 "rs" => Some(Self::Rust),
37 "sql" => Some(Self::Sql),
38 "json" => Some(Self::Json),
39 "toml" => Some(Self::Toml),
40 _ => None,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct SeedRunner {
48 pub seed_path: PathBuf,
50 pub file_type: SeedFileType,
52 pub database_url: String,
54 pub provider: String,
56 pub cwd: PathBuf,
58 pub environment: String,
60 pub reset_before_seed: bool,
62}
63
64impl SeedRunner {
65 pub fn new(
67 seed_path: PathBuf,
68 database_url: String,
69 provider: String,
70 cwd: PathBuf,
71 ) -> CliResult<Self> {
72 let file_type = SeedFileType::from_path(&seed_path).ok_or_else(|| {
73 CliError::Config(format!(
74 "Unsupported seed file type: {}. Supported: .rs, .sql, .json, .toml",
75 seed_path.display()
76 ))
77 })?;
78
79 Ok(Self {
80 seed_path,
81 file_type,
82 database_url,
83 provider,
84 cwd,
85 environment: std::env::var("PRAX_ENV").unwrap_or_else(|_| "development".to_string()),
86 reset_before_seed: false,
87 })
88 }
89
90 pub fn with_environment(mut self, env: impl Into<String>) -> Self {
92 self.environment = env.into();
93 self
94 }
95
96 pub fn with_reset(mut self, reset: bool) -> Self {
98 self.reset_before_seed = reset;
99 self
100 }
101
102 pub async fn run(&self) -> CliResult<SeedResult> {
104 match self.file_type {
105 SeedFileType::Rust => self.run_rust_seed().await,
106 SeedFileType::Sql => self.run_sql_seed().await,
107 SeedFileType::Json => self.run_json_seed().await,
108 SeedFileType::Toml => self.run_toml_seed().await,
109 }
110 }
111
112 async fn run_rust_seed(&self) -> CliResult<SeedResult> {
114 output::step(1, 4, "Compiling seed script...");
115
116 let cargo_toml = self.cwd.join("Cargo.toml");
118 if !cargo_toml.exists() {
119 return Err(CliError::Config(
120 "No Cargo.toml found. Rust seed scripts require a Rust project.".to_string(),
121 ));
122 }
123
124 let seed_name = self
126 .seed_path
127 .file_stem()
128 .and_then(|s| s.to_str())
129 .unwrap_or("seed");
130
131 let has_bin_target = self.check_bin_target(seed_name)?;
133
134 let mut records_affected = 0u64;
135
136 if has_bin_target {
137 output::step(2, 4, &format!("Building seed binary '{}'...", seed_name));
139
140 let build_status = Command::new("cargo")
141 .args(["build", "--bin", seed_name, "--release"])
142 .current_dir(&self.cwd)
143 .env("DATABASE_URL", &self.database_url)
144 .env("PRAX_ENV", &self.environment)
145 .status()?;
146
147 if !build_status.success() {
148 return Err(CliError::Command("Failed to build seed binary".to_string()));
149 }
150
151 output::step(3, 4, "Running seed...");
152
153 let run_output = Command::new("cargo")
154 .args(["run", "--bin", seed_name, "--release"])
155 .current_dir(&self.cwd)
156 .env("DATABASE_URL", &self.database_url)
157 .env("PRAX_ENV", &self.environment)
158 .output()?;
159
160 if !run_output.status.success() {
161 let stderr = String::from_utf8_lossy(&run_output.stderr);
162 return Err(CliError::Command(format!("Seed failed: {}", stderr)));
163 }
164
165 let stdout = String::from_utf8_lossy(&run_output.stdout);
167 for line in stdout.lines() {
168 output::list_item(line);
169 if let Some(count) = parse_seed_output(line) {
171 records_affected += count;
172 }
173 }
174
175 output::step(4, 4, "Verifying seed data...");
176 } else {
177 output::step(2, 4, "Compiling standalone seed script...");
179
180 let temp_dir = std::env::temp_dir().join("prax_seed");
182 std::fs::create_dir_all(&temp_dir)?;
183
184 let output_binary = temp_dir.join(seed_name);
185
186 let seed_content = std::fs::read_to_string(&self.seed_path)?;
188
189 if seed_content.contains("use prax") || seed_content.contains("#[tokio::main]") {
190 output::list_item("Creating temporary build environment...");
192
193 let temp_project = temp_dir.join("seed_project");
194 std::fs::create_dir_all(temp_project.join("src"))?;
195
196 std::fs::copy(&self.seed_path, temp_project.join("src/main.rs"))?;
198
199 let seed_cargo = create_seed_cargo_toml(&self.cwd)?;
201 std::fs::write(temp_project.join("Cargo.toml"), seed_cargo)?;
202
203 let build_status = Command::new("cargo")
205 .args(["build", "--release"])
206 .current_dir(&temp_project)
207 .env("DATABASE_URL", &self.database_url)
208 .env("PRAX_ENV", &self.environment)
209 .status()?;
210
211 if !build_status.success() {
212 return Err(CliError::Command("Failed to compile seed script".to_string()));
213 }
214
215 let built_binary = temp_project.join("target/release/seed");
217 if built_binary.exists() {
218 std::fs::copy(&built_binary, &output_binary)?;
219 }
220 } else {
221 return Err(CliError::Config(
222 "Seed script must be a valid Rust file with a main function".to_string(),
223 ));
224 }
225
226 output::step(3, 4, "Running seed...");
227
228 let run_output = Command::new(&output_binary)
229 .current_dir(&self.cwd)
230 .env("DATABASE_URL", &self.database_url)
231 .env("PRAX_ENV", &self.environment)
232 .output()?;
233
234 if !run_output.status.success() {
235 let stderr = String::from_utf8_lossy(&run_output.stderr);
236 return Err(CliError::Command(format!("Seed failed: {}", stderr)));
237 }
238
239 let stdout = String::from_utf8_lossy(&run_output.stdout);
240 for line in stdout.lines() {
241 output::list_item(line);
242 if let Some(count) = parse_seed_output(line) {
243 records_affected += count;
244 }
245 }
246
247 output::step(4, 4, "Verifying seed data...");
248 }
249
250 Ok(SeedResult {
251 file_type: self.file_type,
252 records_affected,
253 tables_seeded: Vec::new(),
254 duration: std::time::Duration::from_secs(0),
255 })
256 }
257
258 async fn run_sql_seed(&self) -> CliResult<SeedResult> {
260 output::step(1, 3, "Reading SQL seed file...");
261
262 let sql_content = std::fs::read_to_string(&self.seed_path)?;
263
264 let statements: Vec<&str> = sql_content
266 .split(';')
267 .map(|s| s.trim())
268 .filter(|s| !s.is_empty() && !s.starts_with("--"))
269 .collect();
270
271 output::list_item(&format!("Found {} SQL statements", statements.len()));
272
273 output::step(2, 3, "Executing SQL...");
274
275 let records = self.execute_sql(&sql_content).await?;
277
278 output::step(3, 3, "Verifying seed data...");
279
280 Ok(SeedResult {
281 file_type: self.file_type,
282 records_affected: records,
283 tables_seeded: Vec::new(),
284 duration: std::time::Duration::from_secs(0),
285 })
286 }
287
288 async fn run_json_seed(&self) -> CliResult<SeedResult> {
290 output::step(1, 4, "Reading JSON seed file...");
291
292 let json_content = std::fs::read_to_string(&self.seed_path)?;
293 let seed_data: SeedData =
294 serde_json::from_str(&json_content).map_err(|e| CliError::Config(e.to_string()))?;
295
296 output::step(2, 4, "Validating seed data...");
297 output::list_item(&format!("Found {} tables to seed", seed_data.tables.len()));
298
299 output::step(3, 4, "Inserting seed data...");
300
301 let mut total_records = 0u64;
302 let mut tables_seeded = Vec::new();
303
304 for (table_name, records) in &seed_data.tables {
305 let sql = self.generate_insert_sql(table_name, records)?;
306 let count = self.execute_sql(&sql).await?;
307 output::list_item(&format!(" {} - {} records", table_name, records.len()));
308 total_records += count;
309 tables_seeded.push(table_name.clone());
310 }
311
312 output::step(4, 4, "Verifying seed data...");
313
314 Ok(SeedResult {
315 file_type: self.file_type,
316 records_affected: total_records,
317 tables_seeded,
318 duration: std::time::Duration::from_secs(0),
319 })
320 }
321
322 async fn run_toml_seed(&self) -> CliResult<SeedResult> {
324 output::step(1, 4, "Reading TOML seed file...");
325
326 let toml_content = std::fs::read_to_string(&self.seed_path)?;
327 let seed_data: SeedData =
328 toml::from_str(&toml_content).map_err(|e| CliError::Config(e.to_string()))?;
329
330 output::step(2, 4, "Validating seed data...");
331 output::list_item(&format!("Found {} tables to seed", seed_data.tables.len()));
332
333 output::step(3, 4, "Inserting seed data...");
334
335 let mut total_records = 0u64;
336 let mut tables_seeded = Vec::new();
337
338 for (table_name, records) in &seed_data.tables {
339 let sql = self.generate_insert_sql(table_name, records)?;
340 let count = self.execute_sql(&sql).await?;
341 output::list_item(&format!(" {} - {} records", table_name, records.len()));
342 total_records += count;
343 tables_seeded.push(table_name.clone());
344 }
345
346 output::step(4, 4, "Verifying seed data...");
347
348 Ok(SeedResult {
349 file_type: self.file_type,
350 records_affected: total_records,
351 tables_seeded,
352 duration: std::time::Duration::from_secs(0),
353 })
354 }
355
356 fn check_bin_target(&self, name: &str) -> CliResult<bool> {
358 let cargo_toml = self.cwd.join("Cargo.toml");
359 let content = std::fs::read_to_string(&cargo_toml)?;
360
361 Ok(content.contains(&format!("name = \"{}\"", name))
363 || content.contains(&format!("name = '{}'", name)))
364 }
365
366 fn generate_insert_sql(
368 &self,
369 table: &str,
370 records: &[HashMap<String, serde_json::Value>],
371 ) -> CliResult<String> {
372 if records.is_empty() {
373 return Ok(String::new());
374 }
375
376 let mut sql = String::new();
377
378 let columns: Vec<&String> = records[0].keys().collect();
380 let column_names = columns
381 .iter()
382 .map(|c| format!("\"{}\"", c))
383 .collect::<Vec<_>>()
384 .join(", ");
385
386 for record in records {
387 let values = columns
388 .iter()
389 .map(|col| {
390 record
391 .get(*col)
392 .map(|v| self.value_to_sql(v))
393 .unwrap_or_else(|| "NULL".to_string())
394 })
395 .collect::<Vec<_>>()
396 .join(", ");
397
398 sql.push_str(&format!(
399 "INSERT INTO \"{}\" ({}) VALUES ({});\n",
400 table, column_names, values
401 ));
402 }
403
404 Ok(sql)
405 }
406
407 fn value_to_sql(&self, value: &serde_json::Value) -> String {
409 match value {
410 serde_json::Value::Null => "NULL".to_string(),
411 serde_json::Value::Bool(b) => {
412 if *b {
413 "TRUE".to_string()
414 } else {
415 "FALSE".to_string()
416 }
417 }
418 serde_json::Value::Number(n) => n.to_string(),
419 serde_json::Value::String(s) => {
420 match s.as_str() {
422 "now()" | "NOW()" => match self.provider.as_str() {
423 "postgresql" => "CURRENT_TIMESTAMP".to_string(),
424 "mysql" => "NOW()".to_string(),
425 "sqlite" => "datetime('now')".to_string(),
426 _ => "CURRENT_TIMESTAMP".to_string(),
427 },
428 "uuid()" | "UUID()" => match self.provider.as_str() {
429 "postgresql" => "gen_random_uuid()".to_string(),
430 "mysql" => "UUID()".to_string(),
431 "sqlite" => format!("'{}'", uuid::Uuid::new_v4()),
432 _ => "gen_random_uuid()".to_string(),
433 },
434 _ => format!("'{}'", s.replace('\'', "''")),
435 }
436 }
437 serde_json::Value::Array(arr) => {
438 let items = arr
440 .iter()
441 .map(|v| self.value_to_sql(v))
442 .collect::<Vec<_>>()
443 .join(", ");
444 format!("ARRAY[{}]", items)
445 }
446 serde_json::Value::Object(_) => {
447 format!("'{}'", value)
449 }
450 }
451 }
452
453 async fn execute_sql(&self, sql: &str) -> CliResult<u64> {
455 match self.provider.as_str() {
457 "postgresql" | "postgres" => self.execute_postgres_sql(sql).await,
458 "mysql" => self.execute_mysql_sql(sql).await,
459 "sqlite" => self.execute_sqlite_sql(sql).await,
460 _ => Err(CliError::Database(format!(
461 "Unsupported database provider: {}",
462 self.provider
463 ))),
464 }
465 }
466
467 async fn execute_postgres_sql(&self, sql: &str) -> CliResult<u64> {
469 let psql_result = Command::new("psql")
471 .args(["-d", &self.database_url, "-c", sql])
472 .output();
473
474 match psql_result {
475 Ok(output) if output.status.success() => {
476 let stdout = String::from_utf8_lossy(&output.stdout);
478 Ok(parse_affected_rows(&stdout).unwrap_or(0))
479 }
480 Ok(output) => {
481 let stderr = String::from_utf8_lossy(&output.stderr);
482 if stderr.contains("not found") || stderr.contains("No such file") {
484 Err(CliError::Command(
485 "psql not found. Install PostgreSQL client tools or use a Rust seed script.".to_string()
486 ))
487 } else {
488 Err(CliError::Database(format!("SQL execution failed: {}", stderr)))
489 }
490 }
491 Err(e) => {
492 let sqlx_result = Command::new("sqlx")
494 .args(["database", "seed"])
495 .env("DATABASE_URL", &self.database_url)
496 .stdin(std::process::Stdio::piped())
497 .output();
498
499 match sqlx_result {
500 Ok(output) if output.status.success() => Ok(0),
501 _ => Err(CliError::Command(format!(
502 "Failed to execute SQL. Install psql or use a Rust seed script: {}",
503 e
504 ))),
505 }
506 }
507 }
508 }
509
510 async fn execute_mysql_sql(&self, sql: &str) -> CliResult<u64> {
512 let url = url::Url::parse(&self.database_url)
514 .map_err(|e| CliError::Config(format!("Invalid MySQL URL: {}", e)))?;
515
516 let host = url.host_str().unwrap_or("localhost");
517 let port = url.port().unwrap_or(3306);
518 let user = url.username();
519 let password = url.password().unwrap_or("");
520 let database = url.path().trim_start_matches('/');
521
522 let mut cmd = Command::new("mysql");
523 cmd.args(["-h", host, "-P", &port.to_string(), "-u", user]);
524
525 if !password.is_empty() {
526 cmd.arg(format!("-p{}", password));
527 }
528
529 cmd.args(["-D", database, "-e", sql]);
530
531 let output = cmd.output()?;
532
533 if output.status.success() {
534 let stdout = String::from_utf8_lossy(&output.stdout);
535 Ok(parse_affected_rows(&stdout).unwrap_or(0))
536 } else {
537 let stderr = String::from_utf8_lossy(&output.stderr);
538 if stderr.contains("not found") || stderr.contains("No such file") {
539 Err(CliError::Command(
540 "mysql client not found. Install MySQL client tools or use a Rust seed script."
541 .to_string(),
542 ))
543 } else {
544 Err(CliError::Database(format!("SQL execution failed: {}", stderr)))
545 }
546 }
547 }
548
549 async fn execute_sqlite_sql(&self, sql: &str) -> CliResult<u64> {
551 let db_path = self
553 .database_url
554 .strip_prefix("sqlite://")
555 .or_else(|| self.database_url.strip_prefix("sqlite:"))
556 .unwrap_or(&self.database_url);
557
558 let output = Command::new("sqlite3")
559 .args([db_path, sql])
560 .output()?;
561
562 if output.status.success() {
563 let stdout = String::from_utf8_lossy(&output.stdout);
564 Ok(parse_affected_rows(&stdout).unwrap_or(0))
565 } else {
566 let stderr = String::from_utf8_lossy(&output.stderr);
567 if stderr.contains("not found") || stderr.contains("No such file") {
568 Err(CliError::Command(
569 "sqlite3 not found. Install SQLite tools or use a Rust seed script."
570 .to_string(),
571 ))
572 } else {
573 Err(CliError::Database(format!("SQL execution failed: {}", stderr)))
574 }
575 }
576 }
577}
578
579#[derive(Debug)]
581pub struct SeedResult {
582 pub file_type: SeedFileType,
584 pub records_affected: u64,
586 pub tables_seeded: Vec<String>,
588 pub duration: std::time::Duration,
590}
591
592#[derive(Debug, Clone, Deserialize, Serialize)]
594pub struct SeedData {
595 #[serde(default)]
597 pub tables: HashMap<String, Vec<HashMap<String, serde_json::Value>>>,
598
599 #[serde(default)]
601 pub order: Vec<String>,
602
603 #[serde(default)]
605 pub truncate: bool,
606
607 #[serde(default)]
609 pub disable_fk_checks: bool,
610}
611
612pub fn find_seed_file(cwd: &Path, config: &Config) -> Option<PathBuf> {
618 if let Some(ref seed_path) = config.database.seed_path {
620 if seed_path.exists() {
621 return Some(seed_path.clone());
622 }
623 }
624
625 let candidates = [
627 cwd.join("seed.rs"),
628 cwd.join("seed.sql"),
629 cwd.join("seed.json"),
630 cwd.join("seed.toml"),
631 cwd.join("prax/seed.rs"),
632 cwd.join("prax/seed.sql"),
633 cwd.join("prax/seed.json"),
634 cwd.join("prax/seed.toml"),
635 cwd.join("prisma/seed.rs"),
636 cwd.join("prisma/seed.ts"), cwd.join("src/seed.rs"),
638 cwd.join("seeds/seed.rs"),
639 cwd.join("seeds/seed.sql"),
640 ];
641
642 candidates.into_iter().find(|p| p.exists())
643}
644
645pub fn get_database_url(config: &Config) -> CliResult<String> {
647 if let Some(ref url) = config.database.url {
649 let expanded = expand_env_var(url);
651 if !expanded.is_empty() && !expanded.contains("${") {
652 return Ok(expanded);
653 }
654 }
655
656 std::env::var("DATABASE_URL").map_err(|_| {
658 CliError::Config(
659 "Database URL not found. Set DATABASE_URL environment variable or configure in prax.toml"
660 .to_string(),
661 )
662 })
663}
664
665fn expand_env_var(s: &str) -> String {
667 let mut result = s.to_string();
668
669 let re = regex_lite::Regex::new(r"\$\{([^}]+)\}").unwrap();
671 for cap in re.captures_iter(s) {
672 let var_name = &cap[1];
673 if let Ok(value) = std::env::var(var_name) {
674 result = result.replace(&cap[0], &value);
675 }
676 }
677
678 let re2 = regex_lite::Regex::new(r"\$([A-Z_][A-Z0-9_]*)").unwrap();
680 for cap in re2.captures_iter(&result.clone()) {
681 let var_name = &cap[1];
682 if let Ok(value) = std::env::var(var_name) {
683 result = result.replace(&cap[0], &value);
684 }
685 }
686
687 result
688}
689
690fn parse_seed_output(line: &str) -> Option<u64> {
692 let patterns = [
697 r"(?i)created\s+(\d+)",
698 r"(?i)seeded\s+(\d+)",
699 r"(?i)inserted[:\s]+(\d+)",
700 r"(?i)(\d+)\s+records?",
701 r"(?i)(\d+)\s+rows?",
702 ];
703
704 for pattern in patterns {
705 if let Ok(re) = regex_lite::Regex::new(pattern) {
706 if let Some(caps) = re.captures(line) {
707 if let Some(m) = caps.get(1) {
708 if let Ok(n) = m.as_str().parse() {
709 return Some(n);
710 }
711 }
712 }
713 }
714 }
715
716 None
717}
718
719fn parse_affected_rows(output: &str) -> Option<u64> {
721 let patterns = [
726 r"INSERT\s+\d+\s+(\d+)",
727 r"UPDATE\s+(\d+)",
728 r"DELETE\s+(\d+)",
729 r"(\d+)\s+rows?\s+affected",
730 ];
731
732 let mut total = 0u64;
733
734 for pattern in patterns {
735 if let Ok(re) = regex_lite::Regex::new(pattern) {
736 for caps in re.captures_iter(output) {
737 if let Some(m) = caps.get(1) {
738 if let Ok(n) = m.as_str().parse::<u64>() {
739 total += n;
740 }
741 }
742 }
743 }
744 }
745
746 if total > 0 {
747 Some(total)
748 } else {
749 None
750 }
751}
752
753fn create_seed_cargo_toml(project_root: &Path) -> CliResult<String> {
755 let workspace_cargo = project_root.join("Cargo.toml");
757 let prax_version = if workspace_cargo.exists() {
758 let content = std::fs::read_to_string(&workspace_cargo)?;
759 extract_prax_version(&content).unwrap_or_else(|| "0.2".to_string())
761 } else {
762 "0.2".to_string()
763 };
764
765 Ok(format!(
766 r#"[package]
767name = "seed"
768version = "0.1.0"
769edition = "2024"
770
771[dependencies]
772prax = "{}"
773tokio = {{ version = "1", features = ["full"] }}
774"#,
775 prax_version
776 ))
777}
778
779fn extract_prax_version(content: &str) -> Option<String> {
781 let simple_re = regex_lite::Regex::new(r#"prax\s*=\s*"([^"]+)""#).ok()?;
783 if let Some(caps) = simple_re.captures(content) {
784 return Some(caps.get(1)?.as_str().to_string());
785 }
786
787 let complex_re = regex_lite::Regex::new(r#"prax\s*=\s*\{[^}]*version\s*=\s*"([^"]+)""#).ok()?;
788 if let Some(caps) = complex_re.captures(content) {
789 return Some(caps.get(1)?.as_str().to_string());
790 }
791
792 None
793}
794
795#[cfg(test)]
800mod tests {
801 use super::*;
802
803 #[test]
804 fn test_seed_file_type_detection() {
805 assert_eq!(
806 SeedFileType::from_path(Path::new("seed.rs")),
807 Some(SeedFileType::Rust)
808 );
809 assert_eq!(
810 SeedFileType::from_path(Path::new("seed.sql")),
811 Some(SeedFileType::Sql)
812 );
813 assert_eq!(
814 SeedFileType::from_path(Path::new("data.json")),
815 Some(SeedFileType::Json)
816 );
817 assert_eq!(
818 SeedFileType::from_path(Path::new("data.toml")),
819 Some(SeedFileType::Toml)
820 );
821 assert_eq!(SeedFileType::from_path(Path::new("seed.txt")), None);
822 }
823
824 #[test]
825 fn test_parse_seed_output() {
826 assert_eq!(parse_seed_output("Created 10 users"), Some(10));
827 assert_eq!(parse_seed_output("Seeded 100 records"), Some(100));
828 assert_eq!(parse_seed_output("Inserted: 50"), Some(50));
829 assert_eq!(parse_seed_output("5 rows affected"), Some(5));
830 assert_eq!(parse_seed_output("no numbers here"), None);
831 }
832
833 #[test]
834 fn test_parse_affected_rows() {
835 assert_eq!(parse_affected_rows("INSERT 0 5"), Some(5));
836 assert_eq!(parse_affected_rows("UPDATE 3"), Some(3));
837 assert_eq!(
838 parse_affected_rows("Query OK, 10 rows affected"),
839 Some(10)
840 );
841 }
842
843 #[test]
844 fn test_expand_env_var() {
845 unsafe {
847 std::env::set_var("TEST_VAR", "test_value");
848 }
849 assert_eq!(expand_env_var("${TEST_VAR}"), "test_value");
850 assert_eq!(expand_env_var("$TEST_VAR"), "test_value");
851 assert_eq!(
852 expand_env_var("postgres://${TEST_VAR}@localhost"),
853 "postgres://test_value@localhost"
854 );
855 unsafe {
857 std::env::remove_var("TEST_VAR");
858 }
859 }
860}
861