Skip to main content

sqlx_gen/
cli.rs

1use clap::{Parser, Subcommand};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")]
7pub struct Cli {
8    #[command(subcommand)]
9    pub command: Command,
10}
11
12#[derive(Subcommand, Debug)]
13pub enum Command {
14    /// Generate code from database schema
15    Generate {
16        #[command(subcommand)]
17        subcommand: GenerateCommand,
18    },
19}
20
21#[derive(Subcommand, Debug)]
22pub enum GenerateCommand {
23    /// Generate entity structs, enums, composites, and domains
24    Entities(EntitiesArgs),
25    /// Generate CRUD repository for a table or view
26    Crud(CrudArgs),
27}
28
29#[derive(Parser, Debug)]
30pub struct DatabaseArgs {
31    /// Database connection URL
32    #[arg(short = 'u', long, env = "DATABASE_URL")]
33    pub database_url: String,
34
35    /// Schemas to introspect (comma-separated, PG default: public)
36    #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
37    pub schemas: Vec<String>,
38}
39
40impl DatabaseArgs {
41    pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
42        let url = &self.database_url;
43        if url.starts_with("postgres://") || url.starts_with("postgresql://") {
44            Ok(DatabaseKind::Postgres)
45        } else if url.starts_with("mysql://") {
46            Ok(DatabaseKind::Mysql)
47        } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
48            Ok(DatabaseKind::Sqlite)
49        } else {
50            Err(crate::error::Error::Config(
51                "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix.".to_string(),
52            ))
53        }
54    }
55}
56
57#[derive(Parser, Debug)]
58pub struct EntitiesArgs {
59    #[command(flatten)]
60    pub db: DatabaseArgs,
61
62    /// Output directory for generated files
63    #[arg(short = 'o', long, default_value = "src/models")]
64    pub output_dir: PathBuf,
65
66    /// Additional derives (e.g. Serialize,Deserialize,PartialEq)
67    #[arg(short = 'D', long, value_delimiter = ',')]
68    pub derives: Vec<String>,
69
70    /// Type overrides (e.g. jsonb=MyJsonType,uuid=MyUuid)
71    #[arg(short = 'T', long, value_delimiter = ',')]
72    pub type_overrides: Vec<String>,
73
74    /// Generate everything into a single file instead of one file per table
75    #[arg(short = 'S', long)]
76    pub single_file: bool,
77
78    /// Only generate for these tables (comma-separated)
79    #[arg(short = 't', long, value_delimiter = ',')]
80    pub tables: Option<Vec<String>>,
81
82    /// Exclude these tables/views from generation (comma-separated)
83    #[arg(short = 'x', long, value_delimiter = ',')]
84    pub exclude_tables: Option<Vec<String>>,
85
86    /// Also generate structs for SQL views
87    #[arg(short = 'v', long)]
88    pub views: bool,
89
90    /// Print to stdout without writing files
91    #[arg(short = 'n', long)]
92    pub dry_run: bool,
93}
94
95impl EntitiesArgs {
96    pub fn parse_type_overrides(&self) -> HashMap<String, String> {
97        self.type_overrides
98            .iter()
99            .filter_map(|s| {
100                let (k, v) = s.split_once('=')?;
101                Some((k.to_string(), v.to_string()))
102            })
103            .collect()
104    }
105}
106
107#[derive(Parser, Debug)]
108pub struct CrudArgs {
109    /// Path to the generated entity .rs file
110    #[arg(short = 'f', long)]
111    pub entity_file: PathBuf,
112
113    /// Database kind (postgres, mysql, sqlite)
114    #[arg(short = 'd', long)]
115    pub db_kind: String,
116
117    /// Module path of generated entities (e.g. "crate::models::users").
118    /// If omitted, derived from --entity-file by finding `src/` and converting the path.
119    #[arg(short = 'e', long)]
120    pub entities_module: Option<String>,
121
122    /// Output directory for generated repository files
123    #[arg(short = 'o', long, default_value = "src/crud")]
124    pub output_dir: PathBuf,
125
126    /// Methods to generate (comma-separated): *, get_all, paginate, get, insert, update, delete
127    #[arg(short = 'm', long, value_delimiter = ',')]
128    pub methods: Vec<String>,
129
130
131    /// Use sqlx::query_as!() compile-time checked macros instead of query_as::<_, T>() functions
132    #[arg(short = 'q', long)]
133    pub query_macro: bool,
134
135    /// Visibility of the pool field in generated repository structs: private, pub, pub(crate)
136    #[arg(short = 'p', long, default_value = "private")]
137    pub pool_visibility: PoolVisibility,
138
139    /// Print to stdout without writing files
140    #[arg(short = 'n', long)]
141    pub dry_run: bool,
142}
143
144impl CrudArgs {
145    pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
146        match self.db_kind.to_lowercase().as_str() {
147            "postgres" | "postgresql" | "pg" => Ok(DatabaseKind::Postgres),
148            "mysql" => Ok(DatabaseKind::Mysql),
149            "sqlite" => Ok(DatabaseKind::Sqlite),
150            other => Err(crate::error::Error::Config(format!(
151                "Unknown database kind '{}'. Expected: postgres, mysql, sqlite",
152                other
153            ))),
154        }
155    }
156
157    /// Resolve the entities module path: use the explicit value if provided,
158    /// otherwise derive it from the entity file path.
159    pub fn resolve_entities_module(&self) -> crate::error::Result<String> {
160        match &self.entities_module {
161            Some(m) => Ok(m.clone()),
162            None => module_path_from_file(&self.entity_file),
163        }
164    }
165}
166
167/// Derive a Rust module path from a file path by finding `src/` and converting.
168/// e.g. `some/project/src/models/users.rs` → `crate::models::users`
169/// e.g. `src/db/entities/mod.rs` → `crate::db::entities`
170fn module_path_from_file(path: &std::path::Path) -> crate::error::Result<String> {
171    let path_str = path.to_string_lossy().replace('\\', "/");
172
173    let after_src = match path_str.rfind("/src/") {
174        Some(pos) => &path_str[pos + 5..],
175        None if path_str.starts_with("src/") => &path_str[4..],
176        _ => {
177            return Err(crate::error::Error::Config(format!(
178                "Cannot derive module path from '{}': no 'src/' found. Use --entities-module explicitly.",
179                path.display()
180            )));
181        }
182    };
183
184    let without_ext = after_src.strip_suffix(".rs").unwrap_or(after_src);
185    let module = without_ext.strip_suffix("/mod").unwrap_or(without_ext);
186
187    let module_path = format!("crate::{}", module.replace('/', "::"));
188    Ok(module_path)
189}
190
191#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum DatabaseKind {
193    Postgres,
194    Mysql,
195    Sqlite,
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
199pub enum PoolVisibility {
200    #[default]
201    Private,
202    Pub,
203    PubCrate,
204}
205
206impl std::str::FromStr for PoolVisibility {
207    type Err = String;
208
209    fn from_str(s: &str) -> Result<Self, Self::Err> {
210        match s {
211            "private" => Ok(Self::Private),
212            "pub" => Ok(Self::Pub),
213            "pub(crate)" => Ok(Self::PubCrate),
214            other => Err(format!(
215                "Unknown pool visibility '{}'. Expected: private, pub, pub(crate)",
216                other
217            )),
218        }
219    }
220}
221
222/// Which CRUD methods to generate. All fields default to `false`.
223/// Use `Methods::from_list` to parse from CLI input.
224#[derive(Debug, Clone, Default)]
225pub struct Methods {
226    pub get_all: bool,
227    pub paginate: bool,
228    pub get: bool,
229    pub insert: bool,
230    pub update: bool,
231    pub delete: bool,
232}
233
234const ALL_METHODS: &[&str] = &["get_all", "paginate", "get", "insert", "update", "delete"];
235
236impl Methods {
237    /// Parse a list of method names. `"*"` enables all methods.
238    pub fn from_list(names: &[String]) -> Result<Self, String> {
239        let mut m = Self::default();
240        for name in names {
241            match name.as_str() {
242                "*" => return Ok(Self::all()),
243                "get_all" => m.get_all = true,
244                "paginate" => m.paginate = true,
245                "get" => m.get = true,
246                "insert" => m.insert = true,
247                "update" => m.update = true,
248                "delete" => m.delete = true,
249                other => {
250                    return Err(format!(
251                        "Unknown method '{}'. Valid values: *, {}",
252                        other,
253                        ALL_METHODS.join(", ")
254                    ))
255                }
256            }
257        }
258        Ok(m)
259    }
260
261    pub fn all() -> Self {
262        Self {
263            get_all: true,
264            paginate: true,
265            get: true,
266            insert: true,
267            update: true,
268            delete: true,
269        }
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    fn make_db_args(url: &str) -> DatabaseArgs {
278        DatabaseArgs {
279            database_url: url.to_string(),
280            schemas: vec!["public".into()],
281        }
282    }
283
284    fn make_entities_args_with_overrides(overrides: Vec<&str>) -> EntitiesArgs {
285        EntitiesArgs {
286            db: make_db_args("postgres://localhost/db"),
287            output_dir: PathBuf::from("out"),
288            derives: vec![],
289            type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
290            single_file: false,
291            tables: None,
292            exclude_tables: None,
293            views: false,
294            dry_run: false,
295        }
296    }
297
298    // ========== database_kind ==========
299
300    #[test]
301    fn test_postgres_url() {
302        let args = make_db_args("postgres://localhost/db");
303        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
304    }
305
306    #[test]
307    fn test_postgresql_url() {
308        let args = make_db_args("postgresql://localhost/db");
309        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
310    }
311
312    #[test]
313    fn test_postgres_full_url() {
314        let args = make_db_args("postgres://user:pass@host:5432/db");
315        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
316    }
317
318    #[test]
319    fn test_mysql_url() {
320        let args = make_db_args("mysql://localhost/db");
321        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
322    }
323
324    #[test]
325    fn test_mysql_full_url() {
326        let args = make_db_args("mysql://user:pass@host:3306/db");
327        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
328    }
329
330    #[test]
331    fn test_sqlite_url() {
332        let args = make_db_args("sqlite://path.db");
333        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
334    }
335
336    #[test]
337    fn test_sqlite_colon() {
338        let args = make_db_args("sqlite:path.db");
339        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
340    }
341
342    #[test]
343    fn test_sqlite_memory() {
344        let args = make_db_args("sqlite::memory:");
345        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
346    }
347
348    #[test]
349    fn test_http_url_fails() {
350        let args = make_db_args("http://example.com");
351        assert!(args.database_kind().is_err());
352    }
353
354    #[test]
355    fn test_empty_url_fails() {
356        let args = make_db_args("");
357        assert!(args.database_kind().is_err());
358    }
359
360    #[test]
361    fn test_mongo_url_fails() {
362        let args = make_db_args("mongo://localhost");
363        assert!(args.database_kind().is_err());
364    }
365
366    #[test]
367    fn test_uppercase_postgres_fails() {
368        let args = make_db_args("POSTGRES://localhost");
369        assert!(args.database_kind().is_err());
370    }
371
372    // ========== parse_type_overrides ==========
373
374    #[test]
375    fn test_overrides_empty() {
376        let args = make_entities_args_with_overrides(vec![]);
377        assert!(args.parse_type_overrides().is_empty());
378    }
379
380    #[test]
381    fn test_overrides_single() {
382        let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]);
383        let map = args.parse_type_overrides();
384        assert_eq!(map.get("jsonb").unwrap(), "MyJson");
385    }
386
387    #[test]
388    fn test_overrides_multiple() {
389        let args = make_entities_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
390        let map = args.parse_type_overrides();
391        assert_eq!(map.len(), 2);
392        assert_eq!(map.get("jsonb").unwrap(), "MyJson");
393        assert_eq!(map.get("uuid").unwrap(), "MyUuid");
394    }
395
396    #[test]
397    fn test_overrides_malformed_skipped() {
398        let args = make_entities_args_with_overrides(vec!["noequals"]);
399        assert!(args.parse_type_overrides().is_empty());
400    }
401
402    #[test]
403    fn test_overrides_mixed_valid_invalid() {
404        let args = make_entities_args_with_overrides(vec!["good=val", "bad"]);
405        let map = args.parse_type_overrides();
406        assert_eq!(map.len(), 1);
407        assert_eq!(map.get("good").unwrap(), "val");
408    }
409
410    #[test]
411    fn test_overrides_equals_in_value() {
412        let args = make_entities_args_with_overrides(vec!["key=val=ue"]);
413        let map = args.parse_type_overrides();
414        assert_eq!(map.get("key").unwrap(), "val=ue");
415    }
416
417    #[test]
418    fn test_overrides_empty_key() {
419        let args = make_entities_args_with_overrides(vec!["=value"]);
420        let map = args.parse_type_overrides();
421        assert_eq!(map.get("").unwrap(), "value");
422    }
423
424    #[test]
425    fn test_overrides_empty_value() {
426        let args = make_entities_args_with_overrides(vec!["key="]);
427        let map = args.parse_type_overrides();
428        assert_eq!(map.get("key").unwrap(), "");
429    }
430
431    // ========== exclude_tables ==========
432
433    #[test]
434    fn test_exclude_tables_default_none() {
435        let args = make_entities_args_with_overrides(vec![]);
436        assert!(args.exclude_tables.is_none());
437    }
438
439    #[test]
440    fn test_exclude_tables_set() {
441        let mut args = make_entities_args_with_overrides(vec![]);
442        args.exclude_tables = Some(vec!["_migrations".to_string(), "schema_versions".to_string()]);
443        assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2);
444        assert!(args.exclude_tables.as_ref().unwrap().contains(&"_migrations".to_string()));
445    }
446
447    // ========== methods ==========
448
449    #[test]
450    fn test_methods_default_all_false() {
451        let m = Methods::default();
452        assert!(!m.get_all);
453        assert!(!m.paginate);
454        assert!(!m.get);
455        assert!(!m.insert);
456        assert!(!m.update);
457        assert!(!m.delete);
458    }
459
460    #[test]
461    fn test_methods_star() {
462        let m = Methods::from_list(&["*".to_string()]).unwrap();
463        assert!(m.get_all);
464        assert!(m.paginate);
465        assert!(m.get);
466        assert!(m.insert);
467        assert!(m.update);
468        assert!(m.delete);
469    }
470
471    #[test]
472    fn test_methods_single() {
473        let m = Methods::from_list(&["get".to_string()]).unwrap();
474        assert!(m.get);
475        assert!(!m.get_all);
476        assert!(!m.insert);
477    }
478
479    #[test]
480    fn test_methods_multiple() {
481        let m = Methods::from_list(&["get_all".to_string(), "delete".to_string()]).unwrap();
482        assert!(m.get_all);
483        assert!(m.delete);
484        assert!(!m.insert);
485        assert!(!m.paginate);
486    }
487
488    #[test]
489    fn test_methods_unknown_fails() {
490        let result = Methods::from_list(&["unknown".to_string()]);
491        assert!(result.is_err());
492        assert!(result.unwrap_err().contains("Unknown method"));
493    }
494
495    #[test]
496    fn test_methods_all() {
497        let m = Methods::all();
498        assert!(m.get_all);
499        assert!(m.paginate);
500        assert!(m.get);
501        assert!(m.insert);
502        assert!(m.update);
503        assert!(m.delete);
504    }
505
506    // ========== module_path_from_file ==========
507
508    #[test]
509    fn test_module_path_simple() {
510        let p = PathBuf::from("src/models/users.rs");
511        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
512    }
513
514    #[test]
515    fn test_module_path_mod_rs() {
516        let p = PathBuf::from("src/models/mod.rs");
517        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models");
518    }
519
520    #[test]
521    fn test_module_path_nested() {
522        let p = PathBuf::from("src/db/entities/agent.rs");
523        assert_eq!(module_path_from_file(&p).unwrap(), "crate::db::entities::agent");
524    }
525
526    #[test]
527    fn test_module_path_absolute_with_src() {
528        let p = PathBuf::from("/home/user/project/src/models/users.rs");
529        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
530    }
531
532    #[test]
533    fn test_module_path_relative_with_src() {
534        let p = PathBuf::from("../other_project/src/models/users.rs");
535        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
536    }
537
538    #[test]
539    fn test_module_path_no_src_fails() {
540        let p = PathBuf::from("models/users.rs");
541        assert!(module_path_from_file(&p).is_err());
542    }
543
544    #[test]
545    fn test_module_path_deeply_nested_mod() {
546        let p = PathBuf::from("src/a/b/c/mod.rs");
547        assert_eq!(module_path_from_file(&p).unwrap(), "crate::a::b::c");
548    }
549
550    #[test]
551    fn test_module_path_src_root_file() {
552        let p = PathBuf::from("src/lib.rs");
553        assert_eq!(module_path_from_file(&p).unwrap(), "crate::lib");
554    }
555}