use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use crate::utils::SMirrorsError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
#[serde(default)]
pub general: GeneralConfig,
#[serde(default)]
pub testing: TestingConfig,
#[serde(default)]
pub distro: DistroConfig,
#[serde(default)]
pub static_mirrors: HashMap<String, String>,
#[serde(default)]
pub logging: LoggingConfig,
#[serde(default)]
pub notifications: NotificationConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneralConfig {
#[serde(default = "default_interval")]
pub update_interval: String,
#[serde(default = "default_true")]
pub auto_update: bool,
#[serde(default = "default_concurrent")]
pub concurrent_tests: usize,
#[serde(default = "default_timeout")]
pub timeout: u64,
#[serde(default = "default_retries")]
pub retries: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestingConfig {
#[serde(default = "default_speed_weight")]
pub speed_weight: f64,
#[serde(default = "default_latency_weight")]
pub latency_weight: f64,
#[serde(default = "default_test_size")]
pub test_file_size: String,
#[serde(default = "default_max_mirrors")]
pub max_mirrors: usize,
#[serde(default)]
pub country_preference: Vec<String>,
#[serde(default = "default_min_score")]
pub min_score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistroConfig {
#[serde(default)]
pub auto_detect: bool,
pub override_distro: Option<String>,
#[serde(default)]
pub preserve_comments: bool,
#[serde(default = "default_true")]
pub create_backup: bool,
#[serde(default = "default_backup_count")]
pub backup_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
#[serde(default = "default_log_level")]
pub level: String,
#[serde(default = "default_log_format")]
pub format: String,
pub file: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NotificationConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub on_success: bool,
#[serde(default = "default_true")]
pub on_failure: bool,
}
fn default_interval() -> String {
"1h".to_string()
}
fn default_true() -> bool {
true
}
fn default_concurrent() -> usize {
10
}
fn default_timeout() -> u64 {
10
}
fn default_retries() -> u32 {
3
}
fn default_speed_weight() -> f64 {
0.7
}
fn default_latency_weight() -> f64 {
0.3
}
fn default_test_size() -> String {
"1MB".to_string()
}
fn default_max_mirrors() -> usize {
5
}
fn default_min_score() -> f64 {
0.3
}
fn default_backup_count() -> usize {
5
}
fn default_log_level() -> String {
"info".to_string()
}
fn default_log_format() -> String {
"pretty".to_string()
}
impl Default for GeneralConfig {
fn default() -> Self {
Self {
update_interval: default_interval(),
auto_update: default_true(),
concurrent_tests: default_concurrent(),
timeout: default_timeout(),
retries: default_retries(),
}
}
}
impl Default for TestingConfig {
fn default() -> Self {
Self {
speed_weight: default_speed_weight(),
latency_weight: default_latency_weight(),
test_file_size: default_test_size(),
max_mirrors: default_max_mirrors(),
country_preference: Vec::new(),
min_score: default_min_score(),
}
}
}
impl Default for DistroConfig {
fn default() -> Self {
Self {
auto_detect: true,
override_distro: None,
preserve_comments: true,
create_backup: default_true(),
backup_count: default_backup_count(),
}
}
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: default_log_level(),
format: default_log_format(),
file: None,
}
}
}
impl Default for NotificationConfig {
fn default() -> Self {
Self {
enabled: false,
on_success: false,
on_failure: default_true(),
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
general: GeneralConfig::default(),
testing: TestingConfig::default(),
distro: DistroConfig::default(),
static_mirrors: HashMap::new(),
logging: LoggingConfig::default(),
notifications: NotificationConfig::default(),
}
}
}
impl Config {
pub fn load() -> Result<Self> {
let config_path = Self::config_path()?;
if config_path.exists() {
let content = std::fs::read_to_string(&config_path)
.context(format!("Failed to read config file at {:?}", config_path))?;
let config: Config = toml::from_str(&content)
.context("Failed to parse configuration file")?;
config.validate()?;
Ok(config)
} else {
let config = Self::default();
config.save()?;
Ok(config)
}
}
pub fn load_from(path: &PathBuf) -> Result<Self> {
if !path.exists() {
return Err(SMirrorsError::ConfigNotFound(
path.display().to_string(),
)
.into());
}
let content = std::fs::read_to_string(path)
.context(format!("Failed to read config file at {:?}", path))?;
let config: Config = toml::from_str(&content)
.context("Failed to parse configuration file")?;
config.validate()?;
Ok(config)
}
pub fn save(&self) -> Result<()> {
self.validate()?;
let config_path = Self::config_path()?;
if let Some(parent) = config_path.parent() {
std::fs::create_dir_all(parent)
.context("Failed to create config directory")?;
}
let temp_path = config_path.with_extension("toml.tmp");
let content = toml::to_string_pretty(self)
.context("Failed to serialize configuration")?;
std::fs::write(&temp_path, content)
.context("Failed to write temporary config file")?;
std::fs::rename(&temp_path, &config_path)
.context("Failed to save configuration file")?;
Ok(())
}
pub fn save_to(&self, path: &PathBuf) -> Result<()> {
self.validate()?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.context("Failed to create config directory")?;
}
let content = toml::to_string_pretty(self)
.context("Failed to serialize configuration")?;
std::fs::write(path, content)
.context("Failed to write configuration file")?;
Ok(())
}
pub fn config_path() -> Result<PathBuf> {
if nix::unistd::geteuid().is_root() {
Ok(PathBuf::from("/etc/smirrors/config.toml"))
} else {
let dirs = directories::ProjectDirs::from("com", "smirrors", "smirrors")
.context("Could not determine config directory")?;
Ok(dirs.config_dir().join("config.toml"))
}
}
pub fn data_dir() -> Result<PathBuf> {
if nix::unistd::geteuid().is_root() {
Ok(PathBuf::from("/var/lib/smirrors"))
} else {
let dirs = directories::ProjectDirs::from("com", "smirrors", "smirrors")
.context("Could not determine data directory")?;
Ok(dirs.data_local_dir().to_path_buf())
}
}
pub fn cache_dir() -> Result<PathBuf> {
if nix::unistd::geteuid().is_root() {
Ok(PathBuf::from("/var/cache/smirrors"))
} else {
let dirs = directories::ProjectDirs::from("com", "smirrors", "smirrors")
.context("Could not determine cache directory")?;
Ok(dirs.cache_dir().to_path_buf())
}
}
pub fn validate(&self) -> Result<()> {
let weight_sum = self.testing.speed_weight + self.testing.latency_weight;
if (weight_sum - 1.0).abs() > 0.01 {
return Err(SMirrorsError::ConfigError(
"Speed and latency weights should sum to 1.0".to_string(),
)
.into());
}
if self.testing.speed_weight < 0.0 || self.testing.latency_weight < 0.0 {
return Err(SMirrorsError::ConfigError(
"Weights must be non-negative".to_string(),
)
.into());
}
if self.testing.min_score < 0.0 || self.testing.min_score > 1.0 {
return Err(SMirrorsError::ConfigError(
"Min score must be between 0.0 and 1.0".to_string(),
)
.into());
}
if self.general.concurrent_tests == 0 || self.general.concurrent_tests > 100 {
return Err(SMirrorsError::ConfigError(
"Concurrent tests must be between 1 and 100".to_string(),
)
.into());
}
if self.testing.max_mirrors == 0 || self.testing.max_mirrors > 50 {
return Err(SMirrorsError::ConfigError(
"Max mirrors must be between 1 and 50".to_string(),
)
.into());
}
if self.general.timeout == 0 || self.general.timeout > 300 {
return Err(SMirrorsError::ConfigError(
"Timeout must be between 1 and 300 seconds".to_string(),
)
.into());
}
crate::utils::parse_duration(&self.general.update_interval)
.context("Invalid update interval format")?;
crate::utils::parse_size(&self.testing.test_file_size)
.context("Invalid test file size format")?;
Ok(())
}
pub fn set(&mut self, key: &str, value: &str) -> Result<()> {
let parts: Vec<&str> = key.split('.').collect();
if parts.len() != 2 {
return Err(SMirrorsError::ConfigError(
"Key must be in format 'section.key'".to_string(),
)
.into());
}
match (parts[0], parts[1]) {
("general", "update_interval") => {
crate::utils::parse_duration(value)?;
self.general.update_interval = value.to_string();
}
("general", "auto_update") => {
self.general.auto_update = value.parse()
.context("Value must be true or false")?;
}
("general", "concurrent_tests") => {
self.general.concurrent_tests = value.parse()
.context("Value must be a number")?;
}
("general", "timeout") => {
self.general.timeout = value.parse()
.context("Value must be a number")?;
}
("general", "retries") => {
self.general.retries = value.parse()
.context("Value must be a number")?;
}
("testing", "speed_weight") => {
self.testing.speed_weight = value.parse()
.context("Value must be a number")?;
}
("testing", "latency_weight") => {
self.testing.latency_weight = value.parse()
.context("Value must be a number")?;
}
("testing", "test_file_size") => {
crate::utils::parse_size(value)?;
self.testing.test_file_size = value.to_string();
}
("testing", "max_mirrors") => {
self.testing.max_mirrors = value.parse()
.context("Value must be a number")?;
}
("testing", "min_score") => {
self.testing.min_score = value.parse()
.context("Value must be a number")?;
}
("logging", "level") => {
self.logging.level = value.to_string();
}
("logging", "format") => {
self.logging.format = value.to_string();
}
_ => {
return Err(SMirrorsError::ConfigError(
format!("Unknown configuration key: {}", key),
)
.into());
}
}
self.validate()?;
Ok(())
}
pub fn get(&self, key: &str) -> Option<String> {
let parts: Vec<&str> = key.split('.').collect();
if parts.len() != 2 {
return None;
}
match (parts[0], parts[1]) {
("general", "update_interval") => Some(self.general.update_interval.clone()),
("general", "auto_update") => Some(self.general.auto_update.to_string()),
("general", "concurrent_tests") => Some(self.general.concurrent_tests.to_string()),
("general", "timeout") => Some(self.general.timeout.to_string()),
("general", "retries") => Some(self.general.retries.to_string()),
("testing", "speed_weight") => Some(self.testing.speed_weight.to_string()),
("testing", "latency_weight") => Some(self.testing.latency_weight.to_string()),
("testing", "test_file_size") => Some(self.testing.test_file_size.clone()),
("testing", "max_mirrors") => Some(self.testing.max_mirrors.to_string()),
("testing", "min_score") => Some(self.testing.min_score.to_string()),
("logging", "level") => Some(self.logging.level.clone()),
("logging", "format") => Some(self.logging.format.clone()),
_ => None,
}
}
pub fn merge(&mut self, other: &Config) {
self.general = other.general.clone();
self.testing = other.testing.clone();
self.distro = other.distro.clone();
self.logging = other.logging.clone();
self.notifications = other.notifications.clone();
for (key, value) in &other.static_mirrors {
self.static_mirrors.insert(key.clone(), value.clone());
}
}
}