Skip to main content

sqlx_gen/
cli.rs

1use clap::Parser;
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")]
7pub struct Args {
8    /// Database connection URL
9    #[arg(short = 'u', long, env = "DATABASE_URL")]
10    pub database_url: String,
11
12    /// Output directory for generated files
13    #[arg(short = 'o', long, default_value = "src/models")]
14    pub output_dir: PathBuf,
15
16    /// Schemas to introspect (comma-separated, PG default: public)
17    #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
18    pub schemas: Vec<String>,
19
20    /// Additional derives (e.g. Serialize,Deserialize,PartialEq)
21    #[arg(long, value_delimiter = ',')]
22    pub derives: Vec<String>,
23
24    /// Type overrides (e.g. jsonb=MyJsonType,uuid=MyUuid)
25    #[arg(long, value_delimiter = ',')]
26    pub type_overrides: Vec<String>,
27
28    /// Generate everything into a single file instead of one file per table
29    #[arg(long)]
30    pub single_file: bool,
31
32    /// Only generate for these tables (comma-separated)
33    #[arg(long, value_delimiter = ',')]
34    pub tables: Option<Vec<String>>,
35
36    /// Also generate structs for SQL views
37    #[arg(long)]
38    pub views: bool,
39
40    /// Print to stdout without writing files
41    #[arg(long)]
42    pub dry_run: bool,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum DatabaseKind {
47    Postgres,
48    Mysql,
49    Sqlite,
50}
51
52impl Args {
53    pub fn database_kind(&self) -> anyhow::Result<DatabaseKind> {
54        let url = &self.database_url;
55        if url.starts_with("postgres://") || url.starts_with("postgresql://") {
56            Ok(DatabaseKind::Postgres)
57        } else if url.starts_with("mysql://") {
58            Ok(DatabaseKind::Mysql)
59        } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
60            Ok(DatabaseKind::Sqlite)
61        } else {
62            anyhow::bail!(
63                "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix."
64            )
65        }
66    }
67
68    pub fn parse_type_overrides(&self) -> HashMap<String, String> {
69        self.type_overrides
70            .iter()
71            .filter_map(|s| {
72                let (k, v) = s.split_once('=')?;
73                Some((k.to_string(), v.to_string()))
74            })
75            .collect()
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    fn make_args(url: &str) -> Args {
84        Args {
85            database_url: url.to_string(),
86            output_dir: PathBuf::from("out"),
87            schemas: vec!["public".into()],
88            derives: vec![],
89            type_overrides: vec![],
90            single_file: false,
91            tables: None,
92            views: false,
93            dry_run: false,
94        }
95    }
96
97    fn make_args_with_overrides(overrides: Vec<&str>) -> Args {
98        Args {
99            database_url: "postgres://localhost/db".to_string(),
100            output_dir: PathBuf::from("out"),
101            schemas: vec!["public".into()],
102            derives: vec![],
103            type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
104            single_file: false,
105            tables: None,
106            views: false,
107            dry_run: false,
108        }
109    }
110
111    // ========== database_kind ==========
112
113    #[test]
114    fn test_postgres_url() {
115        let args = make_args("postgres://localhost/db");
116        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
117    }
118
119    #[test]
120    fn test_postgresql_url() {
121        let args = make_args("postgresql://localhost/db");
122        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
123    }
124
125    #[test]
126    fn test_postgres_full_url() {
127        let args = make_args("postgres://user:pass@host:5432/db");
128        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
129    }
130
131    #[test]
132    fn test_mysql_url() {
133        let args = make_args("mysql://localhost/db");
134        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
135    }
136
137    #[test]
138    fn test_mysql_full_url() {
139        let args = make_args("mysql://user:pass@host:3306/db");
140        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
141    }
142
143    #[test]
144    fn test_sqlite_url() {
145        let args = make_args("sqlite://path.db");
146        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
147    }
148
149    #[test]
150    fn test_sqlite_colon() {
151        let args = make_args("sqlite:path.db");
152        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
153    }
154
155    #[test]
156    fn test_sqlite_memory() {
157        let args = make_args("sqlite::memory:");
158        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
159    }
160
161    #[test]
162    fn test_http_url_fails() {
163        let args = make_args("http://example.com");
164        assert!(args.database_kind().is_err());
165    }
166
167    #[test]
168    fn test_empty_url_fails() {
169        let args = make_args("");
170        assert!(args.database_kind().is_err());
171    }
172
173    #[test]
174    fn test_mongo_url_fails() {
175        let args = make_args("mongo://localhost");
176        assert!(args.database_kind().is_err());
177    }
178
179    #[test]
180    fn test_uppercase_postgres_fails() {
181        let args = make_args("POSTGRES://localhost");
182        assert!(args.database_kind().is_err());
183    }
184
185    // ========== parse_type_overrides ==========
186
187    #[test]
188    fn test_overrides_empty() {
189        let args = make_args_with_overrides(vec![]);
190        assert!(args.parse_type_overrides().is_empty());
191    }
192
193    #[test]
194    fn test_overrides_single() {
195        let args = make_args_with_overrides(vec!["jsonb=MyJson"]);
196        let map = args.parse_type_overrides();
197        assert_eq!(map.get("jsonb").unwrap(), "MyJson");
198    }
199
200    #[test]
201    fn test_overrides_multiple() {
202        let args = make_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
203        let map = args.parse_type_overrides();
204        assert_eq!(map.len(), 2);
205        assert_eq!(map.get("jsonb").unwrap(), "MyJson");
206        assert_eq!(map.get("uuid").unwrap(), "MyUuid");
207    }
208
209    #[test]
210    fn test_overrides_malformed_skipped() {
211        let args = make_args_with_overrides(vec!["noequals"]);
212        assert!(args.parse_type_overrides().is_empty());
213    }
214
215    #[test]
216    fn test_overrides_mixed_valid_invalid() {
217        let args = make_args_with_overrides(vec!["good=val", "bad"]);
218        let map = args.parse_type_overrides();
219        assert_eq!(map.len(), 1);
220        assert_eq!(map.get("good").unwrap(), "val");
221    }
222
223    #[test]
224    fn test_overrides_equals_in_value() {
225        let args = make_args_with_overrides(vec!["key=val=ue"]);
226        let map = args.parse_type_overrides();
227        assert_eq!(map.get("key").unwrap(), "val=ue");
228    }
229
230    #[test]
231    fn test_overrides_empty_key() {
232        let args = make_args_with_overrides(vec!["=value"]);
233        let map = args.parse_type_overrides();
234        assert_eq!(map.get("").unwrap(), "value");
235    }
236
237    #[test]
238    fn test_overrides_empty_value() {
239        let args = make_args_with_overrides(vec!["key="]);
240        let map = args.parse_type_overrides();
241        assert_eq!(map.get("key").unwrap(), "");
242    }
243}