use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use crate::observability::ObservabilityConfig;
use crate::subscription::SubscriptionConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub server: ServerConfig,
pub auth: AuthConfig,
pub logging: LoggingConfig,
#[serde(default)]
pub http: HttpConfig,
#[serde(default)]
pub observability: ObservabilityConfig,
#[serde(default)]
pub subscriptions: SubscriptionConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub max_connections: usize,
pub ssl_enabled: bool,
pub ssl_cert: Option<PathBuf>,
pub ssl_key: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
pub method: String,
pub password_file: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
pub level: String,
pub file: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpConfig {
pub enabled: bool,
pub host: String,
pub port: u16,
#[serde(default)]
pub auth: HttpAuthConfig,
}
impl Default for HttpConfig {
fn default() -> Self {
Self {
enabled: true,
host: "0.0.0.0".to_string(),
port: 8080,
auth: HttpAuthConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpAuthConfig {
pub enabled: bool,
pub methods: Vec<HttpAuthMethod>,
#[serde(default)]
pub api_keys: ApiKeyConfig,
#[serde(default)]
pub jwt: JwtConfig,
}
impl Default for HttpAuthConfig {
fn default() -> Self {
Self {
enabled: false,
methods: vec![HttpAuthMethod::ApiKey, HttpAuthMethod::Basic],
api_keys: ApiKeyConfig::default(),
jwt: JwtConfig::default(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HttpAuthMethod {
ApiKey,
Basic,
Jwt,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ApiKeyConfig {
#[serde(default)]
pub keys: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtConfig {
#[serde(default)]
pub secret: String,
#[serde(default)]
pub issuer: Option<String>,
#[serde(default)]
pub audience: Option<String>,
#[serde(default = "default_jwt_expiration")]
pub expiration_secs: u64,
}
fn default_jwt_expiration() -> u64 {
3600
}
impl Default for JwtConfig {
fn default() -> Self {
Self {
secret: String::new(),
issuer: Some("vibesql".to_string()),
audience: Some("vibesql-api".to_string()),
expiration_secs: default_jwt_expiration(),
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
server: ServerConfig {
host: "0.0.0.0".to_string(),
port: 5432,
max_connections: 100,
ssl_enabled: false,
ssl_cert: None,
ssl_key: None,
},
auth: AuthConfig { method: "trust".to_string(), password_file: None },
logging: LoggingConfig { level: "info".to_string(), file: None },
http: HttpConfig::default(),
observability: ObservabilityConfig::default(),
subscriptions: SubscriptionConfig::default(),
}
}
}
impl Config {
pub fn load() -> Result<Self> {
let config_paths = vec![
PathBuf::from("vibesql-server.toml"),
dirs::config_dir()
.map(|p| p.join("vibesql").join("vibesql-server.toml"))
.unwrap_or_default(),
PathBuf::from("/etc/vibesql/vibesql-server.toml"),
];
for path in config_paths {
if path.exists() {
let contents = fs::read_to_string(&path)?;
let config: Config = toml::from_str(&contents)?;
return Ok(config);
}
}
Err(anyhow::anyhow!("No configuration file found"))
}
#[allow(dead_code)]
pub fn load_from(path: &PathBuf) -> Result<Self> {
let contents = fs::read_to_string(path)?;
let config: Config = toml::from_str(&contents)?;
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.server.host, "0.0.0.0");
assert_eq!(config.server.port, 5432);
assert_eq!(config.server.max_connections, 100);
assert!(!config.server.ssl_enabled);
assert_eq!(config.auth.method, "trust");
}
#[test]
fn test_config_serialization() {
let config = Config::default();
let toml_str = toml::to_string(&config).unwrap();
let deserialized: Config = toml::from_str(&toml_str).unwrap();
assert_eq!(config.server.port, deserialized.server.port);
}
}