Skip to main content

drizzle_cli/commands/
overrides.rs

1use std::path::PathBuf;
2
3use crate::config::{
4    Credentials, DatabaseConfig, Dialect, Driver, Extension, Filter, PostgresCreds,
5};
6use crate::error::CliError;
7
8/// CLI overrides for database connection credentials.
9///
10/// Each field maps to a top-level `--{name}` flag. Embedded into commands
11/// that talk to a live database via `#[command(flatten)]` so the connection
12/// surface is defined once and reused across `push`, `introspect`/`pull`.
13#[derive(clap::Args, Debug, Clone, Default)]
14pub struct ConnectionOverrides {
15    /// Database connection URL
16    #[arg(long)]
17    pub url: Option<String>,
18
19    /// Database host
20    #[arg(long)]
21    pub host: Option<String>,
22
23    /// Database port
24    #[arg(long)]
25    pub port: Option<u16>,
26
27    /// Database user
28    #[arg(long)]
29    pub user: Option<String>,
30
31    /// Database password
32    #[arg(long)]
33    pub password: Option<String>,
34
35    /// Database name
36    #[arg(long)]
37    pub database: Option<String>,
38
39    /// SSL mode (true/false or require/prefer/verify-full/disable)
40    #[arg(long)]
41    pub ssl: Option<String>,
42
43    /// Turso auth token
44    #[arg(long = "authToken", alias = "auth-token")]
45    pub auth_token: Option<String>,
46}
47
48/// CLI overrides for snapshot filters (tables/schemas/extensions).
49///
50/// Embedded into commands that read or push a snapshot via
51/// `#[command(flatten)]`. Single source of truth for the filter surface.
52#[derive(clap::Args, Debug, Clone, Default)]
53pub struct FilterArgs {
54    /// Table name filters
55    #[arg(long = "tablesFilter", value_delimiter = ',')]
56    pub tables_filter: Option<Vec<String>>,
57
58    /// Schema name filters
59    #[arg(long = "schemaFilters", alias = "schemaFilter", value_delimiter = ',')]
60    pub schema_filters: Option<Vec<String>>,
61
62    /// Extension filters (e.g. postgis)
63    #[arg(long = "extensionsFilters", value_delimiter = ',', value_parser = parse_extension_arg)]
64    pub extensions_filters: Option<Vec<Extension>>,
65}
66
67fn parse_extension_arg(s: &str) -> Result<Extension, String> {
68    s.parse()
69}
70
71impl ConnectionOverrides {
72    #[must_use]
73    pub const fn has_any(&self) -> bool {
74        self.url.is_some()
75            || self.host.is_some()
76            || self.port.is_some()
77            || self.user.is_some()
78            || self.password.is_some()
79            || self.database.is_some()
80            || self.ssl.is_some()
81            || self.auth_token.is_some()
82    }
83}
84
85#[must_use]
86pub fn resolve_dialect(db: &DatabaseConfig, override_dialect: Option<Dialect>) -> Dialect {
87    override_dialect.unwrap_or(db.dialect)
88}
89
90/// Resolve the effective driver by applying CLI overrides on top of the
91/// config's value, validating that the chosen driver is compatible with the
92/// resolved dialect.
93///
94/// # Errors
95///
96/// Returns [`CliError`] if the override driver is set but is not valid for the
97/// given `dialect` (e.g. `rusqlite` selected with `postgresql`).
98pub fn resolve_driver(
99    db: &DatabaseConfig,
100    dialect: Dialect,
101    driver_override: Option<Driver>,
102) -> Result<Option<Driver>, CliError> {
103    let driver = driver_override.or(db.driver);
104    if let Some(driver) = driver
105        && !driver.is_valid_for(dialect)
106    {
107        return Err(CliError::Other(format!(
108            "driver '{driver}' invalid for {dialect} dialect"
109        )));
110    }
111    Ok(driver)
112}
113
114/// Resolve database credentials from CLI overrides, falling back to the
115/// configured value when no override is set.
116///
117/// # Errors
118///
119/// Returns [`CliError`] if an override is provided but is incompatible with
120/// the resolved dialect, or if resolving a config-provided credentials block
121/// fails (e.g. missing environment variables).
122pub fn resolve_credentials(
123    db: &DatabaseConfig,
124    dialect: Dialect,
125    overrides: &ConnectionOverrides,
126) -> Result<Option<Credentials>, CliError> {
127    if !overrides.has_any() {
128        if dialect != db.dialect {
129            return Err(CliError::Other(format!(
130                "--dialect={dialect} requires matching credential flags (--url/--host/--database/etc)"
131            )));
132        }
133        return db.credentials().map_err(Into::into);
134    }
135
136    let creds = match dialect {
137        Dialect::Sqlite => {
138            if overrides.host.is_some()
139                || overrides.port.is_some()
140                || overrides.user.is_some()
141                || overrides.password.is_some()
142                || overrides.database.is_some()
143                || overrides.ssl.is_some()
144                || overrides.auth_token.is_some()
145            {
146                return Err(CliError::Other(
147                    "sqlite credentials only support --url for local database path".into(),
148                ));
149            }
150
151            let path = overrides
152                .url
153                .clone()
154                .ok_or_else(|| CliError::Other("sqlite requires --url".into()))?;
155
156            Credentials::Sqlite {
157                path: path.into_boxed_str(),
158            }
159        }
160        Dialect::Turso => {
161            if overrides.host.is_some()
162                || overrides.port.is_some()
163                || overrides.user.is_some()
164                || overrides.password.is_some()
165                || overrides.database.is_some()
166                || overrides.ssl.is_some()
167            {
168                return Err(CliError::Other(
169                    "turso credentials support --url and optional --authToken".into(),
170                ));
171            }
172
173            let url = overrides
174                .url
175                .clone()
176                .ok_or_else(|| CliError::Other("turso requires --url".into()))?;
177
178            Credentials::Turso {
179                url: url.into_boxed_str(),
180                auth_token: overrides.auth_token.clone().map(String::into_boxed_str),
181            }
182        }
183        Dialect::Postgresql => {
184            if overrides.auth_token.is_some() {
185                return Err(CliError::Other(
186                    "postgresql does not support --authToken (use --password or --url)".into(),
187                ));
188            }
189
190            if let Some(url) = overrides.url.clone() {
191                if overrides.host.is_some()
192                    || overrides.port.is_some()
193                    || overrides.user.is_some()
194                    || overrides.password.is_some()
195                    || overrides.database.is_some()
196                    || overrides.ssl.is_some()
197                {
198                    return Err(CliError::Other(
199                        "postgresql credentials: use either --url OR --host/--database[/--port/...], not both"
200                            .into(),
201                    ));
202                }
203
204                Credentials::Postgres(PostgresCreds::Url(url.into_boxed_str()))
205            } else {
206                let host = overrides.host.clone().ok_or_else(|| {
207                    CliError::Other("postgresql host credentials require --host".into())
208                })?;
209                let database = overrides.database.clone().ok_or_else(|| {
210                    CliError::Other("postgresql host credentials require --database".into())
211                })?;
212
213                Credentials::Postgres(PostgresCreds::Host {
214                    host: host.into_boxed_str(),
215                    port: overrides.port.unwrap_or(5432),
216                    user: overrides.user.clone().map(String::into_boxed_str),
217                    password: overrides.password.clone().map(String::into_boxed_str),
218                    database: database.into_boxed_str(),
219                    ssl: parse_ssl_override(overrides.ssl.as_deref())?.unwrap_or(false),
220                })
221            }
222        }
223    };
224
225    Ok(Some(creds))
226}
227
228fn parse_ssl_override(ssl: Option<&str>) -> Result<Option<bool>, CliError> {
229    let Some(raw) = ssl else {
230        return Ok(None);
231    };
232
233    let value = raw.trim().to_ascii_lowercase();
234    let enabled = match value.as_str() {
235        "true" | "1" | "yes" | "on" | "require" | "allow" | "prefer" | "verify-full"
236        | "verify-ca" => true,
237        "false" | "0" | "no" | "off" | "disable" => false,
238        _ => {
239            return Err(CliError::Other(format!(
240                "invalid --ssl value '{raw}'; expected one of: true,false,require,allow,prefer,verify-full,verify-ca,disable"
241            )));
242        }
243    };
244
245    Ok(Some(enabled))
246}
247
248#[must_use]
249pub fn resolve_filter_list(cli: Option<&[String]>, config: Option<&Filter>) -> Option<Vec<String>> {
250    if let Some(values) = cli {
251        if values.is_empty() {
252            return None;
253        }
254        return Some(values.to_vec());
255    }
256
257    config.map(|f| f.iter().map(ToOwned::to_owned).collect())
258}
259
260#[must_use]
261pub fn resolve_schema_filters(
262    dialect: Dialect,
263    cli: Option<&[String]>,
264    config: Option<&Filter>,
265) -> Option<Vec<String>> {
266    let resolved = resolve_filter_list(cli, config);
267    if resolved.is_some() {
268        return resolved;
269    }
270
271    if matches!(dialect, Dialect::Postgresql) {
272        Some(vec!["public".to_string()])
273    } else {
274        None
275    }
276}
277
278#[must_use]
279pub fn resolve_extensions_filter(
280    cli: Option<&[Extension]>,
281    config: Option<&[Extension]>,
282) -> Option<Vec<Extension>> {
283    if let Some(values) = cli {
284        if values.is_empty() {
285            return None;
286        }
287        return Some(values.to_vec());
288    }
289
290    config.map(<[Extension]>::to_vec)
291}
292
293#[must_use]
294pub fn resolve_schema_display(db: &DatabaseConfig, schema_override: Option<&[String]>) -> String {
295    match schema_override {
296        Some(v) if !v.is_empty() => v.join(", "),
297        _ => db.schema_display(),
298    }
299}
300
301/// Resolve the schema file paths the current command will operate on, using
302/// the CLI override if provided or the configured value otherwise.
303///
304/// # Errors
305///
306/// Returns [`CliError`] if a non-empty override resolves to zero files, if a
307/// glob pattern is invalid, or if expanding the configured schema patterns
308/// fails.
309pub fn resolve_schema_files(
310    db: &DatabaseConfig,
311    schema_override: Option<&[String]>,
312) -> Result<Vec<PathBuf>, CliError> {
313    let Some(schema_patterns) = schema_override else {
314        return db.schema_files().map_err(Into::into);
315    };
316
317    if schema_patterns.is_empty() {
318        return Err(CliError::NoSchemaFiles("(empty schema override)".into()));
319    }
320
321    let mut files = Vec::new();
322
323    for pattern in schema_patterns {
324        let pat = pattern.trim();
325        let is_glob = pat.contains('*') || pat.contains('?') || pat.contains('[');
326
327        if !is_glob {
328            let p = PathBuf::from(pat);
329            if p.exists() {
330                files.push(p);
331                continue;
332            }
333        }
334
335        let pat_norm = pat.replace('\\', "/");
336        let paths = glob::glob(&pat_norm)
337            .map_err(|e| CliError::Other(format!("invalid glob '{pat}': {e}")))?;
338        let matched: Vec<_> = paths.filter_map(Result::ok).collect();
339
340        if matched.is_empty() && !is_glob {
341            let p = PathBuf::from(&pat_norm);
342            if p.exists() {
343                files.push(p);
344            }
345        } else {
346            files.extend(matched);
347        }
348    }
349
350    files.retain(|p| p.is_file());
351    files.sort();
352    files.dedup();
353
354    if files.is_empty() {
355        return Err(CliError::NoSchemaFiles(
356            schema_patterns
357                .iter()
358                .map(std::string::String::as_str)
359                .collect::<Vec<_>>()
360                .join(", "),
361        ));
362    }
363
364    Ok(files)
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::config::Config;
371    use std::path::PathBuf;
372    use tempfile::TempDir;
373
374    fn load_db(config_toml: &str) -> (TempDir, DatabaseConfig) {
375        let dir = TempDir::new().expect("temp dir");
376        let path = dir.path().join("drizzle.config.toml");
377        std::fs::write(&path, config_toml).expect("write config");
378        let config = Config::load_from(&path).expect("load config");
379        let db = config.default_database().expect("default db").clone();
380        (dir, db)
381    }
382
383    #[test]
384    fn resolve_filter_list_prefers_cli_values() {
385        let config = Filter::Many(vec!["from_config".to_string()]);
386        let cli = vec!["from_cli".to_string()];
387
388        let resolved = resolve_filter_list(Some(&cli), Some(&config));
389        assert_eq!(resolved, Some(vec!["from_cli".to_string()]));
390    }
391
392    #[test]
393    fn resolve_filter_list_uses_config_when_cli_missing() {
394        let config = Filter::Many(vec!["public".to_string(), "dev".to_string()]);
395        let resolved = resolve_filter_list(None, Some(&config));
396        assert_eq!(
397            resolved,
398            Some(vec!["public".to_string(), "dev".to_string()])
399        );
400    }
401
402    #[test]
403    fn resolve_schema_filters_defaults_to_public_for_postgres() {
404        let resolved = resolve_schema_filters(Dialect::Postgresql, None, None);
405        assert_eq!(resolved, Some(vec!["public".to_string()]));
406    }
407
408    #[test]
409    fn resolve_schema_filters_does_not_default_for_sqlite() {
410        let resolved = resolve_schema_filters(Dialect::Sqlite, None, None);
411        assert_eq!(resolved, None);
412    }
413
414    #[test]
415    fn resolve_extensions_filter_prefers_cli_values() {
416        let cli = vec![Extension::Postgis];
417        let config = vec![];
418
419        let resolved = resolve_extensions_filter(Some(&cli), Some(&config));
420        assert_eq!(resolved, Some(vec![Extension::Postgis]));
421    }
422
423    #[test]
424    fn resolve_driver_rejects_invalid_override() {
425        let (_dir, db) = load_db(
426            r#"
427dialect = "sqlite"
428schema = "src/schema.rs"
429"#,
430        );
431
432        let err = resolve_driver(&db, Dialect::Sqlite, Some(Driver::TokioPostgres))
433            .expect_err("driver should be rejected");
434        assert_eq!(
435            err.to_string(),
436            "driver 'tokio-postgres' invalid for sqlite dialect"
437        );
438    }
439
440    #[test]
441    fn resolve_credentials_requires_overrides_for_dialect_switch() {
442        let (_dir, db) = load_db(
443            r#"
444dialect = "sqlite"
445[dbCredentials]
446url = "./dev.db"
447"#,
448        );
449
450        let err = resolve_credentials(&db, Dialect::Postgresql, &ConnectionOverrides::default())
451            .expect_err("dialect switch should require explicit credentials");
452        assert_eq!(
453            err.to_string(),
454            "--dialect=postgresql requires matching credential flags (--url/--host/--database/etc)"
455        );
456    }
457
458    #[test]
459    fn resolve_credentials_sqlite_rejects_host_fields() {
460        let (_dir, db) = load_db(
461            r#"
462dialect = "sqlite"
463"#,
464        );
465
466        let overrides = ConnectionOverrides {
467            host: Some("localhost".to_string()),
468            ..Default::default()
469        };
470
471        let err = resolve_credentials(&db, Dialect::Sqlite, &overrides)
472            .expect_err("sqlite should reject host-style credentials");
473        assert_eq!(
474            err.to_string(),
475            "sqlite credentials only support --url for local database path"
476        );
477    }
478
479    #[test]
480    fn resolve_credentials_postgres_rejects_mixed_url_and_host_fields() {
481        let (_dir, db) = load_db(
482            r#"
483dialect = "postgresql"
484"#,
485        );
486
487        let overrides = ConnectionOverrides {
488            url: Some("postgres://u:p@localhost:5432/db".to_string()),
489            host: Some("localhost".to_string()),
490            database: Some("db".to_string()),
491            ..Default::default()
492        };
493
494        let err = resolve_credentials(&db, Dialect::Postgresql, &overrides)
495            .expect_err("postgres should reject mixed credentials");
496        assert_eq!(
497            err.to_string(),
498            "postgresql credentials: use either --url OR --host/--database[/--port/...], not both"
499        );
500    }
501
502    #[test]
503    fn resolve_credentials_postgres_requires_database_for_host_mode() {
504        let (_dir, db) = load_db(
505            r#"
506dialect = "postgresql"
507"#,
508        );
509
510        let overrides = ConnectionOverrides {
511            host: Some("localhost".to_string()),
512            ..Default::default()
513        };
514
515        let err = resolve_credentials(&db, Dialect::Postgresql, &overrides)
516            .expect_err("postgres host credentials require database");
517        assert_eq!(
518            err.to_string(),
519            "postgresql host credentials require --database"
520        );
521    }
522
523    #[test]
524    fn resolve_credentials_turso_accepts_url_with_optional_token() {
525        let (_dir, db) = load_db(
526            r#"
527dialect = "turso"
528"#,
529        );
530
531        let overrides = ConnectionOverrides {
532            url: Some("libsql://example.turso.io".to_string()),
533            auth_token: Some("secret".to_string()),
534            ..Default::default()
535        };
536
537        let creds = resolve_credentials(&db, Dialect::Turso, &overrides)
538            .expect("resolve creds")
539            .expect("some creds");
540
541        match creds {
542            Credentials::Turso { url, auth_token } => {
543                assert_eq!(url.as_ref(), "libsql://example.turso.io");
544                assert_eq!(auth_token.as_deref(), Some("secret"));
545            }
546            _ => panic!("expected turso credentials"),
547        }
548    }
549
550    #[test]
551    fn resolve_credentials_postgres_host_mode_accepts_ssl_modes() {
552        let (_dir, db) = load_db(
553            r#"
554dialect = "postgresql"
555"#,
556        );
557
558        let require_ssl = ConnectionOverrides {
559            host: Some("localhost".to_string()),
560            database: Some("db".to_string()),
561            ssl: Some("require".to_string()),
562            ..Default::default()
563        };
564        let creds = resolve_credentials(&db, Dialect::Postgresql, &require_ssl)
565            .expect("resolve")
566            .expect("creds");
567        match creds {
568            Credentials::Postgres(PostgresCreds::Host { ssl, .. }) => assert!(ssl),
569            _ => panic!("expected postgres host creds"),
570        }
571
572        let disable_ssl = ConnectionOverrides {
573            host: Some("localhost".to_string()),
574            database: Some("db".to_string()),
575            ssl: Some("disable".to_string()),
576            ..Default::default()
577        };
578        let creds = resolve_credentials(&db, Dialect::Postgresql, &disable_ssl)
579            .expect("resolve")
580            .expect("creds");
581        match creds {
582            Credentials::Postgres(PostgresCreds::Host { ssl, .. }) => assert!(!ssl),
583            _ => panic!("expected postgres host creds"),
584        }
585    }
586
587    #[test]
588    fn resolve_credentials_postgres_host_mode_rejects_invalid_ssl_value() {
589        let (_dir, db) = load_db(
590            r#"
591dialect = "postgresql"
592"#,
593        );
594
595        let overrides = ConnectionOverrides {
596            host: Some("localhost".to_string()),
597            database: Some("db".to_string()),
598            ssl: Some("maybe".to_string()),
599            ..Default::default()
600        };
601
602        let err = resolve_credentials(&db, Dialect::Postgresql, &overrides)
603            .expect_err("invalid ssl should fail");
604        assert_eq!(
605            err.to_string(),
606            "invalid --ssl value 'maybe'; expected one of: true,false,require,allow,prefer,verify-full,verify-ca,disable"
607        );
608    }
609
610    #[test]
611    fn resolve_schema_filters_defaults_to_public_in_multi_db_postgres() {
612        let dir = TempDir::new().expect("temp dir");
613        let path = dir.path().join("drizzle.config.toml");
614        std::fs::write(
615            &path,
616            r#"
617[databases.pg]
618dialect = "postgresql"
619
620[databases.pg.dbCredentials]
621url = "postgres://localhost/db"
622
623[databases.sqlite]
624dialect = "sqlite"
625
626[databases.sqlite.dbCredentials]
627url = "./dev.db"
628"#,
629        )
630        .expect("write config");
631
632        let config = Config::load_from(&path).expect("load config");
633        let db = config.database(Some("pg")).expect("pg db");
634
635        let resolved = resolve_schema_filters(Dialect::Postgresql, None, db.schema_filter.as_ref());
636        assert_eq!(resolved, Some(vec!["public".to_string()]));
637    }
638
639    #[test]
640    fn resolve_schema_files_uses_override_glob() {
641        let (dir, db) = load_db(
642            r#"
643dialect = "sqlite"
644schema = "src/schema.rs"
645"#,
646        );
647
648        let a = dir.path().join("a.schema.rs");
649        let b = dir.path().join("b.schema.rs");
650        std::fs::write(&a, "pub struct A;").expect("write a");
651        std::fs::write(&b, "pub struct B;").expect("write b");
652
653        let pattern = format!("{}/*.schema.rs", dir.path().display()).replace('\\', "/");
654        let override_patterns = vec![pattern];
655        let files = resolve_schema_files(&db, Some(&override_patterns)).expect("resolve files");
656
657        let paths: Vec<PathBuf> = files;
658        assert_eq!(paths.len(), 2);
659        assert!(paths.iter().any(|p| p.ends_with("a.schema.rs")));
660        assert!(paths.iter().any(|p| p.ends_with("b.schema.rs")));
661    }
662}