use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
use config::{Config, Environment, File, FileFormat, Source};
use getset::{CopyGetters, Getters};
use serde::{Deserialize, Serialize};
use tracing_subscriber_init::TracingConfig;
use crate::{error::Error, utils::to_path_buf};
pub(crate) trait PathDefaults {
fn env_prefix(&self) -> String;
fn app_name(&self) -> String;
fn config_absolute_path(&self) -> Option<String>;
fn tracing_absolute_path(&self) -> Option<String>;
}
const DEFAULT_PASSPHRASE_CACHE_TIMEOUT: u64 = 3600;
#[derive(Clone, CopyGetters, Debug, Deserialize, Eq, Getters, PartialEq, Serialize)]
#[serde(default)]
pub(crate) struct ConfigSalusAgent {
#[getset(get_copy = "pub(crate)")]
verbose: u8,
#[getset(get_copy = "pub(crate)")]
quiet: u8,
#[getset(get_copy = "pub(crate)")]
enable_std_output: bool,
#[getset(get_copy = "pub(crate)")]
passphrase_cache_timeout: u64,
#[getset(get = "pub(crate)")]
socket_path: Option<String>,
#[getset(get = "pub(crate)")]
tracing: Tracing,
}
impl Default for ConfigSalusAgent {
fn default() -> Self {
Self {
verbose: 0,
quiet: 0,
enable_std_output: false,
passphrase_cache_timeout: DEFAULT_PASSPHRASE_CACHE_TIMEOUT,
socket_path: None,
tracing: Tracing::default(),
}
}
}
impl TracingConfig for ConfigSalusAgent {
fn quiet(&self) -> u8 {
self.quiet
}
fn verbose(&self) -> u8 {
self.verbose
}
fn with_target(&self) -> bool {
self.tracing.with_target
}
fn with_thread_ids(&self) -> bool {
self.tracing.with_thread_ids
}
fn with_thread_names(&self) -> bool {
self.tracing.with_thread_names
}
fn with_line_number(&self) -> bool {
self.tracing.with_line_number
}
fn with_level(&self) -> bool {
self.tracing.with_level
}
}
#[allow(clippy::struct_excessive_bools)]
#[derive(Clone, CopyGetters, Debug, Default, Deserialize, Eq, Getters, PartialEq, Serialize)]
#[serde(default)]
pub(crate) struct Tracing {
#[getset(get_copy = "pub(crate)")]
with_target: bool,
#[getset(get_copy = "pub(crate)")]
with_thread_ids: bool,
#[getset(get_copy = "pub(crate)")]
with_thread_names: bool,
#[getset(get_copy = "pub(crate)")]
with_line_number: bool,
#[getset(get_copy = "pub(crate)")]
with_level: bool,
#[getset(get = "pub(crate)")]
directives: Option<String>,
}
pub(crate) fn load<'a, S, T, D>(cli: &S, defaults: &D) -> Result<T>
where
T: Deserialize<'a>,
S: Source + Clone + Send + Sync + 'static,
D: PathDefaults,
{
let config_file_path = config_file_path(defaults)?;
let config = Config::builder()
.add_source(
File::from(config_file_path)
.format(FileFormat::Toml)
.required(false),
)
.add_source(env_source(&defaults.env_prefix()))
.add_source(cli.clone())
.build()
.with_context(|| Error::ConfigBuild)?;
config
.try_deserialize::<T>()
.with_context(|| Error::ConfigDeserialize)
}
pub(crate) fn env_source(prefix: &str) -> Environment {
Environment::with_prefix(prefix)
.prefix_separator("_")
.separator("__")
.try_parsing(true)
}
fn config_file_path<D>(defaults: &D) -> Result<PathBuf>
where
D: PathDefaults,
{
let default_fn = || -> Result<PathBuf> { default_config_file_path(defaults) };
defaults
.config_absolute_path()
.as_ref()
.map_or_else(default_fn, to_path_buf)
}
fn default_config_file_path<D>(defaults: &D) -> Result<PathBuf>
where
D: PathDefaults,
{
let base = dirs2::config_dir().ok_or(Error::ConfigDir)?;
Ok(config_file_in(&base, &defaults.app_name()))
}
fn config_file_in(base: &Path, app: &str) -> PathBuf {
base.join(app).join(app).with_extension("toml")
}
#[cfg(test)]
mod test {
use std::path::Path;
use anyhow::Result;
use config::{Config, ConfigError, Map, Source, Value, ValueKind};
use super::{
ConfigSalusAgent, DEFAULT_PASSPHRASE_CACHE_TIMEOUT, PathDefaults, Tracing, config_file_in,
env_source, load,
};
#[derive(Clone, Debug)]
struct TestCli {
config_path: Option<String>,
socket_path: Option<String>,
}
impl Source for TestCli {
fn clone_into_box(&self) -> Box<dyn Source + Send + Sync> {
Box::new(self.clone())
}
fn collect(&self) -> Result<Map<String, Value>, ConfigError> {
let mut map = Map::new();
if let Some(socket_path) = &self.socket_path {
let _old = map.insert(
"socket_path".to_string(),
Value::new(None, ValueKind::String(socket_path.clone())),
);
}
Ok(map)
}
}
impl PathDefaults for TestCli {
fn env_prefix(&self) -> String {
"SALUSAGENTTEST".to_string()
}
fn app_name(&self) -> String {
"salus-agent-test".to_string()
}
fn config_absolute_path(&self) -> Option<String> {
self.config_path.clone()
}
fn tracing_absolute_path(&self) -> Option<String> {
None
}
}
#[test]
fn config_file_in_composes_app_dir_and_extension() {
let path = config_file_in(Path::new("/base"), "salus-agent");
assert_eq!(path, Path::new("/base/salus-agent/salus-agent.toml"));
}
#[test]
fn defaults_match_documented_values() {
let cfg = ConfigSalusAgent::default();
assert_eq!(cfg.verbose(), 0);
assert_eq!(cfg.quiet(), 0);
assert!(!cfg.enable_std_output());
assert_eq!(
cfg.passphrase_cache_timeout(),
DEFAULT_PASSPHRASE_CACHE_TIMEOUT
);
assert!(cfg.socket_path().is_none());
let tracing = Tracing::default();
assert!(!tracing.with_target());
assert!(!tracing.with_level());
assert!(tracing.directives().is_none());
}
#[test]
fn load_layers_file_env_and_cli() -> Result<()> {
let cli = TestCli {
config_path: Some("/nonexistent/salus-agent-test.toml".to_string()),
socket_path: Some("/tmp/agent-test.sock".to_string()),
};
let cfg: ConfigSalusAgent = load(&cli, &cli)?;
assert_eq!(cfg.socket_path().as_deref(), Some("/tmp/agent-test.sock"));
assert_eq!(
cfg.passphrase_cache_timeout(),
DEFAULT_PASSPHRASE_CACHE_TIMEOUT
);
Ok(())
}
#[test]
fn missing_fields_fall_back_to_defaults() -> Result<()> {
let config = Config::builder().build()?;
let cfg: ConfigSalusAgent = config.try_deserialize()?;
assert_eq!(
cfg.passphrase_cache_timeout(),
DEFAULT_PASSPHRASE_CACHE_TIMEOUT
);
assert_eq!(cfg.verbose(), 0);
assert!(!cfg.enable_std_output());
assert!(cfg.socket_path().is_none());
Ok(())
}
#[test]
fn env_separators_map_flat_and_nested_fields() -> Result<()> {
let mut map = Map::new();
let _old = map.insert(
"SALUSAGENT_PASSPHRASE_CACHE_TIMEOUT".to_string(),
"99".to_string(),
);
let _old = map.insert(
"SALUSAGENT_TRACING__WITH_TARGET".to_string(),
"true".to_string(),
);
let config = Config::builder()
.add_source(env_source("SALUSAGENT").source(Some(map)))
.build()?;
let cfg: ConfigSalusAgent = config.try_deserialize()?;
assert_eq!(cfg.passphrase_cache_timeout(), 99);
assert!(cfg.tracing().with_target());
Ok(())
}
}