use crate::types::{AppConfig, AppError, Result};
use std::fs;
use std::path::{Path, PathBuf};
use directories::ProjectDirs;
use log::{info, warn};
const DEFAULT_CONFIG_PATH: &str = "/etc/rust-network-mgr/config.yaml";
const PKG_DEFAULT_CONFIG_PATH_FALLBACK: &str = "pkg-files/config/default.yaml";
fn get_pkg_default_config_path() -> PathBuf {
if cfg!(debug_assertions) {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(PKG_DEFAULT_CONFIG_PATH_FALLBACK)
} else {
PathBuf::from(PKG_DEFAULT_CONFIG_PATH_FALLBACK) }
}
fn get_config_path(config_path_override: Option<&str>) -> Result<PathBuf> {
if let Some(path_str) = config_path_override {
let path = PathBuf::from(path_str);
if path.exists() {
return Ok(path);
} else {
return Err(AppError::ConfigIo(format!("Specified config file not found: {}", path_str)));
}
}
let default_path = Path::new(DEFAULT_CONFIG_PATH);
if default_path.exists() {
return Ok(default_path.to_path_buf());
}
let pkg_default_path = get_pkg_default_config_path();
if pkg_default_path.exists() {
warn!("System config not found at {}, using packaged default: {}",
DEFAULT_CONFIG_PATH, pkg_default_path.display());
return Ok(pkg_default_path);
}
if let Some(proj_dirs) = ProjectDirs::from("", "", "RustNetworkManager") {
let config_dir: &Path = proj_dirs.config_dir();
let user_config_path = config_dir.join("config.yaml");
if user_config_path.exists() {
warn!("System config not found, using user config: {}", user_config_path.display());
return Ok(user_config_path);
}
}
Err(AppError::ConfigIo(format!(
"Configuration file not found. Looked in: override ({:?}), {}, {}, and user config dir.",
config_path_override,
DEFAULT_CONFIG_PATH,
pkg_default_path.display()
)))
}
pub fn load_config(config_path_override: Option<&str>) -> Result<AppConfig> {
let config_path = get_config_path(config_path_override)?;
info!("Loading configuration from: {}", config_path.display());
match std::fs::read_to_string(&config_path) {
Ok(content) => {
let config: AppConfig = serde_yaml::from_str(&content)
.map_err(AppError::ConfigParse)?;
validate_config(&config)?;
Ok(config)
}
Err(e) => {
Err(AppError::ConfigIo(format!(
"Failed to read configuration file '{}': {}",
config_path.display(),
e
)))
}
}
}
pub(crate) fn validate_config(config: &AppConfig) -> Result<()> {
if config.interfaces.is_empty() {
return Err(AppError::ConfigValidation(
"Configuration must include at least one interface.".to_string(),
));
}
for interface in &config.interfaces {
if interface.name.is_empty() {
return Err(AppError::ConfigValidation(
"Interface name cannot be empty".to_string(),
));
}
}
Ok(())
}
#[cfg(test)]
pub mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_load_valid_config() {
let yaml = r#"
interfaces:
- name: eth0
dhcp: true
nftables_zone: wan
- name: eth1
address: 192.168.1.1/24
nftables_zone: lan
socket_path: /tmp/test.sock
"#;
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "{}", yaml).unwrap();
let config = load_config(Some(file.path().to_str().unwrap())).unwrap();
assert_eq!(config.interfaces.len(), 2);
assert_eq!(config.interfaces[0].name, "eth0");
assert_eq!(config.interfaces[0].dhcp, Some(true));
assert_eq!(config.interfaces[0].nftables_zone, Some("wan".to_string()));
assert_eq!(config.interfaces[1].name, "eth1");
assert_eq!(config.interfaces[1].address, Some("192.168.1.1/24".to_string()));
assert_eq!(config.interfaces[1].nftables_zone, Some("lan".to_string()));
assert_eq!(config.socket_path, Some("/tmp/test.sock".to_string()));
}
#[test]
fn test_load_fallback_config() {
let non_existent_path = Path::new("/tmp/non_existent_config_for_test.yaml");
let _ = std::fs::remove_file(&non_existent_path);
let fallback_path = get_pkg_default_config_path();
let fallback_dir = fallback_path.parent().unwrap();
std::fs::create_dir_all(fallback_dir).unwrap();
let fallback_yaml = r#"
interfaces:
- name: "fallback0"
dhcp: true
socket_path: "/tmp/fallback.sock"
"#;
std::fs::write(&fallback_path, fallback_yaml).unwrap();
let config = load_config(Some(fallback_path.to_str().unwrap())).unwrap();
assert_eq!(config.interfaces.len(), 1);
assert_eq!(config.interfaces[0].name, "fallback0");
assert_eq!(config.socket_path, Some("/tmp/fallback.sock".to_string()));
let _ = std::fs::remove_file(&fallback_path);
let _ = std::fs::remove_dir(fallback_dir); }
#[test]
fn test_load_invalid_yaml() {
let yaml = "interfaces:\n - name: eth0\n invalid_indent: true";
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "{}", yaml).unwrap();
let result = load_config(Some(file.path().to_str().unwrap()));
assert!(result.is_err());
match result {
Err(AppError::ConfigParse(_)) => { }
_ => panic!("Expected ConfigParse error, got {:?}", result),
}
}
#[test]
fn test_validate_empty_interfaces() {
let config = AppConfig {
interfaces: vec![],
socket_path: None,
nftables_rules_path: None,
};
let result = validate_config(&config);
assert!(result.is_err());
match result {
Err(AppError::ConfigValidation(msg)) => {
assert!(msg.contains("at least one interface"));
}
_ => panic!("Expected ConfigValidation error, got {:?}", result),
}
}
}