use llmtrace_core::{ProxyConfig, TenantId};
use std::path::Path;
pub const CATCH_ALL_TENANT_NAME: &str = "catch-all";
pub fn load_config(path: &Path) -> anyhow::Result<ProxyConfig> {
let contents = std::fs::read_to_string(path)
.map_err(|e| anyhow::anyhow!("Failed to read config file {}: {}", path.display(), e))?;
let config: ProxyConfig = serde_yaml::from_str(&contents)
.map_err(|e| anyhow::anyhow!("Failed to parse config YAML: {}", e))?;
Ok(config)
}
pub fn apply_env_overrides(config: &mut ProxyConfig) {
if let Ok(val) = std::env::var("LLMTRACE_LISTEN_ADDR") {
config.listen_addr = val;
}
if let Ok(val) = std::env::var("LLMTRACE_UPSTREAM_URL") {
config.upstream_url = val;
}
if let Ok(val) = std::env::var("LLMTRACE_STORAGE_PROFILE") {
config.storage.profile = val;
}
if let Ok(val) = std::env::var("LLMTRACE_STORAGE_DATABASE_PATH") {
config.storage.database_path = val;
}
if let Ok(val) = std::env::var("LLMTRACE_CLICKHOUSE_URL") {
config.storage.clickhouse_url = Some(val);
}
if let Ok(val) = std::env::var("LLMTRACE_CLICKHOUSE_DATABASE") {
config.storage.clickhouse_database = Some(val);
}
if let Ok(val) = std::env::var("LLMTRACE_POSTGRES_URL") {
config.storage.postgres_url = Some(val);
}
if let Ok(val) = std::env::var("LLMTRACE_REDIS_URL") {
config.storage.redis_url = Some(val);
}
if let Ok(val) = std::env::var("LLMTRACE_AUTH_ENABLED") {
let val = val.to_lowercase();
config.auth.enabled = val == "1" || val == "true" || val == "yes";
}
if let Ok(val) = std::env::var("LLMTRACE_AUTH_ADMIN_KEY") {
config.auth.admin_key = Some(val);
}
if let Some(rps) = parse_positive_u32("LLMTRACE_RATE_LIMIT_RPS") {
config.rate_limiting.requests_per_second = rps;
}
if let Some(burst) = parse_positive_u32("LLMTRACE_RATE_LIMIT_BURST") {
config.rate_limiting.burst_size = burst;
}
if let Ok(val) = std::env::var("LLMTRACE_DATAMARKING_ENABLED") {
let val = val.to_lowercase();
config.boundary_defense.datamarking.enabled = val == "1" || val == "true" || val == "yes";
}
if let Ok(val) = std::env::var("LLMTRACE_DATAMARKING_SHADOW_MODE") {
let val = val.to_lowercase();
config.boundary_defense.datamarking.shadow_mode =
val == "1" || val == "true" || val == "yes";
}
if let Ok(val) = std::env::var("LLMTRACE_ZONE_DETECTION_ENABLED") {
let val = val.to_lowercase();
config.security_analysis.zone_detection.enabled =
val == "1" || val == "true" || val == "yes";
}
if let Ok(val) = std::env::var("LLMTRACE_DEFAULT_TENANT_ID") {
match uuid::Uuid::parse_str(val.trim()) {
Ok(uuid) => config.default_tenant_id = Some(llmtrace_core::TenantId(uuid)),
Err(_) => tracing::warn!(
value = %val,
"LLMTRACE_DEFAULT_TENANT_ID must be a valid UUID; ignoring"
),
}
}
if let Ok(val) = std::env::var("LLMTRACE_ML_MAX_CONCURRENT") {
match val.parse::<usize>() {
Ok(n) if n > 0 => config.ml_pipeline.max_concurrent_requests = n,
_ => tracing::warn!(
value = %val,
"LLMTRACE_ML_MAX_CONCURRENT must be a positive integer; keeping current value"
),
}
}
}
pub fn resolve_catch_all_tenant_id(config: &mut ProxyConfig) -> TenantId {
match config.default_tenant_id {
Some(id) => {
tracing::info!(
tenant_id = %id.0,
"Using configured LLMTRACE_DEFAULT_TENANT_ID as catch-all tenant for tenant-less traffic"
);
id
}
None => {
let generated = TenantId(uuid::Uuid::new_v4());
config.default_tenant_id = Some(generated);
tracing::info!(
tenant_id = %generated.0,
"No LLMTRACE_DEFAULT_TENANT_ID provided; generated catch-all tenant {} for tenant-less traffic",
generated.0
);
generated
}
}
}
fn parse_positive_u32(name: &str) -> Option<u32> {
let raw = std::env::var(name).ok()?;
raw.trim().parse::<u32>().ok().filter(|n| *n > 0)
}
pub fn validate_config(config: &ProxyConfig) -> anyhow::Result<()> {
let mut errors: Vec<String> = Vec::new();
if config.listen_addr.is_empty() {
errors.push("listen_addr must not be empty".to_string());
}
if config.upstream_url.is_empty() {
errors.push("upstream_url must not be empty".to_string());
} else if !config.upstream_url.starts_with("http://")
&& !config.upstream_url.starts_with("https://")
{
errors.push("upstream_url must start with http:// or https://".to_string());
}
match config.storage.profile.as_str() {
"lite" | "sqlite" | "memory" | "production" => {}
other => errors.push(format!(
"storage.profile must be 'lite', 'sqlite', 'memory', or 'production', got '{other}'"
)),
}
match config.logging.level.as_str() {
"trace" | "debug" | "info" | "warn" | "error" => {}
other => errors.push(format!(
"logging.level must be trace/debug/info/warn/error, got '{other}'"
)),
}
match config.logging.format.as_str() {
"text" | "json" => {}
other => errors.push(format!(
"logging.format must be 'text' or 'json', got '{other}'"
)),
}
if config.timeout_ms == 0 {
errors.push("timeout_ms must be greater than 0".to_string());
}
if config.connection_timeout_ms == 0 {
errors.push("connection_timeout_ms must be greater than 0".to_string());
}
if config.enable_tls {
if config.tls_cert_file.is_none() {
errors.push("tls_cert_file is required when enable_tls is true".to_string());
}
if config.tls_key_file.is_none() {
errors.push("tls_key_file is required when enable_tls is true".to_string());
}
}
let bd = &config.boundary_defense;
if bd.enabled {
if bd.delimiter.is_empty() {
errors.push("boundary_defense.delimiter must not be empty when enabled".to_string());
}
if bd.wrap_roles.is_empty() {
errors.push("boundary_defense.wrap_roles must not be empty when enabled".to_string());
}
}
let enf = &config.enforcement;
if !(0.0..=1.0).contains(&enf.min_confidence) {
errors.push(format!(
"enforcement.min_confidence must be between 0.0 and 1.0, got {}",
enf.min_confidence
));
}
if enf.timeout_ms == 0 {
errors.push("enforcement.timeout_ms must be greater than 0".to_string());
}
if config.max_response_size_bytes == 0 {
errors.push("max_response_size_bytes must be > 0".to_string());
}
let ar = &config.action_router;
if ar.enabled {
if ar.ip_block.ttl_seconds == 0 {
errors.push("action_router.ip_block.ttl_seconds must be greater than 0".to_string());
}
if ar.ip_block.max_offenses == 0 {
errors.push("action_router.ip_block.max_offenses must be greater than 0".to_string());
}
if ar.webhook.timeout_ms == 0 {
errors.push("action_router.webhook.timeout_ms must be greater than 0".to_string());
}
if ar.judge_route.inline_timeout_ms == 0 {
errors.push(
"action_router.judge_route.inline_timeout_ms must be greater than 0".to_string(),
);
}
let webhook_referenced = ar.default_actions.iter().any(|a| a == "webhook")
|| ar
.rules
.iter()
.any(|rule| rule.actions.iter().any(|action| action == "webhook"));
if ar.webhook.url.is_empty() && webhook_referenced {
errors.push(
"action_router.webhook.url must not be empty if webhook action is enabled"
.to_string(),
);
}
}
if config.security_analysis.max_analysis_text_bytes == 0 {
errors.push("security_analysis.max_analysis_text_bytes must be > 0".to_string());
}
if config.ml_pipeline.max_concurrent_requests == 0 {
errors.push("ml_pipeline.max_concurrent_requests must be >= 1".to_string());
}
if errors.is_empty() {
Ok(())
} else {
Err(anyhow::anyhow!(
"Configuration errors:\n - {}",
errors.join("\n - ")
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use llmtrace_core::LoggingConfig;
use std::io::Write;
use tempfile::NamedTempFile;
fn write_yaml(yaml: &str) -> NamedTempFile {
let mut f = NamedTempFile::new().unwrap();
f.write_all(yaml.as_bytes()).unwrap();
f
}
#[test]
fn test_load_config_minimal() {
let yaml = r#"
listen_addr: "127.0.0.1:9090"
upstream_url: "http://localhost:11434"
timeout_ms: 60000
connection_timeout_ms: 5000
max_connections: 500
enable_tls: false
enable_security_analysis: true
enable_trace_storage: true
enable_streaming: true
max_request_size_bytes: 52428800
security_analysis_timeout_ms: 5000
trace_storage_timeout_ms: 10000
rate_limiting:
enabled: true
requests_per_second: 100
burst_size: 200
window_seconds: 60
circuit_breaker:
enabled: true
failure_threshold: 10
recovery_timeout_ms: 30000
half_open_max_calls: 3
health_check:
enabled: true
path: "/health"
interval_seconds: 10
timeout_ms: 5000
retries: 3
"#;
let f = write_yaml(yaml);
let config = load_config(f.path()).unwrap();
assert_eq!(config.listen_addr, "127.0.0.1:9090");
assert_eq!(config.upstream_url, "http://localhost:11434");
assert_eq!(config.timeout_ms, 60000);
assert_eq!(config.logging.level, "info");
assert_eq!(config.logging.format, "text");
}
#[test]
fn test_load_config_with_logging() {
let yaml = r#"
listen_addr: "0.0.0.0:8080"
upstream_url: "https://api.openai.com"
timeout_ms: 30000
connection_timeout_ms: 5000
max_connections: 1000
enable_tls: false
enable_security_analysis: true
enable_trace_storage: true
enable_streaming: true
max_request_size_bytes: 52428800
security_analysis_timeout_ms: 5000
trace_storage_timeout_ms: 10000
logging:
level: "debug"
format: "json"
rate_limiting:
enabled: true
requests_per_second: 100
burst_size: 200
window_seconds: 60
circuit_breaker:
enabled: true
failure_threshold: 10
recovery_timeout_ms: 30000
half_open_max_calls: 3
health_check:
enabled: true
path: "/health"
interval_seconds: 10
timeout_ms: 5000
retries: 3
"#;
let f = write_yaml(yaml);
let config = load_config(f.path()).unwrap();
assert_eq!(config.logging.level, "debug");
assert_eq!(config.logging.format, "json");
}
#[test]
fn test_load_config_missing_file() {
let result = load_config(Path::new("/nonexistent/config.yaml"));
assert!(result.is_err());
}
#[test]
fn test_load_config_invalid_yaml() {
let f = write_yaml("not: [valid: yaml: {{{}}}");
let result = load_config(f.path());
assert!(result.is_err());
}
#[test]
fn test_apply_env_overrides_listen_addr() {
let mut config = ProxyConfig::default();
std::env::set_var("LLMTRACE_LISTEN_ADDR", "127.0.0.1:3000");
apply_env_overrides(&mut config);
assert_eq!(config.listen_addr, "127.0.0.1:3000");
std::env::remove_var("LLMTRACE_LISTEN_ADDR");
}
#[test]
fn test_apply_env_overrides_upstream_url() {
let mut config = ProxyConfig::default();
std::env::set_var("LLMTRACE_UPSTREAM_URL", "http://my-llm:8000");
apply_env_overrides(&mut config);
assert_eq!(config.upstream_url, "http://my-llm:8000");
std::env::remove_var("LLMTRACE_UPSTREAM_URL");
}
#[test]
fn test_apply_env_overrides_storage() {
let mut config = ProxyConfig::default();
std::env::set_var("LLMTRACE_STORAGE_PROFILE", "memory");
std::env::set_var("LLMTRACE_STORAGE_DATABASE_PATH", "/tmp/test.db");
apply_env_overrides(&mut config);
assert_eq!(config.storage.profile, "memory");
assert_eq!(config.storage.database_path, "/tmp/test.db");
std::env::remove_var("LLMTRACE_STORAGE_PROFILE");
std::env::remove_var("LLMTRACE_STORAGE_DATABASE_PATH");
}
static RATE_LIMIT_ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn test_apply_env_overrides_rate_limit_rps_and_burst() {
let _guard = RATE_LIMIT_ENV_LOCK.lock().unwrap();
let mut config = ProxyConfig::default();
std::env::set_var("LLMTRACE_RATE_LIMIT_RPS", "250");
std::env::set_var("LLMTRACE_RATE_LIMIT_BURST", "500");
apply_env_overrides(&mut config);
std::env::remove_var("LLMTRACE_RATE_LIMIT_RPS");
std::env::remove_var("LLMTRACE_RATE_LIMIT_BURST");
assert_eq!(config.rate_limiting.requests_per_second, 250);
assert_eq!(config.rate_limiting.burst_size, 500);
}
#[test]
fn test_apply_env_overrides_rate_limit_ignores_invalid() {
let _guard = RATE_LIMIT_ENV_LOCK.lock().unwrap();
let baseline = ProxyConfig::default();
let mut config = ProxyConfig::default();
std::env::set_var("LLMTRACE_RATE_LIMIT_RPS", "not-a-number");
std::env::set_var("LLMTRACE_RATE_LIMIT_BURST", "0");
apply_env_overrides(&mut config);
std::env::remove_var("LLMTRACE_RATE_LIMIT_RPS");
std::env::remove_var("LLMTRACE_RATE_LIMIT_BURST");
assert_eq!(
config.rate_limiting.requests_per_second,
baseline.rate_limiting.requests_per_second
);
assert_eq!(
config.rate_limiting.burst_size,
baseline.rate_limiting.burst_size
);
}
#[test]
fn test_apply_env_overrides_datamarking_and_zone() {
let mut config = ProxyConfig::default();
assert!(!config.boundary_defense.datamarking.enabled);
assert!(config.boundary_defense.datamarking.shadow_mode);
assert!(!config.security_analysis.zone_detection.enabled);
std::env::set_var("LLMTRACE_DATAMARKING_ENABLED", "true");
std::env::set_var("LLMTRACE_DATAMARKING_SHADOW_MODE", "false");
std::env::set_var("LLMTRACE_ZONE_DETECTION_ENABLED", "1");
apply_env_overrides(&mut config);
std::env::remove_var("LLMTRACE_DATAMARKING_ENABLED");
std::env::remove_var("LLMTRACE_DATAMARKING_SHADOW_MODE");
std::env::remove_var("LLMTRACE_ZONE_DETECTION_ENABLED");
assert!(config.boundary_defense.datamarking.enabled);
assert!(!config.boundary_defense.datamarking.shadow_mode);
assert!(config.security_analysis.zone_detection.enabled);
}
#[test]
fn test_apply_env_overrides_default_tenant_id() {
let mut config = ProxyConfig::default();
assert!(config.default_tenant_id.is_none());
let id = uuid::Uuid::new_v4();
std::env::set_var("LLMTRACE_DEFAULT_TENANT_ID", id.to_string());
apply_env_overrides(&mut config);
std::env::remove_var("LLMTRACE_DEFAULT_TENANT_ID");
assert_eq!(config.default_tenant_id.map(|t| t.0), Some(id));
}
#[test]
fn test_apply_env_overrides_default_tenant_id_invalid_ignored() {
let mut config = ProxyConfig::default();
std::env::set_var("LLMTRACE_DEFAULT_TENANT_ID", "not-a-uuid");
apply_env_overrides(&mut config);
std::env::remove_var("LLMTRACE_DEFAULT_TENANT_ID");
assert!(config.default_tenant_id.is_none());
}
#[test]
fn test_resolve_catch_all_uses_configured_id() {
let configured = TenantId(uuid::Uuid::new_v4());
let mut config = ProxyConfig {
default_tenant_id: Some(configured),
..ProxyConfig::default()
};
let resolved = resolve_catch_all_tenant_id(&mut config);
assert_eq!(resolved, configured);
assert_eq!(config.default_tenant_id, Some(configured));
}
#[test]
fn test_resolve_catch_all_generates_fresh_id_when_unset() {
let mut config = ProxyConfig::default();
assert!(config.default_tenant_id.is_none());
let resolved = resolve_catch_all_tenant_id(&mut config);
assert!(
!resolved.0.is_nil(),
"generated catch-all id must not be the nil UUID"
);
assert_eq!(
config.default_tenant_id,
Some(resolved),
"resolved id must be stamped back onto the config"
);
}
#[test]
fn test_resolve_catch_all_generates_distinct_ids_across_builds() {
let mut a = ProxyConfig::default();
let mut b = ProxyConfig::default();
let id_a = resolve_catch_all_tenant_id(&mut a);
let id_b = resolve_catch_all_tenant_id(&mut b);
assert_ne!(
id_a, id_b,
"fresh resolutions must produce distinct ids (proves runtime generation, not a hardcoded literal)"
);
assert!(!id_a.0.is_nil());
assert!(!id_b.0.is_nil());
}
#[test]
fn test_resolve_catch_all_never_nil_or_anonymous_sentinel() {
let anonymous = uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_OID, b"llmtrace-anonymous");
let mut config = ProxyConfig::default();
let resolved = resolve_catch_all_tenant_id(&mut config);
assert!(!resolved.0.is_nil());
assert_ne!(
resolved.0, anonymous,
"catch-all must be a fresh tenant, not the anonymous sentinel"
);
}
#[test]
fn test_validate_config_valid() {
let config = ProxyConfig::default();
assert!(validate_config(&config).is_ok());
}
#[test]
fn test_validate_config_empty_listen_addr() {
let config = ProxyConfig {
listen_addr: String::new(),
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err.to_string().contains("listen_addr must not be empty"));
}
#[test]
fn test_validate_config_empty_upstream_url() {
let config = ProxyConfig {
upstream_url: String::new(),
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err.to_string().contains("upstream_url must not be empty"));
}
#[test]
fn test_validate_config_bad_upstream_url_scheme() {
let config = ProxyConfig {
upstream_url: "ftp://example.com".to_string(),
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err
.to_string()
.contains("upstream_url must start with http://"));
}
#[test]
fn test_validate_config_invalid_storage_profile() {
let config = ProxyConfig {
storage: llmtrace_core::StorageConfig {
profile: "postgres".to_string(),
database_path: String::new(),
..llmtrace_core::StorageConfig::default()
},
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err.to_string().contains("storage.profile"));
}
#[test]
fn test_validate_config_accepts_sqlite_alias() {
for profile in ["lite", "sqlite", "memory"] {
let config = ProxyConfig {
storage: llmtrace_core::StorageConfig {
profile: profile.to_string(),
..llmtrace_core::StorageConfig::default()
},
..ProxyConfig::default()
};
validate_config(&config)
.unwrap_or_else(|e| panic!("profile {profile} must validate cleanly: {e}"));
}
}
#[test]
fn test_validate_config_invalid_log_level() {
let config = ProxyConfig {
logging: LoggingConfig {
level: "verbose".to_string(),
format: "text".to_string(),
},
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err.to_string().contains("logging.level"));
}
#[test]
fn test_validate_config_invalid_log_format() {
let config = ProxyConfig {
logging: LoggingConfig {
level: "info".to_string(),
format: "xml".to_string(),
},
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err.to_string().contains("logging.format"));
}
#[test]
fn test_validate_config_zero_timeout() {
let config = ProxyConfig {
timeout_ms: 0,
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err.to_string().contains("timeout_ms"));
}
#[test]
fn test_validate_config_action_router_invalid_values() {
let mut config = ProxyConfig::default();
config.listen_addr = "127.0.0.1:8080".to_string();
config.upstream_url = "http://localhost:11434".to_string();
config.timeout_ms = 1000;
config.connection_timeout_ms = 1000;
config.max_response_size_bytes = 1024;
config.security_analysis.max_analysis_text_bytes = 1024;
config.action_router.enabled = true;
config.action_router.ip_block.ttl_seconds = 0;
config.action_router.ip_block.max_offenses = 0;
config.action_router.webhook.timeout_ms = 0;
config.action_router.judge_route.inline_timeout_ms = 0;
config.action_router.rules = vec![llmtrace_core::ActionRuleConfig {
finding_type: Some("prompt_injection".to_string()),
min_severity: llmtrace_core::SecuritySeverity::High,
min_confidence: 0.8,
actions: vec!["webhook".to_string()],
}];
let err = validate_config(&config).unwrap_err().to_string();
assert!(err.contains("action_router.ip_block.ttl_seconds"));
assert!(err.contains("action_router.ip_block.max_offenses"));
assert!(err.contains("action_router.webhook.timeout_ms"));
assert!(err.contains("action_router.judge_route.inline_timeout_ms"));
assert!(err.contains("action_router.webhook.url"));
}
#[test]
fn test_validate_config_tls_without_cert() {
let config = ProxyConfig {
enable_tls: true,
tls_cert_file: None,
tls_key_file: None,
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("tls_cert_file"));
assert!(msg.contains("tls_key_file"));
}
#[test]
fn test_validate_config_enforcement_bad_confidence() {
let config = ProxyConfig {
enforcement: llmtrace_core::EnforcementConfig {
min_confidence: 1.5,
..llmtrace_core::EnforcementConfig::default()
},
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err.to_string().contains("enforcement.min_confidence"));
}
#[test]
fn test_validate_config_enforcement_negative_confidence() {
let config = ProxyConfig {
enforcement: llmtrace_core::EnforcementConfig {
min_confidence: -0.1,
..llmtrace_core::EnforcementConfig::default()
},
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err.to_string().contains("enforcement.min_confidence"));
}
#[test]
fn test_validate_config_enforcement_zero_timeout() {
let config = ProxyConfig {
enforcement: llmtrace_core::EnforcementConfig {
timeout_ms: 0,
..llmtrace_core::EnforcementConfig::default()
},
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err.to_string().contains("enforcement.timeout_ms"));
}
#[test]
fn test_validate_config_multiple_errors() {
let config = ProxyConfig {
listen_addr: String::new(),
upstream_url: String::new(),
timeout_ms: 0,
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("listen_addr"));
assert!(msg.contains("upstream_url"));
assert!(msg.contains("timeout_ms"));
}
#[test]
fn test_validate_config_zero_max_response_size() {
let config = ProxyConfig {
max_response_size_bytes: 0,
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err
.to_string()
.contains("max_response_size_bytes must be > 0"));
}
#[test]
fn test_validate_config_zero_max_analysis_text() {
let config = ProxyConfig {
security_analysis: llmtrace_core::SecurityAnalysisConfig {
max_analysis_text_bytes: 0,
..llmtrace_core::SecurityAnalysisConfig::default()
},
..ProxyConfig::default()
};
let err = validate_config(&config).unwrap_err();
assert!(err
.to_string()
.contains("security_analysis.max_analysis_text_bytes must be > 0"));
}
}