1use clap::{Parser, Subcommand};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")]
7pub struct Cli {
8 #[command(subcommand)]
9 pub command: Command,
10}
11
12#[derive(Subcommand, Debug)]
13pub enum Command {
14 Generate {
16 #[command(subcommand)]
17 subcommand: GenerateCommand,
18 },
19}
20
21#[derive(Subcommand, Debug)]
22pub enum GenerateCommand {
23 Entities(EntitiesArgs),
25 Crud(CrudArgs),
27}
28
29#[derive(Parser, Debug)]
30pub struct DatabaseArgs {
31 #[arg(short = 'u', long, env = "DATABASE_URL")]
33 pub database_url: String,
34
35 #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
37 pub schemas: Vec<String>,
38}
39
40impl DatabaseArgs {
41 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
42 let url = &self.database_url;
43 if url.starts_with("postgres://") || url.starts_with("postgresql://") {
44 Ok(DatabaseKind::Postgres)
45 } else if url.starts_with("mysql://") {
46 Ok(DatabaseKind::Mysql)
47 } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
48 Ok(DatabaseKind::Sqlite)
49 } else {
50 Err(crate::error::Error::Config(
51 "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix.".to_string(),
52 ))
53 }
54 }
55}
56
57#[derive(Parser, Debug)]
58pub struct EntitiesArgs {
59 #[command(flatten)]
60 pub db: DatabaseArgs,
61
62 #[arg(short = 'o', long, default_value = "src/models")]
64 pub output_dir: PathBuf,
65
66 #[arg(short = 'D', long, value_delimiter = ',')]
68 pub derives: Vec<String>,
69
70 #[arg(short = 'T', long, value_delimiter = ',')]
72 pub type_overrides: Vec<String>,
73
74 #[arg(short = 'S', long)]
76 pub single_file: bool,
77
78 #[arg(short = 't', long, value_delimiter = ',')]
80 pub tables: Option<Vec<String>>,
81
82 #[arg(short = 'x', long, value_delimiter = ',')]
84 pub exclude_tables: Option<Vec<String>>,
85
86 #[arg(short = 'v', long)]
88 pub views: bool,
89
90 #[arg(short = 'n', long)]
92 pub dry_run: bool,
93}
94
95impl EntitiesArgs {
96 pub fn parse_type_overrides(&self) -> HashMap<String, String> {
97 self.type_overrides
98 .iter()
99 .filter_map(|s| {
100 let (k, v) = s.split_once('=')?;
101 Some((k.to_string(), v.to_string()))
102 })
103 .collect()
104 }
105}
106
107#[derive(Parser, Debug)]
108pub struct CrudArgs {
109 #[arg(short = 'f', long)]
111 pub entity_file: PathBuf,
112
113 #[arg(short = 'd', long)]
115 pub db_kind: String,
116
117 #[arg(short = 'e', long)]
120 pub entities_module: Option<String>,
121
122 #[arg(short = 'o', long, default_value = "src/crud")]
124 pub output_dir: PathBuf,
125
126 #[arg(short = 'm', long, value_delimiter = ',')]
128 pub methods: Vec<String>,
129
130
131 #[arg(short = 'q', long)]
133 pub query_macro: bool,
134
135 #[arg(short = 'p', long, default_value = "private")]
137 pub pool_visibility: PoolVisibility,
138
139 #[arg(short = 'n', long)]
141 pub dry_run: bool,
142}
143
144impl CrudArgs {
145 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
146 match self.db_kind.to_lowercase().as_str() {
147 "postgres" | "postgresql" | "pg" => Ok(DatabaseKind::Postgres),
148 "mysql" => Ok(DatabaseKind::Mysql),
149 "sqlite" => Ok(DatabaseKind::Sqlite),
150 other => Err(crate::error::Error::Config(format!(
151 "Unknown database kind '{}'. Expected: postgres, mysql, sqlite",
152 other
153 ))),
154 }
155 }
156
157 pub fn resolve_entities_module(&self) -> crate::error::Result<String> {
160 match &self.entities_module {
161 Some(m) => Ok(m.clone()),
162 None => module_path_from_file(&self.entity_file),
163 }
164 }
165}
166
167fn module_path_from_file(path: &std::path::Path) -> crate::error::Result<String> {
171 let path_str = path.to_string_lossy().replace('\\', "/");
172
173 let after_src = match path_str.rfind("/src/") {
174 Some(pos) => &path_str[pos + 5..],
175 None if path_str.starts_with("src/") => &path_str[4..],
176 _ => {
177 return Err(crate::error::Error::Config(format!(
178 "Cannot derive module path from '{}': no 'src/' found. Use --entities-module explicitly.",
179 path.display()
180 )));
181 }
182 };
183
184 let without_ext = after_src.strip_suffix(".rs").unwrap_or(after_src);
185 let module = without_ext.strip_suffix("/mod").unwrap_or(without_ext);
186
187 let module_path = format!("crate::{}", module.replace('/', "::"));
188 Ok(module_path)
189}
190
191#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum DatabaseKind {
193 Postgres,
194 Mysql,
195 Sqlite,
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
199pub enum PoolVisibility {
200 #[default]
201 Private,
202 Pub,
203 PubCrate,
204}
205
206impl std::str::FromStr for PoolVisibility {
207 type Err = String;
208
209 fn from_str(s: &str) -> Result<Self, Self::Err> {
210 match s {
211 "private" => Ok(Self::Private),
212 "pub" => Ok(Self::Pub),
213 "pub(crate)" => Ok(Self::PubCrate),
214 other => Err(format!(
215 "Unknown pool visibility '{}'. Expected: private, pub, pub(crate)",
216 other
217 )),
218 }
219 }
220}
221
222#[derive(Debug, Clone, Default)]
225pub struct Methods {
226 pub get_all: bool,
227 pub paginate: bool,
228 pub get: bool,
229 pub insert: bool,
230 pub update: bool,
231 pub delete: bool,
232}
233
234const ALL_METHODS: &[&str] = &["get_all", "paginate", "get", "insert", "update", "delete"];
235
236impl Methods {
237 pub fn from_list(names: &[String]) -> Result<Self, String> {
239 let mut m = Self::default();
240 for name in names {
241 match name.as_str() {
242 "*" => return Ok(Self::all()),
243 "get_all" => m.get_all = true,
244 "paginate" => m.paginate = true,
245 "get" => m.get = true,
246 "insert" => m.insert = true,
247 "update" => m.update = true,
248 "delete" => m.delete = true,
249 other => {
250 return Err(format!(
251 "Unknown method '{}'. Valid values: *, {}",
252 other,
253 ALL_METHODS.join(", ")
254 ))
255 }
256 }
257 }
258 Ok(m)
259 }
260
261 pub fn all() -> Self {
262 Self {
263 get_all: true,
264 paginate: true,
265 get: true,
266 insert: true,
267 update: true,
268 delete: true,
269 }
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 fn make_db_args(url: &str) -> DatabaseArgs {
278 DatabaseArgs {
279 database_url: url.to_string(),
280 schemas: vec!["public".into()],
281 }
282 }
283
284 fn make_entities_args_with_overrides(overrides: Vec<&str>) -> EntitiesArgs {
285 EntitiesArgs {
286 db: make_db_args("postgres://localhost/db"),
287 output_dir: PathBuf::from("out"),
288 derives: vec![],
289 type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
290 single_file: false,
291 tables: None,
292 exclude_tables: None,
293 views: false,
294 dry_run: false,
295 }
296 }
297
298 #[test]
301 fn test_postgres_url() {
302 let args = make_db_args("postgres://localhost/db");
303 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
304 }
305
306 #[test]
307 fn test_postgresql_url() {
308 let args = make_db_args("postgresql://localhost/db");
309 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
310 }
311
312 #[test]
313 fn test_postgres_full_url() {
314 let args = make_db_args("postgres://user:pass@host:5432/db");
315 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
316 }
317
318 #[test]
319 fn test_mysql_url() {
320 let args = make_db_args("mysql://localhost/db");
321 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
322 }
323
324 #[test]
325 fn test_mysql_full_url() {
326 let args = make_db_args("mysql://user:pass@host:3306/db");
327 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
328 }
329
330 #[test]
331 fn test_sqlite_url() {
332 let args = make_db_args("sqlite://path.db");
333 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
334 }
335
336 #[test]
337 fn test_sqlite_colon() {
338 let args = make_db_args("sqlite:path.db");
339 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
340 }
341
342 #[test]
343 fn test_sqlite_memory() {
344 let args = make_db_args("sqlite::memory:");
345 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
346 }
347
348 #[test]
349 fn test_http_url_fails() {
350 let args = make_db_args("http://example.com");
351 assert!(args.database_kind().is_err());
352 }
353
354 #[test]
355 fn test_empty_url_fails() {
356 let args = make_db_args("");
357 assert!(args.database_kind().is_err());
358 }
359
360 #[test]
361 fn test_mongo_url_fails() {
362 let args = make_db_args("mongo://localhost");
363 assert!(args.database_kind().is_err());
364 }
365
366 #[test]
367 fn test_uppercase_postgres_fails() {
368 let args = make_db_args("POSTGRES://localhost");
369 assert!(args.database_kind().is_err());
370 }
371
372 #[test]
375 fn test_overrides_empty() {
376 let args = make_entities_args_with_overrides(vec![]);
377 assert!(args.parse_type_overrides().is_empty());
378 }
379
380 #[test]
381 fn test_overrides_single() {
382 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]);
383 let map = args.parse_type_overrides();
384 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
385 }
386
387 #[test]
388 fn test_overrides_multiple() {
389 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
390 let map = args.parse_type_overrides();
391 assert_eq!(map.len(), 2);
392 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
393 assert_eq!(map.get("uuid").unwrap(), "MyUuid");
394 }
395
396 #[test]
397 fn test_overrides_malformed_skipped() {
398 let args = make_entities_args_with_overrides(vec!["noequals"]);
399 assert!(args.parse_type_overrides().is_empty());
400 }
401
402 #[test]
403 fn test_overrides_mixed_valid_invalid() {
404 let args = make_entities_args_with_overrides(vec!["good=val", "bad"]);
405 let map = args.parse_type_overrides();
406 assert_eq!(map.len(), 1);
407 assert_eq!(map.get("good").unwrap(), "val");
408 }
409
410 #[test]
411 fn test_overrides_equals_in_value() {
412 let args = make_entities_args_with_overrides(vec!["key=val=ue"]);
413 let map = args.parse_type_overrides();
414 assert_eq!(map.get("key").unwrap(), "val=ue");
415 }
416
417 #[test]
418 fn test_overrides_empty_key() {
419 let args = make_entities_args_with_overrides(vec!["=value"]);
420 let map = args.parse_type_overrides();
421 assert_eq!(map.get("").unwrap(), "value");
422 }
423
424 #[test]
425 fn test_overrides_empty_value() {
426 let args = make_entities_args_with_overrides(vec!["key="]);
427 let map = args.parse_type_overrides();
428 assert_eq!(map.get("key").unwrap(), "");
429 }
430
431 #[test]
434 fn test_exclude_tables_default_none() {
435 let args = make_entities_args_with_overrides(vec![]);
436 assert!(args.exclude_tables.is_none());
437 }
438
439 #[test]
440 fn test_exclude_tables_set() {
441 let mut args = make_entities_args_with_overrides(vec![]);
442 args.exclude_tables = Some(vec!["_migrations".to_string(), "schema_versions".to_string()]);
443 assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2);
444 assert!(args.exclude_tables.as_ref().unwrap().contains(&"_migrations".to_string()));
445 }
446
447 #[test]
450 fn test_methods_default_all_false() {
451 let m = Methods::default();
452 assert!(!m.get_all);
453 assert!(!m.paginate);
454 assert!(!m.get);
455 assert!(!m.insert);
456 assert!(!m.update);
457 assert!(!m.delete);
458 }
459
460 #[test]
461 fn test_methods_star() {
462 let m = Methods::from_list(&["*".to_string()]).unwrap();
463 assert!(m.get_all);
464 assert!(m.paginate);
465 assert!(m.get);
466 assert!(m.insert);
467 assert!(m.update);
468 assert!(m.delete);
469 }
470
471 #[test]
472 fn test_methods_single() {
473 let m = Methods::from_list(&["get".to_string()]).unwrap();
474 assert!(m.get);
475 assert!(!m.get_all);
476 assert!(!m.insert);
477 }
478
479 #[test]
480 fn test_methods_multiple() {
481 let m = Methods::from_list(&["get_all".to_string(), "delete".to_string()]).unwrap();
482 assert!(m.get_all);
483 assert!(m.delete);
484 assert!(!m.insert);
485 assert!(!m.paginate);
486 }
487
488 #[test]
489 fn test_methods_unknown_fails() {
490 let result = Methods::from_list(&["unknown".to_string()]);
491 assert!(result.is_err());
492 assert!(result.unwrap_err().contains("Unknown method"));
493 }
494
495 #[test]
496 fn test_methods_all() {
497 let m = Methods::all();
498 assert!(m.get_all);
499 assert!(m.paginate);
500 assert!(m.get);
501 assert!(m.insert);
502 assert!(m.update);
503 assert!(m.delete);
504 }
505
506 #[test]
509 fn test_module_path_simple() {
510 let p = PathBuf::from("src/models/users.rs");
511 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
512 }
513
514 #[test]
515 fn test_module_path_mod_rs() {
516 let p = PathBuf::from("src/models/mod.rs");
517 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models");
518 }
519
520 #[test]
521 fn test_module_path_nested() {
522 let p = PathBuf::from("src/db/entities/agent.rs");
523 assert_eq!(module_path_from_file(&p).unwrap(), "crate::db::entities::agent");
524 }
525
526 #[test]
527 fn test_module_path_absolute_with_src() {
528 let p = PathBuf::from("/home/user/project/src/models/users.rs");
529 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
530 }
531
532 #[test]
533 fn test_module_path_relative_with_src() {
534 let p = PathBuf::from("../other_project/src/models/users.rs");
535 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
536 }
537
538 #[test]
539 fn test_module_path_no_src_fails() {
540 let p = PathBuf::from("models/users.rs");
541 assert!(module_path_from_file(&p).is_err());
542 }
543
544 #[test]
545 fn test_module_path_deeply_nested_mod() {
546 let p = PathBuf::from("src/a/b/c/mod.rs");
547 assert_eq!(module_path_from_file(&p).unwrap(), "crate::a::b::c");
548 }
549
550 #[test]
551 fn test_module_path_src_root_file() {
552 let p = PathBuf::from("src/lib.rs");
553 assert_eq!(module_path_from_file(&p).unwrap(), "crate::lib");
554 }
555}