1use sqlx::Error;
2
3fn validate_table_name(table_name: &str) -> Result<(), Error> {
6 if !table_name
7 .chars()
8 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
9 {
10 return Err(Error::Protocol(format!(
11 "Invalid table name '{}': only alphanumeric characters, underscores, and hyphens are allowed",
12 table_name
13 )));
14 }
15 if table_name.is_empty() {
16 return Err(Error::Protocol("Table name cannot be empty".to_string()));
17 }
18 Ok(())
19}
20
21pub struct Column {
22 pub name: String,
23 pub col_type: String,
24 pub is_nullable: bool,
25 pub is_primary_key: bool,
26 pub is_auto_increment: bool,
27 pub default_value: Option<String>,
28}
29
30impl Column {
31 pub fn new(name: &str, col_type: &str) -> Self {
32 Self {
33 name: name.to_string(),
34 col_type: col_type.to_string(),
35 is_nullable: true,
36 is_primary_key: false,
37 is_auto_increment: false,
38 default_value: None,
39 }
40 }
41
42 pub fn not_null(&mut self) -> &mut Self {
43 self.is_nullable = false;
44 self
45 }
46
47 pub fn nullable(&mut self) -> &mut Self {
48 self.is_nullable = true;
49 self
50 }
51
52 pub fn default(&mut self, val: &str) -> &mut Self {
53 self.default_value = Some(val.to_string());
54 self
55 }
56
57 pub fn primary(&mut self) -> &mut Self {
58 self.is_primary_key = true;
59 self
60 }
61}
62
63pub struct Blueprint {
64 pub columns: Vec<Column>,
65}
66
67impl Default for Blueprint {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl Blueprint {
74 pub fn new() -> Self {
75 Self { columns: vec![] }
76 }
77
78 pub fn id(&mut self) -> &mut Column {
79 self.columns.push(Column {
80 name: "id".to_string(),
81 col_type: "INTEGER".to_string(),
82 is_nullable: false,
83 is_primary_key: true,
84 is_auto_increment: true,
85 default_value: None,
86 });
87 self.columns
88 .last_mut()
89 .expect("BUG: columns is empty after push")
90 }
91
92 pub fn string(&mut self, name: &str) -> &mut Column {
93 let col = Column::new(name, "TEXT");
94 self.columns.push(col);
95 self.columns
96 .last_mut()
97 .expect("BUG: columns is empty after push")
98 }
99
100 pub fn integer(&mut self, name: &str) -> &mut Column {
101 let col = Column::new(name, "INTEGER");
102 self.columns.push(col);
103 self.columns
104 .last_mut()
105 .expect("BUG: columns is empty after push")
106 }
107
108 pub fn float(&mut self, name: &str) -> &mut Column {
109 let col = Column::new(name, "REAL");
110 self.columns.push(col);
111 self.columns
112 .last_mut()
113 .expect("BUG: columns is empty after push")
114 }
115
116 pub fn boolean(&mut self, name: &str) -> &mut Column {
117 let col = Column::new(name, "INTEGER");
118 self.columns.push(col);
119 self.columns
120 .last_mut()
121 .expect("BUG: columns is empty after push")
122 }
123
124 pub fn timestamps(&mut self) {
125 let mut created = Column::new("created_at", "TEXT");
126 created.default("CURRENT_TIMESTAMP");
127 self.columns.push(created);
128
129 let mut updated = Column::new("updated_at", "TEXT");
130 updated.default("CURRENT_TIMESTAMP");
131 self.columns.push(updated);
132 }
133
134 pub fn soft_deletes(&mut self) {
135 let col = Column::new("deleted_at", "TEXT");
136 self.columns.push(col);
137 self.columns
138 .last_mut()
139 .expect("BUG: columns is empty after push")
140 .nullable();
141 }
142
143 pub fn build(&self) -> String {
144 let mut defs = vec![];
145 for col in &self.columns {
146 let mut def = format!("{} {}", col.name, col.col_type);
147 if col.is_primary_key {
148 def.push_str(" PRIMARY KEY");
149 }
150 if col.is_auto_increment {
151 def.push_str(" AUTOINCREMENT");
152 }
153 if !col.is_nullable && !col.is_primary_key {
154 def.push_str(" NOT NULL");
155 }
156 if let Some(val) = &col.default_value {
157 def.push_str(&format!(" DEFAULT {}", val));
158 }
159 defs.push(def);
160 }
161 defs.join(",\n ")
162 }
163}
164
165pub struct Schema;
166
167impl Schema {
168 pub async fn create<F>(table_name: &str, callback: F) -> Result<(), Error>
169 where
170 F: FnOnce(&mut Blueprint),
171 {
172 validate_table_name(table_name)?;
173
174 let mut blueprint = Blueprint::new();
175 callback(&mut blueprint);
176
177 let columns_sql = blueprint.build();
178 let sql = format!(
179 "CREATE TABLE IF NOT EXISTS {} (\n {}\n);",
180 table_name, columns_sql
181 );
182
183 let pool = crate::Orm::pool();
184 let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
185 query_builder.push(&sql);
186 query_builder.build().execute(pool).await?;
187
188 Ok(())
189 }
190
191 pub async fn drop_if_exists(table_name: &str) -> Result<(), Error> {
192 validate_table_name(table_name)?;
193
194 let sql = format!("DROP TABLE IF EXISTS {};", table_name);
195 let pool = crate::Orm::pool();
196 let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
197 query_builder.push(&sql);
198 query_builder.build().execute(pool).await?;
199 Ok(())
200 }
201}
202
203#[async_trait::async_trait]
204pub trait Migration: Send + Sync {
205 fn name(&self) -> &'static str;
206 async fn up(&self) -> Result<(), Error>;
207 async fn down(&self) -> Result<(), Error>;
208}
209
210pub async fn run_artisan_with_args(
211 args: &[String],
212 migrations: Vec<Box<dyn Migration>>,
213 seeders: Vec<Box<dyn crate::Seeder>>,
214) -> Result<(), Error> {
215 if args.len() < 2 {
216 println!("Rullst ORM Artisan CLI");
217 println!("Usage:");
218 println!(" make:migration <name> Generate a new migration");
219 println!(" migrate Run all pending migrations");
220 println!(" migrate:rollback Rollback the last batch of migrations");
221 println!(" status Show migrations status");
222 println!(" db:seed Populate the database with seeders");
223 return Ok(());
224 }
225
226 let command = &args[1];
227 match command.as_str() {
228 "make:migration" => {
229 if args.len() < 3 {
230 println!("Error: migration name is required.");
231 return Ok(());
232 }
233 let name = &args[2];
234 create_migration_files(name)?;
235 }
236 "migrate" | "db:migrate" => {
237 run_migrations(migrations).await?;
238 }
239 "migrate:rollback" | "db:rollback" => {
240 rollback_migrations(migrations).await?;
241 }
242 "status" | "db:status" => {
243 status_migrations(migrations).await?;
244 }
245 "db:seed" => {
246 println!("Seeding database...");
247 crate::Orm::seed(seeders).await?;
248 println!("Database seeded successfully!");
249 }
250 _ => {
251 println!("Unknown command: {}", command);
252 }
253 }
254 Ok(())
255}
256
257pub async fn run_artisan(
258 migrations: Vec<Box<dyn Migration>>,
259 seeders: Vec<Box<dyn crate::Seeder>>,
260) -> Result<(), Error> {
261 let args: Vec<String> = std::env::args().collect();
262 run_artisan_with_args(&args, migrations, seeders).await
263}
264
265async fn status_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
266 let pool = crate::Orm::pool();
267 let driver = crate::Orm::driver();
268
269 let table_exists = match driver {
270 "postgres" | "mysql" => {
271 let query_str =
272 "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
273 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
274 row.0 > 0
275 }
276 _ => {
277 let query_str =
278 "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
279 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
280 row.0 > 0
281 }
282 };
283
284 let executed_set = if table_exists {
285 let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
286 .fetch_all(pool)
287 .await?;
288 executed
289 .into_iter()
290 .map(|(m,)| m)
291 .collect::<std::collections::HashSet<String>>()
292 } else {
293 std::collections::HashSet::new()
294 };
295
296 let name_header = "Migration Name";
297 let status_header = "Status";
298 println!("{name_header:<40} | {status_header}");
299 println!("{}", "-".repeat(55));
300 for m in migrations {
301 let name = m.name();
302 let status = if executed_set.contains(name) {
303 "Applied"
304 } else {
305 "Pending"
306 };
307 println!("{:<40} | {}", name, status);
308 }
309
310 Ok(())
311}
312
313fn create_migration_files(name: &str) -> Result<(), Error> {
314 validate_table_name(name)?;
315 use std::fs;
316
317 let now = std::time::SystemTime::now()
318 .duration_since(std::time::UNIX_EPOCH)
319 .expect("System time went backwards")
320 .as_secs()
321 .to_string();
322 let snake_name = name.to_lowercase().replace("-", "_");
323 let file_name = format!("m{}_{}", now, snake_name);
324
325 fs::create_dir_all("src/migrations")
326 .map_err(|e| Error::Protocol(format!("Failed to create migrations directory: {}", e)))?;
327
328 let new_file_path = format!("src/migrations/{}.rs", file_name);
329 let migration_code = format!(
330 r#"use rullst_orm::schema::{{Schema, Blueprint, Migration}};
331use rullst_orm::async_trait;
332
333pub struct MigrationImpl;
334
335#[async_trait]
336impl Migration for MigrationImpl {{
337 fn name(&self) -> &'static str {{
338 "m{timestamp}_{name}"
339 }}
340
341 async fn up(&self) -> Result<(), rullst_orm::sqlx::Error> {{
342 Schema::create("{name}", |table| {{
343 table.id();
344 table.timestamps();
345 }}).await
346 }}
347
348 async fn down(&self) -> Result<(), rullst_orm::sqlx::Error> {{
349 Schema::drop_if_exists("{name}").await
350 }}
351}}
352"#,
353 timestamp = now,
354 name = snake_name
355 );
356
357 fs::write(&new_file_path, migration_code)
358 .map_err(|e| Error::Protocol(format!("Failed to write migration file: {}", e)))?;
359 println!("Created migration file: {}", new_file_path);
360
361 regenerate_migrations_mod()?;
362
363 Ok(())
364}
365
366fn regenerate_migrations_mod() -> Result<(), Error> {
367 use std::fs;
368 let paths = fs::read_dir("src/migrations")
369 .map_err(|e| Error::Protocol(format!("Failed to read migrations dir: {}", e)))?;
370
371 let mut modules = vec![];
372 for path in paths {
373 let path = path.map_err(|e| Error::Protocol(e.to_string()))?.path();
374 if let Some(ext) = path.extension()
375 && ext == "rs"
376 && let Some(stem) = path.file_stem()
377 {
378 let stem_str = stem.to_string_lossy().to_string();
379 if stem_str != "mod" && stem_str.starts_with('m') {
380 modules.push(stem_str);
381 }
382 }
383 }
384 modules.sort();
385
386 let mut mod_content = String::new();
387 mod_content.push_str("// Generated by Rullst ORM Artisan. Do not edit manually.\n\n");
388 for m in &modules {
389 mod_content.push_str(&format!("pub mod {};\n", m));
390 }
391 mod_content
392 .push_str("\npub fn get_migrations() -> Vec<Box<dyn rullst_orm::schema::Migration>> {\n");
393 mod_content.push_str(" vec![\n");
394 for m in &modules {
395 mod_content.push_str(&format!(" Box::new({}::MigrationImpl),\n", m));
396 }
397 mod_content.push_str(" ]\n");
398 mod_content.push_str("}\n");
399
400 fs::write("src/migrations/mod.rs", mod_content)
401 .map_err(|e| Error::Protocol(format!("Failed to write mod.rs: {}", e)))?;
402 println!("Regenerated src/migrations/mod.rs");
403
404 Ok(())
405}
406
407async fn run_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
408 let pool = crate::Orm::pool();
409 let driver = crate::Orm::driver();
410
411 let query_str = match driver {
412 "postgres" => {
413 "CREATE TABLE IF NOT EXISTS migrations (
414 id SERIAL PRIMARY KEY,
415 migration VARCHAR(255) NOT NULL,
416 batch INTEGER NOT NULL
417 )"
418 }
419 "mysql" => {
420 "CREATE TABLE IF NOT EXISTS migrations (
421 id INT AUTO_INCREMENT PRIMARY KEY,
422 migration VARCHAR(255) NOT NULL,
423 batch INT NOT NULL
424 )"
425 }
426 _ => {
427 "CREATE TABLE IF NOT EXISTS migrations (
428 id INTEGER PRIMARY KEY AUTOINCREMENT,
429 migration TEXT NOT NULL,
430 batch INTEGER NOT NULL
431 )"
432 }
433 };
434
435 sqlx::query(query_str).execute(pool).await?;
436
437 let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
438 .fetch_all(pool)
439 .await?;
440 let executed_set: std::collections::HashSet<String> =
441 executed.into_iter().map(|(m,)| m).collect();
442
443 let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
444 .fetch_one(pool)
445 .await?;
446 let next_batch = batch_row.0.unwrap_or(0) + 1;
447
448 let mut count = 0;
449 for m in migrations {
450 let name = m.name();
451 if !executed_set.contains(name) {
452 println!("Migrating: {}", name);
453 m.up().await?;
454 sqlx::query("INSERT INTO migrations (migration, batch) VALUES (?, ?)")
455 .bind(name)
456 .bind(next_batch)
457 .execute(pool)
458 .await?;
459 println!("Migrated: {}", name);
460 count += 1;
461 }
462 }
463
464 if count == 0 {
465 println!("Nothing to migrate.");
466 }
467
468 Ok(())
469}
470
471async fn rollback_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
472 let pool = crate::Orm::pool();
473 let driver = crate::Orm::driver();
474
475 let table_exists = match driver {
476 "postgres" | "mysql" => {
477 let query_str =
478 "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
479 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
480 row.0 > 0
481 }
482 _ => {
483 let query_str =
484 "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
485 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
486 row.0 > 0
487 }
488 };
489
490 if !table_exists {
491 println!("Nothing to rollback.");
492 return Ok(());
493 }
494
495 let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
496 .fetch_one(pool)
497 .await?;
498
499 let last_batch = match batch_row.0 {
500 Some(b) if b > 0 => b,
501 _ => {
502 println!("Nothing to rollback.");
503 return Ok(());
504 }
505 };
506
507 let to_rollback: Vec<(String,)> =
508 sqlx::query_as("SELECT migration FROM migrations WHERE batch = ? ORDER BY id DESC")
509 .bind(last_batch)
510 .fetch_all(pool)
511 .await?;
512
513 let mut rollback_map = std::collections::HashMap::new();
514 for m in migrations {
515 rollback_map.insert(m.name().to_string(), m);
516 }
517
518 for (name,) in to_rollback {
519 if let Some(m) = rollback_map.get(&name) {
520 println!("Rolling back: {}", name);
521 m.down().await?;
522 sqlx::query("DELETE FROM migrations WHERE migration = ?")
523 .bind(&name)
524 .execute(pool)
525 .await?;
526 println!("Rolled back: {}", name);
527 } else {
528 println!(
529 "Warning: migration {} found in database but not in compiled binary.",
530 name
531 );
532 }
533 }
534
535 Ok(())
536}
537
538pub struct JoinClause {
539 pub table: String,
540 pub conditions: Vec<String>,
541 pub bindings: Vec<crate::RullstValue>,
542}
543
544impl JoinClause {
545 pub fn new(table: &str) -> Self {
546 Self {
547 table: table.to_string(),
548 conditions: vec![],
549 bindings: vec![],
550 }
551 }
552
553 pub fn on(&mut self, first: &str, operator: &str, second: &str) -> &mut Self {
555 self.conditions
556 .push(format!("{} {} {}", first, operator, second));
557 self
558 }
559
560 pub fn on_eq<T: Into<crate::RullstValue>>(&mut self, column: &str, value: T) -> &mut Self {
561 self.conditions.push(format!("{} = ?", column));
562 self.bindings.push(value.into());
563 self
564 }
565
566 pub fn to_sql(&self) -> String {
567 self.conditions.join(" AND ")
568 }
569}
570
571pub trait SubqueryBuilder {
572 fn to_sql(&self) -> String;
573 fn bindings(&self) -> &Vec<crate::RullstValue>;
574}
575
576pub static QUERY_LOGGING: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
577
578pub fn enable_query_log() {
579 QUERY_LOGGING.store(true, std::sync::atomic::Ordering::SeqCst);
580}
581
582pub fn disable_query_log() {
583 QUERY_LOGGING.store(false, std::sync::atomic::Ordering::SeqCst);
584}
585
586pub fn is_query_log_enabled() -> bool {
587 QUERY_LOGGING.load(std::sync::atomic::Ordering::SeqCst)
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593
594 #[test]
595 fn test_enable_disable_query_log() {
596 disable_query_log();
597 assert!(!is_query_log_enabled());
598 enable_query_log();
599 assert!(is_query_log_enabled());
600 disable_query_log();
601 assert!(!is_query_log_enabled());
602 }
603
604 #[test]
605 fn test_join_clause() {
606 let mut jc = JoinClause::new("users");
607 jc.on("users.id", "=", "posts.user_id");
608 assert_eq!(jc.to_sql(), "users.id = posts.user_id");
609 }
610
611 #[test]
612 fn test_validate_table_name() {
613 assert!(validate_table_name("users").is_ok());
614 assert!(validate_table_name("user_posts").is_ok());
615 assert!(validate_table_name("DROP TABLE users").is_err());
616 assert!(validate_table_name("../../../etc/shadow").is_err());
617 }
618}