codex_memory/security/
tls.rs

1use crate::security::{Result, SecurityError, TlsConfig};
2use axum_server::tls_rustls::RustlsConfig;
3use tracing::{debug, info};
4
5/// TLS configuration and certificate management
6pub struct TlsManager {
7    config: TlsConfig,
8}
9
10impl TlsManager {
11    pub fn new(config: TlsConfig) -> Result<Self> {
12        let manager = Self { config };
13
14        if manager.config.enabled {
15            manager.validate_config()?;
16        }
17
18        Ok(manager)
19    }
20
21    fn validate_config(&self) -> Result<()> {
22        if !self.config.cert_path.exists() {
23            return Err(SecurityError::TlsError {
24                message: format!("TLS certificate not found: {:?}", self.config.cert_path),
25            });
26        }
27
28        if !self.config.key_path.exists() {
29            return Err(SecurityError::TlsError {
30                message: format!("TLS private key not found: {:?}", self.config.key_path),
31            });
32        }
33
34        if let Some(ca_path) = &self.config.client_ca_path {
35            if !ca_path.exists() {
36                return Err(SecurityError::TlsError {
37                    message: format!("Client CA certificate not found: {ca_path:?}"),
38                });
39            }
40        }
41
42        Ok(())
43    }
44
45    /// Create Rustls config for Axum server
46    pub async fn create_rustls_config(&self) -> Result<RustlsConfig> {
47        if !self.config.enabled {
48            return Err(SecurityError::TlsError {
49                message: "TLS is not enabled".to_string(),
50            });
51        }
52
53        info!(
54            "Creating TLS configuration from cert: {:?}, key: {:?}",
55            self.config.cert_path, self.config.key_path
56        );
57
58        RustlsConfig::from_pem_file(&self.config.cert_path, &self.config.key_path)
59            .await
60            .map_err(|e| SecurityError::TlsError {
61                message: format!("Failed to create TLS config: {e}"),
62            })
63    }
64
65    /// Check if TLS is enabled
66    pub fn is_enabled(&self) -> bool {
67        self.config.enabled
68    }
69
70    /// Check if mTLS (mutual TLS) is required
71    pub fn requires_client_cert(&self) -> bool {
72        self.config.enabled && self.config.require_client_cert
73    }
74
75    /// Get the TLS port
76    pub fn get_port(&self) -> u16 {
77        self.config.port
78    }
79
80    /// Check if this configuration supports mTLS
81    pub fn supports_mtls(&self) -> bool {
82        self.config.enabled && self.config.client_ca_path.is_some()
83    }
84
85    /// Get certificate information (for monitoring/diagnostics)
86    pub fn get_cert_info(&self) -> TlsCertInfo {
87        TlsCertInfo {
88            cert_path: self.config.cert_path.clone(),
89            key_path: self.config.key_path.clone(),
90            client_ca_path: self.config.client_ca_path.clone(),
91            enabled: self.config.enabled,
92            mtls_enabled: self.config.require_client_cert,
93            port: self.config.port,
94        }
95    }
96
97    /// Validate that certificates exist and are readable
98    pub fn validate_certificates(&self) -> Result<()> {
99        if !self.config.enabled {
100            debug!("TLS is disabled, skipping certificate validation");
101            return Ok(());
102        }
103
104        self.validate_config()?;
105
106        // Basic file readability check
107        std::fs::read(&self.config.cert_path).map_err(|e| SecurityError::TlsError {
108            message: format!("Cannot read certificate file: {e}"),
109        })?;
110
111        std::fs::read(&self.config.key_path).map_err(|e| SecurityError::TlsError {
112            message: format!("Cannot read private key file: {e}"),
113        })?;
114
115        if let Some(ca_path) = &self.config.client_ca_path {
116            std::fs::read(ca_path).map_err(|e| SecurityError::TlsError {
117                message: format!("Cannot read client CA file: {e}"),
118            })?;
119        }
120
121        info!("TLS certificates validated successfully");
122        Ok(())
123    }
124}
125
126/// TLS certificate information for monitoring
127#[derive(Debug, Clone)]
128pub struct TlsCertInfo {
129    pub cert_path: std::path::PathBuf,
130    pub key_path: std::path::PathBuf,
131    pub client_ca_path: Option<std::path::PathBuf>,
132    pub enabled: bool,
133    pub mtls_enabled: bool,
134    pub port: u16,
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use std::path::PathBuf;
141
142    #[test]
143    fn test_tls_manager_disabled() {
144        let config = TlsConfig {
145            enabled: false,
146            cert_path: PathBuf::from("/nonexistent"),
147            key_path: PathBuf::from("/nonexistent"),
148            port: 8443,
149            require_client_cert: false,
150            client_ca_path: None,
151        };
152
153        let manager = TlsManager::new(config).unwrap();
154        assert!(!manager.is_enabled());
155        assert!(!manager.requires_client_cert());
156        assert_eq!(manager.get_port(), 8443);
157    }
158
159    #[test]
160    fn test_tls_manager_invalid_cert_path() {
161        let config = TlsConfig {
162            enabled: true,
163            cert_path: PathBuf::from("/nonexistent/cert.pem"),
164            key_path: PathBuf::from("/nonexistent/key.pem"),
165            port: 8443,
166            require_client_cert: false,
167            client_ca_path: None,
168        };
169
170        let result = TlsManager::new(config);
171        assert!(result.is_err());
172
173        if let Err(SecurityError::TlsError { message }) = result {
174            assert!(message.contains("certificate not found"));
175        }
176    }
177
178    #[test]
179    fn test_tls_cert_info() {
180        let config = TlsConfig {
181            enabled: true,
182            cert_path: PathBuf::from("/test/cert.pem"),
183            key_path: PathBuf::from("/test/key.pem"),
184            port: 8443,
185            require_client_cert: true,
186            client_ca_path: Some(PathBuf::from("/test/ca.pem")),
187        };
188
189        // This will fail validation but we can still test cert info
190        let manager = TlsManager {
191            config: config.clone(),
192        };
193        let cert_info = manager.get_cert_info();
194
195        assert_eq!(cert_info.cert_path, config.cert_path);
196        assert_eq!(cert_info.key_path, config.key_path);
197        assert_eq!(cert_info.client_ca_path, config.client_ca_path);
198        assert_eq!(cert_info.enabled, config.enabled);
199        assert_eq!(cert_info.mtls_enabled, config.require_client_cert);
200        assert_eq!(cert_info.port, config.port);
201    }
202
203    #[tokio::test]
204    async fn test_create_rustls_config_disabled() {
205        let config = TlsConfig {
206            enabled: false,
207            cert_path: PathBuf::from("/nonexistent"),
208            key_path: PathBuf::from("/nonexistent"),
209            port: 8443,
210            require_client_cert: false,
211            client_ca_path: None,
212        };
213
214        let manager = TlsManager::new(config).unwrap();
215        let result = manager.create_rustls_config().await;
216        assert!(result.is_err());
217
218        if let Err(SecurityError::TlsError { message }) = result {
219            assert!(message.contains("TLS is not enabled"));
220        }
221    }
222}