use config::{Config, ConfigError, Environment, File};
use serde::Deserialize;
use std::collections::HashMap;
#[derive(Debug, Deserialize, Clone)]
pub struct AiConfig {
#[serde(default = "default_provider")]
pub default_provider: String,
pub providers: HashMap<String, ProviderConfig>,
#[serde(default)]
pub fallback: FallbackConfig,
#[serde(default)]
pub extractors: ExtractorsConfig,
#[serde(default)]
pub converters: ConvertersConfig,
#[serde(default = "default_timeout")]
pub timeout: u64,
}
#[derive(Debug, Deserialize, Clone)]
pub struct ProviderConfig {
pub enabled: bool,
pub model: String,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_max_tokens")]
pub max_tokens: u32,
pub api_key: Option<String>,
pub base_url: Option<String>,
pub endpoint: Option<String>,
pub deployment_name: Option<String>,
pub api_version: Option<String>,
pub project_id: Option<String>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct FallbackConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub order: Vec<String>,
#[serde(default = "default_retry_attempts")]
pub retry_attempts: u32,
#[serde(default = "default_retry_delay_ms")]
pub retry_delay_ms: u64,
}
impl Default for FallbackConfig {
fn default() -> Self {
Self {
enabled: false,
order: Vec::new(),
retry_attempts: default_retry_attempts(),
retry_delay_ms: default_retry_delay_ms(),
}
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ExtractorsConfig {
#[serde(default = "default_extractors")]
pub enabled: Vec<String>,
#[serde(default = "default_extractors")]
pub order: Vec<String>,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ConvertersConfig {
#[serde(default)]
pub enabled: Vec<String>,
#[serde(default)]
pub order: Vec<String>,
#[serde(default)]
pub default: String,
}
fn default_provider() -> String {
"openai".to_string()
}
fn default_temperature() -> f32 {
0.7
}
fn default_max_tokens() -> u32 {
2000
}
fn default_retry_attempts() -> u32 {
3
}
fn default_retry_delay_ms() -> u64 {
1000
}
fn default_extractors() -> Vec<String> {
vec![
"json_ld".to_string(),
"microdata".to_string(),
"html_class".to_string(),
]
}
fn default_timeout() -> u64 {
30
}
impl AiConfig {
pub fn load() -> Result<Self, ConfigError> {
load_config()
}
}
pub fn load_config() -> Result<AiConfig, ConfigError> {
let settings = Config::builder()
.add_source(File::with_name("config").required(false))
.add_source(
Environment::with_prefix("COOKLANG")
.separator("__")
.try_parsing(true),
)
.build()?;
settings.try_deserialize()
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_default_values() {
assert_eq!(default_provider(), "openai");
assert_eq!(default_temperature(), 0.7);
assert_eq!(default_max_tokens(), 2000);
assert_eq!(default_retry_attempts(), 3);
assert_eq!(default_retry_delay_ms(), 1000);
}
#[test]
fn test_fallback_config_default() {
let fallback = FallbackConfig::default();
assert!(!fallback.enabled);
assert!(fallback.order.is_empty());
assert_eq!(fallback.retry_attempts, 3);
assert_eq!(fallback.retry_delay_ms, 1000);
}
#[test]
fn test_provider_config_has_optional_fields() {
let config = ProviderConfig {
enabled: true,
model: "gpt-4.1-mini".to_string(),
temperature: 0.7,
max_tokens: 2000,
api_key: None,
base_url: None,
endpoint: None,
deployment_name: None,
api_version: None,
project_id: None,
};
assert!(config.api_key.is_none());
assert!(config.base_url.is_none());
}
#[test]
fn test_load_config_without_file() {
let keys_to_clear: Vec<String> = env::vars()
.filter(|(k, _)| k.starts_with("COOKLANG__"))
.map(|(k, _)| k)
.collect();
for key in keys_to_clear {
env::remove_var(&key);
}
let result = load_config();
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_ai_config_structure() {
let mut providers = HashMap::new();
providers.insert(
"openai".to_string(),
ProviderConfig {
enabled: true,
model: "gpt-4.1-mini".to_string(),
temperature: 0.7,
max_tokens: 2000,
api_key: Some("test-key".to_string()),
base_url: None,
endpoint: None,
deployment_name: None,
api_version: None,
project_id: None,
},
);
let config = AiConfig {
default_provider: "openai".to_string(),
providers,
fallback: FallbackConfig::default(),
extractors: ExtractorsConfig::default(),
converters: ConvertersConfig::default(),
timeout: default_timeout(),
};
assert_eq!(config.default_provider, "openai");
assert_eq!(config.providers.len(), 1);
assert!(config.providers.contains_key("openai"));
}
}