1use crate::Error;
2
3const ALLOWED_OPERATORS: &[&str] = &["=", "!=", "<>", "<", ">", "<=", ">="];
5
6pub fn validate_identifier(name: &str) -> Result<(), Error> {
10 if name.is_empty() {
11 return Err(Error::Internal(
12 "SQL identifier cannot be empty".to_string(),
13 ));
14 }
15
16 let bytes = name.as_bytes();
17 if bytes[0] == b'.' || bytes[bytes.len() - 1] == b'.' {
18 return Err(Error::Internal(format!(
19 "Invalid SQL identifier '{}': must not start or end with a dot",
20 name
21 )));
22 }
23
24 let mut dot_count = 0;
25 for &b in bytes {
26 if b == b'.' {
27 dot_count += 1;
28 if dot_count > 1 {
29 return Err(Error::Internal(format!(
30 "Invalid SQL identifier '{}': at most one dot is allowed",
31 name
32 )));
33 }
34 } else if !b.is_ascii_alphanumeric() && b != b'_' && b != b'-' {
35 return Err(Error::Internal(format!(
36 "Invalid SQL identifier '{}': only alphanumeric characters, underscores, hyphens and dots are allowed",
37 name
38 )));
39 }
40 }
41
42 Ok(())
43}
44
45pub fn validate_table_name(table_name: &str) -> Result<(), Error> {
47 if table_name.contains('.') {
48 return Err(Error::Internal(format!(
49 "Invalid table name '{}': dots are not allowed in table names",
50 table_name
51 )));
52 }
53 validate_identifier(table_name)
54}
55
56#[derive(Debug, Clone, PartialEq)]
61pub enum ColumnDefault {
62 CurrentTimestamp,
64 Null,
66 Integer(i64),
68 Float(f64),
70 Text(String),
73}
74
75impl ColumnDefault {
76 pub fn to_sql(&self) -> String {
78 match self {
79 ColumnDefault::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
80 ColumnDefault::Null => "NULL".to_string(),
81 ColumnDefault::Integer(n) => n.to_string(),
82 ColumnDefault::Float(f) => format!("{f}"),
83 ColumnDefault::Text(s) => format!("'{}'", s.replace('\'', "''")),
86 }
87 }
88}
89
90pub struct Column {
91 pub name: String,
92 pub col_type: String,
93 pub is_nullable: bool,
94 pub is_primary_key: bool,
95 pub is_auto_increment: bool,
96 pub default_value: Option<ColumnDefault>,
97}
98
99impl Column {
100 pub fn new(name: &str, col_type: &str) -> Self {
107 validate_identifier(name)
108 .unwrap_or_else(|e| panic!("Invalid column name {:?}: {}", name, e));
109 Self {
110 name: name.to_string(),
111 col_type: col_type.to_string(),
112 is_nullable: true,
113 is_primary_key: false,
114 is_auto_increment: false,
115 default_value: None,
116 }
117 }
118
119 pub fn not_null(&mut self) -> &mut Self {
120 self.is_nullable = false;
121 self
122 }
123
124 pub fn nullable(&mut self) -> &mut Self {
125 self.is_nullable = true;
126 self
127 }
128
129 pub fn default(&mut self, val: ColumnDefault) -> &mut Self {
134 self.default_value = Some(val);
135 self
136 }
137
138 pub fn primary(&mut self) -> &mut Self {
139 self.is_primary_key = true;
140 self
141 }
142}
143
144pub struct Blueprint {
145 pub columns: Vec<Column>,
146}
147
148impl Default for Blueprint {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154impl Blueprint {
155 pub fn new() -> Self {
156 Self { columns: vec![] }
157 }
158
159 pub fn id(&mut self) -> &mut Column {
160 self.columns.push(Column {
161 name: "id".to_string(),
162 col_type: "INTEGER".to_string(),
163 is_nullable: false,
164 is_primary_key: true,
165 is_auto_increment: true,
166 default_value: None,
167 });
168 self.columns
169 .last_mut()
170 .expect("BUG: columns is empty after push")
171 }
172
173 fn add_column(&mut self, name: &str, col_type: &str) -> &mut Column {
174 let col = Column::new(name, col_type);
175 self.columns.push(col);
176 self.columns
177 .last_mut()
178 .expect("BUG: columns is empty after push")
179 }
180
181 pub fn string(&mut self, name: &str) -> &mut Column {
182 self.add_column(name, "TEXT")
183 }
184
185 pub fn integer(&mut self, name: &str) -> &mut Column {
186 self.add_column(name, "INTEGER")
187 }
188
189 pub fn float(&mut self, name: &str) -> &mut Column {
190 self.add_column(name, "REAL")
191 }
192
193 pub fn boolean(&mut self, name: &str) -> &mut Column {
194 self.add_column(name, "INTEGER")
195 }
196
197 pub fn timestamps(&mut self) {
198 let mut created = Column::new("created_at", "TEXT");
199 created.default(ColumnDefault::CurrentTimestamp);
200 self.columns.push(created);
201
202 let mut updated = Column::new("updated_at", "TEXT");
203 updated.default(ColumnDefault::CurrentTimestamp);
204 self.columns.push(updated);
205 }
206
207 pub fn soft_deletes(&mut self) {
208 let col = Column::new("deleted_at", "TEXT");
209 self.columns.push(col);
210 self.columns
211 .last_mut()
212 .expect("BUG: columns is empty after push")
213 .nullable();
214 }
215
216 pub fn build(&self) -> Result<String, Error> {
217 let mut defs = vec![];
218 for col in &self.columns {
219 validate_identifier(&col.name)?;
222 let mut def = format!("{} {}", col.name, col.col_type);
223 if col.is_primary_key {
224 def.push_str(" PRIMARY KEY");
225 }
226 if col.is_auto_increment {
227 def.push_str(" AUTOINCREMENT");
228 }
229 if !col.is_nullable && !col.is_primary_key {
230 def.push_str(" NOT NULL");
231 }
232 if let Some(default) = &col.default_value {
233 use std::fmt::Write;
234 write!(def, " DEFAULT {}", default.to_sql()).unwrap();
235 }
236 defs.push(def);
237 }
238 Ok(defs.join(",\n "))
239 }
240}
241
242pub struct Schema;
243
244impl Schema {
245 pub async fn create<F>(table_name: &str, callback: F) -> Result<(), Error>
246 where
247 F: FnOnce(&mut Blueprint),
248 {
249 validate_table_name(table_name)?;
250
251 let mut blueprint = Blueprint::new();
252 callback(&mut blueprint);
253
254 let columns_sql = blueprint.build()?;
257 let sql = format!(
258 "CREATE TABLE IF NOT EXISTS {} (\n {}\n);",
259 table_name, columns_sql
260 );
261
262 let pool = crate::Orm::pool();
263 let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
264 query_builder.push(&sql);
265 query_builder.build().execute(pool).await?;
266
267 Ok(())
268 }
269
270 pub async fn drop_if_exists(table_name: &str) -> Result<(), Error> {
271 validate_table_name(table_name)?;
272
273 let sql = format!("DROP TABLE IF EXISTS {};", table_name);
274 let pool = crate::Orm::pool();
275 let mut query_builder = sqlx::query_builder::QueryBuilder::new("");
276 query_builder.push(&sql);
277 query_builder.build().execute(pool).await?;
278 Ok(())
279 }
280}
281
282#[async_trait::async_trait]
283pub trait Migration: Send + Sync {
284 fn name(&self) -> &'static str;
285 async fn up(&self) -> Result<(), Error>;
286 async fn down(&self) -> Result<(), Error>;
287}
288
289pub async fn run_artisan_with_args(
290 args: &[String],
291 migrations: Vec<Box<dyn Migration>>,
292 seeders: Vec<Box<dyn crate::Seeder>>,
293) -> Result<(), Error> {
294 if args.len() < 2 {
295 println!("Rullst ORM Artisan CLI");
296 println!("Usage:");
297 println!(" make:migration <name> Generate a new migration");
298 println!(" migrate Run all pending migrations");
299 println!(" migrate:rollback Rollback the last batch of migrations");
300 println!(" status Show migrations status");
301 println!(" db:seed Populate the database with seeders");
302 return Ok(());
303 }
304
305 let command = &args[1];
306 match command.as_str() {
307 "make:migration" => {
308 if args.len() < 3 {
309 println!("Error: migration name is required.");
310 return Ok(());
311 }
312 let name = &args[2];
313 create_migration_files(name)?;
314 }
315 "migrate" | "db:migrate" => {
316 run_migrations(migrations).await?;
317 }
318 "migrate:rollback" | "db:rollback" => {
319 rollback_migrations(migrations).await?;
320 }
321 "status" | "db:status" => {
322 status_migrations(migrations).await?;
323 }
324 "db:seed" => {
325 println!("Seeding database...");
326 crate::Orm::seed(seeders).await?;
327 println!("Database seeded successfully!");
328 }
329 _ => {
330 println!("Unknown command: {}", command);
331 }
332 }
333 Ok(())
334}
335
336pub async fn run_artisan(
337 migrations: Vec<Box<dyn Migration>>,
338 seeders: Vec<Box<dyn crate::Seeder>>,
339) -> Result<(), Error> {
340 let args: Vec<String> = std::env::args().collect();
341 run_artisan_with_args(&args, migrations, seeders).await
342}
343
344async fn status_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
345 let pool = crate::Orm::pool();
346 let driver = crate::Orm::driver();
347
348 let table_exists = match driver {
349 "postgres" | "mysql" => {
350 let query_str =
351 "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
352 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
353 row.0 > 0
354 }
355 _ => {
356 let query_str =
357 "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
358 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
359 row.0 > 0
360 }
361 };
362
363 let executed_set = if table_exists {
364 let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
365 .fetch_all(pool)
366 .await?;
367 executed
368 .into_iter()
369 .map(|(m,)| m)
370 .collect::<std::collections::HashSet<String>>()
371 } else {
372 std::collections::HashSet::new()
373 };
374
375 let name_header = "Migration Name";
376 let status_header = "Status";
377 println!("{name_header:<40} | {status_header}");
378 println!("{}", "-".repeat(55));
379 for m in migrations {
380 let name = m.name();
381 let status = if executed_set.contains(name) {
382 "Applied"
383 } else {
384 "Pending"
385 };
386 println!("{:<40} | {}", name, status);
387 }
388
389 Ok(())
390}
391
392fn create_migration_files(name: &str) -> Result<(), Error> {
393 validate_table_name(name)?;
394 use std::fs;
395
396 let now = std::time::SystemTime::now()
397 .duration_since(std::time::UNIX_EPOCH)
398 .expect("System time went backwards")
399 .as_secs()
400 .to_string();
401 let snake_name = name.to_lowercase().replace("-", "_");
402 let file_name = format!("m{}_{}", now, snake_name);
403
404 fs::create_dir_all("src/migrations")
405 .map_err(|e| Error::Internal(format!("Failed to create migrations directory: {}", e)))?;
406
407 let new_file_path = format!("src/migrations/{}.rs", file_name);
408 let migration_code = format!(
409 r#"use rullst_orm::schema::{{Schema, Blueprint, Migration}};
410use rullst_orm::async_trait;
411
412pub struct MigrationImpl;
413
414#[async_trait]
415impl Migration for MigrationImpl {{
416 fn name(&self) -> &'static str {{
417 "m{timestamp}_{name}"
418 }}
419
420 async fn up(&self) -> Result<(), crate::Error> {{
421 Schema::create("{name}", |table| {{
422 table.id();
423 table.timestamps();
424 }}).await
425 }}
426
427 async fn down(&self) -> Result<(), crate::Error> {{
428 Schema::drop_if_exists("{name}").await
429 }}
430}}
431"#,
432 timestamp = now,
433 name = snake_name
434 );
435
436 fs::write(&new_file_path, migration_code)
437 .map_err(|e| Error::Internal(format!("Failed to write migration file: {}", e)))?;
438 println!("Created migration file: {}", new_file_path);
439
440 regenerate_migrations_mod()?;
441
442 Ok(())
443}
444
445fn regenerate_migrations_mod() -> Result<(), Error> {
446 use std::fs;
447 let paths = fs::read_dir("src/migrations")
448 .map_err(|e| Error::Internal(format!("Failed to read migrations dir: {}", e)))?;
449
450 let mut modules = vec![];
451 for path in paths {
452 let path = path.map_err(|e| Error::Internal(e.to_string()))?.path();
453 if let Some(ext) = path.extension()
454 && ext == "rs"
455 && let Some(stem) = path.file_stem()
456 {
457 let stem_str = stem.to_string_lossy().to_string();
458 if stem_str != "mod" && stem_str.starts_with('m') {
459 modules.push(stem_str);
460 }
461 }
462 }
463 modules.sort();
464
465 use std::fmt::Write;
466 let mut mod_content = String::new();
467 mod_content.push_str("// Generated by Rullst ORM Artisan. Do not edit manually.\n\n");
468 for m in &modules {
469 writeln!(mod_content, "pub mod {};", m).unwrap();
470 }
471 mod_content
472 .push_str("\npub fn get_migrations() -> Vec<Box<dyn rullst_orm::schema::Migration>> {\n");
473 mod_content.push_str(" vec![\n");
474 for m in &modules {
475 writeln!(mod_content, " Box::new({}::MigrationImpl),", m).unwrap();
476 }
477 mod_content.push_str(" ]\n");
478 mod_content.push_str("}\n");
479
480 fs::write("src/migrations/mod.rs", mod_content)
481 .map_err(|e| Error::Internal(format!("Failed to write mod.rs: {}", e)))?;
482 println!("Regenerated src/migrations/mod.rs");
483
484 Ok(())
485}
486
487async fn run_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
488 let pool = crate::Orm::pool();
489 let driver = crate::Orm::driver();
490
491 let query_str = match driver {
492 "postgres" => {
493 "CREATE TABLE IF NOT EXISTS migrations (
494 id SERIAL PRIMARY KEY,
495 migration VARCHAR(255) NOT NULL,
496 batch INTEGER NOT NULL
497 )"
498 }
499 "mysql" => {
500 "CREATE TABLE IF NOT EXISTS migrations (
501 id INT AUTO_INCREMENT PRIMARY KEY,
502 migration VARCHAR(255) NOT NULL,
503 batch INT NOT NULL
504 )"
505 }
506 _ => {
507 "CREATE TABLE IF NOT EXISTS migrations (
508 id INTEGER PRIMARY KEY AUTOINCREMENT,
509 migration TEXT NOT NULL,
510 batch INTEGER NOT NULL
511 )"
512 }
513 };
514
515 sqlx::query(query_str).execute(pool).await?;
516
517 let executed: Vec<(String,)> = sqlx::query_as("SELECT migration FROM migrations")
518 .fetch_all(pool)
519 .await?;
520 let executed_set: std::collections::HashSet<String> =
521 executed.into_iter().map(|(m,)| m).collect();
522
523 let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
524 .fetch_one(pool)
525 .await?;
526 let next_batch = batch_row.0.unwrap_or(0) + 1;
527
528 let mut count = 0;
529 let mut successful_migrations = vec![];
530 for m in migrations {
531 let name = m.name();
532 if !executed_set.contains(name) {
533 println!("Migrating: {}", name);
534 m.up().await?;
535 successful_migrations.push(name);
536 println!("Migrated: {}", name);
537 count += 1;
538 }
539 }
540
541 if count > 0 {
542 let mut query_builder =
543 sqlx::query_builder::QueryBuilder::new("INSERT INTO migrations (migration, batch) ");
544 query_builder.push_values(successful_migrations, |mut b, name| {
545 b.push_bind(name).push_bind(next_batch);
546 });
547 query_builder.build().execute(pool).await?;
548 } else {
549 println!("Nothing to migrate.");
550 }
551
552 Ok(())
553}
554
555async fn rollback_migrations(migrations: Vec<Box<dyn Migration>>) -> Result<(), Error> {
556 let pool = crate::Orm::pool();
557 let driver = crate::Orm::driver();
558
559 let table_exists = match driver {
560 "postgres" | "mysql" => {
561 let query_str =
562 "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'";
563 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
564 row.0 > 0
565 }
566 _ => {
567 let query_str =
568 "SELECT COUNT(*) FROM sqlite_schema WHERE type='table' AND name='migrations'";
569 let row: (i64,) = sqlx::query_as(query_str).fetch_one(pool).await?;
570 row.0 > 0
571 }
572 };
573
574 if !table_exists {
575 println!("Nothing to rollback.");
576 return Ok(());
577 }
578
579 let batch_row: (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
580 .fetch_one(pool)
581 .await?;
582
583 let last_batch = match batch_row.0 {
584 Some(b) if b > 0 => b,
585 _ => {
586 println!("Nothing to rollback.");
587 return Ok(());
588 }
589 };
590
591 let to_rollback: Vec<(String,)> =
592 sqlx::query_as("SELECT migration FROM migrations WHERE batch = ? ORDER BY id DESC")
593 .bind(last_batch)
594 .fetch_all(pool)
595 .await?;
596
597 let mut rollback_map = std::collections::HashMap::with_capacity(migrations.len());
598 for m in migrations {
599 rollback_map.insert(m.name().to_string(), m);
600 }
601
602 for (name,) in to_rollback {
603 if let Some(m) = rollback_map.get(&name) {
604 println!("Rolling back: {}", name);
605 m.down().await?;
606 sqlx::query("DELETE FROM migrations WHERE migration = ?")
607 .bind(&name)
608 .execute(pool)
609 .await?;
610 println!("Rolled back: {}", name);
611 } else {
612 println!(
613 "Warning: migration {} found in database but not in compiled binary.",
614 name
615 );
616 }
617 }
618
619 Ok(())
620}
621
622pub struct JoinClause {
623 pub table: String,
624 pub conditions: Vec<String>,
625 pub bindings: Vec<crate::RullstValue>,
626 pub errors: Vec<crate::Error>,
627}
628
629impl JoinClause {
630 pub fn new(table: &str) -> Self {
631 Self {
632 table: table.to_string(),
633 conditions: vec![],
634 bindings: vec![],
635 errors: vec![],
636 }
637 }
638
639 pub fn on(&mut self, first: &str, operator: &str, second: &str) -> &mut Self {
644 if let Err(e) = validate_identifier(first) {
645 self.errors.push(crate::Error::Validation(format!(
646 "JoinClause::on — invalid identifier for `first`: {:?}",
647 e
648 )));
649 }
650 if let Err(e) = validate_identifier(second) {
651 self.errors.push(crate::Error::Validation(format!(
652 "JoinClause::on — invalid identifier for `second`: {:?}",
653 e
654 )));
655 }
656 if !ALLOWED_OPERATORS.contains(&operator) {
657 self.errors.push(crate::Error::Validation(format!(
658 "JoinClause::on — invalid operator '{}'. Allowed: {:?}",
659 operator, ALLOWED_OPERATORS
660 )));
661 }
662 self.conditions
663 .push(format!("{} {} {}", first, operator, second));
664 self
665 }
666
667 pub fn on_eq<T: Into<crate::RullstValue>>(&mut self, column: &str, value: T) -> &mut Self {
668 if let Err(e) = validate_identifier(column) {
669 self.errors.push(crate::Error::Validation(format!(
670 "JoinClause::on_eq — invalid identifier for `column`: {:?}",
671 e
672 )));
673 }
674 self.conditions.push(format!("{} = ?", column));
675 self.bindings.push(value.into());
676 self
677 }
678
679 pub fn to_sql(&self) -> String {
680 self.conditions.join(" AND ")
681 }
682}
683
684pub trait SubqueryBuilder {
685 fn to_sql(&self) -> String;
686 fn bindings(&self) -> &Vec<crate::RullstValue>;
687}
688
689pub static QUERY_LOGGING: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
690pub static MAX_QUERY_LIMIT: std::sync::atomic::AtomicUsize =
691 std::sync::atomic::AtomicUsize::new(1000);
692pub static QUERY_TIMEOUT_SECS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(30);
693
694pub fn enable_query_log() {
695 QUERY_LOGGING.store(true, std::sync::atomic::Ordering::SeqCst);
696}
697
698pub fn disable_query_log() {
699 QUERY_LOGGING.store(false, std::sync::atomic::Ordering::SeqCst);
700}
701
702pub fn is_query_log_enabled() -> bool {
703 QUERY_LOGGING.load(std::sync::atomic::Ordering::SeqCst)
704}
705
706pub fn set_max_query_limit(limit: usize) {
707 MAX_QUERY_LIMIT.store(limit, std::sync::atomic::Ordering::SeqCst);
708}
709
710pub fn get_max_query_limit() -> Option<usize> {
711 let limit = MAX_QUERY_LIMIT.load(std::sync::atomic::Ordering::SeqCst);
712 if limit == 0 { None } else { Some(limit) }
713}
714
715pub fn set_query_timeout(secs: u64) {
716 QUERY_TIMEOUT_SECS.store(secs, std::sync::atomic::Ordering::SeqCst);
717}
718
719pub fn get_query_timeout() -> Option<std::time::Duration> {
720 let secs = QUERY_TIMEOUT_SECS.load(std::sync::atomic::Ordering::SeqCst);
721 if secs == 0 {
722 None
723 } else {
724 Some(std::time::Duration::from_secs(secs))
725 }
726}
727
728#[cfg(test)]
729mod tests {
730 use super::*;
731
732 #[test]
733 fn test_enable_disable_query_log() {
734 disable_query_log();
735 assert!(!is_query_log_enabled());
736 enable_query_log();
737 assert!(is_query_log_enabled());
738 disable_query_log();
739 assert!(!is_query_log_enabled());
740 }
741
742 #[test]
743 fn test_join_clause() {
744 let mut jc = JoinClause::new("users");
745 jc.on("users.id", "=", "posts.user_id");
746 assert_eq!(jc.to_sql(), "users.id = posts.user_id");
747 }
748
749 #[test]
750 fn test_validate_table_name() {
751 assert!(validate_table_name("users").is_ok());
752 assert!(validate_table_name("user_posts").is_ok());
753 assert!(validate_table_name("DROP TABLE users").is_err());
754 assert!(validate_table_name("../../../etc/shadow").is_err());
755 assert!(validate_table_name("users.id").is_err());
757 assert!(validate_table_name("").is_err()); }
759
760 #[test]
761 fn test_validate_identifier() {
762 assert!(validate_identifier("users").is_ok());
763 assert!(validate_identifier("users.id").is_ok());
764 assert!(validate_identifier("user_posts").is_ok());
765 assert!(validate_identifier("").is_err());
766 assert!(validate_identifier("users.posts.id").is_err()); assert!(validate_identifier("DROP TABLE users").is_err());
768 assert!(validate_identifier("id; DROP TABLE users--").is_err());
769 assert!(validate_identifier(".").is_err()); assert!(validate_identifier(".users").is_err()); assert!(validate_identifier("users.").is_err()); assert!(validate_identifier("user name").is_err()); assert!(validate_identifier("admin'--").is_err()); assert!(validate_identifier("users()").is_err()); assert!(validate_identifier("a*b").is_err()); assert!(validate_identifier("SELECT * FROM users").is_err());
780 assert!(validate_identifier("users\nWHERE").is_err());
781 assert!(validate_identifier("users\t").is_err());
782 assert!(validate_identifier("\\").is_err());
783 }
784
785 #[test]
786 fn test_join_clause_on_invalid_operator() {
787 let mut jc = JoinClause::new("posts");
788 jc.on("posts.user_id", "OR 1=1 --", "users.id");
789 assert!(!jc.errors.is_empty());
790 assert!(jc.errors[0].to_string().contains("invalid operator"));
791 }
792
793 #[test]
794 fn test_join_clause_on_invalid_column() {
795 let mut jc = JoinClause::new("posts");
796 jc.on("users.id; DROP TABLE users--", "=", "posts.user_id");
797 assert!(!jc.errors.is_empty());
798 assert!(jc.errors[0].to_string().contains("invalid identifier"));
799 }
800
801 #[test]
802 fn test_timestamps_adds_columns() {
803 let mut bp = Blueprint::new();
804 bp.timestamps();
805 assert_eq!(bp.columns.len(), 2);
806 assert_eq!(bp.columns[0].name, "created_at");
807 assert_eq!(bp.columns[1].name, "updated_at");
808 assert_eq!(
809 bp.columns[0].default_value,
810 Some(ColumnDefault::CurrentTimestamp)
811 );
812 assert_eq!(
813 bp.columns[1].default_value,
814 Some(ColumnDefault::CurrentTimestamp)
815 );
816 }
817
818 #[test]
819 fn test_soft_deletes_adds_nullable_column() {
820 let mut bp = Blueprint::new();
821 bp.soft_deletes();
822 assert_eq!(bp.columns.len(), 1);
823 assert_eq!(bp.columns[0].name, "deleted_at");
824 assert!(bp.columns[0].is_nullable);
825 }
826
827 #[test]
828 fn test_blueprint_build_produces_valid_sql() {
829 let mut bp = Blueprint::new();
830 bp.id();
831 bp.string("name").not_null();
832 bp.integer("age");
833 let sql = bp.build().expect("build should succeed for valid columns");
834 assert!(sql.contains("id INTEGER PRIMARY KEY"));
835 assert!(sql.contains("name TEXT NOT NULL"));
836 assert!(sql.contains("age INTEGER"));
837 }
838
839 #[test]
840 fn test_column_default_to_sql_escaping() {
841 let default_text = ColumnDefault::Text("O'Reilly".to_string());
842 assert_eq!(default_text.to_sql(), "'O''Reilly'");
843 }
844
845 #[test]
846 fn test_validate_identifier_multiple_dots() {
847 assert!(validate_identifier("table.column").is_ok()); assert!(validate_identifier("schema.table.column").is_err()); }
850
851 #[test]
852 fn test_column_default_sql_rendering() {
853 assert_eq!(
854 ColumnDefault::CurrentTimestamp.to_sql(),
855 "CURRENT_TIMESTAMP"
856 );
857 assert_eq!(ColumnDefault::Null.to_sql(), "NULL");
858 assert_eq!(ColumnDefault::Integer(42).to_sql(), "42");
859 assert_eq!(ColumnDefault::Float(1.23).to_sql(), "1.23");
860 assert_eq!(ColumnDefault::Text("hello".to_string()).to_sql(), "'hello'");
861 assert_eq!(ColumnDefault::Text("it's".to_string()).to_sql(), "'it''s'");
863 }
864
865 #[test]
866 fn test_join_clause_on_eq_binds_value() {
867 let mut jc = JoinClause::new("orders");
868 jc.on_eq("orders.user_id", 42i32);
869 assert_eq!(jc.to_sql(), "orders.user_id = ?");
870 assert_eq!(jc.bindings.len(), 1);
871 }
872
873 #[test]
874 fn test_join_clause_multiple_conditions() {
875 let mut jc = JoinClause::new("posts");
876 jc.on("posts.user_id", "=", "users.id");
877 jc.on("posts.status", ">", "users.min_status");
878 assert_eq!(
879 jc.to_sql(),
880 "posts.user_id = users.id AND posts.status > users.min_status"
881 );
882 }
883
884 #[test]
885 fn test_column_builder_methods() {
886 let mut col = Column::new("age", "INTEGER");
887 assert_eq!(col.name, "age");
888 assert_eq!(col.col_type, "INTEGER");
889 assert!(col.is_nullable); assert!(!col.is_primary_key);
891 assert!(!col.is_auto_increment);
892 assert_eq!(col.default_value, None);
893
894 col.not_null();
895 assert!(!col.is_nullable);
896
897 col.nullable();
898 assert!(col.is_nullable);
899
900 col.primary();
901 assert!(col.is_primary_key);
902
903 col.default(ColumnDefault::Integer(18));
904 assert_eq!(col.default_value, Some(ColumnDefault::Integer(18)));
905 }
906
907 #[tokio::test]
908 async fn test_db_migration_error_state_invalid_blueprint() {
909 let result = Schema::create("invalid; DROP TABLE users", |bp| {
910 bp.id();
911 })
912 .await;
913
914 assert!(result.is_err());
915 }
916
917 #[tokio::test]
918 async fn test_drop_if_exists_invalid_table() {
919 let result = Schema::drop_if_exists("invalid; name").await;
920 assert!(result.is_err());
921 assert!(matches!(result, Err(crate::Error::Internal(_))));
922 }
923}