use serde::Deserialize;
#[derive(Debug, Clone, Deserialize, Default)]
#[non_exhaustive]
pub struct RullstConfig {
#[serde(default)]
pub app: AppConfig,
#[serde(default)]
pub database: DatabaseConfig,
#[serde(default)]
pub security: SecurityConfig,
#[serde(default)]
pub storage: StorageConfig,
}
#[derive(Debug, Clone, Deserialize, Default)]
#[non_exhaustive]
pub struct StorageConfig {
pub root: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Default)]
#[non_exhaustive]
pub struct AppConfig {
pub env: Option<String>,
pub port: Option<u16>,
}
#[derive(Debug, Clone, Deserialize, Default)]
#[non_exhaustive]
pub struct DatabaseConfig {
pub url: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct SecurityConfig {
#[serde(default = "default_same_site")]
pub csrf_same_site: String,
#[serde(default)]
pub cors_allow_origins: Vec<String>,
#[serde(default = "default_csp")]
pub csp: String,
#[serde(default = "default_user_agent_blocklist")]
pub user_agent_blocklist: Vec<String>,
#[serde(default = "default_false")]
pub enable_pii_masking: bool,
}
fn default_csp() -> String {
"default-src 'self'; img-src 'self' data:; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline' 'unsafe-eval';".to_string()
}
fn default_user_agent_blocklist() -> Vec<String> {
vec![
"curl".to_string(),
"wget".to_string(),
"python-requests".to_string(),
"go-http-client".to_string(),
"gptbot".to_string(),
"chatgpt-user".to_string(),
"google-extended".to_string(),
"anthropic-ai".to_string(),
"claude-web".to_string(),
"cohere-ai".to_string(),
"bytespider".to_string(),
"mj12bot".to_string(),
]
}
fn default_same_site() -> String {
"Lax".to_string()
}
fn default_false() -> bool {
false
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
csrf_same_site: default_same_site(),
cors_allow_origins: vec![],
csp: default_csp(),
user_agent_blocklist: default_user_agent_blocklist(),
enable_pii_masking: false,
}
}
}
static GLOBAL_CONFIG: std::sync::OnceLock<RullstConfig> = std::sync::OnceLock::new();
impl RullstConfig {
pub fn global() -> &'static RullstConfig {
GLOBAL_CONFIG.get_or_init(Self::default)
}
#[allow(clippy::result_large_err)]
pub fn set_global(config: Self) -> Result<(), Self> {
GLOBAL_CONFIG.set(config)
}
pub fn new() -> Self {
Self::default()
}
pub async fn load_from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let content = tokio::fs::read_to_string(path).await?;
let config: RullstConfig = toml::from_str(&content)?;
Ok(config)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[tokio::test]
async fn test_global_config_access() {
let config1 = RullstConfig::global();
let config2 = RullstConfig::global();
assert!(
std::ptr::eq(config1, config2),
"global() should return the same instance"
);
assert_eq!(config1.security.csrf_same_site, "Lax");
}
#[tokio::test]
async fn test_load_config_from_file() {
let temp_dir = "test_config_dir";
let _ = std::fs::create_dir_all(temp_dir);
let path = format!("{}/Rullst.toml", temp_dir);
let toml_content = r#"
[app]
env = "production"
port = 8080
[database]
url = "sqlite::memory:"
[security]
csrf_same_site = "Strict"
cors_allow_origins = ["https://example.com"]
"#;
tokio::fs::write(&path, toml_content).await.unwrap();
let config = RullstConfig::load_from_file(&path).await.unwrap();
assert_eq!(config.app.env.unwrap(), "production");
assert_eq!(config.app.port.unwrap(), 8080);
assert_eq!(config.database.url.unwrap(), "sqlite::memory:");
assert_eq!(config.security.csrf_same_site, "Strict");
assert_eq!(config.security.cors_allow_origins.len(), 1);
assert_eq!(config.security.cors_allow_origins[0], "https://example.com");
let _ = std::fs::remove_dir_all(temp_dir);
}
#[test]
fn test_default_security_config() {
let config = SecurityConfig::default();
assert_eq!(config.csrf_same_site, "Lax");
assert!(config.csp.contains("default-src"));
assert!(config.user_agent_blocklist.contains(&"curl".to_string()));
}
#[test]
fn test_set_global_config() {
let mut config = RullstConfig::new();
config.app.env = Some("test_env".to_string());
let result = RullstConfig::set_global(config);
match result {
Ok(_) => assert_eq!(RullstConfig::global().app.env.as_deref(), Some("test_env")),
Err(c) => assert_eq!(c.app.env.as_deref(), Some("test_env")),
}
}
#[test]
fn test_deserialize_security_config_defaults() {
let config: SecurityConfig = toml::from_str("").unwrap();
assert!(!config.enable_pii_masking);
}
}