prax_cli/commands/
seed.rs

1//! Database seeding implementation.
2//!
3//! Supports multiple seed file types:
4//! - `.rs` - Rust seed scripts (compiled and executed)
5//! - `.sql` - Raw SQL files (executed directly)
6//! - `.json` - JSON data files (declarative seeding)
7//! - `.toml` - TOML data files (declarative seeding)
8
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::process::Command;
12
13use serde::{Deserialize, Serialize};
14
15use crate::config::Config;
16use crate::error::{CliError, CliResult};
17use crate::output;
18
19/// Seed file types
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SeedFileType {
22    /// Rust seed script (.rs)
23    Rust,
24    /// SQL seed file (.sql)
25    Sql,
26    /// JSON seed data (.json)
27    Json,
28    /// TOML seed data (.toml)
29    Toml,
30}
31
32impl SeedFileType {
33    /// Detect seed file type from path extension
34    pub fn from_path(path: &Path) -> Option<Self> {
35        match path.extension()?.to_str()? {
36            "rs" => Some(Self::Rust),
37            "sql" => Some(Self::Sql),
38            "json" => Some(Self::Json),
39            "toml" => Some(Self::Toml),
40            _ => None,
41        }
42    }
43}
44
45/// Seed runner configuration
46#[derive(Debug, Clone)]
47pub struct SeedRunner {
48    /// Path to the seed file
49    pub seed_path: PathBuf,
50    /// Seed file type
51    pub file_type: SeedFileType,
52    /// Database URL for execution
53    pub database_url: String,
54    /// Database provider (postgresql, mysql, sqlite)
55    pub provider: String,
56    /// Current working directory
57    pub cwd: PathBuf,
58    /// Environment name (development, staging, production)
59    pub environment: String,
60    /// Whether to reset database before seeding
61    pub reset_before_seed: bool,
62}
63
64impl SeedRunner {
65    /// Create a new seed runner
66    pub fn new(
67        seed_path: PathBuf,
68        database_url: String,
69        provider: String,
70        cwd: PathBuf,
71    ) -> CliResult<Self> {
72        let file_type = SeedFileType::from_path(&seed_path).ok_or_else(|| {
73            CliError::Config(format!(
74                "Unsupported seed file type: {}. Supported: .rs, .sql, .json, .toml",
75                seed_path.display()
76            ))
77        })?;
78
79        Ok(Self {
80            seed_path,
81            file_type,
82            database_url,
83            provider,
84            cwd,
85            environment: std::env::var("PRAX_ENV").unwrap_or_else(|_| "development".to_string()),
86            reset_before_seed: false,
87        })
88    }
89
90    /// Set environment
91    pub fn with_environment(mut self, env: impl Into<String>) -> Self {
92        self.environment = env.into();
93        self
94    }
95
96    /// Set reset before seed
97    pub fn with_reset(mut self, reset: bool) -> Self {
98        self.reset_before_seed = reset;
99        self
100    }
101
102    /// Run the seed
103    pub async fn run(&self) -> CliResult<SeedResult> {
104        match self.file_type {
105            SeedFileType::Rust => self.run_rust_seed().await,
106            SeedFileType::Sql => self.run_sql_seed().await,
107            SeedFileType::Json => self.run_json_seed().await,
108            SeedFileType::Toml => self.run_toml_seed().await,
109        }
110    }
111
112    /// Run a Rust seed script
113    async fn run_rust_seed(&self) -> CliResult<SeedResult> {
114        output::step(1, 4, "Compiling seed script...");
115
116        // Check if we're in a Cargo workspace
117        let cargo_toml = self.cwd.join("Cargo.toml");
118        if !cargo_toml.exists() {
119            return Err(CliError::Config(
120                "No Cargo.toml found. Rust seed scripts require a Rust project.".to_string(),
121            ));
122        }
123
124        // Create a temporary bin target or use cargo run
125        let seed_name = self
126            .seed_path
127            .file_stem()
128            .and_then(|s| s.to_str())
129            .unwrap_or("seed");
130
131        // Check if there's a [[bin]] entry for the seed, or we need to compile manually
132        let has_bin_target = self.check_bin_target(seed_name)?;
133
134        let mut records_affected = 0u64;
135
136        if has_bin_target {
137            // Use cargo run directly
138            output::step(2, 4, &format!("Building seed binary '{}'...", seed_name));
139
140            let build_status = Command::new("cargo")
141                .args(["build", "--bin", seed_name, "--release"])
142                .current_dir(&self.cwd)
143                .env("DATABASE_URL", &self.database_url)
144                .env("PRAX_ENV", &self.environment)
145                .status()?;
146
147            if !build_status.success() {
148                return Err(CliError::Command("Failed to build seed binary".to_string()));
149            }
150
151            output::step(3, 4, "Running seed...");
152
153            let run_output = Command::new("cargo")
154                .args(["run", "--bin", seed_name, "--release"])
155                .current_dir(&self.cwd)
156                .env("DATABASE_URL", &self.database_url)
157                .env("PRAX_ENV", &self.environment)
158                .output()?;
159
160            if !run_output.status.success() {
161                let stderr = String::from_utf8_lossy(&run_output.stderr);
162                return Err(CliError::Command(format!("Seed failed: {}", stderr)));
163            }
164
165            // Parse output for record count if available
166            let stdout = String::from_utf8_lossy(&run_output.stdout);
167            for line in stdout.lines() {
168                output::list_item(line);
169                // Try to parse seed output for counts
170                if let Some(count) = parse_seed_output(line) {
171                    records_affected += count;
172                }
173            }
174
175            output::step(4, 4, "Verifying seed data...");
176        } else {
177            // Compile and run as a standalone script using rustc
178            output::step(2, 4, "Compiling standalone seed script...");
179
180            // Create temp directory for compiled seed
181            let temp_dir = std::env::temp_dir().join("prax_seed");
182            std::fs::create_dir_all(&temp_dir)?;
183
184            let output_binary = temp_dir.join(seed_name);
185
186            // Try to compile with cargo if it looks like a full Rust file
187            let seed_content = std::fs::read_to_string(&self.seed_path)?;
188
189            if seed_content.contains("use prax") || seed_content.contains("#[tokio::main]") {
190                // This is a standalone Rust file - we'll create a temporary Cargo project
191                output::list_item("Creating temporary build environment...");
192
193                let temp_project = temp_dir.join("seed_project");
194                std::fs::create_dir_all(temp_project.join("src"))?;
195
196                // Copy seed file
197                std::fs::copy(&self.seed_path, temp_project.join("src/main.rs"))?;
198
199                // Create Cargo.toml for the seed
200                let seed_cargo = create_seed_cargo_toml(&self.cwd)?;
201                std::fs::write(temp_project.join("Cargo.toml"), seed_cargo)?;
202
203                // Build
204                let build_status = Command::new("cargo")
205                    .args(["build", "--release"])
206                    .current_dir(&temp_project)
207                    .env("DATABASE_URL", &self.database_url)
208                    .env("PRAX_ENV", &self.environment)
209                    .status()?;
210
211                if !build_status.success() {
212                    return Err(CliError::Command(
213                        "Failed to compile seed script".to_string(),
214                    ));
215                }
216
217                // Copy binary
218                let built_binary = temp_project.join("target/release/seed");
219                if built_binary.exists() {
220                    std::fs::copy(&built_binary, &output_binary)?;
221                }
222            } else {
223                return Err(CliError::Config(
224                    "Seed script must be a valid Rust file with a main function".to_string(),
225                ));
226            }
227
228            output::step(3, 4, "Running seed...");
229
230            let run_output = Command::new(&output_binary)
231                .current_dir(&self.cwd)
232                .env("DATABASE_URL", &self.database_url)
233                .env("PRAX_ENV", &self.environment)
234                .output()?;
235
236            if !run_output.status.success() {
237                let stderr = String::from_utf8_lossy(&run_output.stderr);
238                return Err(CliError::Command(format!("Seed failed: {}", stderr)));
239            }
240
241            let stdout = String::from_utf8_lossy(&run_output.stdout);
242            for line in stdout.lines() {
243                output::list_item(line);
244                if let Some(count) = parse_seed_output(line) {
245                    records_affected += count;
246                }
247            }
248
249            output::step(4, 4, "Verifying seed data...");
250        }
251
252        Ok(SeedResult {
253            file_type: self.file_type,
254            records_affected,
255            tables_seeded: Vec::new(),
256            duration: std::time::Duration::from_secs(0),
257        })
258    }
259
260    /// Run a SQL seed file
261    async fn run_sql_seed(&self) -> CliResult<SeedResult> {
262        output::step(1, 3, "Reading SQL seed file...");
263
264        let sql_content = std::fs::read_to_string(&self.seed_path)?;
265
266        // Count statements for progress
267        let statements: Vec<&str> = sql_content
268            .split(';')
269            .map(|s| s.trim())
270            .filter(|s| !s.is_empty() && !s.starts_with("--"))
271            .collect();
272
273        output::list_item(&format!("Found {} SQL statements", statements.len()));
274
275        output::step(2, 3, "Executing SQL...");
276
277        // Execute SQL based on provider
278        let records = self.execute_sql(&sql_content).await?;
279
280        output::step(3, 3, "Verifying seed data...");
281
282        Ok(SeedResult {
283            file_type: self.file_type,
284            records_affected: records,
285            tables_seeded: Vec::new(),
286            duration: std::time::Duration::from_secs(0),
287        })
288    }
289
290    /// Run a JSON seed file (declarative)
291    async fn run_json_seed(&self) -> CliResult<SeedResult> {
292        output::step(1, 4, "Reading JSON seed file...");
293
294        let json_content = std::fs::read_to_string(&self.seed_path)?;
295        let seed_data: SeedData =
296            serde_json::from_str(&json_content).map_err(|e| CliError::Config(e.to_string()))?;
297
298        output::step(2, 4, "Validating seed data...");
299        output::list_item(&format!("Found {} tables to seed", seed_data.tables.len()));
300
301        output::step(3, 4, "Inserting seed data...");
302
303        let mut total_records = 0u64;
304        let mut tables_seeded = Vec::new();
305
306        for (table_name, records) in &seed_data.tables {
307            let sql = self.generate_insert_sql(table_name, records)?;
308            let count = self.execute_sql(&sql).await?;
309            output::list_item(&format!("  {} - {} records", table_name, records.len()));
310            total_records += count;
311            tables_seeded.push(table_name.clone());
312        }
313
314        output::step(4, 4, "Verifying seed data...");
315
316        Ok(SeedResult {
317            file_type: self.file_type,
318            records_affected: total_records,
319            tables_seeded,
320            duration: std::time::Duration::from_secs(0),
321        })
322    }
323
324    /// Run a TOML seed file (declarative)
325    async fn run_toml_seed(&self) -> CliResult<SeedResult> {
326        output::step(1, 4, "Reading TOML seed file...");
327
328        let toml_content = std::fs::read_to_string(&self.seed_path)?;
329        let seed_data: SeedData =
330            toml::from_str(&toml_content).map_err(|e| CliError::Config(e.to_string()))?;
331
332        output::step(2, 4, "Validating seed data...");
333        output::list_item(&format!("Found {} tables to seed", seed_data.tables.len()));
334
335        output::step(3, 4, "Inserting seed data...");
336
337        let mut total_records = 0u64;
338        let mut tables_seeded = Vec::new();
339
340        for (table_name, records) in &seed_data.tables {
341            let sql = self.generate_insert_sql(table_name, records)?;
342            let count = self.execute_sql(&sql).await?;
343            output::list_item(&format!("  {} - {} records", table_name, records.len()));
344            total_records += count;
345            tables_seeded.push(table_name.clone());
346        }
347
348        output::step(4, 4, "Verifying seed data...");
349
350        Ok(SeedResult {
351            file_type: self.file_type,
352            records_affected: total_records,
353            tables_seeded,
354            duration: std::time::Duration::from_secs(0),
355        })
356    }
357
358    /// Check if there's a bin target in Cargo.toml
359    fn check_bin_target(&self, name: &str) -> CliResult<bool> {
360        let cargo_toml = self.cwd.join("Cargo.toml");
361        let content = std::fs::read_to_string(&cargo_toml)?;
362
363        // Simple check - look for [[bin]] with our name
364        Ok(content.contains(&format!("name = \"{}\"", name))
365            || content.contains(&format!("name = '{}'", name)))
366    }
367
368    /// Generate INSERT SQL from seed records
369    fn generate_insert_sql(
370        &self,
371        table: &str,
372        records: &[HashMap<String, serde_json::Value>],
373    ) -> CliResult<String> {
374        if records.is_empty() {
375            return Ok(String::new());
376        }
377
378        let mut sql = String::new();
379
380        // Get columns from first record
381        let columns: Vec<&String> = records[0].keys().collect();
382        let column_names = columns
383            .iter()
384            .map(|c| format!("\"{}\"", c))
385            .collect::<Vec<_>>()
386            .join(", ");
387
388        for record in records {
389            let values = columns
390                .iter()
391                .map(|col| {
392                    record
393                        .get(*col)
394                        .map(|v| self.value_to_sql(v))
395                        .unwrap_or_else(|| "NULL".to_string())
396                })
397                .collect::<Vec<_>>()
398                .join(", ");
399
400            sql.push_str(&format!(
401                "INSERT INTO \"{}\" ({}) VALUES ({});\n",
402                table, column_names, values
403            ));
404        }
405
406        Ok(sql)
407    }
408
409    /// Convert JSON value to SQL literal
410    fn value_to_sql(&self, value: &serde_json::Value) -> String {
411        match value {
412            serde_json::Value::Null => "NULL".to_string(),
413            serde_json::Value::Bool(b) => {
414                if *b {
415                    "TRUE".to_string()
416                } else {
417                    "FALSE".to_string()
418                }
419            }
420            serde_json::Value::Number(n) => n.to_string(),
421            serde_json::Value::String(s) => {
422                // Check for special functions
423                match s.as_str() {
424                    "now()" | "NOW()" => match self.provider.as_str() {
425                        "postgresql" => "CURRENT_TIMESTAMP".to_string(),
426                        "mysql" => "NOW()".to_string(),
427                        "sqlite" => "datetime('now')".to_string(),
428                        _ => "CURRENT_TIMESTAMP".to_string(),
429                    },
430                    "uuid()" | "UUID()" => match self.provider.as_str() {
431                        "postgresql" => "gen_random_uuid()".to_string(),
432                        "mysql" => "UUID()".to_string(),
433                        "sqlite" => format!("'{}'", uuid::Uuid::new_v4()),
434                        _ => "gen_random_uuid()".to_string(),
435                    },
436                    _ => format!("'{}'", s.replace('\'', "''")),
437                }
438            }
439            serde_json::Value::Array(arr) => {
440                // PostgreSQL array literal
441                let items = arr
442                    .iter()
443                    .map(|v| self.value_to_sql(v))
444                    .collect::<Vec<_>>()
445                    .join(", ");
446                format!("ARRAY[{}]", items)
447            }
448            serde_json::Value::Object(_) => {
449                // JSON/JSONB
450                format!("'{}'", value)
451            }
452        }
453    }
454
455    /// Execute SQL against the database
456    async fn execute_sql(&self, sql: &str) -> CliResult<u64> {
457        // Use command-line tools based on provider
458        match self.provider.as_str() {
459            "postgresql" | "postgres" => self.execute_postgres_sql(sql).await,
460            "mysql" => self.execute_mysql_sql(sql).await,
461            "sqlite" => self.execute_sqlite_sql(sql).await,
462            _ => Err(CliError::Database(format!(
463                "Unsupported database provider: {}",
464                self.provider
465            ))),
466        }
467    }
468
469    /// Execute SQL using psql
470    async fn execute_postgres_sql(&self, sql: &str) -> CliResult<u64> {
471        // First try using psql
472        let psql_result = Command::new("psql")
473            .args(["-d", &self.database_url, "-c", sql])
474            .output();
475
476        match psql_result {
477            Ok(output) if output.status.success() => {
478                // Try to parse affected rows from output
479                let stdout = String::from_utf8_lossy(&output.stdout);
480                Ok(parse_affected_rows(&stdout).unwrap_or(0))
481            }
482            Ok(output) => {
483                let stderr = String::from_utf8_lossy(&output.stderr);
484                // If psql not found, suggest alternative
485                if stderr.contains("not found") || stderr.contains("No such file") {
486                    Err(CliError::Command(
487                        "psql not found. Install PostgreSQL client tools or use a Rust seed script.".to_string()
488                    ))
489                } else {
490                    Err(CliError::Database(format!(
491                        "SQL execution failed: {}",
492                        stderr
493                    )))
494                }
495            }
496            Err(e) => {
497                // psql not found - try using sqlx-cli if available
498                let sqlx_result = Command::new("sqlx")
499                    .args(["database", "seed"])
500                    .env("DATABASE_URL", &self.database_url)
501                    .stdin(std::process::Stdio::piped())
502                    .output();
503
504                match sqlx_result {
505                    Ok(output) if output.status.success() => Ok(0),
506                    _ => Err(CliError::Command(format!(
507                        "Failed to execute SQL. Install psql or use a Rust seed script: {}",
508                        e
509                    ))),
510                }
511            }
512        }
513    }
514
515    /// Execute SQL using mysql client
516    async fn execute_mysql_sql(&self, sql: &str) -> CliResult<u64> {
517        // Parse MySQL URL to extract components
518        let url = url::Url::parse(&self.database_url)
519            .map_err(|e| CliError::Config(format!("Invalid MySQL URL: {}", e)))?;
520
521        let host = url.host_str().unwrap_or("localhost");
522        let port = url.port().unwrap_or(3306);
523        let user = url.username();
524        let password = url.password().unwrap_or("");
525        let database = url.path().trim_start_matches('/');
526
527        let mut cmd = Command::new("mysql");
528        cmd.args(["-h", host, "-P", &port.to_string(), "-u", user]);
529
530        if !password.is_empty() {
531            cmd.arg(format!("-p{}", password));
532        }
533
534        cmd.args(["-D", database, "-e", sql]);
535
536        let output = cmd.output()?;
537
538        if output.status.success() {
539            let stdout = String::from_utf8_lossy(&output.stdout);
540            Ok(parse_affected_rows(&stdout).unwrap_or(0))
541        } else {
542            let stderr = String::from_utf8_lossy(&output.stderr);
543            if stderr.contains("not found") || stderr.contains("No such file") {
544                Err(CliError::Command(
545                    "mysql client not found. Install MySQL client tools or use a Rust seed script."
546                        .to_string(),
547                ))
548            } else {
549                Err(CliError::Database(format!(
550                    "SQL execution failed: {}",
551                    stderr
552                )))
553            }
554        }
555    }
556
557    /// Execute SQL using sqlite3
558    async fn execute_sqlite_sql(&self, sql: &str) -> CliResult<u64> {
559        // Extract database path from URL
560        let db_path = self
561            .database_url
562            .strip_prefix("sqlite://")
563            .or_else(|| self.database_url.strip_prefix("sqlite:"))
564            .unwrap_or(&self.database_url);
565
566        let output = Command::new("sqlite3").args([db_path, sql]).output()?;
567
568        if output.status.success() {
569            let stdout = String::from_utf8_lossy(&output.stdout);
570            Ok(parse_affected_rows(&stdout).unwrap_or(0))
571        } else {
572            let stderr = String::from_utf8_lossy(&output.stderr);
573            if stderr.contains("not found") || stderr.contains("No such file") {
574                Err(CliError::Command(
575                    "sqlite3 not found. Install SQLite tools or use a Rust seed script."
576                        .to_string(),
577                ))
578            } else {
579                Err(CliError::Database(format!(
580                    "SQL execution failed: {}",
581                    stderr
582                )))
583            }
584        }
585    }
586}
587
588/// Seed execution result
589#[derive(Debug)]
590pub struct SeedResult {
591    /// Type of seed file that was executed
592    pub file_type: SeedFileType,
593    /// Number of records affected
594    pub records_affected: u64,
595    /// Tables that were seeded
596    pub tables_seeded: Vec<String>,
597    /// Execution duration
598    pub duration: std::time::Duration,
599}
600
601/// Declarative seed data structure
602#[derive(Debug, Clone, Deserialize, Serialize)]
603pub struct SeedData {
604    /// Tables to seed, keyed by table name
605    #[serde(default)]
606    pub tables: HashMap<String, Vec<HashMap<String, serde_json::Value>>>,
607
608    /// Seed order (optional - tables will be seeded in this order)
609    #[serde(default)]
610    pub order: Vec<String>,
611
612    /// Truncate tables before seeding
613    #[serde(default)]
614    pub truncate: bool,
615
616    /// Disable foreign key checks during seeding
617    #[serde(default)]
618    pub disable_fk_checks: bool,
619}
620
621// =============================================================================
622// Helper Functions
623// =============================================================================
624
625/// Find seed file in common locations
626pub fn find_seed_file(cwd: &Path, config: &Config) -> Option<PathBuf> {
627    // Check config first
628    if let Some(ref seed_path) = config.database.seed_path {
629        if seed_path.exists() {
630            return Some(seed_path.clone());
631        }
632    }
633
634    // Common locations
635    let candidates = [
636        cwd.join("seed.rs"),
637        cwd.join("seed.sql"),
638        cwd.join("seed.json"),
639        cwd.join("seed.toml"),
640        cwd.join("prax/seed.rs"),
641        cwd.join("prax/seed.sql"),
642        cwd.join("prax/seed.json"),
643        cwd.join("prax/seed.toml"),
644        cwd.join("prisma/seed.rs"),
645        cwd.join("prisma/seed.ts"), // Note: .ts not supported yet
646        cwd.join("src/seed.rs"),
647        cwd.join("seeds/seed.rs"),
648        cwd.join("seeds/seed.sql"),
649    ];
650
651    candidates.into_iter().find(|p| p.exists())
652}
653
654/// Get database URL from config or environment
655pub fn get_database_url(config: &Config) -> CliResult<String> {
656    // Try config first
657    if let Some(ref url) = config.database.url {
658        // Expand environment variables
659        let expanded = expand_env_var(url);
660        if !expanded.is_empty() && !expanded.contains("${") {
661            return Ok(expanded);
662        }
663    }
664
665    // Try environment variable
666    std::env::var("DATABASE_URL").map_err(|_| {
667        CliError::Config(
668            "Database URL not found. Set DATABASE_URL environment variable or configure in prax.toml"
669                .to_string(),
670        )
671    })
672}
673
674/// Expand environment variables in a string
675fn expand_env_var(s: &str) -> String {
676    let mut result = s.to_string();
677
678    // Match ${VAR} pattern
679    let re = regex_lite::Regex::new(r"\$\{([^}]+)\}").unwrap();
680    for cap in re.captures_iter(s) {
681        let var_name = &cap[1];
682        if let Ok(value) = std::env::var(var_name) {
683            result = result.replace(&cap[0], &value);
684        }
685    }
686
687    // Also match $VAR pattern (no braces)
688    let re2 = regex_lite::Regex::new(r"\$([A-Z_][A-Z0-9_]*)").unwrap();
689    for cap in re2.captures_iter(&result.clone()) {
690        let var_name = &cap[1];
691        if let Ok(value) = std::env::var(var_name) {
692            result = result.replace(&cap[0], &value);
693        }
694    }
695
696    result
697}
698
699/// Parse seed output for record counts
700fn parse_seed_output(line: &str) -> Option<u64> {
701    // Common patterns:
702    // "Created 10 users"
703    // "Seeded 100 records"
704    // "Inserted: 50"
705    let patterns = [
706        r"(?i)created\s+(\d+)",
707        r"(?i)seeded\s+(\d+)",
708        r"(?i)inserted[:\s]+(\d+)",
709        r"(?i)(\d+)\s+records?",
710        r"(?i)(\d+)\s+rows?",
711    ];
712
713    for pattern in patterns {
714        if let Ok(re) = regex_lite::Regex::new(pattern) {
715            if let Some(caps) = re.captures(line) {
716                if let Some(m) = caps.get(1) {
717                    if let Ok(n) = m.as_str().parse() {
718                        return Some(n);
719                    }
720                }
721            }
722        }
723    }
724
725    None
726}
727
728/// Parse affected rows from database output
729fn parse_affected_rows(output: &str) -> Option<u64> {
730    // PostgreSQL: "INSERT 0 5" or "UPDATE 3"
731    // MySQL: "Query OK, 5 rows affected"
732    // SQLite: no standard format
733
734    let patterns = [
735        r"INSERT\s+\d+\s+(\d+)",
736        r"UPDATE\s+(\d+)",
737        r"DELETE\s+(\d+)",
738        r"(\d+)\s+rows?\s+affected",
739    ];
740
741    let mut total = 0u64;
742
743    for pattern in patterns {
744        if let Ok(re) = regex_lite::Regex::new(pattern) {
745            for caps in re.captures_iter(output) {
746                if let Some(m) = caps.get(1) {
747                    if let Ok(n) = m.as_str().parse::<u64>() {
748                        total += n;
749                    }
750                }
751            }
752        }
753    }
754
755    if total > 0 { Some(total) } else { None }
756}
757
758/// Create a Cargo.toml for standalone seed script
759fn create_seed_cargo_toml(project_root: &Path) -> CliResult<String> {
760    // Try to read the workspace Cargo.toml to get prax version
761    let workspace_cargo = project_root.join("Cargo.toml");
762    let prax_version = if workspace_cargo.exists() {
763        let content = std::fs::read_to_string(&workspace_cargo)?;
764        // Try to extract prax version from dependencies
765        extract_prax_version(&content).unwrap_or_else(|| "0.2".to_string())
766    } else {
767        "0.2".to_string()
768    };
769
770    Ok(format!(
771        r#"[package]
772name = "seed"
773version = "0.1.0"
774edition = "2024"
775
776[dependencies]
777prax-orm = "{}"
778tokio = {{ version = "1", features = ["full"] }}
779"#,
780        prax_version
781    ))
782}
783
784/// Extract prax-orm version from Cargo.toml
785fn extract_prax_version(content: &str) -> Option<String> {
786    // Look for prax-orm = "x.y.z" or prax-orm = { version = "x.y.z" }
787    let simple_re = regex_lite::Regex::new(r#"prax-orm\s*=\s*"([^"]+)""#).ok()?;
788    if let Some(caps) = simple_re.captures(content) {
789        return Some(caps.get(1)?.as_str().to_string());
790    }
791
792    let complex_re = regex_lite::Regex::new(r#"prax-orm\s*=\s*\{[^}]*version\s*=\s*"([^"]+)""#).ok()?;
793    if let Some(caps) = complex_re.captures(content) {
794        return Some(caps.get(1)?.as_str().to_string());
795    }
796
797    None
798}
799
800// =============================================================================
801// Tests
802// =============================================================================
803
804#[cfg(test)]
805mod tests {
806    use super::*;
807
808    #[test]
809    fn test_seed_file_type_detection() {
810        assert_eq!(
811            SeedFileType::from_path(Path::new("seed.rs")),
812            Some(SeedFileType::Rust)
813        );
814        assert_eq!(
815            SeedFileType::from_path(Path::new("seed.sql")),
816            Some(SeedFileType::Sql)
817        );
818        assert_eq!(
819            SeedFileType::from_path(Path::new("data.json")),
820            Some(SeedFileType::Json)
821        );
822        assert_eq!(
823            SeedFileType::from_path(Path::new("data.toml")),
824            Some(SeedFileType::Toml)
825        );
826        assert_eq!(SeedFileType::from_path(Path::new("seed.txt")), None);
827    }
828
829    #[test]
830    fn test_parse_seed_output() {
831        assert_eq!(parse_seed_output("Created 10 users"), Some(10));
832        assert_eq!(parse_seed_output("Seeded 100 records"), Some(100));
833        assert_eq!(parse_seed_output("Inserted: 50"), Some(50));
834        assert_eq!(parse_seed_output("5 rows affected"), Some(5));
835        assert_eq!(parse_seed_output("no numbers here"), None);
836    }
837
838    #[test]
839    fn test_parse_affected_rows() {
840        assert_eq!(parse_affected_rows("INSERT 0 5"), Some(5));
841        assert_eq!(parse_affected_rows("UPDATE 3"), Some(3));
842        assert_eq!(parse_affected_rows("Query OK, 10 rows affected"), Some(10));
843    }
844
845    #[test]
846    fn test_expand_env_var() {
847        // SAFETY: Single-threaded test environment
848        unsafe {
849            std::env::set_var("TEST_VAR", "test_value");
850        }
851        assert_eq!(expand_env_var("${TEST_VAR}"), "test_value");
852        assert_eq!(expand_env_var("$TEST_VAR"), "test_value");
853        assert_eq!(
854            expand_env_var("postgres://${TEST_VAR}@localhost"),
855            "postgres://test_value@localhost"
856        );
857        // SAFETY: Single-threaded test environment
858        unsafe {
859            std::env::remove_var("TEST_VAR");
860        }
861    }
862}