use std::{
fs,
path::{Path, PathBuf},
time::Instant,
};
use anyhow::{Context, Result};
use serde::de::DeserializeOwned;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ConfigFileFormat {
Json,
Yaml,
Toml,
}
impl ConfigFileFormat {
pub fn extensions(&self) -> &'static [&'static str] {
match self {
ConfigFileFormat::Json => &["json"],
ConfigFileFormat::Yaml => &["yaml", "yml"],
ConfigFileFormat::Toml => &["toml"],
}
}
pub fn parse<T: DeserializeOwned>(&self, content: &str, path: &Path) -> Result<T> {
let start = Instant::now();
let result = match self {
ConfigFileFormat::Json => serde_json::from_str(content)
.with_context(|| format!("Failed to parse JSON config file: {}", path.display())),
ConfigFileFormat::Yaml => serde_yaml_ng::from_str(content)
.with_context(|| format!("Failed to parse YAML config file: {}", path.display())),
ConfigFileFormat::Toml => toml::from_str(content)
.with_context(|| format!("Failed to parse TOML config file: {}", path.display())),
};
let duration = start.elapsed();
tracing::debug!(
"⚡ Parsed {} config in {:?}: {}",
format!("{self:?}").to_lowercase(),
duration,
path.display()
);
result
}
pub fn from_extension(ext: &str) -> Option<Self> {
match ext.to_lowercase().as_str() {
"json" => Some(ConfigFileFormat::Json),
"yaml" | "yml" => Some(ConfigFileFormat::Yaml),
"toml" => Some(ConfigFileFormat::Toml),
_ => None,
}
}
}
pub struct ConfigFileLoader {
base_name: String,
}
impl ConfigFileLoader {
pub fn new(base_name: impl Into<String>) -> Self {
Self {
base_name: base_name.into(),
}
}
#[doc(hidden)]
#[allow(dead_code)]
pub fn discover_all_config_files_in(&self, dir: &Path) -> Vec<(PathBuf, ConfigFileFormat)> {
self.discover_all_config_files_from(Some(dir))
}
fn discover_all_config_files_from(
&self,
starting_dir: Option<&Path>,
) -> Vec<(PathBuf, ConfigFileFormat)> {
let start = Instant::now();
let mut found_files = Vec::new();
let mut checked_paths = Vec::new();
let start_dir = match starting_dir {
Some(dir) => dir.to_path_buf(),
None => match std::env::current_dir() {
Ok(dir) => dir,
Err(e) => {
tracing::warn!("Failed to get current directory: {}", e);
return found_files;
}
},
};
let mut dir = start_dir.clone();
if let Some(found) = self.check_directory(&dir, &mut checked_paths) {
found_files.push(found);
tracing::trace!("Found config in starting dir: {}", dir.display());
}
while dir.pop() {
let git_dir = dir.join(".git");
let is_git_root = git_dir.exists() && (git_dir.is_dir() || git_dir.is_file());
if let Some(found) = self.check_directory(&dir, &mut checked_paths) {
found_files.push(found);
tracing::trace!("Found config in parent dir: {}", dir.display());
}
if is_git_root {
tracing::trace!("Reached git root: {}", dir.display());
break;
}
}
if starting_dir.is_none() {
if let Some(config_dir) = dirs::config_dir()
&& let Some(found) = self.check_directory(&config_dir, &mut checked_paths)
{
found_files.push(found);
tracing::trace!("Found config in user dir: {}", config_dir.display());
}
if let Some(found) = self.check_directory("/etc", &mut checked_paths) {
found_files.push(found);
tracing::trace!("Found config in system dir: /etc");
}
}
let duration = start.elapsed();
tracing::debug!(
"⚡ Discovered {} config files in {:?} (checked {} paths)",
found_files.len(),
duration,
checked_paths.len()
);
found_files
}
fn check_directory(
&self,
dir: impl AsRef<Path>,
checked_paths: &mut Vec<PathBuf>,
) -> Option<(PathBuf, ConfigFileFormat)> {
let dir = dir.as_ref();
let formats = [
ConfigFileFormat::Json,
ConfigFileFormat::Yaml,
ConfigFileFormat::Toml,
];
for format in formats {
for ext in format.extensions() {
let dotted_path = dir.join(format!(".{}.{}", self.base_name, ext));
if dotted_path.exists() && dotted_path.is_file() {
checked_paths.push(dotted_path.clone());
tracing::trace!("Found config file: {}", dotted_path.display());
return Some((dotted_path, format));
}
checked_paths.push(dotted_path);
let path = dir.join(format!("{}.{}", self.base_name, ext));
if path.exists() && path.is_file() {
checked_paths.push(path.clone());
tracing::trace!("Found config file: {}", path.display());
return Some((path, format));
}
checked_paths.push(path);
}
}
None
}
pub fn load_config_file<T: DeserializeOwned>(
&self,
path: &Path,
format: ConfigFileFormat,
) -> Result<T> {
let start = Instant::now();
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read config file: {}", path.display()))?;
let read_duration = start.elapsed();
tracing::debug!(
"⚡ Read config file in {:?}: {} ({} bytes)",
read_duration,
path.display(),
content.len()
);
let config = format.parse(&content, path)?;
let total_duration = start.elapsed();
tracing::debug!("⚡ Total config loading time: {:?}", total_duration);
Ok(config)
}
pub fn discover_and_load_merged<T>(&self) -> Result<Option<T>>
where
T: DeserializeOwned + Default + Clone + serde::Serialize,
{
self.discover_and_load_merged_from(None)
}
#[doc(hidden)]
#[allow(dead_code)]
pub fn discover_and_load_merged_in<T>(&self, dir: &Path) -> Result<Option<T>>
where
T: DeserializeOwned + Default + Clone + serde::Serialize,
{
self.discover_and_load_merged_from(Some(dir))
}
fn discover_and_load_merged_from<T>(&self, starting_dir: Option<&Path>) -> Result<Option<T>>
where
T: DeserializeOwned + Default + Clone + serde::Serialize,
{
let files = self.discover_all_config_files_from(starting_dir);
if files.is_empty() {
return Ok(None);
}
let start = Instant::now();
let mut merged_config = T::default();
for (path, format) in files.into_iter().rev() {
match self.load_config_file::<T>(&path, format) {
Ok(file_config) => {
merged_config = self.merge_configs(merged_config, file_config)?;
tracing::debug!("✅ Merged config from: {}", path.display());
}
Err(e) => {
tracing::warn!("⚠️ Failed to load config file {}: {}", path.display(), e);
}
}
}
let duration = start.elapsed();
tracing::debug!("⚡ Total config merging time: {:?}", duration);
Ok(Some(merged_config))
}
fn merge_configs<T>(&self, base_config: T, override_config: T) -> Result<T>
where
T: DeserializeOwned + serde::Serialize,
{
let mut base_value = serde_json::to_value(base_config)?;
let override_value = serde_json::to_value(override_config)?;
Self::merge_json_values(&mut base_value, override_value);
let merged_config = serde_json::from_value(base_value)?;
Ok(merged_config)
}
fn merge_json_values(base: &mut serde_json::Value, override_val: serde_json::Value) {
match (&mut *base, override_val) {
(serde_json::Value::Object(base_map), serde_json::Value::Object(override_map)) => {
for (key, value) in override_map {
match base_map.get_mut(&key) {
Some(base_value) => {
Self::merge_json_values(base_value, value);
}
None => {
base_map.insert(key, value);
}
}
}
}
(base_ref, override_val) => {
*base_ref = override_val;
}
}
}
}