use crate::hash_with_indifferent_access::HashWithIndifferentAccess;
use serde::de::DeserializeOwned;
use serde_json::{Map, Number, Value};
use std::env;
use std::fs;
use std::path::Path;
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("file not found: {0}")]
FileNotFound(String),
#[error("parse error: {0}")]
ParseError(String),
#[error("missing key: {0}")]
MissingKey(String),
}
#[derive(Debug, Clone, Default)]
pub struct Config {
values: HashWithIndifferentAccess,
}
impl Config {
#[must_use]
pub fn new() -> Self {
Self {
values: HashWithIndifferentAccess::new(),
}
}
pub fn from_toml(content: &str) -> Result<Self, ConfigError> {
let parsed: toml::Value =
toml::from_str(content).map_err(|error| ConfigError::ParseError(error.to_string()))?;
let json = serde_json::to_value(parsed)
.map_err(|error| ConfigError::ParseError(error.to_string()))?;
match json {
Value::Object(map) => Ok(Self {
values: HashWithIndifferentAccess::from(map),
}),
_ => Err(ConfigError::ParseError(
"top-level TOML value must be a table".to_owned(),
)),
}
}
pub fn from_file(path: &Path) -> Result<Self, ConfigError> {
let content = fs::read_to_string(path).map_err(|error| {
if error.kind() == std::io::ErrorKind::NotFound {
ConfigError::FileNotFound(path.display().to_string())
} else {
ConfigError::ParseError(error.to_string())
}
})?;
Self::from_toml(&content)
}
#[must_use]
pub fn get(&self, key: &str) -> Option<&Value> {
get_path(self.values.as_index_map(), key)
}
#[must_use]
pub fn get_str(&self, key: &str) -> Option<&str> {
self.get(key).and_then(Value::as_str)
}
#[must_use]
pub fn get_i64(&self, key: &str) -> Option<i64> {
self.get(key).and_then(Value::as_i64)
}
#[must_use]
pub fn get_bool(&self, key: &str) -> Option<bool> {
self.get(key).and_then(Value::as_bool)
}
pub fn set(&mut self, key: impl Into<String>, value: Value) {
set_path(self.values.as_index_map_mut(), &key.into(), value);
}
pub fn apply_env_overrides(&mut self, prefix: &str) {
let prefix = format!("{prefix}_");
for (env_key, env_value) in env::vars() {
if !env_key.starts_with(&prefix) {
continue;
}
let suffix = &env_key[prefix.len()..];
if suffix.is_empty() {
continue;
}
let dotted_key = suffix
.split('_')
.filter(|segment| !segment.is_empty())
.map(str::to_ascii_lowercase)
.collect::<Vec<_>>()
.join(".");
if dotted_key.is_empty() {
continue;
}
self.set(dotted_key, parse_env_value(&env_value));
}
}
pub fn extract<T: DeserializeOwned>(&self) -> Result<T, ConfigError> {
let root = Value::Object(
self.values
.as_index_map()
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect::<Map<String, Value>>(),
);
serde_json::from_value(root).map_err(|error| {
match parse_missing_field(&error.to_string()) {
Some(field) => ConfigError::MissingKey(field),
None => ConfigError::ParseError(error.to_string()),
}
})
}
}
fn get_path<'a>(root: &'a indexmap::IndexMap<String, Value>, key: &str) -> Option<&'a Value> {
let mut segments = key.split('.');
let first = segments.next()?;
let mut current = root.get(first)?;
for segment in segments {
current = current.as_object()?.get(segment)?;
}
Some(current)
}
fn set_path(root: &mut indexmap::IndexMap<String, Value>, key: &str, value: Value) {
let parts: Vec<&str> = key
.split('.')
.filter(|segment| !segment.is_empty())
.collect();
if parts.is_empty() {
return;
}
if parts.len() == 1 {
root.insert(parts[0].to_owned(), value);
return;
}
let mut current = root
.entry(parts[0].to_owned())
.or_insert_with(|| Value::Object(Map::new()));
for segment in &parts[1..parts.len() - 1] {
match current {
Value::Object(map) => {
current = map
.entry((*segment).to_owned())
.or_insert_with(|| Value::Object(Map::new()));
}
_ => {
*current = Value::Object(Map::new());
if let Value::Object(map) = current {
current = map
.entry((*segment).to_owned())
.or_insert_with(|| Value::Object(Map::new()));
}
}
}
}
if let Value::Object(map) = current {
map.insert(parts[parts.len() - 1].to_owned(), value);
}
}
fn parse_env_value(value: &str) -> Value {
if value.eq_ignore_ascii_case("true") {
return Value::Bool(true);
}
if value.eq_ignore_ascii_case("false") {
return Value::Bool(false);
}
if let Ok(integer) = value.parse::<i64>() {
return Value::Number(integer.into());
}
if let Ok(unsigned) = value.parse::<u64>() {
return Value::Number(unsigned.into());
}
if let Ok(float) = value.parse::<f64>()
&& let Some(number) = Number::from_f64(float)
{
return Value::Number(number);
}
if ((value.starts_with('[') && value.ends_with(']'))
|| (value.starts_with('{') && value.ends_with('}')))
&& let Ok(json) = serde_json::from_str::<Value>(value)
{
return json;
}
Value::String(value.to_owned())
}
fn parse_missing_field(message: &str) -> Option<String> {
let prefix = "missing field `";
let suffix = "`";
let rest = message.strip_prefix(prefix)?;
let end = rest.find(suffix)?;
Some(rest[..end].to_owned())
}
#[cfg(test)]
mod tests {
use super::{Config, ConfigError};
use serde::Deserialize;
use serde_json::json;
use std::fs;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Deserialize, PartialEq)]
struct AppConfig {
app_name: String,
port: i64,
debug: bool,
database: DatabaseConfig,
}
#[derive(Debug, Deserialize, PartialEq)]
struct DatabaseConfig {
host: String,
pool: i64,
}
#[test]
fn config_from_toml_reads_top_level_and_nested_values() {
let config = Config::from_toml(
r#"
app_name = "rustrails"
port = 3000
debug = true
[database]
host = "localhost"
pool = 5
"#,
)
.unwrap();
assert_eq!(config.get_str("app_name"), Some("rustrails"));
assert_eq!(config.get_i64("port"), Some(3000));
assert_eq!(config.get_bool("debug"), Some(true));
assert_eq!(config.get_str("database.host"), Some("localhost"));
assert_eq!(config.get_i64("database.pool"), Some(5));
}
#[test]
fn config_set_creates_nested_paths() {
let mut config = Config::new();
config.set("service.name", json!("api"));
config.set("service.enabled", json!(true));
assert_eq!(config.get_str("service.name"), Some("api"));
assert_eq!(config.get_bool("service.enabled"), Some(true));
}
#[test]
fn config_from_file_reads_toml() {
let unique = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let path = std::env::temp_dir().join(format!("rustrails-support-config-{unique}.toml"));
fs::write(&path, "name = \"support\"\n").unwrap();
let config = Config::from_file(&path).unwrap();
fs::remove_file(&path).unwrap();
assert_eq!(config.get_str("name"), Some("support"));
}
#[test]
fn config_from_file_returns_not_found_error_for_missing_file() {
let path = std::env::temp_dir().join("rustrails-support-missing.toml");
let error = Config::from_file(&path).unwrap_err();
assert!(matches!(error, ConfigError::FileNotFound(_)));
}
#[test]
fn config_apply_env_overrides_replaces_existing_values() {
let mut config = Config::from_toml(
r#"
app_name = "rustrails"
port = 3000
[database]
host = "localhost"
pool = 5
"#,
)
.unwrap();
unsafe {
std::env::set_var("RUSTRAILS_PORT", "4000");
std::env::set_var("RUSTRAILS_DATABASE_HOST", "db.internal");
std::env::set_var("RUSTRAILS_DEBUG", "true");
}
config.apply_env_overrides("RUSTRAILS");
unsafe {
std::env::remove_var("RUSTRAILS_PORT");
std::env::remove_var("RUSTRAILS_DATABASE_HOST");
std::env::remove_var("RUSTRAILS_DEBUG");
}
assert_eq!(config.get_i64("port"), Some(4000));
assert_eq!(config.get_str("database.host"), Some("db.internal"));
assert_eq!(config.get_bool("debug"), Some(true));
}
#[test]
fn config_apply_env_overrides_parses_json_values() {
let mut config = Config::new();
unsafe {
std::env::set_var("RUSTRAILS_FEATURES", "[\"cache\",\"jobs\"]");
}
config.apply_env_overrides("RUSTRAILS");
unsafe {
std::env::remove_var("RUSTRAILS_FEATURES");
}
assert_eq!(config.get("features"), Some(&json!(["cache", "jobs"])));
}
#[test]
fn config_extract_deserializes_into_typed_struct() {
let config = Config::from_toml(
r#"
app_name = "rustrails"
port = 3000
debug = false
[database]
host = "db.internal"
pool = 7
"#,
)
.unwrap();
let extracted: AppConfig = config.extract().unwrap();
assert_eq!(
extracted,
AppConfig {
app_name: String::from("rustrails"),
port: 3000,
debug: false,
database: DatabaseConfig {
host: String::from("db.internal"),
pool: 7,
},
}
);
}
#[test]
fn config_extract_reports_missing_keys() {
let config = Config::from_toml("app_name = \"rustrails\"\n").unwrap();
let error = config.extract::<AppConfig>().unwrap_err();
assert!(matches!(error, ConfigError::MissingKey(key) if key == "port"));
}
#[test]
fn config_get_returns_none_for_missing_key() {
let config = Config::new();
assert_eq!(config.get("missing.key"), None);
assert_eq!(config.get_str("missing.key"), None);
}
#[test]
fn config_default_matches_new_for_missing_values() {
let config = Config::default();
assert_eq!(config.get("missing"), Config::new().get("missing"));
}
#[test]
fn config_from_toml_reads_array_values() {
let config = Config::from_toml(
r#"
features = ["cache", "jobs"]
"#,
)
.unwrap();
assert_eq!(config.get("features"), Some(&json!(["cache", "jobs"])));
}
#[test]
fn config_from_toml_reads_three_level_nested_values() {
let config = Config::from_toml(
r#"
[database]
[database.primary]
[database.primary.credentials]
user = "postgres"
"#,
)
.unwrap();
assert_eq!(
config.get_str("database.primary.credentials.user"),
Some("postgres")
);
}
#[test]
fn config_from_toml_reports_parse_error_for_invalid_syntax() {
let error = Config::from_toml("[database\nhost = \"localhost\"").unwrap_err();
assert!(matches!(error, ConfigError::ParseError(_)));
}
#[test]
fn config_from_toml_reports_parse_error_for_non_table_root() {
let error = Config::from_toml("\"support\"").unwrap_err();
assert!(matches!(error, ConfigError::ParseError(_)));
}
#[test]
fn config_from_file_reports_parse_error_for_invalid_toml() {
let unique = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let path =
std::env::temp_dir().join(format!("rustrails-support-config-invalid-{unique}.toml"));
fs::write(&path, "[database\nhost = \"localhost\"").unwrap();
let error = Config::from_file(&path).unwrap_err();
fs::remove_file(&path).unwrap();
assert!(matches!(error, ConfigError::ParseError(_)));
}
#[test]
fn config_from_file_reads_nested_values() {
let unique = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let path =
std::env::temp_dir().join(format!("rustrails-support-config-nested-{unique}.toml"));
fs::write(
&path,
"[database]\nhost = \"localhost\"\n[database.credentials]\nuser = \"postgres\"\n",
)
.unwrap();
let config = Config::from_file(&path).unwrap();
fs::remove_file(&path).unwrap();
assert_eq!(config.get_str("database.host"), Some("localhost"));
assert_eq!(
config.get_str("database.credentials.user"),
Some("postgres")
);
}
#[test]
fn config_set_preserves_existing_nested_siblings() {
let mut config = Config::new();
config.set("database.host", json!("localhost"));
config.set("database.pool", json!(5));
assert_eq!(config.get_str("database.host"), Some("localhost"));
assert_eq!(config.get_i64("database.pool"), Some(5));
}
#[test]
fn config_set_deeper_path_does_not_replace_existing_scalar() {
let mut config = Config::new();
config.set("database", json!("localhost"));
config.set("database.host", json!("db.internal"));
assert_eq!(config.get_str("database"), Some("localhost"));
assert_eq!(config.get_str("database.host"), None);
}
#[test]
fn config_set_ignores_empty_path() {
let mut config = Config::new();
config.set("", json!("ignored"));
assert_eq!(config.get(""), None);
assert!(config.get("anything").is_none());
}
#[test]
fn config_get_returns_none_when_descending_into_scalar() {
let config = Config::from_toml("port = 3000\n").unwrap();
assert_eq!(config.get("port.value"), None);
}
#[test]
fn config_get_str_returns_none_for_non_string_value() {
let config = Config::from_toml("port = 3000\n").unwrap();
assert_eq!(config.get_str("port"), None);
}
#[test]
fn config_get_i64_returns_none_for_non_integer_value() {
let config = Config::from_toml("debug = true\n").unwrap();
assert_eq!(config.get_i64("debug"), None);
}
#[test]
fn config_get_bool_returns_none_for_non_boolean_value() {
let config = Config::from_toml("app_name = \"rustrails\"\n").unwrap();
assert_eq!(config.get_bool("app_name"), None);
}
#[test]
fn config_apply_env_overrides_creates_nested_values() {
let mut config = Config::new();
unsafe {
std::env::set_var("RUSTRAILS_SUPPORT_CFG_CREATE_DATABASE_HOST", "db.internal");
}
config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_CREATE");
unsafe {
std::env::remove_var("RUSTRAILS_SUPPORT_CFG_CREATE_DATABASE_HOST");
}
assert_eq!(config.get_str("database.host"), Some("db.internal"));
}
#[test]
fn config_apply_env_overrides_preserves_unrelated_values() {
let mut config = Config::new();
config.set("database.pool", json!(5));
unsafe {
std::env::set_var("RUSTRAILS_SUPPORT_CFG_KEEP_DATABASE_HOST", "db.internal");
}
config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_KEEP");
unsafe {
std::env::remove_var("RUSTRAILS_SUPPORT_CFG_KEEP_DATABASE_HOST");
}
assert_eq!(config.get_i64("database.pool"), Some(5));
assert_eq!(config.get_str("database.host"), Some("db.internal"));
}
#[test]
fn config_apply_env_overrides_parses_false_values() {
let mut config = Config::new();
unsafe {
std::env::set_var("RUSTRAILS_SUPPORT_CFG_FALSE_DEBUG", "false");
}
config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_FALSE");
unsafe {
std::env::remove_var("RUSTRAILS_SUPPORT_CFG_FALSE_DEBUG");
}
assert_eq!(config.get_bool("debug"), Some(false));
}
#[test]
fn config_apply_env_overrides_parses_unsigned_integer_values() {
let mut config = Config::new();
unsafe {
std::env::set_var(
"RUSTRAILS_SUPPORT_CFG_UNSIGNED_LIMIT",
"18446744073709551615",
);
}
config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_UNSIGNED");
unsafe {
std::env::remove_var("RUSTRAILS_SUPPORT_CFG_UNSIGNED_LIMIT");
}
assert_eq!(config.get("limit"), Some(&json!(18446744073709551615_u64)));
}
#[test]
fn config_apply_env_overrides_parses_float_values() {
let mut config = Config::new();
unsafe {
std::env::set_var("RUSTRAILS_SUPPORT_CFG_FLOAT_RATIO", "3.5");
}
config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_FLOAT");
unsafe {
std::env::remove_var("RUSTRAILS_SUPPORT_CFG_FLOAT_RATIO");
}
assert_eq!(config.get("ratio"), Some(&json!(3.5)));
}
#[test]
fn config_apply_env_overrides_parses_object_json() {
let mut config = Config::new();
unsafe {
std::env::set_var(
"RUSTRAILS_SUPPORT_CFG_OBJECT_DATABASE",
r#"{"host":"db.internal"}"#,
);
}
config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_OBJECT");
unsafe {
std::env::remove_var("RUSTRAILS_SUPPORT_CFG_OBJECT_DATABASE");
}
assert_eq!(config.get_str("database.host"), Some("db.internal"));
}
#[test]
fn config_apply_env_overrides_leaves_invalid_json_as_string() {
let mut config = Config::new();
unsafe {
std::env::set_var("RUSTRAILS_SUPPORT_CFG_INVALID_PAYLOAD", "{not-json}");
}
config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_INVALID");
unsafe {
std::env::remove_var("RUSTRAILS_SUPPORT_CFG_INVALID_PAYLOAD");
}
assert_eq!(config.get_str("payload"), Some("{not-json}"));
}
#[test]
fn config_apply_env_overrides_ignores_empty_suffix_segments() {
let mut config = Config::new();
unsafe {
std::env::set_var(
"RUSTRAILS_SUPPORT_CFG_SEGMENTS__DATABASE__HOST",
"db.internal",
);
}
config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_SEGMENTS");
unsafe {
std::env::remove_var("RUSTRAILS_SUPPORT_CFG_SEGMENTS__DATABASE__HOST");
}
assert_eq!(config.get_str("database.host"), Some("db.internal"));
}
#[test]
fn config_clone_is_independent_of_original() {
let mut config = Config::new();
config.set("database.host", json!("localhost"));
let cloned = config.clone();
config.set("database.host", json!("db.internal"));
assert_eq!(cloned.get_str("database.host"), Some("localhost"));
assert_eq!(config.get_str("database.host"), Some("db.internal"));
}
#[test]
fn config_extract_supports_array_fields() {
#[derive(Debug, Deserialize, PartialEq)]
struct FeatureConfig {
features: Vec<String>,
}
let config = Config::from_toml(
r#"
features = ["cache", "jobs"]
"#,
)
.unwrap();
let extracted: FeatureConfig = config.extract().unwrap();
assert_eq!(
extracted,
FeatureConfig {
features: vec![String::from("cache"), String::from("jobs")],
}
);
}
#[test]
fn config_extract_supports_three_level_nested_values() {
#[derive(Debug, Deserialize, PartialEq)]
struct RootConfig {
database: LevelOne,
}
#[derive(Debug, Deserialize, PartialEq)]
struct LevelOne {
primary: LevelTwo,
}
#[derive(Debug, Deserialize, PartialEq)]
struct LevelTwo {
credentials: Credentials,
}
#[derive(Debug, Deserialize, PartialEq)]
struct Credentials {
username: String,
}
let config = Config::from_toml(
r#"
[database.primary.credentials]
username = "postgres"
"#,
)
.unwrap();
let extracted: RootConfig = config.extract().unwrap();
assert_eq!(
extracted,
RootConfig {
database: LevelOne {
primary: LevelTwo {
credentials: Credentials {
username: String::from("postgres"),
},
},
},
}
);
}
#[test]
fn config_extract_reports_parse_error_for_type_mismatch() {
let config = Config::from_toml(
r#"
app_name = "rustrails"
port = "not-a-number"
debug = false
[database]
host = "db.internal"
pool = 7
"#,
)
.unwrap();
let error = config.extract::<AppConfig>().unwrap_err();
assert!(matches!(error, ConfigError::ParseError(_)));
}
#[test]
fn config_extract_reports_missing_nested_field() {
#[derive(Debug, Deserialize, PartialEq)]
struct NestedConfig {
database: NestedDatabase,
}
#[derive(Debug, Deserialize, PartialEq)]
struct NestedDatabase {
credentials: NestedCredentials,
}
#[derive(Debug, Deserialize, PartialEq)]
struct NestedCredentials {
username: String,
}
let config = Config::from_toml(
r#"
[database.credentials]
"#,
)
.unwrap();
let error = config.extract::<NestedConfig>().unwrap_err();
assert!(matches!(error, ConfigError::MissingKey(key) if key == "username"));
}
}