use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use crate::error::{SchemaError, SchemaResult};
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct PraxConfig {
#[serde(default)]
pub database: DatabaseConfig,
#[serde(default)]
pub schema: SchemaConfig,
#[serde(default)]
pub generator: GeneratorConfig,
#[serde(default)]
pub migrations: MigrationConfig,
#[serde(default)]
pub seed: SeedConfig,
#[serde(default)]
pub debug: DebugConfig,
#[serde(default)]
pub environments: HashMap<String, EnvironmentOverride>,
}
impl PraxConfig {
pub fn from_file(path: impl AsRef<Path>) -> SchemaResult<Self> {
let path = path.as_ref();
let content = std::fs::read_to_string(path).map_err(|e| SchemaError::IoError {
path: path.display().to_string(),
source: e,
})?;
Self::from_str(&content)
}
#[allow(clippy::should_implement_trait)]
pub fn from_str(content: &str) -> SchemaResult<Self> {
let expanded = expand_env_vars(content);
toml::from_str(&expanded).map_err(|e| SchemaError::TomlError { source: e })
}
pub fn database_url(&self) -> Option<&str> {
self.database.url.as_deref()
}
pub fn with_environment(mut self, env: &str) -> Self {
if let Some(overrides) = self.environments.remove(env) {
if let Some(db) = overrides.database {
if let Some(url) = db.url {
self.database.url = Some(url);
}
if let Some(pool) = db.pool {
self.database.pool = pool;
}
}
if let Some(debug) = overrides.debug {
if let Some(log_queries) = debug.log_queries {
self.debug.log_queries = log_queries;
}
if let Some(pretty_sql) = debug.pretty_sql {
self.debug.pretty_sql = pretty_sql;
}
if let Some(threshold) = debug.slow_query_threshold {
self.debug.slow_query_threshold = threshold;
}
}
}
self
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct DatabaseConfig {
#[serde(default = "default_provider")]
pub provider: DatabaseProvider,
pub url: Option<String>,
#[serde(default)]
pub pool: PoolConfig,
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
provider: DatabaseProvider::PostgreSql,
url: None,
pool: PoolConfig::default(),
}
}
}
fn default_provider() -> DatabaseProvider {
DatabaseProvider::PostgreSql
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum DatabaseProvider {
#[serde(alias = "postgres")]
PostgreSql,
MySql,
#[serde(alias = "sqlite3")]
Sqlite,
#[serde(alias = "mongo")]
MongoDb,
}
impl DatabaseProvider {
pub fn as_str(&self) -> &'static str {
match self {
Self::PostgreSql => "postgresql",
Self::MySql => "mysql",
Self::Sqlite => "sqlite",
Self::MongoDb => "mongodb",
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct PoolConfig {
#[serde(default = "default_min_connections")]
pub min_connections: u32,
#[serde(default = "default_max_connections")]
pub max_connections: u32,
#[serde(default = "default_connect_timeout")]
pub connect_timeout: String,
#[serde(default = "default_idle_timeout")]
pub idle_timeout: String,
#[serde(default = "default_max_lifetime")]
pub max_lifetime: String,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
min_connections: default_min_connections(),
max_connections: default_max_connections(),
connect_timeout: default_connect_timeout(),
idle_timeout: default_idle_timeout(),
max_lifetime: default_max_lifetime(),
}
}
}
fn default_min_connections() -> u32 {
2
}
fn default_max_connections() -> u32 {
10
}
fn default_connect_timeout() -> String {
"30s".to_string()
}
fn default_idle_timeout() -> String {
"10m".to_string()
}
fn default_max_lifetime() -> String {
"30m".to_string()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct SchemaConfig {
#[serde(default = "default_schema_path")]
pub path: String,
}
impl Default for SchemaConfig {
fn default() -> Self {
Self {
path: default_schema_path(),
}
}
}
fn default_schema_path() -> String {
"schema.prax".to_string()
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct GeneratorConfig {
#[serde(default)]
pub client: ClientGeneratorConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelStyle {
#[default]
Standard,
#[serde(alias = "async-graphql")]
GraphQL,
}
impl ModelStyle {
pub fn is_graphql(&self) -> bool {
matches!(self, Self::GraphQL)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct ClientGeneratorConfig {
#[serde(default = "default_output")]
pub output: String,
#[serde(default = "default_true")]
pub async_client: bool,
#[serde(default)]
pub tracing: bool,
#[serde(default)]
pub preview_features: Vec<String>,
#[serde(default)]
pub model_style: ModelStyle,
}
impl Default for ClientGeneratorConfig {
fn default() -> Self {
Self {
output: default_output(),
async_client: true,
tracing: false,
preview_features: vec![],
model_style: ModelStyle::default(),
}
}
}
fn default_output() -> String {
"./src/generated".to_string()
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct MigrationConfig {
#[serde(default = "default_migrations_dir")]
pub directory: String,
#[serde(default)]
pub auto_migrate: bool,
#[serde(default = "default_migrations_table")]
pub table_name: String,
}
impl Default for MigrationConfig {
fn default() -> Self {
Self {
directory: default_migrations_dir(),
auto_migrate: false,
table_name: default_migrations_table(),
}
}
}
fn default_migrations_dir() -> String {
"./migrations".to_string()
}
fn default_migrations_table() -> String {
"_prax_migrations".to_string()
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct SeedConfig {
pub script: Option<String>,
#[serde(default)]
pub auto_seed: bool,
#[serde(default)]
pub environments: HashMap<String, bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct DebugConfig {
#[serde(default)]
pub log_queries: bool,
#[serde(default = "default_true")]
pub pretty_sql: bool,
#[serde(default = "default_slow_query_threshold")]
pub slow_query_threshold: u64,
}
impl Default for DebugConfig {
fn default() -> Self {
Self {
log_queries: false,
pretty_sql: true,
slow_query_threshold: default_slow_query_threshold(),
}
}
}
fn default_slow_query_threshold() -> u64 {
1000
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct EnvironmentOverride {
pub database: Option<DatabaseOverride>,
pub debug: Option<DebugOverride>,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct DatabaseOverride {
pub url: Option<String>,
pub pool: Option<PoolConfig>,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct DebugOverride {
pub log_queries: Option<bool>,
pub pretty_sql: Option<bool>,
pub slow_query_threshold: Option<u64>,
}
fn expand_env_vars(content: &str) -> String {
let mut result = content.to_string();
let re = regex_lite::Regex::new(r"\$\{([^}]+)\}").unwrap();
for cap in re.captures_iter(content) {
let var_name = &cap[1];
let full_match = &cap[0];
if let Ok(value) = std::env::var(var_name) {
result = result.replace(full_match, &value);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = PraxConfig::default();
assert_eq!(config.database.provider, DatabaseProvider::PostgreSql);
assert_eq!(config.schema.path, "schema.prax");
assert!(config.database.url.is_none());
assert!(config.environments.is_empty());
}
#[test]
fn test_parse_minimal_config() {
let toml = r#"
[database]
provider = "postgresql"
url = "postgres://localhost/test"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(
config.database.url,
Some("postgres://localhost/test".to_string())
);
}
#[test]
fn test_parse_full_config() {
let toml = r#"
[database]
provider = "postgresql"
url = "postgres://user:pass@localhost:5432/db"
[database.pool]
min_connections = 5
max_connections = 20
connect_timeout = "60s"
idle_timeout = "5m"
max_lifetime = "1h"
[schema]
path = "prisma/schema.prax"
[generator.client]
output = "./src/db"
async_client = true
tracing = true
preview_features = ["json", "fulltext"]
[migrations]
directory = "./db/migrations"
auto_migrate = true
table_name = "_migrations"
[seed]
script = "./scripts/seed.sh"
auto_seed = true
[seed.environments]
development = true
test = true
production = false
[debug]
log_queries = true
pretty_sql = false
slow_query_threshold = 500
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.database.provider, DatabaseProvider::PostgreSql);
assert!(config.database.url.is_some());
assert_eq!(config.database.pool.min_connections, 5);
assert_eq!(config.database.pool.max_connections, 20);
assert_eq!(config.schema.path, "prisma/schema.prax");
assert_eq!(config.generator.client.output, "./src/db");
assert!(config.generator.client.async_client);
assert!(config.generator.client.tracing);
assert_eq!(config.generator.client.preview_features.len(), 2);
assert_eq!(config.migrations.directory, "./db/migrations");
assert!(config.migrations.auto_migrate);
assert_eq!(config.migrations.table_name, "_migrations");
assert_eq!(config.seed.script, Some("./scripts/seed.sh".to_string()));
assert!(config.seed.auto_seed);
assert!(
config
.seed
.environments
.get("development")
.copied()
.unwrap_or(false)
);
assert!(config.debug.log_queries);
assert!(!config.debug.pretty_sql);
assert_eq!(config.debug.slow_query_threshold, 500);
}
#[test]
fn test_database_url_method() {
let config = PraxConfig {
database: DatabaseConfig {
url: Some("postgres://localhost/test".to_string()),
..Default::default()
},
..Default::default()
};
assert_eq!(config.database_url(), Some("postgres://localhost/test"));
}
#[test]
fn test_database_url_method_none() {
let config = PraxConfig::default();
assert!(config.database_url().is_none());
}
#[test]
fn test_with_environment_overrides() {
let toml = r#"
[database]
url = "postgres://localhost/dev"
[debug]
log_queries = false
[environments.production]
[environments.production.database]
url = "postgres://prod.server/db"
[environments.production.debug]
log_queries = true
slow_query_threshold = 100
"#;
let config = PraxConfig::from_str(toml)
.unwrap()
.with_environment("production");
assert_eq!(
config.database.url,
Some("postgres://prod.server/db".to_string())
);
assert!(config.debug.log_queries);
assert_eq!(config.debug.slow_query_threshold, 100);
}
#[test]
fn test_with_environment_nonexistent() {
let config = PraxConfig::default().with_environment("nonexistent");
assert_eq!(config.database.provider, DatabaseProvider::PostgreSql);
}
#[test]
fn test_parse_invalid_toml() {
let toml = "this is not valid [[ toml";
let result = PraxConfig::from_str(toml);
assert!(result.is_err());
}
#[test]
fn test_database_provider_postgresql() {
let toml = r#"
[database]
provider = "postgresql"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.database.provider, DatabaseProvider::PostgreSql);
assert_eq!(config.database.provider.as_str(), "postgresql");
}
#[test]
fn test_database_provider_postgres_alias() {
let toml = r#"
[database]
provider = "postgres"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.database.provider, DatabaseProvider::PostgreSql);
}
#[test]
fn test_database_provider_mysql() {
let toml = r#"
[database]
provider = "mysql"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.database.provider, DatabaseProvider::MySql);
assert_eq!(config.database.provider.as_str(), "mysql");
}
#[test]
fn test_database_provider_sqlite() {
let toml = r#"
[database]
provider = "sqlite"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.database.provider, DatabaseProvider::Sqlite);
assert_eq!(config.database.provider.as_str(), "sqlite");
}
#[test]
fn test_database_provider_sqlite3_alias() {
let toml = r#"
[database]
provider = "sqlite3"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.database.provider, DatabaseProvider::Sqlite);
}
#[test]
fn test_database_provider_mongodb() {
let toml = r#"
[database]
provider = "mongodb"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.database.provider, DatabaseProvider::MongoDb);
assert_eq!(config.database.provider.as_str(), "mongodb");
}
#[test]
fn test_database_provider_mongo_alias() {
let toml = r#"
[database]
provider = "mongo"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.database.provider, DatabaseProvider::MongoDb);
}
#[test]
fn test_pool_config_defaults() {
let config = PoolConfig::default();
assert_eq!(config.min_connections, 2);
assert_eq!(config.max_connections, 10);
assert_eq!(config.connect_timeout, "30s");
assert_eq!(config.idle_timeout, "10m");
assert_eq!(config.max_lifetime, "30m");
}
#[test]
fn test_pool_config_custom() {
let toml = r#"
[database]
provider = "postgresql"
[database.pool]
min_connections = 1
max_connections = 50
connect_timeout = "10s"
idle_timeout = "30m"
max_lifetime = "2h"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.database.pool.min_connections, 1);
assert_eq!(config.database.pool.max_connections, 50);
assert_eq!(config.database.pool.connect_timeout, "10s");
}
#[test]
fn test_schema_config_default() {
let config = SchemaConfig::default();
assert_eq!(config.path, "schema.prax");
}
#[test]
fn test_schema_config_custom() {
let toml = r#"
[schema]
path = "db/schema.prax"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.schema.path, "db/schema.prax");
}
#[test]
fn test_generator_config_default() {
let config = GeneratorConfig::default();
assert_eq!(config.client.output, "./src/generated");
assert!(config.client.async_client);
assert!(!config.client.tracing);
assert!(config.client.preview_features.is_empty());
assert_eq!(config.client.model_style, ModelStyle::Standard);
}
#[test]
fn test_generator_config_custom() {
let toml = r#"
[generator.client]
output = "./generated"
async_client = false
tracing = true
preview_features = ["feature1", "feature2"]
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.generator.client.output, "./generated");
assert!(!config.generator.client.async_client);
assert!(config.generator.client.tracing);
assert_eq!(config.generator.client.preview_features.len(), 2);
}
#[test]
fn test_generator_config_graphql_model_style() {
let toml = r#"
[generator.client]
model_style = "graphql"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.generator.client.model_style, ModelStyle::GraphQL);
assert!(config.generator.client.model_style.is_graphql());
}
#[test]
fn test_generator_config_graphql_model_style_alias() {
let toml = r#"
[generator.client]
model_style = "async-graphql"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.generator.client.model_style, ModelStyle::GraphQL);
}
#[test]
fn test_model_style_standard_is_not_graphql() {
assert!(!ModelStyle::Standard.is_graphql());
assert!(ModelStyle::GraphQL.is_graphql());
}
#[test]
fn test_migration_config_default() {
let config = MigrationConfig::default();
assert_eq!(config.directory, "./migrations");
assert!(!config.auto_migrate);
assert_eq!(config.table_name, "_prax_migrations");
}
#[test]
fn test_migration_config_custom() {
let toml = r#"
[migrations]
directory = "./db/migrate"
auto_migrate = true
table_name = "schema_migrations"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.migrations.directory, "./db/migrate");
assert!(config.migrations.auto_migrate);
assert_eq!(config.migrations.table_name, "schema_migrations");
}
#[test]
fn test_seed_config_default() {
let config = SeedConfig::default();
assert!(config.script.is_none());
assert!(!config.auto_seed);
assert!(config.environments.is_empty());
}
#[test]
fn test_seed_config_custom() {
let toml = r#"
[seed]
script = "seed.rs"
auto_seed = true
[seed.environments]
dev = true
prod = false
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(config.seed.script, Some("seed.rs".to_string()));
assert!(config.seed.auto_seed);
assert_eq!(config.seed.environments.get("dev"), Some(&true));
assert_eq!(config.seed.environments.get("prod"), Some(&false));
}
#[test]
fn test_debug_config_default() {
let config = DebugConfig::default();
assert!(!config.log_queries);
assert!(config.pretty_sql);
assert_eq!(config.slow_query_threshold, 1000);
}
#[test]
fn test_debug_config_custom() {
let toml = r#"
[debug]
log_queries = true
pretty_sql = false
slow_query_threshold = 200
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert!(config.debug.log_queries);
assert!(!config.debug.pretty_sql);
assert_eq!(config.debug.slow_query_threshold, 200);
}
#[test]
fn test_env_var_expansion() {
unsafe {
std::env::set_var("TEST_DB_URL", "postgres://test");
}
let expanded = expand_env_vars("url = \"${TEST_DB_URL}\"");
assert_eq!(expanded, "url = \"postgres://test\"");
unsafe {
std::env::remove_var("TEST_DB_URL");
}
}
#[test]
fn test_env_var_expansion_multiple() {
unsafe {
std::env::set_var("TEST_HOST", "localhost");
std::env::set_var("TEST_PORT", "5432");
}
let content = "host = \"${TEST_HOST}\"\nport = \"${TEST_PORT}\"";
let expanded = expand_env_vars(content);
assert!(expanded.contains("localhost"));
assert!(expanded.contains("5432"));
unsafe {
std::env::remove_var("TEST_HOST");
std::env::remove_var("TEST_PORT");
}
}
#[test]
fn test_env_var_expansion_missing_var() {
let content = "url = \"${DEFINITELY_NOT_SET_VAR_12345}\"";
let expanded = expand_env_vars(content);
assert_eq!(expanded, content);
}
#[test]
fn test_env_var_expansion_in_config() {
unsafe {
std::env::set_var("TEST_DATABASE_URL_2", "postgres://user:pass@localhost/db");
}
let toml = r#"
[database]
url = "${TEST_DATABASE_URL_2}"
"#;
let config = PraxConfig::from_str(toml).unwrap();
assert_eq!(
config.database.url,
Some("postgres://user:pass@localhost/db".to_string())
);
unsafe {
std::env::remove_var("TEST_DATABASE_URL_2");
}
}
#[test]
fn test_environment_override_database_url() {
let toml = r#"
[database]
url = "postgres://localhost/dev"
[environments.test]
[environments.test.database]
url = "postgres://localhost/test_db"
"#;
let config = PraxConfig::from_str(toml).unwrap().with_environment("test");
assert_eq!(
config.database.url,
Some("postgres://localhost/test_db".to_string())
);
}
#[test]
fn test_environment_override_pool() {
let toml = r#"
[database.pool]
max_connections = 10
[environments.production]
[environments.production.database.pool]
max_connections = 100
min_connections = 10
"#;
let config = PraxConfig::from_str(toml)
.unwrap()
.with_environment("production");
assert_eq!(config.database.pool.max_connections, 100);
assert_eq!(config.database.pool.min_connections, 10);
}
#[test]
fn test_environment_override_debug() {
let toml = r#"
[debug]
log_queries = false
pretty_sql = true
[environments.development]
[environments.development.debug]
log_queries = true
pretty_sql = false
slow_query_threshold = 50
"#;
let config = PraxConfig::from_str(toml)
.unwrap()
.with_environment("development");
assert!(config.debug.log_queries);
assert!(!config.debug.pretty_sql);
assert_eq!(config.debug.slow_query_threshold, 50);
}
#[test]
fn test_config_serialization() {
let config = PraxConfig::default();
let toml_str = toml::to_string(&config).unwrap();
assert!(toml_str.contains("[database]"));
}
#[test]
fn test_config_roundtrip() {
let original = PraxConfig {
database: DatabaseConfig {
provider: DatabaseProvider::MySql,
url: Some("mysql://localhost/test".to_string()),
pool: PoolConfig::default(),
},
..Default::default()
};
let toml_str = toml::to_string(&original).unwrap();
let parsed: PraxConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.database.provider, original.database.provider);
assert_eq!(parsed.database.url, original.database.url);
}
#[test]
fn test_config_clone() {
let config = PraxConfig::default();
let cloned = config.clone();
assert_eq!(config.database.provider, cloned.database.provider);
}
#[test]
fn test_config_debug() {
let config = PraxConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("PraxConfig"));
}
#[test]
fn test_provider_equality() {
assert_eq!(DatabaseProvider::PostgreSql, DatabaseProvider::PostgreSql);
assert_ne!(DatabaseProvider::PostgreSql, DatabaseProvider::MySql);
}
#[test]
fn test_default_functions() {
assert_eq!(default_provider(), DatabaseProvider::PostgreSql);
assert_eq!(default_min_connections(), 2);
assert_eq!(default_max_connections(), 10);
assert_eq!(default_connect_timeout(), "30s");
assert_eq!(default_idle_timeout(), "10m");
assert_eq!(default_max_lifetime(), "30m");
assert_eq!(default_schema_path(), "schema.prax");
assert_eq!(default_output(), "./src/generated");
assert!(default_true());
assert_eq!(default_migrations_dir(), "./migrations");
assert_eq!(default_migrations_table(), "_prax_migrations");
assert_eq!(default_slow_query_threshold(), 1000);
}
}