1use std::path::PathBuf;
2
3use crate::config::{
4 Credentials, DatabaseConfig, Dialect, Driver, Extension, Filter, PostgresCreds,
5};
6use crate::error::CliError;
7
8#[derive(clap::Args, Debug, Clone, Default)]
14pub struct ConnectionOverrides {
15 #[arg(long)]
17 pub url: Option<String>,
18
19 #[arg(long)]
21 pub host: Option<String>,
22
23 #[arg(long)]
25 pub port: Option<u16>,
26
27 #[arg(long)]
29 pub user: Option<String>,
30
31 #[arg(long)]
33 pub password: Option<String>,
34
35 #[arg(long)]
37 pub database: Option<String>,
38
39 #[arg(long)]
41 pub ssl: Option<String>,
42
43 #[arg(long = "authToken", alias = "auth-token")]
45 pub auth_token: Option<String>,
46}
47
48#[derive(clap::Args, Debug, Clone, Default)]
53pub struct FilterArgs {
54 #[arg(long = "tablesFilter", value_delimiter = ',')]
56 pub tables_filter: Option<Vec<String>>,
57
58 #[arg(long = "schemaFilters", alias = "schemaFilter", value_delimiter = ',')]
60 pub schema_filters: Option<Vec<String>>,
61
62 #[arg(long = "extensionsFilters", value_delimiter = ',', value_parser = parse_extension_arg)]
64 pub extensions_filters: Option<Vec<Extension>>,
65}
66
67fn parse_extension_arg(s: &str) -> Result<Extension, String> {
68 s.parse()
69}
70
71impl ConnectionOverrides {
72 #[must_use]
73 pub const fn has_any(&self) -> bool {
74 self.url.is_some()
75 || self.host.is_some()
76 || self.port.is_some()
77 || self.user.is_some()
78 || self.password.is_some()
79 || self.database.is_some()
80 || self.ssl.is_some()
81 || self.auth_token.is_some()
82 }
83}
84
85#[must_use]
86pub fn resolve_dialect(db: &DatabaseConfig, override_dialect: Option<Dialect>) -> Dialect {
87 override_dialect.unwrap_or(db.dialect)
88}
89
90pub fn resolve_driver(
99 db: &DatabaseConfig,
100 dialect: Dialect,
101 driver_override: Option<Driver>,
102) -> Result<Option<Driver>, CliError> {
103 let driver = driver_override.or(db.driver);
104 if let Some(driver) = driver
105 && !driver.is_valid_for(dialect)
106 {
107 return Err(CliError::Other(format!(
108 "driver '{driver}' invalid for {dialect} dialect"
109 )));
110 }
111 Ok(driver)
112}
113
114pub fn resolve_credentials(
123 db: &DatabaseConfig,
124 dialect: Dialect,
125 overrides: &ConnectionOverrides,
126) -> Result<Option<Credentials>, CliError> {
127 if !overrides.has_any() {
128 if dialect != db.dialect {
129 return Err(CliError::Other(format!(
130 "--dialect={dialect} requires matching credential flags (--url/--host/--database/etc)"
131 )));
132 }
133 return db.credentials().map_err(Into::into);
134 }
135
136 let creds = match dialect {
137 Dialect::Sqlite => {
138 if overrides.host.is_some()
139 || overrides.port.is_some()
140 || overrides.user.is_some()
141 || overrides.password.is_some()
142 || overrides.database.is_some()
143 || overrides.ssl.is_some()
144 || overrides.auth_token.is_some()
145 {
146 return Err(CliError::Other(
147 "sqlite credentials only support --url for local database path".into(),
148 ));
149 }
150
151 let path = overrides
152 .url
153 .clone()
154 .ok_or_else(|| CliError::Other("sqlite requires --url".into()))?;
155
156 Credentials::Sqlite {
157 path: path.into_boxed_str(),
158 }
159 }
160 Dialect::Turso => {
161 if overrides.host.is_some()
162 || overrides.port.is_some()
163 || overrides.user.is_some()
164 || overrides.password.is_some()
165 || overrides.database.is_some()
166 || overrides.ssl.is_some()
167 {
168 return Err(CliError::Other(
169 "turso credentials support --url and optional --authToken".into(),
170 ));
171 }
172
173 let url = overrides
174 .url
175 .clone()
176 .ok_or_else(|| CliError::Other("turso requires --url".into()))?;
177
178 Credentials::Turso {
179 url: url.into_boxed_str(),
180 auth_token: overrides.auth_token.clone().map(String::into_boxed_str),
181 }
182 }
183 Dialect::Postgresql => {
184 if overrides.auth_token.is_some() {
185 return Err(CliError::Other(
186 "postgresql does not support --authToken (use --password or --url)".into(),
187 ));
188 }
189
190 if let Some(url) = overrides.url.clone() {
191 if overrides.host.is_some()
192 || overrides.port.is_some()
193 || overrides.user.is_some()
194 || overrides.password.is_some()
195 || overrides.database.is_some()
196 || overrides.ssl.is_some()
197 {
198 return Err(CliError::Other(
199 "postgresql credentials: use either --url OR --host/--database[/--port/...], not both"
200 .into(),
201 ));
202 }
203
204 Credentials::Postgres(PostgresCreds::Url(url.into_boxed_str()))
205 } else {
206 let host = overrides.host.clone().ok_or_else(|| {
207 CliError::Other("postgresql host credentials require --host".into())
208 })?;
209 let database = overrides.database.clone().ok_or_else(|| {
210 CliError::Other("postgresql host credentials require --database".into())
211 })?;
212
213 Credentials::Postgres(PostgresCreds::Host {
214 host: host.into_boxed_str(),
215 port: overrides.port.unwrap_or(5432),
216 user: overrides.user.clone().map(String::into_boxed_str),
217 password: overrides.password.clone().map(String::into_boxed_str),
218 database: database.into_boxed_str(),
219 ssl: parse_ssl_override(overrides.ssl.as_deref())?.unwrap_or(false),
220 })
221 }
222 }
223 };
224
225 Ok(Some(creds))
226}
227
228fn parse_ssl_override(ssl: Option<&str>) -> Result<Option<bool>, CliError> {
229 let Some(raw) = ssl else {
230 return Ok(None);
231 };
232
233 let value = raw.trim().to_ascii_lowercase();
234 let enabled = match value.as_str() {
235 "true" | "1" | "yes" | "on" | "require" | "allow" | "prefer" | "verify-full"
236 | "verify-ca" => true,
237 "false" | "0" | "no" | "off" | "disable" => false,
238 _ => {
239 return Err(CliError::Other(format!(
240 "invalid --ssl value '{raw}'; expected one of: true,false,require,allow,prefer,verify-full,verify-ca,disable"
241 )));
242 }
243 };
244
245 Ok(Some(enabled))
246}
247
248#[must_use]
249pub fn resolve_filter_list(cli: Option<&[String]>, config: Option<&Filter>) -> Option<Vec<String>> {
250 if let Some(values) = cli {
251 if values.is_empty() {
252 return None;
253 }
254 return Some(values.to_vec());
255 }
256
257 config.map(|f| f.iter().map(ToOwned::to_owned).collect())
258}
259
260#[must_use]
261pub fn resolve_schema_filters(
262 dialect: Dialect,
263 cli: Option<&[String]>,
264 config: Option<&Filter>,
265) -> Option<Vec<String>> {
266 let resolved = resolve_filter_list(cli, config);
267 if resolved.is_some() {
268 return resolved;
269 }
270
271 if matches!(dialect, Dialect::Postgresql) {
272 Some(vec!["public".to_string()])
273 } else {
274 None
275 }
276}
277
278#[must_use]
279pub fn resolve_extensions_filter(
280 cli: Option<&[Extension]>,
281 config: Option<&[Extension]>,
282) -> Option<Vec<Extension>> {
283 if let Some(values) = cli {
284 if values.is_empty() {
285 return None;
286 }
287 return Some(values.to_vec());
288 }
289
290 config.map(<[Extension]>::to_vec)
291}
292
293#[must_use]
294pub fn resolve_schema_display(db: &DatabaseConfig, schema_override: Option<&[String]>) -> String {
295 match schema_override {
296 Some(v) if !v.is_empty() => v.join(", "),
297 _ => db.schema_display(),
298 }
299}
300
301pub fn resolve_schema_files(
310 db: &DatabaseConfig,
311 schema_override: Option<&[String]>,
312) -> Result<Vec<PathBuf>, CliError> {
313 let Some(schema_patterns) = schema_override else {
314 return db.schema_files().map_err(Into::into);
315 };
316
317 if schema_patterns.is_empty() {
318 return Err(CliError::NoSchemaFiles("(empty schema override)".into()));
319 }
320
321 let mut files = Vec::new();
322
323 for pattern in schema_patterns {
324 let pat = pattern.trim();
325 let is_glob = pat.contains('*') || pat.contains('?') || pat.contains('[');
326
327 if !is_glob {
328 let p = PathBuf::from(pat);
329 if p.exists() {
330 files.push(p);
331 continue;
332 }
333 }
334
335 let pat_norm = pat.replace('\\', "/");
336 let paths = glob::glob(&pat_norm)
337 .map_err(|e| CliError::Other(format!("invalid glob '{pat}': {e}")))?;
338 let matched: Vec<_> = paths.filter_map(Result::ok).collect();
339
340 if matched.is_empty() && !is_glob {
341 let p = PathBuf::from(&pat_norm);
342 if p.exists() {
343 files.push(p);
344 }
345 } else {
346 files.extend(matched);
347 }
348 }
349
350 files.retain(|p| p.is_file());
351 files.sort();
352 files.dedup();
353
354 if files.is_empty() {
355 return Err(CliError::NoSchemaFiles(
356 schema_patterns
357 .iter()
358 .map(std::string::String::as_str)
359 .collect::<Vec<_>>()
360 .join(", "),
361 ));
362 }
363
364 Ok(files)
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::config::Config;
371 use std::path::PathBuf;
372 use tempfile::TempDir;
373
374 fn load_db(config_toml: &str) -> (TempDir, DatabaseConfig) {
375 let dir = TempDir::new().expect("temp dir");
376 let path = dir.path().join("drizzle.config.toml");
377 std::fs::write(&path, config_toml).expect("write config");
378 let config = Config::load_from(&path).expect("load config");
379 let db = config.default_database().expect("default db").clone();
380 (dir, db)
381 }
382
383 #[test]
384 fn resolve_filter_list_prefers_cli_values() {
385 let config = Filter::Many(vec!["from_config".to_string()]);
386 let cli = vec!["from_cli".to_string()];
387
388 let resolved = resolve_filter_list(Some(&cli), Some(&config));
389 assert_eq!(resolved, Some(vec!["from_cli".to_string()]));
390 }
391
392 #[test]
393 fn resolve_filter_list_uses_config_when_cli_missing() {
394 let config = Filter::Many(vec!["public".to_string(), "dev".to_string()]);
395 let resolved = resolve_filter_list(None, Some(&config));
396 assert_eq!(
397 resolved,
398 Some(vec!["public".to_string(), "dev".to_string()])
399 );
400 }
401
402 #[test]
403 fn resolve_schema_filters_defaults_to_public_for_postgres() {
404 let resolved = resolve_schema_filters(Dialect::Postgresql, None, None);
405 assert_eq!(resolved, Some(vec!["public".to_string()]));
406 }
407
408 #[test]
409 fn resolve_schema_filters_does_not_default_for_sqlite() {
410 let resolved = resolve_schema_filters(Dialect::Sqlite, None, None);
411 assert_eq!(resolved, None);
412 }
413
414 #[test]
415 fn resolve_extensions_filter_prefers_cli_values() {
416 let cli = vec![Extension::Postgis];
417 let config = vec![];
418
419 let resolved = resolve_extensions_filter(Some(&cli), Some(&config));
420 assert_eq!(resolved, Some(vec![Extension::Postgis]));
421 }
422
423 #[test]
424 fn resolve_driver_rejects_invalid_override() {
425 let (_dir, db) = load_db(
426 r#"
427dialect = "sqlite"
428schema = "src/schema.rs"
429"#,
430 );
431
432 let err = resolve_driver(&db, Dialect::Sqlite, Some(Driver::TokioPostgres))
433 .expect_err("driver should be rejected");
434 assert_eq!(
435 err.to_string(),
436 "driver 'tokio-postgres' invalid for sqlite dialect"
437 );
438 }
439
440 #[test]
441 fn resolve_credentials_requires_overrides_for_dialect_switch() {
442 let (_dir, db) = load_db(
443 r#"
444dialect = "sqlite"
445[dbCredentials]
446url = "./dev.db"
447"#,
448 );
449
450 let err = resolve_credentials(&db, Dialect::Postgresql, &ConnectionOverrides::default())
451 .expect_err("dialect switch should require explicit credentials");
452 assert_eq!(
453 err.to_string(),
454 "--dialect=postgresql requires matching credential flags (--url/--host/--database/etc)"
455 );
456 }
457
458 #[test]
459 fn resolve_credentials_sqlite_rejects_host_fields() {
460 let (_dir, db) = load_db(
461 r#"
462dialect = "sqlite"
463"#,
464 );
465
466 let overrides = ConnectionOverrides {
467 host: Some("localhost".to_string()),
468 ..Default::default()
469 };
470
471 let err = resolve_credentials(&db, Dialect::Sqlite, &overrides)
472 .expect_err("sqlite should reject host-style credentials");
473 assert_eq!(
474 err.to_string(),
475 "sqlite credentials only support --url for local database path"
476 );
477 }
478
479 #[test]
480 fn resolve_credentials_postgres_rejects_mixed_url_and_host_fields() {
481 let (_dir, db) = load_db(
482 r#"
483dialect = "postgresql"
484"#,
485 );
486
487 let overrides = ConnectionOverrides {
488 url: Some("postgres://u:p@localhost:5432/db".to_string()),
489 host: Some("localhost".to_string()),
490 database: Some("db".to_string()),
491 ..Default::default()
492 };
493
494 let err = resolve_credentials(&db, Dialect::Postgresql, &overrides)
495 .expect_err("postgres should reject mixed credentials");
496 assert_eq!(
497 err.to_string(),
498 "postgresql credentials: use either --url OR --host/--database[/--port/...], not both"
499 );
500 }
501
502 #[test]
503 fn resolve_credentials_postgres_requires_database_for_host_mode() {
504 let (_dir, db) = load_db(
505 r#"
506dialect = "postgresql"
507"#,
508 );
509
510 let overrides = ConnectionOverrides {
511 host: Some("localhost".to_string()),
512 ..Default::default()
513 };
514
515 let err = resolve_credentials(&db, Dialect::Postgresql, &overrides)
516 .expect_err("postgres host credentials require database");
517 assert_eq!(
518 err.to_string(),
519 "postgresql host credentials require --database"
520 );
521 }
522
523 #[test]
524 fn resolve_credentials_turso_accepts_url_with_optional_token() {
525 let (_dir, db) = load_db(
526 r#"
527dialect = "turso"
528"#,
529 );
530
531 let overrides = ConnectionOverrides {
532 url: Some("libsql://example.turso.io".to_string()),
533 auth_token: Some("secret".to_string()),
534 ..Default::default()
535 };
536
537 let creds = resolve_credentials(&db, Dialect::Turso, &overrides)
538 .expect("resolve creds")
539 .expect("some creds");
540
541 match creds {
542 Credentials::Turso { url, auth_token } => {
543 assert_eq!(url.as_ref(), "libsql://example.turso.io");
544 assert_eq!(auth_token.as_deref(), Some("secret"));
545 }
546 _ => panic!("expected turso credentials"),
547 }
548 }
549
550 #[test]
551 fn resolve_credentials_postgres_host_mode_accepts_ssl_modes() {
552 let (_dir, db) = load_db(
553 r#"
554dialect = "postgresql"
555"#,
556 );
557
558 let require_ssl = ConnectionOverrides {
559 host: Some("localhost".to_string()),
560 database: Some("db".to_string()),
561 ssl: Some("require".to_string()),
562 ..Default::default()
563 };
564 let creds = resolve_credentials(&db, Dialect::Postgresql, &require_ssl)
565 .expect("resolve")
566 .expect("creds");
567 match creds {
568 Credentials::Postgres(PostgresCreds::Host { ssl, .. }) => assert!(ssl),
569 _ => panic!("expected postgres host creds"),
570 }
571
572 let disable_ssl = ConnectionOverrides {
573 host: Some("localhost".to_string()),
574 database: Some("db".to_string()),
575 ssl: Some("disable".to_string()),
576 ..Default::default()
577 };
578 let creds = resolve_credentials(&db, Dialect::Postgresql, &disable_ssl)
579 .expect("resolve")
580 .expect("creds");
581 match creds {
582 Credentials::Postgres(PostgresCreds::Host { ssl, .. }) => assert!(!ssl),
583 _ => panic!("expected postgres host creds"),
584 }
585 }
586
587 #[test]
588 fn resolve_credentials_postgres_host_mode_rejects_invalid_ssl_value() {
589 let (_dir, db) = load_db(
590 r#"
591dialect = "postgresql"
592"#,
593 );
594
595 let overrides = ConnectionOverrides {
596 host: Some("localhost".to_string()),
597 database: Some("db".to_string()),
598 ssl: Some("maybe".to_string()),
599 ..Default::default()
600 };
601
602 let err = resolve_credentials(&db, Dialect::Postgresql, &overrides)
603 .expect_err("invalid ssl should fail");
604 assert_eq!(
605 err.to_string(),
606 "invalid --ssl value 'maybe'; expected one of: true,false,require,allow,prefer,verify-full,verify-ca,disable"
607 );
608 }
609
610 #[test]
611 fn resolve_schema_filters_defaults_to_public_in_multi_db_postgres() {
612 let dir = TempDir::new().expect("temp dir");
613 let path = dir.path().join("drizzle.config.toml");
614 std::fs::write(
615 &path,
616 r#"
617[databases.pg]
618dialect = "postgresql"
619
620[databases.pg.dbCredentials]
621url = "postgres://localhost/db"
622
623[databases.sqlite]
624dialect = "sqlite"
625
626[databases.sqlite.dbCredentials]
627url = "./dev.db"
628"#,
629 )
630 .expect("write config");
631
632 let config = Config::load_from(&path).expect("load config");
633 let db = config.database(Some("pg")).expect("pg db");
634
635 let resolved = resolve_schema_filters(Dialect::Postgresql, None, db.schema_filter.as_ref());
636 assert_eq!(resolved, Some(vec!["public".to_string()]));
637 }
638
639 #[test]
640 fn resolve_schema_files_uses_override_glob() {
641 let (dir, db) = load_db(
642 r#"
643dialect = "sqlite"
644schema = "src/schema.rs"
645"#,
646 );
647
648 let a = dir.path().join("a.schema.rs");
649 let b = dir.path().join("b.schema.rs");
650 std::fs::write(&a, "pub struct A;").expect("write a");
651 std::fs::write(&b, "pub struct B;").expect("write b");
652
653 let pattern = format!("{}/*.schema.rs", dir.path().display()).replace('\\', "/");
654 let override_patterns = vec![pattern];
655 let files = resolve_schema_files(&db, Some(&override_patterns)).expect("resolve files");
656
657 let paths: Vec<PathBuf> = files;
658 assert_eq!(paths.len(), 2);
659 assert!(paths.iter().any(|p| p.ends_with("a.schema.rs")));
660 assert!(paths.iter().any(|p| p.ends_with("b.schema.rs")));
661 }
662}