#![cfg(feature = "config")]
use config::{Config as Cfg, Environment, File, FileFormat};
use regex::Regex;
use serde::de::DeserializeOwned;
use std::path::Path;
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Config parsing error: {0}")]
Parse(String),
#[error("Unsupported format: {0}")]
UnsupportedFormat(String),
#[error("Serialization error: {0}")]
Serialization(String),
}
pub type ConfigResult<T> = Result<T, ConfigError>;
pub fn detect_format(path: &str) -> ConfigResult<FileFormat> {
let ext = Path::new(path)
.extension()
.and_then(|e| e.to_str())
.ok_or_else(|| ConfigError::UnsupportedFormat("No file extension found".to_string()))?;
match ext.to_lowercase().as_str() {
"yaml" | "yml" => Ok(FileFormat::Yaml),
"toml" => Ok(FileFormat::Toml),
"json" => Ok(FileFormat::Json),
"ini" => Ok(FileFormat::Ini),
"ron" => Ok(FileFormat::Ron),
"json5" => Ok(FileFormat::Json5),
_ => Err(ConfigError::UnsupportedFormat(ext.to_string())),
}
}
pub fn substitute_env_vars(content: &str) -> String {
let mut result = content.to_string();
let re_braced = Regex::new(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}").unwrap();
result = re_braced
.replace_all(&result, |caps: ®ex::Captures| {
let var_name = &caps[1];
std::env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
})
.to_string();
let re_simple = Regex::new(r"\$([A-Za-z_][A-Za-z0-9_]*)\b").unwrap();
result = re_simple
.replace_all(&result, |caps: ®ex::Captures| {
let var_name = &caps[1];
std::env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
})
.to_string();
result
}
pub fn load_config<T>(path: &str) -> ConfigResult<T>
where
T: DeserializeOwned,
{
let format = detect_format(path)?;
let content = std::fs::read_to_string(path)?;
let substituted_content = substitute_env_vars(&content);
let config = Cfg::builder()
.add_source(File::from_str(&substituted_content, format))
.build()
.map_err(|e| ConfigError::Parse(e.to_string()))?;
config
.try_deserialize()
.map_err(|e| ConfigError::Serialization(e.to_string()))
}
pub fn from_str<T>(content: &str, format: FileFormat) -> ConfigResult<T>
where
T: DeserializeOwned,
{
let substituted_content = substitute_env_vars(content);
let config = Cfg::builder()
.add_source(File::from_str(&substituted_content, format))
.build()
.map_err(|e| ConfigError::Parse(e.to_string()))?;
config
.try_deserialize()
.map_err(|e| ConfigError::Serialization(e.to_string()))
}
pub fn merge_configs<T>(sources: &[(&str, FileFormat)]) -> ConfigResult<T>
where
T: DeserializeOwned,
{
let mut builder = Cfg::builder();
for (content, format) in sources {
let substituted = substitute_env_vars(content);
builder = builder.add_source(File::from_str(&substituted, *format));
}
let config = builder
.build()
.map_err(|e| ConfigError::Parse(e.to_string()))?;
config
.try_deserialize()
.map_err(|e| ConfigError::Serialization(e.to_string()))
}
pub fn load_merged<T>(paths: &[&str]) -> ConfigResult<T>
where
T: DeserializeOwned,
{
let mut builder = Cfg::builder();
for path in paths {
let format = detect_format(path)?;
let content = std::fs::read_to_string(path)?;
let substituted = substitute_env_vars(&content);
builder = builder.add_source(File::from_str(&substituted, format));
}
let config = builder
.build()
.map_err(|e| ConfigError::Parse(e.to_string()))?;
config
.try_deserialize()
.map_err(|e| ConfigError::Serialization(e.to_string()))
}
pub fn load_with_env<T>(path: &str, env_prefix: &str) -> ConfigResult<T>
where
T: DeserializeOwned,
{
let format = detect_format(path)?;
let content = std::fs::read_to_string(path)?;
let substituted = substitute_env_vars(&content);
let config = Cfg::builder()
.add_source(File::from_str(&substituted, format))
.add_source(Environment::with_prefix(env_prefix).separator("__"))
.build()
.map_err(|e| ConfigError::Parse(e.to_string()))?;
config
.try_deserialize()
.map_err(|e| ConfigError::Serialization(e.to_string()))
}
#[cfg(all(test, feature = "config"))]
mod unit_tests {
use super::*;
#[test]
fn test_detect_format() {
assert_eq!(detect_format("config.yaml").unwrap(), FileFormat::Yaml);
assert_eq!(detect_format("config.yml").unwrap(), FileFormat::Yaml);
assert_eq!(detect_format("config.toml").unwrap(), FileFormat::Toml);
assert_eq!(detect_format("config.json").unwrap(), FileFormat::Json);
assert_eq!(detect_format("config.ini").unwrap(), FileFormat::Ini);
assert_eq!(detect_format("config.ron").unwrap(), FileFormat::Ron);
assert_eq!(detect_format("config.json5").unwrap(), FileFormat::Json5);
assert!(detect_format("config.txt").is_err());
}
#[test]
fn test_from_str_toml() {
let toml = r#"
id = "test-agent"
name = "Test Agent"
model = "gpt-4"
"#;
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)]
struct TestConfig {
id: String,
name: String,
model: String,
}
let config: TestConfig = from_str(toml, FileFormat::Toml).unwrap();
assert_eq!(config.id, "test-agent");
assert_eq!(config.name, "Test Agent");
assert_eq!(config.model, "gpt-4");
}
#[test]
fn test_from_str_json() {
let json = r#"
{
"id": "test-agent",
"name": "Test Agent"
}
"#;
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)]
struct TestConfig {
id: String,
name: String,
}
let config: TestConfig = from_str(json, FileFormat::Json).unwrap();
assert_eq!(config.id, "test-agent");
assert_eq!(config.name, "Test Agent");
}
#[test]
fn test_from_str_yaml() {
let yaml = r#"
id: test-agent
name: Test Agent
"#;
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)]
struct TestConfig {
id: String,
name: String,
}
let config: TestConfig = from_str(yaml, FileFormat::Yaml).unwrap();
assert_eq!(config.id, "test-agent");
assert_eq!(config.name, "Test Agent");
}
#[test]
fn test_from_str_ini() {
let ini = r#"
default.id = "test-agent"
default.name = "Test Agent"
"#;
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)]
struct IniConfig {
default: IniSection,
}
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)]
struct IniSection {
id: String,
name: String,
}
let config: IniConfig = from_str(ini, FileFormat::Ini).unwrap();
assert_eq!(config.default.id, "test-agent");
assert_eq!(config.default.name, "Test Agent");
}
#[test]
fn test_from_str_ron() {
let ron = r#"
(
id: "test-agent",
name: "Test Agent",
)
"#;
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)]
struct TestConfig {
id: String,
name: String,
}
let config: TestConfig = from_str(ron, FileFormat::Ron).unwrap();
assert_eq!(config.id, "test-agent");
assert_eq!(config.name, "Test Agent");
}
#[test]
fn test_from_str_json5() {
let json5 = r#"
{
// JSON5 comment
id: "test-agent",
name: "Test Agent",
}
"#;
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)]
struct TestConfig {
id: String,
name: String,
}
let config: TestConfig = from_str(json5, FileFormat::Json5).unwrap();
assert_eq!(config.id, "test-agent");
assert_eq!(config.name, "Test Agent");
}
#[test]
fn test_merge_configs() {
let base = r#"
{
"id": "base-agent",
"name": "Base Name",
"model": "gpt-3.5"
}
"#;
let override_config = r#"
{
"model": "gpt-4"
}
"#;
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)]
struct TestConfig {
id: String,
name: String,
model: String,
}
let config: TestConfig = merge_configs(&[
(base, FileFormat::Json),
(override_config, FileFormat::Json),
])
.unwrap();
assert_eq!(config.id, "base-agent");
assert_eq!(config.name, "Base Name");
assert_eq!(config.model, "gpt-4");
}
}
#[cfg(test)]
mod tests;