use crate::{
error::{Error, Result},
expand::expand_env_vars,
merge::{deep_merge, env_str_to_value, set_dotted},
validation::Validate,
};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::fs;
use std::path::{Path, PathBuf};
use toml::Value;
enum ConfigLayer {
Str(String, &'static str ),
File(PathBuf),
FileIfExists(PathBuf),
FindFile(String ),
}
#[derive(Debug, Clone)]
pub struct ConfigFile<T> {
pub config: T,
pub path: PathBuf,
}
impl<T> ConfigFile<T> {
#[must_use]
pub fn resolve(&self, relative: impl AsRef<Path>) -> PathBuf {
let rel = relative.as_ref();
if rel.is_absolute() {
return rel.to_path_buf();
}
let dir = self.path.parent().unwrap_or_else(|| Path::new("."));
dir.join(rel)
}
}
impl<T> std::ops::Deref for ConfigFile<T> {
type Target = T;
fn deref(&self) -> &T {
&self.config
}
}
impl<T> std::ops::DerefMut for ConfigFile<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.config
}
}
pub struct Loader {
layers: Vec<ConfigLayer>,
env_prefix: Option<String>,
}
impl Default for Loader {
fn default() -> Self {
Self::new()
}
}
impl Loader {
#[must_use]
pub fn new() -> Self {
Self {
layers: Vec::new(),
env_prefix: None,
}
}
#[must_use]
pub fn layer_str(mut self, content: impl Into<String>, label: &'static str) -> Self {
self.layers.push(ConfigLayer::Str(content.into(), label));
self
}
#[must_use]
pub fn layer_file(mut self, path: impl Into<PathBuf>) -> Self {
self.layers.push(ConfigLayer::File(path.into()));
self
}
#[must_use]
pub fn layer_file_if_exists(mut self, path: impl Into<PathBuf>) -> Self {
self.layers.push(ConfigLayer::FileIfExists(path.into()));
self
}
#[must_use]
pub fn find_file(mut self, file_name: impl Into<String>) -> Self {
self.layers.push(ConfigLayer::FindFile(file_name.into()));
self
}
#[must_use]
pub fn env_prefix(mut self, prefix: impl Into<String>) -> Self {
self.env_prefix = Some(prefix.into());
self
}
pub fn load<T: DeserializeOwned>(self) -> Result<T> {
let (merged, _) = self.merge_layers()?;
deserialize_value(merged, "merged config")
}
pub fn load_file<T: DeserializeOwned>(self) -> Result<ConfigFile<T>> {
let (merged, last_path) = self.merge_layers()?;
let config = deserialize_value(merged, "merged config")?;
Ok(ConfigFile {
config,
path: last_path,
})
}
pub fn load_validated<T: DeserializeOwned + Validate>(self) -> Result<T> {
let cfg: T = self.load()?;
cfg.check()?;
Ok(cfg)
}
fn merge_layers(self) -> Result<(Value, PathBuf)> {
let mut merged = Value::Table(toml::map::Map::new());
let mut last_file_path = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
for layer in self.layers {
match layer {
ConfigLayer::Str(content, label) => {
let expanded = expand_env_vars(&content);
let val = parse_str(&expanded, label)?;
deep_merge(&mut merged, val);
}
ConfigLayer::File(path) => {
if !path.exists() {
return Err(Error::FileNotFound(path));
}
let val = load_file_as_value(&path)?;
last_file_path = path;
deep_merge(&mut merged, val);
}
ConfigLayer::FileIfExists(path) => {
if path.exists() {
let val = load_file_as_value(&path)?;
last_file_path = path;
deep_merge(&mut merged, val);
}
}
ConfigLayer::FindFile(file_name) => {
if let Some(path) = find_config_file_from_cwd(&file_name) {
let val = load_file_as_value(&path)?;
last_file_path = path;
deep_merge(&mut merged, val);
}
}
}
}
if let Some(prefix) = self.env_prefix {
let prefix_upper = prefix.to_ascii_uppercase();
for (key, val) in std::env::vars() {
let key_upper = key.to_ascii_uppercase();
if let Some(suffix) = key_upper.strip_prefix(&prefix_upper) {
let toml_key = suffix.replace("__", ".").to_ascii_lowercase();
let toml_val = env_str_to_value(&val);
set_dotted(&mut merged, &toml_key, toml_val);
}
}
}
Ok((merged, last_file_path))
}
}
#[must_use]
pub fn find_config_file(file_name: &str, start: impl AsRef<Path>) -> Option<PathBuf> {
let mut dir = start.as_ref().to_path_buf();
if dir.is_file() {
dir.pop();
}
loop {
let candidate = dir.join(file_name);
if candidate.exists() {
return Some(candidate);
}
if !dir.pop() {
return None;
}
}
}
pub fn find_and_load<T: DeserializeOwned>(
file_name: &str, start: impl AsRef<Path>,
) -> Result<(PathBuf, T)> {
let path = find_config_file(file_name, start)
.ok_or_else(|| Error::FileNotFound(PathBuf::from(file_name)))?;
let cfg = load_file(&path)?;
Ok((path, cfg))
}
pub fn from_str<T: DeserializeOwned>(content: &str) -> Result<T> {
let expanded = expand_env_vars(content);
parse_str(&expanded, "inline string")
}
pub fn load_file<T: DeserializeOwned>(path: impl AsRef<Path>) -> Result<T> {
let path = path.as_ref();
let content = read_file(path)?;
let expanded = expand_env_vars(&content);
parse_str(&expanded, &path.display().to_string())
}
pub fn to_string<T: Serialize>(value: &T) -> Result<String> {
toml::to_string_pretty(value).map_err(Error::from)
}
pub fn save_file<T: Serialize>(value: &T, path: impl AsRef<Path>) -> Result<()> {
let path = path.as_ref();
let toml = to_string(value)?;
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
fs::create_dir_all(parent).map_err(|e| Error::io(parent, e))?;
}
}
fs::write(path, toml).map_err(|e| Error::io(path, e))
}
fn read_file(path: &Path) -> Result<String> {
if !path.exists() {
return Err(Error::FileNotFound(path.to_path_buf()));
}
fs::read_to_string(path).map_err(|e| Error::io(path, e))
}
fn load_file_as_value(path: &Path) -> Result<Value> {
let content = read_file(path)?;
let expanded = expand_env_vars(&content);
parse_str(&expanded, &path.display().to_string())
}
fn parse_str<T: DeserializeOwned>(content: &str, label: &str) -> Result<T> {
toml::from_str(content).map_err(|e| Error::parse(label, e))
}
fn deserialize_value<T: DeserializeOwned>(value: Value, label: &str) -> Result<T> {
T::deserialize(value).map_err(|e| Error::parse(label, e))
}
fn find_config_file_from_cwd(file_name: &str) -> Option<PathBuf> {
let cwd = std::env::current_dir().ok()?;
find_config_file(file_name, cwd)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use std::io::Write;
use tempfile::{NamedTempFile, TempDir};
#[derive(Debug, Deserialize, Serialize, PartialEq)]
struct Simple {
name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
port: Option<u16>,
}
fn write_toml(dir: &TempDir, name: &str, content: &str) -> PathBuf {
let path = dir.path().join(name);
std::fs::write(&path, content).unwrap();
path
}
#[test]
fn from_str_parses() {
let cfg: Simple = from_str("name = \"hello\"\nport = 8080\n").unwrap();
assert_eq!(cfg.name, "hello");
assert_eq!(cfg.port, Some(8080));
}
#[test]
fn load_file_ok() {
let mut f = NamedTempFile::new().unwrap();
writeln!(f, "name = \"test\"").unwrap();
let cfg: Simple = load_file(f.path()).unwrap();
assert_eq!(cfg.name, "test");
}
#[test]
fn load_file_not_found() {
let result = load_file::<Simple>("/nonexistent/x.toml");
assert!(matches!(result, Err(Error::FileNotFound(_))));
}
#[test]
fn loader_layer_str_and_file() {
let dir = TempDir::new().unwrap();
write_toml(&dir, "a.toml", "name = \"from-file\"\nport = 9090\n");
let cfg: Simple = Loader::new()
.layer_str("name = \"default\"\nport = 8080\n", "defaults")
.layer_file(dir.path().join("a.toml"))
.load()
.unwrap();
assert_eq!(cfg.name, "from-file"); assert_eq!(cfg.port, Some(9090));
}
#[test]
fn loader_file_if_exists_skips_missing() {
let cfg: Simple = Loader::new()
.layer_str("name = \"default\"", "defaults")
.layer_file_if_exists("/nonexistent/optional.toml")
.load()
.unwrap();
assert_eq!(cfg.name, "default");
}
#[test]
fn loader_env_prefix_override() {
std::env::set_var("STTOML_NAME", "env-name");
let result = Loader::new()
.layer_str("name = \"original\"", "defaults")
.env_prefix("STTOML_")
.load::<Simple>();
std::env::remove_var("STTOML_NAME");
let cfg = result.unwrap();
assert_eq!(cfg.name, "env-name");
}
#[test]
fn loader_env_prefix_nested_double_underscore() {
#[derive(Deserialize, PartialEq, Debug)]
struct Outer {
server: Server,
}
#[derive(Deserialize, PartialEq, Debug)]
struct Server {
port: u16,
}
std::env::set_var("STTOML2_SERVER__PORT", "9999");
let result = Loader::new()
.layer_str("[server]\nport = 8080\n", "defaults")
.env_prefix("STTOML2_")
.load::<Outer>();
std::env::remove_var("STTOML2_SERVER__PORT");
assert_eq!(result.unwrap().server.port, 9999);
}
#[test]
fn config_file_resolves_relative_paths() {
let dir = TempDir::new().unwrap();
write_toml(&dir, "app.toml", "name = \"app\"\n");
let path = dir.path().join("app.toml");
let cf: ConfigFile<Simple> = Loader::new().layer_file(&path).load_file().unwrap();
let resolved = cf.resolve("templates/foo.tera");
assert_eq!(resolved, dir.path().join("templates/foo.tera"));
}
#[test]
fn find_config_file_walks_up() {
let dir = TempDir::new().unwrap();
let child = dir.path().join("a/b/c");
std::fs::create_dir_all(&child).unwrap();
let config = dir.path().join("myconfig.toml");
std::fs::write(&config, "").unwrap();
let found = find_config_file("myconfig.toml", &child);
assert_eq!(found, Some(config));
}
#[test]
fn find_config_file_none_when_absent() {
let dir = TempDir::new().unwrap();
assert!(find_config_file("missing.toml", dir.path()).is_none());
}
#[test]
fn find_and_load_returns_path_and_config() {
let dir = TempDir::new().unwrap();
let child = dir.path().join("sub");
std::fs::create_dir_all(&child).unwrap();
write_toml(&dir, "x.toml", "name = \"found\"\n");
let (path, cfg): (PathBuf, Simple) = find_and_load("x.toml", &child).unwrap();
assert_eq!(path, dir.path().join("x.toml"));
assert_eq!(cfg.name, "found");
}
#[test]
fn save_file_round_trips_with_load_file() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("nested/out.toml");
let original = Simple {
name: "round-trip".into(),
port: Some(1234),
};
save_file(&original, &path).unwrap();
assert!(path.exists());
let reloaded: Simple = load_file(&path).unwrap();
assert_eq!(reloaded, original);
}
#[test]
fn loader_three_layer_precedence() {
let dir = TempDir::new().unwrap();
write_toml(&dir, "mid.toml", "name = \"mid\"\nport = 2000\n");
write_toml(&dir, "top.toml", "port = 3000\n");
let cfg: Simple = Loader::new()
.layer_str("name = \"base\"\nport = 1000\n", "base")
.layer_file(dir.path().join("mid.toml"))
.layer_file(dir.path().join("top.toml"))
.load()
.unwrap();
assert_eq!(cfg.name, "mid");
assert_eq!(cfg.port, Some(3000));
}
}