use serde::Deserialize;
use std::path::Path;
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct Config {
#[serde(default = "default_listen_addr")]
pub listen_addr: String,
pub tls: Option<TlsConfig>,
#[serde(default)]
pub logging: LoggingConfig,
#[serde(default)]
pub server: ServerConfig,
#[serde(default = "default_catalog_name")]
pub catalog_name: String,
#[serde(default = "default_schema_name")]
pub default_schema: String,
#[serde(default)]
pub enable_test_fixtures: bool,
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct TlsConfig {
pub cert_path: String,
pub key_path: String,
pub ca_path: Option<String>,
#[serde(default)]
pub client_auth_optional: bool,
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct LoggingConfig {
#[serde(default = "default_log_level")]
pub level: String,
#[serde(default = "default_log_format")]
pub format: String,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: default_log_level(),
format: default_log_format(),
}
}
}
fn default_log_level() -> String {
"info".to_string()
}
fn default_log_format() -> String {
"pretty".to_string()
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct ServerConfig {
#[serde(default = "default_shutdown_timeout_secs")]
pub shutdown_timeout_secs: u64,
#[serde(default = "default_tcp_keepalive_secs")]
pub tcp_keepalive_secs: u64,
#[serde(default)]
pub max_connections: usize,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
shutdown_timeout_secs: default_shutdown_timeout_secs(),
tcp_keepalive_secs: default_tcp_keepalive_secs(),
max_connections: 0,
}
}
}
fn default_shutdown_timeout_secs() -> u64 {
30
}
fn default_tcp_keepalive_secs() -> u64 {
60
}
fn default_listen_addr() -> String {
"0.0.0.0:50051".to_string()
}
fn default_catalog_name() -> String {
"default".to_string()
}
fn default_schema_name() -> String {
"public".to_string()
}
impl Default for Config {
fn default() -> Self {
Self {
listen_addr: default_listen_addr(),
tls: None,
logging: LoggingConfig::default(),
server: ServerConfig::default(),
catalog_name: default_catalog_name(),
default_schema: default_schema_name(),
enable_test_fixtures: false,
}
}
}
impl Config {
pub fn load() -> anyhow::Result<Self> {
if let Ok(path) = std::env::var("SAGITTA_CONFIG")
&& Path::new(&path).exists()
{
return Self::from_file(&path);
}
let default_path = "sagitta.toml";
if Path::new(default_path).exists() {
return Self::from_file(default_path);
}
Ok(Self::default())
}
pub fn from_file(path: &str) -> anyhow::Result<Self> {
let content = std::fs::read_to_string(path)?;
let config: Config = toml::from_str(&content)?;
Ok(config)
}
pub fn parse_toml(content: &str) -> anyhow::Result<Self> {
let config: Config = toml::from_str(content)?;
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.listen_addr, "0.0.0.0:50051");
assert!(config.tls.is_none());
assert_eq!(config.logging.level, "info");
assert_eq!(config.logging.format, "pretty");
assert_eq!(config.server.shutdown_timeout_secs, 30);
assert_eq!(config.server.tcp_keepalive_secs, 60);
}
#[test]
fn test_parse_minimal_config() {
let toml = r#"
listen_addr = "127.0.0.1:8080"
"#;
let config = Config::parse_toml(toml).unwrap();
assert_eq!(config.listen_addr, "127.0.0.1:8080");
assert!(config.tls.is_none());
}
#[test]
fn test_parse_tls_config() {
let toml = r#"
listen_addr = "0.0.0.0:50051"
[tls]
cert_path = "/path/to/cert.pem"
key_path = "/path/to/key.pem"
"#;
let config = Config::parse_toml(toml).unwrap();
assert!(config.tls.is_some());
let tls = config.tls.unwrap();
assert_eq!(tls.cert_path, "/path/to/cert.pem");
assert_eq!(tls.key_path, "/path/to/key.pem");
assert!(tls.ca_path.is_none());
}
#[test]
fn test_parse_mtls_config() {
let toml = r#"
listen_addr = "0.0.0.0:50051"
[tls]
cert_path = "/path/to/cert.pem"
key_path = "/path/to/key.pem"
ca_path = "/path/to/ca.pem"
"#;
let config = Config::parse_toml(toml).unwrap();
assert!(config.tls.is_some());
let tls = config.tls.unwrap();
assert_eq!(tls.cert_path, "/path/to/cert.pem");
assert_eq!(tls.key_path, "/path/to/key.pem");
assert_eq!(tls.ca_path, Some("/path/to/ca.pem".to_string()));
}
#[test]
fn test_default_listen_addr_when_omitted() {
let toml = r#"
[tls]
cert_path = "/path/to/cert.pem"
key_path = "/path/to/key.pem"
"#;
let config = Config::parse_toml(toml).unwrap();
assert_eq!(config.listen_addr, "0.0.0.0:50051");
}
#[test]
fn test_parse_logging_config() {
let toml = r#"
[logging]
level = "debug"
format = "json"
"#;
let config = Config::parse_toml(toml).unwrap();
assert_eq!(config.logging.level, "debug");
assert_eq!(config.logging.format, "json");
}
#[test]
fn test_parse_server_config() {
let toml = r#"
[server]
shutdown_timeout_secs = 60
tcp_keepalive_secs = 120
max_connections = 1000
"#;
let config = Config::parse_toml(toml).unwrap();
assert_eq!(config.server.shutdown_timeout_secs, 60);
assert_eq!(config.server.tcp_keepalive_secs, 120);
assert_eq!(config.server.max_connections, 1000);
}
#[test]
fn test_parse_full_config() {
let toml = r#"
listen_addr = "127.0.0.1:9999"
[logging]
level = "trace"
format = "json"
[server]
shutdown_timeout_secs = 10
tcp_keepalive_secs = 30
max_connections = 500
[tls]
cert_path = "/certs/server.pem"
key_path = "/certs/key.pem"
ca_path = "/certs/ca.pem"
client_auth_optional = true
"#;
let config = Config::parse_toml(toml).unwrap();
assert_eq!(config.listen_addr, "127.0.0.1:9999");
assert_eq!(config.logging.level, "trace");
assert_eq!(config.logging.format, "json");
assert_eq!(config.server.shutdown_timeout_secs, 10);
assert_eq!(config.server.tcp_keepalive_secs, 30);
assert_eq!(config.server.max_connections, 500);
let tls = config.tls.unwrap();
assert_eq!(tls.cert_path, "/certs/server.pem");
assert!(tls.client_auth_optional);
}
}