1pub use drizzle_types::{Casing, EnvOr, EnvOrError};
10use schemars::JsonSchema;
11use serde::Deserialize;
12use std::collections::HashMap;
13use std::path::{Path, PathBuf};
14
15pub const CONFIG_FILE: &str = "drizzle.config.toml";
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Deserialize, JsonSchema)]
19pub enum IntrospectCasing {
20 #[default]
22 #[serde(rename = "camel")]
23 Camel,
24 #[serde(rename = "preserve")]
26 Preserve,
27}
28
29impl IntrospectCasing {
30 #[must_use]
31 pub const fn as_str(self) -> &'static str {
32 match self {
33 Self::Camel => "camel",
34 Self::Preserve => "preserve",
35 }
36 }
37}
38
39impl std::fmt::Display for IntrospectCasing {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.write_str(self.as_str())
42 }
43}
44
45impl std::str::FromStr for IntrospectCasing {
46 type Err = String;
47
48 fn from_str(s: &str) -> Result<Self, Self::Err> {
49 match s {
50 "camel" | "camelCase" => Ok(Self::Camel),
51 "preserve" => Ok(Self::Preserve),
52 _ => Err(format!(
53 "invalid introspect casing '{s}', expected 'camel' or 'preserve'"
54 )),
55 }
56 }
57}
58
59#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
61pub struct IntrospectConfig {
62 #[serde(default)]
64 pub casing: IntrospectCasing,
65}
66
67#[derive(Debug, Clone, Deserialize, JsonSchema)]
76#[serde(untagged)]
77pub enum RolesFilter {
78 Bool(bool),
80 Config {
82 #[serde(default)]
84 provider: Option<String>,
85 #[serde(default)]
87 include: Option<Vec<String>>,
88 #[serde(default)]
90 exclude: Option<Vec<String>>,
91 },
92}
93
94impl Default for RolesFilter {
95 fn default() -> Self {
96 Self::Bool(false)
97 }
98}
99
100impl RolesFilter {
101 #[must_use]
103 pub const fn is_enabled(&self) -> bool {
104 match self {
105 Self::Bool(b) => *b,
106 Self::Config { .. } => true,
107 }
108 }
109
110 #[must_use]
112 pub fn should_include(&self, role_name: &str) -> bool {
113 match self {
114 Self::Bool(b) => *b,
115 Self::Config {
116 provider,
117 include,
118 exclude,
119 } => {
120 if let Some(p) = provider
122 && is_provider_role(p, role_name)
123 {
124 return false;
125 }
126 if let Some(excl) = exclude
128 && excl.iter().any(|e| e == role_name)
129 {
130 return false;
131 }
132 if let Some(incl) = include {
134 return incl.iter().any(|i| i == role_name);
135 }
136 true
137 }
138 }
139 }
140}
141
142fn is_provider_role(provider: &str, role_name: &str) -> bool {
144 match provider {
145 "supabase" => matches!(
146 role_name,
147 "anon"
148 | "authenticated"
149 | "service_role"
150 | "supabase_admin"
151 | "supabase_auth_admin"
152 | "supabase_storage_admin"
153 | "dashboard_user"
154 | "supabase_replication_admin"
155 | "supabase_read_only_user"
156 | "supabase_realtime_admin"
157 | "supabase_functions_admin"
158 | "postgres"
159 | "pgbouncer"
160 | "pgsodium_keyholder"
161 | "pgsodium_keyiduser"
162 | "pgsodium_keymaker"
163 ),
164 "neon" => matches!(
165 role_name,
166 "neon_superuser" | "cloud_admin" | "authenticated" | "anonymous"
167 ),
168 _ => false,
169 }
170}
171
172#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
176pub struct EntitiesFilter {
177 #[serde(default)]
179 pub roles: RolesFilter,
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
188#[serde(rename_all = "lowercase")]
189pub enum Extension {
190 Postgis,
192}
193
194impl Extension {
195 #[must_use]
196 pub const fn as_str(self) -> &'static str {
197 match self {
198 Self::Postgis => "postgis",
199 }
200 }
201}
202
203impl std::fmt::Display for Extension {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 f.write_str(self.as_str())
206 }
207}
208
209#[derive(
215 Debug, Clone, Copy, PartialEq, Eq, Hash, Default, serde::Serialize, Deserialize, JsonSchema,
216)]
217#[serde(rename_all = "lowercase")]
218pub enum Dialect {
219 #[default]
220 Sqlite,
221 #[serde(alias = "postgres")]
222 Postgresql,
223 Turso,
224}
225
226impl Dialect {
227 pub const ALL: &'static [&'static str] = &["sqlite", "postgresql", "turso"];
228
229 #[inline]
230 #[must_use]
231 pub const fn as_str(self) -> &'static str {
232 match self {
233 Self::Sqlite => "sqlite",
234 Self::Postgresql => "postgresql",
235 Self::Turso => "turso",
236 }
237 }
238
239 #[inline]
240 #[must_use]
241 pub const fn to_base(self) -> drizzle_types::Dialect {
242 match self {
243 Self::Sqlite | Self::Turso => drizzle_types::Dialect::SQLite,
244 Self::Postgresql => drizzle_types::Dialect::PostgreSQL,
245 }
246 }
247}
248
249impl std::fmt::Display for Dialect {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 f.write_str(self.as_str())
252 }
253}
254
255impl std::str::FromStr for Dialect {
256 type Err = String;
257
258 fn from_str(s: &str) -> Result<Self, Self::Err> {
259 match s {
260 "sqlite" => Ok(Self::Sqlite),
261 "postgresql" | "postgres" => Ok(Self::Postgresql),
262 "turso" => Ok(Self::Turso),
263 _ => Err(format!(
264 "invalid dialect '{}', expected one of: {}",
265 s,
266 Self::ALL.join(", ")
267 )),
268 }
269 }
270}
271
272impl From<Dialect> for drizzle_types::Dialect {
273 #[inline]
274 fn from(d: Dialect) -> Self {
275 d.to_base()
276 }
277}
278
279#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
285#[serde(rename_all = "kebab-case")]
286pub enum Driver {
287 Rusqlite,
289 Libsql,
291 Turso,
293 PostgresSync,
295 TokioPostgres,
297 D1Http,
305 DurableSqlite,
313 AwsDataApi,
327}
328
329impl Driver {
330 pub const ALL: &'static [&'static str] = &[
331 "rusqlite",
332 "libsql",
333 "turso",
334 "postgres-sync",
335 "tokio-postgres",
336 "d1-http",
337 "durable-sqlite",
338 "aws-data-api",
339 ];
340
341 #[inline]
342 #[must_use]
343 pub const fn as_str(self) -> &'static str {
344 match self {
345 Self::Rusqlite => "rusqlite",
346 Self::Libsql => "libsql",
347 Self::Turso => "turso",
348 Self::PostgresSync => "postgres-sync",
349 Self::TokioPostgres => "tokio-postgres",
350 Self::D1Http => "d1-http",
351 Self::DurableSqlite => "durable-sqlite",
352 Self::AwsDataApi => "aws-data-api",
353 }
354 }
355
356 #[must_use]
357 pub const fn valid_for(dialect: Dialect) -> &'static [Self] {
358 match dialect {
359 Dialect::Sqlite => &[Self::Rusqlite, Self::D1Http, Self::DurableSqlite],
363 Dialect::Turso => &[Self::Libsql, Self::Turso],
364 Dialect::Postgresql => &[Self::PostgresSync, Self::TokioPostgres, Self::AwsDataApi],
365 }
366 }
367
368 #[inline]
369 #[must_use]
370 pub const fn is_valid_for(self, dialect: Dialect) -> bool {
371 matches!(
372 (self, dialect),
373 (
374 Self::Rusqlite | Self::D1Http | Self::DurableSqlite,
375 Dialect::Sqlite
376 ) | (Self::Libsql | Self::Turso, Dialect::Turso)
377 | (
378 Self::PostgresSync | Self::TokioPostgres | Self::AwsDataApi,
379 Dialect::Postgresql
380 )
381 )
382 }
383
384 #[inline]
388 #[must_use]
389 pub const fn is_codegen_only(self) -> bool {
390 matches!(self, Self::DurableSqlite)
391 }
392}
393
394impl std::fmt::Display for Driver {
395 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 f.write_str(self.as_str())
397 }
398}
399
400impl std::str::FromStr for Driver {
401 type Err = String;
402
403 fn from_str(s: &str) -> Result<Self, Self::Err> {
404 match s {
405 "rusqlite" => Ok(Self::Rusqlite),
406 "libsql" => Ok(Self::Libsql),
407 "turso" => Ok(Self::Turso),
408 "postgres-sync" => Ok(Self::PostgresSync),
409 "tokio-postgres" => Ok(Self::TokioPostgres),
410 "d1-http" => Ok(Self::D1Http),
411 "durable-sqlite" => Ok(Self::DurableSqlite),
412 "aws-data-api" => Ok(Self::AwsDataApi),
413 _ => Err(format!(
414 "invalid driver '{}', expected one of: {}",
415 s,
416 Self::ALL.join(", ")
417 )),
418 }
419 }
420}
421
422impl std::str::FromStr for Extension {
423 type Err = String;
424
425 fn from_str(s: &str) -> Result<Self, Self::Err> {
426 match s {
427 "postgis" => Ok(Self::Postgis),
428 _ => Err(format!(
429 "invalid extension filter '{s}', expected 'postgis'"
430 )),
431 }
432 }
433}
434
435#[derive(Debug, Clone)]
441pub enum Credentials {
442 Sqlite { path: Box<str> },
444
445 Turso {
447 url: Box<str>,
448 auth_token: Option<Box<str>>,
449 },
450
451 Postgres(PostgresCreds),
453
454 D1 {
460 account_id: Box<str>,
461 database_id: Box<str>,
462 token: Box<str>,
463 },
464
465 AwsDataApi {
471 database: Box<str>,
472 secret_arn: Box<str>,
473 resource_arn: Box<str>,
474 },
475}
476
477#[derive(Debug, Clone)]
479pub enum PostgresCreds {
480 Url(Box<str>),
481 Host {
482 host: Box<str>,
483 port: u16,
484 user: Option<Box<str>>,
485 password: Option<Box<str>>,
486 database: Box<str>,
487 ssl: bool,
488 },
489}
490
491impl PostgresCreds {
492 #[must_use]
494 pub fn connection_url(&self) -> String {
495 match self {
496 Self::Url(url) => url.to_string(),
497 Self::Host {
498 host,
499 port,
500 user,
501 password,
502 database,
503 ..
504 } => {
505 let auth = match (user, password) {
506 (Some(u), Some(p)) => format!("{u}:{p}@"),
507 (Some(u), None) => format!("{u}@"),
508 _ => String::new(),
509 };
510 format!("postgres://{auth}{host}:{port}/{database}")
511 }
512 }
513 }
514}
515
516#[derive(Debug, Clone, Deserialize, JsonSchema)]
522#[serde(untagged)]
523pub enum Schema {
524 One(String),
525 Many(Vec<String>),
526}
527
528impl Default for Schema {
529 fn default() -> Self {
530 Self::One("src/schema.rs".into())
531 }
532}
533
534impl Schema {
535 pub fn iter(&self) -> impl Iterator<Item = &str> {
536 match self {
537 Self::One(s) => std::slice::from_ref(s).iter().map(String::as_str),
538 Self::Many(v) => v.iter().map(String::as_str),
539 }
540 }
541}
542
543#[derive(Debug, Clone, Deserialize, JsonSchema)]
545#[serde(untagged)]
546pub enum Filter {
547 One(String),
548 Many(Vec<String>),
549}
550
551impl Filter {
552 pub fn iter(&self) -> impl Iterator<Item = &str> {
553 match self {
554 Self::One(s) => std::slice::from_ref(s).iter().map(String::as_str),
555 Self::Many(v) => v.iter().map(String::as_str),
556 }
557 }
558}
559
560#[derive(Debug, Clone, Deserialize, JsonSchema)]
562pub struct MigrationsOpts {
563 pub table: Option<String>,
564 pub schema: Option<String>,
565 pub prefix: Option<MigrationPrefix>,
566 #[serde(default)]
573 pub bundle: Option<bool>,
574}
575
576#[derive(Debug, Clone, Copy, Deserialize, JsonSchema)]
577#[serde(rename_all = "lowercase")]
578pub enum MigrationPrefix {
579 Index,
580 Timestamp,
581 Supabase,
582 Unix,
583 None,
584}
585
586#[derive(Debug, Clone, Deserialize, JsonSchema)]
591#[serde(untagged)]
592enum RawCreds {
593 D1 {
600 #[serde(rename = "accountId")]
601 account_id: EnvOr,
602 #[serde(rename = "databaseId")]
603 database_id: EnvOr,
604 token: EnvOr,
605 },
606 AwsDataApi {
613 database: EnvOr,
614 #[serde(rename = "secretArn")]
615 secret_arn: EnvOr,
616 #[serde(rename = "resourceArn")]
617 resource_arn: EnvOr,
618 },
619 Url {
620 url: EnvOr,
621 #[serde(default, rename = "authToken")]
622 auth_token: Option<EnvOr>,
623 },
624 Host {
625 host: EnvOr,
626 #[serde(default)]
627 port: Option<u16>,
628 #[serde(default)]
629 user: Option<EnvOr>,
630 #[serde(default)]
631 password: Option<EnvOr>,
632 database: EnvOr,
633 #[serde(default)]
634 ssl: Option<SslVal>,
635 },
636}
637
638#[derive(Debug, Clone, Deserialize, JsonSchema)]
639#[serde(untagged)]
640enum SslVal {
641 Bool(bool),
642 Str(String),
643}
644
645impl SslVal {
646 fn enabled(&self) -> bool {
647 match self {
648 Self::Bool(b) => *b,
649 Self::Str(s) => !matches!(s.as_str(), "disable" | "false" | "no" | "off"),
650 }
651 }
652}
653
654#[derive(Debug, Clone, Deserialize, JsonSchema)]
662#[serde(rename_all = "camelCase")]
663pub struct DatabaseConfig {
664 pub dialect: Dialect,
666
667 #[serde(default)]
669 pub schema: Schema,
670
671 #[serde(default = "default_out")]
673 pub out: PathBuf,
674
675 #[serde(default = "yes")]
677 pub breakpoints: bool,
678
679 #[serde(default)]
681 pub driver: Option<Driver>,
682
683 #[serde(default)]
685 db_credentials: Option<RawCreds>,
686
687 #[serde(default)]
689 pub tables_filter: Option<Filter>,
690
691 #[serde(default)]
693 pub schema_filter: Option<Filter>,
694
695 #[serde(default)]
697 pub extensions_filters: Option<Vec<Extension>>,
698
699 #[serde(default)]
701 pub entities: Option<EntitiesFilter>,
702
703 #[serde(default)]
705 pub casing: Option<Casing>,
706
707 #[serde(default)]
709 pub introspect: Option<IntrospectConfig>,
710
711 #[serde(default)]
713 pub verbose: bool,
714
715 #[serde(default)]
717 pub migrations: Option<MigrationsOpts>,
718}
719
720fn default_out() -> PathBuf {
721 PathBuf::from("./drizzle")
722}
723
724const fn yes() -> bool {
725 true
726}
727
728impl DatabaseConfig {
729 fn normalize_paths(&mut self, base_dir: &Path) {
730 if self.out.is_relative() {
733 self.out = base_dir.join(&self.out);
734 }
735
736 let base = base_dir.to_string_lossy().replace('\\', "/");
740 let base = base.trim_end_matches('/').to_string();
741
742 let normalize_one = |p: &str| -> String {
743 let p_trim = p.trim();
744 let is_abs = Path::new(p_trim).is_absolute() || p_trim.starts_with("\\\\");
745 let joined = if is_abs || base.is_empty() || base == "." {
746 p_trim.to_string()
747 } else {
748 format!("{base}/{p_trim}")
749 };
750 joined.replace('\\', "/")
751 };
752
753 match &mut self.schema {
754 Schema::One(p) => *p = normalize_one(p),
755 Schema::Many(v) => {
756 for p in v.iter_mut() {
757 *p = normalize_one(p);
758 }
759 }
760 }
761 }
762
763 fn validate(&self, name: &str) -> Result<(), Error> {
764 if let Some(d) = self.driver
766 && !d.is_valid_for(self.dialect)
767 {
768 return Err(Error::InvalidDriver {
769 driver: d,
770 dialect: self.dialect,
771 });
772 }
773
774 if let Some(ref raw) = self.db_credentials {
776 self.validate_creds(raw, name)?;
777 }
778
779 if self.dialect != Dialect::Postgresql {
781 if self.schema_filter.is_some() {
782 return Err(Error::InvalidConfig(
783 "schemaFilter is only supported for dialect = \"postgresql\"".into(),
784 ));
785 }
786 if self.extensions_filters.is_some() {
787 return Err(Error::InvalidConfig(
788 "extensionsFilters is only supported for dialect = \"postgresql\"".into(),
789 ));
790 }
791 if self.entities.is_some() {
792 return Err(Error::InvalidConfig(
793 "entities filter is only supported for dialect = \"postgresql\"".into(),
794 ));
795 }
796 }
797
798 Ok(())
799 }
800
801 fn validate_creds(&self, raw: &RawCreds, _name: &str) -> Result<(), Error> {
802 let err = |msg: &str| Error::InvalidCredentials(msg.into());
803
804 match (self.dialect, raw) {
807 (Dialect::Postgresql, RawCreds::Host { .. } | RawCreds::Url { .. }) => {}
808 (_, RawCreds::Host { .. }) => {
809 return Err(err(
810 "host-based dbCredentials are only supported for dialect = \"postgresql\"",
811 ));
812 }
813 _ => {}
814 }
815
816 if let RawCreds::D1 { .. } = raw {
820 if self.dialect != Dialect::Sqlite {
821 return Err(err(
822 "D1 dbCredentials (accountId/databaseId/token) require dialect = \"sqlite\"",
823 ));
824 }
825 if self.driver != Some(Driver::D1Http) {
826 return Err(err(
827 "D1 dbCredentials (accountId/databaseId/token) require driver = \"d1-http\"",
828 ));
829 }
830 }
831
832 if self.driver == Some(Driver::D1Http) && !matches!(raw, RawCreds::D1 { .. }) {
836 return Err(err(
837 "driver = \"d1-http\" requires dbCredentials with accountId, databaseId, and token",
838 ));
839 }
840
841 if let RawCreds::AwsDataApi { .. } = raw {
845 if self.dialect != Dialect::Postgresql {
846 return Err(err(
847 "AWS Data API dbCredentials (database/secretArn/resourceArn) require dialect = \"postgresql\"",
848 ));
849 }
850 if self.driver != Some(Driver::AwsDataApi) {
851 return Err(err(
852 "AWS Data API dbCredentials (database/secretArn/resourceArn) require driver = \"aws-data-api\"",
853 ));
854 }
855 }
856
857 if self.driver == Some(Driver::AwsDataApi) && !matches!(raw, RawCreds::AwsDataApi { .. }) {
859 return Err(err(
860 "driver = \"aws-data-api\" requires dbCredentials with database, secretArn, and resourceArn",
861 ));
862 }
863
864 match (self.dialect, raw) {
866 (
867 Dialect::Sqlite,
868 RawCreds::Url {
869 auth_token: Some(_),
870 ..
871 },
872 ) => Err(err(
873 "SQLite doesn't support authToken (use dialect = \"turso\")",
874 )),
875 (
876 Dialect::Sqlite,
877 RawCreds::Url {
878 url: EnvOr::Value(url),
879 ..
880 },
881 ) if url.starts_with("libsql://") => Err(err(
882 "libsql:// URLs require dialect = \"turso\" (for local SQLite files, use ./path.db)",
883 )),
884 (
885 Dialect::Sqlite,
886 RawCreds::Url {
887 url: EnvOr::Value(url),
888 ..
889 },
890 ) if url.starts_with("http://")
891 || url.starts_with("https://")
892 || url.starts_with("postgres://")
893 || url.starts_with("postgresql://") =>
894 {
895 Err(err(
896 "SQLite dbCredentials.url must be a local file path (not an http(s)/postgres URL)",
897 ))
898 }
899 (
900 Dialect::Turso,
901 RawCreds::Url {
902 url: EnvOr::Value(url),
903 ..
904 },
905 ) if !url.starts_with("libsql://") && !url.starts_with("http") => {
906 Err(err("Turso URL must start with libsql:// or http(s)://"))
907 }
908 (
909 Dialect::Postgresql,
910 RawCreds::Url {
911 url: EnvOr::Value(url),
912 ..
913 },
914 ) if !url.starts_with("postgres") => {
915 Err(err("PostgreSQL URL must start with postgres://"))
916 }
917 _ => Ok(()),
918 }
919 }
920
921 pub fn credentials(&self) -> Result<Option<Credentials>, Error> {
929 let Some(raw) = self.db_credentials.as_ref() else {
930 return Ok(None);
931 };
932
933 let resolve_opt = |opt: &Option<EnvOr>| -> Result<Option<Box<str>>, Error> {
935 match opt.as_ref() {
936 None => Ok(None),
937 Some(e) => Ok(Some(e.resolve()?.into_boxed_str())),
938 }
939 };
940
941 let creds = match (self.dialect, raw) {
942 (
946 Dialect::Sqlite,
947 RawCreds::D1 {
948 account_id,
949 database_id,
950 token,
951 },
952 ) => Credentials::D1 {
953 account_id: account_id.resolve()?.into_boxed_str(),
954 database_id: database_id.resolve()?.into_boxed_str(),
955 token: token.resolve()?.into_boxed_str(),
956 },
957 (
960 Dialect::Postgresql,
961 RawCreds::AwsDataApi {
962 database,
963 secret_arn,
964 resource_arn,
965 },
966 ) => Credentials::AwsDataApi {
967 database: database.resolve()?.into_boxed_str(),
968 secret_arn: secret_arn.resolve()?.into_boxed_str(),
969 resource_arn: resource_arn.resolve()?.into_boxed_str(),
970 },
971 (Dialect::Sqlite, RawCreds::Url { url, .. }) => Credentials::Sqlite {
973 path: url.resolve()?.into_boxed_str(),
974 },
975 (Dialect::Turso, RawCreds::Url { url, auth_token }) => Credentials::Turso {
977 url: url.resolve()?.into_boxed_str(),
978 auth_token: resolve_opt(auth_token)?,
979 },
980 (Dialect::Postgresql, RawCreds::Url { url, .. }) => {
982 Credentials::Postgres(PostgresCreds::Url(url.resolve()?.into_boxed_str()))
983 }
984 (
986 Dialect::Postgresql,
987 RawCreds::Host {
988 host,
989 port,
990 user,
991 password,
992 database,
993 ssl,
994 },
995 ) => Credentials::Postgres(PostgresCreds::Host {
996 host: host.resolve()?.into_boxed_str(),
997 port: port.unwrap_or(5432),
998 user: resolve_opt(user)?,
999 password: resolve_opt(password)?,
1000 database: database.resolve()?.into_boxed_str(),
1001 ssl: ssl.as_ref().is_some_and(SslVal::enabled),
1002 }),
1003 _ => return Ok(None),
1004 };
1005
1006 Ok(Some(creds))
1007 }
1008
1009 #[inline]
1011 #[must_use]
1012 pub fn migrations_dir(&self) -> &Path {
1013 &self.out
1014 }
1015
1016 #[inline]
1018 #[must_use]
1019 pub fn meta_dir(&self) -> PathBuf {
1020 self.out.join("meta")
1021 }
1022
1023 #[inline]
1025 #[must_use]
1026 pub fn journal_path(&self) -> PathBuf {
1027 self.meta_dir().join("_journal.json")
1028 }
1029
1030 #[must_use]
1032 pub fn schema_display(&self) -> String {
1033 match &self.schema {
1034 Schema::One(s) => s.clone(),
1035 Schema::Many(v) => v.join(", "),
1036 }
1037 }
1038
1039 pub fn schema_files(&self) -> Result<Vec<PathBuf>, Error> {
1046 let mut files = Vec::new();
1047
1048 for pattern in self.schema.iter() {
1049 let pat = pattern.trim();
1050
1051 let is_glob = pat.contains('*') || pat.contains('?') || pat.contains('[');
1053 if !is_glob {
1054 let p = PathBuf::from(pat);
1055 if p.exists() {
1056 files.push(p);
1057 continue;
1058 }
1059 }
1060
1061 let pat_norm = pat.replace('\\', "/");
1063 match glob::glob(&pat_norm) {
1064 Ok(paths) => {
1065 let matched: Vec<_> = paths.filter_map(Result::ok).collect();
1066 if matched.is_empty() && !is_glob {
1067 let p = PathBuf::from(&pat_norm);
1068 if p.exists() {
1069 files.push(p);
1070 }
1071 } else {
1072 files.extend(matched);
1073 }
1074 }
1075 Err(e) => return Err(Error::Glob(pat.into(), e)),
1076 }
1077 }
1078
1079 files.retain(|p| p.is_file());
1081 files.sort();
1082 files.dedup();
1083
1084 if files.is_empty() {
1085 return Err(Error::NoSchemaFiles(self.schema_display()));
1086 }
1087
1088 Ok(files)
1089 }
1090
1091 #[inline]
1093 #[must_use]
1094 pub fn effective_casing(&self) -> Casing {
1095 self.casing.unwrap_or_default()
1096 }
1097
1098 #[inline]
1100 #[must_use]
1101 pub fn effective_introspect_casing(&self) -> IntrospectCasing {
1102 self.introspect
1103 .as_ref()
1104 .map(|i| i.casing)
1105 .unwrap_or_default()
1106 }
1107
1108 #[inline]
1110 #[must_use]
1111 pub fn effective_entities(&self) -> EntitiesFilter {
1112 self.entities.clone().unwrap_or_default()
1113 }
1114
1115 #[must_use]
1117 pub fn should_include_role(&self, role_name: &str) -> bool {
1118 self.entities
1119 .as_ref()
1120 .is_some_and(|e| e.roles.should_include(role_name))
1121 }
1122
1123 #[must_use]
1125 pub fn roles_enabled(&self) -> bool {
1126 self.entities.as_ref().is_some_and(|e| e.roles.is_enabled())
1127 }
1128
1129 #[must_use]
1131 pub fn extensions(&self) -> &[Extension] {
1132 self.extensions_filters.as_deref().unwrap_or(&[])
1133 }
1134
1135 #[must_use]
1137 pub fn has_extension(&self, ext: Extension) -> bool {
1138 self.extensions_filters
1139 .as_ref()
1140 .is_some_and(|v| v.contains(&ext))
1141 }
1142
1143 #[must_use]
1145 pub fn migrations_table(&self) -> &str {
1146 self.migrations
1147 .as_ref()
1148 .and_then(|m| m.table.as_deref())
1149 .unwrap_or("__drizzle_migrations")
1150 }
1151
1152 #[must_use]
1154 pub fn migrations_schema(&self) -> &str {
1155 self.migrations
1156 .as_ref()
1157 .and_then(|m| m.schema.as_deref())
1158 .unwrap_or("drizzle")
1159 }
1160
1161 #[must_use]
1170 pub fn bundle_enabled(&self) -> bool {
1171 if let Some(explicit) = self.migrations.as_ref().and_then(|m| m.bundle) {
1172 return explicit;
1173 }
1174 matches!(self.driver, Some(Driver::DurableSqlite))
1175 }
1176}
1177
1178#[derive(Debug, Clone, Deserialize)]
1184struct MultiDbConfig {
1185 databases: HashMap<String, DatabaseConfig>,
1186}
1187
1188#[derive(Debug, Clone)]
1212pub struct Config {
1213 databases: HashMap<String, DatabaseConfig>,
1215 is_single: bool,
1217}
1218
1219pub const DEFAULT_DB: &str = "default";
1221
1222impl Config {
1223 pub fn load() -> Result<Self, Error> {
1230 Self::load_from(Path::new(CONFIG_FILE))
1231 }
1232
1233 pub fn load_from(path: &Path) -> Result<Self, Error> {
1241 let content = std::fs::read_to_string(path).map_err(|e| {
1242 if e.kind() == std::io::ErrorKind::NotFound {
1243 Error::NotFound(path.into())
1244 } else {
1245 Error::Io(path.into(), e)
1246 }
1247 })?;
1248
1249 Self::load_from_str(&content, path)
1250 }
1251
1252 fn load_from_str(content: &str, path: &Path) -> Result<Self, Error> {
1254 let base_dir = path.parent().unwrap_or_else(|| Path::new("."));
1255
1256 if let Ok(multi) = toml::from_str::<MultiDbConfig>(content)
1258 && !multi.databases.is_empty()
1259 {
1260 let mut config = Self {
1261 databases: multi.databases,
1262 is_single: false,
1263 };
1264 for db in config.databases.values_mut() {
1265 db.normalize_paths(base_dir);
1266 }
1267 config.validate()?;
1268 return Ok(config);
1269 }
1270
1271 let db_config: DatabaseConfig =
1273 toml::from_str(content).map_err(|e| Error::Parse(path.into(), e))?;
1274
1275 let mut databases = HashMap::new();
1276 databases.insert(DEFAULT_DB.to_string(), db_config);
1277
1278 let mut config = Self {
1279 databases,
1280 is_single: true,
1281 };
1282 for db in config.databases.values_mut() {
1283 db.normalize_paths(base_dir);
1284 }
1285 config.validate()?;
1286 Ok(config)
1287 }
1288
1289 fn validate(&self) -> Result<(), Error> {
1290 for (name, db) in &self.databases {
1291 db.validate(name)?;
1292 }
1293 Ok(())
1294 }
1295
1296 #[must_use]
1298 pub const fn is_single_database(&self) -> bool {
1299 self.is_single
1300 }
1301
1302 pub fn database_names(&self) -> impl Iterator<Item = &str> {
1304 self.databases.keys().map(String::as_str)
1305 }
1306
1307 pub fn database(&self, name: Option<&str>) -> Result<&DatabaseConfig, Error> {
1319 name.map_or_else(
1320 || {
1321 if self.is_single {
1323 self.databases.get(DEFAULT_DB).ok_or(Error::NoDatabases)
1324 } else if self.databases.len() == 1 {
1325 self.databases.values().next().ok_or(Error::NoDatabases)
1326 } else {
1327 Err(Error::DatabaseRequired(
1328 self.databases.keys().cloned().collect(),
1329 ))
1330 }
1331 },
1332 |name| {
1333 if self.is_single {
1334 self.databases.get(DEFAULT_DB).ok_or(Error::NoDatabases)
1336 } else {
1337 self.databases
1338 .get(name)
1339 .ok_or_else(|| Error::DatabaseNotFound(name.to_string()))
1340 }
1341 },
1342 )
1343 }
1344
1345 pub fn default_database(&self) -> Result<&DatabaseConfig, Error> {
1351 self.database(None)
1352 }
1353
1354 #[must_use]
1360 pub fn dialect(&self) -> Dialect {
1361 self.default_database()
1362 .map(|d| d.dialect)
1363 .unwrap_or_default()
1364 }
1365
1366 pub fn credentials(&self) -> Result<Option<Credentials>, Error> {
1373 self.default_database()?.credentials()
1374 }
1375
1376 #[must_use]
1378 pub fn migrations_dir(&self) -> &Path {
1379 self.default_database()
1380 .map_or_else(|_| Path::new("./drizzle"), |d| d.migrations_dir())
1381 }
1382
1383 #[must_use]
1385 pub fn journal_path(&self) -> PathBuf {
1386 self.default_database().map_or_else(
1387 |_| PathBuf::from("./drizzle/meta/_journal.json"),
1388 DatabaseConfig::journal_path,
1389 )
1390 }
1391
1392 #[must_use]
1394 pub fn schema_display(&self) -> String {
1395 self.default_database()
1396 .map_or_else(|_| "src/schema.rs".into(), DatabaseConfig::schema_display)
1397 }
1398
1399 pub fn schema_files(&self) -> Result<Vec<PathBuf>, Error> {
1406 self.default_database()?.schema_files()
1407 }
1408
1409 #[must_use]
1411 pub fn base_dialect(&self) -> drizzle_types::Dialect {
1412 self.dialect().to_base()
1413 }
1414}
1415
1416#[derive(Debug, thiserror::Error)]
1421pub enum Error {
1422 #[error("config not found: {}", .0.display())]
1423 NotFound(PathBuf),
1424
1425 #[error("failed to read {}: {}", .0.display(), .1)]
1426 Io(PathBuf, #[source] std::io::Error),
1427
1428 #[error("failed to parse {}: {}", .0.display(), .1)]
1429 Parse(PathBuf, #[source] toml::de::Error),
1430
1431 #[error("driver '{driver}' invalid for {dialect} dialect")]
1432 InvalidDriver { driver: Driver, dialect: Dialect },
1433
1434 #[error("invalid credentials: {0}")]
1435 InvalidCredentials(String),
1436
1437 #[error("invalid config: {0}")]
1438 InvalidConfig(String),
1439
1440 #[error("invalid glob '{0}': {1}")]
1441 Glob(String, #[source] glob::PatternError),
1442
1443 #[error("no schema files found: {0}")]
1444 NoSchemaFiles(String),
1445
1446 #[error("environment variable '{0}' not found")]
1447 EnvNotFound(String),
1448
1449 #[error("environment variable '{0}' invalid: {1}")]
1450 EnvInvalid(String, String),
1451
1452 #[error("no databases configured")]
1453 NoDatabases,
1454
1455 #[error("database '{0}' not found")]
1456 DatabaseNotFound(String),
1457
1458 #[error("multiple databases configured, use --db to specify: {}", .0.join(", "))]
1459 DatabaseRequired(Vec<String>),
1460}
1461
1462impl From<EnvOrError> for Error {
1463 fn from(err: EnvOrError) -> Self {
1464 match err {
1465 EnvOrError::NotPresent(var) => Self::EnvNotFound(var),
1466 EnvOrError::NotUnicode(var) => Self::EnvInvalid(var, "contains invalid unicode".into()),
1467 }
1468 }
1469}
1470
1471#[cfg(test)]
1476mod tests {
1477 use super::*;
1478 use std::fs;
1479 use tempfile::TempDir;
1480
1481 #[test]
1482 fn sqlite() {
1483 let cfg = Config::load_from_str(
1484 r#"
1485 dialect = "sqlite"
1486 [dbCredentials]
1487 url = "./dev.db"
1488 "#,
1489 Path::new("test.toml"),
1490 )
1491 .unwrap();
1492 assert!(cfg.is_single_database());
1493 assert!(matches!(
1494 cfg.credentials().unwrap(),
1495 Some(Credentials::Sqlite { .. })
1496 ));
1497 }
1498
1499 #[test]
1500 fn postgres_url() {
1501 let cfg = Config::load_from_str(
1502 r#"
1503 dialect = "postgresql"
1504 [dbCredentials]
1505 url = "postgres://localhost/db"
1506 "#,
1507 Path::new("test.toml"),
1508 )
1509 .unwrap();
1510 assert!(matches!(
1511 cfg.credentials().unwrap(),
1512 Some(Credentials::Postgres(PostgresCreds::Url(_)))
1513 ));
1514 }
1515
1516 #[test]
1517 fn multi_database() {
1518 let cfg = Config::load_from_str(
1519 r#"
1520 [databases.dev]
1521 dialect = "sqlite"
1522 out = "./drizzle/sqlite"
1523 [databases.dev.dbCredentials]
1524 url = "./dev.db"
1525
1526 [databases.prod]
1527 dialect = "postgresql"
1528 out = "./drizzle/postgres"
1529 [databases.prod.dbCredentials]
1530 url = "postgres://localhost/db"
1531 "#,
1532 Path::new("test.toml"),
1533 )
1534 .unwrap();
1535
1536 assert!(!cfg.is_single_database());
1537 let names: Vec<_> = cfg.database_names().collect();
1538 assert!(names.contains(&"dev"));
1539 assert!(names.contains(&"prod"));
1540
1541 let dev = cfg.database(Some("dev")).unwrap();
1542 assert_eq!(dev.dialect, Dialect::Sqlite);
1543
1544 let prod = cfg.database(Some("prod")).unwrap();
1545 assert_eq!(prod.dialect, Dialect::Postgresql);
1546 }
1547
1548 #[test]
1549 fn multi_database_requires_selection() {
1550 let cfg = Config::load_from_str(
1551 r#"
1552 [databases.a]
1553 dialect = "sqlite"
1554 [databases.b]
1555 dialect = "postgresql"
1556 "#,
1557 Path::new("test.toml"),
1558 )
1559 .unwrap();
1560
1561 assert!(cfg.database(None).is_err());
1563 }
1564
1565 #[test]
1566 fn env_var_syntax() {
1567 let cfg = Config::load_from_str(
1568 r#"
1569 dialect = "postgresql"
1570 [dbCredentials]
1571 url = { env = "DATABASE_URL" }
1572 "#,
1573 Path::new("test.toml"),
1574 )
1575 .unwrap();
1576 assert!(cfg.is_single_database());
1577 }
1578
1579 #[test]
1580 fn casing_options() {
1581 let cfg = Config::load_from_str(
1582 r#"
1583 dialect = "postgresql"
1584 casing = "snake_case"
1585 [dbCredentials]
1586 url = "postgres://localhost/db"
1587 "#,
1588 Path::new("test.toml"),
1589 )
1590 .unwrap();
1591 let db = cfg.default_database().unwrap();
1592 assert_eq!(db.effective_casing(), Casing::SnakeCase);
1593
1594 let cfg2 = Config::load_from_str(
1596 r#"
1597 dialect = "postgresql"
1598 [dbCredentials]
1599 url = "postgres://localhost/db"
1600 "#,
1601 Path::new("test.toml"),
1602 )
1603 .unwrap();
1604 let db2 = cfg2.default_database().unwrap();
1605 assert_eq!(db2.effective_casing(), Casing::CamelCase);
1606 }
1607
1608 #[test]
1609 fn introspect_casing() {
1610 let cfg = Config::load_from_str(
1611 r#"
1612 dialect = "postgresql"
1613 [introspect]
1614 casing = "preserve"
1615 [dbCredentials]
1616 url = "postgres://localhost/db"
1617 "#,
1618 Path::new("test.toml"),
1619 )
1620 .unwrap();
1621 let db = cfg.default_database().unwrap();
1622 assert_eq!(db.effective_introspect_casing(), IntrospectCasing::Preserve);
1623 }
1624
1625 #[test]
1626 fn entities_roles_filter() {
1627 let cfg = Config::load_from_str(
1629 r#"
1630 dialect = "postgresql"
1631 [entities]
1632 roles = true
1633 [dbCredentials]
1634 url = "postgres://localhost/db"
1635 "#,
1636 Path::new("test.toml"),
1637 )
1638 .unwrap();
1639 let db = cfg.default_database().unwrap();
1640 assert!(db.roles_enabled());
1641 assert!(db.should_include_role("my_role"));
1642
1643 let cfg2 = Config::load_from_str(
1645 r#"
1646 dialect = "postgresql"
1647 [entities.roles]
1648 provider = "supabase"
1649 [dbCredentials]
1650 url = "postgres://localhost/db"
1651 "#,
1652 Path::new("test.toml"),
1653 )
1654 .unwrap();
1655 let db2 = cfg2.default_database().unwrap();
1656 assert!(db2.roles_enabled());
1657 assert!(!db2.should_include_role("anon")); assert!(db2.should_include_role("my_custom_role"));
1659 }
1660
1661 #[test]
1662 fn extensions_filter() {
1663 let cfg = Config::load_from_str(
1664 r#"
1665 dialect = "postgresql"
1666 extensionsFilters = ["postgis"]
1667 [dbCredentials]
1668 url = "postgres://localhost/db"
1669 "#,
1670 Path::new("test.toml"),
1671 )
1672 .unwrap();
1673 let db = cfg.default_database().unwrap();
1674 assert!(db.has_extension(Extension::Postgis));
1675 }
1676
1677 #[test]
1678 fn rejects_postgres_only_filters_for_sqlite() {
1679 let err = Config::load_from_str(
1680 r#"
1681 dialect = "sqlite"
1682 schemaFilter = ["public"]
1683 [dbCredentials]
1684 url = "./dev.db"
1685 "#,
1686 Path::new("test.toml"),
1687 )
1688 .expect_err("sqlite should reject schemaFilter");
1689 assert_eq!(
1690 err.to_string(),
1691 "invalid config: schemaFilter is only supported for dialect = \"postgresql\""
1692 );
1693
1694 let err = Config::load_from_str(
1695 r#"
1696 dialect = "sqlite"
1697 extensionsFilters = ["postgis"]
1698 [dbCredentials]
1699 url = "./dev.db"
1700 "#,
1701 Path::new("test.toml"),
1702 )
1703 .expect_err("sqlite should reject extensionsFilters");
1704 assert_eq!(
1705 err.to_string(),
1706 "invalid config: extensionsFilters is only supported for dialect = \"postgresql\""
1707 );
1708 }
1709
1710 #[test]
1711 fn rejects_entities_filter_for_turso() {
1712 let err = Config::load_from_str(
1713 r#"
1714 dialect = "turso"
1715 [entities]
1716 roles = true
1717 [dbCredentials]
1718 url = "libsql://example.turso.io"
1719 "#,
1720 Path::new("test.toml"),
1721 )
1722 .expect_err("turso should reject entities filter");
1723 assert_eq!(
1724 err.to_string(),
1725 "invalid config: entities filter is only supported for dialect = \"postgresql\""
1726 );
1727 }
1728
1729 #[test]
1730 fn migrations_config() {
1731 let cfg = Config::load_from_str(
1732 r#"
1733 dialect = "postgresql"
1734 [migrations]
1735 table = "custom_migrations"
1736 schema = "custom_schema"
1737 [dbCredentials]
1738 url = "postgres://localhost/db"
1739 "#,
1740 Path::new("test.toml"),
1741 )
1742 .unwrap();
1743 let db = cfg.default_database().unwrap();
1744 assert_eq!(db.migrations_table(), "custom_migrations");
1745 assert_eq!(db.migrations_schema(), "custom_schema");
1746
1747 let cfg2 = Config::load_from_str(
1749 r#"
1750 dialect = "postgresql"
1751 [dbCredentials]
1752 url = "postgres://localhost/db"
1753 "#,
1754 Path::new("test.toml"),
1755 )
1756 .unwrap();
1757 let db2 = cfg2.default_database().unwrap();
1758 assert_eq!(db2.migrations_table(), "__drizzle_migrations");
1759 assert_eq!(db2.migrations_schema(), "drizzle");
1760 }
1761
1762 #[test]
1763 fn resolves_paths_relative_to_config_dir() {
1764 let tmp = TempDir::new().unwrap();
1765 let cfg_dir = tmp.path().join("cfg");
1766 fs::create_dir_all(&cfg_dir).unwrap();
1767
1768 let schema_path = cfg_dir.join("schema.rs");
1770 fs::write(&schema_path, "#[allow(dead_code)]\npub struct X;").unwrap();
1771
1772 let cfg_path = cfg_dir.join("drizzle.config.toml");
1773 let cfg = Config::load_from_str(
1774 r#"
1775 dialect = "sqlite"
1776 schema = "schema.rs"
1777 out = "./drizzle"
1778 [dbCredentials]
1779 url = "./dev.db"
1780 "#,
1781 &cfg_path,
1782 )
1783 .unwrap();
1784
1785 let db = cfg.default_database().unwrap();
1786 assert_eq!(db.migrations_dir(), cfg_dir.join("./drizzle").as_path());
1787
1788 let files = db.schema_files().unwrap();
1789 assert_eq!(files.len(), 1);
1790 assert_eq!(files[0], schema_path);
1791 }
1792
1793 #[test]
1794 fn rejects_host_credentials_for_sqlite() {
1795 let err = Config::load_from_str(
1796 r#"
1797 dialect = "sqlite"
1798 [dbCredentials]
1799 host = "localhost"
1800 database = "db"
1801 "#,
1802 Path::new("test.toml"),
1803 )
1804 .unwrap_err();
1805
1806 assert_eq!(
1807 err.to_string(),
1808 "invalid credentials: host-based dbCredentials are only supported for dialect = \"postgresql\""
1809 );
1810 }
1811
1812 #[test]
1817 fn d1_http_credentials_parse() {
1818 let cfg = Config::load_from_str(
1819 r#"
1820 dialect = "sqlite"
1821 driver = "d1-http"
1822 [dbCredentials]
1823 accountId = "acc_abc"
1824 databaseId = "db_xyz"
1825 token = "tok_123"
1826 "#,
1827 Path::new("test.toml"),
1828 )
1829 .unwrap();
1830
1831 let db = cfg.default_database().unwrap();
1832 assert_eq!(db.driver, Some(Driver::D1Http));
1833 match db.credentials().unwrap() {
1834 Some(Credentials::D1 {
1835 account_id,
1836 database_id,
1837 token,
1838 }) => {
1839 assert_eq!(&*account_id, "acc_abc");
1840 assert_eq!(&*database_id, "db_xyz");
1841 assert_eq!(&*token, "tok_123");
1842 }
1843 other => panic!("expected Credentials::D1, got {other:?}"),
1844 }
1845 }
1846
1847 #[test]
1848 fn d1_http_credentials_resolve_from_env() {
1849 unsafe {
1851 std::env::set_var("TEST_D1_ACCT", "env_acct");
1852 std::env::set_var("TEST_D1_DB", "env_db");
1853 std::env::set_var("TEST_D1_TOKEN", "env_token");
1854 }
1855 let cfg = Config::load_from_str(
1856 r#"
1857 dialect = "sqlite"
1858 driver = "d1-http"
1859 [dbCredentials]
1860 accountId = { env = "TEST_D1_ACCT" }
1861 databaseId = { env = "TEST_D1_DB" }
1862 token = { env = "TEST_D1_TOKEN" }
1863 "#,
1864 Path::new("test.toml"),
1865 )
1866 .unwrap();
1867
1868 match cfg.default_database().unwrap().credentials().unwrap() {
1869 Some(Credentials::D1 {
1870 account_id,
1871 database_id,
1872 token,
1873 }) => {
1874 assert_eq!(&*account_id, "env_acct");
1875 assert_eq!(&*database_id, "env_db");
1876 assert_eq!(&*token, "env_token");
1877 }
1878 other => panic!("expected Credentials::D1, got {other:?}"),
1879 }
1880 }
1881
1882 #[test]
1883 fn d1_credentials_require_sqlite_dialect() {
1884 let err = Config::load_from_str(
1885 r#"
1886 dialect = "postgresql"
1887 [dbCredentials]
1888 accountId = "acc"
1889 databaseId = "db"
1890 token = "tok"
1891 "#,
1892 Path::new("test.toml"),
1893 )
1894 .unwrap_err();
1895 assert!(
1896 err.to_string().contains("D1 dbCredentials"),
1897 "expected D1-specific error, got: {err}"
1898 );
1899 }
1900
1901 #[test]
1902 fn d1_credentials_require_d1_http_driver() {
1903 let err = Config::load_from_str(
1905 r#"
1906 dialect = "sqlite"
1907 driver = "rusqlite"
1908 [dbCredentials]
1909 accountId = "acc"
1910 databaseId = "db"
1911 token = "tok"
1912 "#,
1913 Path::new("test.toml"),
1914 )
1915 .unwrap_err();
1916 assert!(
1917 err.to_string().contains("driver = \"d1-http\""),
1918 "expected d1-http driver error, got: {err}"
1919 );
1920 }
1921
1922 #[test]
1923 fn d1_http_driver_requires_d1_credentials() {
1924 let err = Config::load_from_str(
1926 r#"
1927 dialect = "sqlite"
1928 driver = "d1-http"
1929 [dbCredentials]
1930 url = "./dev.db"
1931 "#,
1932 Path::new("test.toml"),
1933 )
1934 .unwrap_err();
1935 assert!(
1936 err.to_string().contains("accountId, databaseId, and token"),
1937 "expected d1-http creds-shape error, got: {err}"
1938 );
1939 }
1940
1941 #[test]
1942 fn durable_sqlite_no_credentials_ok() {
1943 let cfg = Config::load_from_str(
1946 r#"
1947 dialect = "sqlite"
1948 driver = "durable-sqlite"
1949 "#,
1950 Path::new("test.toml"),
1951 )
1952 .unwrap();
1953
1954 let db = cfg.default_database().unwrap();
1955 assert_eq!(db.driver, Some(Driver::DurableSqlite));
1956 assert!(db.credentials().unwrap().is_none());
1957 assert!(
1959 db.bundle_enabled(),
1960 "durable-sqlite should auto-enable bundle"
1961 );
1962 }
1963
1964 #[test]
1965 fn durable_sqlite_explicit_bundle_false_respected() {
1966 let cfg = Config::load_from_str(
1968 r#"
1969 dialect = "sqlite"
1970 driver = "durable-sqlite"
1971 [migrations]
1972 bundle = false
1973 "#,
1974 Path::new("test.toml"),
1975 )
1976 .unwrap();
1977 assert!(!cfg.default_database().unwrap().bundle_enabled());
1978 }
1979
1980 #[test]
1981 fn durable_sqlite_rejects_non_sqlite_dialect() {
1982 let err = Config::load_from_str(
1983 r#"
1984 dialect = "postgresql"
1985 driver = "durable-sqlite"
1986 [dbCredentials]
1987 url = "postgres://localhost/db"
1988 "#,
1989 Path::new("test.toml"),
1990 )
1991 .unwrap_err();
1992 assert!(
1993 err.to_string().contains("invalid for postgresql"),
1994 "expected dialect/driver mismatch error, got: {err}"
1995 );
1996 }
1997
1998 #[test]
1999 fn driver_valid_for_sqlite_includes_cloudflare() {
2000 let drivers = Driver::valid_for(Dialect::Sqlite);
2001 assert!(drivers.contains(&Driver::Rusqlite));
2002 assert!(drivers.contains(&Driver::D1Http));
2003 assert!(drivers.contains(&Driver::DurableSqlite));
2004 for drv in [Driver::D1Http, Driver::DurableSqlite] {
2006 assert!(!drv.is_valid_for(Dialect::Postgresql));
2007 assert!(!drv.is_valid_for(Dialect::Turso));
2008 }
2009 }
2010
2011 #[test]
2012 fn driver_is_codegen_only_flag() {
2013 assert!(Driver::DurableSqlite.is_codegen_only());
2014 assert!(!Driver::D1Http.is_codegen_only());
2015 assert!(!Driver::Rusqlite.is_codegen_only());
2016 assert!(!Driver::AwsDataApi.is_codegen_only());
2017 }
2018
2019 #[test]
2024 fn aws_data_api_credentials_parse() {
2025 let cfg = Config::load_from_str(
2026 r#"
2027 dialect = "postgresql"
2028 driver = "aws-data-api"
2029 [dbCredentials]
2030 database = "mydb"
2031 secretArn = "arn:aws:secretsmanager:us-east-1:123:secret:db-xyz"
2032 resourceArn = "arn:aws:rds:us-east-1:123:cluster:my-aurora"
2033 "#,
2034 Path::new("test.toml"),
2035 )
2036 .unwrap();
2037
2038 let db = cfg.default_database().unwrap();
2039 assert_eq!(db.driver, Some(Driver::AwsDataApi));
2040 match db.credentials().unwrap() {
2041 Some(Credentials::AwsDataApi {
2042 database,
2043 secret_arn,
2044 resource_arn,
2045 }) => {
2046 assert_eq!(&*database, "mydb");
2047 assert!(secret_arn.starts_with("arn:aws:secretsmanager"));
2048 assert!(resource_arn.starts_with("arn:aws:rds"));
2049 }
2050 other => panic!("expected Credentials::AwsDataApi, got {other:?}"),
2051 }
2052 }
2053
2054 #[test]
2055 fn aws_data_api_credentials_resolve_from_env() {
2056 unsafe {
2057 std::env::set_var("TEST_AWS_DB", "envdb");
2058 std::env::set_var("TEST_AWS_SECRET", "arn:env:secret");
2059 std::env::set_var("TEST_AWS_RESOURCE", "arn:env:resource");
2060 }
2061 let cfg = Config::load_from_str(
2062 r#"
2063 dialect = "postgresql"
2064 driver = "aws-data-api"
2065 [dbCredentials]
2066 database = { env = "TEST_AWS_DB" }
2067 secretArn = { env = "TEST_AWS_SECRET" }
2068 resourceArn = { env = "TEST_AWS_RESOURCE" }
2069 "#,
2070 Path::new("test.toml"),
2071 )
2072 .unwrap();
2073
2074 match cfg.default_database().unwrap().credentials().unwrap() {
2075 Some(Credentials::AwsDataApi {
2076 database,
2077 secret_arn,
2078 resource_arn,
2079 }) => {
2080 assert_eq!(&*database, "envdb");
2081 assert_eq!(&*secret_arn, "arn:env:secret");
2082 assert_eq!(&*resource_arn, "arn:env:resource");
2083 }
2084 other => panic!("expected Credentials::AwsDataApi, got {other:?}"),
2085 }
2086 }
2087
2088 #[test]
2089 fn aws_data_api_requires_postgres_dialect() {
2090 let err = Config::load_from_str(
2091 r#"
2092 dialect = "sqlite"
2093 [dbCredentials]
2094 database = "mydb"
2095 secretArn = "arn:aws:secretsmanager:..."
2096 resourceArn = "arn:aws:rds:..."
2097 "#,
2098 Path::new("test.toml"),
2099 )
2100 .unwrap_err();
2101 assert!(
2102 err.to_string().contains("AWS Data API dbCredentials"),
2103 "expected AWS-specific error, got: {err}"
2104 );
2105 }
2106
2107 #[test]
2108 fn aws_data_api_requires_aws_data_api_driver() {
2109 let err = Config::load_from_str(
2111 r#"
2112 dialect = "postgresql"
2113 driver = "tokio-postgres"
2114 [dbCredentials]
2115 database = "mydb"
2116 secretArn = "arn:aws:secretsmanager:..."
2117 resourceArn = "arn:aws:rds:..."
2118 "#,
2119 Path::new("test.toml"),
2120 )
2121 .unwrap_err();
2122 assert!(
2123 err.to_string().contains("driver = \"aws-data-api\""),
2124 "expected aws-data-api driver error, got: {err}"
2125 );
2126 }
2127
2128 #[test]
2129 fn aws_data_api_driver_requires_aws_credentials() {
2130 let err = Config::load_from_str(
2132 r#"
2133 dialect = "postgresql"
2134 driver = "aws-data-api"
2135 [dbCredentials]
2136 url = "postgres://localhost/db"
2137 "#,
2138 Path::new("test.toml"),
2139 )
2140 .unwrap_err();
2141 assert!(
2142 err.to_string()
2143 .contains("database, secretArn, and resourceArn"),
2144 "expected aws-data-api creds-shape error, got: {err}"
2145 );
2146 }
2147
2148 #[test]
2149 fn aws_data_api_rejected_for_non_postgres_dialect() {
2150 let err = Config::load_from_str(
2151 r#"
2152 dialect = "sqlite"
2153 driver = "aws-data-api"
2154 "#,
2155 Path::new("test.toml"),
2156 )
2157 .unwrap_err();
2158 assert!(
2159 err.to_string().contains("invalid for sqlite"),
2160 "expected dialect/driver mismatch error, got: {err}"
2161 );
2162 }
2163
2164 #[test]
2165 fn driver_valid_for_postgres_includes_aws_data_api() {
2166 let drivers = Driver::valid_for(Dialect::Postgresql);
2167 assert!(drivers.contains(&Driver::PostgresSync));
2168 assert!(drivers.contains(&Driver::TokioPostgres));
2169 assert!(drivers.contains(&Driver::AwsDataApi));
2170 assert!(!Driver::AwsDataApi.is_valid_for(Dialect::Sqlite));
2172 assert!(!Driver::AwsDataApi.is_valid_for(Dialect::Turso));
2173 }
2174
2175 #[cfg(windows)]
2176 #[test]
2177 fn schema_files_accept_backslash_paths() {
2178 let tmp = TempDir::new().unwrap();
2179 let cfg_dir = tmp.path().join("cfg");
2180 fs::create_dir_all(&cfg_dir).unwrap();
2181
2182 let schema_path = cfg_dir.join("src").join("schema.rs");
2183 fs::create_dir_all(schema_path.parent().unwrap()).unwrap();
2184 fs::write(&schema_path, "#[allow(dead_code)]\npub struct X;").unwrap();
2185
2186 let schema_str = schema_path.to_string_lossy().replace('/', "\\");
2188 let schema_toml = schema_str.replace('\\', "\\\\");
2190 let cfg_path = cfg_dir.join("drizzle.config.toml");
2191 let cfg = Config::load_from_str(
2192 &format!(
2193 r#"
2194 dialect = "sqlite"
2195 schema = "{}"
2196 "#,
2197 schema_toml
2198 ),
2199 &cfg_path,
2200 )
2201 .unwrap();
2202
2203 let db = cfg.default_database().unwrap();
2204 let files = db.schema_files().unwrap();
2205 assert_eq!(files, vec![schema_path]);
2206 }
2207}