drizzle_cli/
config.rs

1//! Configuration for Drizzle CLI
2//!
3//! Handles loading `drizzle.config.toml` with type-safe credentials.
4//! Supports both single-database (legacy) and multi-database configurations.
5//!
6//! This configuration format is designed to be compatible with drizzle-kit
7//! so TypeScript users can use the same config expectations.
8
9use serde::Deserialize;
10use serde::de::{self, Deserializer, MapAccess, Visitor};
11use std::collections::HashMap;
12use std::path::{Path, PathBuf};
13
14pub const CONFIG_FILE: &str = "drizzle.config.toml";
15
16// ============================================================================
17// Casing Options (matching drizzle-kit)
18// ============================================================================
19
20/// Casing mode for generated code and SQL identifiers
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Deserialize)]
22pub enum Casing {
23    /// camelCase - e.g., "userId", "createdAt"
24    #[default]
25    #[serde(rename = "camelCase")]
26    CamelCase,
27    /// snake_case - e.g., "user_id", "created_at"
28    #[serde(rename = "snake_case")]
29    SnakeCase,
30}
31
32impl Casing {
33    pub const fn as_str(self) -> &'static str {
34        match self {
35            Self::CamelCase => "camelCase",
36            Self::SnakeCase => "snake_case",
37        }
38    }
39}
40
41impl std::fmt::Display for Casing {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.write_str(self.as_str())
44    }
45}
46
47impl std::str::FromStr for Casing {
48    type Err = String;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        match s {
52            "camelCase" | "camel" => Ok(Self::CamelCase),
53            "snake_case" | "snake" => Ok(Self::SnakeCase),
54            _ => Err(format!(
55                "invalid casing '{}', expected 'camelCase' or 'snake_case'",
56                s
57            )),
58        }
59    }
60}
61
62/// Casing mode for introspection (pull command)
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Deserialize)]
64pub enum IntrospectCasing {
65    /// Convert database names to camelCase
66    #[default]
67    #[serde(rename = "camel")]
68    Camel,
69    /// Preserve original database names
70    #[serde(rename = "preserve")]
71    Preserve,
72}
73
74impl IntrospectCasing {
75    pub const fn as_str(self) -> &'static str {
76        match self {
77            Self::Camel => "camel",
78            Self::Preserve => "preserve",
79        }
80    }
81}
82
83impl std::fmt::Display for IntrospectCasing {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        f.write_str(self.as_str())
86    }
87}
88
89impl std::str::FromStr for IntrospectCasing {
90    type Err = String;
91
92    fn from_str(s: &str) -> Result<Self, Self::Err> {
93        match s {
94            "camel" | "camelCase" => Ok(Self::Camel),
95            "preserve" => Ok(Self::Preserve),
96            _ => Err(format!(
97                "invalid introspect casing '{}', expected 'camel' or 'preserve'",
98                s
99            )),
100        }
101    }
102}
103
104/// Introspection configuration
105#[derive(Debug, Clone, Default, Deserialize)]
106pub struct IntrospectConfig {
107    /// Casing mode for introspected identifiers
108    #[serde(default)]
109    pub casing: IntrospectCasing,
110}
111
112// ============================================================================
113// Entities Filter (matching drizzle-kit)
114// ============================================================================
115
116/// Roles filter configuration
117///
118/// Can be either a boolean (true = include all, false = exclude all)
119/// or a detailed configuration with provider/include/exclude lists.
120#[derive(Debug, Clone, Deserialize)]
121#[serde(untagged)]
122pub enum RolesFilter {
123    /// Simple boolean: true = include all user roles, false = exclude all
124    Bool(bool),
125    /// Detailed configuration
126    Config {
127        /// Provider preset (e.g., "supabase", "neon") - excludes provider-specific roles
128        #[serde(default)]
129        provider: Option<String>,
130        /// Explicit list of role names to include
131        #[serde(default)]
132        include: Option<Vec<String>>,
133        /// Explicit list of role names to exclude
134        #[serde(default)]
135        exclude: Option<Vec<String>>,
136    },
137}
138
139impl Default for RolesFilter {
140    fn default() -> Self {
141        Self::Bool(false)
142    }
143}
144
145impl RolesFilter {
146    /// Check if roles should be included at all
147    pub fn is_enabled(&self) -> bool {
148        match self {
149            Self::Bool(b) => *b,
150            Self::Config { .. } => true,
151        }
152    }
153
154    /// Check if a specific role should be included
155    pub fn should_include(&self, role_name: &str) -> bool {
156        match self {
157            Self::Bool(b) => *b,
158            Self::Config {
159                provider,
160                include,
161                exclude,
162            } => {
163                // Check provider exclusions
164                if let Some(p) = provider
165                    && is_provider_role(p, role_name)
166                {
167                    return false;
168                }
169                // Check explicit exclude list
170                if let Some(excl) = exclude
171                    && excl.iter().any(|e| e == role_name)
172                {
173                    return false;
174                }
175                // Check explicit include list (if specified, only include those)
176                if let Some(incl) = include {
177                    return incl.iter().any(|i| i == role_name);
178                }
179                true
180            }
181        }
182    }
183}
184
185/// Check if a role belongs to a provider's built-in roles
186fn is_provider_role(provider: &str, role_name: &str) -> bool {
187    match provider {
188        "supabase" => matches!(
189            role_name,
190            "anon"
191                | "authenticated"
192                | "service_role"
193                | "supabase_admin"
194                | "supabase_auth_admin"
195                | "supabase_storage_admin"
196                | "dashboard_user"
197                | "supabase_replication_admin"
198                | "supabase_read_only_user"
199                | "supabase_realtime_admin"
200                | "supabase_functions_admin"
201                | "postgres"
202                | "pgbouncer"
203                | "pgsodium_keyholder"
204                | "pgsodium_keyiduser"
205                | "pgsodium_keymaker"
206        ),
207        "neon" => matches!(
208            role_name,
209            "neon_superuser" | "cloud_admin" | "authenticated" | "anonymous"
210        ),
211        _ => false,
212    }
213}
214
215/// Entities filter configuration
216///
217/// Controls which database entities are included in push/pull operations.
218#[derive(Debug, Clone, Default, Deserialize)]
219pub struct EntitiesFilter {
220    /// Roles filter (PostgreSQL only)
221    #[serde(default)]
222    pub roles: RolesFilter,
223}
224
225// ============================================================================
226// Extensions Filter (PostgreSQL only)
227// ============================================================================
228
229/// Known PostgreSQL extensions that can be filtered
230#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
231#[serde(rename_all = "lowercase")]
232pub enum Extension {
233    /// PostGIS spatial extension
234    Postgis,
235}
236
237impl Extension {
238    pub const fn as_str(self) -> &'static str {
239        match self {
240            Self::Postgis => "postgis",
241        }
242    }
243}
244
245impl std::fmt::Display for Extension {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        f.write_str(self.as_str())
248    }
249}
250
251// ============================================================================
252// EnvOr - Environment variable or direct value
253// ============================================================================
254
255/// A value that can be either a direct string or an environment variable reference.
256///
257/// In TOML config, users can write:
258/// ```toml
259/// url = "postgres://localhost/db"           # Direct value
260/// url = { env = "DATABASE_URL" }            # Environment variable
261/// ```
262#[derive(Debug, Clone)]
263pub enum EnvOr {
264    /// Direct string value
265    Value(String),
266    /// Environment variable name to resolve
267    Env(String),
268}
269
270impl EnvOr {
271    /// Resolve the value, looking up environment variable if needed
272    pub fn resolve(&self) -> Result<String, Error> {
273        match self {
274            Self::Value(v) => Ok(v.clone()),
275            Self::Env(var) => std::env::var(var).map_err(|_| Error::EnvNotFound(var.clone())),
276        }
277    }
278
279    /// Resolve to an optional value (returns None for missing env vars)
280    pub fn resolve_optional(&self) -> Result<Option<String>, Error> {
281        match self {
282            Self::Value(v) => Ok(Some(v.clone())),
283            Self::Env(var) => match std::env::var(var) {
284                Ok(v) => Ok(Some(v)),
285                Err(std::env::VarError::NotPresent) => Ok(None),
286                Err(std::env::VarError::NotUnicode(_)) => Err(Error::EnvInvalid(
287                    var.clone(),
288                    "contains invalid unicode".into(),
289                )),
290            },
291        }
292    }
293}
294
295impl<'de> Deserialize<'de> for EnvOr {
296    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
297    where
298        D: Deserializer<'de>,
299    {
300        struct EnvOrVisitor;
301
302        impl<'de> Visitor<'de> for EnvOrVisitor {
303            type Value = EnvOr;
304
305            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
306                formatter.write_str("a string or { env = \"VAR_NAME\" }")
307            }
308
309            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
310            where
311                E: de::Error,
312            {
313                Ok(EnvOr::Value(value.to_string()))
314            }
315
316            fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
317            where
318                M: MapAccess<'de>,
319            {
320                let mut env_var: Option<String> = None;
321
322                while let Some(key) = map.next_key::<String>()? {
323                    if key == "env" {
324                        env_var = Some(map.next_value()?);
325                    } else {
326                        return Err(de::Error::unknown_field(&key, &["env"]));
327                    }
328                }
329
330                env_var
331                    .map(EnvOr::Env)
332                    .ok_or_else(|| de::Error::missing_field("env"))
333            }
334        }
335
336        deserializer.deserialize_any(EnvOrVisitor)
337    }
338}
339
340// ============================================================================
341// Dialect
342// ============================================================================
343
344/// Database dialect
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Deserialize)]
346#[serde(rename_all = "lowercase")]
347pub enum Dialect {
348    #[default]
349    Sqlite,
350    #[serde(alias = "postgres")]
351    Postgresql,
352    Turso,
353}
354
355impl Dialect {
356    pub const ALL: &'static [&'static str] = &["sqlite", "postgresql", "turso"];
357
358    #[inline]
359    pub const fn as_str(self) -> &'static str {
360        match self {
361            Self::Sqlite => "sqlite",
362            Self::Postgresql => "postgresql",
363            Self::Turso => "turso",
364        }
365    }
366
367    #[inline]
368    pub const fn to_base(self) -> drizzle_types::Dialect {
369        match self {
370            Self::Sqlite | Self::Turso => drizzle_types::Dialect::SQLite,
371            Self::Postgresql => drizzle_types::Dialect::PostgreSQL,
372        }
373    }
374}
375
376impl std::fmt::Display for Dialect {
377    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378        f.write_str(self.as_str())
379    }
380}
381
382impl From<Dialect> for drizzle_types::Dialect {
383    #[inline]
384    fn from(d: Dialect) -> Self {
385        d.to_base()
386    }
387}
388
389// ============================================================================
390// Driver
391// ============================================================================
392
393/// Database driver for Rust database connections
394#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
395#[serde(rename_all = "kebab-case")]
396pub enum Driver {
397    /// rusqlite - synchronous SQLite driver
398    Rusqlite,
399    /// libsql - LibSQL driver (local embedded)
400    Libsql,
401    /// turso - Turso cloud driver (remote)
402    Turso,
403    /// postgres-sync - synchronous PostgreSQL driver
404    PostgresSync,
405    /// tokio-postgres - async PostgreSQL driver
406    TokioPostgres,
407}
408
409impl Driver {
410    pub const ALL: &'static [&'static str] = &[
411        "rusqlite",
412        "libsql",
413        "turso",
414        "postgres-sync",
415        "tokio-postgres",
416    ];
417
418    #[inline]
419    pub const fn as_str(self) -> &'static str {
420        match self {
421            Self::Rusqlite => "rusqlite",
422            Self::Libsql => "libsql",
423            Self::Turso => "turso",
424            Self::PostgresSync => "postgres-sync",
425            Self::TokioPostgres => "tokio-postgres",
426        }
427    }
428
429    pub const fn valid_for(dialect: Dialect) -> &'static [Driver] {
430        match dialect {
431            Dialect::Sqlite => &[Self::Rusqlite],
432            Dialect::Turso => &[Self::Libsql, Self::Turso],
433            Dialect::Postgresql => &[Self::PostgresSync, Self::TokioPostgres],
434        }
435    }
436
437    #[inline]
438    pub const fn is_valid_for(self, dialect: Dialect) -> bool {
439        matches!(
440            (self, dialect),
441            (Self::Rusqlite, Dialect::Sqlite)
442                | (Self::Libsql | Self::Turso, Dialect::Turso)
443                | (
444                    Self::PostgresSync | Self::TokioPostgres,
445                    Dialect::Postgresql
446                )
447        )
448    }
449}
450
451impl std::fmt::Display for Driver {
452    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453        f.write_str(self.as_str())
454    }
455}
456
457// ============================================================================
458// Credentials
459// ============================================================================
460
461/// Database credentials - validated and typed
462#[derive(Debug, Clone)]
463pub enum Credentials {
464    /// Local SQLite file
465    Sqlite { path: Box<str> },
466
467    /// Turso/LibSQL
468    Turso {
469        url: Box<str>,
470        auth_token: Option<Box<str>>,
471    },
472
473    /// PostgreSQL
474    Postgres(PostgresCreds),
475}
476
477/// PostgreSQL credentials
478#[derive(Debug, Clone)]
479pub enum PostgresCreds {
480    Url(Box<str>),
481    Host {
482        host: Box<str>,
483        port: u16,
484        user: Option<Box<str>>,
485        password: Option<Box<str>>,
486        database: Box<str>,
487        ssl: bool,
488    },
489}
490
491impl PostgresCreds {
492    /// Build connection URL
493    pub fn connection_url(&self) -> String {
494        match self {
495            Self::Url(url) => url.to_string(),
496            Self::Host {
497                host,
498                port,
499                user,
500                password,
501                database,
502                ..
503            } => {
504                let auth = match (user, password) {
505                    (Some(u), Some(p)) => format!("{u}:{p}@"),
506                    (Some(u), None) => format!("{u}@"),
507                    _ => String::new(),
508                };
509                format!("postgres://{auth}{host}:{port}/{database}")
510            }
511        }
512    }
513}
514
515// ============================================================================
516// Schema path(s)
517// ============================================================================
518
519/// Schema path(s)
520#[derive(Debug, Clone, Deserialize)]
521#[serde(untagged)]
522pub enum Schema {
523    One(String),
524    Many(Vec<String>),
525}
526
527impl Default for Schema {
528    fn default() -> Self {
529        Self::One("src/schema.rs".into())
530    }
531}
532
533impl Schema {
534    pub fn iter(&self) -> impl Iterator<Item = &str> {
535        match self {
536            Self::One(s) => std::slice::from_ref(s).iter().map(String::as_str),
537            Self::Many(v) => v.iter().map(String::as_str),
538        }
539    }
540}
541
542/// Filter (single or multiple values)
543#[derive(Debug, Clone, Deserialize)]
544#[serde(untagged)]
545pub enum Filter {
546    One(String),
547    Many(Vec<String>),
548}
549
550impl Filter {
551    pub fn iter(&self) -> impl Iterator<Item = &str> {
552        match self {
553            Self::One(s) => std::slice::from_ref(s).iter().map(String::as_str),
554            Self::Many(v) => v.iter().map(String::as_str),
555        }
556    }
557}
558
559/// Migration options
560#[derive(Debug, Clone, Deserialize)]
561pub struct MigrationsOpts {
562    pub table: Option<String>,
563    pub schema: Option<String>,
564    pub prefix: Option<MigrationPrefix>,
565}
566
567#[derive(Debug, Clone, Copy, Deserialize)]
568#[serde(rename_all = "lowercase")]
569pub enum MigrationPrefix {
570    Index,
571    Timestamp,
572    Supabase,
573    Unix,
574    None,
575}
576
577// ============================================================================
578// Raw credentials (serde parsing helper)
579// ============================================================================
580
581#[derive(Debug, Clone, Deserialize)]
582#[serde(untagged)]
583enum RawCreds {
584    Url {
585        url: EnvOr,
586        #[serde(default, rename = "authToken")]
587        auth_token: Option<EnvOr>,
588    },
589    Host {
590        host: EnvOr,
591        #[serde(default)]
592        port: Option<u16>,
593        #[serde(default)]
594        user: Option<EnvOr>,
595        #[serde(default)]
596        password: Option<EnvOr>,
597        database: EnvOr,
598        #[serde(default)]
599        ssl: Option<SslVal>,
600    },
601}
602
603#[derive(Debug, Clone, Deserialize)]
604#[serde(untagged)]
605enum SslVal {
606    Bool(bool),
607    Str(String),
608}
609
610impl SslVal {
611    fn enabled(&self) -> bool {
612        match self {
613            Self::Bool(b) => *b,
614            Self::Str(s) => !matches!(s.as_str(), "disable" | "false" | "no" | "off"),
615        }
616    }
617}
618
619// ============================================================================
620// DatabaseConfig - Per-database configuration
621// ============================================================================
622
623/// Configuration for a single database
624///
625/// This structure matches drizzle-kit's config format for compatibility.
626#[derive(Debug, Clone, Deserialize)]
627#[serde(rename_all = "camelCase")]
628pub struct DatabaseConfig {
629    /// Database dialect (required)
630    pub dialect: Dialect,
631
632    /// Path(s) to schema file(s) - supports glob patterns
633    #[serde(default)]
634    pub schema: Schema,
635
636    /// Output directory for migrations (default: "./drizzle")
637    #[serde(default = "default_out")]
638    pub out: PathBuf,
639
640    /// Whether to use SQL breakpoints in migrations (default: true)
641    #[serde(default = "yes")]
642    pub breakpoints: bool,
643
644    /// Database driver for Rust connections
645    #[serde(default)]
646    pub driver: Option<Driver>,
647
648    /// Database credentials
649    #[serde(default)]
650    db_credentials: Option<RawCreds>,
651
652    /// Table name filter (glob patterns supported)
653    #[serde(default)]
654    pub tables_filter: Option<Filter>,
655
656    /// Schema name filter (PostgreSQL only)
657    #[serde(default)]
658    pub schema_filter: Option<Filter>,
659
660    /// Extensions filter (PostgreSQL only, e.g., ["postgis"])
661    #[serde(default)]
662    pub extensions_filters: Option<Vec<Extension>>,
663
664    /// Entities filter (roles, etc.)
665    #[serde(default)]
666    pub entities: Option<EntitiesFilter>,
667
668    /// Casing mode for generated code
669    #[serde(default)]
670    pub casing: Option<Casing>,
671
672    /// Introspection configuration
673    #[serde(default)]
674    pub introspect: Option<IntrospectConfig>,
675
676    /// Verbose output
677    #[serde(default)]
678    pub verbose: bool,
679
680    /// Strict mode - deprecated, use explain flag instead
681    #[serde(default)]
682    pub strict: bool,
683
684    /// Migration table configuration
685    #[serde(default)]
686    pub migrations: Option<MigrationsOpts>,
687}
688
689fn default_out() -> PathBuf {
690    PathBuf::from("./drizzle")
691}
692
693fn yes() -> bool {
694    true
695}
696
697impl DatabaseConfig {
698    fn normalize_paths(&mut self, base_dir: &Path) {
699        // Resolve `out` relative to the config file directory for predictable behavior,
700        // especially when `--config` points at a file outside the current working directory.
701        if self.out.is_relative() {
702            self.out = base_dir.join(&self.out);
703        }
704
705        // Normalize schema patterns:
706        // - Resolve relative patterns relative to config dir
707        // - Use forward slashes to avoid glob escaping issues on Windows
708        let base = base_dir.to_string_lossy().replace('\\', "/");
709        let base = base.trim_end_matches('/').to_string();
710
711        let normalize_one = |p: &str| -> String {
712            let p_trim = p.trim();
713            let is_abs = Path::new(p_trim).is_absolute() || p_trim.starts_with("\\\\");
714            let joined = if is_abs || base.is_empty() || base == "." {
715                p_trim.to_string()
716            } else {
717                format!("{base}/{p_trim}")
718            };
719            joined.replace('\\', "/")
720        };
721
722        match &mut self.schema {
723            Schema::One(p) => *p = normalize_one(p),
724            Schema::Many(v) => {
725                for p in v.iter_mut() {
726                    *p = normalize_one(p);
727                }
728            }
729        }
730    }
731
732    fn validate(&self, name: &str) -> Result<(), Error> {
733        // Check driver compatibility
734        if let Some(d) = self.driver
735            && !d.is_valid_for(self.dialect)
736        {
737            return Err(Error::InvalidDriver {
738                driver: d,
739                dialect: self.dialect,
740            });
741        }
742
743        // Validate credentials if present
744        if let Some(ref raw) = self.db_credentials {
745            self.validate_creds(raw, name)?;
746        }
747
748        Ok(())
749    }
750
751    fn validate_creds(&self, raw: &RawCreds, _name: &str) -> Result<(), Error> {
752        let err = |msg: &str| Error::InvalidCredentials(msg.into());
753
754        // Enforce dialect/shape pairing. Without this, serde can parse a "host" form for
755        // any dialect, and later `credentials()` would silently return None.
756        match (self.dialect, raw) {
757            (Dialect::Postgresql, RawCreds::Host { .. }) => {}
758            (Dialect::Postgresql, RawCreds::Url { .. }) => {}
759            (_, RawCreds::Host { .. }) => {
760                return Err(err(
761                    "host-based dbCredentials are only supported for dialect = \"postgresql\"",
762                ));
763            }
764            _ => {}
765        }
766
767        // Dialect-specific checks (only for direct values, not env var references)
768        match (self.dialect, raw) {
769            (
770                Dialect::Sqlite,
771                RawCreds::Url {
772                    auth_token: Some(_),
773                    ..
774                },
775            ) => Err(err(
776                "SQLite doesn't support authToken (use dialect = \"turso\")",
777            )),
778            (
779                Dialect::Sqlite,
780                RawCreds::Url {
781                    url: EnvOr::Value(url),
782                    ..
783                },
784            ) if url.starts_with("libsql://") => Err(err(
785                "libsql:// URLs require dialect = \"turso\" (for local SQLite files, use ./path.db)",
786            )),
787            (
788                Dialect::Sqlite,
789                RawCreds::Url {
790                    url: EnvOr::Value(url),
791                    ..
792                },
793            ) if url.starts_with("http://")
794                || url.starts_with("https://")
795                || url.starts_with("postgres://")
796                || url.starts_with("postgresql://") =>
797            {
798                Err(err(
799                    "SQLite dbCredentials.url must be a local file path (not an http(s)/postgres URL)",
800                ))
801            }
802            (
803                Dialect::Turso,
804                RawCreds::Url {
805                    url: EnvOr::Value(url),
806                    ..
807                },
808            ) if !url.starts_with("libsql://") && !url.starts_with("http") => {
809                Err(err("Turso URL must start with libsql:// or http(s)://"))
810            }
811            (
812                Dialect::Postgresql,
813                RawCreds::Url {
814                    url: EnvOr::Value(url),
815                    ..
816                },
817            ) if !url.starts_with("postgres") => {
818                Err(err("PostgreSQL URL must start with postgres://"))
819            }
820            _ => Ok(()),
821        }
822    }
823
824    /// Get typed credentials, resolving any environment variable references
825    pub fn credentials(&self) -> Result<Option<Credentials>, Error> {
826        let raw = match self.db_credentials.as_ref() {
827            Some(r) => r,
828            None => return Ok(None),
829        };
830
831        // Helper to resolve an optional EnvOr
832        let resolve_opt = |opt: &Option<EnvOr>| -> Result<Option<Box<str>>, Error> {
833            match opt {
834                Some(e) => e.resolve().map(|s| Some(s.into_boxed_str())),
835                None => Ok(None),
836            }
837        };
838
839        let creds = match (self.dialect, raw) {
840            // SQLite
841            (Dialect::Sqlite, RawCreds::Url { url, .. }) => Credentials::Sqlite {
842                path: url.resolve()?.into_boxed_str(),
843            },
844            // Turso
845            (Dialect::Turso, RawCreds::Url { url, auth_token }) => Credentials::Turso {
846                url: url.resolve()?.into_boxed_str(),
847                auth_token: resolve_opt(auth_token)?,
848            },
849            // PostgreSQL URL
850            (Dialect::Postgresql, RawCreds::Url { url, .. }) => {
851                Credentials::Postgres(PostgresCreds::Url(url.resolve()?.into_boxed_str()))
852            }
853            // PostgreSQL Host
854            (
855                Dialect::Postgresql,
856                RawCreds::Host {
857                    host,
858                    port,
859                    user,
860                    password,
861                    database,
862                    ssl,
863                },
864            ) => Credentials::Postgres(PostgresCreds::Host {
865                host: host.resolve()?.into_boxed_str(),
866                port: port.unwrap_or(5432),
867                user: resolve_opt(user)?,
868                password: resolve_opt(password)?,
869                database: database.resolve()?.into_boxed_str(),
870                ssl: ssl.as_ref().map(|s| s.enabled()).unwrap_or(false),
871            }),
872            _ => return Ok(None),
873        };
874
875        Ok(Some(creds))
876    }
877
878    /// Migrations output directory
879    #[inline]
880    pub fn migrations_dir(&self) -> &Path {
881        &self.out
882    }
883
884    /// Meta directory (for journal)
885    #[inline]
886    pub fn meta_dir(&self) -> PathBuf {
887        self.out.join("meta")
888    }
889
890    /// Journal file path
891    #[inline]
892    pub fn journal_path(&self) -> PathBuf {
893        self.meta_dir().join("_journal.json")
894    }
895
896    /// Schema paths display string
897    pub fn schema_display(&self) -> String {
898        match &self.schema {
899            Schema::One(s) => s.clone(),
900            Schema::Many(v) => v.join(", "),
901        }
902    }
903
904    /// Resolve schema files (with glob support)
905    pub fn schema_files(&self) -> Result<Vec<PathBuf>, Error> {
906        let mut files = Vec::new();
907
908        for pattern in self.schema.iter() {
909            let pat = pattern.trim();
910
911            // If it's not a glob pattern, treat it as a direct path (better Windows behavior).
912            let is_glob = pat.contains('*') || pat.contains('?') || pat.contains('[');
913            if !is_glob {
914                let p = PathBuf::from(pat);
915                if p.exists() {
916                    files.push(p);
917                    continue;
918                }
919            }
920
921            // Glob patterns: normalize separators to avoid `\` being treated as an escape.
922            let pat_norm = pat.replace('\\', "/");
923            match glob::glob(&pat_norm) {
924                Ok(paths) => {
925                    let matched: Vec<_> = paths.filter_map(Result::ok).collect();
926                    if matched.is_empty() && !is_glob {
927                        let p = PathBuf::from(&pat_norm);
928                        if p.exists() {
929                            files.push(p);
930                        }
931                    } else {
932                        files.extend(matched);
933                    }
934                }
935                Err(e) => return Err(Error::Glob(pat.into(), e)),
936            }
937        }
938
939        // Keep only real files (glob can return directories).
940        files.retain(|p| p.is_file());
941        files.sort();
942        files.dedup();
943
944        if files.is_empty() {
945            return Err(Error::NoSchemaFiles(self.schema_display()));
946        }
947
948        Ok(files)
949    }
950
951    /// Get effective casing mode (default: camelCase)
952    #[inline]
953    pub fn effective_casing(&self) -> Casing {
954        self.casing.unwrap_or_default()
955    }
956
957    /// Get effective introspect casing mode (default: camel)
958    #[inline]
959    pub fn effective_introspect_casing(&self) -> IntrospectCasing {
960        self.introspect
961            .as_ref()
962            .map(|i| i.casing)
963            .unwrap_or_default()
964    }
965
966    /// Get entities filter (default: empty)
967    #[inline]
968    pub fn effective_entities(&self) -> EntitiesFilter {
969        self.entities.clone().unwrap_or_default()
970    }
971
972    /// Check if a role should be included based on entities filter
973    pub fn should_include_role(&self, role_name: &str) -> bool {
974        self.entities
975            .as_ref()
976            .map(|e| e.roles.should_include(role_name))
977            .unwrap_or(false)
978    }
979
980    /// Check if roles are enabled in entities filter
981    pub fn roles_enabled(&self) -> bool {
982        self.entities
983            .as_ref()
984            .map(|e| e.roles.is_enabled())
985            .unwrap_or(false)
986    }
987
988    /// Get extensions filters (PostgreSQL only)
989    pub fn extensions(&self) -> &[Extension] {
990        self.extensions_filters.as_deref().unwrap_or(&[])
991    }
992
993    /// Check if an extension is in the filter list
994    pub fn has_extension(&self, ext: Extension) -> bool {
995        self.extensions_filters
996            .as_ref()
997            .map(|v| v.contains(&ext))
998            .unwrap_or(false)
999    }
1000
1001    /// Get migration table name (default: __drizzle_migrations)
1002    pub fn migrations_table(&self) -> &str {
1003        self.migrations
1004            .as_ref()
1005            .and_then(|m| m.table.as_deref())
1006            .unwrap_or("__drizzle_migrations")
1007    }
1008
1009    /// Get migration schema (PostgreSQL only, default: drizzle)
1010    pub fn migrations_schema(&self) -> &str {
1011        self.migrations
1012            .as_ref()
1013            .and_then(|m| m.schema.as_deref())
1014            .unwrap_or("drizzle")
1015    }
1016}
1017
1018// ============================================================================
1019// Main Configuration - Wrapper for single/multi-database modes
1020// ============================================================================
1021
1022/// Internal format for multi-database config
1023#[derive(Debug, Clone, Deserialize)]
1024struct MultiDbConfig {
1025    databases: HashMap<String, DatabaseConfig>,
1026}
1027
1028/// Main configuration structure
1029///
1030/// Supports both single-database (legacy) and multi-database configurations:
1031///
1032/// Single database:
1033/// ```toml
1034/// dialect = "sqlite"
1035/// [dbCredentials]
1036/// url = "./dev.db"
1037/// ```
1038///
1039/// Multiple databases:
1040/// ```toml
1041/// [databases.dev]
1042/// dialect = "sqlite"
1043/// [databases.dev.dbCredentials]
1044/// url = "./dev.db"
1045///
1046/// [databases.prod]
1047/// dialect = "postgresql"
1048/// [databases.prod.dbCredentials]
1049/// url = { env = "DATABASE_URL" }
1050/// ```
1051#[derive(Debug, Clone)]
1052pub struct Config {
1053    /// Named database configurations
1054    databases: HashMap<String, DatabaseConfig>,
1055    /// Whether this is a single-database config (for backwards compat)
1056    is_single: bool,
1057}
1058
1059/// Default database name for single-database configs
1060pub const DEFAULT_DB: &str = "default";
1061
1062impl Config {
1063    /// Load from default config file
1064    pub fn load() -> Result<Self, Error> {
1065        Self::load_from(Path::new(CONFIG_FILE))
1066    }
1067
1068    /// Load from specific path
1069    pub fn load_from(path: &Path) -> Result<Self, Error> {
1070        let content = std::fs::read_to_string(path).map_err(|e| {
1071            if e.kind() == std::io::ErrorKind::NotFound {
1072                Error::NotFound(path.into())
1073            } else {
1074                Error::Io(path.into(), e)
1075            }
1076        })?;
1077
1078        Self::load_from_str(&content, path)
1079    }
1080
1081    /// Load from string content
1082    fn load_from_str(content: &str, path: &Path) -> Result<Self, Error> {
1083        let base_dir = path.parent().unwrap_or_else(|| Path::new("."));
1084
1085        // Try multi-database format first
1086        if let Ok(multi) = toml::from_str::<MultiDbConfig>(content)
1087            && !multi.databases.is_empty()
1088        {
1089            let mut config = Self {
1090                databases: multi.databases,
1091                is_single: false,
1092            };
1093            for db in config.databases.values_mut() {
1094                db.normalize_paths(base_dir);
1095            }
1096            config.validate()?;
1097            return Ok(config);
1098        }
1099
1100        // Fall back to single-database format
1101        let db_config: DatabaseConfig =
1102            toml::from_str(content).map_err(|e| Error::Parse(path.into(), e))?;
1103
1104        let mut databases = HashMap::new();
1105        databases.insert(DEFAULT_DB.to_string(), db_config);
1106
1107        let mut config = Self {
1108            databases,
1109            is_single: true,
1110        };
1111        for db in config.databases.values_mut() {
1112            db.normalize_paths(base_dir);
1113        }
1114        config.validate()?;
1115        Ok(config)
1116    }
1117
1118    fn validate(&mut self) -> Result<(), Error> {
1119        for (name, db) in &self.databases {
1120            db.validate(name)?;
1121        }
1122        Ok(())
1123    }
1124
1125    /// Check if this is a single-database config
1126    pub fn is_single_database(&self) -> bool {
1127        self.is_single
1128    }
1129
1130    /// Get all database names
1131    pub fn database_names(&self) -> impl Iterator<Item = &str> {
1132        self.databases.keys().map(String::as_str)
1133    }
1134
1135    /// Get a specific database config by name
1136    ///
1137    /// If name is None, returns the default/only database.
1138    /// For single-db configs, any name or None returns the single database.
1139    pub fn database(&self, name: Option<&str>) -> Result<&DatabaseConfig, Error> {
1140        match name {
1141            None => {
1142                // Get default
1143                if self.is_single {
1144                    self.databases.get(DEFAULT_DB).ok_or(Error::NoDatabases)
1145                } else if self.databases.len() == 1 {
1146                    self.databases.values().next().ok_or(Error::NoDatabases)
1147                } else {
1148                    Err(Error::DatabaseRequired(
1149                        self.databases.keys().cloned().collect(),
1150                    ))
1151                }
1152            }
1153            Some(name) => {
1154                if self.is_single {
1155                    // For single-db config, accept any name
1156                    self.databases.get(DEFAULT_DB).ok_or(Error::NoDatabases)
1157                } else {
1158                    self.databases
1159                        .get(name)
1160                        .ok_or_else(|| Error::DatabaseNotFound(name.to_string()))
1161                }
1162            }
1163        }
1164    }
1165
1166    /// Get the default database (for single-db mode or when only one db exists)
1167    pub fn default_database(&self) -> Result<&DatabaseConfig, Error> {
1168        self.database(None)
1169    }
1170
1171    // ========================================================================
1172    // Backwards compatibility - delegate to default database
1173    // ========================================================================
1174
1175    /// Get dialect (for single-db mode backwards compat)
1176    pub fn dialect(&self) -> Dialect {
1177        self.default_database()
1178            .map(|d| d.dialect)
1179            .unwrap_or_default()
1180    }
1181
1182    /// Get credentials (for single-db mode backwards compat)
1183    pub fn credentials(&self) -> Result<Option<Credentials>, Error> {
1184        self.default_database()?.credentials()
1185    }
1186
1187    /// Get migrations directory (for single-db mode backwards compat)
1188    pub fn migrations_dir(&self) -> &Path {
1189        self.default_database()
1190            .map(|d| d.migrations_dir())
1191            .unwrap_or(Path::new("./drizzle"))
1192    }
1193
1194    /// Get journal path (for single-db mode backwards compat)
1195    pub fn journal_path(&self) -> PathBuf {
1196        self.default_database()
1197            .map(|d| d.journal_path())
1198            .unwrap_or_else(|_| PathBuf::from("./drizzle/meta/_journal.json"))
1199    }
1200
1201    /// Get schema display (for single-db mode backwards compat)
1202    pub fn schema_display(&self) -> String {
1203        self.default_database()
1204            .map(|d| d.schema_display())
1205            .unwrap_or_else(|_| "src/schema.rs".into())
1206    }
1207
1208    /// Get schema files (for single-db mode backwards compat)
1209    pub fn schema_files(&self) -> Result<Vec<PathBuf>, Error> {
1210        self.default_database()?.schema_files()
1211    }
1212
1213    /// Base dialect for SQL generation (for single-db mode backwards compat)
1214    pub fn base_dialect(&self) -> drizzle_types::Dialect {
1215        self.dialect().to_base()
1216    }
1217}
1218
1219// Re-export as DrizzleConfig for compatibility
1220pub type DrizzleConfig = Config;
1221
1222// ============================================================================
1223// Errors
1224// ============================================================================
1225
1226#[derive(Debug, thiserror::Error)]
1227pub enum Error {
1228    #[error("config not found: {}", .0.display())]
1229    NotFound(PathBuf),
1230
1231    #[error("failed to read {}: {}", .0.display(), .1)]
1232    Io(PathBuf, #[source] std::io::Error),
1233
1234    #[error("failed to parse {}: {}", .0.display(), .1)]
1235    Parse(PathBuf, #[source] toml::de::Error),
1236
1237    #[error("driver '{driver}' invalid for {dialect} dialect")]
1238    InvalidDriver { driver: Driver, dialect: Dialect },
1239
1240    #[error("invalid credentials: {0}")]
1241    InvalidCredentials(String),
1242
1243    #[error("invalid glob '{0}': {1}")]
1244    Glob(String, #[source] glob::PatternError),
1245
1246    #[error("no schema files found: {0}")]
1247    NoSchemaFiles(String),
1248
1249    #[error("environment variable '{0}' not found")]
1250    EnvNotFound(String),
1251
1252    #[error("environment variable '{0}' invalid: {1}")]
1253    EnvInvalid(String, String),
1254
1255    #[error("no databases configured")]
1256    NoDatabases,
1257
1258    #[error("database '{0}' not found")]
1259    DatabaseNotFound(String),
1260
1261    #[error("multiple databases configured, use --db to specify: {}", .0.join(", "))]
1262    DatabaseRequired(Vec<String>),
1263}
1264
1265pub type ConfigError = Error;
1266
1267// ============================================================================
1268// Tests
1269// ============================================================================
1270
1271#[cfg(test)]
1272mod tests {
1273    use super::*;
1274    use std::fs;
1275    use tempfile::TempDir;
1276
1277    #[test]
1278    fn sqlite() {
1279        let cfg = Config::load_from_str(
1280            r#"
1281            dialect = "sqlite"
1282            [dbCredentials]
1283            url = "./dev.db"
1284        "#,
1285            Path::new("test.toml"),
1286        )
1287        .unwrap();
1288        assert!(cfg.is_single_database());
1289        assert!(matches!(
1290            cfg.credentials().unwrap(),
1291            Some(Credentials::Sqlite { .. })
1292        ));
1293    }
1294
1295    #[test]
1296    fn postgres_url() {
1297        let cfg = Config::load_from_str(
1298            r#"
1299            dialect = "postgresql"
1300            [dbCredentials]
1301            url = "postgres://localhost/db"
1302        "#,
1303            Path::new("test.toml"),
1304        )
1305        .unwrap();
1306        assert!(matches!(
1307            cfg.credentials().unwrap(),
1308            Some(Credentials::Postgres(PostgresCreds::Url(_)))
1309        ));
1310    }
1311
1312    #[test]
1313    fn multi_database() {
1314        let cfg = Config::load_from_str(
1315            r#"
1316            [databases.dev]
1317            dialect = "sqlite"
1318            out = "./drizzle/sqlite"
1319            [databases.dev.dbCredentials]
1320            url = "./dev.db"
1321
1322            [databases.prod]
1323            dialect = "postgresql"
1324            out = "./drizzle/postgres"
1325            [databases.prod.dbCredentials]
1326            url = "postgres://localhost/db"
1327        "#,
1328            Path::new("test.toml"),
1329        )
1330        .unwrap();
1331
1332        assert!(!cfg.is_single_database());
1333        let names: Vec<_> = cfg.database_names().collect();
1334        assert!(names.contains(&"dev"));
1335        assert!(names.contains(&"prod"));
1336
1337        let dev = cfg.database(Some("dev")).unwrap();
1338        assert_eq!(dev.dialect, Dialect::Sqlite);
1339
1340        let prod = cfg.database(Some("prod")).unwrap();
1341        assert_eq!(prod.dialect, Dialect::Postgresql);
1342    }
1343
1344    #[test]
1345    fn multi_database_requires_selection() {
1346        let cfg = Config::load_from_str(
1347            r#"
1348            [databases.a]
1349            dialect = "sqlite"
1350            [databases.b]
1351            dialect = "postgresql"
1352        "#,
1353            Path::new("test.toml"),
1354        )
1355        .unwrap();
1356
1357        // Should error when no db specified with multiple dbs
1358        assert!(cfg.database(None).is_err());
1359    }
1360
1361    #[test]
1362    fn env_var_syntax() {
1363        let cfg = Config::load_from_str(
1364            r#"
1365            dialect = "postgresql"
1366            [dbCredentials]
1367            url = { env = "DATABASE_URL" }
1368        "#,
1369            Path::new("test.toml"),
1370        )
1371        .unwrap();
1372        assert!(cfg.is_single_database());
1373    }
1374
1375    #[test]
1376    fn casing_options() {
1377        let cfg = Config::load_from_str(
1378            r#"
1379            dialect = "postgresql"
1380            casing = "snake_case"
1381            [dbCredentials]
1382            url = "postgres://localhost/db"
1383        "#,
1384            Path::new("test.toml"),
1385        )
1386        .unwrap();
1387        let db = cfg.default_database().unwrap();
1388        assert_eq!(db.effective_casing(), Casing::SnakeCase);
1389
1390        // Test default (camelCase)
1391        let cfg2 = Config::load_from_str(
1392            r#"
1393            dialect = "postgresql"
1394            [dbCredentials]
1395            url = "postgres://localhost/db"
1396        "#,
1397            Path::new("test.toml"),
1398        )
1399        .unwrap();
1400        let db2 = cfg2.default_database().unwrap();
1401        assert_eq!(db2.effective_casing(), Casing::CamelCase);
1402    }
1403
1404    #[test]
1405    fn introspect_casing() {
1406        let cfg = Config::load_from_str(
1407            r#"
1408            dialect = "postgresql"
1409            [introspect]
1410            casing = "preserve"
1411            [dbCredentials]
1412            url = "postgres://localhost/db"
1413        "#,
1414            Path::new("test.toml"),
1415        )
1416        .unwrap();
1417        let db = cfg.default_database().unwrap();
1418        assert_eq!(db.effective_introspect_casing(), IntrospectCasing::Preserve);
1419    }
1420
1421    #[test]
1422    fn entities_roles_filter() {
1423        // Test boolean roles filter
1424        let cfg = Config::load_from_str(
1425            r#"
1426            dialect = "postgresql"
1427            [entities]
1428            roles = true
1429            [dbCredentials]
1430            url = "postgres://localhost/db"
1431        "#,
1432            Path::new("test.toml"),
1433        )
1434        .unwrap();
1435        let db = cfg.default_database().unwrap();
1436        assert!(db.roles_enabled());
1437        assert!(db.should_include_role("my_role"));
1438
1439        // Test roles filter with provider
1440        let cfg2 = Config::load_from_str(
1441            r#"
1442            dialect = "postgresql"
1443            [entities.roles]
1444            provider = "supabase"
1445            [dbCredentials]
1446            url = "postgres://localhost/db"
1447        "#,
1448            Path::new("test.toml"),
1449        )
1450        .unwrap();
1451        let db2 = cfg2.default_database().unwrap();
1452        assert!(db2.roles_enabled());
1453        assert!(!db2.should_include_role("anon")); // Supabase built-in
1454        assert!(db2.should_include_role("my_custom_role"));
1455    }
1456
1457    #[test]
1458    fn extensions_filter() {
1459        let cfg = Config::load_from_str(
1460            r#"
1461            dialect = "postgresql"
1462            extensionsFilters = ["postgis"]
1463            [dbCredentials]
1464            url = "postgres://localhost/db"
1465        "#,
1466            Path::new("test.toml"),
1467        )
1468        .unwrap();
1469        let db = cfg.default_database().unwrap();
1470        assert!(db.has_extension(Extension::Postgis));
1471    }
1472
1473    #[test]
1474    fn migrations_config() {
1475        let cfg = Config::load_from_str(
1476            r#"
1477            dialect = "postgresql"
1478            [migrations]
1479            table = "custom_migrations"
1480            schema = "custom_schema"
1481            [dbCredentials]
1482            url = "postgres://localhost/db"
1483        "#,
1484            Path::new("test.toml"),
1485        )
1486        .unwrap();
1487        let db = cfg.default_database().unwrap();
1488        assert_eq!(db.migrations_table(), "custom_migrations");
1489        assert_eq!(db.migrations_schema(), "custom_schema");
1490
1491        // Test defaults
1492        let cfg2 = Config::load_from_str(
1493            r#"
1494            dialect = "postgresql"
1495            [dbCredentials]
1496            url = "postgres://localhost/db"
1497        "#,
1498            Path::new("test.toml"),
1499        )
1500        .unwrap();
1501        let db2 = cfg2.default_database().unwrap();
1502        assert_eq!(db2.migrations_table(), "__drizzle_migrations");
1503        assert_eq!(db2.migrations_schema(), "drizzle");
1504    }
1505
1506    #[test]
1507    fn resolves_paths_relative_to_config_dir() {
1508        let tmp = TempDir::new().unwrap();
1509        let cfg_dir = tmp.path().join("cfg");
1510        fs::create_dir_all(&cfg_dir).unwrap();
1511
1512        // Create schema file next to config file.
1513        let schema_path = cfg_dir.join("schema.rs");
1514        fs::write(&schema_path, "#[allow(dead_code)]\npub struct X;").unwrap();
1515
1516        let cfg_path = cfg_dir.join("drizzle.config.toml");
1517        let cfg = Config::load_from_str(
1518            r#"
1519            dialect = "sqlite"
1520            schema = "schema.rs"
1521            out = "./drizzle"
1522            [dbCredentials]
1523            url = "./dev.db"
1524        "#,
1525            &cfg_path,
1526        )
1527        .unwrap();
1528
1529        let db = cfg.default_database().unwrap();
1530        assert_eq!(db.migrations_dir(), cfg_dir.join("./drizzle").as_path());
1531
1532        let files = db.schema_files().unwrap();
1533        assert_eq!(files.len(), 1);
1534        assert_eq!(files[0], schema_path);
1535    }
1536
1537    #[test]
1538    fn rejects_host_credentials_for_sqlite() {
1539        let err = Config::load_from_str(
1540            r#"
1541            dialect = "sqlite"
1542            [dbCredentials]
1543            host = "localhost"
1544            database = "db"
1545        "#,
1546            Path::new("test.toml"),
1547        )
1548        .unwrap_err();
1549
1550        let msg = err.to_string();
1551        assert!(
1552            msg.contains("host-based dbCredentials are only supported"),
1553            "unexpected error: {msg}"
1554        );
1555    }
1556
1557    #[cfg(windows)]
1558    #[test]
1559    fn schema_files_accept_backslash_paths() {
1560        let tmp = TempDir::new().unwrap();
1561        let cfg_dir = tmp.path().join("cfg");
1562        fs::create_dir_all(&cfg_dir).unwrap();
1563
1564        let schema_path = cfg_dir.join("src").join("schema.rs");
1565        fs::create_dir_all(schema_path.parent().unwrap()).unwrap();
1566        fs::write(&schema_path, "#[allow(dead_code)]\npub struct X;").unwrap();
1567
1568        // Write schema path with backslashes (common on Windows).
1569        let schema_str = schema_path.to_string_lossy().replace('/', "\\");
1570        // TOML basic strings treat backslash as an escape; double-escape to embed a Windows path.
1571        let schema_toml = schema_str.replace('\\', "\\\\");
1572        let cfg_path = cfg_dir.join("drizzle.config.toml");
1573        let cfg = Config::load_from_str(
1574            &format!(
1575                r#"
1576                dialect = "sqlite"
1577                schema = "{}"
1578            "#,
1579                schema_toml
1580            ),
1581            &cfg_path,
1582        )
1583        .unwrap();
1584
1585        let db = cfg.default_database().unwrap();
1586        let files = db.schema_files().unwrap();
1587        assert_eq!(files, vec![schema_path]);
1588    }
1589}