use serde::Deserialize;
use serde::de::{self, Deserializer, MapAccess, Visitor};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
pub const CONFIG_FILE: &str = "drizzle.config.toml";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Deserialize)]
pub enum Casing {
#[default]
#[serde(rename = "camelCase")]
CamelCase,
#[serde(rename = "snake_case")]
SnakeCase,
}
impl Casing {
pub const fn as_str(self) -> &'static str {
match self {
Self::CamelCase => "camelCase",
Self::SnakeCase => "snake_case",
}
}
}
impl std::fmt::Display for Casing {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl std::str::FromStr for Casing {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"camelCase" | "camel" => Ok(Self::CamelCase),
"snake_case" | "snake" => Ok(Self::SnakeCase),
_ => Err(format!(
"invalid casing '{}', expected 'camelCase' or 'snake_case'",
s
)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Deserialize)]
pub enum IntrospectCasing {
#[default]
#[serde(rename = "camel")]
Camel,
#[serde(rename = "preserve")]
Preserve,
}
impl IntrospectCasing {
pub const fn as_str(self) -> &'static str {
match self {
Self::Camel => "camel",
Self::Preserve => "preserve",
}
}
}
impl std::fmt::Display for IntrospectCasing {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl std::str::FromStr for IntrospectCasing {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"camel" | "camelCase" => Ok(Self::Camel),
"preserve" => Ok(Self::Preserve),
_ => Err(format!(
"invalid introspect casing '{}', expected 'camel' or 'preserve'",
s
)),
}
}
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct IntrospectConfig {
#[serde(default)]
pub casing: IntrospectCasing,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum RolesFilter {
Bool(bool),
Config {
#[serde(default)]
provider: Option<String>,
#[serde(default)]
include: Option<Vec<String>>,
#[serde(default)]
exclude: Option<Vec<String>>,
},
}
impl Default for RolesFilter {
fn default() -> Self {
Self::Bool(false)
}
}
impl RolesFilter {
pub fn is_enabled(&self) -> bool {
match self {
Self::Bool(b) => *b,
Self::Config { .. } => true,
}
}
pub fn should_include(&self, role_name: &str) -> bool {
match self {
Self::Bool(b) => *b,
Self::Config {
provider,
include,
exclude,
} => {
if let Some(p) = provider
&& is_provider_role(p, role_name)
{
return false;
}
if let Some(excl) = exclude
&& excl.iter().any(|e| e == role_name)
{
return false;
}
if let Some(incl) = include {
return incl.iter().any(|i| i == role_name);
}
true
}
}
}
}
fn is_provider_role(provider: &str, role_name: &str) -> bool {
match provider {
"supabase" => matches!(
role_name,
"anon"
| "authenticated"
| "service_role"
| "supabase_admin"
| "supabase_auth_admin"
| "supabase_storage_admin"
| "dashboard_user"
| "supabase_replication_admin"
| "supabase_read_only_user"
| "supabase_realtime_admin"
| "supabase_functions_admin"
| "postgres"
| "pgbouncer"
| "pgsodium_keyholder"
| "pgsodium_keyiduser"
| "pgsodium_keymaker"
),
"neon" => matches!(
role_name,
"neon_superuser" | "cloud_admin" | "authenticated" | "anonymous"
),
_ => false,
}
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct EntitiesFilter {
#[serde(default)]
pub roles: RolesFilter,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Extension {
Postgis,
}
impl Extension {
pub const fn as_str(self) -> &'static str {
match self {
Self::Postgis => "postgis",
}
}
}
impl std::fmt::Display for Extension {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone)]
pub enum EnvOr {
Value(String),
Env(String),
}
impl EnvOr {
pub fn resolve(&self) -> Result<String, Error> {
match self {
Self::Value(v) => Ok(v.clone()),
Self::Env(var) => std::env::var(var).map_err(|_| Error::EnvNotFound(var.clone())),
}
}
pub fn resolve_optional(&self) -> Result<Option<String>, Error> {
match self {
Self::Value(v) => Ok(Some(v.clone())),
Self::Env(var) => match std::env::var(var) {
Ok(v) => Ok(Some(v)),
Err(std::env::VarError::NotPresent) => Ok(None),
Err(std::env::VarError::NotUnicode(_)) => Err(Error::EnvInvalid(
var.clone(),
"contains invalid unicode".into(),
)),
},
}
}
}
impl<'de> Deserialize<'de> for EnvOr {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct EnvOrVisitor;
impl<'de> Visitor<'de> for EnvOrVisitor {
type Value = EnvOr;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string or { env = \"VAR_NAME\" }")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(EnvOr::Value(value.to_string()))
}
fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut env_var: Option<String> = None;
while let Some(key) = map.next_key::<String>()? {
if key == "env" {
env_var = Some(map.next_value()?);
} else {
return Err(de::Error::unknown_field(&key, &["env"]));
}
}
env_var
.map(EnvOr::Env)
.ok_or_else(|| de::Error::missing_field("env"))
}
}
deserializer.deserialize_any(EnvOrVisitor)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Dialect {
#[default]
Sqlite,
#[serde(alias = "postgres")]
Postgresql,
Turso,
}
impl Dialect {
pub const ALL: &'static [&'static str] = &["sqlite", "postgresql", "turso"];
#[inline]
pub const fn as_str(self) -> &'static str {
match self {
Self::Sqlite => "sqlite",
Self::Postgresql => "postgresql",
Self::Turso => "turso",
}
}
#[inline]
pub const fn to_base(self) -> drizzle_types::Dialect {
match self {
Self::Sqlite | Self::Turso => drizzle_types::Dialect::SQLite,
Self::Postgresql => drizzle_types::Dialect::PostgreSQL,
}
}
}
impl std::fmt::Display for Dialect {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl From<Dialect> for drizzle_types::Dialect {
#[inline]
fn from(d: Dialect) -> Self {
d.to_base()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum Driver {
Rusqlite,
Libsql,
Turso,
PostgresSync,
TokioPostgres,
}
impl Driver {
pub const ALL: &'static [&'static str] = &[
"rusqlite",
"libsql",
"turso",
"postgres-sync",
"tokio-postgres",
];
#[inline]
pub const fn as_str(self) -> &'static str {
match self {
Self::Rusqlite => "rusqlite",
Self::Libsql => "libsql",
Self::Turso => "turso",
Self::PostgresSync => "postgres-sync",
Self::TokioPostgres => "tokio-postgres",
}
}
pub const fn valid_for(dialect: Dialect) -> &'static [Driver] {
match dialect {
Dialect::Sqlite => &[Self::Rusqlite],
Dialect::Turso => &[Self::Libsql, Self::Turso],
Dialect::Postgresql => &[Self::PostgresSync, Self::TokioPostgres],
}
}
#[inline]
pub const fn is_valid_for(self, dialect: Dialect) -> bool {
matches!(
(self, dialect),
(Self::Rusqlite, Dialect::Sqlite)
| (Self::Libsql | Self::Turso, Dialect::Turso)
| (
Self::PostgresSync | Self::TokioPostgres,
Dialect::Postgresql
)
)
}
}
impl std::fmt::Display for Driver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone)]
pub enum Credentials {
Sqlite { path: Box<str> },
Turso {
url: Box<str>,
auth_token: Option<Box<str>>,
},
Postgres(PostgresCreds),
}
#[derive(Debug, Clone)]
pub enum PostgresCreds {
Url(Box<str>),
Host {
host: Box<str>,
port: u16,
user: Option<Box<str>>,
password: Option<Box<str>>,
database: Box<str>,
ssl: bool,
},
}
impl PostgresCreds {
pub fn connection_url(&self) -> String {
match self {
Self::Url(url) => url.to_string(),
Self::Host {
host,
port,
user,
password,
database,
..
} => {
let auth = match (user, password) {
(Some(u), Some(p)) => format!("{u}:{p}@"),
(Some(u), None) => format!("{u}@"),
_ => String::new(),
};
format!("postgres://{auth}{host}:{port}/{database}")
}
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum Schema {
One(String),
Many(Vec<String>),
}
impl Default for Schema {
fn default() -> Self {
Self::One("src/schema.rs".into())
}
}
impl Schema {
pub fn iter(&self) -> impl Iterator<Item = &str> {
match self {
Self::One(s) => std::slice::from_ref(s).iter().map(String::as_str),
Self::Many(v) => v.iter().map(String::as_str),
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum Filter {
One(String),
Many(Vec<String>),
}
impl Filter {
pub fn iter(&self) -> impl Iterator<Item = &str> {
match self {
Self::One(s) => std::slice::from_ref(s).iter().map(String::as_str),
Self::Many(v) => v.iter().map(String::as_str),
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct MigrationsOpts {
pub table: Option<String>,
pub schema: Option<String>,
pub prefix: Option<MigrationPrefix>,
}
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MigrationPrefix {
Index,
Timestamp,
Supabase,
Unix,
None,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum RawCreds {
Url {
url: EnvOr,
#[serde(default, rename = "authToken")]
auth_token: Option<EnvOr>,
},
Host {
host: EnvOr,
#[serde(default)]
port: Option<u16>,
#[serde(default)]
user: Option<EnvOr>,
#[serde(default)]
password: Option<EnvOr>,
database: EnvOr,
#[serde(default)]
ssl: Option<SslVal>,
},
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum SslVal {
Bool(bool),
Str(String),
}
impl SslVal {
fn enabled(&self) -> bool {
match self {
Self::Bool(b) => *b,
Self::Str(s) => !matches!(s.as_str(), "disable" | "false" | "no" | "off"),
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DatabaseConfig {
pub dialect: Dialect,
#[serde(default)]
pub schema: Schema,
#[serde(default = "default_out")]
pub out: PathBuf,
#[serde(default = "yes")]
pub breakpoints: bool,
#[serde(default)]
pub driver: Option<Driver>,
#[serde(default)]
db_credentials: Option<RawCreds>,
#[serde(default)]
pub tables_filter: Option<Filter>,
#[serde(default)]
pub schema_filter: Option<Filter>,
#[serde(default)]
pub extensions_filters: Option<Vec<Extension>>,
#[serde(default)]
pub entities: Option<EntitiesFilter>,
#[serde(default)]
pub casing: Option<Casing>,
#[serde(default)]
pub introspect: Option<IntrospectConfig>,
#[serde(default)]
pub verbose: bool,
#[serde(default)]
pub strict: bool,
#[serde(default)]
pub migrations: Option<MigrationsOpts>,
}
fn default_out() -> PathBuf {
PathBuf::from("./drizzle")
}
fn yes() -> bool {
true
}
impl DatabaseConfig {
fn normalize_paths(&mut self, base_dir: &Path) {
if self.out.is_relative() {
self.out = base_dir.join(&self.out);
}
let base = base_dir.to_string_lossy().replace('\\', "/");
let base = base.trim_end_matches('/').to_string();
let normalize_one = |p: &str| -> String {
let p_trim = p.trim();
let is_abs = Path::new(p_trim).is_absolute() || p_trim.starts_with("\\\\");
let joined = if is_abs || base.is_empty() || base == "." {
p_trim.to_string()
} else {
format!("{base}/{p_trim}")
};
joined.replace('\\', "/")
};
match &mut self.schema {
Schema::One(p) => *p = normalize_one(p),
Schema::Many(v) => {
for p in v.iter_mut() {
*p = normalize_one(p);
}
}
}
}
fn validate(&self, name: &str) -> Result<(), Error> {
if let Some(d) = self.driver
&& !d.is_valid_for(self.dialect)
{
return Err(Error::InvalidDriver {
driver: d,
dialect: self.dialect,
});
}
if let Some(ref raw) = self.db_credentials {
self.validate_creds(raw, name)?;
}
Ok(())
}
fn validate_creds(&self, raw: &RawCreds, _name: &str) -> Result<(), Error> {
let err = |msg: &str| Error::InvalidCredentials(msg.into());
match (self.dialect, raw) {
(Dialect::Postgresql, RawCreds::Host { .. }) => {}
(Dialect::Postgresql, RawCreds::Url { .. }) => {}
(_, RawCreds::Host { .. }) => {
return Err(err(
"host-based dbCredentials are only supported for dialect = \"postgresql\"",
));
}
_ => {}
}
match (self.dialect, raw) {
(
Dialect::Sqlite,
RawCreds::Url {
auth_token: Some(_),
..
},
) => Err(err(
"SQLite doesn't support authToken (use dialect = \"turso\")",
)),
(
Dialect::Sqlite,
RawCreds::Url {
url: EnvOr::Value(url),
..
},
) if url.starts_with("libsql://") => Err(err(
"libsql:// URLs require dialect = \"turso\" (for local SQLite files, use ./path.db)",
)),
(
Dialect::Sqlite,
RawCreds::Url {
url: EnvOr::Value(url),
..
},
) if url.starts_with("http://")
|| url.starts_with("https://")
|| url.starts_with("postgres://")
|| url.starts_with("postgresql://") =>
{
Err(err(
"SQLite dbCredentials.url must be a local file path (not an http(s)/postgres URL)",
))
}
(
Dialect::Turso,
RawCreds::Url {
url: EnvOr::Value(url),
..
},
) if !url.starts_with("libsql://") && !url.starts_with("http") => {
Err(err("Turso URL must start with libsql:// or http(s)://"))
}
(
Dialect::Postgresql,
RawCreds::Url {
url: EnvOr::Value(url),
..
},
) if !url.starts_with("postgres") => {
Err(err("PostgreSQL URL must start with postgres://"))
}
_ => Ok(()),
}
}
pub fn credentials(&self) -> Result<Option<Credentials>, Error> {
let raw = match self.db_credentials.as_ref() {
Some(r) => r,
None => return Ok(None),
};
let resolve_opt = |opt: &Option<EnvOr>| -> Result<Option<Box<str>>, Error> {
match opt {
Some(e) => e.resolve().map(|s| Some(s.into_boxed_str())),
None => Ok(None),
}
};
let creds = match (self.dialect, raw) {
(Dialect::Sqlite, RawCreds::Url { url, .. }) => Credentials::Sqlite {
path: url.resolve()?.into_boxed_str(),
},
(Dialect::Turso, RawCreds::Url { url, auth_token }) => Credentials::Turso {
url: url.resolve()?.into_boxed_str(),
auth_token: resolve_opt(auth_token)?,
},
(Dialect::Postgresql, RawCreds::Url { url, .. }) => {
Credentials::Postgres(PostgresCreds::Url(url.resolve()?.into_boxed_str()))
}
(
Dialect::Postgresql,
RawCreds::Host {
host,
port,
user,
password,
database,
ssl,
},
) => Credentials::Postgres(PostgresCreds::Host {
host: host.resolve()?.into_boxed_str(),
port: port.unwrap_or(5432),
user: resolve_opt(user)?,
password: resolve_opt(password)?,
database: database.resolve()?.into_boxed_str(),
ssl: ssl.as_ref().map(|s| s.enabled()).unwrap_or(false),
}),
_ => return Ok(None),
};
Ok(Some(creds))
}
#[inline]
pub fn migrations_dir(&self) -> &Path {
&self.out
}
#[inline]
pub fn meta_dir(&self) -> PathBuf {
self.out.join("meta")
}
#[inline]
pub fn journal_path(&self) -> PathBuf {
self.meta_dir().join("_journal.json")
}
pub fn schema_display(&self) -> String {
match &self.schema {
Schema::One(s) => s.clone(),
Schema::Many(v) => v.join(", "),
}
}
pub fn schema_files(&self) -> Result<Vec<PathBuf>, Error> {
let mut files = Vec::new();
for pattern in self.schema.iter() {
let pat = pattern.trim();
let is_glob = pat.contains('*') || pat.contains('?') || pat.contains('[');
if !is_glob {
let p = PathBuf::from(pat);
if p.exists() {
files.push(p);
continue;
}
}
let pat_norm = pat.replace('\\', "/");
match glob::glob(&pat_norm) {
Ok(paths) => {
let matched: Vec<_> = paths.filter_map(Result::ok).collect();
if matched.is_empty() && !is_glob {
let p = PathBuf::from(&pat_norm);
if p.exists() {
files.push(p);
}
} else {
files.extend(matched);
}
}
Err(e) => return Err(Error::Glob(pat.into(), e)),
}
}
files.retain(|p| p.is_file());
files.sort();
files.dedup();
if files.is_empty() {
return Err(Error::NoSchemaFiles(self.schema_display()));
}
Ok(files)
}
#[inline]
pub fn effective_casing(&self) -> Casing {
self.casing.unwrap_or_default()
}
#[inline]
pub fn effective_introspect_casing(&self) -> IntrospectCasing {
self.introspect
.as_ref()
.map(|i| i.casing)
.unwrap_or_default()
}
#[inline]
pub fn effective_entities(&self) -> EntitiesFilter {
self.entities.clone().unwrap_or_default()
}
pub fn should_include_role(&self, role_name: &str) -> bool {
self.entities
.as_ref()
.map(|e| e.roles.should_include(role_name))
.unwrap_or(false)
}
pub fn roles_enabled(&self) -> bool {
self.entities
.as_ref()
.map(|e| e.roles.is_enabled())
.unwrap_or(false)
}
pub fn extensions(&self) -> &[Extension] {
self.extensions_filters.as_deref().unwrap_or(&[])
}
pub fn has_extension(&self, ext: Extension) -> bool {
self.extensions_filters
.as_ref()
.map(|v| v.contains(&ext))
.unwrap_or(false)
}
pub fn migrations_table(&self) -> &str {
self.migrations
.as_ref()
.and_then(|m| m.table.as_deref())
.unwrap_or("__drizzle_migrations")
}
pub fn migrations_schema(&self) -> &str {
self.migrations
.as_ref()
.and_then(|m| m.schema.as_deref())
.unwrap_or("drizzle")
}
}
#[derive(Debug, Clone, Deserialize)]
struct MultiDbConfig {
databases: HashMap<String, DatabaseConfig>,
}
#[derive(Debug, Clone)]
pub struct Config {
databases: HashMap<String, DatabaseConfig>,
is_single: bool,
}
pub const DEFAULT_DB: &str = "default";
impl Config {
pub fn load() -> Result<Self, Error> {
Self::load_from(Path::new(CONFIG_FILE))
}
pub fn load_from(path: &Path) -> Result<Self, Error> {
let content = std::fs::read_to_string(path).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
Error::NotFound(path.into())
} else {
Error::Io(path.into(), e)
}
})?;
Self::load_from_str(&content, path)
}
fn load_from_str(content: &str, path: &Path) -> Result<Self, Error> {
let base_dir = path.parent().unwrap_or_else(|| Path::new("."));
if let Ok(multi) = toml::from_str::<MultiDbConfig>(content)
&& !multi.databases.is_empty()
{
let mut config = Self {
databases: multi.databases,
is_single: false,
};
for db in config.databases.values_mut() {
db.normalize_paths(base_dir);
}
config.validate()?;
return Ok(config);
}
let db_config: DatabaseConfig =
toml::from_str(content).map_err(|e| Error::Parse(path.into(), e))?;
let mut databases = HashMap::new();
databases.insert(DEFAULT_DB.to_string(), db_config);
let mut config = Self {
databases,
is_single: true,
};
for db in config.databases.values_mut() {
db.normalize_paths(base_dir);
}
config.validate()?;
Ok(config)
}
fn validate(&mut self) -> Result<(), Error> {
for (name, db) in &self.databases {
db.validate(name)?;
}
Ok(())
}
pub fn is_single_database(&self) -> bool {
self.is_single
}
pub fn database_names(&self) -> impl Iterator<Item = &str> {
self.databases.keys().map(String::as_str)
}
pub fn database(&self, name: Option<&str>) -> Result<&DatabaseConfig, Error> {
match name {
None => {
if self.is_single {
self.databases.get(DEFAULT_DB).ok_or(Error::NoDatabases)
} else if self.databases.len() == 1 {
self.databases.values().next().ok_or(Error::NoDatabases)
} else {
Err(Error::DatabaseRequired(
self.databases.keys().cloned().collect(),
))
}
}
Some(name) => {
if self.is_single {
self.databases.get(DEFAULT_DB).ok_or(Error::NoDatabases)
} else {
self.databases
.get(name)
.ok_or_else(|| Error::DatabaseNotFound(name.to_string()))
}
}
}
}
pub fn default_database(&self) -> Result<&DatabaseConfig, Error> {
self.database(None)
}
pub fn dialect(&self) -> Dialect {
self.default_database()
.map(|d| d.dialect)
.unwrap_or_default()
}
pub fn credentials(&self) -> Result<Option<Credentials>, Error> {
self.default_database()?.credentials()
}
pub fn migrations_dir(&self) -> &Path {
self.default_database()
.map(|d| d.migrations_dir())
.unwrap_or(Path::new("./drizzle"))
}
pub fn journal_path(&self) -> PathBuf {
self.default_database()
.map(|d| d.journal_path())
.unwrap_or_else(|_| PathBuf::from("./drizzle/meta/_journal.json"))
}
pub fn schema_display(&self) -> String {
self.default_database()
.map(|d| d.schema_display())
.unwrap_or_else(|_| "src/schema.rs".into())
}
pub fn schema_files(&self) -> Result<Vec<PathBuf>, Error> {
self.default_database()?.schema_files()
}
pub fn base_dialect(&self) -> drizzle_types::Dialect {
self.dialect().to_base()
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("config not found: {}", .0.display())]
NotFound(PathBuf),
#[error("failed to read {}: {}", .0.display(), .1)]
Io(PathBuf, #[source] std::io::Error),
#[error("failed to parse {}: {}", .0.display(), .1)]
Parse(PathBuf, #[source] toml::de::Error),
#[error("driver '{driver}' invalid for {dialect} dialect")]
InvalidDriver { driver: Driver, dialect: Dialect },
#[error("invalid credentials: {0}")]
InvalidCredentials(String),
#[error("invalid glob '{0}': {1}")]
Glob(String, #[source] glob::PatternError),
#[error("no schema files found: {0}")]
NoSchemaFiles(String),
#[error("environment variable '{0}' not found")]
EnvNotFound(String),
#[error("environment variable '{0}' invalid: {1}")]
EnvInvalid(String, String),
#[error("no databases configured")]
NoDatabases,
#[error("database '{0}' not found")]
DatabaseNotFound(String),
#[error("multiple databases configured, use --db to specify: {}", .0.join(", "))]
DatabaseRequired(Vec<String>),
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn sqlite() {
let cfg = Config::load_from_str(
r#"
dialect = "sqlite"
[dbCredentials]
url = "./dev.db"
"#,
Path::new("test.toml"),
)
.unwrap();
assert!(cfg.is_single_database());
assert!(matches!(
cfg.credentials().unwrap(),
Some(Credentials::Sqlite { .. })
));
}
#[test]
fn postgres_url() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
assert!(matches!(
cfg.credentials().unwrap(),
Some(Credentials::Postgres(PostgresCreds::Url(_)))
));
}
#[test]
fn multi_database() {
let cfg = Config::load_from_str(
r#"
[databases.dev]
dialect = "sqlite"
out = "./drizzle/sqlite"
[databases.dev.dbCredentials]
url = "./dev.db"
[databases.prod]
dialect = "postgresql"
out = "./drizzle/postgres"
[databases.prod.dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
assert!(!cfg.is_single_database());
let names: Vec<_> = cfg.database_names().collect();
assert!(names.contains(&"dev"));
assert!(names.contains(&"prod"));
let dev = cfg.database(Some("dev")).unwrap();
assert_eq!(dev.dialect, Dialect::Sqlite);
let prod = cfg.database(Some("prod")).unwrap();
assert_eq!(prod.dialect, Dialect::Postgresql);
}
#[test]
fn multi_database_requires_selection() {
let cfg = Config::load_from_str(
r#"
[databases.a]
dialect = "sqlite"
[databases.b]
dialect = "postgresql"
"#,
Path::new("test.toml"),
)
.unwrap();
assert!(cfg.database(None).is_err());
}
#[test]
fn env_var_syntax() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
[dbCredentials]
url = { env = "DATABASE_URL" }
"#,
Path::new("test.toml"),
)
.unwrap();
assert!(cfg.is_single_database());
}
#[test]
fn casing_options() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
casing = "snake_case"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert_eq!(db.effective_casing(), Casing::SnakeCase);
let cfg2 = Config::load_from_str(
r#"
dialect = "postgresql"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db2 = cfg2.default_database().unwrap();
assert_eq!(db2.effective_casing(), Casing::CamelCase);
}
#[test]
fn introspect_casing() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
[introspect]
casing = "preserve"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert_eq!(db.effective_introspect_casing(), IntrospectCasing::Preserve);
}
#[test]
fn entities_roles_filter() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
[entities]
roles = true
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert!(db.roles_enabled());
assert!(db.should_include_role("my_role"));
let cfg2 = Config::load_from_str(
r#"
dialect = "postgresql"
[entities.roles]
provider = "supabase"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db2 = cfg2.default_database().unwrap();
assert!(db2.roles_enabled());
assert!(!db2.should_include_role("anon")); assert!(db2.should_include_role("my_custom_role"));
}
#[test]
fn extensions_filter() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
extensionsFilters = ["postgis"]
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert!(db.has_extension(Extension::Postgis));
}
#[test]
fn migrations_config() {
let cfg = Config::load_from_str(
r#"
dialect = "postgresql"
[migrations]
table = "custom_migrations"
schema = "custom_schema"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db = cfg.default_database().unwrap();
assert_eq!(db.migrations_table(), "custom_migrations");
assert_eq!(db.migrations_schema(), "custom_schema");
let cfg2 = Config::load_from_str(
r#"
dialect = "postgresql"
[dbCredentials]
url = "postgres://localhost/db"
"#,
Path::new("test.toml"),
)
.unwrap();
let db2 = cfg2.default_database().unwrap();
assert_eq!(db2.migrations_table(), "__drizzle_migrations");
assert_eq!(db2.migrations_schema(), "drizzle");
}
#[test]
fn resolves_paths_relative_to_config_dir() {
let tmp = TempDir::new().unwrap();
let cfg_dir = tmp.path().join("cfg");
fs::create_dir_all(&cfg_dir).unwrap();
let schema_path = cfg_dir.join("schema.rs");
fs::write(&schema_path, "#[allow(dead_code)]\npub struct X;").unwrap();
let cfg_path = cfg_dir.join("drizzle.config.toml");
let cfg = Config::load_from_str(
r#"
dialect = "sqlite"
schema = "schema.rs"
out = "./drizzle"
[dbCredentials]
url = "./dev.db"
"#,
&cfg_path,
)
.unwrap();
let db = cfg.default_database().unwrap();
assert_eq!(db.migrations_dir(), cfg_dir.join("./drizzle").as_path());
let files = db.schema_files().unwrap();
assert_eq!(files.len(), 1);
assert_eq!(files[0], schema_path);
}
#[test]
fn rejects_host_credentials_for_sqlite() {
let err = Config::load_from_str(
r#"
dialect = "sqlite"
[dbCredentials]
host = "localhost"
database = "db"
"#,
Path::new("test.toml"),
)
.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("host-based dbCredentials are only supported"),
"unexpected error: {msg}"
);
}
#[cfg(windows)]
#[test]
fn schema_files_accept_backslash_paths() {
let tmp = TempDir::new().unwrap();
let cfg_dir = tmp.path().join("cfg");
fs::create_dir_all(&cfg_dir).unwrap();
let schema_path = cfg_dir.join("src").join("schema.rs");
fs::create_dir_all(schema_path.parent().unwrap()).unwrap();
fs::write(&schema_path, "#[allow(dead_code)]\npub struct X;").unwrap();
let schema_str = schema_path.to_string_lossy().replace('/', "\\");
let schema_toml = schema_str.replace('\\', "\\\\");
let cfg_path = cfg_dir.join("drizzle.config.toml");
let cfg = Config::load_from_str(
&format!(
r#"
dialect = "sqlite"
schema = "{}"
"#,
schema_toml
),
&cfg_path,
)
.unwrap();
let db = cfg.default_database().unwrap();
let files = db.schema_files().unwrap();
assert_eq!(files, vec![schema_path]);
}
}