Skip to main content

rullst_orm/
schema.rs

1use crate::Error;
2
3/// Allowlist of SQL comparison/join operators accepted in raw clause builders.
4const ALLOWED_OPERATORS: &[&str] = &["=", "!=", "<>", "<", ">", "<=", ">="];
5
6/// Validates a SQL identifier (column or table name) to prevent SQL injection.
7/// Allows alphanumeric characters, underscores, hyphens and a single dot
8/// for qualified names like `table.column`.
9pub fn validate_identifier(name: &str) -> Result<(), Error> {
10    if name.is_empty() {
11        return Err(Error::Internal(
12            "SQL identifier cannot be empty".to_string(),
13        ));
14    }
15    // At most one dot is allowed (for `table.column` notation)
16    let dot_count = name.chars().filter(|&c| c == '.').count();
17    if dot_count > 1 {
18        return Err(Error::Internal(format!(
19            "Invalid SQL identifier '{}': at most one dot is allowed",
20            name
21        )));
22    }
23    if !name
24        .chars()
25        .all(|c| c.is_alphanumeric() || c == '_' || c == '-' || c == '.')
26    {
27        return Err(Error::Internal(format!(
28            "Invalid SQL identifier '{}': only alphanumeric characters, underscores, hyphens and dots are allowed",
29            name
30        )));
31    }
32    Ok(())
33}
34
35/// Validates a table name to prevent SQL injection.
36/// Wraps `validate_identifier` but disallows dots (table names have no qualifier).
37fn validate_table_name(table_name: &str) -> Result<(), Error> {
38    if table_name.contains('.') {
39        return Err(Error::Internal(format!(
40            "Invalid table name '{}': dots are not allowed in table names",
41            table_name
42        )));
43    }
44    validate_identifier(table_name)
45}
46
47pub struct Column {
48    pub name: String,
49    pub col_type: String,
50    pub is_nullable: bool,
51    pub is_primary_key: bool,
52    pub is_auto_increment: bool,
53    pub default_value: Option<String>,
54}
55
56impl Column {
57    pub fn new(name: &str, col_type: &str) -> Self {
58        Self {
59            name: name.to_string(),
60            col_type: col_type.to_string(),
61            is_nullable: true,
62            is_primary_key: false,
63            is_auto_increment: false,
64            default_value: None,
65        }
66    }
67
68    pub fn not_null(&mut self) -> &mut Self {
69        self.is_nullable = false;
70        self
71    }
72
73    pub fn nullable(&mut self) -> &mut Self {
74        self.is_nullable = true;
75        self
76    }
77
78    pub fn default(&mut self, val: &str) -> &mut Self {
79        self.default_value = Some(val.to_string());
80        self
81    }
82
83    pub fn primary(&mut self) -> &mut Self {
84        self.is_primary_key = true;
85        self
86    }
87}
88
89pub struct Blueprint {
90    pub columns: Vec<Column>,
91}
92
93impl Default for Blueprint {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99impl Blueprint {
100    pub fn new() -> Self {
101        Self { columns: vec![] }
102    }
103
104    pub fn id(&mut self) -> &mut Column {
105        self.columns.push(Column {
106            name: "id".to_string(),
107            col_type: "INTEGER".to_string(),
108            is_nullable: false,
109            is_primary_key: true,
110            is_auto_increment: true,
111            default_value: None,
112        });
113        self.columns
114            .last_mut()
115            .expect("BUG: columns is empty after push")
116    }
117
118    pub fn string(&mut self, name: &str) -> &mut Column {
119        let col = Column::new(name, "TEXT");
120        self.columns.push(col);
121        self.columns
122            .last_mut()
123            .expect("BUG: columns is empty after push")
124    }
125
126    pub fn integer(&mut self, name: &str) -> &mut Column {
127        let col = Column::new(name, "INTEGER");
128        self.columns.push(col);
129        self.columns
130            .last_mut()
131            .expect("BUG: columns is empty after push")
132    }
133
134    pub fn float(&mut self, name: &str) -> &mut Column {
135        let col = Column::new(name, "REAL");
136        self.columns.push(col);
137        self.columns
138            .last_mut()
139            .expect("BUG: columns is empty after push")
140    }
141
142    pub fn boolean(&mut self, name: &str) -> &mut Column {
143        let col = Column::new(name, "INTEGER");
144        self.columns.push(col);
145        self.columns
146            .last_mut()
147            .expect("BUG: columns is empty after push")
148    }
149
150    pub fn timestamps(&mut self) {
151        let mut created = Column::new("created_at", "TEXT");
152        created.default("CURRENT_TIMESTAMP");
153        self.columns.push(created);
154
155        let mut updated = Column::new("updated_at", "TEXT");
156        updated.default("CURRENT_TIMESTAMP");
157        self.columns.push(updated);
158    }
159
160    pub fn soft_deletes(&mut self) {
161        let col = Column::new("deleted_at", "TEXT");
162        self.columns.push(col);
163        self.columns
164            .last_mut()
165            .expect("BUG: columns is empty after push")
166            .nullable();
167    }
168
169    pub fn build(&self) -> String {
170        let mut defs = vec![];
171        for col in &self.columns {
172            let mut def = format!("{} {}", col.name, col.col_type);
173            if col.is_primary_key {
174                def.push_str(" PRIMARY KEY");
175            }
176            if col.is_auto_increment {
177                def.push_str(" AUTOINCREMENT");
178            }
179            if !col.is_nullable && !col.is_primary_key {
180                def.push_str(" NOT NULL");
181            }
182            if let Some(val) = &col.default_value {
183                def.push_str(&format!(" DEFAULT {}", val));
184            }
185            defs.push(def);
186        }
187        defs.join(",\n    ")
188    }
189}
190
191pub struct Schema;
192
193impl Schema {
194    pub async fn create<F>(table_name: &str, callback: F) -> Result<(), Error>
195    where
196        F: FnOnce(&mut Blueprint),
197    {
198        validate_table_name(table_name)?;
199
200        let mut blueprint = Blueprint::new();
201        callback(&mut blueprint);
202
203        let columns_sql = blueprint.build();
204        let sql = format!(
205            "CREATE TABLE IF NOT EXISTS {} (\n    {}\n);",
206            table_name, columns_sql
207        );
208
209        let pool = crate::Orm::pool();
210        let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
211        query_builder.push(&sql);
212        query_builder.build().execute(pool).await?;
213
214        Ok(())
215    }
216
217    pub async fn drop_if_exists(table_name: &str) -> Result<(), Error> {
218        validate_table_name(table_name)?;
219
220        let sql = format!("DROP TABLE IF EXISTS {};", table_name);
221        let pool = crate::Orm::pool();
222        let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
223        query_builder.push(&sql);
224        query_builder.build().execute(pool).await?;
225        Ok(())
226    }
227}
228
229#[async_trait::async_trait]
230pub trait Migration: Send + Sync {
231    fn name(&self) -> &'static str;
232    async fn up(&self) -> Result<(), Error>;
233    async fn down(&self) -> Result<(), Error>;
234}
235
236pub async fn run_artisan_with_args(
237    args: &[String],
238    migrations: Vec<Box<dyn Migration>>,
239    seeders: Vec<Box<dyn crate::Seeder>>,
240) -> Result<(), Error> {
241    if args.len() < 2 {
242        println!("Rullst ORM Artisan CLI");
243        println!("Usage:");
244        println!("  make:migration <name>   Generate a new migration");
245        println!("  migrate                  Run all pending migrations");
246        println!("  migrate:rollback         Rollback the last batch of migrations");
247        println!("  status                   Show migrations status");
248        println!("  db:seed                  Populate the database with seeders");
249        return Ok(());
250    }
251
252    let command = &args[1];
253    match command.as_str() {
254        "make:migration" => {
255            if args.len() < 3 {
256                println!("Error: migration name is required.");
257                return Ok(());
258            }
259            let name = &args[2];
260            create_migration_files(name)?;
261        }
262        "migrate" | "db:migrate" => {
263            run_migrations(migrations).await?;
264        }
265        "migrate:rollback" | "db:rollback" => {
266            rollback_migrations(migrations).await?;
267        }
268        "status" | "db:status" => {
269            status_migrations(migrations).await?;
270        }
271        "db:seed" => {
272            println!("Seeding database...");
273            crate::Orm::seed(seeders).await?;
274            println!("Database seeded successfully!");
275        }
276        _ => {
277            println!("Unknown command: {}", command);
278        }
279    }
280    Ok(())
281}
282
283pub async fn run_artisan(
284    migrations: Vec<Box<dyn Migration>>,
285    seeders: Vec<Box<dyn crate::Seeder>>,
286) -> Result<(), Error> {
287    let args: Vec<String> = std::env::args().collect();
288    run_artisan_with_args(&args, migrations, seeders).await
289}
290
291async fn status_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
292    let pool = crate::Orm::pool();
293    let driver = crate::Orm::driver();
294
295    let table_exists = match driver {
296        "postgres" | "mysql" => {
297            let query_str =
298                "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
299            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
300            row.0 > 0
301        }
302        _ => {
303            let query_str =
304                "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
305            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
306            row.0 > 0
307        }
308    };
309
310    let executed_set = if table_exists {
311        let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
312            .fetch_all(pool)
313            .await?;
314        executed
315            .into_iter()
316            .map(|(m,)| m)
317            .collect::<std::collections::HashSet<String>>()
318    } else {
319        std::collections::HashSet::new()
320    };
321
322    let name_header = "Migration Name";
323    let status_header = "Status";
324    println!("{name_header:<40} | {status_header}");
325    println!("{}", "-".repeat(55));
326    for m in migrations {
327        let name = m.name();
328        let status = if executed_set.contains(name) {
329            "Applied"
330        } else {
331            "Pending"
332        };
333        println!("{:<40} | {}", name, status);
334    }
335
336    Ok(())
337}
338
339fn create_migration_files(name: &str) -> Result<(), Error> {
340    validate_table_name(name)?;
341    use std::fs;
342
343    let now = std::time::SystemTime::now()
344        .duration_since(std::time::UNIX_EPOCH)
345        .expect("System time went backwards")
346        .as_secs()
347        .to_string();
348    let snake_name = name.to_lowercase().replace("-", "_");
349    let file_name = format!("m{}_{}", now, snake_name);
350
351    fs::create_dir_all("src/migrations")
352        .map_err(|e| Error::Internal(format!("Failed to create migrations directory: {}", e)))?;
353
354    let new_file_path = format!("src/migrations/{}.rs", file_name);
355    let migration_code = format!(
356        r#"use rullst_orm::schema::{{Schema, Blueprint, Migration}};
357use rullst_orm::async_trait;
358
359pub struct MigrationImpl;
360
361#[async_trait]
362impl Migration for MigrationImpl {{
363    fn name(&self) -> &'static str {{
364        "m{timestamp}_{name}"
365    }}
366
367    async fn up(&self) -> Result<(), crate::Error> {{
368        Schema::create("{name}", |table| {{
369            table.id();
370            table.timestamps();
371        }}).await
372    }}
373
374    async fn down(&self) -> Result<(), crate::Error> {{
375        Schema::drop_if_exists("{name}").await
376    }}
377}}
378"#,
379        timestamp = now,
380        name = snake_name
381    );
382
383    fs::write(&new_file_path, migration_code)
384        .map_err(|e| Error::Internal(format!("Failed to write migration file: {}", e)))?;
385    println!("Created migration file: {}", new_file_path);
386
387    regenerate_migrations_mod()?;
388
389    Ok(())
390}
391
392fn regenerate_migrations_mod() -> Result<(), Error> {
393    use std::fs;
394    let paths = fs::read_dir("src/migrations")
395        .map_err(|e| Error::Internal(format!("Failed to read migrations dir: {}", e)))?;
396
397    let mut modules = vec![];
398    for path in paths {
399        let path = path.map_err(|e| Error::Internal(e.to_string()))?.path();
400        if let Some(ext) = path.extension()
401            && ext == "rs"
402            && let Some(stem) = path.file_stem()
403        {
404            let stem_str = stem.to_string_lossy().to_string();
405            if stem_str != "mod" && stem_str.starts_with('m') {
406                modules.push(stem_str);
407            }
408        }
409    }
410    modules.sort();
411
412    let mut mod_content = String::new();
413    mod_content.push_str("// Generated by Rullst ORM Artisan. Do not edit manually.\n\n");
414    for m in &modules {
415        mod_content.push_str(&format!("pub mod {};\n", m));
416    }
417    mod_content
418        .push_str("\npub fn get_migrations() -> Vec<Box<dyn rullst_orm::schema::Migration>> {\n");
419    mod_content.push_str("    vec![\n");
420    for m in &modules {
421        mod_content.push_str(&format!("        Box::new({}::MigrationImpl),\n", m));
422    }
423    mod_content.push_str("    ]\n");
424    mod_content.push_str("}\n");
425
426    fs::write("src/migrations/mod.rs", mod_content)
427        .map_err(|e| Error::Internal(format!("Failed to write mod.rs: {}", e)))?;
428    println!("Regenerated src/migrations/mod.rs");
429
430    Ok(())
431}
432
433async fn run_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
434    let pool = crate::Orm::pool();
435    let driver = crate::Orm::driver();
436
437    let query_str = match driver {
438        "postgres" => {
439            "CREATE TABLE IF NOT EXISTS migrations (
440                id SERIAL PRIMARY KEY,
441                migration VARCHAR(255) NOT NULL,
442                batch INTEGER NOT NULL
443            )"
444        }
445        "mysql" => {
446            "CREATE TABLE IF NOT EXISTS migrations (
447                id INT AUTO_INCREMENT PRIMARY KEY,
448                migration VARCHAR(255) NOT NULL,
449                batch INT NOT NULL
450            )"
451        }
452        _ => {
453            "CREATE TABLE IF NOT EXISTS migrations (
454                id INTEGER PRIMARY KEY AUTOINCREMENT,
455                migration TEXT NOT NULL,
456                batch INTEGER NOT NULL
457            )"
458        }
459    };
460
461    sqlx::query(query_str).execute(pool).await?;
462
463    let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
464        .fetch_all(pool)
465        .await?;
466    let executed_set: std::collections::HashSet<String> =
467        executed.into_iter().map(|(m,)| m).collect();
468
469    let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
470        .fetch_one(pool)
471        .await?;
472    let next_batch = batch_row.0.unwrap_or(0) + 1;
473
474    let mut count = 0;
475    for m in migrations {
476        let name = m.name();
477        if !executed_set.contains(name) {
478            println!("Migrating: {}", name);
479            m.up().await?;
480            sqlx::query("INSERT INTO migrations (migration, batch) VALUES (?, ?)")
481                .bind(name)
482                .bind(next_batch)
483                .execute(pool)
484                .await?;
485            println!("Migrated:  {}", name);
486            count += 1;
487        }
488    }
489
490    if count == 0 {
491        println!("Nothing to migrate.");
492    }
493
494    Ok(())
495}
496
497async fn rollback_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
498    let pool = crate::Orm::pool();
499    let driver = crate::Orm::driver();
500
501    let table_exists = match driver {
502        "postgres" | "mysql" => {
503            let query_str =
504                "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
505            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
506            row.0 > 0
507        }
508        _ => {
509            let query_str =
510                "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
511            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
512            row.0 > 0
513        }
514    };
515
516    if !table_exists {
517        println!("Nothing to rollback.");
518        return Ok(());
519    }
520
521    let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
522        .fetch_one(pool)
523        .await?;
524
525    let last_batch = match batch_row.0 {
526        Some(b) if b > 0 => b,
527        _ => {
528            println!("Nothing to rollback.");
529            return Ok(());
530        }
531    };
532
533    let to_rollback: Vec<(String,)> =
534        sqlx::query_as("SELECT migration FROM migrations WHERE batch = ? ORDER BY id DESC")
535            .bind(last_batch)
536            .fetch_all(pool)
537            .await?;
538
539    let mut rollback_map = std::collections::HashMap::new();
540    for m in migrations {
541        rollback_map.insert(m.name().to_string(), m);
542    }
543
544    for (name,) in to_rollback {
545        if let Some(m) = rollback_map.get(&name) {
546            println!("Rolling back: {}", name);
547            m.down().await?;
548            sqlx::query("DELETE FROM migrations WHERE migration = ?")
549                .bind(&name)
550                .execute(pool)
551                .await?;
552            println!("Rolled back:  {}", name);
553        } else {
554            println!(
555                "Warning: migration {} found in database but not in compiled binary.",
556                name
557            );
558        }
559    }
560
561    Ok(())
562}
563
564pub struct JoinClause {
565    pub table: String,
566    pub conditions: Vec<String>,
567    pub bindings: Vec<crate::RullstValue>,
568}
569
570impl JoinClause {
571    pub fn new(table: &str) -> Self {
572        Self {
573            table: table.to_string(),
574            conditions: vec![],
575            bindings: vec![],
576        }
577    }
578
579    /// Adds a column-to-column JOIN condition.
580    ///
581    /// # Panics
582    /// Panics if `first` or `second` are not valid SQL identifiers (alphanumeric,
583    /// underscores, hyphens, or a single qualifying dot), or if `operator` is not
584    /// one of: `=`, `!=`, `<>`, `<`, `>`, `<=`, `>=`.
585    /// This prevents SQL injection — column names should always be hardcoded, never
586    /// derived from user input.
587    pub fn on(&mut self, first: &str, operator: &str, second: &str) -> &mut Self {
588        validate_identifier(first)
589            .unwrap_or_else(|e| panic!("JoinClause::on — invalid identifier for `first`: {}", e));
590        validate_identifier(second)
591            .unwrap_or_else(|e| panic!("JoinClause::on — invalid identifier for `second`: {}", e));
592        if !ALLOWED_OPERATORS.contains(&operator) {
593            panic!(
594                "JoinClause::on — invalid operator '{}'. Allowed: {:?}",
595                operator, ALLOWED_OPERATORS
596            );
597        }
598        self.conditions
599            .push(format!("{} {} {}", first, operator, second));
600        self
601    }
602
603    pub fn on_eq<T: Into<crate::RullstValue>>(&mut self, column: &str, value: T) -> &mut Self {
604        self.conditions.push(format!("{} = ?", column));
605        self.bindings.push(value.into());
606        self
607    }
608
609    pub fn to_sql(&self) -> String {
610        self.conditions.join(" AND ")
611    }
612}
613
614pub trait SubqueryBuilder {
615    fn to_sql(&self) -> String;
616    fn bindings(&self) -> &Vec<crate::RullstValue>;
617}
618
619pub static QUERY_LOGGING: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
620
621pub fn enable_query_log() {
622    QUERY_LOGGING.store(true, std::sync::atomic::Ordering::SeqCst);
623}
624
625pub fn disable_query_log() {
626    QUERY_LOGGING.store(false, std::sync::atomic::Ordering::SeqCst);
627}
628
629pub fn is_query_log_enabled() -> bool {
630    QUERY_LOGGING.load(std::sync::atomic::Ordering::SeqCst)
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636
637    #[test]
638    fn test_enable_disable_query_log() {
639        disable_query_log();
640        assert!(!is_query_log_enabled());
641        enable_query_log();
642        assert!(is_query_log_enabled());
643        disable_query_log();
644        assert!(!is_query_log_enabled());
645    }
646
647    #[test]
648    fn test_join_clause() {
649        let mut jc = JoinClause::new("users");
650        jc.on("users.id", "=", "posts.user_id");
651        assert_eq!(jc.to_sql(), "users.id = posts.user_id");
652    }
653
654    #[test]
655    fn test_validate_table_name() {
656        assert!(validate_table_name("users").is_ok());
657        assert!(validate_table_name("user_posts").is_ok());
658        assert!(validate_table_name("DROP TABLE users").is_err());
659        assert!(validate_table_name("../../../etc/shadow").is_err());
660        // dots not allowed in table names
661        assert!(validate_table_name("users.id").is_err());
662    }
663
664    #[test]
665    fn test_validate_identifier() {
666        assert!(validate_identifier("users").is_ok());
667        assert!(validate_identifier("users.id").is_ok());
668        assert!(validate_identifier("user_posts").is_ok());
669        assert!(validate_identifier("").is_err());
670        assert!(validate_identifier("users.posts.id").is_err()); // two dots
671        assert!(validate_identifier("DROP TABLE users").is_err());
672        assert!(validate_identifier("id; DROP TABLE users--").is_err());
673    }
674
675    #[test]
676    #[should_panic(expected = "invalid operator")]
677    fn test_join_clause_on_invalid_operator() {
678        let mut jc = JoinClause::new("posts");
679        jc.on("posts.user_id", "OR 1=1 --", "users.id");
680    }
681
682    #[test]
683    #[should_panic(expected = "invalid identifier")]
684    fn test_join_clause_on_invalid_column() {
685        let mut jc = JoinClause::new("posts");
686        jc.on("users.id; DROP TABLE users--", "=", "posts.user_id");
687    }
688
689    #[test]
690    fn test_timestamps_adds_columns() {
691        let mut bp = Blueprint::new();
692        bp.timestamps();
693        assert_eq!(bp.columns.len(), 2);
694        assert_eq!(bp.columns[0].name, "created_at");
695        assert_eq!(bp.columns[1].name, "updated_at");
696        assert!(bp.columns[0].default_value.is_some());
697        assert!(bp.columns[1].default_value.is_some());
698    }
699
700    #[test]
701    fn test_soft_deletes_adds_nullable_column() {
702        let mut bp = Blueprint::new();
703        bp.soft_deletes();
704        assert_eq!(bp.columns.len(), 1);
705        assert_eq!(bp.columns[0].name, "deleted_at");
706        assert!(bp.columns[0].is_nullable);
707    }
708
709    #[test]
710    fn test_blueprint_build_produces_valid_sql() {
711        let mut bp = Blueprint::new();
712        bp.id();
713        bp.string("name").not_null();
714        bp.integer("age");
715        let sql = bp.build();
716        assert!(sql.contains("id INTEGER PRIMARY KEY"));
717        assert!(sql.contains("name TEXT NOT NULL"));
718        assert!(sql.contains("age INTEGER"));
719    }
720
721    #[test]
722    fn test_join_clause_on_eq_binds_value() {
723        let mut jc = JoinClause::new("orders");
724        jc.on_eq("orders.user_id", 42i32);
725        assert_eq!(jc.to_sql(), "orders.user_id = ?");
726        assert_eq!(jc.bindings.len(), 1);
727    }
728
729    #[test]
730    fn test_join_clause_multiple_conditions() {
731        let mut jc = JoinClause::new("posts");
732        jc.on("posts.user_id", "=", "users.id");
733        jc.on("posts.status", ">", "users.min_status");
734        assert_eq!(
735            jc.to_sql(),
736            "posts.user_id = users.id AND posts.status > users.min_status"
737        );
738    }
739}