1use sqlx::Error;
2
3fn validate_table_name(table_name: &str) -> Result<(), Error> {
6 if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '-') {
7 return Err(Error::Protocol(format!(
8 "Invalid table name '{}': only alphanumeric characters, underscores, and hyphens are allowed",
9 table_name
10 )));
11 }
12 if table_name.is_empty() {
13 return Err(Error::Protocol("Table name cannot be empty".to_string()));
14 }
15 Ok(())
16}
17
18pub struct Column {
19 pub name: String,
20 pub col_type: String,
21 pub is_nullable: bool,
22 pub is_primary_key: bool,
23 pub is_auto_increment: bool,
24 pub default_value: Option<String>,
25}
26
27impl Column {
28 pub fn new(name: &str, col_type: &str) -> Self {
29 Self {
30 name: name.to_string(),
31 col_type: col_type.to_string(),
32 is_nullable: true,
33 is_primary_key: false,
34 is_auto_increment: false,
35 default_value: None,
36 }
37 }
38
39 pub fn not_null(&mut self) -> &mut Self {
40 self.is_nullable = false;
41 self
42 }
43
44 pub fn nullable(&mut self) -> &mut Self {
45 self.is_nullable = true;
46 self
47 }
48
49 pub fn default(&mut self, val: &str) -> &mut Self {
50 self.default_value = Some(val.to_string());
51 self
52 }
53
54 pub fn primary(&mut self) -> &mut Self {
55 self.is_primary_key = true;
56 self
57 }
58}
59
60pub struct Blueprint {
61 pub columns: Vec<Column>,
62}
63
64impl Default for Blueprint {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl Blueprint {
71 pub fn new() -> Self {
72 Self { columns: vec![] }
73 }
74
75 pub fn id(&mut self) -> &mut Column {
76 self.columns.push(Column {
77 name: "id".to_string(),
78 col_type: "INTEGER".to_string(),
79 is_nullable: false,
80 is_primary_key: true,
81 is_auto_increment: true,
82 default_value: None,
83 });
84 self.columns.last_mut().expect("BUG: columns is empty after push")
85 }
86
87 pub fn string(&mut self, name: &str) -> &mut Column {
88 let col = Column::new(name, "TEXT");
89 self.columns.push(col);
90 self.columns.last_mut().expect("BUG: columns is empty after push")
91 }
92
93 pub fn integer(&mut self, name: &str) -> &mut Column {
94 let col = Column::new(name, "INTEGER");
95 self.columns.push(col);
96 self.columns.last_mut().expect("BUG: columns is empty after push")
97 }
98
99 pub fn float(&mut self, name: &str) -> &mut Column {
100 let col = Column::new(name, "REAL");
101 self.columns.push(col);
102 self.columns.last_mut().expect("BUG: columns is empty after push")
103 }
104
105 pub fn boolean(&mut self, name: &str) -> &mut Column {
106 let col = Column::new(name, "INTEGER");
107 self.columns.push(col);
108 self.columns.last_mut().expect("BUG: columns is empty after push")
109 }
110
111 pub fn timestamps(&mut self) {
112 let mut created = Column::new("created_at", "TEXT");
113 created.default("CURRENT_TIMESTAMP");
114 self.columns.push(created);
115
116 let mut updated = Column::new("updated_at", "TEXT");
117 updated.default("CURRENT_TIMESTAMP");
118 self.columns.push(updated);
119 }
120
121 pub fn soft_deletes(&mut self) {
122 let col = Column::new("deleted_at", "TEXT");
123 self.columns.push(col);
124 self.columns.last_mut().expect("BUG: columns is empty after push").nullable();
125 }
126
127 pub fn build(&self) -> String {
128 let mut defs = vec![];
129 for col in &self.columns {
130 let mut def = format!("{} {}", col.name, col.col_type);
131 if col.is_primary_key {
132 def.push_str(" PRIMARY KEY");
133 }
134 if col.is_auto_increment {
135 def.push_str(" AUTOINCREMENT");
136 }
137 if !col.is_nullable && !col.is_primary_key {
138 def.push_str(" NOT NULL");
139 }
140 if let Some(val) = &col.default_value {
141 def.push_str(&format!(" DEFAULT {}", val));
142 }
143 defs.push(def);
144 }
145 defs.join(",\n ")
146 }
147}
148
149pub struct Schema;
150
151impl Schema {
152 pub async fn create<F>(table_name: &str, callback: F) -> Result<(), Error>
153 where
154 F: FnOnce(&mut Blueprint),
155 {
156 validate_table_name(table_name)?;
157
158 let mut blueprint = Blueprint::new();
159 callback(&mut blueprint);
160
161 let columns_sql = blueprint.build();
162 let sql = format!("CREATE TABLE IF NOT EXISTS {} (\n {}\n);", table_name, columns_sql);
163
164 let pool = crate::Eloquent::pool();
165 let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
166 query_builder.push(&sql);
167 query_builder.build().execute(pool).await?;
168
169 Ok(())
170 }
171
172 pub async fn drop_if_exists(table_name: &str) -> Result<(), Error> {
173 validate_table_name(table_name)?;
174
175 let sql = format!("DROP TABLE IF EXISTS {};", table_name);
176 let pool = crate::Eloquent::pool();
177 let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
178 query_builder.push(&sql);
179 query_builder.build().execute(pool).await?;
180 Ok(())
181 }
182}
183
184#[async_trait::async_trait]
185pub trait Migration: Send + Sync {
186 fn name(&self) -> &'static str;
187 async fn up(&self) -> Result<(), Error>;
188 async fn down(&self) -> Result<(), Error>;
189}
190
191pub async fn run_artisan_with_args(
192 args: &[String],
193 migrations: Vec<Box<dyn Migration>>,
194 seeders: Vec<Box<dyn crate::Seeder>>
195) -> Result<(), Error> {
196 if args.len() < 2 {
197 println!("Rust Eloquent Artisan CLI");
198 println!("Usage:");
199 println!(" make:migration <name> Generate a new migration");
200 println!(" migrate Run all pending migrations");
201 println!(" migrate:rollback Rollback the last batch of migrations");
202 println!(" status Show migrations status");
203 println!(" db:seed Populate the database with seeders");
204 return Ok(());
205 }
206
207 let command = &args[1];
208 match command.as_str() {
209 "make:migration" => {
210 if args.len() < 3 {
211 println!("Error: migration name is required.");
212 return Ok(());
213 }
214 let name = &args[2];
215 create_migration_files(name)?;
216 }
217 "migrate" | "db:migrate" => {
218 run_migrations(migrations).await?;
219 }
220 "migrate:rollback" | "db:rollback" => {
221 rollback_migrations(migrations).await?;
222 }
223 "status" | "db:status" => {
224 status_migrations(migrations).await?;
225 }
226 "db:seed" => {
227 println!("Seeding database...");
228 crate::Eloquent::seed(seeders).await?;
229 println!("Database seeded successfully!");
230 }
231 _ => {
232 println!("Unknown command: {}", command);
233 }
234 }
235 Ok(())
236}
237
238pub async fn run_artisan(
239 migrations: Vec<Box<dyn Migration>>,
240 seeders: Vec<Box<dyn crate::Seeder>>
241) -> Result<(), Error> {
242 let args: Vec<String> = std::env::args().collect();
243 run_artisan_with_args(&args, migrations, seeders).await
244}
245
246async fn status_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
247 let pool = crate::Eloquent::pool();
248 let driver = crate::Eloquent::driver();
249
250 let table_exists = match driver {
251 "postgres" | "mysql" => {
252 let query_str = "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
253 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
254 row.0 > 0
255 }
256 _ => {
257 let query_str = "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
258 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
259 row.0 > 0
260 }
261 };
262
263 let executed_set = if table_exists {
264 let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
265 .fetch_all(pool)
266 .await?;
267 executed.into_iter().map(|(m,)| m).collect::<std::collections::HashSet<String>>()
268 } else {
269 std::collections::HashSet::new()
270 };
271
272 let name_header = "Migration Name";
273 let status_header = "Status";
274 println!("{name_header:<40} | {status_header}");
275 println!("{}", "-".repeat(55));
276 for m in migrations {
277 let name = m.name();
278 let status = if executed_set.contains(name) {
279 "Applied"
280 } else {
281 "Pending"
282 };
283 println!("{:<40} | {}", name, status);
284 }
285
286 Ok(())
287}
288
289fn create_migration_files(name: &str) -> Result<(), Error> {
290 validate_table_name(name)?;
291 use std::fs;
292
293 let now = std::time::SystemTime::now()
294 .duration_since(std::time::UNIX_EPOCH)
295 .expect("System time went backwards")
296 .as_secs()
297 .to_string();
298 let snake_name = name.to_lowercase().replace("-", "_");
299 let file_name = format!("m{}_{}", now, snake_name);
300
301 fs::create_dir_all("src/migrations").map_err(|e| {
302 Error::Protocol(format!("Failed to create migrations directory: {}", e))
303 })?;
304
305 let new_file_path = format!("src/migrations/{}.rs", file_name);
306 let migration_code = format!(
307r#"use rullst_orm::schema::{{Schema, Blueprint, Migration}};
308use rullst_orm::async_trait;
309
310pub struct MigrationImpl;
311
312#[async_trait]
313impl Migration for MigrationImpl {{
314 fn name(&self) -> &'static str {{
315 "m{timestamp}_{name}"
316 }}
317
318 async fn up(&self) -> Result<(), rullst_orm::sqlx::Error> {{
319 Schema::create("{name}", |table| {{
320 table.id();
321 table.timestamps();
322 }}).await
323 }}
324
325 async fn down(&self) -> Result<(), rullst_orm::sqlx::Error> {{
326 Schema::drop_if_exists("{name}").await
327 }}
328}}
329"#,
330 timestamp = now,
331 name = snake_name
332 );
333
334 fs::write(&new_file_path, migration_code).map_err(|e| {
335 Error::Protocol(format!("Failed to write migration file: {}", e))
336 })?;
337 println!("Created migration file: {}", new_file_path);
338
339 regenerate_migrations_mod()?;
340
341 Ok(())
342}
343
344fn regenerate_migrations_mod() -> Result<(), Error> {
345 use std::fs;
346 let paths = fs::read_dir("src/migrations").map_err(|e| {
347 Error::Protocol(format!("Failed to read migrations dir: {}", e))
348 })?;
349
350 let mut modules = vec![];
351 for path in paths {
352 let path = path.map_err(|e| Error::Protocol(e.to_string()))?.path();
353 if let Some(ext) = path.extension()
354 && ext == "rs"
355 && let Some(stem) = path.file_stem() {
356 let stem_str = stem.to_string_lossy().to_string();
357 if stem_str != "mod" && stem_str.starts_with('m') {
358 modules.push(stem_str);
359 }
360 }
361 }
362 modules.sort();
363
364 let mut mod_content = String::new();
365 mod_content.push_str("// Generated by Rust Eloquent Artisan. Do not edit manually.\n\n");
366 for m in &modules {
367 mod_content.push_str(&format!("pub mod {};\n", m));
368 }
369 mod_content.push_str("\npub fn get_migrations() -> Vec<Box<dyn rullst_orm::schema::Migration>> {\n");
370 mod_content.push_str(" vec![\n");
371 for m in &modules {
372 mod_content.push_str(&format!(" Box::new({}::MigrationImpl),\n", m));
373 }
374 mod_content.push_str(" ]\n");
375 mod_content.push_str("}\n");
376
377 fs::write("src/migrations/mod.rs", mod_content).map_err(|e| {
378 Error::Protocol(format!("Failed to write mod.rs: {}", e))
379 })?;
380 println!("Regenerated src/migrations/mod.rs");
381
382 Ok(())
383}
384
385async fn run_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
386 let pool = crate::Eloquent::pool();
387 let driver = crate::Eloquent::driver();
388
389 let query_str = match driver {
390 "postgres" => {
391 "CREATE TABLE IF NOT EXISTS migrations (
392 id SERIAL PRIMARY KEY,
393 migration VARCHAR(255) NOT NULL,
394 batch INTEGER NOT NULL
395 )"
396 }
397 "mysql" => {
398 "CREATE TABLE IF NOT EXISTS migrations (
399 id INT AUTO_INCREMENT PRIMARY KEY,
400 migration VARCHAR(255) NOT NULL,
401 batch INT NOT NULL
402 )"
403 }
404 _ => {
405 "CREATE TABLE IF NOT EXISTS migrations (
406 id INTEGER PRIMARY KEY AUTOINCREMENT,
407 migration TEXT NOT NULL,
408 batch INTEGER NOT NULL
409 )"
410 }
411 };
412
413 sqlx::query(query_str).execute(pool).await?;
414
415 let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
416 .fetch_all(pool)
417 .await?;
418 let executed_set: std::collections::HashSet<String> = executed.into_iter().map(|(m,)| m).collect();
419
420 let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
421 .fetch_one(pool)
422 .await?;
423 let next_batch = batch_row.0.unwrap_or(0) + 1;
424
425 let mut count = 0;
426 for m in migrations {
427 let name = m.name();
428 if !executed_set.contains(name) {
429 println!("Migrating: {}", name);
430 m.up().await?;
431 sqlx::query("INSERT INTO migrations (migration, batch) VALUES (?, ?)")
432 .bind(name)
433 .bind(next_batch)
434 .execute(pool)
435 .await?;
436 println!("Migrated: {}", name);
437 count += 1;
438 }
439 }
440
441 if count == 0 {
442 println!("Nothing to migrate.");
443 }
444
445 Ok(())
446}
447
448async fn rollback_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
449 let pool = crate::Eloquent::pool();
450 let driver = crate::Eloquent::driver();
451
452 let table_exists = match driver {
453 "postgres" | "mysql" => {
454 let query_str = "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
455 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
456 row.0 > 0
457 }
458 _ => {
459 let query_str = "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
460 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
461 row.0 > 0
462 }
463 };
464
465 if !table_exists {
466 println!("Nothing to rollback.");
467 return Ok(());
468 }
469
470 let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
471 .fetch_one(pool)
472 .await?;
473
474 let last_batch = match batch_row.0 {
475 Some(b) if b > 0 => b,
476 _ => {
477 println!("Nothing to rollback.");
478 return Ok(());
479 }
480 };
481
482 let to_rollback: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations WHERE batch = ? ORDER BY id DESC")
483 .bind(last_batch)
484 .fetch_all(pool)
485 .await?;
486
487 let mut rollback_map = std::collections::HashMap::new();
488 for m in migrations {
489 rollback_map.insert(m.name().to_string(), m);
490 }
491
492 for (name,) in to_rollback {
493 if let Some(m) = rollback_map.get(&name) {
494 println!("Rolling back: {}", name);
495 m.down().await?;
496 sqlx::query("DELETE FROM migrations WHERE migration = ?")
497 .bind(&name)
498 .execute(pool)
499 .await?;
500 println!("Rolled back: {}", name);
501 } else {
502 println!("Warning: migration {} found in database but not in compiled binary.", name);
503 }
504 }
505
506 Ok(())
507}
508
509pub struct JoinClause {
510 pub table: String,
511 pub conditions: Vec<String>,
512 pub bindings: Vec<crate::EloquentValue>,
513}
514
515impl JoinClause {
516 pub fn new(table: &str) -> Self {
517 Self {
518 table: table.to_string(),
519 conditions: vec![],
520 bindings: vec![],
521 }
522 }
523
524 pub fn on(&mut self, first: &str, operator: &str, second: &str) -> &mut Self {
526 self.conditions.push(format!("{} {} {}", first, operator, second));
527 self
528 }
529
530 pub fn on_eq<T: Into<crate::EloquentValue>>(&mut self, column: &str, value: T) -> &mut Self {
531 self.conditions.push(format!("{} = ?", column));
532 self.bindings.push(value.into());
533 self
534 }
535
536 pub fn to_sql(&self) -> String {
537 self.conditions.join(" AND ")
538 }
539}
540
541pub trait SubqueryBuilder {
542 fn to_sql(&self) -> String;
543 fn bindings(&self) -> &Vec<crate::EloquentValue>;
544}
545
546pub static QUERY_LOGGING: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
547
548pub fn enable_query_log() {
549 QUERY_LOGGING.store(true, std::sync::atomic::Ordering::SeqCst);
550}
551
552pub fn disable_query_log() {
553 QUERY_LOGGING.store(false, std::sync::atomic::Ordering::SeqCst);
554}
555
556pub fn is_query_log_enabled() -> bool {
557 QUERY_LOGGING.load(std::sync::atomic::Ordering::SeqCst)
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 #[test]
565 fn test_enable_disable_query_log() {
566 disable_query_log();
567 assert!(!is_query_log_enabled());
568 enable_query_log();
569 assert!(is_query_log_enabled());
570 disable_query_log();
571 assert!(!is_query_log_enabled());
572 }
573
574 #[test]
575 fn test_join_clause() {
576 let mut jc = JoinClause::new("users");
577 jc.on("users.id", "=", "posts.user_id");
578 assert_eq!(jc.to_sql(), "users.id = posts.user_id");
579 }
580
581 #[test]
582 fn test_validate_table_name() {
583 assert!(validate_table_name("users").is_ok());
584 assert!(validate_table_name("user_posts").is_ok());
585 assert!(validate_table_name("DROP TABLE users").is_err());
586 assert!(validate_table_name("../../../etc/shadow").is_err());
587 }
588}