use crate::errors::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
pub trait ServerConfig {
fn validate(&self) -> Result<()> {
Ok(())
}
fn config_name(&self) -> &'static str;
fn is_enabled(&self) -> bool {
true
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeoutConfig {
pub connect_timeout: Duration,
pub read_timeout: Duration,
pub write_timeout: Duration,
}
impl Default for TimeoutConfig {
fn default() -> Self {
Self {
connect_timeout: Duration::from_secs(30),
read_timeout: Duration::from_secs(30),
write_timeout: Duration::from_secs(30),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub enable_tls: bool,
pub min_tls_version: String,
pub cipher_suites: Vec<String>,
pub cert_validation: CertificateValidation,
pub verify_certificates: bool,
#[serde(default)]
pub accept_invalid_certs: bool,
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
enable_tls: true,
min_tls_version: "1.2".to_string(),
cipher_suites: vec![
"TLS_AES_256_GCM_SHA384".to_string(),
"TLS_CHACHA20_POLY1305_SHA256".to_string(),
"TLS_AES_128_GCM_SHA256".to_string(),
],
cert_validation: CertificateValidation::Full,
verify_certificates: true,
accept_invalid_certs: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CertificateValidation {
Full,
SkipHostname,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EndpointConfig {
pub base_url: String,
pub api_version: Option<String>,
pub headers: HashMap<String, String>,
pub timeout: TimeoutConfig,
pub security: SecurityConfig,
}
impl EndpointConfig {
pub fn new(base_url: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
api_version: None,
headers: HashMap::new(),
timeout: TimeoutConfig::default(),
security: SecurityConfig::default(),
}
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
self.api_version = Some(version.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_attempts: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
pub jitter_factor: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter_factor: 0.1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
pub debug: bool,
pub log_bodies: bool,
pub log_sensitive: bool,
pub max_log_size: usize,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
debug: false,
log_bodies: false,
log_sensitive: false,
max_log_size: 4096,
}
}
}
pub mod validation {
use super::*;
pub fn validate_url(url: &str) -> Result<()> {
if url.is_empty() {
return Err(crate::errors::AuthError::config(
"URL cannot be empty".to_string(),
));
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err(crate::errors::AuthError::config(format!(
"Invalid URL format: {}",
url
)));
}
Ok(())
}
pub fn validate_positive_duration(duration: &Duration, field_name: &str) -> Result<()> {
if duration.is_zero() {
return Err(crate::errors::AuthError::config(format!(
"{} must be greater than zero",
field_name
)));
}
Ok(())
}
pub fn validate_port(port: u16) -> Result<()> {
if port == 0 {
return Err(crate::errors::AuthError::config(
"Port cannot be zero".to_string(),
));
}
Ok(())
}
pub fn validate_required_field(value: &str, field_name: &str) -> Result<()> {
if value.trim().is_empty() {
return Err(crate::errors::AuthError::config(format!(
"{} is required and cannot be empty",
field_name
)));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_timeout_config_default() {
let tc = TimeoutConfig::default();
assert_eq!(tc.connect_timeout, Duration::from_secs(30));
assert_eq!(tc.read_timeout, Duration::from_secs(30));
}
#[test]
fn test_security_config_default() {
let sc = SecurityConfig::default();
assert!(sc.enable_tls);
assert!(sc.verify_certificates);
}
#[test]
fn test_retry_config_default() {
let rc = RetryConfig::default();
assert!(rc.max_attempts > 0);
}
#[test]
fn test_logging_config_default() {
let lc = LoggingConfig::default();
assert!(!lc.debug);
assert!(!lc.log_bodies);
assert!(!lc.log_sensitive);
assert!(lc.max_log_size > 0);
}
#[test]
fn test_endpoint_config_new() {
let ec = EndpointConfig::new("https://api.example.com");
assert_eq!(ec.base_url, "https://api.example.com");
assert!(ec.api_version.is_none());
}
#[test]
fn test_endpoint_config_with_header() {
let ec = EndpointConfig::new("https://api.example.com")
.with_header("Authorization", "Bearer xxx");
assert_eq!(ec.headers.get("Authorization").unwrap(), "Bearer xxx");
}
#[test]
fn test_endpoint_config_with_api_version() {
let ec = EndpointConfig::new("https://api.example.com").with_api_version("2024-01-01");
assert_eq!(ec.api_version.as_deref(), Some("2024-01-01"));
}
#[test]
fn test_validate_url_valid() {
assert!(validation::validate_url("https://example.com").is_ok());
}
#[test]
fn test_validate_url_empty() {
assert!(validation::validate_url("").is_err());
}
#[test]
fn test_validate_url_no_scheme() {
assert!(validation::validate_url("example.com").is_err());
}
#[test]
fn test_validate_positive_duration() {
assert!(validation::validate_positive_duration(&Duration::from_secs(1), "timeout").is_ok());
}
#[test]
fn test_validate_zero_duration() {
assert!(validation::validate_positive_duration(&Duration::ZERO, "timeout").is_err());
}
#[test]
fn test_validate_port() {
assert!(validation::validate_port(8080).is_ok());
assert!(validation::validate_port(0).is_err());
}
#[test]
fn test_validate_required_field() {
assert!(validation::validate_required_field("value", "name").is_ok());
assert!(validation::validate_required_field("", "name").is_err());
assert!(validation::validate_required_field(" ", "name").is_err());
}
#[test]
fn test_security_config_accept_invalid_certs_defaults_false() {
let sc = SecurityConfig::default();
assert!(
!sc.accept_invalid_certs,
"accept_invalid_certs must default to false"
);
}
#[test]
fn test_security_config_accept_invalid_certs_deserialization_default() {
let json = r#"{"enable_tls":true,"min_tls_version":"1.2","cipher_suites":[],"cert_validation":"Full","verify_certificates":true}"#;
let sc: SecurityConfig = serde_json::from_str(json).unwrap();
assert!(!sc.accept_invalid_certs);
}
}