Skip to main content

rullst_orm/
schema.rs

1use crate::Error;
2
3/// Allowlist of SQL comparison/join operators accepted in raw clause builders.
4const ALLOWED_OPERATORS: &[&str] = &["=", "!=", "<>", "<", ">", "<=", ">="];
5
6/// Validates a SQL identifier (column or table name) to prevent SQL injection.
7/// Allows alphanumeric characters, underscores, hyphens and a single dot
8/// for qualified names like `table.column`.
9pub fn validate_identifier(name: &str) -> Result<(), Error> {
10    if name.is_empty() {
11        return Err(Error::Internal(
12            "SQL identifier cannot be empty".to_string(),
13        ));
14    }
15
16    let bytes = name.as_bytes();
17    if bytes[0] == b'.' || bytes[bytes.len() - 1] == b'.' {
18        return Err(Error::Internal(format!(
19            "Invalid SQL identifier '{}': must not start or end with a dot",
20            name
21        )));
22    }
23
24    let mut dot_count = 0;
25    for &b in bytes {
26        if b == b'.' {
27            dot_count += 1;
28            if dot_count > 1 {
29                return Err(Error::Internal(format!(
30                    "Invalid SQL identifier '{}': at most one dot is allowed",
31                    name
32                )));
33            }
34        } else if !b.is_ascii_alphanumeric() && b != b'_' && b != b'-' {
35            return Err(Error::Internal(format!(
36                "Invalid SQL identifier '{}': only alphanumeric characters, underscores, hyphens and dots are allowed",
37                name
38            )));
39        }
40    }
41
42    Ok(())
43}
44
45/// Validates a table name to prevent SQL injection.
46pub fn validate_table_name(table_name: &str) -> Result<(), Error> {
47    if table_name.contains('.') {
48        return Err(Error::Internal(format!(
49            "Invalid table name '{}': dots are not allowed in table names",
50            table_name
51        )));
52    }
53    validate_identifier(table_name)
54}
55
56/// Safe values allowed for a column DEFAULT clause.
57///
58/// Accepting a raw `&str` would allow DDL injection through the DEFAULT
59/// position. This enum restricts callers to known-safe literals.
60#[derive(Debug, Clone, PartialEq)]
61pub enum ColumnDefault {
62    /// `CURRENT_TIMESTAMP` — standard SQL timestamp literal.
63    CurrentTimestamp,
64    /// `NULL` — explicit SQL null default.
65    Null,
66    /// A non-negative integer literal (e.g. `0`, `1`).
67    Integer(i64),
68    /// A non-negative real literal (e.g. `0.0`).
69    Float(f64),
70    /// A string literal that will be single-quoted and escaped.
71    /// Only printable ASCII excluding `'` and `\` is accepted.
72    Text(String),
73}
74
75impl ColumnDefault {
76    /// Renders the default value as a safe SQL fragment.
77    pub fn to_sql(&self) -> String {
78        match self {
79            ColumnDefault::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
80            ColumnDefault::Null => "NULL".to_string(),
81            ColumnDefault::Integer(n) => n.to_string(),
82            ColumnDefault::Float(f) => format!("{f}"),
83            // Single-quote the string and escape any embedded single-quotes
84            // via SQL standard doubling (''), which is safe on every driver.
85            ColumnDefault::Text(s) => format!("'{}'", s.replace('\'', "''")),
86        }
87    }
88}
89
90pub struct Column {
91    pub name: String,
92    pub col_type: String,
93    pub is_nullable: bool,
94    pub is_primary_key: bool,
95    pub is_auto_increment: bool,
96    pub default_value: Option<ColumnDefault>,
97}
98
99impl Column {
100    /// Creates a new column, validating `name` against SQL identifier rules.
101    ///
102    /// # Panics
103    /// Panics if `name` fails identifier validation. Column names are always
104    /// developer-supplied compile-time literals — an invalid name is a bug,
105    /// not a runtime condition.
106    pub fn new(name: &str, col_type: &str) -> Self {
107        validate_identifier(name)
108            .unwrap_or_else(|e| panic!("Invalid column name {:?}: {}", name, e));
109        Self {
110            name: name.to_string(),
111            col_type: col_type.to_string(),
112            is_nullable: true,
113            is_primary_key: false,
114            is_auto_increment: false,
115            default_value: None,
116        }
117    }
118
119    pub fn not_null(&mut self) -> &mut Self {
120        self.is_nullable = false;
121        self
122    }
123
124    pub fn nullable(&mut self) -> &mut Self {
125        self.is_nullable = true;
126        self
127    }
128
129    /// Sets a safe DEFAULT value using the [`ColumnDefault`] enum.
130    ///
131    /// The old `&str` overload has been removed to prevent DDL injection
132    /// through unescaped DEFAULT clauses.
133    pub fn default(&mut self, val: ColumnDefault) -> &mut Self {
134        self.default_value = Some(val);
135        self
136    }
137
138    pub fn primary(&mut self) -> &mut Self {
139        self.is_primary_key = true;
140        self
141    }
142}
143
144pub struct Blueprint {
145    pub columns: Vec<Column>,
146}
147
148impl Default for Blueprint {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl Blueprint {
155    pub fn new() -> Self {
156        Self { columns: vec![] }
157    }
158
159    pub fn id(&mut self) -> &mut Column {
160        self.columns.push(Column {
161            name: "id".to_string(),
162            col_type: "INTEGER".to_string(),
163            is_nullable: false,
164            is_primary_key: true,
165            is_auto_increment: true,
166            default_value: None,
167        });
168        self.columns
169            .last_mut()
170            .expect("BUG: columns is empty after push")
171    }
172
173    fn add_column(&mut self, name: &str, col_type: &str) -> &mut Column {
174        let col = Column::new(name, col_type);
175        self.columns.push(col);
176        self.columns
177            .last_mut()
178            .expect("BUG: columns is empty after push")
179    }
180
181    pub fn string(&mut self, name: &str) -> &mut Column {
182        self.add_column(name, "TEXT")
183    }
184
185    pub fn integer(&mut self, name: &str) -> &mut Column {
186        self.add_column(name, "INTEGER")
187    }
188
189    pub fn float(&mut self, name: &str) -> &mut Column {
190        self.add_column(name, "REAL")
191    }
192
193    pub fn boolean(&mut self, name: &str) -> &mut Column {
194        self.add_column(name, "INTEGER")
195    }
196
197    pub fn timestamps(&mut self) {
198        let mut created = Column::new("created_at", "TEXT");
199        created.default(ColumnDefault::CurrentTimestamp);
200        self.columns.push(created);
201
202        let mut updated = Column::new("updated_at", "TEXT");
203        updated.default(ColumnDefault::CurrentTimestamp);
204        self.columns.push(updated);
205    }
206
207    pub fn soft_deletes(&mut self) {
208        let col = Column::new("deleted_at", "TEXT");
209        self.columns.push(col);
210        self.columns
211            .last_mut()
212            .expect("BUG: columns is empty after push")
213            .nullable();
214    }
215
216    #[cfg_attr(test, mutants::skip)]
217    pub fn build(&self) -> Result<String, Error> {
218        let driver = crate::DB_DRIVER
219            .get()
220            .map(|s| s.as_str())
221            .unwrap_or("sqlite");
222        let mut defs = vec![];
223        for col in &self.columns {
224            // Defensive re-validation: column names must always be safe
225            // identifiers regardless of how the Column was constructed.
226            validate_identifier(&col.name)?;
227
228            let mut col_type_str = col.col_type.clone();
229            if driver == "postgres" && col.is_auto_increment {
230                if col.col_type == "INTEGER" || col.col_type == "INT" {
231                    col_type_str = "SERIAL".to_string();
232                } else if col.col_type == "BIGINT" {
233                    col_type_str = "BIGSERIAL".to_string();
234                }
235            }
236
237            let mut def = format!("{} {}", col.name, col_type_str);
238            if col.is_primary_key {
239                def.push_str(" PRIMARY KEY");
240            }
241            if col.is_auto_increment {
242                if driver == "sqlite" {
243                    def.push_str(" AUTOINCREMENT");
244                } else if driver == "mysql" {
245                    def.push_str(" AUTO_INCREMENT");
246                }
247            }
248            if !col.is_nullable && !col.is_primary_key {
249                def.push_str(" NOT NULL");
250            }
251            if let Some(default) = &col.default_value {
252                use std::fmt::Write;
253                write!(def, " DEFAULT {}", default.to_sql()).unwrap();
254            }
255            defs.push(def);
256        }
257        Ok(defs.join(",\n    "))
258    }
259}
260
261pub struct Schema;
262
263impl Schema {
264    pub async fn create<F>(table_name: &str, callback: F) -> Result<(), Error>
265    where
266        F: FnOnce(&mut Blueprint),
267    {
268        validate_table_name(table_name)?;
269
270        let mut blueprint = Blueprint::new();
271        callback(&mut blueprint);
272
273        // build() now returns Result so any column-name or default issues
274        // surface as errors rather than producing malformed SQL.
275        let columns_sql = blueprint.build()?;
276        let sql = format!(
277            "CREATE TABLE IF NOT EXISTS {} (\n    {}\n);",
278            table_name, columns_sql
279        );
280
281        let pool = crate::Orm::pool();
282        let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
283        query_builder.push(&sql);
284        query_builder.build().execute(pool).await?;
285
286        Ok(())
287    }
288
289    pub async fn drop_if_exists(table_name: &str) -> Result<(), Error> {
290        validate_table_name(table_name)?;
291
292        let sql = format!("DROP TABLE IF EXISTS {};", table_name);
293        let pool = crate::Orm::pool();
294        let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
295        query_builder.push(&sql);
296        query_builder.build().execute(pool).await?;
297        Ok(())
298    }
299}
300
301#[async_trait::async_trait]
302pub trait Migration: Send + Sync {
303    fn name(&self) -> &'static str;
304    async fn up(&self) -> Result<(), Error>;
305    async fn down(&self) -> Result<(), Error>;
306}
307
308#[cfg_attr(test, mutants::skip)]
309pub async fn run_artisan_with_args(
310    args: &[String],
311    migrations: Vec<Box<dyn Migration>>,
312    seeders: Vec<Box<dyn crate::Seeder>>,
313) -> Result<(), Error> {
314    if args.len() < 2 {
315        println!("Rullst ORM Artisan CLI");
316        println!("Usage:");
317        println!("  make:migration <name>   Generate a new migration");
318        println!("  migrate                  Run all pending migrations");
319        println!("  migrate:rollback         Rollback the last batch of migrations");
320        println!("  status                   Show migrations status");
321        println!("  db:seed                  Populate the database with seeders");
322        return Ok(());
323    }
324
325    let command = &args[1];
326    match command.as_str() {
327        "make:migration" => {
328            if args.len() < 3 {
329                println!("Error: migration name is required.");
330                return Ok(());
331            }
332            let name = &args[2];
333            create_migration_files(name)?;
334        }
335        "migrate" | "db:migrate" => {
336            run_migrations(migrations).await?;
337        }
338        "migrate:rollback" | "db:rollback" => {
339            rollback_migrations(migrations).await?;
340        }
341        "status" | "db:status" => {
342            status_migrations(migrations).await?;
343        }
344        "db:seed" => {
345            println!("Seeding database...");
346            crate::Orm::seed(seeders).await?;
347            println!("Database seeded successfully!");
348        }
349        _ => {
350            println!("Unknown command: {}", command);
351        }
352    }
353    Ok(())
354}
355
356#[cfg_attr(test, mutants::skip)]
357pub async fn run_artisan(
358    migrations: Vec<Box<dyn Migration>>,
359    seeders: Vec<Box<dyn crate::Seeder>>,
360) -> Result<(), Error> {
361    let args: Vec<String> = std::env::args().collect();
362    run_artisan_with_args(&args, migrations, seeders).await
363}
364
365#[cfg_attr(test, mutants::skip)]
366async fn status_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
367    let pool = crate::Orm::pool();
368    let driver = crate::Orm::driver();
369
370    let table_exists = match driver {
371        "postgres" | "mysql" => {
372            let query_str =
373                "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
374            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
375            row.0 > 0
376        }
377        _ => {
378            let query_str =
379                "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
380            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
381            row.0 > 0
382        }
383    };
384
385    let executed_set = if table_exists {
386        let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
387            .fetch_all(pool)
388            .await?;
389        executed
390            .into_iter()
391            .map(|(m,)| m)
392            .collect::<std::collections::HashSet<String>>()
393    } else {
394        std::collections::HashSet::new()
395    };
396
397    let name_header = "Migration Name";
398    let status_header = "Status";
399    println!("{name_header:<40} | {status_header}");
400    println!("{}", "-".repeat(55));
401    for m in migrations {
402        let name = m.name();
403        let status = if executed_set.contains(name) {
404            "Applied"
405        } else {
406            "Pending"
407        };
408        println!("{:<40} | {}", name, status);
409    }
410
411    Ok(())
412}
413
414#[cfg_attr(test, mutants::skip)]
415fn create_migration_files(name: &str) -> Result<(), Error> {
416    validate_table_name(name)?;
417    use std::fs;
418
419    let now = std::time::SystemTime::now()
420        .duration_since(std::time::UNIX_EPOCH)
421        .expect("System time went backwards")
422        .as_secs()
423        .to_string();
424    let snake_name = name.to_lowercase().replace("-", "_");
425    let file_name = format!("m{}_{}", now, snake_name);
426
427    fs::create_dir_all("src/migrations")
428        .map_err(|e| Error::Internal(format!("Failed to create migrations directory: {}", e)))?;
429
430    let new_file_path = format!("src/migrations/{}.rs", file_name);
431    let template = include_str!("migration_template.rs.txt");
432    let migration_code = template
433        .replace("{timestamp}", &now)
434        .replace("{name}", &snake_name);
435
436    fs::write(&new_file_path, migration_code)
437        .map_err(|e| Error::Internal(format!("Failed to write migration file: {}", e)))?;
438    println!("Created migration file: {}", new_file_path);
439
440    regenerate_migrations_mod()?;
441
442    Ok(())
443}
444
445#[cfg_attr(test, mutants::skip)]
446fn regenerate_migrations_mod() -> Result<(), Error> {
447    use std::fs;
448    let paths = fs::read_dir("src/migrations")
449        .map_err(|e| Error::Internal(format!("Failed to read migrations dir: {}", e)))?;
450
451    let mut modules = vec![];
452    for path in paths {
453        let path = path.map_err(|e| Error::Internal(e.to_string()))?.path();
454        if let Some(ext) = path.extension()
455            && ext == "rs"
456            && let Some(stem) = path.file_stem()
457        {
458            let stem_str = stem.to_string_lossy().to_string();
459            if stem_str != "mod" && stem_str.starts_with('m') {
460                modules.push(stem_str);
461            }
462        }
463    }
464    modules.sort();
465
466    use std::fmt::Write;
467    let mut mod_content = String::new();
468    mod_content.push_str("// Generated by Rullst ORM Artisan. Do not edit manually.\n\n");
469    for m in &modules {
470        writeln!(mod_content, "pub mod {};", m).unwrap();
471    }
472    mod_content
473        .push_str("\npub fn get_migrations() -> Vec<Box<dyn rullst_orm::schema::Migration>> {\n");
474    mod_content.push_str("    vec![\n");
475    for m in &modules {
476        writeln!(mod_content, "        Box::new({}::MigrationImpl),", m).unwrap();
477    }
478    mod_content.push_str("    ]\n");
479    mod_content.push_str("}\n");
480
481    fs::write("src/migrations/mod.rs", mod_content)
482        .map_err(|e| Error::Internal(format!("Failed to write mod.rs: {}", e)))?;
483    println!("Regenerated src/migrations/mod.rs");
484
485    Ok(())
486}
487
488#[cfg_attr(test, mutants::skip)]
489async fn run_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
490    let pool = crate::Orm::pool();
491    let driver = crate::Orm::driver();
492
493    let query_str = match driver {
494        "postgres" => {
495            "CREATE TABLE IF NOT EXISTS migrations (
496                id SERIAL PRIMARY KEY,
497                migration VARCHAR(255) NOT NULL,
498                batch INTEGER NOT NULL
499            )"
500        }
501        "mysql" => {
502            "CREATE TABLE IF NOT EXISTS migrations (
503                id INT AUTO_INCREMENT PRIMARY KEY,
504                migration VARCHAR(255) NOT NULL,
505                batch INT NOT NULL
506            )"
507        }
508        _ => {
509            "CREATE TABLE IF NOT EXISTS migrations (
510                id INTEGER PRIMARY KEY AUTOINCREMENT,
511                migration TEXT NOT NULL,
512                batch INTEGER NOT NULL
513            )"
514        }
515    };
516
517    sqlx::query(query_str).execute(pool).await?;
518
519    let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
520        .fetch_all(pool)
521        .await?;
522    let executed_set: std::collections::HashSet<String> =
523        executed.into_iter().map(|(m,)| m).collect();
524
525    let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
526        .fetch_one(pool)
527        .await?;
528    let next_batch = batch_row.0.unwrap_or(0) + 1;
529
530    let mut count = 0;
531    let mut successful_migrations = vec![];
532    for m in migrations {
533        let name = m.name();
534        if !executed_set.contains(name) {
535            println!("Migrating: {}", name);
536            m.up().await?;
537            successful_migrations.push(name);
538            println!("Migrated:  {}", name);
539            count += 1;
540        }
541    }
542
543    if count > 0 {
544        let mut query_builder =
545            sqlx::query_builder::QueryBuilder::new("INSERT INTO migrations (migration, batch) ");
546        query_builder.push_values(successful_migrations, |mut b, name| {
547            b.push_bind(name).push_bind(next_batch);
548        });
549        query_builder.build().execute(pool).await?;
550    } else {
551        println!("Nothing to migrate.");
552    }
553
554    Ok(())
555}
556
557#[cfg_attr(test, mutants::skip)]
558async fn rollback_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
559    let pool = crate::Orm::pool();
560    let driver = crate::Orm::driver();
561
562    let table_exists = match driver {
563        "postgres" | "mysql" => {
564            let query_str =
565                "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
566            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
567            row.0 > 0
568        }
569        _ => {
570            let query_str =
571                "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
572            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
573            row.0 > 0
574        }
575    };
576
577    if !table_exists {
578        println!("Nothing to rollback.");
579        return Ok(());
580    }
581
582    let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
583        .fetch_one(pool)
584        .await?;
585
586    let last_batch = match batch_row.0 {
587        Some(b) if b > 0 => b,
588        _ => {
589            println!("Nothing to rollback.");
590            return Ok(());
591        }
592    };
593
594    let to_rollback: Vec<(String,)> =
595        sqlx::query_as("SELECT migration FROM migrations WHERE batch = ? ORDER BY id DESC")
596            .bind(last_batch)
597            .fetch_all(pool)
598            .await?;
599
600    let mut rollback_map = std::collections::HashMap::with_capacity(migrations.len());
601    for m in migrations {
602        rollback_map.insert(m.name().to_string(), m);
603    }
604
605    let mut rolled_back = Vec::with_capacity(to_rollback.len());
606    for (name,) in to_rollback {
607        if let Some(m) = rollback_map.get(&name) {
608            println!("Rolling back: {}", name);
609            m.down().await?;
610            println!("Rolled back:  {}", name);
611            rolled_back.push(name);
612        } else {
613            println!(
614                "Warning: migration {} found in database but not in compiled binary.",
615                name
616            );
617        }
618    }
619
620    if !rolled_back.is_empty() {
621        let mut query_builder =
622            sqlx::query_builder::QueryBuilder::new("DELETE FROM migrations WHERE migration IN (");
623        let mut separated = query_builder.separated(", ");
624        for name in rolled_back {
625            separated.push_bind(name);
626        }
627        separated.push_unseparated(")");
628        query_builder.build().execute(pool).await?;
629    }
630
631    Ok(())
632}
633
634pub struct JoinClause {
635    pub table: String,
636    pub conditions: Vec<String>,
637    pub bindings: Vec<crate::RullstValue>,
638    pub errors: Vec<crate::Error>,
639}
640
641impl JoinClause {
642    pub fn new(table: &str) -> Self {
643        Self {
644            table: table.to_string(),
645            conditions: vec![],
646            bindings: vec![],
647            errors: vec![],
648        }
649    }
650
651    /// Adds a column-to-column JOIN condition.
652    ///
653    /// This prevents SQL injection — column names should always be hardcoded, never
654    /// derived from user input. Returns errors internally rather than panicking.
655    pub fn on(&mut self, first: &str, operator: &str, second: &str) -> &mut Self {
656        if let Err(e) = validate_identifier(first) {
657            self.errors.push(crate::Error::Validation(format!(
658                "JoinClause::on — invalid identifier for `first`: {:?}",
659                e
660            )));
661        }
662        if let Err(e) = validate_identifier(second) {
663            self.errors.push(crate::Error::Validation(format!(
664                "JoinClause::on — invalid identifier for `second`: {:?}",
665                e
666            )));
667        }
668        if !ALLOWED_OPERATORS.contains(&operator) {
669            self.errors.push(crate::Error::Validation(format!(
670                "JoinClause::on — invalid operator '{}'. Allowed: {:?}",
671                operator, ALLOWED_OPERATORS
672            )));
673        }
674        self.conditions
675            .push(format!("{} {} {}", first, operator, second));
676        self
677    }
678
679    pub fn on_eq<T: Into<crate::RullstValue>>(&mut self, column: &str, value: T) -> &mut Self {
680        if let Err(e) = validate_identifier(column) {
681            self.errors.push(crate::Error::Validation(format!(
682                "JoinClause::on_eq — invalid identifier for `column`: {:?}",
683                e
684            )));
685        }
686        self.conditions.push(format!("{} = ?", column));
687        self.bindings.push(value.into());
688        self
689    }
690
691    pub fn to_sql(&self) -> String {
692        self.conditions.join(" AND ")
693    }
694}
695
696pub trait SubqueryBuilder {
697    fn to_sql(&self) -> String;
698    fn bindings(&self) -> &Vec<crate::RullstValue>;
699}
700
701pub static QUERY_LOGGING: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
702pub static MAX_QUERY_LIMIT: std::sync::atomic::AtomicUsize =
703    std::sync::atomic::AtomicUsize::new(1000);
704pub static QUERY_TIMEOUT_SECS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(30);
705
706pub fn enable_query_log() {
707    QUERY_LOGGING.store(true, std::sync::atomic::Ordering::SeqCst);
708}
709
710pub fn disable_query_log() {
711    QUERY_LOGGING.store(false, std::sync::atomic::Ordering::SeqCst);
712}
713
714pub fn is_query_log_enabled() -> bool {
715    QUERY_LOGGING.load(std::sync::atomic::Ordering::SeqCst)
716}
717
718pub fn set_max_query_limit(limit: usize) {
719    MAX_QUERY_LIMIT.store(limit, std::sync::atomic::Ordering::SeqCst);
720}
721
722pub fn get_max_query_limit() -> Option<usize> {
723    let limit = MAX_QUERY_LIMIT.load(std::sync::atomic::Ordering::SeqCst);
724    if limit == 0 { None } else { Some(limit) }
725}
726
727pub fn set_query_timeout(secs: u64) {
728    QUERY_TIMEOUT_SECS.store(secs, std::sync::atomic::Ordering::SeqCst);
729}
730
731pub fn get_query_timeout() -> Option<std::time::Duration> {
732    let secs = QUERY_TIMEOUT_SECS.load(std::sync::atomic::Ordering::SeqCst);
733    if secs == 0 {
734        None
735    } else {
736        Some(std::time::Duration::from_secs(secs))
737    }
738}
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743
744    #[test]
745    fn test_enable_disable_query_log() {
746        disable_query_log();
747        assert!(!is_query_log_enabled());
748        enable_query_log();
749        assert!(is_query_log_enabled());
750        disable_query_log();
751        assert!(!is_query_log_enabled());
752    }
753
754    #[test]
755    fn test_join_clause() {
756        let mut jc = JoinClause::new("users");
757        jc.on("users.id", "=", "posts.user_id");
758        assert_eq!(jc.to_sql(), "users.id = posts.user_id");
759    }
760
761    #[test]
762    fn test_validate_table_name() {
763        assert!(validate_table_name("users").is_ok());
764        assert!(validate_table_name("user_posts").is_ok());
765        assert!(validate_table_name("DROP TABLE users").is_err());
766        assert!(validate_table_name("../../../etc/shadow").is_err());
767        // dots not allowed in table names
768        assert!(validate_table_name("users.id").is_err());
769        assert!(validate_table_name("").is_err()); // Empty table name
770    }
771
772    #[test]
773    fn test_validate_identifier() {
774        assert!(validate_identifier("users").is_ok());
775        assert!(validate_identifier("users.id").is_ok());
776        assert!(validate_identifier("user_posts").is_ok());
777        assert!(validate_identifier("").is_err());
778        assert!(validate_identifier("users.posts.id").is_err()); // two dots
779        assert!(validate_identifier("DROP TABLE users").is_err());
780        assert!(validate_identifier("id; DROP TABLE users--").is_err());
781        // Leading/trailing dot edge cases — all now rejected
782        assert!(validate_identifier(".").is_err()); // bare dot: starts AND ends with dot
783        assert!(validate_identifier(".users").is_err()); // leading dot
784        assert!(validate_identifier("users.").is_err()); // trailing dot
785        assert!(validate_identifier("user name").is_err()); // Spaces not allowed
786        assert!(validate_identifier("admin'--").is_err()); // Quotes not allowed
787        assert!(validate_identifier("users()").is_err()); // Parentheses not allowed
788        assert!(validate_identifier("a*b").is_err()); // Asterisk not allowed
789
790        // Extensive error tests
791        assert!(validate_identifier("SELECT * FROM users").is_err());
792        assert!(validate_identifier("users\nWHERE").is_err());
793        assert!(validate_identifier("users\t").is_err());
794        assert!(validate_identifier("\\").is_err());
795    }
796
797    #[test]
798    fn test_join_clause_on_invalid_operator() {
799        let mut jc = JoinClause::new("posts");
800        jc.on("posts.user_id", "OR 1=1 --", "users.id");
801        assert!(!jc.errors.is_empty());
802        assert!(jc.errors[0].to_string().contains("invalid operator"));
803    }
804
805    #[test]
806    fn test_join_clause_on_invalid_column() {
807        let mut jc = JoinClause::new("posts");
808        jc.on("users.id; DROP TABLE users--", "=", "posts.user_id");
809        assert!(!jc.errors.is_empty());
810        assert!(jc.errors[0].to_string().contains("invalid identifier"));
811    }
812
813    #[test]
814    fn test_timestamps_adds_columns() {
815        let mut bp = Blueprint::new();
816        bp.timestamps();
817        assert_eq!(bp.columns.len(), 2);
818        assert_eq!(bp.columns[0].name, "created_at");
819        assert_eq!(bp.columns[1].name, "updated_at");
820        assert_eq!(
821            bp.columns[0].default_value,
822            Some(ColumnDefault::CurrentTimestamp)
823        );
824        assert_eq!(
825            bp.columns[1].default_value,
826            Some(ColumnDefault::CurrentTimestamp)
827        );
828    }
829
830    #[test]
831    fn test_soft_deletes_adds_nullable_column() {
832        let mut bp = Blueprint::new();
833        bp.soft_deletes();
834        assert_eq!(bp.columns.len(), 1);
835        assert_eq!(bp.columns[0].name, "deleted_at");
836        assert!(bp.columns[0].is_nullable);
837    }
838
839    #[test]
840    fn test_blueprint_build_produces_valid_sql() {
841        let mut bp = Blueprint::new();
842        bp.id();
843        bp.string("name").not_null();
844        bp.integer("age");
845        let sql = bp.build().expect("build should succeed for valid columns");
846        assert!(sql.contains("id INTEGER PRIMARY KEY"));
847        assert!(sql.contains("name TEXT NOT NULL"));
848        assert!(sql.contains("age INTEGER"));
849    }
850
851    #[test]
852    fn test_column_default_to_sql_escaping() {
853        let default_text = ColumnDefault::Text("O'Reilly".to_string());
854        assert_eq!(default_text.to_sql(), "'O''Reilly'");
855    }
856
857    #[test]
858    fn test_validate_identifier_multiple_dots() {
859        assert!(validate_identifier("table.column").is_ok()); // one dot
860        assert!(validate_identifier("schema.table.column").is_err()); // multiple dots
861    }
862
863    #[test]
864    fn test_column_default_sql_rendering() {
865        assert_eq!(
866            ColumnDefault::CurrentTimestamp.to_sql(),
867            "CURRENT_TIMESTAMP"
868        );
869        assert_eq!(ColumnDefault::Null.to_sql(), "NULL");
870        assert_eq!(ColumnDefault::Integer(42).to_sql(), "42");
871        assert_eq!(ColumnDefault::Float(1.23).to_sql(), "1.23");
872        assert_eq!(ColumnDefault::Text("hello".to_string()).to_sql(), "'hello'");
873        // SQL injection via embedded quote must be escaped
874        assert_eq!(ColumnDefault::Text("it's".to_string()).to_sql(), "'it''s'");
875    }
876
877    #[test]
878    fn test_join_clause_on_eq_binds_value() {
879        let mut jc = JoinClause::new("orders");
880        jc.on_eq("orders.user_id", 42i32);
881        assert_eq!(jc.to_sql(), "orders.user_id = ?");
882        assert_eq!(jc.bindings.len(), 1);
883    }
884
885    #[test]
886    fn test_join_clause_multiple_conditions() {
887        let mut jc = JoinClause::new("posts");
888        jc.on("posts.user_id", "=", "users.id");
889        jc.on("posts.status", ">", "users.min_status");
890        assert_eq!(
891            jc.to_sql(),
892            "posts.user_id = users.id AND posts.status > users.min_status"
893        );
894    }
895
896    #[test]
897    fn test_column_builder_methods() {
898        let mut col = Column::new("age", "INTEGER");
899        assert_eq!(col.name, "age");
900        assert_eq!(col.col_type, "INTEGER");
901        assert!(col.is_nullable); // default is true
902        assert!(!col.is_primary_key);
903        assert!(!col.is_auto_increment);
904        assert_eq!(col.default_value, None);
905
906        col.not_null();
907        assert!(!col.is_nullable);
908
909        col.nullable();
910        assert!(col.is_nullable);
911
912        col.primary();
913        assert!(col.is_primary_key);
914
915        col.default(ColumnDefault::Integer(18));
916        assert_eq!(col.default_value, Some(ColumnDefault::Integer(18)));
917    }
918
919    #[tokio::test]
920    async fn test_db_migration_error_state_invalid_blueprint() {
921        let result = Schema::create("invalid; DROP TABLE users", |bp| {
922            bp.id();
923        })
924        .await;
925
926        assert!(result.is_err());
927    }
928
929    #[tokio::test]
930    async fn test_drop_if_exists_invalid_table() {
931        let result = Schema::drop_if_exists("invalid; name").await;
932        assert!(result.is_err());
933        assert!(matches!(result, Err(crate::Error::Internal(_))));
934    }
935
936    #[test]
937    fn test_max_query_limit_and_timeout_globals() {
938        // Test limit
939        set_max_query_limit(50);
940        assert_eq!(get_max_query_limit(), Some(50));
941        set_max_query_limit(0);
942        assert_eq!(get_max_query_limit(), None);
943
944        // Test timeout
945        set_query_timeout(10);
946        assert_eq!(
947            get_query_timeout(),
948            Some(std::time::Duration::from_secs(10))
949        );
950        set_query_timeout(0);
951        assert_eq!(get_query_timeout(), None);
952    }
953
954    #[tokio::test]
955    async fn test_run_artisan_entrypoint() {
956        // Calling run_artisan with empty lists. It parses std::env::args() and prints help
957        // because the arguments of cargo test won't match any of the commands.
958        let result = run_artisan(vec![], vec![]).await;
959        assert!(result.is_ok());
960    }
961}