Skip to main content

rullst_orm/
schema.rs

1use sqlx::Error;
2
3/// Validates a table name to prevent SQL injection
4/// Only allows alphanumeric characters, underscores, and hyphens
5fn validate_table_name(table_name: &str) -> Result<(), Error> {
6    if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '-') {
7        return Err(Error::Protocol(format!(
8            "Invalid table name '{}': only alphanumeric characters, underscores, and hyphens are allowed",
9            table_name
10        )));
11    }
12    if table_name.is_empty() {
13        return Err(Error::Protocol("Table name cannot be empty".to_string()));
14    }
15    Ok(())
16}
17
18pub struct Column {
19    pub name: String,
20    pub col_type: String,
21    pub is_nullable: bool,
22    pub is_primary_key: bool,
23    pub is_auto_increment: bool,
24    pub default_value: Option<String>,
25}
26
27impl Column {
28    pub fn new(name: &str, col_type: &str) -> Self {
29        Self {
30            name: name.to_string(),
31            col_type: col_type.to_string(),
32            is_nullable: true,
33            is_primary_key: false,
34            is_auto_increment: false,
35            default_value: None,
36        }
37    }
38
39    pub fn not_null(&mut self) -> &mut Self {
40        self.is_nullable = false;
41        self
42    }
43
44    pub fn nullable(&mut self) -> &mut Self {
45        self.is_nullable = true;
46        self
47    }
48
49    pub fn default(&mut self, val: &str) -> &mut Self {
50        self.default_value = Some(val.to_string());
51        self
52    }
53
54    pub fn primary(&mut self) -> &mut Self {
55        self.is_primary_key = true;
56        self
57    }
58}
59
60pub struct Blueprint {
61    pub columns: Vec<Column>,
62}
63
64impl Default for Blueprint {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl Blueprint {
71    pub fn new() -> Self {
72        Self { columns: vec![] }
73    }
74
75    pub fn id(&mut self) -> &mut Column {
76        self.columns.push(Column {
77            name: "id".to_string(),
78            col_type: "INTEGER".to_string(),
79            is_nullable: false,
80            is_primary_key: true,
81            is_auto_increment: true,
82            default_value: None,
83        });
84        self.columns.last_mut().expect("BUG: columns is empty after push")
85    }
86
87    pub fn string(&mut self, name: &str) -> &mut Column {
88        let col = Column::new(name, "TEXT");
89        self.columns.push(col);
90        self.columns.last_mut().expect("BUG: columns is empty after push")
91    }
92
93    pub fn integer(&mut self, name: &str) -> &mut Column {
94        let col = Column::new(name, "INTEGER");
95        self.columns.push(col);
96        self.columns.last_mut().expect("BUG: columns is empty after push")
97    }
98
99    pub fn float(&mut self, name: &str) -> &mut Column {
100        let col = Column::new(name, "REAL");
101        self.columns.push(col);
102        self.columns.last_mut().expect("BUG: columns is empty after push")
103    }
104
105    pub fn boolean(&mut self, name: &str) -> &mut Column {
106        let col = Column::new(name, "INTEGER");
107        self.columns.push(col);
108        self.columns.last_mut().expect("BUG: columns is empty after push")
109    }
110
111    pub fn timestamps(&mut self) {
112        let mut created = Column::new("created_at", "TEXT");
113        created.default("CURRENT_TIMESTAMP");
114        self.columns.push(created);
115        
116        let mut updated = Column::new("updated_at", "TEXT");
117        updated.default("CURRENT_TIMESTAMP");
118        self.columns.push(updated);
119    }
120
121    pub fn soft_deletes(&mut self) {
122        let col = Column::new("deleted_at", "TEXT");
123        self.columns.push(col);
124        self.columns.last_mut().expect("BUG: columns is empty after push").nullable();
125    }
126    
127    pub fn build(&self) -> String {
128        let mut defs = vec![];
129        for col in &self.columns {
130            let mut def = format!("{} {}", col.name, col.col_type);
131            if col.is_primary_key {
132                def.push_str(" PRIMARY KEY");
133            }
134            if col.is_auto_increment {
135                def.push_str(" AUTOINCREMENT");
136            }
137            if !col.is_nullable && !col.is_primary_key {
138                def.push_str(" NOT NULL");
139            }
140            if let Some(val) = &col.default_value {
141                def.push_str(&format!(" DEFAULT {}", val));
142            }
143            defs.push(def);
144        }
145        defs.join(",\n    ")
146    }
147}
148
149pub struct Schema;
150
151impl Schema {
152    pub async fn create<F>(table_name: &str, callback: F) -> Result<(), Error>
153    where
154        F: FnOnce(&mut Blueprint),
155    {
156        validate_table_name(table_name)?;
157        
158        let mut blueprint = Blueprint::new();
159        callback(&mut blueprint);
160        
161        let columns_sql = blueprint.build();
162        let sql = format!("CREATE TABLE IF NOT EXISTS {} (\n    {}\n);", table_name, columns_sql);
163        
164        let pool = crate::Orm::pool();
165        let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
166        query_builder.push(&sql);
167        query_builder.build().execute(pool).await?;
168        
169        Ok(())
170    }
171    
172    pub async fn drop_if_exists(table_name: &str) -> Result<(), Error> {
173        validate_table_name(table_name)?;
174        
175        let sql = format!("DROP TABLE IF EXISTS {};", table_name);
176        let pool = crate::Orm::pool();
177        let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
178        query_builder.push(&sql);
179        query_builder.build().execute(pool).await?;
180        Ok(())
181    }
182}
183
184#[async_trait::async_trait]
185pub trait Migration: Send + Sync {
186    fn name(&self) -> &'static str;
187    async fn up(&self) -> Result<(), Error>;
188    async fn down(&self) -> Result<(), Error>;
189}
190
191pub async fn run_artisan_with_args(
192    args: &[String],
193    migrations: Vec<Box<dyn Migration>>,
194    seeders: Vec<Box<dyn crate::Seeder>>
195) -> Result<(), Error> {
196    if args.len() < 2 {
197        println!("Rullst ORM Artisan CLI");
198        println!("Usage:");
199        println!("  make:migration <name>   Generate a new migration");
200        println!("  migrate                  Run all pending migrations");
201        println!("  migrate:rollback         Rollback the last batch of migrations");
202        println!("  status                   Show migrations status");
203        println!("  db:seed                  Populate the database with seeders");
204        return Ok(());
205    }
206
207    let command = &args[1];
208    match command.as_str() {
209        "make:migration" => {
210            if args.len() < 3 {
211                println!("Error: migration name is required.");
212                return Ok(());
213            }
214            let name = &args[2];
215            create_migration_files(name)?;
216        }
217        "migrate" | "db:migrate" => {
218            run_migrations(migrations).await?;
219        }
220        "migrate:rollback" | "db:rollback" => {
221            rollback_migrations(migrations).await?;
222        }
223        "status" | "db:status" => {
224            status_migrations(migrations).await?;
225        }
226        "db:seed" => {
227            println!("Seeding database...");
228            crate::Orm::seed(seeders).await?;
229            println!("Database seeded successfully!");
230        }
231        _ => {
232            println!("Unknown command: {}", command);
233        }
234    }
235    Ok(())
236}
237
238pub async fn run_artisan(
239    migrations: Vec<Box<dyn Migration>>,
240    seeders: Vec<Box<dyn crate::Seeder>>
241) -> Result<(), Error> {
242    let args: Vec<String> = std::env::args().collect();
243    run_artisan_with_args(&args, migrations, seeders).await
244}
245
246async fn status_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
247    let pool = crate::Orm::pool();
248    let driver = crate::Orm::driver();
249
250    let table_exists = match driver {
251        "postgres" | "mysql" => {
252            let query_str = "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
253            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
254            row.0 > 0
255        }
256        _ => {
257            let query_str = "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
258            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
259            row.0 > 0
260        }
261    };
262
263    let executed_set = if table_exists {
264        let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
265            .fetch_all(pool)
266            .await?;
267        executed.into_iter().map(|(m,)| m).collect::<std::collections::HashSet<String>>()
268    } else {
269        std::collections::HashSet::new()
270    };
271
272    let name_header = "Migration Name";
273    let status_header = "Status";
274    println!("{name_header:<40} | {status_header}");
275    println!("{}", "-".repeat(55));
276    for m in migrations {
277        let name = m.name();
278        let status = if executed_set.contains(name) {
279            "Applied"
280        } else {
281            "Pending"
282        };
283        println!("{:<40} | {}", name, status);
284    }
285
286    Ok(())
287}
288
289fn create_migration_files(name: &str) -> Result<(), Error> {
290    validate_table_name(name)?;
291    use std::fs;
292    
293    let now = std::time::SystemTime::now()
294        .duration_since(std::time::UNIX_EPOCH)
295        .expect("System time went backwards")
296        .as_secs()
297        .to_string();
298    let snake_name = name.to_lowercase().replace("-", "_");
299    let file_name = format!("m{}_{}", now, snake_name);
300    
301    fs::create_dir_all("src/migrations").map_err(|e| {
302        Error::Protocol(format!("Failed to create migrations directory: {}", e))
303    })?;
304
305    let new_file_path = format!("src/migrations/{}.rs", file_name);
306    let migration_code = format!(
307r#"use rullst_orm::schema::{{Schema, Blueprint, Migration}};
308use rullst_orm::async_trait;
309
310pub struct MigrationImpl;
311
312#[async_trait]
313impl Migration for MigrationImpl {{
314    fn name(&self) -> &'static str {{
315        "m{timestamp}_{name}"
316    }}
317
318    async fn up(&self) -> Result<(), rullst_orm::sqlx::Error> {{
319        Schema::create("{name}", |table| {{
320            table.id();
321            table.timestamps();
322        }}).await
323    }}
324
325    async fn down(&self) -> Result<(), rullst_orm::sqlx::Error> {{
326        Schema::drop_if_exists("{name}").await
327    }}
328}}
329"#,
330        timestamp = now,
331        name = snake_name
332    );
333
334    fs::write(&new_file_path, migration_code).map_err(|e| {
335        Error::Protocol(format!("Failed to write migration file: {}", e))
336    })?;
337    println!("Created migration file: {}", new_file_path);
338
339    regenerate_migrations_mod()?;
340
341    Ok(())
342}
343
344fn regenerate_migrations_mod() -> Result<(), Error> {
345    use std::fs;
346    let paths = fs::read_dir("src/migrations").map_err(|e| {
347        Error::Protocol(format!("Failed to read migrations dir: {}", e))
348    })?;
349
350    let mut modules = vec![];
351    for path in paths {
352        let path = path.map_err(|e| Error::Protocol(e.to_string()))?.path();
353        if let Some(ext) = path.extension()
354            && ext == "rs"
355                && let Some(stem) = path.file_stem() {
356                    let stem_str = stem.to_string_lossy().to_string();
357                    if stem_str != "mod" && stem_str.starts_with('m') {
358                        modules.push(stem_str);
359                    }
360                }
361    }
362    modules.sort();
363
364    let mut mod_content = String::new();
365    mod_content.push_str("// Generated by Rullst ORM Artisan. Do not edit manually.\n\n");
366    for m in &modules {
367        mod_content.push_str(&format!("pub mod {};\n", m));
368    }
369    mod_content.push_str("\npub fn get_migrations() -> Vec<Box<dyn rullst_orm::schema::Migration>> {\n");
370    mod_content.push_str("    vec![\n");
371    for m in &modules {
372        mod_content.push_str(&format!("        Box::new({}::MigrationImpl),\n", m));
373    }
374    mod_content.push_str("    ]\n");
375    mod_content.push_str("}\n");
376
377    fs::write("src/migrations/mod.rs", mod_content).map_err(|e| {
378        Error::Protocol(format!("Failed to write mod.rs: {}", e))
379    })?;
380    println!("Regenerated src/migrations/mod.rs");
381
382    Ok(())
383}
384
385async fn run_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
386    let pool = crate::Orm::pool();
387    let driver = crate::Orm::driver();
388
389    let query_str = match driver {
390        "postgres" => {
391            "CREATE TABLE IF NOT EXISTS migrations (
392                id SERIAL PRIMARY KEY,
393                migration VARCHAR(255) NOT NULL,
394                batch INTEGER NOT NULL
395            )"
396        }
397        "mysql" => {
398            "CREATE TABLE IF NOT EXISTS migrations (
399                id INT AUTO_INCREMENT PRIMARY KEY,
400                migration VARCHAR(255) NOT NULL,
401                batch INT NOT NULL
402            )"
403        }
404        _ => {
405            "CREATE TABLE IF NOT EXISTS migrations (
406                id INTEGER PRIMARY KEY AUTOINCREMENT,
407                migration TEXT NOT NULL,
408                batch INTEGER NOT NULL
409            )"
410        }
411    };
412
413    sqlx::query(query_str).execute(pool).await?;
414
415    let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
416        .fetch_all(pool)
417        .await?;
418    let executed_set: std::collections::HashSet<String> = executed.into_iter().map(|(m,)| m).collect();
419
420    let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
421        .fetch_one(pool)
422        .await?;
423    let next_batch = batch_row.0.unwrap_or(0) + 1;
424
425    let mut count = 0;
426    for m in migrations {
427        let name = m.name();
428        if !executed_set.contains(name) {
429            println!("Migrating: {}", name);
430            m.up().await?;
431            sqlx::query("INSERT INTO migrations (migration, batch) VALUES (?, ?)")
432                .bind(name)
433                .bind(next_batch)
434                .execute(pool)
435                .await?;
436            println!("Migrated:  {}", name);
437            count += 1;
438        }
439    }
440
441    if count == 0 {
442        println!("Nothing to migrate.");
443    }
444
445    Ok(())
446}
447
448async fn rollback_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
449    let pool = crate::Orm::pool();
450    let driver = crate::Orm::driver();
451
452    let table_exists = match driver {
453        "postgres" | "mysql" => {
454            let query_str = "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
455            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
456            row.0 > 0
457        }
458        _ => {
459            let query_str = "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
460            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
461            row.0 > 0
462        }
463    };
464
465    if !table_exists {
466        println!("Nothing to rollback.");
467        return Ok(());
468    }
469
470    let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
471        .fetch_one(pool)
472        .await?;
473    
474    let last_batch = match batch_row.0 {
475        Some(b) if b > 0 => b,
476        _ => {
477            println!("Nothing to rollback.");
478            return Ok(());
479        }
480    };
481
482    let to_rollback: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations WHERE batch = ? ORDER BY id DESC")
483        .bind(last_batch)
484        .fetch_all(pool)
485        .await?;
486
487    let mut rollback_map = std::collections::HashMap::new();
488    for m in migrations {
489        rollback_map.insert(m.name().to_string(), m);
490    }
491
492    for (name,) in to_rollback {
493        if let Some(m) = rollback_map.get(&name) {
494            println!("Rolling back: {}", name);
495            m.down().await?;
496            sqlx::query("DELETE FROM migrations WHERE migration = ?")
497                .bind(&name)
498                .execute(pool)
499                .await?;
500            println!("Rolled back:  {}", name);
501        } else {
502            println!("Warning: migration {} found in database but not in compiled binary.", name);
503        }
504    }
505
506    Ok(())
507}
508
509pub struct JoinClause {
510    pub table: String,
511    pub conditions: Vec<String>,
512    pub bindings: Vec<crate::EloquentValue>,
513}
514
515impl JoinClause {
516    pub fn new(table: &str) -> Self {
517        Self {
518            table: table.to_string(),
519            conditions: vec![],
520            bindings: vec![],
521        }
522    }
523
524    /// WARNING: Ensure `first` and `second` do not contain user input to prevent SQL Injection.
525    pub fn on(&mut self, first: &str, operator: &str, second: &str) -> &mut Self {
526        self.conditions.push(format!("{} {} {}", first, operator, second));
527        self
528    }
529
530    pub fn on_eq<T: Into<crate::EloquentValue>>(&mut self, column: &str, value: T) -> &mut Self {
531        self.conditions.push(format!("{} = ?", column));
532        self.bindings.push(value.into());
533        self
534    }
535
536    pub fn to_sql(&self) -> String {
537        self.conditions.join(" AND ")
538    }
539}
540
541pub trait SubqueryBuilder {
542    fn to_sql(&self) -> String;
543    fn bindings(&self) -> &Vec<crate::EloquentValue>;
544}
545
546pub static QUERY_LOGGING: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
547
548pub fn enable_query_log() {
549    QUERY_LOGGING.store(true, std::sync::atomic::Ordering::SeqCst);
550}
551
552pub fn disable_query_log() {
553    QUERY_LOGGING.store(false, std::sync::atomic::Ordering::SeqCst);
554}
555
556pub fn is_query_log_enabled() -> bool {
557    QUERY_LOGGING.load(std::sync::atomic::Ordering::SeqCst)
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    #[test]
565    fn test_enable_disable_query_log() {
566        disable_query_log();
567        assert!(!is_query_log_enabled());
568        enable_query_log();
569        assert!(is_query_log_enabled());
570        disable_query_log();
571        assert!(!is_query_log_enabled());
572    }
573
574    #[test]
575    fn test_join_clause() {
576        let mut jc = JoinClause::new("users");
577        jc.on("users.id", "=", "posts.user_id");
578        assert_eq!(jc.to_sql(), "users.id = posts.user_id");
579    }
580
581    #[test]
582    fn test_validate_table_name() {
583        assert!(validate_table_name("users").is_ok());
584        assert!(validate_table_name("user_posts").is_ok());
585        assert!(validate_table_name("DROP TABLE users").is_err());
586        assert!(validate_table_name("../../../etc/shadow").is_err());
587    }
588}