1use serde::Deserialize;
10use serde::de::{self, Deserializer, MapAccess, Visitor};
11use std::collections::HashMap;
12use std::path::{Path, PathBuf};
13
14pub const CONFIG_FILE: &str = "drizzle.config.toml";
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Deserialize)]
22pub enum Casing {
23 #[default]
25 #[serde(rename = "camelCase")]
26 CamelCase,
27 #[serde(rename = "snake_case")]
29 SnakeCase,
30}
31
32impl Casing {
33 pub const fn as_str(self) -> &'static str {
34 match self {
35 Self::CamelCase => "camelCase",
36 Self::SnakeCase => "snake_case",
37 }
38 }
39}
40
41impl std::fmt::Display for Casing {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.write_str(self.as_str())
44 }
45}
46
47impl std::str::FromStr for Casing {
48 type Err = String;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 match s {
52 "camelCase" | "camel" => Ok(Self::CamelCase),
53 "snake_case" | "snake" => Ok(Self::SnakeCase),
54 _ => Err(format!(
55 "invalid casing '{}', expected 'camelCase' or 'snake_case'",
56 s
57 )),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Deserialize)]
64pub enum IntrospectCasing {
65 #[default]
67 #[serde(rename = "camel")]
68 Camel,
69 #[serde(rename = "preserve")]
71 Preserve,
72}
73
74impl IntrospectCasing {
75 pub const fn as_str(self) -> &'static str {
76 match self {
77 Self::Camel => "camel",
78 Self::Preserve => "preserve",
79 }
80 }
81}
82
83impl std::fmt::Display for IntrospectCasing {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.write_str(self.as_str())
86 }
87}
88
89impl std::str::FromStr for IntrospectCasing {
90 type Err = String;
91
92 fn from_str(s: &str) -> Result<Self, Self::Err> {
93 match s {
94 "camel" | "camelCase" => Ok(Self::Camel),
95 "preserve" => Ok(Self::Preserve),
96 _ => Err(format!(
97 "invalid introspect casing '{}', expected 'camel' or 'preserve'",
98 s
99 )),
100 }
101 }
102}
103
104#[derive(Debug, Clone, Default, Deserialize)]
106pub struct IntrospectConfig {
107 #[serde(default)]
109 pub casing: IntrospectCasing,
110}
111
112#[derive(Debug, Clone, Deserialize)]
121#[serde(untagged)]
122pub enum RolesFilter {
123 Bool(bool),
125 Config {
127 #[serde(default)]
129 provider: Option<String>,
130 #[serde(default)]
132 include: Option<Vec<String>>,
133 #[serde(default)]
135 exclude: Option<Vec<String>>,
136 },
137}
138
139impl Default for RolesFilter {
140 fn default() -> Self {
141 Self::Bool(false)
142 }
143}
144
145impl RolesFilter {
146 pub fn is_enabled(&self) -> bool {
148 match self {
149 Self::Bool(b) => *b,
150 Self::Config { .. } => true,
151 }
152 }
153
154 pub fn should_include(&self, role_name: &str) -> bool {
156 match self {
157 Self::Bool(b) => *b,
158 Self::Config {
159 provider,
160 include,
161 exclude,
162 } => {
163 if let Some(p) = provider
165 && is_provider_role(p, role_name)
166 {
167 return false;
168 }
169 if let Some(excl) = exclude
171 && excl.iter().any(|e| e == role_name)
172 {
173 return false;
174 }
175 if let Some(incl) = include {
177 return incl.iter().any(|i| i == role_name);
178 }
179 true
180 }
181 }
182 }
183}
184
185fn is_provider_role(provider: &str, role_name: &str) -> bool {
187 match provider {
188 "supabase" => matches!(
189 role_name,
190 "anon"
191 | "authenticated"
192 | "service_role"
193 | "supabase_admin"
194 | "supabase_auth_admin"
195 | "supabase_storage_admin"
196 | "dashboard_user"
197 | "supabase_replication_admin"
198 | "supabase_read_only_user"
199 | "supabase_realtime_admin"
200 | "supabase_functions_admin"
201 | "postgres"
202 | "pgbouncer"
203 | "pgsodium_keyholder"
204 | "pgsodium_keyiduser"
205 | "pgsodium_keymaker"
206 ),
207 "neon" => matches!(
208 role_name,
209 "neon_superuser" | "cloud_admin" | "authenticated" | "anonymous"
210 ),
211 _ => false,
212 }
213}
214
215#[derive(Debug, Clone, Default, Deserialize)]
219pub struct EntitiesFilter {
220 #[serde(default)]
222 pub roles: RolesFilter,
223}
224
225#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
231#[serde(rename_all = "lowercase")]
232pub enum Extension {
233 Postgis,
235}
236
237impl Extension {
238 pub const fn as_str(self) -> &'static str {
239 match self {
240 Self::Postgis => "postgis",
241 }
242 }
243}
244
245impl std::fmt::Display for Extension {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 f.write_str(self.as_str())
248 }
249}
250
251#[derive(Debug, Clone)]
263pub enum EnvOr {
264 Value(String),
266 Env(String),
268}
269
270impl EnvOr {
271 pub fn resolve(&self) -> Result<String, Error> {
273 match self {
274 Self::Value(v) => Ok(v.clone()),
275 Self::Env(var) => std::env::var(var).map_err(|_| Error::EnvNotFound(var.clone())),
276 }
277 }
278
279 pub fn resolve_optional(&self) -> Result<Option<String>, Error> {
281 match self {
282 Self::Value(v) => Ok(Some(v.clone())),
283 Self::Env(var) => match std::env::var(var) {
284 Ok(v) => Ok(Some(v)),
285 Err(std::env::VarError::NotPresent) => Ok(None),
286 Err(std::env::VarError::NotUnicode(_)) => Err(Error::EnvInvalid(
287 var.clone(),
288 "contains invalid unicode".into(),
289 )),
290 },
291 }
292 }
293}
294
295impl<'de> Deserialize<'de> for EnvOr {
296 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
297 where
298 D: Deserializer<'de>,
299 {
300 struct EnvOrVisitor;
301
302 impl<'de> Visitor<'de> for EnvOrVisitor {
303 type Value = EnvOr;
304
305 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
306 formatter.write_str("a string or { env = \"VAR_NAME\" }")
307 }
308
309 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
310 where
311 E: de::Error,
312 {
313 Ok(EnvOr::Value(value.to_string()))
314 }
315
316 fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
317 where
318 M: MapAccess<'de>,
319 {
320 let mut env_var: Option<String> = None;
321
322 while let Some(key) = map.next_key::<String>()? {
323 if key == "env" {
324 env_var = Some(map.next_value()?);
325 } else {
326 return Err(de::Error::unknown_field(&key, &["env"]));
327 }
328 }
329
330 env_var
331 .map(EnvOr::Env)
332 .ok_or_else(|| de::Error::missing_field("env"))
333 }
334 }
335
336 deserializer.deserialize_any(EnvOrVisitor)
337 }
338}
339
340#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Deserialize)]
346#[serde(rename_all = "lowercase")]
347pub enum Dialect {
348 #[default]
349 Sqlite,
350 #[serde(alias = "postgres")]
351 Postgresql,
352 Turso,
353}
354
355impl Dialect {
356 pub const ALL: &'static [&'static str] = &["sqlite", "postgresql", "turso"];
357
358 #[inline]
359 pub const fn as_str(self) -> &'static str {
360 match self {
361 Self::Sqlite => "sqlite",
362 Self::Postgresql => "postgresql",
363 Self::Turso => "turso",
364 }
365 }
366
367 #[inline]
368 pub const fn to_base(self) -> drizzle_types::Dialect {
369 match self {
370 Self::Sqlite | Self::Turso => drizzle_types::Dialect::SQLite,
371 Self::Postgresql => drizzle_types::Dialect::PostgreSQL,
372 }
373 }
374}
375
376impl std::fmt::Display for Dialect {
377 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378 f.write_str(self.as_str())
379 }
380}
381
382impl From<Dialect> for drizzle_types::Dialect {
383 #[inline]
384 fn from(d: Dialect) -> Self {
385 d.to_base()
386 }
387}
388
389#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
395#[serde(rename_all = "kebab-case")]
396pub enum Driver {
397 Rusqlite,
399 Libsql,
401 Turso,
403 PostgresSync,
405 TokioPostgres,
407}
408
409impl Driver {
410 pub const ALL: &'static [&'static str] = &[
411 "rusqlite",
412 "libsql",
413 "turso",
414 "postgres-sync",
415 "tokio-postgres",
416 ];
417
418 #[inline]
419 pub const fn as_str(self) -> &'static str {
420 match self {
421 Self::Rusqlite => "rusqlite",
422 Self::Libsql => "libsql",
423 Self::Turso => "turso",
424 Self::PostgresSync => "postgres-sync",
425 Self::TokioPostgres => "tokio-postgres",
426 }
427 }
428
429 pub const fn valid_for(dialect: Dialect) -> &'static [Driver] {
430 match dialect {
431 Dialect::Sqlite => &[Self::Rusqlite],
432 Dialect::Turso => &[Self::Libsql, Self::Turso],
433 Dialect::Postgresql => &[Self::PostgresSync, Self::TokioPostgres],
434 }
435 }
436
437 #[inline]
438 pub const fn is_valid_for(self, dialect: Dialect) -> bool {
439 matches!(
440 (self, dialect),
441 (Self::Rusqlite, Dialect::Sqlite)
442 | (Self::Libsql | Self::Turso, Dialect::Turso)
443 | (
444 Self::PostgresSync | Self::TokioPostgres,
445 Dialect::Postgresql
446 )
447 )
448 }
449}
450
451impl std::fmt::Display for Driver {
452 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453 f.write_str(self.as_str())
454 }
455}
456
457#[derive(Debug, Clone)]
463pub enum Credentials {
464 Sqlite { path: Box<str> },
466
467 Turso {
469 url: Box<str>,
470 auth_token: Option<Box<str>>,
471 },
472
473 Postgres(PostgresCreds),
475}
476
477#[derive(Debug, Clone)]
479pub enum PostgresCreds {
480 Url(Box<str>),
481 Host {
482 host: Box<str>,
483 port: u16,
484 user: Option<Box<str>>,
485 password: Option<Box<str>>,
486 database: Box<str>,
487 ssl: bool,
488 },
489}
490
491impl PostgresCreds {
492 pub fn connection_url(&self) -> String {
494 match self {
495 Self::Url(url) => url.to_string(),
496 Self::Host {
497 host,
498 port,
499 user,
500 password,
501 database,
502 ..
503 } => {
504 let auth = match (user, password) {
505 (Some(u), Some(p)) => format!("{u}:{p}@"),
506 (Some(u), None) => format!("{u}@"),
507 _ => String::new(),
508 };
509 format!("postgres://{auth}{host}:{port}/{database}")
510 }
511 }
512 }
513}
514
515#[derive(Debug, Clone, Deserialize)]
521#[serde(untagged)]
522pub enum Schema {
523 One(String),
524 Many(Vec<String>),
525}
526
527impl Default for Schema {
528 fn default() -> Self {
529 Self::One("src/schema.rs".into())
530 }
531}
532
533impl Schema {
534 pub fn iter(&self) -> impl Iterator<Item = &str> {
535 match self {
536 Self::One(s) => std::slice::from_ref(s).iter().map(String::as_str),
537 Self::Many(v) => v.iter().map(String::as_str),
538 }
539 }
540}
541
542#[derive(Debug, Clone, Deserialize)]
544#[serde(untagged)]
545pub enum Filter {
546 One(String),
547 Many(Vec<String>),
548}
549
550impl Filter {
551 pub fn iter(&self) -> impl Iterator<Item = &str> {
552 match self {
553 Self::One(s) => std::slice::from_ref(s).iter().map(String::as_str),
554 Self::Many(v) => v.iter().map(String::as_str),
555 }
556 }
557}
558
559#[derive(Debug, Clone, Deserialize)]
561pub struct MigrationsOpts {
562 pub table: Option<String>,
563 pub schema: Option<String>,
564 pub prefix: Option<MigrationPrefix>,
565}
566
567#[derive(Debug, Clone, Copy, Deserialize)]
568#[serde(rename_all = "lowercase")]
569pub enum MigrationPrefix {
570 Index,
571 Timestamp,
572 Supabase,
573 Unix,
574 None,
575}
576
577#[derive(Debug, Clone, Deserialize)]
582#[serde(untagged)]
583enum RawCreds {
584 Url {
585 url: EnvOr,
586 #[serde(default, rename = "authToken")]
587 auth_token: Option<EnvOr>,
588 },
589 Host {
590 host: EnvOr,
591 #[serde(default)]
592 port: Option<u16>,
593 #[serde(default)]
594 user: Option<EnvOr>,
595 #[serde(default)]
596 password: Option<EnvOr>,
597 database: EnvOr,
598 #[serde(default)]
599 ssl: Option<SslVal>,
600 },
601}
602
603#[derive(Debug, Clone, Deserialize)]
604#[serde(untagged)]
605enum SslVal {
606 Bool(bool),
607 Str(String),
608}
609
610impl SslVal {
611 fn enabled(&self) -> bool {
612 match self {
613 Self::Bool(b) => *b,
614 Self::Str(s) => !matches!(s.as_str(), "disable" | "false" | "no" | "off"),
615 }
616 }
617}
618
619#[derive(Debug, Clone, Deserialize)]
627#[serde(rename_all = "camelCase")]
628pub struct DatabaseConfig {
629 pub dialect: Dialect,
631
632 #[serde(default)]
634 pub schema: Schema,
635
636 #[serde(default = "default_out")]
638 pub out: PathBuf,
639
640 #[serde(default = "yes")]
642 pub breakpoints: bool,
643
644 #[serde(default)]
646 pub driver: Option<Driver>,
647
648 #[serde(default)]
650 db_credentials: Option<RawCreds>,
651
652 #[serde(default)]
654 pub tables_filter: Option<Filter>,
655
656 #[serde(default)]
658 pub schema_filter: Option<Filter>,
659
660 #[serde(default)]
662 pub extensions_filters: Option<Vec<Extension>>,
663
664 #[serde(default)]
666 pub entities: Option<EntitiesFilter>,
667
668 #[serde(default)]
670 pub casing: Option<Casing>,
671
672 #[serde(default)]
674 pub introspect: Option<IntrospectConfig>,
675
676 #[serde(default)]
678 pub verbose: bool,
679
680 #[serde(default)]
682 pub strict: bool,
683
684 #[serde(default)]
686 pub migrations: Option<MigrationsOpts>,
687}
688
689fn default_out() -> PathBuf {
690 PathBuf::from("./drizzle")
691}
692
693fn yes() -> bool {
694 true
695}
696
697impl DatabaseConfig {
698 fn normalize_paths(&mut self, base_dir: &Path) {
699 if self.out.is_relative() {
702 self.out = base_dir.join(&self.out);
703 }
704
705 let base = base_dir.to_string_lossy().replace('\\', "/");
709 let base = base.trim_end_matches('/').to_string();
710
711 let normalize_one = |p: &str| -> String {
712 let p_trim = p.trim();
713 let is_abs = Path::new(p_trim).is_absolute() || p_trim.starts_with("\\\\");
714 let joined = if is_abs || base.is_empty() || base == "." {
715 p_trim.to_string()
716 } else {
717 format!("{base}/{p_trim}")
718 };
719 joined.replace('\\', "/")
720 };
721
722 match &mut self.schema {
723 Schema::One(p) => *p = normalize_one(p),
724 Schema::Many(v) => {
725 for p in v.iter_mut() {
726 *p = normalize_one(p);
727 }
728 }
729 }
730 }
731
732 fn validate(&self, name: &str) -> Result<(), Error> {
733 if let Some(d) = self.driver
735 && !d.is_valid_for(self.dialect)
736 {
737 return Err(Error::InvalidDriver {
738 driver: d,
739 dialect: self.dialect,
740 });
741 }
742
743 if let Some(ref raw) = self.db_credentials {
745 self.validate_creds(raw, name)?;
746 }
747
748 Ok(())
749 }
750
751 fn validate_creds(&self, raw: &RawCreds, _name: &str) -> Result<(), Error> {
752 let err = |msg: &str| Error::InvalidCredentials(msg.into());
753
754 match (self.dialect, raw) {
757 (Dialect::Postgresql, RawCreds::Host { .. }) => {}
758 (Dialect::Postgresql, RawCreds::Url { .. }) => {}
759 (_, RawCreds::Host { .. }) => {
760 return Err(err(
761 "host-based dbCredentials are only supported for dialect = \"postgresql\"",
762 ));
763 }
764 _ => {}
765 }
766
767 match (self.dialect, raw) {
769 (
770 Dialect::Sqlite,
771 RawCreds::Url {
772 auth_token: Some(_),
773 ..
774 },
775 ) => Err(err(
776 "SQLite doesn't support authToken (use dialect = \"turso\")",
777 )),
778 (
779 Dialect::Sqlite,
780 RawCreds::Url {
781 url: EnvOr::Value(url),
782 ..
783 },
784 ) if url.starts_with("libsql://") => Err(err(
785 "libsql:// URLs require dialect = \"turso\" (for local SQLite files, use ./path.db)",
786 )),
787 (
788 Dialect::Sqlite,
789 RawCreds::Url {
790 url: EnvOr::Value(url),
791 ..
792 },
793 ) if url.starts_with("http://")
794 || url.starts_with("https://")
795 || url.starts_with("postgres://")
796 || url.starts_with("postgresql://") =>
797 {
798 Err(err(
799 "SQLite dbCredentials.url must be a local file path (not an http(s)/postgres URL)",
800 ))
801 }
802 (
803 Dialect::Turso,
804 RawCreds::Url {
805 url: EnvOr::Value(url),
806 ..
807 },
808 ) if !url.starts_with("libsql://") && !url.starts_with("http") => {
809 Err(err("Turso URL must start with libsql:// or http(s)://"))
810 }
811 (
812 Dialect::Postgresql,
813 RawCreds::Url {
814 url: EnvOr::Value(url),
815 ..
816 },
817 ) if !url.starts_with("postgres") => {
818 Err(err("PostgreSQL URL must start with postgres://"))
819 }
820 _ => Ok(()),
821 }
822 }
823
824 pub fn credentials(&self) -> Result<Option<Credentials>, Error> {
826 let raw = match self.db_credentials.as_ref() {
827 Some(r) => r,
828 None => return Ok(None),
829 };
830
831 let resolve_opt = |opt: &Option<EnvOr>| -> Result<Option<Box<str>>, Error> {
833 match opt {
834 Some(e) => e.resolve().map(|s| Some(s.into_boxed_str())),
835 None => Ok(None),
836 }
837 };
838
839 let creds = match (self.dialect, raw) {
840 (Dialect::Sqlite, RawCreds::Url { url, .. }) => Credentials::Sqlite {
842 path: url.resolve()?.into_boxed_str(),
843 },
844 (Dialect::Turso, RawCreds::Url { url, auth_token }) => Credentials::Turso {
846 url: url.resolve()?.into_boxed_str(),
847 auth_token: resolve_opt(auth_token)?,
848 },
849 (Dialect::Postgresql, RawCreds::Url { url, .. }) => {
851 Credentials::Postgres(PostgresCreds::Url(url.resolve()?.into_boxed_str()))
852 }
853 (
855 Dialect::Postgresql,
856 RawCreds::Host {
857 host,
858 port,
859 user,
860 password,
861 database,
862 ssl,
863 },
864 ) => Credentials::Postgres(PostgresCreds::Host {
865 host: host.resolve()?.into_boxed_str(),
866 port: port.unwrap_or(5432),
867 user: resolve_opt(user)?,
868 password: resolve_opt(password)?,
869 database: database.resolve()?.into_boxed_str(),
870 ssl: ssl.as_ref().map(|s| s.enabled()).unwrap_or(false),
871 }),
872 _ => return Ok(None),
873 };
874
875 Ok(Some(creds))
876 }
877
878 #[inline]
880 pub fn migrations_dir(&self) -> &Path {
881 &self.out
882 }
883
884 #[inline]
886 pub fn meta_dir(&self) -> PathBuf {
887 self.out.join("meta")
888 }
889
890 #[inline]
892 pub fn journal_path(&self) -> PathBuf {
893 self.meta_dir().join("_journal.json")
894 }
895
896 pub fn schema_display(&self) -> String {
898 match &self.schema {
899 Schema::One(s) => s.clone(),
900 Schema::Many(v) => v.join(", "),
901 }
902 }
903
904 pub fn schema_files(&self) -> Result<Vec<PathBuf>, Error> {
906 let mut files = Vec::new();
907
908 for pattern in self.schema.iter() {
909 let pat = pattern.trim();
910
911 let is_glob = pat.contains('*') || pat.contains('?') || pat.contains('[');
913 if !is_glob {
914 let p = PathBuf::from(pat);
915 if p.exists() {
916 files.push(p);
917 continue;
918 }
919 }
920
921 let pat_norm = pat.replace('\\', "/");
923 match glob::glob(&pat_norm) {
924 Ok(paths) => {
925 let matched: Vec<_> = paths.filter_map(Result::ok).collect();
926 if matched.is_empty() && !is_glob {
927 let p = PathBuf::from(&pat_norm);
928 if p.exists() {
929 files.push(p);
930 }
931 } else {
932 files.extend(matched);
933 }
934 }
935 Err(e) => return Err(Error::Glob(pat.into(), e)),
936 }
937 }
938
939 files.retain(|p| p.is_file());
941 files.sort();
942 files.dedup();
943
944 if files.is_empty() {
945 return Err(Error::NoSchemaFiles(self.schema_display()));
946 }
947
948 Ok(files)
949 }
950
951 #[inline]
953 pub fn effective_casing(&self) -> Casing {
954 self.casing.unwrap_or_default()
955 }
956
957 #[inline]
959 pub fn effective_introspect_casing(&self) -> IntrospectCasing {
960 self.introspect
961 .as_ref()
962 .map(|i| i.casing)
963 .unwrap_or_default()
964 }
965
966 #[inline]
968 pub fn effective_entities(&self) -> EntitiesFilter {
969 self.entities.clone().unwrap_or_default()
970 }
971
972 pub fn should_include_role(&self, role_name: &str) -> bool {
974 self.entities
975 .as_ref()
976 .map(|e| e.roles.should_include(role_name))
977 .unwrap_or(false)
978 }
979
980 pub fn roles_enabled(&self) -> bool {
982 self.entities
983 .as_ref()
984 .map(|e| e.roles.is_enabled())
985 .unwrap_or(false)
986 }
987
988 pub fn extensions(&self) -> &[Extension] {
990 self.extensions_filters.as_deref().unwrap_or(&[])
991 }
992
993 pub fn has_extension(&self, ext: Extension) -> bool {
995 self.extensions_filters
996 .as_ref()
997 .map(|v| v.contains(&ext))
998 .unwrap_or(false)
999 }
1000
1001 pub fn migrations_table(&self) -> &str {
1003 self.migrations
1004 .as_ref()
1005 .and_then(|m| m.table.as_deref())
1006 .unwrap_or("__drizzle_migrations")
1007 }
1008
1009 pub fn migrations_schema(&self) -> &str {
1011 self.migrations
1012 .as_ref()
1013 .and_then(|m| m.schema.as_deref())
1014 .unwrap_or("drizzle")
1015 }
1016}
1017
1018#[derive(Debug, Clone, Deserialize)]
1024struct MultiDbConfig {
1025 databases: HashMap<String, DatabaseConfig>,
1026}
1027
1028#[derive(Debug, Clone)]
1052pub struct Config {
1053 databases: HashMap<String, DatabaseConfig>,
1055 is_single: bool,
1057}
1058
1059pub const DEFAULT_DB: &str = "default";
1061
1062impl Config {
1063 pub fn load() -> Result<Self, Error> {
1065 Self::load_from(Path::new(CONFIG_FILE))
1066 }
1067
1068 pub fn load_from(path: &Path) -> Result<Self, Error> {
1070 let content = std::fs::read_to_string(path).map_err(|e| {
1071 if e.kind() == std::io::ErrorKind::NotFound {
1072 Error::NotFound(path.into())
1073 } else {
1074 Error::Io(path.into(), e)
1075 }
1076 })?;
1077
1078 Self::load_from_str(&content, path)
1079 }
1080
1081 fn load_from_str(content: &str, path: &Path) -> Result<Self, Error> {
1083 let base_dir = path.parent().unwrap_or_else(|| Path::new("."));
1084
1085 if let Ok(multi) = toml::from_str::<MultiDbConfig>(content)
1087 && !multi.databases.is_empty()
1088 {
1089 let mut config = Self {
1090 databases: multi.databases,
1091 is_single: false,
1092 };
1093 for db in config.databases.values_mut() {
1094 db.normalize_paths(base_dir);
1095 }
1096 config.validate()?;
1097 return Ok(config);
1098 }
1099
1100 let db_config: DatabaseConfig =
1102 toml::from_str(content).map_err(|e| Error::Parse(path.into(), e))?;
1103
1104 let mut databases = HashMap::new();
1105 databases.insert(DEFAULT_DB.to_string(), db_config);
1106
1107 let mut config = Self {
1108 databases,
1109 is_single: true,
1110 };
1111 for db in config.databases.values_mut() {
1112 db.normalize_paths(base_dir);
1113 }
1114 config.validate()?;
1115 Ok(config)
1116 }
1117
1118 fn validate(&mut self) -> Result<(), Error> {
1119 for (name, db) in &self.databases {
1120 db.validate(name)?;
1121 }
1122 Ok(())
1123 }
1124
1125 pub fn is_single_database(&self) -> bool {
1127 self.is_single
1128 }
1129
1130 pub fn database_names(&self) -> impl Iterator<Item = &str> {
1132 self.databases.keys().map(String::as_str)
1133 }
1134
1135 pub fn database(&self, name: Option<&str>) -> Result<&DatabaseConfig, Error> {
1140 match name {
1141 None => {
1142 if self.is_single {
1144 self.databases.get(DEFAULT_DB).ok_or(Error::NoDatabases)
1145 } else if self.databases.len() == 1 {
1146 self.databases.values().next().ok_or(Error::NoDatabases)
1147 } else {
1148 Err(Error::DatabaseRequired(
1149 self.databases.keys().cloned().collect(),
1150 ))
1151 }
1152 }
1153 Some(name) => {
1154 if self.is_single {
1155 self.databases.get(DEFAULT_DB).ok_or(Error::NoDatabases)
1157 } else {
1158 self.databases
1159 .get(name)
1160 .ok_or_else(|| Error::DatabaseNotFound(name.to_string()))
1161 }
1162 }
1163 }
1164 }
1165
1166 pub fn default_database(&self) -> Result<&DatabaseConfig, Error> {
1168 self.database(None)
1169 }
1170
1171 pub fn dialect(&self) -> Dialect {
1177 self.default_database()
1178 .map(|d| d.dialect)
1179 .unwrap_or_default()
1180 }
1181
1182 pub fn credentials(&self) -> Result<Option<Credentials>, Error> {
1184 self.default_database()?.credentials()
1185 }
1186
1187 pub fn migrations_dir(&self) -> &Path {
1189 self.default_database()
1190 .map(|d| d.migrations_dir())
1191 .unwrap_or(Path::new("./drizzle"))
1192 }
1193
1194 pub fn journal_path(&self) -> PathBuf {
1196 self.default_database()
1197 .map(|d| d.journal_path())
1198 .unwrap_or_else(|_| PathBuf::from("./drizzle/meta/_journal.json"))
1199 }
1200
1201 pub fn schema_display(&self) -> String {
1203 self.default_database()
1204 .map(|d| d.schema_display())
1205 .unwrap_or_else(|_| "src/schema.rs".into())
1206 }
1207
1208 pub fn schema_files(&self) -> Result<Vec<PathBuf>, Error> {
1210 self.default_database()?.schema_files()
1211 }
1212
1213 pub fn base_dialect(&self) -> drizzle_types::Dialect {
1215 self.dialect().to_base()
1216 }
1217}
1218
1219pub type DrizzleConfig = Config;
1221
1222#[derive(Debug, thiserror::Error)]
1227pub enum Error {
1228 #[error("config not found: {}", .0.display())]
1229 NotFound(PathBuf),
1230
1231 #[error("failed to read {}: {}", .0.display(), .1)]
1232 Io(PathBuf, #[source] std::io::Error),
1233
1234 #[error("failed to parse {}: {}", .0.display(), .1)]
1235 Parse(PathBuf, #[source] toml::de::Error),
1236
1237 #[error("driver '{driver}' invalid for {dialect} dialect")]
1238 InvalidDriver { driver: Driver, dialect: Dialect },
1239
1240 #[error("invalid credentials: {0}")]
1241 InvalidCredentials(String),
1242
1243 #[error("invalid glob '{0}': {1}")]
1244 Glob(String, #[source] glob::PatternError),
1245
1246 #[error("no schema files found: {0}")]
1247 NoSchemaFiles(String),
1248
1249 #[error("environment variable '{0}' not found")]
1250 EnvNotFound(String),
1251
1252 #[error("environment variable '{0}' invalid: {1}")]
1253 EnvInvalid(String, String),
1254
1255 #[error("no databases configured")]
1256 NoDatabases,
1257
1258 #[error("database '{0}' not found")]
1259 DatabaseNotFound(String),
1260
1261 #[error("multiple databases configured, use --db to specify: {}", .0.join(", "))]
1262 DatabaseRequired(Vec<String>),
1263}
1264
1265pub type ConfigError = Error;
1266
1267#[cfg(test)]
1272mod tests {
1273 use super::*;
1274 use std::fs;
1275 use tempfile::TempDir;
1276
1277 #[test]
1278 fn sqlite() {
1279 let cfg = Config::load_from_str(
1280 r#"
1281 dialect = "sqlite"
1282 [dbCredentials]
1283 url = "./dev.db"
1284 "#,
1285 Path::new("test.toml"),
1286 )
1287 .unwrap();
1288 assert!(cfg.is_single_database());
1289 assert!(matches!(
1290 cfg.credentials().unwrap(),
1291 Some(Credentials::Sqlite { .. })
1292 ));
1293 }
1294
1295 #[test]
1296 fn postgres_url() {
1297 let cfg = Config::load_from_str(
1298 r#"
1299 dialect = "postgresql"
1300 [dbCredentials]
1301 url = "postgres://localhost/db"
1302 "#,
1303 Path::new("test.toml"),
1304 )
1305 .unwrap();
1306 assert!(matches!(
1307 cfg.credentials().unwrap(),
1308 Some(Credentials::Postgres(PostgresCreds::Url(_)))
1309 ));
1310 }
1311
1312 #[test]
1313 fn multi_database() {
1314 let cfg = Config::load_from_str(
1315 r#"
1316 [databases.dev]
1317 dialect = "sqlite"
1318 out = "./drizzle/sqlite"
1319 [databases.dev.dbCredentials]
1320 url = "./dev.db"
1321
1322 [databases.prod]
1323 dialect = "postgresql"
1324 out = "./drizzle/postgres"
1325 [databases.prod.dbCredentials]
1326 url = "postgres://localhost/db"
1327 "#,
1328 Path::new("test.toml"),
1329 )
1330 .unwrap();
1331
1332 assert!(!cfg.is_single_database());
1333 let names: Vec<_> = cfg.database_names().collect();
1334 assert!(names.contains(&"dev"));
1335 assert!(names.contains(&"prod"));
1336
1337 let dev = cfg.database(Some("dev")).unwrap();
1338 assert_eq!(dev.dialect, Dialect::Sqlite);
1339
1340 let prod = cfg.database(Some("prod")).unwrap();
1341 assert_eq!(prod.dialect, Dialect::Postgresql);
1342 }
1343
1344 #[test]
1345 fn multi_database_requires_selection() {
1346 let cfg = Config::load_from_str(
1347 r#"
1348 [databases.a]
1349 dialect = "sqlite"
1350 [databases.b]
1351 dialect = "postgresql"
1352 "#,
1353 Path::new("test.toml"),
1354 )
1355 .unwrap();
1356
1357 assert!(cfg.database(None).is_err());
1359 }
1360
1361 #[test]
1362 fn env_var_syntax() {
1363 let cfg = Config::load_from_str(
1364 r#"
1365 dialect = "postgresql"
1366 [dbCredentials]
1367 url = { env = "DATABASE_URL" }
1368 "#,
1369 Path::new("test.toml"),
1370 )
1371 .unwrap();
1372 assert!(cfg.is_single_database());
1373 }
1374
1375 #[test]
1376 fn casing_options() {
1377 let cfg = Config::load_from_str(
1378 r#"
1379 dialect = "postgresql"
1380 casing = "snake_case"
1381 [dbCredentials]
1382 url = "postgres://localhost/db"
1383 "#,
1384 Path::new("test.toml"),
1385 )
1386 .unwrap();
1387 let db = cfg.default_database().unwrap();
1388 assert_eq!(db.effective_casing(), Casing::SnakeCase);
1389
1390 let cfg2 = Config::load_from_str(
1392 r#"
1393 dialect = "postgresql"
1394 [dbCredentials]
1395 url = "postgres://localhost/db"
1396 "#,
1397 Path::new("test.toml"),
1398 )
1399 .unwrap();
1400 let db2 = cfg2.default_database().unwrap();
1401 assert_eq!(db2.effective_casing(), Casing::CamelCase);
1402 }
1403
1404 #[test]
1405 fn introspect_casing() {
1406 let cfg = Config::load_from_str(
1407 r#"
1408 dialect = "postgresql"
1409 [introspect]
1410 casing = "preserve"
1411 [dbCredentials]
1412 url = "postgres://localhost/db"
1413 "#,
1414 Path::new("test.toml"),
1415 )
1416 .unwrap();
1417 let db = cfg.default_database().unwrap();
1418 assert_eq!(db.effective_introspect_casing(), IntrospectCasing::Preserve);
1419 }
1420
1421 #[test]
1422 fn entities_roles_filter() {
1423 let cfg = Config::load_from_str(
1425 r#"
1426 dialect = "postgresql"
1427 [entities]
1428 roles = true
1429 [dbCredentials]
1430 url = "postgres://localhost/db"
1431 "#,
1432 Path::new("test.toml"),
1433 )
1434 .unwrap();
1435 let db = cfg.default_database().unwrap();
1436 assert!(db.roles_enabled());
1437 assert!(db.should_include_role("my_role"));
1438
1439 let cfg2 = Config::load_from_str(
1441 r#"
1442 dialect = "postgresql"
1443 [entities.roles]
1444 provider = "supabase"
1445 [dbCredentials]
1446 url = "postgres://localhost/db"
1447 "#,
1448 Path::new("test.toml"),
1449 )
1450 .unwrap();
1451 let db2 = cfg2.default_database().unwrap();
1452 assert!(db2.roles_enabled());
1453 assert!(!db2.should_include_role("anon")); assert!(db2.should_include_role("my_custom_role"));
1455 }
1456
1457 #[test]
1458 fn extensions_filter() {
1459 let cfg = Config::load_from_str(
1460 r#"
1461 dialect = "postgresql"
1462 extensionsFilters = ["postgis"]
1463 [dbCredentials]
1464 url = "postgres://localhost/db"
1465 "#,
1466 Path::new("test.toml"),
1467 )
1468 .unwrap();
1469 let db = cfg.default_database().unwrap();
1470 assert!(db.has_extension(Extension::Postgis));
1471 }
1472
1473 #[test]
1474 fn migrations_config() {
1475 let cfg = Config::load_from_str(
1476 r#"
1477 dialect = "postgresql"
1478 [migrations]
1479 table = "custom_migrations"
1480 schema = "custom_schema"
1481 [dbCredentials]
1482 url = "postgres://localhost/db"
1483 "#,
1484 Path::new("test.toml"),
1485 )
1486 .unwrap();
1487 let db = cfg.default_database().unwrap();
1488 assert_eq!(db.migrations_table(), "custom_migrations");
1489 assert_eq!(db.migrations_schema(), "custom_schema");
1490
1491 let cfg2 = Config::load_from_str(
1493 r#"
1494 dialect = "postgresql"
1495 [dbCredentials]
1496 url = "postgres://localhost/db"
1497 "#,
1498 Path::new("test.toml"),
1499 )
1500 .unwrap();
1501 let db2 = cfg2.default_database().unwrap();
1502 assert_eq!(db2.migrations_table(), "__drizzle_migrations");
1503 assert_eq!(db2.migrations_schema(), "drizzle");
1504 }
1505
1506 #[test]
1507 fn resolves_paths_relative_to_config_dir() {
1508 let tmp = TempDir::new().unwrap();
1509 let cfg_dir = tmp.path().join("cfg");
1510 fs::create_dir_all(&cfg_dir).unwrap();
1511
1512 let schema_path = cfg_dir.join("schema.rs");
1514 fs::write(&schema_path, "#[allow(dead_code)]\npub struct X;").unwrap();
1515
1516 let cfg_path = cfg_dir.join("drizzle.config.toml");
1517 let cfg = Config::load_from_str(
1518 r#"
1519 dialect = "sqlite"
1520 schema = "schema.rs"
1521 out = "./drizzle"
1522 [dbCredentials]
1523 url = "./dev.db"
1524 "#,
1525 &cfg_path,
1526 )
1527 .unwrap();
1528
1529 let db = cfg.default_database().unwrap();
1530 assert_eq!(db.migrations_dir(), cfg_dir.join("./drizzle").as_path());
1531
1532 let files = db.schema_files().unwrap();
1533 assert_eq!(files.len(), 1);
1534 assert_eq!(files[0], schema_path);
1535 }
1536
1537 #[test]
1538 fn rejects_host_credentials_for_sqlite() {
1539 let err = Config::load_from_str(
1540 r#"
1541 dialect = "sqlite"
1542 [dbCredentials]
1543 host = "localhost"
1544 database = "db"
1545 "#,
1546 Path::new("test.toml"),
1547 )
1548 .unwrap_err();
1549
1550 let msg = err.to_string();
1551 assert!(
1552 msg.contains("host-based dbCredentials are only supported"),
1553 "unexpected error: {msg}"
1554 );
1555 }
1556
1557 #[cfg(windows)]
1558 #[test]
1559 fn schema_files_accept_backslash_paths() {
1560 let tmp = TempDir::new().unwrap();
1561 let cfg_dir = tmp.path().join("cfg");
1562 fs::create_dir_all(&cfg_dir).unwrap();
1563
1564 let schema_path = cfg_dir.join("src").join("schema.rs");
1565 fs::create_dir_all(schema_path.parent().unwrap()).unwrap();
1566 fs::write(&schema_path, "#[allow(dead_code)]\npub struct X;").unwrap();
1567
1568 let schema_str = schema_path.to_string_lossy().replace('/', "\\");
1570 let schema_toml = schema_str.replace('\\', "\\\\");
1572 let cfg_path = cfg_dir.join("drizzle.config.toml");
1573 let cfg = Config::load_from_str(
1574 &format!(
1575 r#"
1576 dialect = "sqlite"
1577 schema = "{}"
1578 "#,
1579 schema_toml
1580 ),
1581 &cfg_path,
1582 )
1583 .unwrap();
1584
1585 let db = cfg.default_database().unwrap();
1586 let files = db.schema_files().unwrap();
1587 assert_eq!(files, vec![schema_path]);
1588 }
1589}