1use crate::Error;
2
3const ALLOWED_OPERATORS: &[&str] = &["=", "!=", "<>", "<", ">", "<=", ">="];
5
6pub fn validate_identifier(name: &str) -> Result<(), Error> {
10 if name.is_empty() {
11 return Err(Error::Internal(
12 "SQL identifier cannot be empty".to_string(),
13 ));
14 }
15
16 let bytes = name.as_bytes();
17 if bytes[0] == b'.' || bytes[bytes.len() - 1] == b'.' {
18 return Err(Error::Internal(format!(
19 "Invalid SQL identifier '{}': must not start or end with a dot",
20 name
21 )));
22 }
23
24 let mut dot_count = 0;
25 for &b in bytes {
26 if b == b'.' {
27 dot_count += 1;
28 if dot_count > 1 {
29 return Err(Error::Internal(format!(
30 "Invalid SQL identifier '{}': at most one dot is allowed",
31 name
32 )));
33 }
34 } else if !b.is_ascii_alphanumeric() && b != b'_' && b != b'-' {
35 return Err(Error::Internal(format!(
36 "Invalid SQL identifier '{}': only alphanumeric characters, underscores, hyphens and dots are allowed",
37 name
38 )));
39 }
40 }
41
42 Ok(())
43}
44
45pub fn validate_table_name(table_name: &str) -> Result<(), Error> {
47 if table_name.contains('.') {
48 return Err(Error::Internal(format!(
49 "Invalid table name '{}': dots are not allowed in table names",
50 table_name
51 )));
52 }
53 validate_identifier(table_name)
54}
55
56#[derive(Debug, Clone, PartialEq)]
61pub enum ColumnDefault {
62 CurrentTimestamp,
64 Null,
66 Integer(i64),
68 Float(f64),
70 Text(String),
73}
74
75impl ColumnDefault {
76 pub fn to_sql(&self) -> String {
78 match self {
79 ColumnDefault::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
80 ColumnDefault::Null => "NULL".to_string(),
81 ColumnDefault::Integer(n) => n.to_string(),
82 ColumnDefault::Float(f) => format!("{f}"),
83 ColumnDefault::Text(s) => format!("'{}'", s.replace('\'', "''")),
86 }
87 }
88}
89
90pub struct Column {
91 pub name: String,
92 pub col_type: String,
93 pub is_nullable: bool,
94 pub is_primary_key: bool,
95 pub is_auto_increment: bool,
96 pub default_value: Option<ColumnDefault>,
97}
98
99impl Column {
100 pub fn new(name: &str, col_type: &str) -> Self {
107 validate_identifier(name)
108 .unwrap_or_else(|e| panic!("Invalid column name {:?}: {}", name, e));
109 Self {
110 name: name.to_string(),
111 col_type: col_type.to_string(),
112 is_nullable: true,
113 is_primary_key: false,
114 is_auto_increment: false,
115 default_value: None,
116 }
117 }
118
119 pub fn not_null(&mut self) -> &mut Self {
120 self.is_nullable = false;
121 self
122 }
123
124 pub fn nullable(&mut self) -> &mut Self {
125 self.is_nullable = true;
126 self
127 }
128
129 pub fn default(&mut self, val: ColumnDefault) -> &mut Self {
134 self.default_value = Some(val);
135 self
136 }
137
138 pub fn primary(&mut self) -> &mut Self {
139 self.is_primary_key = true;
140 self
141 }
142}
143
144pub struct Blueprint {
145 pub columns: Vec<Column>,
146}
147
148impl Default for Blueprint {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154impl Blueprint {
155 pub fn new() -> Self {
156 Self { columns: vec![] }
157 }
158
159 pub fn id(&mut self) -> &mut Column {
160 self.columns.push(Column {
161 name: "id".to_string(),
162 col_type: "INTEGER".to_string(),
163 is_nullable: false,
164 is_primary_key: true,
165 is_auto_increment: true,
166 default_value: None,
167 });
168 self.columns
169 .last_mut()
170 .expect("BUG: columns is empty after push")
171 }
172
173 fn add_column(&mut self, name: &str, col_type: &str) -> &mut Column {
174 let col = Column::new(name, col_type);
175 self.columns.push(col);
176 self.columns
177 .last_mut()
178 .expect("BUG: columns is empty after push")
179 }
180
181 pub fn string(&mut self, name: &str) -> &mut Column {
182 self.add_column(name, "TEXT")
183 }
184
185 pub fn integer(&mut self, name: &str) -> &mut Column {
186 self.add_column(name, "INTEGER")
187 }
188
189 pub fn float(&mut self, name: &str) -> &mut Column {
190 self.add_column(name, "REAL")
191 }
192
193 pub fn boolean(&mut self, name: &str) -> &mut Column {
194 self.add_column(name, "INTEGER")
195 }
196
197 pub fn timestamps(&mut self) {
198 let mut created = Column::new("created_at", "TEXT");
199 created.default(ColumnDefault::CurrentTimestamp);
200 self.columns.push(created);
201
202 let mut updated = Column::new("updated_at", "TEXT");
203 updated.default(ColumnDefault::CurrentTimestamp);
204 self.columns.push(updated);
205 }
206
207 pub fn soft_deletes(&mut self) {
208 let col = Column::new("deleted_at", "TEXT");
209 self.columns.push(col);
210 self.columns
211 .last_mut()
212 .expect("BUG: columns is empty after push")
213 .nullable();
214 }
215
216 #[cfg_attr(test, mutants::skip)]
217 pub fn build(&self) -> Result<String, Error> {
218 let driver = crate::DB_DRIVER
219 .get()
220 .map(|s| s.as_str())
221 .unwrap_or("sqlite");
222 let mut defs = vec![];
223 for col in &self.columns {
224 validate_identifier(&col.name)?;
227
228 let mut col_type_str = col.col_type.clone();
229 if driver == "postgres" && col.is_auto_increment {
230 if col.col_type == "INTEGER" || col.col_type == "INT" {
231 col_type_str = "SERIAL".to_string();
232 } else if col.col_type == "BIGINT" {
233 col_type_str = "BIGSERIAL".to_string();
234 }
235 }
236
237 let mut def = format!("{} {}", col.name, col_type_str);
238 if col.is_primary_key {
239 def.push_str(" PRIMARY KEY");
240 }
241 if col.is_auto_increment {
242 if driver == "sqlite" {
243 def.push_str(" AUTOINCREMENT");
244 } else if driver == "mysql" {
245 def.push_str(" AUTO_INCREMENT");
246 }
247 }
248 if !col.is_nullable && !col.is_primary_key {
249 def.push_str(" NOT NULL");
250 }
251 if let Some(default) = &col.default_value {
252 use std::fmt::Write;
253 write!(def, " DEFAULT {}", default.to_sql()).unwrap();
254 }
255 defs.push(def);
256 }
257 Ok(defs.join(",\n "))
258 }
259}
260
261pub struct Schema;
262
263impl Schema {
264 pub async fn create<F>(table_name: &str, callback: F) -> Result<(), Error>
265 where
266 F: FnOnce(&mut Blueprint),
267 {
268 validate_table_name(table_name)?;
269
270 let mut blueprint = Blueprint::new();
271 callback(&mut blueprint);
272
273 let columns_sql = blueprint.build()?;
276 let sql = format!(
277 "CREATE TABLE IF NOT EXISTS {} (\n {}\n);",
278 table_name, columns_sql
279 );
280
281 let pool = crate::Orm::pool();
282 let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
283 query_builder.push(&sql);
284 query_builder.build().execute(pool).await?;
285
286 Ok(())
287 }
288
289 pub async fn drop_if_exists(table_name: &str) -> Result<(), Error> {
290 validate_table_name(table_name)?;
291
292 let sql = format!("DROP TABLE IF EXISTS {};", table_name);
293 let pool = crate::Orm::pool();
294 let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
295 query_builder.push(&sql);
296 query_builder.build().execute(pool).await?;
297 Ok(())
298 }
299}
300
301#[async_trait::async_trait]
302pub trait Migration: Send + Sync {
303 fn name(&self) -> &'static str;
304 async fn up(&self) -> Result<(), Error>;
305 async fn down(&self) -> Result<(), Error>;
306}
307
308#[cfg_attr(test, mutants::skip)]
309pub async fn run_artisan_with_args(
310 args: &[String],
311 migrations: Vec<Box<dyn Migration>>,
312 seeders: Vec<Box<dyn crate::Seeder>>,
313) -> Result<(), Error> {
314 if args.len() < 2 {
315 println!("Rullst ORM Artisan CLI");
316 println!("Usage:");
317 println!(" make:migration <name> Generate a new migration");
318 println!(" migrate Run all pending migrations");
319 println!(" migrate:rollback Rollback the last batch of migrations");
320 println!(" status Show migrations status");
321 println!(" db:seed Populate the database with seeders");
322 return Ok(());
323 }
324
325 let command = &args[1];
326 match command.as_str() {
327 "make:migration" => {
328 if args.len() < 3 {
329 println!("Error: migration name is required.");
330 return Ok(());
331 }
332 let name = &args[2];
333 create_migration_files(name)?;
334 }
335 "migrate" | "db:migrate" => {
336 run_migrations(migrations).await?;
337 }
338 "migrate:rollback" | "db:rollback" => {
339 rollback_migrations(migrations).await?;
340 }
341 "status" | "db:status" => {
342 status_migrations(migrations).await?;
343 }
344 "db:seed" => {
345 println!("Seeding database...");
346 crate::Orm::seed(seeders).await?;
347 println!("Database seeded successfully!");
348 }
349 _ => {
350 println!("Unknown command: {}", command);
351 }
352 }
353 Ok(())
354}
355
356#[cfg_attr(test, mutants::skip)]
357pub async fn run_artisan(
358 migrations: Vec<Box<dyn Migration>>,
359 seeders: Vec<Box<dyn crate::Seeder>>,
360) -> Result<(), Error> {
361 let args: Vec<String> = std::env::args().collect();
362 run_artisan_with_args(&args, migrations, seeders).await
363}
364
365#[cfg_attr(test, mutants::skip)]
366async fn status_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
367 let pool = crate::Orm::pool();
368 let driver = crate::Orm::driver();
369
370 let table_exists = match driver {
371 "postgres" | "mysql" => {
372 let query_str =
373 "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
374 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
375 row.0 > 0
376 }
377 _ => {
378 let query_str =
379 "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
380 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
381 row.0 > 0
382 }
383 };
384
385 let executed_set = if table_exists {
386 let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
387 .fetch_all(pool)
388 .await?;
389 executed
390 .into_iter()
391 .map(|(m,)| m)
392 .collect::<std::collections::HashSet<String>>()
393 } else {
394 std::collections::HashSet::new()
395 };
396
397 let name_header = "Migration Name";
398 let status_header = "Status";
399 println!("{name_header:<40} | {status_header}");
400 println!("{}", "-".repeat(55));
401 for m in migrations {
402 let name = m.name();
403 let status = if executed_set.contains(name) {
404 "Applied"
405 } else {
406 "Pending"
407 };
408 println!("{:<40} | {}", name, status);
409 }
410
411 Ok(())
412}
413
414#[cfg_attr(test, mutants::skip)]
415fn create_migration_files(name: &str) -> Result<(), Error> {
416 validate_table_name(name)?;
417 use std::fs;
418
419 let now = std::time::SystemTime::now()
420 .duration_since(std::time::UNIX_EPOCH)
421 .expect("System time went backwards")
422 .as_secs()
423 .to_string();
424 let snake_name = name.to_lowercase().replace("-", "_");
425 let file_name = format!("m{}_{}", now, snake_name);
426
427 fs::create_dir_all("src/migrations")
428 .map_err(|e| Error::Internal(format!("Failed to create migrations directory: {}", e)))?;
429
430 let new_file_path = format!("src/migrations/{}.rs", file_name);
431 let template = include_str!("migration_template.rs.txt");
432 let migration_code = template
433 .replace("{timestamp}", &now)
434 .replace("{name}", &snake_name);
435
436 fs::write(&new_file_path, migration_code)
437 .map_err(|e| Error::Internal(format!("Failed to write migration file: {}", e)))?;
438 println!("Created migration file: {}", new_file_path);
439
440 regenerate_migrations_mod()?;
441
442 Ok(())
443}
444
445#[cfg_attr(test, mutants::skip)]
446fn regenerate_migrations_mod() -> Result<(), Error> {
447 use std::fs;
448 let paths = fs::read_dir("src/migrations")
449 .map_err(|e| Error::Internal(format!("Failed to read migrations dir: {}", e)))?;
450
451 let mut modules = vec![];
452 for path in paths {
453 let path = path.map_err(|e| Error::Internal(e.to_string()))?.path();
454 if let Some(ext) = path.extension()
455 && ext == "rs"
456 && let Some(stem) = path.file_stem()
457 {
458 let stem_str = stem.to_string_lossy().to_string();
459 if stem_str != "mod" && stem_str.starts_with('m') {
460 modules.push(stem_str);
461 }
462 }
463 }
464 modules.sort();
465
466 use std::fmt::Write;
467 let mut mod_content = String::new();
468 mod_content.push_str("// Generated by Rullst ORM Artisan. Do not edit manually.\n\n");
469 for m in &modules {
470 writeln!(mod_content, "pub mod {};", m).unwrap();
471 }
472 mod_content
473 .push_str("\npub fn get_migrations() -> Vec<Box<dyn rullst_orm::schema::Migration>> {\n");
474 mod_content.push_str(" vec![\n");
475 for m in &modules {
476 writeln!(mod_content, " Box::new({}::MigrationImpl),", m).unwrap();
477 }
478 mod_content.push_str(" ]\n");
479 mod_content.push_str("}\n");
480
481 fs::write("src/migrations/mod.rs", mod_content)
482 .map_err(|e| Error::Internal(format!("Failed to write mod.rs: {}", e)))?;
483 println!("Regenerated src/migrations/mod.rs");
484
485 Ok(())
486}
487
488#[cfg_attr(test, mutants::skip)]
489async fn run_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
490 let pool = crate::Orm::pool();
491 let driver = crate::Orm::driver();
492
493 let query_str = match driver {
494 "postgres" => {
495 "CREATE TABLE IF NOT EXISTS migrations (
496 id SERIAL PRIMARY KEY,
497 migration VARCHAR(255) NOT NULL,
498 batch INTEGER NOT NULL
499 )"
500 }
501 "mysql" => {
502 "CREATE TABLE IF NOT EXISTS migrations (
503 id INT AUTO_INCREMENT PRIMARY KEY,
504 migration VARCHAR(255) NOT NULL,
505 batch INT NOT NULL
506 )"
507 }
508 _ => {
509 "CREATE TABLE IF NOT EXISTS migrations (
510 id INTEGER PRIMARY KEY AUTOINCREMENT,
511 migration TEXT NOT NULL,
512 batch INTEGER NOT NULL
513 )"
514 }
515 };
516
517 sqlx::query(query_str).execute(pool).await?;
518
519 let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
520 .fetch_all(pool)
521 .await?;
522 let executed_set: std::collections::HashSet<String> =
523 executed.into_iter().map(|(m,)| m).collect();
524
525 let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
526 .fetch_one(pool)
527 .await?;
528 let next_batch = batch_row.0.unwrap_or(0) + 1;
529
530 let mut count = 0;
531 let mut successful_migrations = vec![];
532 for m in migrations {
533 let name = m.name();
534 if !executed_set.contains(name) {
535 println!("Migrating: {}", name);
536 m.up().await?;
537 successful_migrations.push(name);
538 println!("Migrated: {}", name);
539 count += 1;
540 }
541 }
542
543 if count > 0 {
544 let mut query_builder =
545 sqlx::query_builder::QueryBuilder::new("INSERT INTO migrations (migration, batch) ");
546 query_builder.push_values(successful_migrations, |mut b, name| {
547 b.push_bind(name).push_bind(next_batch);
548 });
549 query_builder.build().execute(pool).await?;
550 } else {
551 println!("Nothing to migrate.");
552 }
553
554 Ok(())
555}
556
557#[cfg_attr(test, mutants::skip)]
558async fn rollback_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
559 let pool = crate::Orm::pool();
560 let driver = crate::Orm::driver();
561
562 let table_exists = match driver {
563 "postgres" | "mysql" => {
564 let query_str =
565 "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
566 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
567 row.0 > 0
568 }
569 _ => {
570 let query_str =
571 "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
572 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
573 row.0 > 0
574 }
575 };
576
577 if !table_exists {
578 println!("Nothing to rollback.");
579 return Ok(());
580 }
581
582 let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
583 .fetch_one(pool)
584 .await?;
585
586 let last_batch = match batch_row.0 {
587 Some(b) if b > 0 => b,
588 _ => {
589 println!("Nothing to rollback.");
590 return Ok(());
591 }
592 };
593
594 let to_rollback: Vec<(String,)> =
595 sqlx::query_as("SELECT migration FROM migrations WHERE batch = ? ORDER BY id DESC")
596 .bind(last_batch)
597 .fetch_all(pool)
598 .await?;
599
600 let mut rollback_map = std::collections::HashMap::with_capacity(migrations.len());
601 for m in migrations {
602 rollback_map.insert(m.name().to_string(), m);
603 }
604
605 let mut rolled_back = Vec::with_capacity(to_rollback.len());
606 for (name,) in to_rollback {
607 if let Some(m) = rollback_map.get(&name) {
608 println!("Rolling back: {}", name);
609 m.down().await?;
610 println!("Rolled back: {}", name);
611 rolled_back.push(name);
612 } else {
613 println!(
614 "Warning: migration {} found in database but not in compiled binary.",
615 name
616 );
617 }
618 }
619
620 if !rolled_back.is_empty() {
621 let mut query_builder =
622 sqlx::query_builder::QueryBuilder::new("DELETE FROM migrations WHERE migration IN (");
623 let mut separated = query_builder.separated(", ");
624 for name in rolled_back {
625 separated.push_bind(name);
626 }
627 separated.push_unseparated(")");
628 query_builder.build().execute(pool).await?;
629 }
630
631 Ok(())
632}
633
634pub struct JoinClause {
635 pub table: String,
636 pub conditions: Vec<String>,
637 pub bindings: Vec<crate::RullstValue>,
638 pub errors: Vec<crate::Error>,
639}
640
641impl JoinClause {
642 pub fn new(table: &str) -> Self {
643 Self {
644 table: table.to_string(),
645 conditions: vec![],
646 bindings: vec![],
647 errors: vec![],
648 }
649 }
650
651 pub fn on(&mut self, first: &str, operator: &str, second: &str) -> &mut Self {
656 if let Err(e) = validate_identifier(first) {
657 self.errors.push(crate::Error::Validation(format!(
658 "JoinClause::on — invalid identifier for `first`: {:?}",
659 e
660 )));
661 }
662 if let Err(e) = validate_identifier(second) {
663 self.errors.push(crate::Error::Validation(format!(
664 "JoinClause::on — invalid identifier for `second`: {:?}",
665 e
666 )));
667 }
668 if !ALLOWED_OPERATORS.contains(&operator) {
669 self.errors.push(crate::Error::Validation(format!(
670 "JoinClause::on — invalid operator '{}'. Allowed: {:?}",
671 operator, ALLOWED_OPERATORS
672 )));
673 }
674 self.conditions
675 .push(format!("{} {} {}", first, operator, second));
676 self
677 }
678
679 pub fn on_eq<T: Into<crate::RullstValue>>(&mut self, column: &str, value: T) -> &mut Self {
680 if let Err(e) = validate_identifier(column) {
681 self.errors.push(crate::Error::Validation(format!(
682 "JoinClause::on_eq — invalid identifier for `column`: {:?}",
683 e
684 )));
685 }
686 self.conditions.push(format!("{} = ?", column));
687 self.bindings.push(value.into());
688 self
689 }
690
691 pub fn to_sql(&self) -> String {
692 self.conditions.join(" AND ")
693 }
694}
695
696pub trait SubqueryBuilder {
697 fn to_sql(&self) -> String;
698 fn bindings(&self) -> &Vec<crate::RullstValue>;
699}
700
701pub static QUERY_LOGGING: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
702pub static MAX_QUERY_LIMIT: std::sync::atomic::AtomicUsize =
703 std::sync::atomic::AtomicUsize::new(1000);
704pub static QUERY_TIMEOUT_SECS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(30);
705
706pub fn enable_query_log() {
707 QUERY_LOGGING.store(true, std::sync::atomic::Ordering::SeqCst);
708}
709
710pub fn disable_query_log() {
711 QUERY_LOGGING.store(false, std::sync::atomic::Ordering::SeqCst);
712}
713
714pub fn is_query_log_enabled() -> bool {
715 QUERY_LOGGING.load(std::sync::atomic::Ordering::SeqCst)
716}
717
718pub fn set_max_query_limit(limit: usize) {
719 MAX_QUERY_LIMIT.store(limit, std::sync::atomic::Ordering::SeqCst);
720}
721
722pub fn get_max_query_limit() -> Option<usize> {
723 let limit = MAX_QUERY_LIMIT.load(std::sync::atomic::Ordering::SeqCst);
724 if limit == 0 { None } else { Some(limit) }
725}
726
727pub fn set_query_timeout(secs: u64) {
728 QUERY_TIMEOUT_SECS.store(secs, std::sync::atomic::Ordering::SeqCst);
729}
730
731pub fn get_query_timeout() -> Option<std::time::Duration> {
732 let secs = QUERY_TIMEOUT_SECS.load(std::sync::atomic::Ordering::SeqCst);
733 if secs == 0 {
734 None
735 } else {
736 Some(std::time::Duration::from_secs(secs))
737 }
738}
739
740#[cfg(test)]
741mod tests {
742 use super::*;
743
744 #[test]
745 fn test_enable_disable_query_log() {
746 disable_query_log();
747 assert!(!is_query_log_enabled());
748 enable_query_log();
749 assert!(is_query_log_enabled());
750 disable_query_log();
751 assert!(!is_query_log_enabled());
752 }
753
754 #[test]
755 fn test_join_clause() {
756 let mut jc = JoinClause::new("users");
757 jc.on("users.id", "=", "posts.user_id");
758 assert_eq!(jc.to_sql(), "users.id = posts.user_id");
759 }
760
761 #[test]
762 fn test_validate_table_name() {
763 assert!(validate_table_name("users").is_ok());
764 assert!(validate_table_name("user_posts").is_ok());
765 assert!(validate_table_name("DROP TABLE users").is_err());
766 assert!(validate_table_name("../../../etc/shadow").is_err());
767 assert!(validate_table_name("users.id").is_err());
769 assert!(validate_table_name("").is_err()); }
771
772 #[test]
773 fn test_validate_identifier() {
774 assert!(validate_identifier("users").is_ok());
775 assert!(validate_identifier("users.id").is_ok());
776 assert!(validate_identifier("user_posts").is_ok());
777 assert!(validate_identifier("").is_err());
778 assert!(validate_identifier("users.posts.id").is_err()); assert!(validate_identifier("DROP TABLE users").is_err());
780 assert!(validate_identifier("id; DROP TABLE users--").is_err());
781 assert!(validate_identifier(".").is_err()); assert!(validate_identifier(".users").is_err()); assert!(validate_identifier("users.").is_err()); assert!(validate_identifier("user name").is_err()); assert!(validate_identifier("admin'--").is_err()); assert!(validate_identifier("users()").is_err()); assert!(validate_identifier("a*b").is_err()); assert!(validate_identifier("SELECT * FROM users").is_err());
792 assert!(validate_identifier("users\nWHERE").is_err());
793 assert!(validate_identifier("users\t").is_err());
794 assert!(validate_identifier("\\").is_err());
795 }
796
797 #[test]
798 fn test_join_clause_on_invalid_operator() {
799 let mut jc = JoinClause::new("posts");
800 jc.on("posts.user_id", "OR 1=1 --", "users.id");
801 assert!(!jc.errors.is_empty());
802 assert!(jc.errors[0].to_string().contains("invalid operator"));
803 }
804
805 #[test]
806 fn test_join_clause_on_invalid_column() {
807 let mut jc = JoinClause::new("posts");
808 jc.on("users.id; DROP TABLE users--", "=", "posts.user_id");
809 assert!(!jc.errors.is_empty());
810 assert!(jc.errors[0].to_string().contains("invalid identifier"));
811 }
812
813 #[test]
814 fn test_timestamps_adds_columns() {
815 let mut bp = Blueprint::new();
816 bp.timestamps();
817 assert_eq!(bp.columns.len(), 2);
818 assert_eq!(bp.columns[0].name, "created_at");
819 assert_eq!(bp.columns[1].name, "updated_at");
820 assert_eq!(
821 bp.columns[0].default_value,
822 Some(ColumnDefault::CurrentTimestamp)
823 );
824 assert_eq!(
825 bp.columns[1].default_value,
826 Some(ColumnDefault::CurrentTimestamp)
827 );
828 }
829
830 #[test]
831 fn test_soft_deletes_adds_nullable_column() {
832 let mut bp = Blueprint::new();
833 bp.soft_deletes();
834 assert_eq!(bp.columns.len(), 1);
835 assert_eq!(bp.columns[0].name, "deleted_at");
836 assert!(bp.columns[0].is_nullable);
837 }
838
839 #[test]
840 fn test_blueprint_build_produces_valid_sql() {
841 let mut bp = Blueprint::new();
842 bp.id();
843 bp.string("name").not_null();
844 bp.integer("age");
845 let sql = bp.build().expect("build should succeed for valid columns");
846 assert!(sql.contains("id INTEGER PRIMARY KEY"));
847 assert!(sql.contains("name TEXT NOT NULL"));
848 assert!(sql.contains("age INTEGER"));
849 }
850
851 #[test]
852 fn test_column_default_to_sql_escaping() {
853 let default_text = ColumnDefault::Text("O'Reilly".to_string());
854 assert_eq!(default_text.to_sql(), "'O''Reilly'");
855 }
856
857 #[test]
858 fn test_validate_identifier_multiple_dots() {
859 assert!(validate_identifier("table.column").is_ok()); assert!(validate_identifier("schema.table.column").is_err()); }
862
863 #[test]
864 fn test_column_default_sql_rendering() {
865 assert_eq!(
866 ColumnDefault::CurrentTimestamp.to_sql(),
867 "CURRENT_TIMESTAMP"
868 );
869 assert_eq!(ColumnDefault::Null.to_sql(), "NULL");
870 assert_eq!(ColumnDefault::Integer(42).to_sql(), "42");
871 assert_eq!(ColumnDefault::Float(1.23).to_sql(), "1.23");
872 assert_eq!(ColumnDefault::Text("hello".to_string()).to_sql(), "'hello'");
873 assert_eq!(ColumnDefault::Text("it's".to_string()).to_sql(), "'it''s'");
875 }
876
877 #[test]
878 fn test_join_clause_on_eq_binds_value() {
879 let mut jc = JoinClause::new("orders");
880 jc.on_eq("orders.user_id", 42i32);
881 assert_eq!(jc.to_sql(), "orders.user_id = ?");
882 assert_eq!(jc.bindings.len(), 1);
883 }
884
885 #[test]
886 fn test_join_clause_multiple_conditions() {
887 let mut jc = JoinClause::new("posts");
888 jc.on("posts.user_id", "=", "users.id");
889 jc.on("posts.status", ">", "users.min_status");
890 assert_eq!(
891 jc.to_sql(),
892 "posts.user_id = users.id AND posts.status > users.min_status"
893 );
894 }
895
896 #[test]
897 fn test_column_builder_methods() {
898 let mut col = Column::new("age", "INTEGER");
899 assert_eq!(col.name, "age");
900 assert_eq!(col.col_type, "INTEGER");
901 assert!(col.is_nullable); assert!(!col.is_primary_key);
903 assert!(!col.is_auto_increment);
904 assert_eq!(col.default_value, None);
905
906 col.not_null();
907 assert!(!col.is_nullable);
908
909 col.nullable();
910 assert!(col.is_nullable);
911
912 col.primary();
913 assert!(col.is_primary_key);
914
915 col.default(ColumnDefault::Integer(18));
916 assert_eq!(col.default_value, Some(ColumnDefault::Integer(18)));
917 }
918
919 #[tokio::test]
920 async fn test_db_migration_error_state_invalid_blueprint() {
921 let result = Schema::create("invalid; DROP TABLE users", |bp| {
922 bp.id();
923 })
924 .await;
925
926 assert!(result.is_err());
927 }
928
929 #[tokio::test]
930 async fn test_drop_if_exists_invalid_table() {
931 let result = Schema::drop_if_exists("invalid; name").await;
932 assert!(result.is_err());
933 assert!(matches!(result, Err(crate::Error::Internal(_))));
934 }
935
936 #[test]
937 fn test_max_query_limit_and_timeout_globals() {
938 set_max_query_limit(50);
940 assert_eq!(get_max_query_limit(), Some(50));
941 set_max_query_limit(0);
942 assert_eq!(get_max_query_limit(), None);
943
944 set_query_timeout(10);
946 assert_eq!(
947 get_query_timeout(),
948 Some(std::time::Duration::from_secs(10))
949 );
950 set_query_timeout(0);
951 assert_eq!(get_query_timeout(), None);
952 }
953
954 #[tokio::test]
955 async fn test_run_artisan_entrypoint() {
956 let result = run_artisan(vec![], vec![]).await;
959 assert!(result.is_ok());
960 }
961}