prax_cli/commands/
seed.rs

1//! Database seeding implementation.
2//!
3//! Supports multiple seed file types:
4//! - `.rs` - Rust seed scripts (compiled and executed)
5//! - `.sql` - Raw SQL files (executed directly)
6//! - `.json` - JSON data files (declarative seeding)
7//! - `.toml` - TOML data files (declarative seeding)
8
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::process::Command;
12
13use serde::{Deserialize, Serialize};
14
15use crate::config::Config;
16use crate::error::{CliError, CliResult};
17use crate::output;
18
19/// Seed file types
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SeedFileType {
22    /// Rust seed script (.rs)
23    Rust,
24    /// SQL seed file (.sql)
25    Sql,
26    /// JSON seed data (.json)
27    Json,
28    /// TOML seed data (.toml)
29    Toml,
30}
31
32impl SeedFileType {
33    /// Detect seed file type from path extension
34    pub fn from_path(path: &Path) -> Option<Self> {
35        match path.extension()?.to_str()? {
36            "rs" => Some(Self::Rust),
37            "sql" => Some(Self::Sql),
38            "json" => Some(Self::Json),
39            "toml" => Some(Self::Toml),
40            _ => None,
41        }
42    }
43}
44
45/// Seed runner configuration
46#[derive(Debug, Clone)]
47pub struct SeedRunner {
48    /// Path to the seed file
49    pub seed_path: PathBuf,
50    /// Seed file type
51    pub file_type: SeedFileType,
52    /// Database URL for execution
53    pub database_url: String,
54    /// Database provider (postgresql, mysql, sqlite)
55    pub provider: String,
56    /// Current working directory
57    pub cwd: PathBuf,
58    /// Environment name (development, staging, production)
59    pub environment: String,
60    /// Whether to reset database before seeding
61    pub reset_before_seed: bool,
62}
63
64impl SeedRunner {
65    /// Create a new seed runner
66    pub fn new(
67        seed_path: PathBuf,
68        database_url: String,
69        provider: String,
70        cwd: PathBuf,
71    ) -> CliResult<Self> {
72        let file_type = SeedFileType::from_path(&seed_path).ok_or_else(|| {
73            CliError::Config(format!(
74                "Unsupported seed file type: {}. Supported: .rs, .sql, .json, .toml",
75                seed_path.display()
76            ))
77        })?;
78
79        Ok(Self {
80            seed_path,
81            file_type,
82            database_url,
83            provider,
84            cwd,
85            environment: std::env::var("PRAX_ENV").unwrap_or_else(|_| "development".to_string()),
86            reset_before_seed: false,
87        })
88    }
89
90    /// Set environment
91    pub fn with_environment(mut self, env: impl Into<String>) -> Self {
92        self.environment = env.into();
93        self
94    }
95
96    /// Set reset before seed
97    pub fn with_reset(mut self, reset: bool) -> Self {
98        self.reset_before_seed = reset;
99        self
100    }
101
102    /// Run the seed
103    pub async fn run(&self) -> CliResult<SeedResult> {
104        match self.file_type {
105            SeedFileType::Rust => self.run_rust_seed().await,
106            SeedFileType::Sql => self.run_sql_seed().await,
107            SeedFileType::Json => self.run_json_seed().await,
108            SeedFileType::Toml => self.run_toml_seed().await,
109        }
110    }
111
112    /// Run a Rust seed script
113    async fn run_rust_seed(&self) -> CliResult<SeedResult> {
114        output::step(1, 4, "Compiling seed script...");
115
116        // Check if we're in a Cargo workspace
117        let cargo_toml = self.cwd.join("Cargo.toml");
118        if !cargo_toml.exists() {
119            return Err(CliError::Config(
120                "No Cargo.toml found. Rust seed scripts require a Rust project.".to_string(),
121            ));
122        }
123
124        // Create a temporary bin target or use cargo run
125        let seed_name = self
126            .seed_path
127            .file_stem()
128            .and_then(|s| s.to_str())
129            .unwrap_or("seed");
130
131        // Check if there's a [[bin]] entry for the seed, or we need to compile manually
132        let has_bin_target = self.check_bin_target(seed_name)?;
133
134        let mut records_affected = 0u64;
135
136        if has_bin_target {
137            // Use cargo run directly
138            output::step(2, 4, &format!("Building seed binary '{}'...", seed_name));
139
140            let build_status = Command::new("cargo")
141                .args(["build", "--bin", seed_name, "--release"])
142                .current_dir(&self.cwd)
143                .env("DATABASE_URL", &self.database_url)
144                .env("PRAX_ENV", &self.environment)
145                .status()?;
146
147            if !build_status.success() {
148                return Err(CliError::Command("Failed to build seed binary".to_string()));
149            }
150
151            output::step(3, 4, "Running seed...");
152
153            let run_output = Command::new("cargo")
154                .args(["run", "--bin", seed_name, "--release"])
155                .current_dir(&self.cwd)
156                .env("DATABASE_URL", &self.database_url)
157                .env("PRAX_ENV", &self.environment)
158                .output()?;
159
160            if !run_output.status.success() {
161                let stderr = String::from_utf8_lossy(&run_output.stderr);
162                return Err(CliError::Command(format!("Seed failed: {}", stderr)));
163            }
164
165            // Parse output for record count if available
166            let stdout = String::from_utf8_lossy(&run_output.stdout);
167            for line in stdout.lines() {
168                output::list_item(line);
169                // Try to parse seed output for counts
170                if let Some(count) = parse_seed_output(line) {
171                    records_affected += count;
172                }
173            }
174
175            output::step(4, 4, "Verifying seed data...");
176        } else {
177            // Compile and run as a standalone script using rustc
178            output::step(2, 4, "Compiling standalone seed script...");
179
180            // Create temp directory for compiled seed
181            let temp_dir = std::env::temp_dir().join("prax_seed");
182            std::fs::create_dir_all(&temp_dir)?;
183
184            let output_binary = temp_dir.join(seed_name);
185
186            // Try to compile with cargo if it looks like a full Rust file
187            let seed_content = std::fs::read_to_string(&self.seed_path)?;
188
189            if seed_content.contains("use prax") || seed_content.contains("#[tokio::main]") {
190                // This is a standalone Rust file - we'll create a temporary Cargo project
191                output::list_item("Creating temporary build environment...");
192
193                let temp_project = temp_dir.join("seed_project");
194                std::fs::create_dir_all(temp_project.join("src"))?;
195
196                // Copy seed file
197                std::fs::copy(&self.seed_path, temp_project.join("src/main.rs"))?;
198
199                // Create Cargo.toml for the seed
200                let seed_cargo = create_seed_cargo_toml(&self.cwd)?;
201                std::fs::write(temp_project.join("Cargo.toml"), seed_cargo)?;
202
203                // Build
204                let build_status = Command::new("cargo")
205                    .args(["build", "--release"])
206                    .current_dir(&temp_project)
207                    .env("DATABASE_URL", &self.database_url)
208                    .env("PRAX_ENV", &self.environment)
209                    .status()?;
210
211                if !build_status.success() {
212                    return Err(CliError::Command("Failed to compile seed script".to_string()));
213                }
214
215                // Copy binary
216                let built_binary = temp_project.join("target/release/seed");
217                if built_binary.exists() {
218                    std::fs::copy(&built_binary, &output_binary)?;
219                }
220            } else {
221                return Err(CliError::Config(
222                    "Seed script must be a valid Rust file with a main function".to_string(),
223                ));
224            }
225
226            output::step(3, 4, "Running seed...");
227
228            let run_output = Command::new(&output_binary)
229                .current_dir(&self.cwd)
230                .env("DATABASE_URL", &self.database_url)
231                .env("PRAX_ENV", &self.environment)
232                .output()?;
233
234            if !run_output.status.success() {
235                let stderr = String::from_utf8_lossy(&run_output.stderr);
236                return Err(CliError::Command(format!("Seed failed: {}", stderr)));
237            }
238
239            let stdout = String::from_utf8_lossy(&run_output.stdout);
240            for line in stdout.lines() {
241                output::list_item(line);
242                if let Some(count) = parse_seed_output(line) {
243                    records_affected += count;
244                }
245            }
246
247            output::step(4, 4, "Verifying seed data...");
248        }
249
250        Ok(SeedResult {
251            file_type: self.file_type,
252            records_affected,
253            tables_seeded: Vec::new(),
254            duration: std::time::Duration::from_secs(0),
255        })
256    }
257
258    /// Run a SQL seed file
259    async fn run_sql_seed(&self) -> CliResult<SeedResult> {
260        output::step(1, 3, "Reading SQL seed file...");
261
262        let sql_content = std::fs::read_to_string(&self.seed_path)?;
263
264        // Count statements for progress
265        let statements: Vec<&str> = sql_content
266            .split(';')
267            .map(|s| s.trim())
268            .filter(|s| !s.is_empty() && !s.starts_with("--"))
269            .collect();
270
271        output::list_item(&format!("Found {} SQL statements", statements.len()));
272
273        output::step(2, 3, "Executing SQL...");
274
275        // Execute SQL based on provider
276        let records = self.execute_sql(&sql_content).await?;
277
278        output::step(3, 3, "Verifying seed data...");
279
280        Ok(SeedResult {
281            file_type: self.file_type,
282            records_affected: records,
283            tables_seeded: Vec::new(),
284            duration: std::time::Duration::from_secs(0),
285        })
286    }
287
288    /// Run a JSON seed file (declarative)
289    async fn run_json_seed(&self) -> CliResult<SeedResult> {
290        output::step(1, 4, "Reading JSON seed file...");
291
292        let json_content = std::fs::read_to_string(&self.seed_path)?;
293        let seed_data: SeedData =
294            serde_json::from_str(&json_content).map_err(|e| CliError::Config(e.to_string()))?;
295
296        output::step(2, 4, "Validating seed data...");
297        output::list_item(&format!("Found {} tables to seed", seed_data.tables.len()));
298
299        output::step(3, 4, "Inserting seed data...");
300
301        let mut total_records = 0u64;
302        let mut tables_seeded = Vec::new();
303
304        for (table_name, records) in &seed_data.tables {
305            let sql = self.generate_insert_sql(table_name, records)?;
306            let count = self.execute_sql(&sql).await?;
307            output::list_item(&format!("  {} - {} records", table_name, records.len()));
308            total_records += count;
309            tables_seeded.push(table_name.clone());
310        }
311
312        output::step(4, 4, "Verifying seed data...");
313
314        Ok(SeedResult {
315            file_type: self.file_type,
316            records_affected: total_records,
317            tables_seeded,
318            duration: std::time::Duration::from_secs(0),
319        })
320    }
321
322    /// Run a TOML seed file (declarative)
323    async fn run_toml_seed(&self) -> CliResult<SeedResult> {
324        output::step(1, 4, "Reading TOML seed file...");
325
326        let toml_content = std::fs::read_to_string(&self.seed_path)?;
327        let seed_data: SeedData =
328            toml::from_str(&toml_content).map_err(|e| CliError::Config(e.to_string()))?;
329
330        output::step(2, 4, "Validating seed data...");
331        output::list_item(&format!("Found {} tables to seed", seed_data.tables.len()));
332
333        output::step(3, 4, "Inserting seed data...");
334
335        let mut total_records = 0u64;
336        let mut tables_seeded = Vec::new();
337
338        for (table_name, records) in &seed_data.tables {
339            let sql = self.generate_insert_sql(table_name, records)?;
340            let count = self.execute_sql(&sql).await?;
341            output::list_item(&format!("  {} - {} records", table_name, records.len()));
342            total_records += count;
343            tables_seeded.push(table_name.clone());
344        }
345
346        output::step(4, 4, "Verifying seed data...");
347
348        Ok(SeedResult {
349            file_type: self.file_type,
350            records_affected: total_records,
351            tables_seeded,
352            duration: std::time::Duration::from_secs(0),
353        })
354    }
355
356    /// Check if there's a bin target in Cargo.toml
357    fn check_bin_target(&self, name: &str) -> CliResult<bool> {
358        let cargo_toml = self.cwd.join("Cargo.toml");
359        let content = std::fs::read_to_string(&cargo_toml)?;
360
361        // Simple check - look for [[bin]] with our name
362        Ok(content.contains(&format!("name = \"{}\"", name))
363            || content.contains(&format!("name = '{}'", name)))
364    }
365
366    /// Generate INSERT SQL from seed records
367    fn generate_insert_sql(
368        &self,
369        table: &str,
370        records: &[HashMap<String, serde_json::Value>],
371    ) -> CliResult<String> {
372        if records.is_empty() {
373            return Ok(String::new());
374        }
375
376        let mut sql = String::new();
377
378        // Get columns from first record
379        let columns: Vec<&String> = records[0].keys().collect();
380        let column_names = columns
381            .iter()
382            .map(|c| format!("\"{}\"", c))
383            .collect::<Vec<_>>()
384            .join(", ");
385
386        for record in records {
387            let values = columns
388                .iter()
389                .map(|col| {
390                    record
391                        .get(*col)
392                        .map(|v| self.value_to_sql(v))
393                        .unwrap_or_else(|| "NULL".to_string())
394                })
395                .collect::<Vec<_>>()
396                .join(", ");
397
398            sql.push_str(&format!(
399                "INSERT INTO \"{}\" ({}) VALUES ({});\n",
400                table, column_names, values
401            ));
402        }
403
404        Ok(sql)
405    }
406
407    /// Convert JSON value to SQL literal
408    fn value_to_sql(&self, value: &serde_json::Value) -> String {
409        match value {
410            serde_json::Value::Null => "NULL".to_string(),
411            serde_json::Value::Bool(b) => {
412                if *b {
413                    "TRUE".to_string()
414                } else {
415                    "FALSE".to_string()
416                }
417            }
418            serde_json::Value::Number(n) => n.to_string(),
419            serde_json::Value::String(s) => {
420                // Check for special functions
421                match s.as_str() {
422                    "now()" | "NOW()" => match self.provider.as_str() {
423                        "postgresql" => "CURRENT_TIMESTAMP".to_string(),
424                        "mysql" => "NOW()".to_string(),
425                        "sqlite" => "datetime('now')".to_string(),
426                        _ => "CURRENT_TIMESTAMP".to_string(),
427                    },
428                    "uuid()" | "UUID()" => match self.provider.as_str() {
429                        "postgresql" => "gen_random_uuid()".to_string(),
430                        "mysql" => "UUID()".to_string(),
431                        "sqlite" => format!("'{}'", uuid::Uuid::new_v4()),
432                        _ => "gen_random_uuid()".to_string(),
433                    },
434                    _ => format!("'{}'", s.replace('\'', "''")),
435                }
436            }
437            serde_json::Value::Array(arr) => {
438                // PostgreSQL array literal
439                let items = arr
440                    .iter()
441                    .map(|v| self.value_to_sql(v))
442                    .collect::<Vec<_>>()
443                    .join(", ");
444                format!("ARRAY[{}]", items)
445            }
446            serde_json::Value::Object(_) => {
447                // JSON/JSONB
448                format!("'{}'", value)
449            }
450        }
451    }
452
453    /// Execute SQL against the database
454    async fn execute_sql(&self, sql: &str) -> CliResult<u64> {
455        // Use command-line tools based on provider
456        match self.provider.as_str() {
457            "postgresql" | "postgres" => self.execute_postgres_sql(sql).await,
458            "mysql" => self.execute_mysql_sql(sql).await,
459            "sqlite" => self.execute_sqlite_sql(sql).await,
460            _ => Err(CliError::Database(format!(
461                "Unsupported database provider: {}",
462                self.provider
463            ))),
464        }
465    }
466
467    /// Execute SQL using psql
468    async fn execute_postgres_sql(&self, sql: &str) -> CliResult<u64> {
469        // First try using psql
470        let psql_result = Command::new("psql")
471            .args(["-d", &self.database_url, "-c", sql])
472            .output();
473
474        match psql_result {
475            Ok(output) if output.status.success() => {
476                // Try to parse affected rows from output
477                let stdout = String::from_utf8_lossy(&output.stdout);
478                Ok(parse_affected_rows(&stdout).unwrap_or(0))
479            }
480            Ok(output) => {
481                let stderr = String::from_utf8_lossy(&output.stderr);
482                // If psql not found, suggest alternative
483                if stderr.contains("not found") || stderr.contains("No such file") {
484                    Err(CliError::Command(
485                        "psql not found. Install PostgreSQL client tools or use a Rust seed script.".to_string()
486                    ))
487                } else {
488                    Err(CliError::Database(format!("SQL execution failed: {}", stderr)))
489                }
490            }
491            Err(e) => {
492                // psql not found - try using sqlx-cli if available
493                let sqlx_result = Command::new("sqlx")
494                    .args(["database", "seed"])
495                    .env("DATABASE_URL", &self.database_url)
496                    .stdin(std::process::Stdio::piped())
497                    .output();
498
499                match sqlx_result {
500                    Ok(output) if output.status.success() => Ok(0),
501                    _ => Err(CliError::Command(format!(
502                        "Failed to execute SQL. Install psql or use a Rust seed script: {}",
503                        e
504                    ))),
505                }
506            }
507        }
508    }
509
510    /// Execute SQL using mysql client
511    async fn execute_mysql_sql(&self, sql: &str) -> CliResult<u64> {
512        // Parse MySQL URL to extract components
513        let url = url::Url::parse(&self.database_url)
514            .map_err(|e| CliError::Config(format!("Invalid MySQL URL: {}", e)))?;
515
516        let host = url.host_str().unwrap_or("localhost");
517        let port = url.port().unwrap_or(3306);
518        let user = url.username();
519        let password = url.password().unwrap_or("");
520        let database = url.path().trim_start_matches('/');
521
522        let mut cmd = Command::new("mysql");
523        cmd.args(["-h", host, "-P", &port.to_string(), "-u", user]);
524
525        if !password.is_empty() {
526            cmd.arg(format!("-p{}", password));
527        }
528
529        cmd.args(["-D", database, "-e", sql]);
530
531        let output = cmd.output()?;
532
533        if output.status.success() {
534            let stdout = String::from_utf8_lossy(&output.stdout);
535            Ok(parse_affected_rows(&stdout).unwrap_or(0))
536        } else {
537            let stderr = String::from_utf8_lossy(&output.stderr);
538            if stderr.contains("not found") || stderr.contains("No such file") {
539                Err(CliError::Command(
540                    "mysql client not found. Install MySQL client tools or use a Rust seed script."
541                        .to_string(),
542                ))
543            } else {
544                Err(CliError::Database(format!("SQL execution failed: {}", stderr)))
545            }
546        }
547    }
548
549    /// Execute SQL using sqlite3
550    async fn execute_sqlite_sql(&self, sql: &str) -> CliResult<u64> {
551        // Extract database path from URL
552        let db_path = self
553            .database_url
554            .strip_prefix("sqlite://")
555            .or_else(|| self.database_url.strip_prefix("sqlite:"))
556            .unwrap_or(&self.database_url);
557
558        let output = Command::new("sqlite3")
559            .args([db_path, sql])
560            .output()?;
561
562        if output.status.success() {
563            let stdout = String::from_utf8_lossy(&output.stdout);
564            Ok(parse_affected_rows(&stdout).unwrap_or(0))
565        } else {
566            let stderr = String::from_utf8_lossy(&output.stderr);
567            if stderr.contains("not found") || stderr.contains("No such file") {
568                Err(CliError::Command(
569                    "sqlite3 not found. Install SQLite tools or use a Rust seed script."
570                        .to_string(),
571                ))
572            } else {
573                Err(CliError::Database(format!("SQL execution failed: {}", stderr)))
574            }
575        }
576    }
577}
578
579/// Seed execution result
580#[derive(Debug)]
581pub struct SeedResult {
582    /// Type of seed file that was executed
583    pub file_type: SeedFileType,
584    /// Number of records affected
585    pub records_affected: u64,
586    /// Tables that were seeded
587    pub tables_seeded: Vec<String>,
588    /// Execution duration
589    pub duration: std::time::Duration,
590}
591
592/// Declarative seed data structure
593#[derive(Debug, Clone, Deserialize, Serialize)]
594pub struct SeedData {
595    /// Tables to seed, keyed by table name
596    #[serde(default)]
597    pub tables: HashMap<String, Vec<HashMap<String, serde_json::Value>>>,
598
599    /// Seed order (optional - tables will be seeded in this order)
600    #[serde(default)]
601    pub order: Vec<String>,
602
603    /// Truncate tables before seeding
604    #[serde(default)]
605    pub truncate: bool,
606
607    /// Disable foreign key checks during seeding
608    #[serde(default)]
609    pub disable_fk_checks: bool,
610}
611
612// =============================================================================
613// Helper Functions
614// =============================================================================
615
616/// Find seed file in common locations
617pub fn find_seed_file(cwd: &Path, config: &Config) -> Option<PathBuf> {
618    // Check config first
619    if let Some(ref seed_path) = config.database.seed_path {
620        if seed_path.exists() {
621            return Some(seed_path.clone());
622        }
623    }
624
625    // Common locations
626    let candidates = [
627        cwd.join("seed.rs"),
628        cwd.join("seed.sql"),
629        cwd.join("seed.json"),
630        cwd.join("seed.toml"),
631        cwd.join("prax/seed.rs"),
632        cwd.join("prax/seed.sql"),
633        cwd.join("prax/seed.json"),
634        cwd.join("prax/seed.toml"),
635        cwd.join("prisma/seed.rs"),
636        cwd.join("prisma/seed.ts"), // Note: .ts not supported yet
637        cwd.join("src/seed.rs"),
638        cwd.join("seeds/seed.rs"),
639        cwd.join("seeds/seed.sql"),
640    ];
641
642    candidates.into_iter().find(|p| p.exists())
643}
644
645/// Get database URL from config or environment
646pub fn get_database_url(config: &Config) -> CliResult<String> {
647    // Try config first
648    if let Some(ref url) = config.database.url {
649        // Expand environment variables
650        let expanded = expand_env_var(url);
651        if !expanded.is_empty() && !expanded.contains("${") {
652            return Ok(expanded);
653        }
654    }
655
656    // Try environment variable
657    std::env::var("DATABASE_URL").map_err(|_| {
658        CliError::Config(
659            "Database URL not found. Set DATABASE_URL environment variable or configure in prax.toml"
660                .to_string(),
661        )
662    })
663}
664
665/// Expand environment variables in a string
666fn expand_env_var(s: &str) -> String {
667    let mut result = s.to_string();
668
669    // Match ${VAR} pattern
670    let re = regex_lite::Regex::new(r"\$\{([^}]+)\}").unwrap();
671    for cap in re.captures_iter(s) {
672        let var_name = &cap[1];
673        if let Ok(value) = std::env::var(var_name) {
674            result = result.replace(&cap[0], &value);
675        }
676    }
677
678    // Also match $VAR pattern (no braces)
679    let re2 = regex_lite::Regex::new(r"\$([A-Z_][A-Z0-9_]*)").unwrap();
680    for cap in re2.captures_iter(&result.clone()) {
681        let var_name = &cap[1];
682        if let Ok(value) = std::env::var(var_name) {
683            result = result.replace(&cap[0], &value);
684        }
685    }
686
687    result
688}
689
690/// Parse seed output for record counts
691fn parse_seed_output(line: &str) -> Option<u64> {
692    // Common patterns:
693    // "Created 10 users"
694    // "Seeded 100 records"
695    // "Inserted: 50"
696    let patterns = [
697        r"(?i)created\s+(\d+)",
698        r"(?i)seeded\s+(\d+)",
699        r"(?i)inserted[:\s]+(\d+)",
700        r"(?i)(\d+)\s+records?",
701        r"(?i)(\d+)\s+rows?",
702    ];
703
704    for pattern in patterns {
705        if let Ok(re) = regex_lite::Regex::new(pattern) {
706            if let Some(caps) = re.captures(line) {
707                if let Some(m) = caps.get(1) {
708                    if let Ok(n) = m.as_str().parse() {
709                        return Some(n);
710                    }
711                }
712            }
713        }
714    }
715
716    None
717}
718
719/// Parse affected rows from database output
720fn parse_affected_rows(output: &str) -> Option<u64> {
721    // PostgreSQL: "INSERT 0 5" or "UPDATE 3"
722    // MySQL: "Query OK, 5 rows affected"
723    // SQLite: no standard format
724
725    let patterns = [
726        r"INSERT\s+\d+\s+(\d+)",
727        r"UPDATE\s+(\d+)",
728        r"DELETE\s+(\d+)",
729        r"(\d+)\s+rows?\s+affected",
730    ];
731
732    let mut total = 0u64;
733
734    for pattern in patterns {
735        if let Ok(re) = regex_lite::Regex::new(pattern) {
736            for caps in re.captures_iter(output) {
737                if let Some(m) = caps.get(1) {
738                    if let Ok(n) = m.as_str().parse::<u64>() {
739                        total += n;
740                    }
741                }
742            }
743        }
744    }
745
746    if total > 0 {
747        Some(total)
748    } else {
749        None
750    }
751}
752
753/// Create a Cargo.toml for standalone seed script
754fn create_seed_cargo_toml(project_root: &Path) -> CliResult<String> {
755    // Try to read the workspace Cargo.toml to get prax version
756    let workspace_cargo = project_root.join("Cargo.toml");
757    let prax_version = if workspace_cargo.exists() {
758        let content = std::fs::read_to_string(&workspace_cargo)?;
759        // Try to extract prax version from dependencies
760        extract_prax_version(&content).unwrap_or_else(|| "0.2".to_string())
761    } else {
762        "0.2".to_string()
763    };
764
765    Ok(format!(
766        r#"[package]
767name = "seed"
768version = "0.1.0"
769edition = "2024"
770
771[dependencies]
772prax = "{}"
773tokio = {{ version = "1", features = ["full"] }}
774"#,
775        prax_version
776    ))
777}
778
779/// Extract prax version from Cargo.toml
780fn extract_prax_version(content: &str) -> Option<String> {
781    // Look for prax = "x.y.z" or prax = { version = "x.y.z" }
782    let simple_re = regex_lite::Regex::new(r#"prax\s*=\s*"([^"]+)""#).ok()?;
783    if let Some(caps) = simple_re.captures(content) {
784        return Some(caps.get(1)?.as_str().to_string());
785    }
786
787    let complex_re = regex_lite::Regex::new(r#"prax\s*=\s*\{[^}]*version\s*=\s*"([^"]+)""#).ok()?;
788    if let Some(caps) = complex_re.captures(content) {
789        return Some(caps.get(1)?.as_str().to_string());
790    }
791
792    None
793}
794
795// =============================================================================
796// Tests
797// =============================================================================
798
799#[cfg(test)]
800mod tests {
801    use super::*;
802
803    #[test]
804    fn test_seed_file_type_detection() {
805        assert_eq!(
806            SeedFileType::from_path(Path::new("seed.rs")),
807            Some(SeedFileType::Rust)
808        );
809        assert_eq!(
810            SeedFileType::from_path(Path::new("seed.sql")),
811            Some(SeedFileType::Sql)
812        );
813        assert_eq!(
814            SeedFileType::from_path(Path::new("data.json")),
815            Some(SeedFileType::Json)
816        );
817        assert_eq!(
818            SeedFileType::from_path(Path::new("data.toml")),
819            Some(SeedFileType::Toml)
820        );
821        assert_eq!(SeedFileType::from_path(Path::new("seed.txt")), None);
822    }
823
824    #[test]
825    fn test_parse_seed_output() {
826        assert_eq!(parse_seed_output("Created 10 users"), Some(10));
827        assert_eq!(parse_seed_output("Seeded 100 records"), Some(100));
828        assert_eq!(parse_seed_output("Inserted: 50"), Some(50));
829        assert_eq!(parse_seed_output("5 rows affected"), Some(5));
830        assert_eq!(parse_seed_output("no numbers here"), None);
831    }
832
833    #[test]
834    fn test_parse_affected_rows() {
835        assert_eq!(parse_affected_rows("INSERT 0 5"), Some(5));
836        assert_eq!(parse_affected_rows("UPDATE 3"), Some(3));
837        assert_eq!(
838            parse_affected_rows("Query OK, 10 rows affected"),
839            Some(10)
840        );
841    }
842
843    #[test]
844    fn test_expand_env_var() {
845        // SAFETY: Single-threaded test environment
846        unsafe {
847            std::env::set_var("TEST_VAR", "test_value");
848        }
849        assert_eq!(expand_env_var("${TEST_VAR}"), "test_value");
850        assert_eq!(expand_env_var("$TEST_VAR"), "test_value");
851        assert_eq!(
852            expand_env_var("postgres://${TEST_VAR}@localhost"),
853            "postgres://test_value@localhost"
854        );
855        // SAFETY: Single-threaded test environment
856        unsafe {
857            std::env::remove_var("TEST_VAR");
858        }
859    }
860}
861