use std::{collections::BTreeSet, path::PathBuf, sync::Arc};
use anyhow::{Context as _, Result};
use config::{Config, Environment, File, FileFormat, Source};
use dirs2::config_dir;
use serde::Deserialize;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::{
KexMode,
error::Error,
kex::negotiate::{AlgorithmList, supported_algorithms},
session::SessionRegistry,
to_path_buf,
udp::DiffMode,
};
pub(crate) mod mps;
pub(crate) mod tracing;
pub trait PathDefaults {
fn env_prefix(&self) -> String;
fn config_absolute_path(&self) -> Option<String>;
fn default_file_path(&self) -> String;
fn default_file_name(&self) -> String;
fn tracing_absolute_path(&self) -> Option<String>;
fn default_tracing_path(&self) -> String;
fn default_tracing_file_name(&self) -> String;
}
pub trait KexConfig {
fn mode(&self) -> KexMode;
fn port_pool(&self) -> Option<Arc<Mutex<BTreeSet<u16>>>>;
fn key_pair_paths(&self) -> Result<(PathBuf, PathBuf)>;
fn user(&self) -> Option<String>;
fn session_registry(&self) -> Option<SessionRegistry> {
None
}
fn resume_session_uuid(&self) -> Option<Uuid> {
None
}
fn server_id(&self) -> Option<String> {
None
}
fn diff_mode(&self) -> DiffMode {
DiffMode::Reliable
}
fn preferred_algorithms(&self) -> AlgorithmList {
supported_algorithms()
}
}
pub fn load<'a, S, T, D>(cli: &S, defaults: &D) -> Result<T>
where
T: Deserialize<'a>,
S: Source + Clone + Send + Sync + 'static,
D: PathDefaults,
{
let config_file_path = config_file_path(defaults)?;
let config = Config::builder()
.add_source(
Environment::with_prefix(&defaults.env_prefix())
.separator("_")
.try_parsing(true),
)
.add_source(cli.clone())
.add_source(File::from(config_file_path).format(FileFormat::Toml))
.build()
.with_context(|| Error::ConfigBuild)?;
config
.try_deserialize::<T>()
.with_context(|| Error::ConfigDeserialize)
}
fn config_file_path<D>(defaults: &D) -> Result<PathBuf>
where
D: PathDefaults,
{
let default_fn = || -> Result<PathBuf> { default_config_file_path(defaults) };
defaults
.config_absolute_path()
.as_ref()
.map_or_else(default_fn, to_path_buf)
}
fn default_config_file_path<D>(defaults: &D) -> Result<PathBuf>
where
D: PathDefaults,
{
let mut config_file_path = config_dir().ok_or(Error::ConfigDir)?;
config_file_path.push(defaults.default_file_path());
config_file_path.push(defaults.default_file_name());
let _ = config_file_path.set_extension("toml");
Ok(config_file_path)
}
#[cfg(test)]
mod tests {
use std::{collections::BTreeSet, path::PathBuf, sync::Arc};
use tokio::sync::Mutex;
use uuid::Uuid;
use super::{
KexConfig, KexMode, PathDefaults, SessionRegistry, config_file_path,
default_config_file_path,
};
struct TestKexConfig;
impl KexConfig for TestKexConfig {
fn mode(&self) -> KexMode {
KexMode::Client
}
fn port_pool(&self) -> Option<Arc<Mutex<BTreeSet<u16>>>> {
None
}
fn key_pair_paths(&self) -> anyhow::Result<(PathBuf, PathBuf)> {
Ok((PathBuf::from("/tmp/pub"), PathBuf::from("/tmp/priv")))
}
fn user(&self) -> Option<String> {
Some("testuser".to_string())
}
}
#[test]
fn kex_config_session_registry_default_is_none() {
let cfg = TestKexConfig;
let reg: Option<SessionRegistry> = cfg.session_registry();
assert!(reg.is_none());
}
#[test]
fn kex_config_resume_session_uuid_default_is_none() {
let cfg = TestKexConfig;
let uuid: Option<Uuid> = cfg.resume_session_uuid();
assert!(uuid.is_none());
}
#[test]
fn kex_config_server_id_default_is_none() {
let cfg = TestKexConfig;
let sid: Option<String> = cfg.server_id();
assert!(sid.is_none());
}
struct TestPathDefaults;
impl PathDefaults for TestPathDefaults {
fn env_prefix(&self) -> String {
"TEST".to_string()
}
fn config_absolute_path(&self) -> Option<String> {
None
}
fn default_file_path(&self) -> String {
"moshpit-test".to_string()
}
fn default_file_name(&self) -> String {
"config".to_string()
}
fn tracing_absolute_path(&self) -> Option<String> {
None
}
fn default_tracing_path(&self) -> String {
"moshpit-test".to_string()
}
fn default_tracing_file_name(&self) -> String {
"moshpits".to_string()
}
}
#[test]
fn default_config_file_path_ends_with_toml() {
let defaults = TestPathDefaults;
if let Ok(path) = default_config_file_path(&defaults) {
assert_eq!(path.extension().and_then(|e| e.to_str()), Some("toml"));
let path_str = path.to_string_lossy();
assert!(
path_str.contains("moshpit-test"),
"path must contain the default file path component"
);
}
}
struct AbsolutePathDefaults;
impl PathDefaults for AbsolutePathDefaults {
fn env_prefix(&self) -> String {
"TEST".to_string()
}
fn config_absolute_path(&self) -> Option<String> {
Some("/tmp/my-moshpit-config.toml".to_string())
}
fn default_file_path(&self) -> String {
"unused".to_string()
}
fn default_file_name(&self) -> String {
"unused".to_string()
}
fn tracing_absolute_path(&self) -> Option<String> {
None
}
fn default_tracing_path(&self) -> String {
"unused".to_string()
}
fn default_tracing_file_name(&self) -> String {
"unused".to_string()
}
}
#[test]
fn config_file_path_uses_absolute_path_when_provided() {
let defaults = AbsolutePathDefaults;
let path =
config_file_path(&defaults).expect("config_file_path must succeed with absolute path");
assert_eq!(
path,
PathBuf::from("/tmp/my-moshpit-config.toml"),
"config_file_path must return the exact absolute path from config_absolute_path()"
);
}
}