1use crate::Error;
2
3const ALLOWED_OPERATORS: &[&str] = &["=", "!=", "<>", "<", ">", "<=", ">="];
5
6pub fn validate_identifier(name: &str) -> Result<(), Error> {
10 if name.is_empty() {
11 return Err(Error::Internal(
12 "SQL identifier cannot be empty".to_string(),
13 ));
14 }
15 let dot_count = name.chars().filter(|&c| c == '.').count();
17 if dot_count > 1 {
18 return Err(Error::Internal(format!(
19 "Invalid SQL identifier '{}': at most one dot is allowed",
20 name
21 )));
22 }
23 if !name
24 .chars()
25 .all(|c| c.is_alphanumeric() || c == '_' || c == '-' || c == '.')
26 {
27 return Err(Error::Internal(format!(
28 "Invalid SQL identifier '{}': only alphanumeric characters, underscores, hyphens and dots are allowed",
29 name
30 )));
31 }
32 Ok(())
33}
34
35fn validate_table_name(table_name: &str) -> Result<(), Error> {
38 if table_name.contains('.') {
39 return Err(Error::Internal(format!(
40 "Invalid table name '{}': dots are not allowed in table names",
41 table_name
42 )));
43 }
44 validate_identifier(table_name)
45}
46
47pub struct Column {
48 pub name: String,
49 pub col_type: String,
50 pub is_nullable: bool,
51 pub is_primary_key: bool,
52 pub is_auto_increment: bool,
53 pub default_value: Option<String>,
54}
55
56impl Column {
57 pub fn new(name: &str, col_type: &str) -> Self {
58 Self {
59 name: name.to_string(),
60 col_type: col_type.to_string(),
61 is_nullable: true,
62 is_primary_key: false,
63 is_auto_increment: false,
64 default_value: None,
65 }
66 }
67
68 pub fn not_null(&mut self) -> &mut Self {
69 self.is_nullable = false;
70 self
71 }
72
73 pub fn nullable(&mut self) -> &mut Self {
74 self.is_nullable = true;
75 self
76 }
77
78 pub fn default(&mut self, val: &str) -> &mut Self {
79 self.default_value = Some(val.to_string());
80 self
81 }
82
83 pub fn primary(&mut self) -> &mut Self {
84 self.is_primary_key = true;
85 self
86 }
87}
88
89pub struct Blueprint {
90 pub columns: Vec<Column>,
91}
92
93impl Default for Blueprint {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl Blueprint {
100 pub fn new() -> Self {
101 Self { columns: vec![] }
102 }
103
104 pub fn id(&mut self) -> &mut Column {
105 self.columns.push(Column {
106 name: "id".to_string(),
107 col_type: "INTEGER".to_string(),
108 is_nullable: false,
109 is_primary_key: true,
110 is_auto_increment: true,
111 default_value: None,
112 });
113 self.columns
114 .last_mut()
115 .expect("BUG: columns is empty after push")
116 }
117
118 pub fn string(&mut self, name: &str) -> &mut Column {
119 let col = Column::new(name, "TEXT");
120 self.columns.push(col);
121 self.columns
122 .last_mut()
123 .expect("BUG: columns is empty after push")
124 }
125
126 pub fn integer(&mut self, name: &str) -> &mut Column {
127 let col = Column::new(name, "INTEGER");
128 self.columns.push(col);
129 self.columns
130 .last_mut()
131 .expect("BUG: columns is empty after push")
132 }
133
134 pub fn float(&mut self, name: &str) -> &mut Column {
135 let col = Column::new(name, "REAL");
136 self.columns.push(col);
137 self.columns
138 .last_mut()
139 .expect("BUG: columns is empty after push")
140 }
141
142 pub fn boolean(&mut self, name: &str) -> &mut Column {
143 let col = Column::new(name, "INTEGER");
144 self.columns.push(col);
145 self.columns
146 .last_mut()
147 .expect("BUG: columns is empty after push")
148 }
149
150 pub fn timestamps(&mut self) {
151 let mut created = Column::new("created_at", "TEXT");
152 created.default("CURRENT_TIMESTAMP");
153 self.columns.push(created);
154
155 let mut updated = Column::new("updated_at", "TEXT");
156 updated.default("CURRENT_TIMESTAMP");
157 self.columns.push(updated);
158 }
159
160 pub fn soft_deletes(&mut self) {
161 let col = Column::new("deleted_at", "TEXT");
162 self.columns.push(col);
163 self.columns
164 .last_mut()
165 .expect("BUG: columns is empty after push")
166 .nullable();
167 }
168
169 pub fn build(&self) -> String {
170 let mut defs = vec![];
171 for col in &self.columns {
172 let mut def = format!("{} {}", col.name, col.col_type);
173 if col.is_primary_key {
174 def.push_str(" PRIMARY KEY");
175 }
176 if col.is_auto_increment {
177 def.push_str(" AUTOINCREMENT");
178 }
179 if !col.is_nullable && !col.is_primary_key {
180 def.push_str(" NOT NULL");
181 }
182 if let Some(val) = &col.default_value {
183 def.push_str(&format!(" DEFAULT {}", val));
184 }
185 defs.push(def);
186 }
187 defs.join(",\n ")
188 }
189}
190
191pub struct Schema;
192
193impl Schema {
194 pub async fn create<F>(table_name: &str, callback: F) -> Result<(), Error>
195 where
196 F: FnOnce(&mut Blueprint),
197 {
198 validate_table_name(table_name)?;
199
200 let mut blueprint = Blueprint::new();
201 callback(&mut blueprint);
202
203 let columns_sql = blueprint.build();
204 let sql = format!(
205 "CREATE TABLE IF NOT EXISTS {} (\n {}\n);",
206 table_name, columns_sql
207 );
208
209 let pool = crate::Orm::pool();
210 let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
211 query_builder.push(&sql);
212 query_builder.build().execute(pool).await?;
213
214 Ok(())
215 }
216
217 pub async fn drop_if_exists(table_name: &str) -> Result<(), Error> {
218 validate_table_name(table_name)?;
219
220 let sql = format!("DROP TABLE IF EXISTS {};", table_name);
221 let pool = crate::Orm::pool();
222 let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
223 query_builder.push(&sql);
224 query_builder.build().execute(pool).await?;
225 Ok(())
226 }
227}
228
229#[async_trait::async_trait]
230pub trait Migration: Send + Sync {
231 fn name(&self) -> &'static str;
232 async fn up(&self) -> Result<(), Error>;
233 async fn down(&self) -> Result<(), Error>;
234}
235
236pub async fn run_artisan_with_args(
237 args: &[String],
238 migrations: Vec<Box<dyn Migration>>,
239 seeders: Vec<Box<dyn crate::Seeder>>,
240) -> Result<(), Error> {
241 if args.len() < 2 {
242 println!("Rullst ORM Artisan CLI");
243 println!("Usage:");
244 println!(" make:migration <name> Generate a new migration");
245 println!(" migrate Run all pending migrations");
246 println!(" migrate:rollback Rollback the last batch of migrations");
247 println!(" status Show migrations status");
248 println!(" db:seed Populate the database with seeders");
249 return Ok(());
250 }
251
252 let command = &args[1];
253 match command.as_str() {
254 "make:migration" => {
255 if args.len() < 3 {
256 println!("Error: migration name is required.");
257 return Ok(());
258 }
259 let name = &args[2];
260 create_migration_files(name)?;
261 }
262 "migrate" | "db:migrate" => {
263 run_migrations(migrations).await?;
264 }
265 "migrate:rollback" | "db:rollback" => {
266 rollback_migrations(migrations).await?;
267 }
268 "status" | "db:status" => {
269 status_migrations(migrations).await?;
270 }
271 "db:seed" => {
272 println!("Seeding database...");
273 crate::Orm::seed(seeders).await?;
274 println!("Database seeded successfully!");
275 }
276 _ => {
277 println!("Unknown command: {}", command);
278 }
279 }
280 Ok(())
281}
282
283pub async fn run_artisan(
284 migrations: Vec<Box<dyn Migration>>,
285 seeders: Vec<Box<dyn crate::Seeder>>,
286) -> Result<(), Error> {
287 let args: Vec<String> = std::env::args().collect();
288 run_artisan_with_args(&args, migrations, seeders).await
289}
290
291async fn status_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
292 let pool = crate::Orm::pool();
293 let driver = crate::Orm::driver();
294
295 let table_exists = match driver {
296 "postgres" | "mysql" => {
297 let query_str =
298 "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
299 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
300 row.0 > 0
301 }
302 _ => {
303 let query_str =
304 "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
305 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
306 row.0 > 0
307 }
308 };
309
310 let executed_set = if table_exists {
311 let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
312 .fetch_all(pool)
313 .await?;
314 executed
315 .into_iter()
316 .map(|(m,)| m)
317 .collect::<std::collections::HashSet<String>>()
318 } else {
319 std::collections::HashSet::new()
320 };
321
322 let name_header = "Migration Name";
323 let status_header = "Status";
324 println!("{name_header:<40} | {status_header}");
325 println!("{}", "-".repeat(55));
326 for m in migrations {
327 let name = m.name();
328 let status = if executed_set.contains(name) {
329 "Applied"
330 } else {
331 "Pending"
332 };
333 println!("{:<40} | {}", name, status);
334 }
335
336 Ok(())
337}
338
339fn create_migration_files(name: &str) -> Result<(), Error> {
340 validate_table_name(name)?;
341 use std::fs;
342
343 let now = std::time::SystemTime::now()
344 .duration_since(std::time::UNIX_EPOCH)
345 .expect("System time went backwards")
346 .as_secs()
347 .to_string();
348 let snake_name = name.to_lowercase().replace("-", "_");
349 let file_name = format!("m{}_{}", now, snake_name);
350
351 fs::create_dir_all("src/migrations")
352 .map_err(|e| Error::Internal(format!("Failed to create migrations directory: {}", e)))?;
353
354 let new_file_path = format!("src/migrations/{}.rs", file_name);
355 let migration_code = format!(
356 r#"use rullst_orm::schema::{{Schema, Blueprint, Migration}};
357use rullst_orm::async_trait;
358
359pub struct MigrationImpl;
360
361#[async_trait]
362impl Migration for MigrationImpl {{
363 fn name(&self) -> &'static str {{
364 "m{timestamp}_{name}"
365 }}
366
367 async fn up(&self) -> Result<(), crate::Error> {{
368 Schema::create("{name}", |table| {{
369 table.id();
370 table.timestamps();
371 }}).await
372 }}
373
374 async fn down(&self) -> Result<(), crate::Error> {{
375 Schema::drop_if_exists("{name}").await
376 }}
377}}
378"#,
379 timestamp = now,
380 name = snake_name
381 );
382
383 fs::write(&new_file_path, migration_code)
384 .map_err(|e| Error::Internal(format!("Failed to write migration file: {}", e)))?;
385 println!("Created migration file: {}", new_file_path);
386
387 regenerate_migrations_mod()?;
388
389 Ok(())
390}
391
392fn regenerate_migrations_mod() -> Result<(), Error> {
393 use std::fs;
394 let paths = fs::read_dir("src/migrations")
395 .map_err(|e| Error::Internal(format!("Failed to read migrations dir: {}", e)))?;
396
397 let mut modules = vec![];
398 for path in paths {
399 let path = path.map_err(|e| Error::Internal(e.to_string()))?.path();
400 if let Some(ext) = path.extension()
401 && ext == "rs"
402 && let Some(stem) = path.file_stem()
403 {
404 let stem_str = stem.to_string_lossy().to_string();
405 if stem_str != "mod" && stem_str.starts_with('m') {
406 modules.push(stem_str);
407 }
408 }
409 }
410 modules.sort();
411
412 let mut mod_content = String::new();
413 mod_content.push_str("// Generated by Rullst ORM Artisan. Do not edit manually.\n\n");
414 for m in &modules {
415 mod_content.push_str(&format!("pub mod {};\n", m));
416 }
417 mod_content
418 .push_str("\npub fn get_migrations() -> Vec<Box<dyn rullst_orm::schema::Migration>> {\n");
419 mod_content.push_str(" vec![\n");
420 for m in &modules {
421 mod_content.push_str(&format!(" Box::new({}::MigrationImpl),\n", m));
422 }
423 mod_content.push_str(" ]\n");
424 mod_content.push_str("}\n");
425
426 fs::write("src/migrations/mod.rs", mod_content)
427 .map_err(|e| Error::Internal(format!("Failed to write mod.rs: {}", e)))?;
428 println!("Regenerated src/migrations/mod.rs");
429
430 Ok(())
431}
432
433async fn run_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
434 let pool = crate::Orm::pool();
435 let driver = crate::Orm::driver();
436
437 let query_str = match driver {
438 "postgres" => {
439 "CREATE TABLE IF NOT EXISTS migrations (
440 id SERIAL PRIMARY KEY,
441 migration VARCHAR(255) NOT NULL,
442 batch INTEGER NOT NULL
443 )"
444 }
445 "mysql" => {
446 "CREATE TABLE IF NOT EXISTS migrations (
447 id INT AUTO_INCREMENT PRIMARY KEY,
448 migration VARCHAR(255) NOT NULL,
449 batch INT NOT NULL
450 )"
451 }
452 _ => {
453 "CREATE TABLE IF NOT EXISTS migrations (
454 id INTEGER PRIMARY KEY AUTOINCREMENT,
455 migration TEXT NOT NULL,
456 batch INTEGER NOT NULL
457 )"
458 }
459 };
460
461 sqlx::query(query_str).execute(pool).await?;
462
463 let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
464 .fetch_all(pool)
465 .await?;
466 let executed_set: std::collections::HashSet<String> =
467 executed.into_iter().map(|(m,)| m).collect();
468
469 let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
470 .fetch_one(pool)
471 .await?;
472 let next_batch = batch_row.0.unwrap_or(0) + 1;
473
474 let mut count = 0;
475 for m in migrations {
476 let name = m.name();
477 if !executed_set.contains(name) {
478 println!("Migrating: {}", name);
479 m.up().await?;
480 sqlx::query("INSERT INTO migrations (migration, batch) VALUES (?, ?)")
481 .bind(name)
482 .bind(next_batch)
483 .execute(pool)
484 .await?;
485 println!("Migrated: {}", name);
486 count += 1;
487 }
488 }
489
490 if count == 0 {
491 println!("Nothing to migrate.");
492 }
493
494 Ok(())
495}
496
497async fn rollback_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
498 let pool = crate::Orm::pool();
499 let driver = crate::Orm::driver();
500
501 let table_exists = match driver {
502 "postgres" | "mysql" => {
503 let query_str =
504 "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
505 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
506 row.0 > 0
507 }
508 _ => {
509 let query_str =
510 "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
511 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
512 row.0 > 0
513 }
514 };
515
516 if !table_exists {
517 println!("Nothing to rollback.");
518 return Ok(());
519 }
520
521 let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
522 .fetch_one(pool)
523 .await?;
524
525 let last_batch = match batch_row.0 {
526 Some(b) if b > 0 => b,
527 _ => {
528 println!("Nothing to rollback.");
529 return Ok(());
530 }
531 };
532
533 let to_rollback: Vec<(String,)> =
534 sqlx::query_as("SELECT migration FROM migrations WHERE batch = ? ORDER BY id DESC")
535 .bind(last_batch)
536 .fetch_all(pool)
537 .await?;
538
539 let mut rollback_map = std::collections::HashMap::new();
540 for m in migrations {
541 rollback_map.insert(m.name().to_string(), m);
542 }
543
544 for (name,) in to_rollback {
545 if let Some(m) = rollback_map.get(&name) {
546 println!("Rolling back: {}", name);
547 m.down().await?;
548 sqlx::query("DELETE FROM migrations WHERE migration = ?")
549 .bind(&name)
550 .execute(pool)
551 .await?;
552 println!("Rolled back: {}", name);
553 } else {
554 println!(
555 "Warning: migration {} found in database but not in compiled binary.",
556 name
557 );
558 }
559 }
560
561 Ok(())
562}
563
564pub struct JoinClause {
565 pub table: String,
566 pub conditions: Vec<String>,
567 pub bindings: Vec<crate::RullstValue>,
568}
569
570impl JoinClause {
571 pub fn new(table: &str) -> Self {
572 Self {
573 table: table.to_string(),
574 conditions: vec![],
575 bindings: vec![],
576 }
577 }
578
579 pub fn on(&mut self, first: &str, operator: &str, second: &str) -> &mut Self {
588 validate_identifier(first)
589 .unwrap_or_else(|e| panic!("JoinClause::on — invalid identifier for `first`: {}", e));
590 validate_identifier(second)
591 .unwrap_or_else(|e| panic!("JoinClause::on — invalid identifier for `second`: {}", e));
592 if !ALLOWED_OPERATORS.contains(&operator) {
593 panic!(
594 "JoinClause::on — invalid operator '{}'. Allowed: {:?}",
595 operator, ALLOWED_OPERATORS
596 );
597 }
598 self.conditions
599 .push(format!("{} {} {}", first, operator, second));
600 self
601 }
602
603 pub fn on_eq<T: Into<crate::RullstValue>>(&mut self, column: &str, value: T) -> &mut Self {
604 self.conditions.push(format!("{} = ?", column));
605 self.bindings.push(value.into());
606 self
607 }
608
609 pub fn to_sql(&self) -> String {
610 self.conditions.join(" AND ")
611 }
612}
613
614pub trait SubqueryBuilder {
615 fn to_sql(&self) -> String;
616 fn bindings(&self) -> &Vec<crate::RullstValue>;
617}
618
619pub static QUERY_LOGGING: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
620
621pub fn enable_query_log() {
622 QUERY_LOGGING.store(true, std::sync::atomic::Ordering::SeqCst);
623}
624
625pub fn disable_query_log() {
626 QUERY_LOGGING.store(false, std::sync::atomic::Ordering::SeqCst);
627}
628
629pub fn is_query_log_enabled() -> bool {
630 QUERY_LOGGING.load(std::sync::atomic::Ordering::SeqCst)
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636
637 #[test]
638 fn test_enable_disable_query_log() {
639 disable_query_log();
640 assert!(!is_query_log_enabled());
641 enable_query_log();
642 assert!(is_query_log_enabled());
643 disable_query_log();
644 assert!(!is_query_log_enabled());
645 }
646
647 #[test]
648 fn test_join_clause() {
649 let mut jc = JoinClause::new("users");
650 jc.on("users.id", "=", "posts.user_id");
651 assert_eq!(jc.to_sql(), "users.id = posts.user_id");
652 }
653
654 #[test]
655 fn test_validate_table_name() {
656 assert!(validate_table_name("users").is_ok());
657 assert!(validate_table_name("user_posts").is_ok());
658 assert!(validate_table_name("DROP TABLE users").is_err());
659 assert!(validate_table_name("../../../etc/shadow").is_err());
660 assert!(validate_table_name("users.id").is_err());
662 }
663
664 #[test]
665 fn test_validate_identifier() {
666 assert!(validate_identifier("users").is_ok());
667 assert!(validate_identifier("users.id").is_ok());
668 assert!(validate_identifier("user_posts").is_ok());
669 assert!(validate_identifier("").is_err());
670 assert!(validate_identifier("users.posts.id").is_err()); assert!(validate_identifier("DROP TABLE users").is_err());
672 assert!(validate_identifier("id; DROP TABLE users--").is_err());
673 }
674
675 #[test]
676 #[should_panic(expected = "invalid operator")]
677 fn test_join_clause_on_invalid_operator() {
678 let mut jc = JoinClause::new("posts");
679 jc.on("posts.user_id", "OR 1=1 --", "users.id");
680 }
681
682 #[test]
683 #[should_panic(expected = "invalid identifier")]
684 fn test_join_clause_on_invalid_column() {
685 let mut jc = JoinClause::new("posts");
686 jc.on("users.id; DROP TABLE users--", "=", "posts.user_id");
687 }
688
689 #[test]
690 fn test_timestamps_adds_columns() {
691 let mut bp = Blueprint::new();
692 bp.timestamps();
693 assert_eq!(bp.columns.len(), 2);
694 assert_eq!(bp.columns[0].name, "created_at");
695 assert_eq!(bp.columns[1].name, "updated_at");
696 assert!(bp.columns[0].default_value.is_some());
697 assert!(bp.columns[1].default_value.is_some());
698 }
699
700 #[test]
701 fn test_soft_deletes_adds_nullable_column() {
702 let mut bp = Blueprint::new();
703 bp.soft_deletes();
704 assert_eq!(bp.columns.len(), 1);
705 assert_eq!(bp.columns[0].name, "deleted_at");
706 assert!(bp.columns[0].is_nullable);
707 }
708
709 #[test]
710 fn test_blueprint_build_produces_valid_sql() {
711 let mut bp = Blueprint::new();
712 bp.id();
713 bp.string("name").not_null();
714 bp.integer("age");
715 let sql = bp.build();
716 assert!(sql.contains("id INTEGER PRIMARY KEY"));
717 assert!(sql.contains("name TEXT NOT NULL"));
718 assert!(sql.contains("age INTEGER"));
719 }
720
721 #[test]
722 fn test_join_clause_on_eq_binds_value() {
723 let mut jc = JoinClause::new("orders");
724 jc.on_eq("orders.user_id", 42i32);
725 assert_eq!(jc.to_sql(), "orders.user_id = ?");
726 assert_eq!(jc.bindings.len(), 1);
727 }
728
729 #[test]
730 fn test_join_clause_multiple_conditions() {
731 let mut jc = JoinClause::new("posts");
732 jc.on("posts.user_id", "=", "users.id");
733 jc.on("posts.status", ">", "users.min_status");
734 assert_eq!(
735 jc.to_sql(),
736 "posts.user_id = users.id AND posts.status > users.min_status"
737 );
738 }
739}