1use crate::config::{NetworkSettings, ServerConfig};
7use crate::server::{ServerError, ServerResult};
8use arc_swap::ArcSwap;
9use std::fs;
10use std::path::Path;
11use std::sync::Arc;
12use tonic::transport::{Certificate, Identity, ServerTlsConfig};
13use tracing::{debug, info};
14
15pub struct TlsServerBuilder;
17
18impl TlsServerBuilder {
19 pub fn build(config: &ServerConfig) -> ServerResult<Option<ServerTlsConfig>> {
24 let network = &config.network;
25
26 if !network.tls_enabled {
27 debug!("TLS is not enabled");
28 return Ok(None);
29 }
30
31 info!("Building TLS configuration");
32
33 let cert_path = network.tls_cert.as_ref().ok_or_else(|| {
35 ServerError::ConfigValidation(
36 "TLS cert path is required when TLS is enabled".to_string(),
37 )
38 })?;
39
40 let key_path = network.tls_key.as_ref().ok_or_else(|| {
41 ServerError::ConfigValidation(
42 "TLS key path is required when TLS is enabled".to_string(),
43 )
44 })?;
45
46 let cert_pem = Self::load_file(cert_path)
47 .map_err(|e| ServerError::TlsSetup(format!("Failed to load certificate: {}", e)))?;
48
49 let key_pem = Self::load_file(key_path)
50 .map_err(|e| ServerError::TlsSetup(format!("Failed to load private key: {}", e)))?;
51
52 let identity = Identity::from_pem(&cert_pem, &key_pem);
53
54 let mut tls_config = ServerTlsConfig::new().identity(identity);
55
56 if network.require_client_cert {
58 info!("mTLS enabled - requiring client certificates");
59
60 let ca_path = network.tls_ca.as_ref().ok_or_else(|| {
61 ServerError::ConfigValidation(
62 "TLS CA path is required when client certificates are required".to_string(),
63 )
64 })?;
65
66 let ca_pem = Self::load_file(ca_path).map_err(|e| {
67 ServerError::TlsSetup(format!("Failed to load CA certificate: {}", e))
68 })?;
69
70 let ca_cert = Certificate::from_pem(&ca_pem);
71
72 tls_config = tls_config
73 .client_ca_root(ca_cert)
74 .client_auth_optional(false); }
76
77 info!("TLS configuration built successfully");
78 Ok(Some(tls_config))
79 }
80
81 fn load_file(path: &Path) -> ServerResult<Vec<u8>> {
83 fs::read(path)
84 .map_err(|e| ServerError::TlsSetup(format!("Failed to read file {:?}: {}", path, e)))
85 }
86}
87
88pub struct TlsClientBuilder;
90
91impl TlsClientBuilder {
92 pub fn build(
96 network: &NetworkSettings,
97 ) -> ServerResult<Option<tonic::transport::ClientTlsConfig>> {
98 if !network.tls_enabled {
99 debug!("TLS is not enabled for client connections");
100 return Ok(None);
101 }
102
103 info!("Building client TLS configuration");
104
105 let mut tls_config = tonic::transport::ClientTlsConfig::new();
106
107 if network.require_client_cert {
109 let cert_path = network.tls_cert.as_ref().ok_or_else(|| {
110 ServerError::ConfigValidation(
111 "TLS cert path is required for client mTLS".to_string(),
112 )
113 })?;
114
115 let key_path = network.tls_key.as_ref().ok_or_else(|| {
116 ServerError::ConfigValidation(
117 "TLS key path is required for client mTLS".to_string(),
118 )
119 })?;
120
121 let cert_pem = TlsServerBuilder::load_file(cert_path).map_err(|e| {
122 ServerError::TlsSetup(format!("Failed to load client certificate: {}", e))
123 })?;
124
125 let key_pem = TlsServerBuilder::load_file(key_path).map_err(|e| {
126 ServerError::TlsSetup(format!("Failed to load client private key: {}", e))
127 })?;
128
129 let identity = Identity::from_pem(&cert_pem, &key_pem);
130 tls_config = tls_config.identity(identity);
131 }
132
133 if let Some(ca_path) = &network.tls_ca {
135 let ca_pem = TlsServerBuilder::load_file(ca_path).map_err(|e| {
136 ServerError::TlsSetup(format!("Failed to load CA certificate for client: {}", e))
137 })?;
138
139 let ca_cert = Certificate::from_pem(&ca_pem);
140 tls_config = tls_config.ca_certificate(ca_cert);
141 }
142
143 info!("Client TLS configuration built successfully");
144 Ok(Some(tls_config))
145 }
146}
147
148pub struct LiveTlsAcceptor {
158 config_store: Arc<ArcSwap<rustls::ServerConfig>>,
159}
160
161impl LiveTlsAcceptor {
162 pub fn new(config_store: Arc<ArcSwap<rustls::ServerConfig>>) -> Self {
164 Self { config_store }
165 }
166
167 pub fn current_config(&self) -> Arc<rustls::ServerConfig> {
169 self.config_store.load_full()
170 }
171
172 pub fn rotate(&self, new_config: Arc<rustls::ServerConfig>) {
178 self.config_store.store(new_config);
179 }
180
181 pub fn make_acceptor(&self) -> tokio_rustls::TlsAcceptor {
187 tokio_rustls::TlsAcceptor::from(self.current_config())
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::config::{
195 AuthSettings, AuthorizationSettings, ClusterSettings, LoggingSettings, MetricsSettings,
196 ServerSettings, StorageSettings,
197 };
198 use std::env;
199 use std::fs;
200 use std::path::PathBuf;
201
202 fn create_test_config(
203 tls_enabled: bool,
204 require_client_cert: bool,
205 ) -> (ServerConfig, tempfile::TempDir) {
206 let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
207 let data_dir = temp_dir.path().join("data");
208 fs::create_dir(&data_dir).expect("Failed to create data dir");
209
210 let (tls_cert, tls_key, tls_ca) = if tls_enabled {
212 let cert_path = temp_dir.path().join("server.crt");
213 let key_path = temp_dir.path().join("server.key");
214 let ca_path = temp_dir.path().join("ca.crt");
215
216 fs::write(
218 &cert_path,
219 b"-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----\n",
220 )
221 .expect("Failed to write cert");
222 fs::write(
223 &key_path,
224 b"-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----\n",
225 )
226 .expect("Failed to write key");
227 fs::write(
228 &ca_path,
229 b"-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----\n",
230 )
231 .expect("Failed to write CA");
232
233 (Some(cert_path), Some(key_path), Some(ca_path))
234 } else {
235 (None, None, None)
236 };
237
238 let config = ServerConfig {
239 server: ServerSettings {
240 bind_address: "127.0.0.1:50051".to_string(),
241 data_dir,
242 pid_file: temp_dir.path().join("server.pid"),
243 max_connections: 1000,
244 shutdown_timeout_secs: 30,
245 },
246 storage: StorageSettings {
247 engine: "memory".to_string(),
248 wal: crate::config::WalSettings {
249 enabled: true,
250 dir: PathBuf::from("wal"),
251 segment_size_mb: 64,
252 sync_mode: "interval".to_string(),
253 },
254 memtable_size_mb: 64,
255 block_cache_size_mb: 256,
256 compaction: crate::config::CompactionSettings {
257 strategy: "leveled".to_string(),
258 num_levels: 7,
259 level_multiplier: 10,
260 max_concurrent: 2,
261 },
262 },
263 network: NetworkSettings {
264 tls_enabled,
265 tls_cert,
266 tls_key,
267 tls_ca,
268 require_client_cert,
269 connection_timeout_secs: 30,
270 keepalive_interval_secs: 60,
271 },
272 cluster: None,
273 logging: LoggingSettings {
274 level: "info".to_string(),
275 format: "json".to_string(),
276 file_enabled: false,
277 file_path: None,
278 rotation: crate::config::LogRotationSettings::default(),
279 },
280 metrics: MetricsSettings {
281 enabled: true,
282 bind_address: "127.0.0.1:9090".to_string(),
283 export_interval_secs: 60,
284 },
285 auth: AuthSettings::default(),
286 authz: AuthorizationSettings {
287 enabled: false,
288 default_role: "user".to_string(),
289 roles_file: None,
290 policies_file: None,
291 collection_permissions: true,
292 default_mode: "deny-by-default".to_string(),
293 audit_enabled: false,
294 audit_log_path: None,
295 },
296 resource_limits: Default::default(),
297 circuit_cache: Default::default(),
298 timeouts: Default::default(),
299 };
300
301 (config, temp_dir)
302 }
303
304 #[test]
305 fn test_tls_disabled() {
306 let (config, _temp_dir) = create_test_config(false, false);
307
308 let result = TlsServerBuilder::build(&config);
309 assert!(result.is_ok());
310 assert!(result.ok().and_then(|x| x).is_none());
311 }
312
313 #[test]
314 fn test_tls_enabled_basic() {
315 let (config, _temp_dir) = create_test_config(true, false);
316
317 let result = TlsServerBuilder::build(&config);
318 assert!(result.is_ok() || result.is_err());
322 }
323
324 #[test]
325 fn test_mtls_enabled() {
326 let (config, _temp_dir) = create_test_config(true, true);
327
328 let result = TlsServerBuilder::build(&config);
329 assert!(result.is_ok() || result.is_err());
331 }
332
333 #[test]
334 fn test_load_file() {
335 let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
336 let file_path = temp_dir.path().join("test.txt");
337 fs::write(&file_path, b"test content").expect("Failed to write file");
338
339 let content = TlsServerBuilder::load_file(&file_path);
340 assert!(content.is_ok());
341 assert_eq!(content.ok(), Some(b"test content".to_vec()));
342 }
343
344 #[test]
345 fn test_load_file_not_found() {
346 let result = TlsServerBuilder::load_file(Path::new("/nonexistent/file.txt"));
347 assert!(result.is_err());
348 }
349
350 #[test]
351 fn test_client_tls_disabled() {
352 let (config, _temp_dir) = create_test_config(false, false);
353
354 let result = TlsClientBuilder::build(&config.network);
355 assert!(result.is_ok());
356 assert!(result.ok().and_then(|x| x).is_none());
357 }
358
359 #[test]
360 fn test_client_tls_enabled() {
361 let (config, _temp_dir) = create_test_config(true, false);
362
363 let result = TlsClientBuilder::build(&config.network);
364 assert!(result.is_ok() || result.is_err());
366 }
367}