use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
pub enable_auth: bool,
pub secret_key: String,
pub token_lifetime_hours: u64,
pub allowed_ips: Vec<String>,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
enable_auth: false,
secret_key: generate_random_key(),
token_lifetime_hours: 24,
allowed_ips: vec![],
}
}
}
impl AuthConfig {
pub fn new() -> Self {
Self::default()
}
pub fn enable_auth(mut self) -> Self {
self.enable_auth = true;
self
}
pub fn with_secret_key(mut self, key: String) -> Self {
self.secret_key = key;
self
}
pub fn with_token_lifetime(mut self, hours: u64) -> Self {
self.token_lifetime_hours = hours;
self
}
pub fn add_allowed_ip(mut self, ip: String) -> Self {
self.allowed_ips.push(ip);
self
}
pub fn from_toml(toml_str: &str) -> Result<Self, toml::de::Error> {
toml::from_str(toml_str)
}
pub fn to_toml(&self) -> Result<String, toml::ser::Error> {
toml::to_string_pretty(self)
}
pub fn validate(&self) -> Result<(), String> {
if self.enable_auth && self.secret_key.is_empty() {
return Err("Secret key cannot be empty when auth is enabled".to_string());
}
if self.secret_key.len() < 16 {
return Err("Secret key should be at least 16 characters".to_string());
}
if self.token_lifetime_hours == 0 {
return Err("Token lifetime must be greater than 0".to_string());
}
for cidr in &self.allowed_ips {
validate_cidr(cidr)?;
}
Ok(())
}
}
fn generate_random_key() -> String {
use uuid::Uuid;
format!("{}-{}", Uuid::new_v4(), Uuid::new_v4())
}
fn validate_cidr(cidr: &str) -> Result<(), String> {
let (ip_str, prefix_str) = cidr
.split_once('/')
.ok_or_else(|| format!("Invalid CIDR '{}': missing prefix length (e.g. /24)", cidr))?;
let prefix_len: u8 = prefix_str
.parse()
.map_err(|_| format!("Invalid CIDR '{}': prefix length is not a number", cidr))?;
if ip_str.contains(':') {
ip_str
.parse::<std::net::Ipv6Addr>()
.map_err(|_| format!("Invalid CIDR '{}': malformed IPv6 address", cidr))?;
if prefix_len > 128 {
return Err(format!(
"Invalid CIDR '{}': IPv6 prefix length must be 0-128",
cidr
));
}
} else {
ip_str
.parse::<std::net::Ipv4Addr>()
.map_err(|_| format!("Invalid CIDR '{}': malformed IPv4 address", cidr))?;
if prefix_len > 32 {
return Err(format!(
"Invalid CIDR '{}': IPv4 prefix length must be 0-32",
cidr
));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = AuthConfig::default();
assert!(!config.enable_auth);
assert!(!config.secret_key.is_empty());
assert_eq!(config.token_lifetime_hours, 24);
}
#[test]
fn test_config_builder() {
let config = AuthConfig::new()
.enable_auth()
.with_secret_key("my-secret-key-12345".to_string())
.with_token_lifetime(48)
.add_allowed_ip("192.168.1.0/24".to_string());
assert!(config.enable_auth);
assert_eq!(config.secret_key, "my-secret-key-12345");
assert_eq!(config.token_lifetime_hours, 48);
assert_eq!(config.allowed_ips.len(), 1);
}
#[test]
fn test_config_validation() {
let config = AuthConfig::new()
.enable_auth()
.with_secret_key("short".to_string());
assert!(config.validate().is_err());
let config = AuthConfig::new()
.enable_auth()
.with_secret_key("valid-secret-key-16-chars".to_string());
assert!(config.validate().is_ok());
}
#[test]
fn test_toml_serialization() {
let config = AuthConfig::new()
.enable_auth()
.with_secret_key("test-secret-key-12345".to_string())
.with_token_lifetime(48);
let toml_str = config.to_toml().unwrap();
assert!(toml_str.contains("enable_auth = true"));
assert!(toml_str.contains("test-secret-key-12345"));
let loaded_config = AuthConfig::from_toml(&toml_str).unwrap();
assert_eq!(loaded_config.enable_auth, config.enable_auth);
assert_eq!(loaded_config.secret_key, config.secret_key);
}
}