use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use crate::protocol::TimestampFormat;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
#[serde(default)]
pub client: ClientDefaults,
#[serde(default)]
pub server: ServerDefaults,
#[serde(default)]
pub presets: Vec<ServerPreset>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ClientDefaults {
pub duration_secs: Option<u64>,
pub parallel_streams: Option<u8>,
pub tcp_nodelay: Option<bool>,
pub window_size: Option<String>,
pub json_output: Option<bool>,
pub no_tui: Option<bool>,
#[serde(default)]
pub timestamp_format: Option<TimestampFormat>,
pub log_file: Option<String>,
pub log_level: Option<String>,
pub psk: Option<String>,
pub theme: Option<String>,
pub address_family: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ServerDefaults {
pub port: Option<u16>,
pub one_off: Option<bool>,
pub prometheus_port: Option<u16>,
pub push_gateway: Option<String>,
pub log_file: Option<String>,
pub log_level: Option<String>,
pub psk: Option<String>,
pub rate_limit: Option<u32>,
pub rate_limit_window: Option<u64>,
pub allow: Option<Vec<String>>,
pub deny: Option<Vec<String>>,
pub acl_file: Option<String>,
pub address_family: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerPreset {
pub name: String,
pub bandwidth_limit: Option<String>,
pub allowed_clients: Option<Vec<String>>,
pub max_duration_secs: Option<u64>,
}
impl Config {
pub fn load() -> anyhow::Result<Self> {
let config_path = Self::config_path();
if config_path.exists() {
let contents = std::fs::read_to_string(&config_path)?;
Ok(toml::from_str(&contents)?)
} else {
Ok(Self::default())
}
}
pub fn config_path() -> PathBuf {
dirs::config_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("xfr")
.join("config.toml")
}
pub fn get_preset(&self, name: &str) -> Option<&ServerPreset> {
self.presets.iter().find(|p| p.name == name)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert!(config.presets.is_empty());
assert!(config.client.duration_secs.is_none());
}
#[test]
fn test_parse_config() {
let toml = r#"
[client]
duration_secs = 30
parallel_streams = 4
tcp_nodelay = true
[server]
port = 9000
prometheus_port = 9090
[[presets]]
name = "limited"
bandwidth_limit = "100M"
max_duration_secs = 60
[[presets]]
name = "internal"
allowed_clients = ["192.168.1.0/24"]
"#;
let config: Config = toml::from_str(toml).unwrap();
assert_eq!(config.client.duration_secs, Some(30));
assert_eq!(config.client.parallel_streams, Some(4));
assert_eq!(config.server.port, Some(9000));
assert_eq!(config.presets.len(), 2);
assert_eq!(config.presets[0].name, "limited");
assert_eq!(config.presets[0].bandwidth_limit, Some("100M".to_string()));
}
#[test]
fn test_get_preset() {
let toml = r#"
[[presets]]
name = "fast"
bandwidth_limit = "1G"
[[presets]]
name = "slow"
bandwidth_limit = "10M"
"#;
let config: Config = toml::from_str(toml).unwrap();
let fast = config.get_preset("fast").unwrap();
assert_eq!(fast.bandwidth_limit, Some("1G".to_string()));
let slow = config.get_preset("slow").unwrap();
assert_eq!(slow.bandwidth_limit, Some("10M".to_string()));
assert!(config.get_preset("nonexistent").is_none());
}
}