pub use drizzle_types::{Casing, EnvOr, EnvOrError};
use schemars::JsonSchema;
use serde::Deserialize;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
pub const CONFIG_FILE: &str = "drizzle.config.toml";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Deserialize, JsonSchema)]
pub enum IntrospectCasing {
#[default]
#[serde(rename = "camel")]
Camel,
#[serde(rename = "preserve")]
Preserve,
}
impl IntrospectCasing {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Camel => "camel",
Self::Preserve => "preserve",
}
}
}
impl std::fmt::Display for IntrospectCasing {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl std::str::FromStr for IntrospectCasing {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"camel" | "camelCase" => Ok(Self::Camel),
"preserve" => Ok(Self::Preserve),
_ => Err(format!(
"invalid introspect casing '{s}', expected 'camel' or 'preserve'"
)),
}
}
}
#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
pub struct IntrospectConfig {
#[serde(default)]
pub casing: IntrospectCasing,
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
#[serde(untagged)]
pub enum RolesFilter {
Bool(bool),
Config {
#[serde(default)]
provider: Option<String>,
#[serde(default)]
include: Option<Vec<String>>,
#[serde(default)]
exclude: Option<Vec<String>>,
},
}
impl Default for RolesFilter {
fn default() -> Self {
Self::Bool(false)
}
}
impl RolesFilter {
#[must_use]
pub const fn is_enabled(&self) -> bool {
match self {
Self::Bool(b) => *b,
Self::Config { .. } => true,
}
}
#[must_use]
pub fn should_include(&self, role_name: &str) -> bool {
match self {
Self::Bool(b) => *b,
Self::Config {
provider,
include,
exclude,
} => {
if let Some(p) = provider
&& is_provider_role(p, role_name)
{
return false;
}
if let Some(excl) = exclude
&& excl.iter().any(|e| e == role_name)
{
return false;
}
if let Some(incl) = include {
return incl.iter().any(|i| i == role_name);
}
true
}
}
}
}
fn is_provider_role(provider: &str, role_name: &str) -> bool {
match provider {
"supabase" => matches!(
role_name,
"anon"
| "authenticated"
| "service_role"
| "supabase_admin"
| "supabase_auth_admin"
| "supabase_storage_admin"
| "dashboard_user"
| "supabase_replication_admin"
| "supabase_read_only_user"
| "supabase_realtime_admin"
| "supabase_functions_admin"
| "postgres"
| "pgbouncer"
| "pgsodium_keyholder"
| "pgsodium_keyiduser"
| "pgsodium_keymaker"
),
"neon" => matches!(
role_name,
"neon_superuser" | "cloud_admin" | "authenticated" | "anonymous"
),
_ => false,
}
}
#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
pub struct EntitiesFilter {
#[serde(default)]
pub roles: RolesFilter,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum Extension {
Postgis,
}
impl Extension {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Postgis => "postgis",
}
}
}
impl std::fmt::Display for Extension {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Default, serde::Serialize, Deserialize, JsonSchema,
)]
#[serde(rename_all = "lowercase")]
pub enum Dialect {
#[default]
Sqlite,
#[serde(alias = "postgres")]
Postgresql,
Turso,
}
impl Dialect {
pub const ALL: &'static [&'static str] = &["sqlite", "postgresql", "turso"];
#[inline]
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Sqlite => "sqlite",
Self::Postgresql => "postgresql",
Self::Turso => "turso",
}
}
#[inline]
#[must_use]
pub const fn to_base(self) -> drizzle_types::Dialect {
match self {
Self::Sqlite | Self::Turso => drizzle_types::Dialect::SQLite,
Self::Postgresql => drizzle_types::Dialect::PostgreSQL,
}
}
}
impl std::fmt::Display for Dialect {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl std::str::FromStr for Dialect {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"sqlite" => Ok(Self::Sqlite),
"postgresql" | "postgres" => Ok(Self::Postgresql),
"turso" => Ok(Self::Turso),
_ => Err(format!(
"invalid dialect '{}', expected one of: {}",
s,
Self::ALL.join(", ")
)),
}
}
}
impl From<Dialect> for drizzle_types::Dialect {
#[inline]
fn from(d: Dialect) -> Self {
d.to_base()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
#[serde(rename_all = "kebab-case")]
pub enum Driver {
Rusqlite,
Libsql,
Turso,
PostgresSync,
TokioPostgres,
D1Http,
DurableSqlite,
AwsDataApi,
}
impl Driver {
pub const ALL: &'static [&'static str] = &[
"rusqlite",
"libsql",
"turso",
"postgres-sync",
"tokio-postgres",
"d1-http",
"durable-sqlite",
"aws-data-api",
];
#[inline]
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Rusqlite => "rusqlite",
Self::Libsql => "libsql",
Self::Turso => "turso",
Self::PostgresSync => "postgres-sync",
Self::TokioPostgres => "tokio-postgres",
Self::D1Http => "d1-http",
Self::DurableSqlite => "durable-sqlite",
Self::AwsDataApi => "aws-data-api",
}
}
#[must_use]
pub const fn valid_for(dialect: Dialect) -> &'static [Self] {
match dialect {
Dialect::Sqlite => &[Self::Rusqlite, Self::D1Http, Self::DurableSqlite],
Dialect::Turso => &[Self::Libsql, Self::Turso],
Dialect::Postgresql => &[Self::PostgresSync, Self::TokioPostgres, Self::AwsDataApi],
}
}
#[inline]
#[must_use]
pub const fn is_valid_for(self, dialect: Dialect) -> bool {
matches!(
(self, dialect),
(
Self::Rusqlite | Self::D1Http | Self::DurableSqlite,
Dialect::Sqlite
) | (Self::Libsql | Self::Turso, Dialect::Turso)
| (
Self::PostgresSync | Self::TokioPostgres | Self::AwsDataApi,
Dialect::Postgresql
)
)
}
#[inline]
#[must_use]
pub const fn is_codegen_only(self) -> bool {
matches!(self, Self::DurableSqlite)
}
}
impl std::fmt::Display for Driver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl std::str::FromStr for Driver {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"rusqlite" => Ok(Self::Rusqlite),
"libsql" => Ok(Self::Libsql),
"turso" => Ok(Self::Turso),
"postgres-sync" => Ok(Self::PostgresSync),
"tokio-postgres" => Ok(Self::TokioPostgres),
"d1-http" => Ok(Self::D1Http),
"durable-sqlite" => Ok(Self::DurableSqlite),
"aws-data-api" => Ok(Self::AwsDataApi),
_ => Err(format!(
"invalid driver '{}', expected one of: {}",
s,
Self::ALL.join(", ")
)),
}
}
}
impl std::str::FromStr for Extension {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"postgis" => Ok(Self::Postgis),
_ => Err(format!(
"invalid extension filter '{s}', expected 'postgis'"
)),
}
}
}
#[derive(Debug, Clone)]
pub enum Credentials {
Sqlite { path: Box<str> },
Turso {
url: Box<str>,
auth_token: Option<Box<str>>,
},
Postgres(PostgresCreds),
D1 {
account_id: Box<str>,
database_id: Box<str>,
token: Box<str>,
},
AwsDataApi {
database: Box<str>,
secret_arn: Box<str>,
resource_arn: Box<str>,
},
}
#[derive(Debug, Clone)]
pub enum PostgresCreds {
Url(Box<str>),
Host {
host: Box<str>,
port: u16,
user: Option<Box<str>>,
password: Option<Box<str>>,
database: Box<str>,
ssl: bool,
},
}
impl PostgresCreds {
#[must_use]
pub fn connection_url(&self) -> String {
match self {
Self::Url(url) => url.to_string(),
Self::Host {
host,
port,
user,
password,
database,
..
} => {
let auth = match (user, password) {
(Some(u), Some(p)) => format!("{u}:{p}@"),
(Some(u), None) => format!("{u}@"),
_ => String::new(),
};
format!("postgres://{auth}{host}:{port}/{database}")
}
}
}
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
#[serde(untagged)]
pub enum Schema {
One(String),
Many(Vec<String>),
}
impl Default for Schema {
fn default() -> Self {
Self::One("src/schema.rs".into())
}
}
impl Schema {
pub fn iter(&self) -> impl Iterator<Item = &str> {
match self {
Self::One(s) => std::slice::from_ref(s).iter().map(String::as_str),
Self::Many(v) => v.iter().map(String::as_str),
}
}
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
#[serde(untagged)]
pub enum Filter {
One(String),
Many(Vec<String>),
}
impl Filter {
pub fn iter(&self) -> impl Iterator<Item = &str> {
match self {
Self::One(s) => std::slice::from_ref(s).iter().map(String::as_str),
Self::Many(v) => v.iter().map(String::as_str),
}
}
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub struct MigrationsOpts {
pub table: Option<String>,
pub schema: Option<String>,
pub prefix: Option<MigrationPrefix>,
#[serde(default)]
pub bundle: Option<bool>,
}
#[derive(Debug, Clone, Copy, Deserialize, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum MigrationPrefix {
Index,
Timestamp,
Supabase,
Unix,
None,
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
#[serde(untagged)]
enum RawCreds {
D1 {
#[serde(rename = "accountId")]
account_id: EnvOr,
#[serde(rename = "databaseId")]
database_id: EnvOr,
token: EnvOr,
},
AwsDataApi {
database: EnvOr,
#[serde(rename = "secretArn")]
secret_arn: EnvOr,
#[serde(rename = "resourceArn")]
resource_arn: EnvOr,
},
Url {
url: EnvOr,
#[serde(default, rename = "authToken")]
auth_token: Option<EnvOr>,
},
Host {
host: EnvOr,
#[serde(default)]
port: Option<u16>,
#[serde(default)]
user: Option<EnvOr>,
#[serde(default)]
password: Option<EnvOr>,
database: EnvOr,
#[serde(default)]
ssl: Option<SslVal>,
},
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
#[serde(untagged)]
enum SslVal {
Bool(bool),
Str(String),
}
impl SslVal {
fn enabled(&self) -> bool {
match self {
Self::Bool(b) => *b,
Self::Str(s) => !matches!(s.as_str(), "disable" | "false" | "no" | "off"),
}
}
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct DatabaseConfig {
pub dialect: Dialect,
#[serde(default)]
pub schema: Schema,
#[serde(default = "default_out")]
pub out: PathBuf,
#[serde(default = "yes")]
pub breakpoints: bool,
#[serde(default)]
pub driver: Option<Driver>,
#[serde(default)]
db_credentials: Option<RawCreds>,
#[serde(default)]
pub tables_filter: Option<Filter>,
#[serde(default)]
pub schema_filter: Option<Filter>,
#[serde(default)]
pub extensions_filters: Option<Vec<Extension>>,
#[serde(default)]
pub entities: Option<EntitiesFilter>,
#[serde(default)]
pub casing: Option<Casing>,
#[serde(default)]
pub introspect: Option<IntrospectConfig>,
#[serde(default)]
pub verbose: bool,
#[serde(default)]
pub migrations: Option<MigrationsOpts>,
}
fn default_out() -> PathBuf {
PathBuf::from("./drizzle")
}
const fn yes() -> bool {
true
}
impl DatabaseConfig {
fn normalize_paths(&mut self, base_dir: &Path) {
if self.out.is_relative() {
self.out = base_dir.join(&self.out);
}
let base = base_dir.to_string_lossy().replace('\\', "/");
let base = base.trim_end_matches('/').to_string();
let normalize_one = |p: &str| -> String {
let p_trim = p.trim();
let is_abs = Path::new(p_trim).is_absolute() || p_trim.starts_with("\\\\");
let joined = if is_abs || base.is_empty() || base == "." {
p_trim.to_string()
} else {
format!("{base}/{p_trim}")
};
joined.replace('\\', "/")
};
match &mut self.schema {
Schema::One(p) => *p = normalize_one(p),
Schema::Many(v) => {
for p in v.iter_mut() {
*p = normalize_one(p);
}
}
}
}
fn validate(&self, name: &str) -> Result<(), Error> {
if let Some(d) = self.driver
&& !d.is_valid_for(self.dialect)
{
return Err(Error::InvalidDriver {
driver: d,
dialect: self.dialect,
});
}
if let Some(ref raw) = self.db_credentials {
self.validate_creds(raw, name)?;
}
if self.dialect != Dialect::Postgresql {
if self.schema_filter.is_some() {
return Err(Error::InvalidConfig(
"schemaFilter is only supported for dialect = \"postgresql\"".into(),
));
}
if self.extensions_filters.is_some() {
return Err(Error::InvalidConfig(
"extensionsFilters is only supported for dialect = \"postgresql\"".into(),
));
}
if self.entities.is_some() {
return Err(Error::InvalidConfig(
"entities filter is only supported for dialect = \"postgresql\"".into(),
));
}
}
Ok(())
}
fn validate_creds(&self, raw: &RawCreds, _name: &str) -> Result<(), Error> {
let err = |msg: &str| Error::InvalidCredentials(msg.into());
match (self.dialect, raw) {
(Dialect::Postgresql, RawCreds::Host { .. } | RawCreds::Url { .. }) => {}
(_, RawCreds::Host { .. }) => {
return Err(err(
"host-based dbCredentials are only supported for dialect = \"postgresql\"",
));
}
_ => {}
}
if let RawCreds::D1 { .. } = raw {
if self.dialect != Dialect::Sqlite {
return Err(err(
"D1 dbCredentials (accountId/databaseId/token) require dialect = \"sqlite\"",
));
}
if self.driver != Some(Driver::D1Http) {
return Err(err(
"D1 dbCredentials (accountId/databaseId/token) require driver = \"d1-http\"",
));
}
}
if self.driver == Some(Driver::D1Http) && !matches!(raw, RawCreds::D1 { .. }) {
return Err(err(
"driver = \"d1-http\" requires dbCredentials with accountId, databaseId, and token",
));
}
if let RawCreds::AwsDataApi { .. } = raw {
if self.dialect != Dialect::Postgresql {
return Err(err(
"AWS Data API dbCredentials (database/secretArn/resourceArn) require dialect = \"postgresql\"",
));
}
if self.driver != Some(Driver::AwsDataApi) {
return Err(err(
"AWS Data API dbCredentials (database/secretArn/resourceArn) require driver = \"aws-data-api\"",
));
}
}
if self.driver == Some(Driver::AwsDataApi) && !matches!(raw, RawCreds::AwsDataApi { .. }) {
return Err(err(
"driver = \"aws-data-api\" requires dbCredentials with database, secretArn, and resourceArn",
));
}
match (self.dialect, raw) {
(
Dialect::Sqlite,
RawCreds::Url {
auth_token: Some(_),
..
},
) => Err(err(
"SQLite doesn't support authToken (use dialect = \"turso\")",
)),
(
Dialect::Sqlite,
RawCreds::Url {
url: EnvOr::Value(url),
..
},
) if url.starts_with("libsql://") => Err(err(
"libsql:// URLs require dialect = \"turso\" (for local SQLite files, use ./path.db)",
)),
(
Dialect::Sqlite,
RawCreds::Url {
url: EnvOr::Value(url),
..
},
) if url.starts_with("http://")
|| url.starts_with("https://")
|| url.starts_with("postgres://")
|| url.starts_with("postgresql://") =>
{
Err(err(
"SQLite dbCredentials.url must be a local file path (not an http(s)/postgres URL)",
))
}
(
Dialect::Turso,
RawCreds::Url {
url: EnvOr::Value(url),
..
},
) if !url.starts_with("libsql://") && !url.starts_with("http") => {
Err(err("Turso URL must start with libsql:// or http(s)://"))
}
(
Dialect::Postgresql,
RawCreds::Url {
url: EnvOr::Value(url),
..
},
) if !url.starts_with("postgres") => {
Err(err("PostgreSQL URL must start with postgres://"))
}
_ => Ok(()),
}
}
pub fn credentials(&self) -> Result<Option<Credentials>, Error> {
let Some(raw) = self.db_credentials.as_ref() else {
return Ok(None);
};
let resolve_opt = |opt: &Option<EnvOr>| -> Result<Option<Box<str>>, Error> {
match opt.as_ref() {
None => Ok(None),
Some(e) => Ok(Some(e.resolve()?.into_boxed_str())),
}
};
let creds = match (self.dialect, raw) {
(
Dialect::Sqlite,
RawCreds::D1 {
account_id,
database_id,
token,
},
) => Credentials::D1 {
account_id: account_id.resolve()?.into_boxed_str(),
database_id: database_id.resolve()?.into_boxed_str(),
token: token.resolve()?.into_boxed_str(),
},
(
Dialect::Postgresql,
RawCreds::AwsDataApi {
database,
secret_arn,
resource_arn,
},
) => Credentials::AwsDataApi {
database: database.resolve()?.into_boxed_str(),
secret_arn: secret_arn.resolve()?.into_boxed_str(),
resource_arn: resource_arn.resolve()?.into_boxed_str(),
},
(Dialect::Sqlite, RawCreds::Url { url, .. }) => Credentials::Sqlite {
path: url.resolve()?.into_boxed_str(),
},
(Dialect::Turso, RawCreds::Url { url, auth_token }) => Credentials::Turso {
url: url.resolve()?.into_boxed_str(),
auth_token: resolve_opt(auth_token)?,
},
(Dialect::Postgresql, RawCreds::Url { url, .. }) => {
Credentials::Postgres(PostgresCreds::Url(url.resolve()?.into_boxed_str()))
}
(
Dialect::Postgresql,
RawCreds::Host {
host,
port,
user,
password,
database,
ssl,
},
) => Credentials::Postgres(PostgresCreds::Host {
host: host.resolve()?.into_boxed_str(),
port: port.unwrap_or(5432),
user: resolve_opt(user)?,
password: resolve_opt(password)?,
database: database.resolve()?.into_boxed_str(),
ssl: ssl.as_ref().is_some_and(SslVal::enabled),
}),
_ => return Ok(None),
};
Ok(Some(creds))
}
#[inline]
#[must_use]
pub fn migrations_dir(&self) -> &Path {
&self.out
}
#[inline]
#[must_use]
pub fn meta_dir(&self) -> PathBuf {
self.out.join("meta")
}
#[inline]
#[must_use]
pub fn journal_path(&self) -> PathBuf {
self.meta_dir().join("_journal.json")
}
#[must_use]
pub fn schema_display(&self) -> String {
match &self.schema {
Schema::One(s) => s.clone(),
Schema::Many(v) => v.join(", "),
}
}
pub fn schema_files(&self) -> Result<Vec<PathBuf>, Error> {
let mut files = Vec::new();
for pattern in self.schema.iter() {
let pat = pattern.trim();
let is_glob = pat.contains('*') || pat.contains('?') || pat.contains('[');
if !is_glob {
let p = PathBuf::from(pat);
if p.exists() {
files.push(p);
continue;
}
}
let pat_norm = pat.replace('\\', "/");
match glob::glob(&pat_norm) {
Ok(paths) => {
let matched: Vec<_> = paths.filter_map(Result::ok).collect();
if matched.is_empty() && !is_glob {
let p = PathBuf::from(&pat_norm);
if p.exists() {
files.push(p);
}
} else {
files.extend(matched);
}
}
Err(e) => return Err(Error::Glob(pat.into(), e)),
}
}
files.retain(|p| p.is_file());
files.sort();
files.dedup();
if files.is_empty() {
return Err(Error::NoSchemaFiles(self.schema_display()));
}
Ok(files)
}
#[inline]
#[must_use]
pub fn effective_casing(&self) -> Casing {
self.casing.unwrap_or_default()
}
#[inline]
#[must_use]
pub fn effective_introspect_casing(&self) -> IntrospectCasing {
self.introspect
.as_ref()
.map(|i| i.casing)
.unwrap_or_default()
}
#[inline]
#[must_use]
pub fn effective_entities(&self) -> EntitiesFilter {
self.entities.clone().unwrap_or_default()
}
#[must_use]
pub fn should_include_role(&self, role_name: &str) -> bool {
self.entities
.as_ref()
.is_some_and(|e| e.roles.should_include(role_name))
}
#[must_use]
pub fn roles_enabled(&self) -> bool {
self.entities.as_ref().is_some_and(|e| e.roles.is_enabled())
}
#[must_use]
pub fn extensions(&self) -> &[Extension] {
self.extensions_filters.as_deref().unwrap_or(&[])
}
#[must_use]
pub fn has_extension(&self, ext: Extension) -> bool {
self.extensions_filters
.as_ref()
.is_some_and(|v| v.contains(&ext))
}
#[must_use]
pub fn migrations_table(&self) -> &str {
self.migrations
.as_ref()
.and_then(|m| m.table.as_deref())
.unwrap_or("__drizzle_migrations")
}
#[must_use]
pub fn migrations_schema(&self) -> &str {
self.migrations
.as_ref()
.and_then(|m| m.schema.as_deref())
.unwrap_or("drizzle")
}
#[must_use]
pub fn bundle_enabled(&self) -> bool {
if let Some(explicit) = self.migrations.as_ref().and_then(|m| m.bundle) {
return explicit;
}
matches!(self.driver, Some(Driver::DurableSqlite))
}
}
#[derive(Debug, Clone, Deserialize)]
struct MultiDbConfig {
databases: HashMap<String, DatabaseConfig>,
}
#[derive(Debug, Clone)]
pub struct Config {
databases: HashMap<String, DatabaseConfig>,
is_single: bool,
}
pub const DEFAULT_DB: &str = "default";
impl Config {
pub fn load() -> Result<Self, Error> {
Self::load_from(Path::new(CONFIG_FILE))
}
pub fn load_from(path: &Path) -> Result<Self, Error> {
let content = std::fs::read_to_string(path).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
Error::NotFound(path.into())
} else {
Error::Io(path.into(), e)
}
})?;
Self::load_from_str(&content, path)
}
fn load_from_str(content: &str, path: &Path) -> Result<Self, Error> {
let base_dir = path.parent().unwrap_or_else(|| Path::new("."));
if let Ok(multi) = toml::from_str::<MultiDbConfig>(content)
&& !multi.databases.is_empty()
{
let mut config = Self {
databases: multi.databases,
is_single: false,
};
for db in config.databases.values_mut() {
db.normalize_paths(base_dir);
}
config.validate()?;
return Ok(config);
}
let db_config: DatabaseConfig =
toml::from_str(content).map_err(|e| Error::Parse(path.into(), e))?;
let mut databases = HashMap::new();
databases.insert(DEFAULT_DB.to_string(), db_config);
let mut config = Self {
databases,
is_single: true,
};
for db in config.databases.values_mut() {
db.normalize_paths(base_dir);
}
config.validate()?;
Ok(config)
}
fn validate(&self) -> Result<(), Error> {
for (name, db) in &self.databases {
db.validate(name)?;
}
Ok(())
}
#[must_use]
pub const fn is_single_database(&self) -> bool {
self.is_single
}
pub fn database_names(&self) -> impl Iterator<Item = &str> {
self.databases.keys().map(String::as_str)
}
pub fn database(&self, name: Option<&str>) -> Result<&DatabaseConfig, Error> {
name.map_or_else(
|| {
if self.is_single {
self.databases.get(DEFAULT_DB).ok_or(Error::NoDatabases)
} else if self.databases.len() == 1 {
self.databases.values().next().ok_or(Error::NoDatabases)
} else {
Err(Error::DatabaseRequired(
self.databases.keys().cloned().collect(),
))
}
},
|name| {
if self.is_single {
self.databases.get(DEFAULT_DB).ok_or(Error::NoDatabases)
} else {
self.databases
.get(name)
.ok_or_else(|| Error::DatabaseNotFound(name.to_string()))
}
},
)
}
pub fn default_database(&self) -> Result<&DatabaseConfig, Error> {
self.database(None)
}
#[must_use]
pub fn dialect(&self) -> Dialect {
self.default_database()
.map(|d| d.dialect)
.unwrap_or_default()
}
pub fn credentials(&self) -> Result<Option<Credentials>, Error> {
self.default_database()?.credentials()
}
#[must_use]
pub fn migrations_dir(&self) -> &Path {
self.default_database()
.map_or_else(|_| Path::new("./drizzle"), |d| d.migrations_dir())
}
#[must_use]
pub fn journal_path(&self) -> PathBuf {
self.default_database().map_or_else(
|_| PathBuf::from("./drizzle/meta/_journal.json"),
DatabaseConfig::journal_path,
)
}
#[must_use]
pub fn schema_display(&self) -> String {
self.default_database()
.map_or_else(|_| "src/schema.rs".into(), DatabaseConfig::schema_display)
}
pub fn schema_files(&self) -> Result<Vec<PathBuf>, Error> {
self.default_database()?.schema_files()
}
#[must_use]
pub fn base_dialect(&self) -> drizzle_types::Dialect {
self.dialect().to_base()
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("config not found: {}", .0.display())]
NotFound(PathBuf),
#[error("failed to read {}: {}", .0.display(), .1)]
Io(PathBuf, #[source] std::io::Error),
#[error("failed to parse {}: {}", .0.display(), .1)]
Parse(PathBuf, #[source] toml::de::Error),
#[error("driver '{driver}' invalid for {dialect} dialect")]
InvalidDriver { driver: Driver, dialect: Dialect },
#[error("invalid credentials: {0}")]
InvalidCredentials(String),
#[error("invalid config: {0}")]
InvalidConfig(String),
#[error("invalid glob '{0}': {1}")]
Glob(String, #[source] glob::PatternError),
#[error("no schema files found: {0}")]
NoSchemaFiles(String),
#[error("environment variable '{0}' not found")]
EnvNotFound(String),
#[error("environment variable '{0}' invalid: {1}")]
EnvInvalid(String, String),
#[error("no databases configured")]
NoDatabases,
#[error("database '{0}' not found")]
DatabaseNotFound(String),
#[error("multiple databases configured, use --db to specify: {}", .0.join(", "))]
DatabaseRequired(Vec<String>),
}
impl From<EnvOrError> for Error {
fn from(err: EnvOrError) -> Self {
match err {
EnvOrError::NotPresent(var) => Self::EnvNotFound(var),
EnvOrError::NotUnicode(var) => Self::EnvInvalid(var, "contains invalid unicode".into()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn sqlite() {
let cfg = Config::load_from_str(
r#"
dialect = "sqlite"
[dbCredentials]
url = "./dev.db"
"#,
Path::new("test.toml"),
)
.unwrap();
assert!(cfg.is_single_database());
assert!(matches!(
cfg.credentials().unwrap(),
Some(Credentials::Sqlite { .. })
));
}
#[test]
fn postgres_url() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
assert!(matches!(
cfg.credentials().unwrap(),
Some(Credentials::Postgres(PostgresCreds::Url(_)))
));
}
#[test]
fn multi_database() {
let cfg = Config::load_from_str(
r#"
[databases.dev]
dialect = "sqlite"
out = "./drizzle/sqlite"
[databases.dev.dbCredentials]
url = "./dev.db"
[databases.prod]
dialect = "postgresql"
out = "./drizzle/postgres"
[databases.prod.dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
assert!(!cfg.is_single_database());
let names: Vec<_> = cfg.database_names().collect();
assert!(names.contains(&"dev"));
assert!(names.contains(&"prod"));
let dev = cfg.database(Some("dev")).unwrap();
assert_eq!(dev.dialect, Dialect::Sqlite);
let prod = cfg.database(Some("prod")).unwrap();
assert_eq!(prod.dialect, Dialect::Postgresql);
}
#[test]
fn multi_database_requires_selection() {
let cfg = Config::load_from_str(
r#"
[databases.a]
dialect = "sqlite"
[databases.b]
dialect = "postgresql"
"#,
Path::new("test.toml"),
)
.unwrap();
assert!(cfg.database(None).is_err());
}
#[test]
fn env_var_syntax() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
[dbCredentials]
url = { env = "DATABASE_URL" }
"#,
Path::new("test.toml"),
)
.unwrap();
assert!(cfg.is_single_database());
}
#[test]
fn casing_options() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
casing = "snake_case"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert_eq!(db.effective_casing(), Casing::SnakeCase);
let cfg2 = Config::load_from_str(
r#"
dialect = "postgresql"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db2 = cfg2.default_database().unwrap();
assert_eq!(db2.effective_casing(), Casing::CamelCase);
}
#[test]
fn introspect_casing() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
[introspect]
casing = "preserve"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert_eq!(db.effective_introspect_casing(), IntrospectCasing::Preserve);
}
#[test]
fn entities_roles_filter() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
[entities]
roles = true
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert!(db.roles_enabled());
assert!(db.should_include_role("my_role"));
let cfg2 = Config::load_from_str(
r#"
dialect = "postgresql"
[entities.roles]
provider = "supabase"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db2 = cfg2.default_database().unwrap();
assert!(db2.roles_enabled());
assert!(!db2.should_include_role("anon")); assert!(db2.should_include_role("my_custom_role"));
}
#[test]
fn extensions_filter() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
extensionsFilters = ["postgis"]
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert!(db.has_extension(Extension::Postgis));
}
#[test]
fn rejects_postgres_only_filters_for_sqlite() {
let err = Config::load_from_str(
r#"
dialect = "sqlite"
schemaFilter = ["public"]
[dbCredentials]
url = "./dev.db"
"#,
Path::new("test.toml"),
)
.expect_err("sqlite should reject schemaFilter");
assert_eq!(
err.to_string(),
"invalid config: schemaFilter is only supported for dialect = \"postgresql\""
);
let err = Config::load_from_str(
r#"
dialect = "sqlite"
extensionsFilters = ["postgis"]
[dbCredentials]
url = "./dev.db"
"#,
Path::new("test.toml"),
)
.expect_err("sqlite should reject extensionsFilters");
assert_eq!(
err.to_string(),
"invalid config: extensionsFilters is only supported for dialect = \"postgresql\""
);
}
#[test]
fn rejects_entities_filter_for_turso() {
let err = Config::load_from_str(
r#"
dialect = "turso"
[entities]
roles = true
[dbCredentials]
url = "libsql://example.turso.io"
"#,
Path::new("test.toml"),
)
.expect_err("turso should reject entities filter");
assert_eq!(
err.to_string(),
"invalid config: entities filter is only supported for dialect = \"postgresql\""
);
}
#[test]
fn migrations_config() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
[migrations]
table = "custom_migrations"
schema = "custom_schema"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert_eq!(db.migrations_table(), "custom_migrations");
assert_eq!(db.migrations_schema(), "custom_schema");
let cfg2 = Config::load_from_str(
r#"
dialect = "postgresql"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db2 = cfg2.default_database().unwrap();
assert_eq!(db2.migrations_table(), "__drizzle_migrations");
assert_eq!(db2.migrations_schema(), "drizzle");
}
#[test]
fn resolves_paths_relative_to_config_dir() {
let tmp = TempDir::new().unwrap();
let cfg_dir = tmp.path().join("cfg");
fs::create_dir_all(&cfg_dir).unwrap();
let schema_path = cfg_dir.join("schema.rs");
fs::write(&schema_path, "#[allow(dead_code)]\npub struct X;").unwrap();
let cfg_path = cfg_dir.join("drizzle.config.toml");
let cfg = Config::load_from_str(
r#"
dialect = "sqlite"
schema = "schema.rs"
out = "./drizzle"
[dbCredentials]
url = "./dev.db"
"#,
&cfg_path,
)
.unwrap();
let db = cfg.default_database().unwrap();
assert_eq!(db.migrations_dir(), cfg_dir.join("./drizzle").as_path());
let files = db.schema_files().unwrap();
assert_eq!(files.len(), 1);
assert_eq!(files[0], schema_path);
}
#[test]
fn rejects_host_credentials_for_sqlite() {
let err = Config::load_from_str(
r#"
dialect = "sqlite"
[dbCredentials]
host = "localhost"
database = "db"
"#,
Path::new("test.toml"),
)
.unwrap_err();
assert_eq!(
err.to_string(),
"invalid credentials: host-based dbCredentials are only supported for dialect = \"postgresql\""
);
}
#[test]
fn d1_http_credentials_parse() {
let cfg = Config::load_from_str(
r#"
dialect = "sqlite"
driver = "d1-http"
[dbCredentials]
accountId = "acc_abc"
databaseId = "db_xyz"
token = "tok_123"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert_eq!(db.driver, Some(Driver::D1Http));
match db.credentials().unwrap() {
Some(Credentials::D1 {
account_id,
database_id,
token,
}) => {
assert_eq!(&*account_id, "acc_abc");
assert_eq!(&*database_id, "db_xyz");
assert_eq!(&*token, "tok_123");
}
other => panic!("expected Credentials::D1, got {other:?}"),
}
}
#[test]
fn d1_http_credentials_resolve_from_env() {
unsafe {
std::env::set_var("TEST_D1_ACCT", "env_acct");
std::env::set_var("TEST_D1_DB", "env_db");
std::env::set_var("TEST_D1_TOKEN", "env_token");
}
let cfg = Config::load_from_str(
r#"
dialect = "sqlite"
driver = "d1-http"
[dbCredentials]
accountId = { env = "TEST_D1_ACCT" }
databaseId = { env = "TEST_D1_DB" }
token = { env = "TEST_D1_TOKEN" }
"#,
Path::new("test.toml"),
)
.unwrap();
match cfg.default_database().unwrap().credentials().unwrap() {
Some(Credentials::D1 {
account_id,
database_id,
token,
}) => {
assert_eq!(&*account_id, "env_acct");
assert_eq!(&*database_id, "env_db");
assert_eq!(&*token, "env_token");
}
other => panic!("expected Credentials::D1, got {other:?}"),
}
}
#[test]
fn d1_credentials_require_sqlite_dialect() {
let err = Config::load_from_str(
r#"
dialect = "postgresql"
[dbCredentials]
accountId = "acc"
databaseId = "db"
token = "tok"
"#,
Path::new("test.toml"),
)
.unwrap_err();
assert!(
err.to_string().contains("D1 dbCredentials"),
"expected D1-specific error, got: {err}"
);
}
#[test]
fn d1_credentials_require_d1_http_driver() {
let err = Config::load_from_str(
r#"
dialect = "sqlite"
driver = "rusqlite"
[dbCredentials]
accountId = "acc"
databaseId = "db"
token = "tok"
"#,
Path::new("test.toml"),
)
.unwrap_err();
assert!(
err.to_string().contains("driver = \"d1-http\""),
"expected d1-http driver error, got: {err}"
);
}
#[test]
fn d1_http_driver_requires_d1_credentials() {
let err = Config::load_from_str(
r#"
dialect = "sqlite"
driver = "d1-http"
[dbCredentials]
url = "./dev.db"
"#,
Path::new("test.toml"),
)
.unwrap_err();
assert!(
err.to_string().contains("accountId, databaseId, and token"),
"expected d1-http creds-shape error, got: {err}"
);
}
#[test]
fn durable_sqlite_no_credentials_ok() {
let cfg = Config::load_from_str(
r#"
dialect = "sqlite"
driver = "durable-sqlite"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert_eq!(db.driver, Some(Driver::DurableSqlite));
assert!(db.credentials().unwrap().is_none());
assert!(
db.bundle_enabled(),
"durable-sqlite should auto-enable bundle"
);
}
#[test]
fn durable_sqlite_explicit_bundle_false_respected() {
let cfg = Config::load_from_str(
r#"
dialect = "sqlite"
driver = "durable-sqlite"
[migrations]
bundle = false
"#,
Path::new("test.toml"),
)
.unwrap();
assert!(!cfg.default_database().unwrap().bundle_enabled());
}
#[test]
fn durable_sqlite_rejects_non_sqlite_dialect() {
let err = Config::load_from_str(
r#"
dialect = "postgresql"
driver = "durable-sqlite"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap_err();
assert!(
err.to_string().contains("invalid for postgresql"),
"expected dialect/driver mismatch error, got: {err}"
);
}
#[test]
fn driver_valid_for_sqlite_includes_cloudflare() {
let drivers = Driver::valid_for(Dialect::Sqlite);
assert!(drivers.contains(&Driver::Rusqlite));
assert!(drivers.contains(&Driver::D1Http));
assert!(drivers.contains(&Driver::DurableSqlite));
for drv in [Driver::D1Http, Driver::DurableSqlite] {
assert!(!drv.is_valid_for(Dialect::Postgresql));
assert!(!drv.is_valid_for(Dialect::Turso));
}
}
#[test]
fn driver_is_codegen_only_flag() {
assert!(Driver::DurableSqlite.is_codegen_only());
assert!(!Driver::D1Http.is_codegen_only());
assert!(!Driver::Rusqlite.is_codegen_only());
assert!(!Driver::AwsDataApi.is_codegen_only());
}
#[test]
fn aws_data_api_credentials_parse() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
driver = "aws-data-api"
[dbCredentials]
database = "mydb"
secretArn = "arn:aws:secretsmanager:us-east-1:123:secret:db-xyz"
resourceArn = "arn:aws:rds:us-east-1:123:cluster:my-aurora"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert_eq!(db.driver, Some(Driver::AwsDataApi));
match db.credentials().unwrap() {
Some(Credentials::AwsDataApi {
database,
secret_arn,
resource_arn,
}) => {
assert_eq!(&*database, "mydb");
assert!(secret_arn.starts_with("arn:aws:secretsmanager"));
assert!(resource_arn.starts_with("arn:aws:rds"));
}
other => panic!("expected Credentials::AwsDataApi, got {other:?}"),
}
}
#[test]
fn aws_data_api_credentials_resolve_from_env() {
unsafe {
std::env::set_var("TEST_AWS_DB", "envdb");
std::env::set_var("TEST_AWS_SECRET", "arn:env:secret");
std::env::set_var("TEST_AWS_RESOURCE", "arn:env:resource");
}
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
driver = "aws-data-api"
[dbCredentials]
database = { env = "TEST_AWS_DB" }
secretArn = { env = "TEST_AWS_SECRET" }
resourceArn = { env = "TEST_AWS_RESOURCE" }
"#,
Path::new("test.toml"),
)
.unwrap();
match cfg.default_database().unwrap().credentials().unwrap() {
Some(Credentials::AwsDataApi {
database,
secret_arn,
resource_arn,
}) => {
assert_eq!(&*database, "envdb");
assert_eq!(&*secret_arn, "arn:env:secret");
assert_eq!(&*resource_arn, "arn:env:resource");
}
other => panic!("expected Credentials::AwsDataApi, got {other:?}"),
}
}
#[test]
fn aws_data_api_requires_postgres_dialect() {
let err = Config::load_from_str(
r#"
dialect = "sqlite"
[dbCredentials]
database = "mydb"
secretArn = "arn:aws:secretsmanager:..."
resourceArn = "arn:aws:rds:..."
"#,
Path::new("test.toml"),
)
.unwrap_err();
assert!(
err.to_string().contains("AWS Data API dbCredentials"),
"expected AWS-specific error, got: {err}"
);
}
#[test]
fn aws_data_api_requires_aws_data_api_driver() {
let err = Config::load_from_str(
r#"
dialect = "postgresql"
driver = "tokio-postgres"
[dbCredentials]
database = "mydb"
secretArn = "arn:aws:secretsmanager:..."
resourceArn = "arn:aws:rds:..."
"#,
Path::new("test.toml"),
)
.unwrap_err();
assert!(
err.to_string().contains("driver = \"aws-data-api\""),
"expected aws-data-api driver error, got: {err}"
);
}
#[test]
fn aws_data_api_driver_requires_aws_credentials() {
let err = Config::load_from_str(
r#"
dialect = "postgresql"
driver = "aws-data-api"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap_err();
assert!(
err.to_string()
.contains("database, secretArn, and resourceArn"),
"expected aws-data-api creds-shape error, got: {err}"
);
}
#[test]
fn aws_data_api_rejected_for_non_postgres_dialect() {
let err = Config::load_from_str(
r#"
dialect = "sqlite"
driver = "aws-data-api"
"#,
Path::new("test.toml"),
)
.unwrap_err();
assert!(
err.to_string().contains("invalid for sqlite"),
"expected dialect/driver mismatch error, got: {err}"
);
}
#[test]
fn driver_valid_for_postgres_includes_aws_data_api() {
let drivers = Driver::valid_for(Dialect::Postgresql);
assert!(drivers.contains(&Driver::PostgresSync));
assert!(drivers.contains(&Driver::TokioPostgres));
assert!(drivers.contains(&Driver::AwsDataApi));
assert!(!Driver::AwsDataApi.is_valid_for(Dialect::Sqlite));
assert!(!Driver::AwsDataApi.is_valid_for(Dialect::Turso));
}
#[cfg(windows)]
#[test]
fn schema_files_accept_backslash_paths() {
let tmp = TempDir::new().unwrap();
let cfg_dir = tmp.path().join("cfg");
fs::create_dir_all(&cfg_dir).unwrap();
let schema_path = cfg_dir.join("src").join("schema.rs");
fs::create_dir_all(schema_path.parent().unwrap()).unwrap();
fs::write(&schema_path, "#[allow(dead_code)]\npub struct X;").unwrap();
let schema_str = schema_path.to_string_lossy().replace('/', "\\");
let schema_toml = schema_str.replace('\\', "\\\\");
let cfg_path = cfg_dir.join("drizzle.config.toml");
let cfg = Config::load_from_str(
&format!(
r#"
dialect = "sqlite"
schema = "{}"
"#,
schema_toml
),
&cfg_path,
)
.unwrap();
let db = cfg.default_database().unwrap();
let files = db.schema_files().unwrap();
assert_eq!(files, vec![schema_path]);
}
}