1use clap::{Parser, Subcommand};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(
7 name = "sqlx-gen",
8 about = "Generate Rust structs from database schema"
9)]
10pub struct Cli {
11 #[command(subcommand)]
12 pub command: Command,
13}
14
15#[derive(Subcommand, Debug)]
16pub enum Command {
17 Generate {
19 #[command(subcommand)]
20 subcommand: GenerateCommand,
21 },
22}
23
24#[derive(Subcommand, Debug)]
25pub enum GenerateCommand {
26 Entities(EntitiesArgs),
28 Crud(CrudArgs),
30}
31
32#[derive(Parser, Debug)]
33pub struct DatabaseArgs {
34 #[arg(short = 'u', long, env = "DATABASE_URL")]
36 pub database_url: String,
37
38 #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
40 pub schemas: Vec<String>,
41}
42
43impl DatabaseArgs {
44 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
45 let url = &self.database_url;
46 if url.starts_with("postgres://") || url.starts_with("postgresql://") {
47 Ok(DatabaseKind::Postgres)
48 } else if url.starts_with("mysql://") {
49 Ok(DatabaseKind::Mysql)
50 } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
51 Ok(DatabaseKind::Sqlite)
52 } else {
53 Err(crate::error::Error::Config(
54 "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix.".to_string(),
55 ))
56 }
57 }
58}
59
60#[derive(Parser, Debug)]
61pub struct EntitiesArgs {
62 #[command(flatten)]
63 pub db: DatabaseArgs,
64
65 #[arg(short = 'o', long, default_value = "src/models")]
67 pub output_dir: PathBuf,
68
69 #[arg(short = 'D', long, value_delimiter = ',')]
71 pub derives: Vec<String>,
72
73 #[arg(short = 'T', long, value_delimiter = ',')]
75 pub type_overrides: Vec<String>,
76
77 #[arg(short = 'S', long)]
79 pub single_file: bool,
80
81 #[arg(short = 't', long, value_delimiter = ',')]
83 pub tables: Option<Vec<String>>,
84
85 #[arg(short = 'x', long, value_delimiter = ',')]
87 pub exclude_tables: Option<Vec<String>>,
88
89 #[arg(short = 'v', long)]
91 pub views: bool,
92
93 #[arg(long, default_value = "chrono")]
95 pub time_crate: TimeCrate,
96
97 #[arg(long, default_value = "alias")]
100 pub domain_style: DomainStyle,
101
102 #[arg(short = 'n', long)]
104 pub dry_run: bool,
105}
106
107impl EntitiesArgs {
108 pub fn parse_type_overrides(&self) -> HashMap<String, String> {
109 self.type_overrides
110 .iter()
111 .filter_map(|s| {
112 let (k, v) = s.split_once('=')?;
113 Some((k.to_string(), v.to_string()))
114 })
115 .collect()
116 }
117
118 pub fn parse_type_overrides_checked(&self) -> crate::error::Result<HashMap<String, String>> {
122 let mut map = HashMap::new();
123 for s in &self.type_overrides {
124 let (k, v) = s.split_once('=').ok_or_else(|| {
125 crate::error::Error::Config(format!(
126 "Invalid --type-overrides entry '{}'. Expected format: sql_type=RustType",
127 s
128 ))
129 })?;
130 if k.is_empty() {
131 return Err(crate::error::Error::Config(format!(
132 "Empty SQL type key in --type-overrides entry '{}'",
133 s
134 )));
135 }
136 if v.trim().is_empty() {
137 return Err(crate::error::Error::Config(format!(
138 "Empty Rust type value in --type-overrides entry '{}'",
139 s
140 )));
141 }
142 syn::parse_str::<syn::Type>(v).map_err(|e| {
143 crate::error::Error::Config(format!(
144 "Invalid Rust type in --type-overrides value '{}': {}",
145 v, e
146 ))
147 })?;
148 map.insert(k.to_string(), v.to_string());
149 }
150 Ok(map)
151 }
152}
153
154#[derive(Parser, Debug)]
155pub struct CrudArgs {
156 #[arg(short = 'f', long)]
158 pub entity_file: PathBuf,
159
160 #[arg(short = 'd', long)]
162 pub db_kind: String,
163
164 #[arg(short = 'e', long)]
167 pub entities_module: Option<String>,
168
169 #[arg(short = 'o', long, default_value = "src/crud")]
171 pub output_dir: PathBuf,
172
173 #[arg(short = 'm', long, value_delimiter = ',')]
175 pub methods: Vec<String>,
176
177 #[arg(short = 'q', long)]
179 pub query_macro: bool,
180
181 #[arg(short = 'p', long, default_value = "private")]
183 pub pool_visibility: PoolVisibility,
184
185 #[arg(short = 'n', long)]
187 pub dry_run: bool,
188}
189
190impl CrudArgs {
191 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
192 match self.db_kind.to_lowercase().as_str() {
193 "postgres" | "postgresql" | "pg" => Ok(DatabaseKind::Postgres),
194 "mysql" => Ok(DatabaseKind::Mysql),
195 "sqlite" => Ok(DatabaseKind::Sqlite),
196 other => Err(crate::error::Error::Config(format!(
197 "Unknown database kind '{}'. Expected: postgres, mysql, sqlite",
198 other
199 ))),
200 }
201 }
202
203 pub fn resolve_entities_module(&self) -> crate::error::Result<String> {
206 match &self.entities_module {
207 Some(m) => Ok(m.clone()),
208 None => module_path_from_file(&self.entity_file),
209 }
210 }
211}
212
213fn module_path_from_file(path: &std::path::Path) -> crate::error::Result<String> {
217 let path_str = path.to_string_lossy().replace('\\', "/");
218
219 let after_src = match path_str.rfind("/src/") {
220 Some(pos) => &path_str[pos + 5..],
221 None if path_str.starts_with("src/") => &path_str[4..],
222 _ => {
223 return Err(crate::error::Error::Config(format!(
224 "Cannot derive module path from '{}': no 'src/' found. Use --entities-module explicitly.",
225 path.display()
226 )));
227 }
228 };
229
230 let without_ext = after_src.strip_suffix(".rs").unwrap_or(after_src);
231 let module = without_ext.strip_suffix("/mod").unwrap_or(without_ext);
232
233 let module_path = format!("crate::{}", module.replace('/', "::"));
234 Ok(module_path)
235}
236
237#[derive(Debug, Clone, Copy, PartialEq, Eq)]
238pub enum DatabaseKind {
239 Postgres,
240 Mysql,
241 Sqlite,
242}
243
244#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
246pub enum DomainStyle {
247 #[default]
249 Alias,
250 Newtype,
253}
254
255impl std::str::FromStr for DomainStyle {
256 type Err = String;
257 fn from_str(s: &str) -> Result<Self, Self::Err> {
258 match s {
259 "alias" => Ok(Self::Alias),
260 "newtype" => Ok(Self::Newtype),
261 other => Err(format!(
262 "Unknown domain style '{}'. Expected: alias, newtype",
263 other
264 )),
265 }
266 }
267}
268
269impl std::fmt::Display for DomainStyle {
270 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271 match self {
272 Self::Alias => write!(f, "alias"),
273 Self::Newtype => write!(f, "newtype"),
274 }
275 }
276}
277
278#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
279pub enum TimeCrate {
280 #[default]
281 Chrono,
282 Time,
283}
284
285impl std::str::FromStr for TimeCrate {
286 type Err = String;
287
288 fn from_str(s: &str) -> Result<Self, Self::Err> {
289 match s {
290 "chrono" => Ok(Self::Chrono),
291 "time" => Ok(Self::Time),
292 other => Err(format!(
293 "Unknown time crate '{}'. Expected: chrono, time",
294 other
295 )),
296 }
297 }
298}
299
300impl std::fmt::Display for TimeCrate {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 match self {
303 Self::Chrono => write!(f, "chrono"),
304 Self::Time => write!(f, "time"),
305 }
306 }
307}
308
309#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
310pub enum PoolVisibility {
311 #[default]
312 Private,
313 Pub,
314 PubCrate,
315}
316
317impl std::str::FromStr for PoolVisibility {
318 type Err = String;
319
320 fn from_str(s: &str) -> Result<Self, Self::Err> {
321 match s {
322 "private" => Ok(Self::Private),
323 "pub" => Ok(Self::Pub),
324 "pub(crate)" => Ok(Self::PubCrate),
325 other => Err(format!(
326 "Unknown pool visibility '{}'. Expected: private, pub, pub(crate)",
327 other
328 )),
329 }
330 }
331}
332
333#[derive(Debug, Clone, Default)]
336pub struct Methods {
337 pub get_all: bool,
338 pub paginate: bool,
339 pub get: bool,
340 pub insert: bool,
341 pub insert_many: bool,
342 pub update: bool,
343 pub overwrite: bool,
344 pub delete: bool,
345}
346
347const ALL_METHODS: &[&str] = &[
348 "get_all",
349 "paginate",
350 "get",
351 "insert",
352 "insert_many",
353 "update",
354 "overwrite",
355 "delete",
356];
357
358impl Methods {
359 pub fn from_list(names: &[String]) -> Result<Self, String> {
361 let mut m = Self::default();
362 for name in names {
363 match name.as_str() {
364 "*" => return Ok(Self::all()),
365 "get_all" => m.get_all = true,
366 "paginate" => m.paginate = true,
367 "get" => m.get = true,
368 "insert" => m.insert = true,
369 "insert_many" => m.insert_many = true,
370 "update" => m.update = true,
371 "overwrite" => m.overwrite = true,
372 "delete" => m.delete = true,
373 other => {
374 return Err(format!(
375 "Unknown method '{}'. Valid values: *, {}",
376 other,
377 ALL_METHODS.join(", ")
378 ))
379 }
380 }
381 }
382 Ok(m)
383 }
384
385 pub fn all() -> Self {
386 Self {
387 get_all: true,
388 paginate: true,
389 get: true,
390 insert: true,
391 insert_many: true,
392 update: true,
393 overwrite: true,
394 delete: true,
395 }
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 fn make_db_args(url: &str) -> DatabaseArgs {
404 DatabaseArgs {
405 database_url: url.to_string(),
406 schemas: vec!["public".into()],
407 }
408 }
409
410 fn make_entities_args_with_overrides(overrides: Vec<&str>) -> EntitiesArgs {
411 EntitiesArgs {
412 db: make_db_args("postgres://localhost/db"),
413 output_dir: PathBuf::from("out"),
414 derives: vec![],
415 type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
416 single_file: false,
417 tables: None,
418 exclude_tables: None,
419 views: false,
420 time_crate: TimeCrate::Chrono,
421 domain_style: DomainStyle::Alias,
422 dry_run: false,
423 }
424 }
425
426 #[test]
429 fn test_postgres_url() {
430 let args = make_db_args("postgres://localhost/db");
431 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
432 }
433
434 #[test]
435 fn test_postgresql_url() {
436 let args = make_db_args("postgresql://localhost/db");
437 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
438 }
439
440 #[test]
441 fn test_postgres_full_url() {
442 let args = make_db_args("postgres://user:pass@host:5432/db");
443 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
444 }
445
446 #[test]
447 fn test_mysql_url() {
448 let args = make_db_args("mysql://localhost/db");
449 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
450 }
451
452 #[test]
453 fn test_mysql_full_url() {
454 let args = make_db_args("mysql://user:pass@host:3306/db");
455 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
456 }
457
458 #[test]
459 fn test_sqlite_url() {
460 let args = make_db_args("sqlite://path.db");
461 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
462 }
463
464 #[test]
465 fn test_sqlite_colon() {
466 let args = make_db_args("sqlite:path.db");
467 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
468 }
469
470 #[test]
471 fn test_sqlite_memory() {
472 let args = make_db_args("sqlite::memory:");
473 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
474 }
475
476 #[test]
477 fn test_http_url_fails() {
478 let args = make_db_args("http://example.com");
479 assert!(args.database_kind().is_err());
480 }
481
482 #[test]
483 fn test_empty_url_fails() {
484 let args = make_db_args("");
485 assert!(args.database_kind().is_err());
486 }
487
488 #[test]
489 fn test_mongo_url_fails() {
490 let args = make_db_args("mongo://localhost");
491 assert!(args.database_kind().is_err());
492 }
493
494 #[test]
495 fn test_uppercase_postgres_fails() {
496 let args = make_db_args("POSTGRES://localhost");
497 assert!(args.database_kind().is_err());
498 }
499
500 #[test]
503 fn test_overrides_empty() {
504 let args = make_entities_args_with_overrides(vec![]);
505 assert!(args.parse_type_overrides().is_empty());
506 }
507
508 #[test]
511 fn test_overrides_checked_empty_ok() {
512 let args = make_entities_args_with_overrides(vec![]);
513 assert!(args.parse_type_overrides_checked().unwrap().is_empty());
514 }
515
516 #[test]
517 fn test_overrides_checked_simple_type() {
518 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]);
519 let map = args.parse_type_overrides_checked().unwrap();
520 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
521 }
522
523 #[test]
524 fn test_overrides_checked_path_type() {
525 let args = make_entities_args_with_overrides(vec!["jsonb=crate::types::MyJson"]);
526 let map = args.parse_type_overrides_checked().unwrap();
527 assert_eq!(map.get("jsonb").unwrap(), "crate::types::MyJson");
528 }
529
530 #[test]
531 fn test_overrides_checked_generic_type() {
532 let args = make_entities_args_with_overrides(vec!["bytea=Vec<u8>"]);
533 assert!(args.parse_type_overrides_checked().is_ok());
534 }
535
536 #[test]
537 fn test_overrides_checked_rejects_injection() {
538 let args = make_entities_args_with_overrides(vec!["jsonb=Vec<u8>; fn pwned() {}"]);
539 let result = args.parse_type_overrides_checked();
540 assert!(
541 result.is_err(),
542 "must reject value that isn't a single Rust type"
543 );
544 }
545
546 #[test]
547 fn test_overrides_checked_rejects_no_equals() {
548 let args = make_entities_args_with_overrides(vec!["noequals"]);
549 assert!(args.parse_type_overrides_checked().is_err());
550 }
551
552 #[test]
553 fn test_overrides_checked_rejects_empty_value() {
554 let args = make_entities_args_with_overrides(vec!["jsonb="]);
555 assert!(args.parse_type_overrides_checked().is_err());
556 }
557
558 #[test]
559 fn test_overrides_checked_rejects_empty_key() {
560 let args = make_entities_args_with_overrides(vec!["=Foo"]);
561 assert!(args.parse_type_overrides_checked().is_err());
562 }
563
564 #[test]
565 fn test_overrides_single() {
566 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]);
567 let map = args.parse_type_overrides();
568 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
569 }
570
571 #[test]
572 fn test_overrides_multiple() {
573 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
574 let map = args.parse_type_overrides();
575 assert_eq!(map.len(), 2);
576 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
577 assert_eq!(map.get("uuid").unwrap(), "MyUuid");
578 }
579
580 #[test]
581 fn test_overrides_malformed_skipped() {
582 let args = make_entities_args_with_overrides(vec!["noequals"]);
583 assert!(args.parse_type_overrides().is_empty());
584 }
585
586 #[test]
587 fn test_overrides_mixed_valid_invalid() {
588 let args = make_entities_args_with_overrides(vec!["good=val", "bad"]);
589 let map = args.parse_type_overrides();
590 assert_eq!(map.len(), 1);
591 assert_eq!(map.get("good").unwrap(), "val");
592 }
593
594 #[test]
595 fn test_overrides_equals_in_value() {
596 let args = make_entities_args_with_overrides(vec!["key=val=ue"]);
597 let map = args.parse_type_overrides();
598 assert_eq!(map.get("key").unwrap(), "val=ue");
599 }
600
601 #[test]
602 fn test_overrides_empty_key() {
603 let args = make_entities_args_with_overrides(vec!["=value"]);
604 let map = args.parse_type_overrides();
605 assert_eq!(map.get("").unwrap(), "value");
606 }
607
608 #[test]
609 fn test_overrides_empty_value() {
610 let args = make_entities_args_with_overrides(vec!["key="]);
611 let map = args.parse_type_overrides();
612 assert_eq!(map.get("key").unwrap(), "");
613 }
614
615 #[test]
618 fn test_exclude_tables_default_none() {
619 let args = make_entities_args_with_overrides(vec![]);
620 assert!(args.exclude_tables.is_none());
621 }
622
623 #[test]
624 fn test_exclude_tables_set() {
625 let mut args = make_entities_args_with_overrides(vec![]);
626 args.exclude_tables = Some(vec![
627 "_migrations".to_string(),
628 "schema_versions".to_string(),
629 ]);
630 assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2);
631 assert!(args
632 .exclude_tables
633 .as_ref()
634 .unwrap()
635 .contains(&"_migrations".to_string()));
636 }
637
638 #[test]
641 fn test_methods_default_all_false() {
642 let m = Methods::default();
643 assert!(!m.get_all);
644 assert!(!m.paginate);
645 assert!(!m.get);
646 assert!(!m.insert);
647 assert!(!m.insert_many);
648 assert!(!m.update);
649 assert!(!m.overwrite);
650 assert!(!m.delete);
651 }
652
653 #[test]
654 fn test_methods_star() {
655 let m = Methods::from_list(&["*".to_string()]).unwrap();
656 assert!(m.get_all);
657 assert!(m.paginate);
658 assert!(m.get);
659 assert!(m.insert);
660 assert!(m.insert_many);
661 assert!(m.update);
662 assert!(m.overwrite);
663 assert!(m.delete);
664 }
665
666 #[test]
667 fn test_methods_single() {
668 let m = Methods::from_list(&["get".to_string()]).unwrap();
669 assert!(m.get);
670 assert!(!m.get_all);
671 assert!(!m.insert);
672 }
673
674 #[test]
675 fn test_methods_multiple() {
676 let m = Methods::from_list(&["get_all".to_string(), "delete".to_string()]).unwrap();
677 assert!(m.get_all);
678 assert!(m.delete);
679 assert!(!m.insert);
680 assert!(!m.paginate);
681 }
682
683 #[test]
684 fn test_methods_unknown_fails() {
685 let result = Methods::from_list(&["unknown".to_string()]);
686 assert!(result.is_err());
687 assert!(result.unwrap_err().contains("Unknown method"));
688 }
689
690 #[test]
691 fn test_methods_all() {
692 let m = Methods::all();
693 assert!(m.get_all);
694 assert!(m.paginate);
695 assert!(m.get);
696 assert!(m.insert);
697 assert!(m.insert_many);
698 assert!(m.update);
699 assert!(m.overwrite);
700 assert!(m.delete);
701 }
702
703 #[test]
704 fn test_parse_overwrite_method() {
705 let m = Methods::from_list(&["overwrite".to_string()]).unwrap();
706 assert!(m.overwrite);
707 assert!(!m.update);
708 }
709
710 #[test]
711 fn test_parse_insert_many_method() {
712 let m = Methods::from_list(&["insert_many".to_string()]).unwrap();
713 assert!(m.insert_many);
714 assert!(!m.insert);
715 assert!(!m.get);
716 }
717
718 #[test]
721 fn test_module_path_simple() {
722 let p = PathBuf::from("src/models/users.rs");
723 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
724 }
725
726 #[test]
727 fn test_module_path_mod_rs() {
728 let p = PathBuf::from("src/models/mod.rs");
729 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models");
730 }
731
732 #[test]
733 fn test_module_path_nested() {
734 let p = PathBuf::from("src/db/entities/agent.rs");
735 assert_eq!(
736 module_path_from_file(&p).unwrap(),
737 "crate::db::entities::agent"
738 );
739 }
740
741 #[test]
742 fn test_module_path_absolute_with_src() {
743 let p = PathBuf::from("/home/user/project/src/models/users.rs");
744 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
745 }
746
747 #[test]
748 fn test_module_path_relative_with_src() {
749 let p = PathBuf::from("../other_project/src/models/users.rs");
750 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
751 }
752
753 #[test]
754 fn test_module_path_no_src_fails() {
755 let p = PathBuf::from("models/users.rs");
756 assert!(module_path_from_file(&p).is_err());
757 }
758
759 #[test]
760 fn test_module_path_deeply_nested_mod() {
761 let p = PathBuf::from("src/a/b/c/mod.rs");
762 assert_eq!(module_path_from_file(&p).unwrap(), "crate::a::b::c");
763 }
764
765 #[test]
766 fn test_module_path_src_root_file() {
767 let p = PathBuf::from("src/lib.rs");
768 assert_eq!(module_path_from_file(&p).unwrap(), "crate::lib");
769 }
770}