Skip to main content

nklave_core/replication/
tls.rs

1//! TLS configuration for replication protocol
2//!
3//! Provides mTLS (mutual TLS) support for primary-passive communication.
4
5use std::fs::File;
6use std::io::BufReader;
7use std::path::PathBuf;
8use std::sync::Arc;
9use tokio::net::TcpStream;
10use tokio_rustls::rustls::{
11    pki_types::{CertificateDer, PrivateKeyDer, ServerName},
12    ClientConfig, RootCertStore, ServerConfig,
13};
14use tokio_rustls::{TlsAcceptor, TlsConnector, TlsStream};
15use tracing::{debug, info};
16
17/// TLS configuration for replication
18#[derive(Debug, Clone)]
19pub struct ReplicationTlsConfig {
20    /// Path to the certificate file (PEM format)
21    pub cert_path: PathBuf,
22
23    /// Path to the private key file (PEM format)
24    pub key_path: PathBuf,
25
26    /// Path to the CA certificate for verifying peers (PEM format)
27    pub ca_cert_path: PathBuf,
28
29    /// Whether to require client certificates (mTLS)
30    pub require_client_cert: bool,
31}
32
33/// Errors from TLS operations
34#[derive(Debug, thiserror::Error)]
35pub enum TlsError {
36    #[error("Failed to read certificate file: {0}")]
37    CertificateRead(String),
38
39    #[error("Failed to read private key file: {0}")]
40    PrivateKeyRead(String),
41
42    #[error("Failed to read CA certificate file: {0}")]
43    CaCertificateRead(String),
44
45    #[error("No certificates found in file")]
46    NoCertificates,
47
48    #[error("No private key found in file")]
49    NoPrivateKey,
50
51    #[error("TLS configuration error: {0}")]
52    Configuration(String),
53
54    #[error("TLS handshake failed: {0}")]
55    Handshake(String),
56
57    #[error("IO error: {0}")]
58    Io(#[from] std::io::Error),
59}
60
61impl ReplicationTlsConfig {
62    /// Create a new TLS configuration
63    pub fn new(cert_path: PathBuf, key_path: PathBuf, ca_cert_path: PathBuf) -> Self {
64        Self {
65            cert_path,
66            key_path,
67            ca_cert_path,
68            require_client_cert: true, // Default to mTLS
69        }
70    }
71
72    /// Load certificates from the certificate file
73    fn load_certs(&self) -> Result<Vec<CertificateDer<'static>>, TlsError> {
74        let file = File::open(&self.cert_path).map_err(|e| {
75            TlsError::CertificateRead(format!("{}: {}", self.cert_path.display(), e))
76        })?;
77        let mut reader = BufReader::new(file);
78
79        let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
80            .filter_map(|r| r.ok())
81            .collect();
82
83        if certs.is_empty() {
84            return Err(TlsError::NoCertificates);
85        }
86
87        debug!(count = certs.len(), path = %self.cert_path.display(), "Loaded certificates");
88        Ok(certs)
89    }
90
91    /// Load the private key from the key file
92    fn load_private_key(&self) -> Result<PrivateKeyDer<'static>, TlsError> {
93        let file = File::open(&self.key_path).map_err(|e| {
94            TlsError::PrivateKeyRead(format!("{}: {}", self.key_path.display(), e))
95        })?;
96        let mut reader = BufReader::new(file);
97
98        let key = rustls_pemfile::private_key(&mut reader)
99            .map_err(|e| TlsError::PrivateKeyRead(e.to_string()))?
100            .ok_or(TlsError::NoPrivateKey)?;
101
102        debug!(path = %self.key_path.display(), "Loaded private key");
103        Ok(key)
104    }
105
106    /// Load CA certificates for peer verification
107    fn load_ca_certs(&self) -> Result<RootCertStore, TlsError> {
108        let file = File::open(&self.ca_cert_path).map_err(|e| {
109            TlsError::CaCertificateRead(format!("{}: {}", self.ca_cert_path.display(), e))
110        })?;
111        let mut reader = BufReader::new(file);
112
113        let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
114            .filter_map(|r| r.ok())
115            .collect();
116
117        if certs.is_empty() {
118            return Err(TlsError::NoCertificates);
119        }
120
121        let mut root_store = RootCertStore::empty();
122        for cert in certs {
123            root_store.add(cert).map_err(|e| {
124                TlsError::Configuration(format!("Failed to add CA certificate: {}", e))
125            })?;
126        }
127
128        debug!(path = %self.ca_cert_path.display(), "Loaded CA certificates");
129        Ok(root_store)
130    }
131
132    /// Build a TLS acceptor for the primary node (server)
133    pub fn build_acceptor(&self) -> Result<TlsAcceptor, TlsError> {
134        let certs = self.load_certs()?;
135        let key = self.load_private_key()?;
136        let ca_certs = self.load_ca_certs()?;
137
138        let mut config = if self.require_client_cert {
139            // mTLS: require and verify client certificates
140            let client_cert_verifier = tokio_rustls::rustls::server::WebPkiClientVerifier::builder(
141                Arc::new(ca_certs),
142            )
143            .build()
144            .map_err(|e| TlsError::Configuration(format!("Failed to build client verifier: {}", e)))?;
145
146            ServerConfig::builder()
147                .with_client_cert_verifier(client_cert_verifier)
148                .with_single_cert(certs, key)
149                .map_err(|e| TlsError::Configuration(e.to_string()))?
150        } else {
151            // Server-only TLS, no client cert required
152            ServerConfig::builder()
153                .with_no_client_auth()
154                .with_single_cert(certs, key)
155                .map_err(|e| TlsError::Configuration(e.to_string()))?
156        };
157
158        // Disable TLS 1.2 to only allow TLS 1.3
159        config.alpn_protocols = vec![b"nklave-repl".to_vec()];
160
161        info!(
162            require_client_cert = self.require_client_cert,
163            "TLS acceptor configured for primary"
164        );
165
166        Ok(TlsAcceptor::from(Arc::new(config)))
167    }
168
169    /// Build a TLS connector for the passive node (client)
170    pub fn build_connector(&self) -> Result<TlsConnector, TlsError> {
171        let certs = self.load_certs()?;
172        let key = self.load_private_key()?;
173        let ca_certs = self.load_ca_certs()?;
174
175        let mut config = ClientConfig::builder()
176            .with_root_certificates(ca_certs)
177            .with_client_auth_cert(certs, key)
178            .map_err(|e| TlsError::Configuration(e.to_string()))?;
179
180        config.alpn_protocols = vec![b"nklave-repl".to_vec()];
181
182        info!("TLS connector configured for passive");
183
184        Ok(TlsConnector::from(Arc::new(config)))
185    }
186}
187
188/// Wrapper for TLS stream that works with both client and server connections
189pub enum ReplicationTlsStream {
190    /// Server-side TLS stream (primary accepting passive connections)
191    Server(TlsStream<TcpStream>),
192    /// Client-side TLS stream (passive connecting to primary)
193    Client(TlsStream<TcpStream>),
194    /// Plain TCP stream (when TLS is disabled)
195    Plain(TcpStream),
196}
197
198impl ReplicationTlsStream {
199    /// Accept a TLS connection on the server side
200    pub async fn accept(
201        acceptor: &TlsAcceptor,
202        stream: TcpStream,
203    ) -> Result<Self, TlsError> {
204        let tls_stream = acceptor.accept(stream).await.map_err(|e| {
205            TlsError::Handshake(format!("Server TLS handshake failed: {}", e))
206        })?;
207        Ok(ReplicationTlsStream::Server(TlsStream::Server(tls_stream)))
208    }
209
210    /// Connect with TLS on the client side
211    pub async fn connect(
212        connector: &TlsConnector,
213        server_name: &str,
214        stream: TcpStream,
215    ) -> Result<Self, TlsError> {
216        let server_name = ServerName::try_from(server_name.to_string())
217            .map_err(|e| TlsError::Configuration(format!("Invalid server name: {}", e)))?;
218
219        let tls_stream = connector.connect(server_name, stream).await.map_err(|e| {
220            TlsError::Handshake(format!("Client TLS handshake failed: {}", e))
221        })?;
222        Ok(ReplicationTlsStream::Client(TlsStream::Client(tls_stream)))
223    }
224
225    /// Wrap a plain TCP stream (when TLS is disabled)
226    pub fn plain(stream: TcpStream) -> Self {
227        ReplicationTlsStream::Plain(stream)
228    }
229}
230
231// Implement AsyncRead and AsyncWrite for ReplicationTlsStream
232impl tokio::io::AsyncRead for ReplicationTlsStream {
233    fn poll_read(
234        self: std::pin::Pin<&mut Self>,
235        cx: &mut std::task::Context<'_>,
236        buf: &mut tokio::io::ReadBuf<'_>,
237    ) -> std::task::Poll<std::io::Result<()>> {
238        match self.get_mut() {
239            ReplicationTlsStream::Server(s) => std::pin::Pin::new(s).poll_read(cx, buf),
240            ReplicationTlsStream::Client(s) => std::pin::Pin::new(s).poll_read(cx, buf),
241            ReplicationTlsStream::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
242        }
243    }
244}
245
246impl tokio::io::AsyncWrite for ReplicationTlsStream {
247    fn poll_write(
248        self: std::pin::Pin<&mut Self>,
249        cx: &mut std::task::Context<'_>,
250        buf: &[u8],
251    ) -> std::task::Poll<Result<usize, std::io::Error>> {
252        match self.get_mut() {
253            ReplicationTlsStream::Server(s) => std::pin::Pin::new(s).poll_write(cx, buf),
254            ReplicationTlsStream::Client(s) => std::pin::Pin::new(s).poll_write(cx, buf),
255            ReplicationTlsStream::Plain(s) => std::pin::Pin::new(s).poll_write(cx, buf),
256        }
257    }
258
259    fn poll_flush(
260        self: std::pin::Pin<&mut Self>,
261        cx: &mut std::task::Context<'_>,
262    ) -> std::task::Poll<Result<(), std::io::Error>> {
263        match self.get_mut() {
264            ReplicationTlsStream::Server(s) => std::pin::Pin::new(s).poll_flush(cx),
265            ReplicationTlsStream::Client(s) => std::pin::Pin::new(s).poll_flush(cx),
266            ReplicationTlsStream::Plain(s) => std::pin::Pin::new(s).poll_flush(cx),
267        }
268    }
269
270    fn poll_shutdown(
271        self: std::pin::Pin<&mut Self>,
272        cx: &mut std::task::Context<'_>,
273    ) -> std::task::Poll<Result<(), std::io::Error>> {
274        match self.get_mut() {
275            ReplicationTlsStream::Server(s) => std::pin::Pin::new(s).poll_shutdown(cx),
276            ReplicationTlsStream::Client(s) => std::pin::Pin::new(s).poll_shutdown(cx),
277            ReplicationTlsStream::Plain(s) => std::pin::Pin::new(s).poll_shutdown(cx),
278        }
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use std::path::PathBuf;
286
287    #[test]
288    fn test_tls_config_creation() {
289        let config = ReplicationTlsConfig::new(
290            PathBuf::from("/path/to/cert.pem"),
291            PathBuf::from("/path/to/key.pem"),
292            PathBuf::from("/path/to/ca.pem"),
293        );
294
295        assert!(config.require_client_cert);
296        assert_eq!(config.cert_path, PathBuf::from("/path/to/cert.pem"));
297    }
298
299    #[test]
300    fn test_tls_config_without_client_cert() {
301        let mut config = ReplicationTlsConfig::new(
302            PathBuf::from("/path/to/cert.pem"),
303            PathBuf::from("/path/to/key.pem"),
304            PathBuf::from("/path/to/ca.pem"),
305        );
306        config.require_client_cert = false;
307
308        assert!(!config.require_client_cert);
309    }
310}