Skip to main content

rullst_orm/
schema.rs

1use sqlx::Error;
2
3/// Validates a table name to prevent SQL injection
4/// Only allows alphanumeric characters, underscores, and hyphens
5fn validate_table_name(table_name: &str) -> Result<(), Error> {
6    if !table_name
7        .chars()
8        .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
9    {
10        return Err(Error::Protocol(format!(
11            "Invalid table name '{}': only alphanumeric characters, underscores, and hyphens are allowed",
12            table_name
13        )));
14    }
15    if table_name.is_empty() {
16        return Err(Error::Protocol("Table name cannot be empty".to_string()));
17    }
18    Ok(())
19}
20
21pub struct Column {
22    pub name: String,
23    pub col_type: String,
24    pub is_nullable: bool,
25    pub is_primary_key: bool,
26    pub is_auto_increment: bool,
27    pub default_value: Option<String>,
28}
29
30impl Column {
31    pub fn new(name: &str, col_type: &str) -> Self {
32        Self {
33            name: name.to_string(),
34            col_type: col_type.to_string(),
35            is_nullable: true,
36            is_primary_key: false,
37            is_auto_increment: false,
38            default_value: None,
39        }
40    }
41
42    pub fn not_null(&mut self) -> &mut Self {
43        self.is_nullable = false;
44        self
45    }
46
47    pub fn nullable(&mut self) -> &mut Self {
48        self.is_nullable = true;
49        self
50    }
51
52    pub fn default(&mut self, val: &str) -> &mut Self {
53        self.default_value = Some(val.to_string());
54        self
55    }
56
57    pub fn primary(&mut self) -> &mut Self {
58        self.is_primary_key = true;
59        self
60    }
61}
62
63pub struct Blueprint {
64    pub columns: Vec<Column>,
65}
66
67impl Default for Blueprint {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl Blueprint {
74    pub fn new() -> Self {
75        Self { columns: vec![] }
76    }
77
78    pub fn id(&mut self) -> &mut Column {
79        self.columns.push(Column {
80            name: "id".to_string(),
81            col_type: "INTEGER".to_string(),
82            is_nullable: false,
83            is_primary_key: true,
84            is_auto_increment: true,
85            default_value: None,
86        });
87        self.columns
88            .last_mut()
89            .expect("BUG: columns is empty after push")
90    }
91
92    pub fn string(&mut self, name: &str) -> &mut Column {
93        let col = Column::new(name, "TEXT");
94        self.columns.push(col);
95        self.columns
96            .last_mut()
97            .expect("BUG: columns is empty after push")
98    }
99
100    pub fn integer(&mut self, name: &str) -> &mut Column {
101        let col = Column::new(name, "INTEGER");
102        self.columns.push(col);
103        self.columns
104            .last_mut()
105            .expect("BUG: columns is empty after push")
106    }
107
108    pub fn float(&mut self, name: &str) -> &mut Column {
109        let col = Column::new(name, "REAL");
110        self.columns.push(col);
111        self.columns
112            .last_mut()
113            .expect("BUG: columns is empty after push")
114    }
115
116    pub fn boolean(&mut self, name: &str) -> &mut Column {
117        let col = Column::new(name, "INTEGER");
118        self.columns.push(col);
119        self.columns
120            .last_mut()
121            .expect("BUG: columns is empty after push")
122    }
123
124    pub fn timestamps(&mut self) {
125        let mut created = Column::new("created_at", "TEXT");
126        created.default("CURRENT_TIMESTAMP");
127        self.columns.push(created);
128
129        let mut updated = Column::new("updated_at", "TEXT");
130        updated.default("CURRENT_TIMESTAMP");
131        self.columns.push(updated);
132    }
133
134    pub fn soft_deletes(&mut self) {
135        let col = Column::new("deleted_at", "TEXT");
136        self.columns.push(col);
137        self.columns
138            .last_mut()
139            .expect("BUG: columns is empty after push")
140            .nullable();
141    }
142
143    pub fn build(&self) -> String {
144        let mut defs = vec![];
145        for col in &self.columns {
146            let mut def = format!("{} {}", col.name, col.col_type);
147            if col.is_primary_key {
148                def.push_str(" PRIMARY KEY");
149            }
150            if col.is_auto_increment {
151                def.push_str(" AUTOINCREMENT");
152            }
153            if !col.is_nullable && !col.is_primary_key {
154                def.push_str(" NOT NULL");
155            }
156            if let Some(val) = &col.default_value {
157                def.push_str(&format!(" DEFAULT {}", val));
158            }
159            defs.push(def);
160        }
161        defs.join(",\n    ")
162    }
163}
164
165pub struct Schema;
166
167impl Schema {
168    pub async fn create<F>(table_name: &str, callback: F) -> Result<(), Error>
169    where
170        F: FnOnce(&mut Blueprint),
171    {
172        validate_table_name(table_name)?;
173
174        let mut blueprint = Blueprint::new();
175        callback(&mut blueprint);
176
177        let columns_sql = blueprint.build();
178        let sql = format!(
179            "CREATE TABLE IF NOT EXISTS {} (\n    {}\n);",
180            table_name, columns_sql
181        );
182
183        let pool = crate::Orm::pool();
184        let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
185        query_builder.push(&sql);
186        query_builder.build().execute(pool).await?;
187
188        Ok(())
189    }
190
191    pub async fn drop_if_exists(table_name: &str) -> Result<(), Error> {
192        validate_table_name(table_name)?;
193
194        let sql = format!("DROP TABLE IF EXISTS {};", table_name);
195        let pool = crate::Orm::pool();
196        let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
197        query_builder.push(&sql);
198        query_builder.build().execute(pool).await?;
199        Ok(())
200    }
201}
202
203#[async_trait::async_trait]
204pub trait Migration: Send + Sync {
205    fn name(&self) -> &'static str;
206    async fn up(&self) -> Result<(), Error>;
207    async fn down(&self) -> Result<(), Error>;
208}
209
210pub async fn run_artisan_with_args(
211    args: &[String],
212    migrations: Vec<Box<dyn Migration>>,
213    seeders: Vec<Box<dyn crate::Seeder>>,
214) -> Result<(), Error> {
215    if args.len() < 2 {
216        println!("Rullst ORM Artisan CLI");
217        println!("Usage:");
218        println!("  make:migration <name>   Generate a new migration");
219        println!("  migrate                  Run all pending migrations");
220        println!("  migrate:rollback         Rollback the last batch of migrations");
221        println!("  status                   Show migrations status");
222        println!("  db:seed                  Populate the database with seeders");
223        return Ok(());
224    }
225
226    let command = &args[1];
227    match command.as_str() {
228        "make:migration" => {
229            if args.len() < 3 {
230                println!("Error: migration name is required.");
231                return Ok(());
232            }
233            let name = &args[2];
234            create_migration_files(name)?;
235        }
236        "migrate" | "db:migrate" => {
237            run_migrations(migrations).await?;
238        }
239        "migrate:rollback" | "db:rollback" => {
240            rollback_migrations(migrations).await?;
241        }
242        "status" | "db:status" => {
243            status_migrations(migrations).await?;
244        }
245        "db:seed" => {
246            println!("Seeding database...");
247            crate::Orm::seed(seeders).await?;
248            println!("Database seeded successfully!");
249        }
250        _ => {
251            println!("Unknown command: {}", command);
252        }
253    }
254    Ok(())
255}
256
257pub async fn run_artisan(
258    migrations: Vec<Box<dyn Migration>>,
259    seeders: Vec<Box<dyn crate::Seeder>>,
260) -> Result<(), Error> {
261    let args: Vec<String> = std::env::args().collect();
262    run_artisan_with_args(&args, migrations, seeders).await
263}
264
265async fn status_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
266    let pool = crate::Orm::pool();
267    let driver = crate::Orm::driver();
268
269    let table_exists = match driver {
270        "postgres" | "mysql" => {
271            let query_str =
272                "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
273            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
274            row.0 > 0
275        }
276        _ => {
277            let query_str =
278                "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
279            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
280            row.0 > 0
281        }
282    };
283
284    let executed_set = if table_exists {
285        let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
286            .fetch_all(pool)
287            .await?;
288        executed
289            .into_iter()
290            .map(|(m,)| m)
291            .collect::<std::collections::HashSet<String>>()
292    } else {
293        std::collections::HashSet::new()
294    };
295
296    let name_header = "Migration Name";
297    let status_header = "Status";
298    println!("{name_header:<40} | {status_header}");
299    println!("{}", "-".repeat(55));
300    for m in migrations {
301        let name = m.name();
302        let status = if executed_set.contains(name) {
303            "Applied"
304        } else {
305            "Pending"
306        };
307        println!("{:<40} | {}", name, status);
308    }
309
310    Ok(())
311}
312
313fn create_migration_files(name: &str) -> Result<(), Error> {
314    validate_table_name(name)?;
315    use std::fs;
316
317    let now = std::time::SystemTime::now()
318        .duration_since(std::time::UNIX_EPOCH)
319        .expect("System time went backwards")
320        .as_secs()
321        .to_string();
322    let snake_name = name.to_lowercase().replace("-", "_");
323    let file_name = format!("m{}_{}", now, snake_name);
324
325    fs::create_dir_all("src/migrations")
326        .map_err(|e| Error::Protocol(format!("Failed to create migrations directory: {}", e)))?;
327
328    let new_file_path = format!("src/migrations/{}.rs", file_name);
329    let migration_code = format!(
330        r#"use rullst_orm::schema::{{Schema, Blueprint, Migration}};
331use rullst_orm::async_trait;
332
333pub struct MigrationImpl;
334
335#[async_trait]
336impl Migration for MigrationImpl {{
337    fn name(&self) -> &'static str {{
338        "m{timestamp}_{name}"
339    }}
340
341    async fn up(&self) -> Result<(), rullst_orm::sqlx::Error> {{
342        Schema::create("{name}", |table| {{
343            table.id();
344            table.timestamps();
345        }}).await
346    }}
347
348    async fn down(&self) -> Result<(), rullst_orm::sqlx::Error> {{
349        Schema::drop_if_exists("{name}").await
350    }}
351}}
352"#,
353        timestamp = now,
354        name = snake_name
355    );
356
357    fs::write(&new_file_path, migration_code)
358        .map_err(|e| Error::Protocol(format!("Failed to write migration file: {}", e)))?;
359    println!("Created migration file: {}", new_file_path);
360
361    regenerate_migrations_mod()?;
362
363    Ok(())
364}
365
366fn regenerate_migrations_mod() -> Result<(), Error> {
367    use std::fs;
368    let paths = fs::read_dir("src/migrations")
369        .map_err(|e| Error::Protocol(format!("Failed to read migrations dir: {}", e)))?;
370
371    let mut modules = vec![];
372    for path in paths {
373        let path = path.map_err(|e| Error::Protocol(e.to_string()))?.path();
374        if let Some(ext) = path.extension()
375            && ext == "rs"
376            && let Some(stem) = path.file_stem()
377        {
378            let stem_str = stem.to_string_lossy().to_string();
379            if stem_str != "mod" && stem_str.starts_with('m') {
380                modules.push(stem_str);
381            }
382        }
383    }
384    modules.sort();
385
386    let mut mod_content = String::new();
387    mod_content.push_str("// Generated by Rullst ORM Artisan. Do not edit manually.\n\n");
388    for m in &modules {
389        mod_content.push_str(&format!("pub mod {};\n", m));
390    }
391    mod_content
392        .push_str("\npub fn get_migrations() -> Vec<Box<dyn rullst_orm::schema::Migration>> {\n");
393    mod_content.push_str("    vec![\n");
394    for m in &modules {
395        mod_content.push_str(&format!("        Box::new({}::MigrationImpl),\n", m));
396    }
397    mod_content.push_str("    ]\n");
398    mod_content.push_str("}\n");
399
400    fs::write("src/migrations/mod.rs", mod_content)
401        .map_err(|e| Error::Protocol(format!("Failed to write mod.rs: {}", e)))?;
402    println!("Regenerated src/migrations/mod.rs");
403
404    Ok(())
405}
406
407async fn run_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
408    let pool = crate::Orm::pool();
409    let driver = crate::Orm::driver();
410
411    let query_str = match driver {
412        "postgres" => {
413            "CREATE TABLE IF NOT EXISTS migrations (
414                id SERIAL PRIMARY KEY,
415                migration VARCHAR(255) NOT NULL,
416                batch INTEGER NOT NULL
417            )"
418        }
419        "mysql" => {
420            "CREATE TABLE IF NOT EXISTS migrations (
421                id INT AUTO_INCREMENT PRIMARY KEY,
422                migration VARCHAR(255) NOT NULL,
423                batch INT NOT NULL
424            )"
425        }
426        _ => {
427            "CREATE TABLE IF NOT EXISTS migrations (
428                id INTEGER PRIMARY KEY AUTOINCREMENT,
429                migration TEXT NOT NULL,
430                batch INTEGER NOT NULL
431            )"
432        }
433    };
434
435    sqlx::query(query_str).execute(pool).await?;
436
437    let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
438        .fetch_all(pool)
439        .await?;
440    let executed_set: std::collections::HashSet<String> =
441        executed.into_iter().map(|(m,)| m).collect();
442
443    let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
444        .fetch_one(pool)
445        .await?;
446    let next_batch = batch_row.0.unwrap_or(0) + 1;
447
448    let mut count = 0;
449    for m in migrations {
450        let name = m.name();
451        if !executed_set.contains(name) {
452            println!("Migrating: {}", name);
453            m.up().await?;
454            sqlx::query("INSERT INTO migrations (migration, batch) VALUES (?, ?)")
455                .bind(name)
456                .bind(next_batch)
457                .execute(pool)
458                .await?;
459            println!("Migrated:  {}", name);
460            count += 1;
461        }
462    }
463
464    if count == 0 {
465        println!("Nothing to migrate.");
466    }
467
468    Ok(())
469}
470
471async fn rollback_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
472    let pool = crate::Orm::pool();
473    let driver = crate::Orm::driver();
474
475    let table_exists = match driver {
476        "postgres" | "mysql" => {
477            let query_str =
478                "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
479            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
480            row.0 > 0
481        }
482        _ => {
483            let query_str =
484                "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
485            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
486            row.0 > 0
487        }
488    };
489
490    if !table_exists {
491        println!("Nothing to rollback.");
492        return Ok(());
493    }
494
495    let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
496        .fetch_one(pool)
497        .await?;
498
499    let last_batch = match batch_row.0 {
500        Some(b) if b > 0 => b,
501        _ => {
502            println!("Nothing to rollback.");
503            return Ok(());
504        }
505    };
506
507    let to_rollback: Vec<(String,)> =
508        sqlx::query_as("SELECT migration FROM migrations WHERE batch = ? ORDER BY id DESC")
509            .bind(last_batch)
510            .fetch_all(pool)
511            .await?;
512
513    let mut rollback_map = std::collections::HashMap::new();
514    for m in migrations {
515        rollback_map.insert(m.name().to_string(), m);
516    }
517
518    for (name,) in to_rollback {
519        if let Some(m) = rollback_map.get(&name) {
520            println!("Rolling back: {}", name);
521            m.down().await?;
522            sqlx::query("DELETE FROM migrations WHERE migration = ?")
523                .bind(&name)
524                .execute(pool)
525                .await?;
526            println!("Rolled back:  {}", name);
527        } else {
528            println!(
529                "Warning: migration {} found in database but not in compiled binary.",
530                name
531            );
532        }
533    }
534
535    Ok(())
536}
537
538pub struct JoinClause {
539    pub table: String,
540    pub conditions: Vec<String>,
541    pub bindings: Vec<crate::RullstValue>,
542}
543
544impl JoinClause {
545    pub fn new(table: &str) -> Self {
546        Self {
547            table: table.to_string(),
548            conditions: vec![],
549            bindings: vec![],
550        }
551    }
552
553    /// WARNING: Ensure `first` and `second` do not contain user input to prevent SQL Injection.
554    pub fn on(&mut self, first: &str, operator: &str, second: &str) -> &mut Self {
555        self.conditions
556            .push(format!("{} {} {}", first, operator, second));
557        self
558    }
559
560    pub fn on_eq<T: Into<crate::RullstValue>>(&mut self, column: &str, value: T) -> &mut Self {
561        self.conditions.push(format!("{} = ?", column));
562        self.bindings.push(value.into());
563        self
564    }
565
566    pub fn to_sql(&self) -> String {
567        self.conditions.join(" AND ")
568    }
569}
570
571pub trait SubqueryBuilder {
572    fn to_sql(&self) -> String;
573    fn bindings(&self) -> &Vec<crate::RullstValue>;
574}
575
576pub static QUERY_LOGGING: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
577
578pub fn enable_query_log() {
579    QUERY_LOGGING.store(true, std::sync::atomic::Ordering::SeqCst);
580}
581
582pub fn disable_query_log() {
583    QUERY_LOGGING.store(false, std::sync::atomic::Ordering::SeqCst);
584}
585
586pub fn is_query_log_enabled() -> bool {
587    QUERY_LOGGING.load(std::sync::atomic::Ordering::SeqCst)
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593
594    #[test]
595    fn test_enable_disable_query_log() {
596        disable_query_log();
597        assert!(!is_query_log_enabled());
598        enable_query_log();
599        assert!(is_query_log_enabled());
600        disable_query_log();
601        assert!(!is_query_log_enabled());
602    }
603
604    #[test]
605    fn test_join_clause() {
606        let mut jc = JoinClause::new("users");
607        jc.on("users.id", "=", "posts.user_id");
608        assert_eq!(jc.to_sql(), "users.id = posts.user_id");
609    }
610
611    #[test]
612    fn test_validate_table_name() {
613        assert!(validate_table_name("users").is_ok());
614        assert!(validate_table_name("user_posts").is_ok());
615        assert!(validate_table_name("DROP TABLE users").is_err());
616        assert!(validate_table_name("../../../etc/shadow").is_err());
617    }
618}