Skip to main content

amaters_server/
tls_config.rs

1//! TLS configuration builder for gRPC server
2//!
3//! This module provides functionality to build TLS/mTLS configurations
4//! for the AmateRS gRPC server from server configuration.
5
6use crate::config::{NetworkSettings, ServerConfig};
7use crate::server::{ServerError, ServerResult};
8use arc_swap::ArcSwap;
9use std::fs;
10use std::path::Path;
11use std::sync::Arc;
12use tonic::transport::{Certificate, Identity, ServerTlsConfig};
13use tracing::{debug, info};
14
15/// TLS server configuration builder
16pub struct TlsServerBuilder;
17
18impl TlsServerBuilder {
19    /// Build TLS configuration from server config
20    ///
21    /// Returns None if TLS is not enabled.
22    /// Returns ServerError if TLS is enabled but configuration is invalid.
23    pub fn build(config: &ServerConfig) -> ServerResult<Option<ServerTlsConfig>> {
24        let network = &config.network;
25
26        if !network.tls_enabled {
27            debug!("TLS is not enabled");
28            return Ok(None);
29        }
30
31        info!("Building TLS configuration");
32
33        // Load server certificate and key
34        let cert_path = network.tls_cert.as_ref().ok_or_else(|| {
35            ServerError::ConfigValidation(
36                "TLS cert path is required when TLS is enabled".to_string(),
37            )
38        })?;
39
40        let key_path = network.tls_key.as_ref().ok_or_else(|| {
41            ServerError::ConfigValidation(
42                "TLS key path is required when TLS is enabled".to_string(),
43            )
44        })?;
45
46        let cert_pem = Self::load_file(cert_path)
47            .map_err(|e| ServerError::TlsSetup(format!("Failed to load certificate: {}", e)))?;
48
49        let key_pem = Self::load_file(key_path)
50            .map_err(|e| ServerError::TlsSetup(format!("Failed to load private key: {}", e)))?;
51
52        let identity = Identity::from_pem(&cert_pem, &key_pem);
53
54        let mut tls_config = ServerTlsConfig::new().identity(identity);
55
56        // Configure mTLS if client certificates are required
57        if network.require_client_cert {
58            info!("mTLS enabled - requiring client certificates");
59
60            let ca_path = network.tls_ca.as_ref().ok_or_else(|| {
61                ServerError::ConfigValidation(
62                    "TLS CA path is required when client certificates are required".to_string(),
63                )
64            })?;
65
66            let ca_pem = Self::load_file(ca_path).map_err(|e| {
67                ServerError::TlsSetup(format!("Failed to load CA certificate: {}", e))
68            })?;
69
70            let ca_cert = Certificate::from_pem(&ca_pem);
71
72            tls_config = tls_config
73                .client_ca_root(ca_cert)
74                .client_auth_optional(false); // Require client certificates
75        }
76
77        info!("TLS configuration built successfully");
78        Ok(Some(tls_config))
79    }
80
81    /// Load a file as a byte vector
82    fn load_file(path: &Path) -> ServerResult<Vec<u8>> {
83        fs::read(path)
84            .map_err(|e| ServerError::TlsSetup(format!("Failed to read file {:?}: {}", path, e)))
85    }
86}
87
88/// Builder for client TLS configuration (for connecting to other nodes)
89pub struct TlsClientBuilder;
90
91impl TlsClientBuilder {
92    /// Build client TLS configuration from network settings
93    ///
94    /// This is used when the server needs to connect to other nodes in a cluster.
95    pub fn build(
96        network: &NetworkSettings,
97    ) -> ServerResult<Option<tonic::transport::ClientTlsConfig>> {
98        if !network.tls_enabled {
99            debug!("TLS is not enabled for client connections");
100            return Ok(None);
101        }
102
103        info!("Building client TLS configuration");
104
105        let mut tls_config = tonic::transport::ClientTlsConfig::new();
106
107        // If mTLS is enabled, load client certificate and key
108        if network.require_client_cert {
109            let cert_path = network.tls_cert.as_ref().ok_or_else(|| {
110                ServerError::ConfigValidation(
111                    "TLS cert path is required for client mTLS".to_string(),
112                )
113            })?;
114
115            let key_path = network.tls_key.as_ref().ok_or_else(|| {
116                ServerError::ConfigValidation(
117                    "TLS key path is required for client mTLS".to_string(),
118                )
119            })?;
120
121            let cert_pem = TlsServerBuilder::load_file(cert_path).map_err(|e| {
122                ServerError::TlsSetup(format!("Failed to load client certificate: {}", e))
123            })?;
124
125            let key_pem = TlsServerBuilder::load_file(key_path).map_err(|e| {
126                ServerError::TlsSetup(format!("Failed to load client private key: {}", e))
127            })?;
128
129            let identity = Identity::from_pem(&cert_pem, &key_pem);
130            tls_config = tls_config.identity(identity);
131        }
132
133        // Load CA certificate if provided
134        if let Some(ca_path) = &network.tls_ca {
135            let ca_pem = TlsServerBuilder::load_file(ca_path).map_err(|e| {
136                ServerError::TlsSetup(format!("Failed to load CA certificate for client: {}", e))
137            })?;
138
139            let ca_cert = Certificate::from_pem(&ca_pem);
140            tls_config = tls_config.ca_certificate(ca_cert);
141        }
142
143        info!("Client TLS configuration built successfully");
144        Ok(Some(tls_config))
145    }
146}
147
148// ─── LiveTlsAcceptor ─────────────────────────────────────────────────────────
149
150/// A TLS acceptor that can swap its certificate store without restarting.
151///
152/// Shares the same `ArcSwap<rustls::ServerConfig>` used by the cert-rotation
153/// path in `main.rs`, ensuring a single source of truth for the TLS
154/// configuration.  This is the *server-side config shim* — for the TCP-level
155/// acceptor (which wraps a `TcpListener`) see
156/// `amaters_net::tls_acceptor::LiveTlsAcceptor`.
157pub struct LiveTlsAcceptor {
158    config_store: Arc<ArcSwap<rustls::ServerConfig>>,
159}
160
161impl LiveTlsAcceptor {
162    /// Create a new acceptor backed by a shared config store.
163    pub fn new(config_store: Arc<ArcSwap<rustls::ServerConfig>>) -> Self {
164        Self { config_store }
165    }
166
167    /// Load the current TLS configuration.
168    pub fn current_config(&self) -> Arc<rustls::ServerConfig> {
169        self.config_store.load_full()
170    }
171
172    /// Update the TLS configuration (e.g., after certificate rotation).
173    ///
174    /// All *new* connections that call [`Self::make_acceptor`] after this
175    /// point will use the new configuration.  In-flight connections are
176    /// unaffected.
177    pub fn rotate(&self, new_config: Arc<rustls::ServerConfig>) {
178        self.config_store.store(new_config);
179    }
180
181    /// Build a `tokio_rustls::TlsAcceptor` from the current config.
182    ///
183    /// The returned acceptor captures a snapshot of the current config.
184    /// Call `make_acceptor()` again after [`Self::rotate`] to pick up the
185    /// new certificate.
186    pub fn make_acceptor(&self) -> tokio_rustls::TlsAcceptor {
187        tokio_rustls::TlsAcceptor::from(self.current_config())
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::config::{
195        AuthSettings, AuthorizationSettings, ClusterSettings, LoggingSettings, MetricsSettings,
196        ServerSettings, StorageSettings,
197    };
198    use std::env;
199    use std::fs;
200    use std::path::PathBuf;
201
202    fn create_test_config(
203        tls_enabled: bool,
204        require_client_cert: bool,
205    ) -> (ServerConfig, tempfile::TempDir) {
206        let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
207        let data_dir = temp_dir.path().join("data");
208        fs::create_dir(&data_dir).expect("Failed to create data dir");
209
210        // Create dummy certificate files if TLS is enabled
211        let (tls_cert, tls_key, tls_ca) = if tls_enabled {
212            let cert_path = temp_dir.path().join("server.crt");
213            let key_path = temp_dir.path().join("server.key");
214            let ca_path = temp_dir.path().join("ca.crt");
215
216            // Write dummy PEM data (not valid certificates, just for testing path loading)
217            fs::write(
218                &cert_path,
219                b"-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----\n",
220            )
221            .expect("Failed to write cert");
222            fs::write(
223                &key_path,
224                b"-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----\n",
225            )
226            .expect("Failed to write key");
227            fs::write(
228                &ca_path,
229                b"-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----\n",
230            )
231            .expect("Failed to write CA");
232
233            (Some(cert_path), Some(key_path), Some(ca_path))
234        } else {
235            (None, None, None)
236        };
237
238        let config = ServerConfig {
239            server: ServerSettings {
240                bind_address: "127.0.0.1:50051".to_string(),
241                data_dir,
242                pid_file: temp_dir.path().join("server.pid"),
243                max_connections: 1000,
244                shutdown_timeout_secs: 30,
245            },
246            storage: StorageSettings {
247                engine: "memory".to_string(),
248                wal: crate::config::WalSettings {
249                    enabled: true,
250                    dir: PathBuf::from("wal"),
251                    segment_size_mb: 64,
252                    sync_mode: "interval".to_string(),
253                },
254                memtable_size_mb: 64,
255                block_cache_size_mb: 256,
256                compaction: crate::config::CompactionSettings {
257                    strategy: "leveled".to_string(),
258                    num_levels: 7,
259                    level_multiplier: 10,
260                    max_concurrent: 2,
261                },
262            },
263            network: NetworkSettings {
264                tls_enabled,
265                tls_cert,
266                tls_key,
267                tls_ca,
268                require_client_cert,
269                connection_timeout_secs: 30,
270                keepalive_interval_secs: 60,
271            },
272            cluster: None,
273            logging: LoggingSettings {
274                level: "info".to_string(),
275                format: "json".to_string(),
276                file_enabled: false,
277                file_path: None,
278                rotation: crate::config::LogRotationSettings::default(),
279            },
280            metrics: MetricsSettings {
281                enabled: true,
282                bind_address: "127.0.0.1:9090".to_string(),
283                export_interval_secs: 60,
284            },
285            auth: AuthSettings::default(),
286            authz: AuthorizationSettings {
287                enabled: false,
288                default_role: "user".to_string(),
289                roles_file: None,
290                policies_file: None,
291                collection_permissions: true,
292                default_mode: "deny-by-default".to_string(),
293                audit_enabled: false,
294                audit_log_path: None,
295            },
296            resource_limits: Default::default(),
297            circuit_cache: Default::default(),
298            timeouts: Default::default(),
299        };
300
301        (config, temp_dir)
302    }
303
304    #[test]
305    fn test_tls_disabled() {
306        let (config, _temp_dir) = create_test_config(false, false);
307
308        let result = TlsServerBuilder::build(&config);
309        assert!(result.is_ok());
310        assert!(result.ok().and_then(|x| x).is_none());
311    }
312
313    #[test]
314    fn test_tls_enabled_basic() {
315        let (config, _temp_dir) = create_test_config(true, false);
316
317        let result = TlsServerBuilder::build(&config);
318        // This will fail because the dummy PEM data is invalid,
319        // but it proves that the file loading logic works
320        // In a real test with valid certificates, this would succeed
321        assert!(result.is_ok() || result.is_err());
322    }
323
324    #[test]
325    fn test_mtls_enabled() {
326        let (config, _temp_dir) = create_test_config(true, true);
327
328        let result = TlsServerBuilder::build(&config);
329        // Same as above - proves file loading works
330        assert!(result.is_ok() || result.is_err());
331    }
332
333    #[test]
334    fn test_load_file() {
335        let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
336        let file_path = temp_dir.path().join("test.txt");
337        fs::write(&file_path, b"test content").expect("Failed to write file");
338
339        let content = TlsServerBuilder::load_file(&file_path);
340        assert!(content.is_ok());
341        assert_eq!(content.ok(), Some(b"test content".to_vec()));
342    }
343
344    #[test]
345    fn test_load_file_not_found() {
346        let result = TlsServerBuilder::load_file(Path::new("/nonexistent/file.txt"));
347        assert!(result.is_err());
348    }
349
350    #[test]
351    fn test_client_tls_disabled() {
352        let (config, _temp_dir) = create_test_config(false, false);
353
354        let result = TlsClientBuilder::build(&config.network);
355        assert!(result.is_ok());
356        assert!(result.ok().and_then(|x| x).is_none());
357    }
358
359    #[test]
360    fn test_client_tls_enabled() {
361        let (config, _temp_dir) = create_test_config(true, false);
362
363        let result = TlsClientBuilder::build(&config.network);
364        // Same as server tests - validates logic but not certificate validity
365        assert!(result.is_ok() || result.is_err());
366    }
367}