Skip to main content

rust_eloquent/
schema.rs

1use sqlx::Error;
2
3pub struct Column {
4    pub name: String,
5    pub col_type: String,
6    pub is_nullable: bool,
7    pub is_primary_key: bool,
8    pub is_auto_increment: bool,
9    pub default_value: Option<String>,
10}
11
12impl Column {
13    pub fn new(name: &str, col_type: &str) -> Self {
14        Self {
15            name: name.to_string(),
16            col_type: col_type.to_string(),
17            is_nullable: true,
18            is_primary_key: false,
19            is_auto_increment: false,
20            default_value: None,
21        }
22    }
23
24    pub fn not_null(&mut self) -> &mut Self {
25        self.is_nullable = false;
26        self
27    }
28
29    pub fn nullable(&mut self) -> &mut Self {
30        self.is_nullable = true;
31        self
32    }
33
34    pub fn default(&mut self, val: &str) -> &mut Self {
35        self.default_value = Some(val.to_string());
36        self
37    }
38
39    pub fn primary(&mut self) -> &mut Self {
40        self.is_primary_key = true;
41        self
42    }
43}
44
45pub struct Blueprint {
46    pub columns: Vec<Column>,
47}
48
49impl Blueprint {
50    pub fn new() -> Self {
51        Self { columns: vec![] }
52    }
53
54    pub fn id(&mut self) -> &mut Column {
55        self.columns.push(Column {
56            name: "id".to_string(),
57            col_type: "INTEGER".to_string(),
58            is_nullable: false,
59            is_primary_key: true,
60            is_auto_increment: true,
61            default_value: None,
62        });
63        self.columns.last_mut().unwrap()
64    }
65
66    pub fn string(&mut self, name: &str) -> &mut Column {
67        let col = Column::new(name, "TEXT");
68        self.columns.push(col);
69        self.columns.last_mut().unwrap()
70    }
71
72    pub fn integer(&mut self, name: &str) -> &mut Column {
73        let col = Column::new(name, "INTEGER");
74        self.columns.push(col);
75        self.columns.last_mut().unwrap()
76    }
77
78    pub fn float(&mut self, name: &str) -> &mut Column {
79        let col = Column::new(name, "REAL");
80        self.columns.push(col);
81        self.columns.last_mut().unwrap()
82    }
83
84    pub fn boolean(&mut self, name: &str) -> &mut Column {
85        let col = Column::new(name, "INTEGER");
86        self.columns.push(col);
87        self.columns.last_mut().unwrap()
88    }
89
90    pub fn timestamps(&mut self) {
91        let mut created = Column::new("created_at", "TEXT");
92        created.default("CURRENT_TIMESTAMP");
93        self.columns.push(created);
94        
95        let mut updated = Column::new("updated_at", "TEXT");
96        updated.default("CURRENT_TIMESTAMP");
97        self.columns.push(updated);
98    }
99
100    pub fn soft_deletes(&mut self) {
101        let col = Column::new("deleted_at", "TEXT");
102        self.columns.push(col);
103        self.columns.last_mut().unwrap().nullable();
104    }
105    
106    pub fn build(&self) -> String {
107        let mut defs = vec![];
108        for col in &self.columns {
109            let mut def = format!("{} {}", col.name, col.col_type);
110            if col.is_primary_key {
111                def.push_str(" PRIMARY KEY");
112            }
113            if col.is_auto_increment {
114                def.push_str(" AUTOINCREMENT");
115            }
116            if !col.is_nullable && !col.is_primary_key {
117                def.push_str(" NOT NULL");
118            }
119            if let Some(val) = &col.default_value {
120                def.push_str(&format!(" DEFAULT {}", val));
121            }
122            defs.push(def);
123        }
124        defs.join(",\n    ")
125    }
126}
127
128pub struct Schema;
129
130impl Schema {
131    pub async fn create<F>(table_name: &str, callback: F) -> Result<(), Error>
132    where
133        F: FnOnce(&mut Blueprint),
134    {
135        let mut blueprint = Blueprint::new();
136        callback(&mut blueprint);
137        
138        let columns_sql = blueprint.build();
139        let sql = format!("CREATE TABLE IF NOT EXISTS {} (\n    {}\n);", table_name, columns_sql);
140        
141        let pool = crate::Eloquent::pool();
142        sqlx::query(&sql).execute(pool).await?;
143        
144        Ok(())
145    }
146    
147    pub async fn drop_if_exists(table_name: &str) -> Result<(), Error> {
148        let sql = format!("DROP TABLE IF EXISTS {};", table_name);
149        let pool = crate::Eloquent::pool();
150        sqlx::query(&sql).execute(pool).await?;
151        Ok(())
152    }
153}
154
155#[async_trait::async_trait]
156pub trait Migration: Send + Sync {
157    fn name(&self) -> &'static str;
158    async fn up(&self) -> Result<(), Error>;
159    async fn down(&self) -> Result<(), Error>;
160}
161
162pub async fn run_artisan_with_args(
163    args: &[String],
164    migrations: Vec<Box<dyn Migration>>,
165    seeders: Vec<Box<dyn crate::Seeder>>
166) -> Result<(), Error> {
167    if args.len() < 2 {
168        println!("Rust Eloquent Artisan CLI");
169        println!("Usage:");
170        println!("  make:migration <name>   Generate a new migration");
171        println!("  migrate                  Run all pending migrations");
172        println!("  migrate:rollback         Rollback the last batch of migrations");
173        println!("  status                   Show migrations status");
174        println!("  db:seed                  Populate the database with seeders");
175        return Ok(());
176    }
177
178    let command = &args[1];
179    match command.as_str() {
180        "make:migration" => {
181            if args.len() < 3 {
182                println!("Error: migration name is required.");
183                return Ok(());
184            }
185            let name = &args[2];
186            create_migration_files(name)?;
187        }
188        "migrate" | "db:migrate" => {
189            run_migrations(migrations).await?;
190        }
191        "migrate:rollback" | "db:rollback" => {
192            rollback_migrations(migrations).await?;
193        }
194        "status" | "db:status" => {
195            status_migrations(migrations).await?;
196        }
197        "db:seed" => {
198            println!("Seeding database...");
199            crate::Eloquent::seed(seeders).await?;
200            println!("Database seeded successfully!");
201        }
202        _ => {
203            println!("Unknown command: {}", command);
204        }
205    }
206    Ok(())
207}
208
209pub async fn run_artisan(
210    migrations: Vec<Box<dyn Migration>>,
211    seeders: Vec<Box<dyn crate::Seeder>>
212) -> Result<(), Error> {
213    let args: Vec<String> = std::env::args().collect();
214    run_artisan_with_args(&args, migrations, seeders).await
215}
216
217async fn status_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
218    let pool = crate::Eloquent::pool();
219    let driver = crate::Eloquent::driver();
220
221    let table_exists = match driver {
222        "postgres" | "mysql" => {
223            let query_str = "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
224            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await.unwrap_or((0,));
225            row.0 > 0
226        }
227        _ => {
228            let query_str = "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
229            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await.unwrap_or((0,));
230            row.0 > 0
231        }
232    };
233
234    let executed_set = if table_exists {
235        let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
236            .fetch_all(pool)
237            .await?;
238        executed.into_iter().map(|(m,)| m).collect::<std::collections::HashSet<String>>()
239    } else {
240        std::collections::HashSet::new()
241    };
242
243    println!("{:<40} | {}", "Migration Name", "Status");
244    println!("{}", "-".repeat(55));
245    for m in migrations {
246        let name = m.name();
247        let status = if executed_set.contains(name) {
248            "Applied"
249        } else {
250            "Pending"
251        };
252        println!("{:<40} | {}", name, status);
253    }
254
255    Ok(())
256}
257
258fn create_migration_files(name: &str) -> Result<(), Error> {
259    use std::fs;
260    
261    let now = std::time::SystemTime::now()
262        .duration_since(std::time::UNIX_EPOCH)
263        .unwrap()
264        .as_secs()
265        .to_string();
266    let snake_name = name.to_lowercase().replace("-", "_");
267    let file_name = format!("m{}_{}", now, snake_name);
268    
269    fs::create_dir_all("src/migrations").map_err(|e| {
270        Error::Protocol(format!("Failed to create migrations directory: {}", e))
271    })?;
272
273    let new_file_path = format!("src/migrations/{}.rs", file_name);
274    let migration_code = format!(
275r#"use rust_eloquent::schema::{{Schema, Blueprint, Migration}};
276use rust_eloquent::async_trait;
277
278pub struct MigrationImpl;
279
280#[async_trait]
281impl Migration for MigrationImpl {{
282    fn name(&self) -> &'static str {{
283        "m{timestamp}_{name}"
284    }}
285
286    async fn up(&self) -> Result<(), rust_eloquent::sqlx::Error> {{
287        Schema::create("{name}", |table| {{
288            table.id();
289            // table.string("column_name");
290            table.timestamps();
291        }}).await
292    }}
293
294    async fn down(&self) -> Result<(), rust_eloquent::sqlx::Error> {{
295        Schema::drop_if_exists("{name}").await
296    }}
297}}
298"#,
299        timestamp = now,
300        name = snake_name
301    );
302
303    fs::write(&new_file_path, migration_code).map_err(|e| {
304        Error::Protocol(format!("Failed to write migration file: {}", e))
305    })?;
306    println!("Created migration file: {}", new_file_path);
307
308    regenerate_migrations_mod()?;
309
310    Ok(())
311}
312
313fn regenerate_migrations_mod() -> Result<(), Error> {
314    use std::fs;
315    let paths = fs::read_dir("src/migrations").map_err(|e| {
316        Error::Protocol(format!("Failed to read migrations dir: {}", e))
317    })?;
318
319    let mut modules = vec![];
320    for path in paths {
321        let path = path.map_err(|e| Error::Protocol(e.to_string()))?.path();
322        if let Some(ext) = path.extension() {
323            if ext == "rs" {
324                if let Some(stem) = path.file_stem() {
325                    let stem_str = stem.to_string_lossy().to_string();
326                    if stem_str != "mod" && stem_str.starts_with('m') {
327                        modules.push(stem_str);
328                    }
329                }
330            }
331        }
332    }
333    modules.sort();
334
335    let mut mod_content = String::new();
336    mod_content.push_str("// Generated by Rust Eloquent Artisan. Do not edit manually.\n\n");
337    for m in &modules {
338        mod_content.push_str(&format!("pub mod {};\n", m));
339    }
340    mod_content.push_str("\npub fn get_migrations() -> Vec<Box<dyn rust_eloquent::schema::Migration>> {\n");
341    mod_content.push_str("    vec![\n");
342    for m in &modules {
343        mod_content.push_str(&format!("        Box::new({}::MigrationImpl),\n", m));
344    }
345    mod_content.push_str("    ]\n");
346    mod_content.push_str("}\n");
347
348    fs::write("src/migrations/mod.rs", mod_content).map_err(|e| {
349        Error::Protocol(format!("Failed to write mod.rs: {}", e))
350    })?;
351    println!("Regenerated src/migrations/mod.rs");
352
353    Ok(())
354}
355
356async fn run_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
357    let pool = crate::Eloquent::pool();
358    let driver = crate::Eloquent::driver();
359
360    let query_str = match driver {
361        "postgres" => {
362            "CREATE TABLE IF NOT EXISTS migrations (
363                id SERIAL PRIMARY KEY,
364                migration VARCHAR(255) NOT NULL,
365                batch INTEGER NOT NULL
366            )"
367        }
368        "mysql" => {
369            "CREATE TABLE IF NOT EXISTS migrations (
370                id INT AUTO_INCREMENT PRIMARY KEY,
371                migration VARCHAR(255) NOT NULL,
372                batch INT NOT NULL
373            )"
374        }
375        _ => {
376            "CREATE TABLE IF NOT EXISTS migrations (
377                id INTEGER PRIMARY KEY AUTOINCREMENT,
378                migration TEXT NOT NULL,
379                batch INTEGER NOT NULL
380            )"
381        }
382    };
383
384    sqlx::query(query_str).execute(pool).await?;
385
386    let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
387        .fetch_all(pool)
388        .await?;
389    let executed_set: std::collections::HashSet<String> = executed.into_iter().map(|(m,)| m).collect();
390
391    let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
392        .fetch_one(pool)
393        .await?;
394    let next_batch = batch_row.0.unwrap_or(0) + 1;
395
396    let mut count = 0;
397    for m in migrations {
398        let name = m.name();
399        if !executed_set.contains(name) {
400            println!("Migrating: {}", name);
401            m.up().await?;
402            sqlx::query("INSERT INTO migrations (migration, batch) VALUES (?, ?)")
403                .bind(name)
404                .bind(next_batch)
405                .execute(pool)
406                .await?;
407            println!("Migrated:  {}", name);
408            count += 1;
409        }
410    }
411
412    if count == 0 {
413        println!("Nothing to migrate.");
414    }
415
416    Ok(())
417}
418
419async fn rollback_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
420    let pool = crate::Eloquent::pool();
421    let driver = crate::Eloquent::driver();
422
423    let table_exists = match driver {
424        "postgres" | "mysql" => {
425            let query_str = "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
426            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await.unwrap_or((0,));
427            row.0 > 0
428        }
429        _ => {
430            let query_str = "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
431            let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await.unwrap_or((0,));
432            row.0 > 0
433        }
434    };
435
436    if !table_exists {
437        println!("Nothing to rollback.");
438        return Ok(());
439    }
440
441    let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
442        .fetch_one(pool)
443        .await?;
444    
445    let last_batch = match batch_row.0 {
446        Some(b) if b > 0 => b,
447        _ => {
448            println!("Nothing to rollback.");
449            return Ok(());
450        }
451    };
452
453    let to_rollback: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations WHERE batch = ? ORDER BY id DESC")
454        .bind(last_batch)
455        .fetch_all(pool)
456        .await?;
457
458    let mut rollback_map = std::collections::HashMap::new();
459    for m in migrations {
460        rollback_map.insert(m.name().to_string(), m);
461    }
462
463    for (name,) in to_rollback {
464        if let Some(m) = rollback_map.get(&name) {
465            println!("Rolling back: {}", name);
466            m.down().await?;
467            sqlx::query("DELETE FROM migrations WHERE migration = ?")
468                .bind(&name)
469                .execute(pool)
470                .await?;
471            println!("Rolled back:  {}", name);
472        } else {
473            println!("Warning: migration {} found in database but not in compiled binary.", name);
474        }
475    }
476
477    Ok(())
478}
479
480pub struct JoinClause {
481    pub table: String,
482    pub conditions: Vec<String>,
483    pub bindings: Vec<crate::EloquentValue>,
484}
485
486impl JoinClause {
487    pub fn new(table: &str) -> Self {
488        Self {
489            table: table.to_string(),
490            conditions: vec![],
491            bindings: vec![],
492        }
493    }
494
495    pub fn on(&mut self, first: &str, operator: &str, second: &str) -> &mut Self {
496        self.conditions.push(format!("{} {} {}", first, operator, second));
497        self
498    }
499
500    pub fn on_eq<T: Into<crate::EloquentValue>>(&mut self, column: &str, value: T) -> &mut Self {
501        self.conditions.push(format!("{} = ?", column));
502        self.bindings.push(value.into());
503        self
504    }
505}
506
507pub trait SubqueryBuilder {
508    fn to_sql(&self) -> String;
509    fn bindings(&self) -> &Vec<crate::EloquentValue>;
510}
511
512pub static QUERY_LOGGING: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
513
514pub fn enable_query_log() {
515    QUERY_LOGGING.store(true, std::sync::atomic::Ordering::SeqCst);
516}
517
518pub fn disable_query_log() {
519    QUERY_LOGGING.store(false, std::sync::atomic::Ordering::SeqCst);
520}
521
522pub fn is_query_log_enabled() -> bool {
523    QUERY_LOGGING.load(std::sync::atomic::Ordering::SeqCst)
524}