use std::{collections::HashSet, env, path::PathBuf};
use figment::{
Figment,
error::Kind,
providers::{Env, Format, Toml},
};
use serde::Deserialize;
use crate::error::{Result, TypedError};
use super::{Config, ConfigParseError};
pub struct TomlConfigProvider<T> {
filename: String,
search_paths: HashSet<PathBuf>,
content: T,
}
impl TomlConfigProvider<()> {
pub fn new(filename: &str) -> Self {
let mut search_paths = HashSet::new();
if let Ok(cwd) = env::current_dir() {
search_paths.insert(cwd);
}
if let Some(path) = env::var_os("CARGO_MANIFEST_DIR") {
search_paths.insert(path.to_string_lossy().into_owned().into());
}
TomlConfigProvider {
filename: filename.to_string(),
search_paths,
content: (),
}
}
#[cfg(test)]
fn from(content: &str) -> TomlConfigProvider<String> {
TomlConfigProvider {
filename: "app.toml".to_string(),
search_paths: HashSet::new(),
content: content.to_string(),
}
}
pub fn build(self) -> TomlConfigProvider<String> {
let file = self
.search_paths
.iter()
.map(|p| p.join(&self.filename))
.find(|p| p.is_file());
if let Some(file) = file {
return TomlConfigProvider {
filename: self.filename,
search_paths: self.search_paths,
content: std::fs::read_to_string(file).unwrap_or_default(),
};
}
TomlConfigProvider {
filename: self.filename,
search_paths: self.search_paths,
content: String::new(),
}
}
}
impl TomlConfigProvider<String> {
pub fn parse<T: Config + Deserialize<'static>>(&self) -> Result<T> {
let mut figment = Figment::from(Toml::string(&self.content));
if let Some(prefix) = T::env_prefix() {
figment = figment.merge(Env::prefixed(prefix));
}
figment.extract_inner::<T>(T::name()).or_else(|e| {
match &e.kind {
Kind::MissingField(_) => {
Ok(T::default())
}
_ => Err(
ConfigParseError::error(format!("Failed to parse config section '{}'", T::name()))
.with_source(e),
),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Deserialize, Debug, PartialEq)]
struct Timeout(u32);
impl Default for Timeout {
fn default() -> Self {
Timeout(30)
}
}
#[derive(Deserialize, Debug, Default)]
struct ServerConfig {
host: String,
port: u16,
#[serde(default)]
timeout: Timeout,
}
impl Config for ServerConfig {
fn name() -> &'static str {
"server"
}
fn env_prefix() -> Option<&'static str> {
Some("APP_SERVER_")
}
}
#[derive(Deserialize, Debug, Default)]
struct StorageConfig {
layer: String,
}
impl Config for StorageConfig {
fn name() -> &'static str {
"storage"
}
}
#[test]
fn test_config_provider() {
let provider = TomlConfigProvider::from(
r#"
[server]
host = "127.0.0.1"
port = 8080
[storage]
layer = "memory"
"#,
);
let config = provider.parse::<ServerConfig>().unwrap();
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 8080);
let storage_config = provider.parse::<StorageConfig>().unwrap();
assert_eq!(storage_config.layer, "memory");
}
#[test]
fn test_override_default() {
let provider = TomlConfigProvider::from(
r#"
[server]
host = "127.0.0.1"
port = 8080
timeout = 120
"#,
);
let config = provider.parse::<ServerConfig>().unwrap();
assert_eq!(config.timeout, Timeout(120));
}
#[test]
fn test_missing_section() {
let provider = TomlConfigProvider::from("");
let config = provider.parse::<ServerConfig>().unwrap();
let default_config = ServerConfig::default();
assert_eq!(config.port, default_config.port);
assert_eq!(config.host, default_config.host);
assert_eq!(config.timeout, default_config.timeout);
}
#[test]
fn test_parse_type_mismatch() {
use crate::config::ConfigParseError;
use crate::error::ErrorCode;
let provider = TomlConfigProvider::from(
r#"
[server]
host = "127.0.0.1"
port = "not a number"
"#,
);
let result = provider.parse::<ServerConfig>();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.is_code(ErrorCode::Internal));
assert!(err.is(ConfigParseError));
assert!(err.message().contains("Failed to parse config"));
}
#[test]
fn test_parse_invalid_toml() {
use crate::config::ConfigParseError;
use crate::error::ErrorCode;
let provider = TomlConfigProvider::from(
r#"
[server
host = "127.0.0.1"
"#,
);
let result = provider.parse::<ServerConfig>();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.is_code(ErrorCode::Internal));
assert!(err.is(ConfigParseError));
assert!(err.message().contains("Failed to parse config"));
}
#[test]
fn test_parse_wrong_type_in_nested_field() {
use crate::config::ConfigParseError;
use crate::error::ErrorCode;
let provider = TomlConfigProvider::from(
r#"
[server]
host = "127.0.0.1"
port = 8080
timeout = "not a number"
"#,
);
let result = provider.parse::<ServerConfig>();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.is_code(ErrorCode::Internal));
assert!(err.is(ConfigParseError));
assert!(err.message().contains("Failed to parse config"));
}
#[test]
fn test_parse_error_preserves_source() {
use crate::config::ConfigParseError;
let provider = TomlConfigProvider::from(
r#"
[server]
port = "invalid"
"#,
);
let result = provider.parse::<ServerConfig>();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.is(ConfigParseError));
assert!(err.source().is_some());
let source = err.source().unwrap();
assert!(!source.to_string().is_empty());
}
}