1use std::fs::File;
6use std::io::BufReader;
7use std::path::PathBuf;
8use std::sync::Arc;
9use tokio::net::TcpStream;
10use tokio_rustls::rustls::{
11 pki_types::{CertificateDer, PrivateKeyDer, ServerName},
12 ClientConfig, RootCertStore, ServerConfig,
13};
14use tokio_rustls::{TlsAcceptor, TlsConnector, TlsStream};
15use tracing::{debug, info};
16
17#[derive(Debug, Clone)]
19pub struct ReplicationTlsConfig {
20 pub cert_path: PathBuf,
22
23 pub key_path: PathBuf,
25
26 pub ca_cert_path: PathBuf,
28
29 pub require_client_cert: bool,
31}
32
33#[derive(Debug, thiserror::Error)]
35pub enum TlsError {
36 #[error("Failed to read certificate file: {0}")]
37 CertificateRead(String),
38
39 #[error("Failed to read private key file: {0}")]
40 PrivateKeyRead(String),
41
42 #[error("Failed to read CA certificate file: {0}")]
43 CaCertificateRead(String),
44
45 #[error("No certificates found in file")]
46 NoCertificates,
47
48 #[error("No private key found in file")]
49 NoPrivateKey,
50
51 #[error("TLS configuration error: {0}")]
52 Configuration(String),
53
54 #[error("TLS handshake failed: {0}")]
55 Handshake(String),
56
57 #[error("IO error: {0}")]
58 Io(#[from] std::io::Error),
59}
60
61impl ReplicationTlsConfig {
62 pub fn new(cert_path: PathBuf, key_path: PathBuf, ca_cert_path: PathBuf) -> Self {
64 Self {
65 cert_path,
66 key_path,
67 ca_cert_path,
68 require_client_cert: true, }
70 }
71
72 fn load_certs(&self) -> Result<Vec<CertificateDer<'static>>, TlsError> {
74 let file = File::open(&self.cert_path).map_err(|e| {
75 TlsError::CertificateRead(format!("{}: {}", self.cert_path.display(), e))
76 })?;
77 let mut reader = BufReader::new(file);
78
79 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
80 .filter_map(|r| r.ok())
81 .collect();
82
83 if certs.is_empty() {
84 return Err(TlsError::NoCertificates);
85 }
86
87 debug!(count = certs.len(), path = %self.cert_path.display(), "Loaded certificates");
88 Ok(certs)
89 }
90
91 fn load_private_key(&self) -> Result<PrivateKeyDer<'static>, TlsError> {
93 let file = File::open(&self.key_path).map_err(|e| {
94 TlsError::PrivateKeyRead(format!("{}: {}", self.key_path.display(), e))
95 })?;
96 let mut reader = BufReader::new(file);
97
98 let key = rustls_pemfile::private_key(&mut reader)
99 .map_err(|e| TlsError::PrivateKeyRead(e.to_string()))?
100 .ok_or(TlsError::NoPrivateKey)?;
101
102 debug!(path = %self.key_path.display(), "Loaded private key");
103 Ok(key)
104 }
105
106 fn load_ca_certs(&self) -> Result<RootCertStore, TlsError> {
108 let file = File::open(&self.ca_cert_path).map_err(|e| {
109 TlsError::CaCertificateRead(format!("{}: {}", self.ca_cert_path.display(), e))
110 })?;
111 let mut reader = BufReader::new(file);
112
113 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
114 .filter_map(|r| r.ok())
115 .collect();
116
117 if certs.is_empty() {
118 return Err(TlsError::NoCertificates);
119 }
120
121 let mut root_store = RootCertStore::empty();
122 for cert in certs {
123 root_store.add(cert).map_err(|e| {
124 TlsError::Configuration(format!("Failed to add CA certificate: {}", e))
125 })?;
126 }
127
128 debug!(path = %self.ca_cert_path.display(), "Loaded CA certificates");
129 Ok(root_store)
130 }
131
132 pub fn build_acceptor(&self) -> Result<TlsAcceptor, TlsError> {
134 let certs = self.load_certs()?;
135 let key = self.load_private_key()?;
136 let ca_certs = self.load_ca_certs()?;
137
138 let mut config = if self.require_client_cert {
139 let client_cert_verifier = tokio_rustls::rustls::server::WebPkiClientVerifier::builder(
141 Arc::new(ca_certs),
142 )
143 .build()
144 .map_err(|e| TlsError::Configuration(format!("Failed to build client verifier: {}", e)))?;
145
146 ServerConfig::builder()
147 .with_client_cert_verifier(client_cert_verifier)
148 .with_single_cert(certs, key)
149 .map_err(|e| TlsError::Configuration(e.to_string()))?
150 } else {
151 ServerConfig::builder()
153 .with_no_client_auth()
154 .with_single_cert(certs, key)
155 .map_err(|e| TlsError::Configuration(e.to_string()))?
156 };
157
158 config.alpn_protocols = vec![b"nklave-repl".to_vec()];
160
161 info!(
162 require_client_cert = self.require_client_cert,
163 "TLS acceptor configured for primary"
164 );
165
166 Ok(TlsAcceptor::from(Arc::new(config)))
167 }
168
169 pub fn build_connector(&self) -> Result<TlsConnector, TlsError> {
171 let certs = self.load_certs()?;
172 let key = self.load_private_key()?;
173 let ca_certs = self.load_ca_certs()?;
174
175 let mut config = ClientConfig::builder()
176 .with_root_certificates(ca_certs)
177 .with_client_auth_cert(certs, key)
178 .map_err(|e| TlsError::Configuration(e.to_string()))?;
179
180 config.alpn_protocols = vec![b"nklave-repl".to_vec()];
181
182 info!("TLS connector configured for passive");
183
184 Ok(TlsConnector::from(Arc::new(config)))
185 }
186}
187
188pub enum ReplicationTlsStream {
190 Server(TlsStream<TcpStream>),
192 Client(TlsStream<TcpStream>),
194 Plain(TcpStream),
196}
197
198impl ReplicationTlsStream {
199 pub async fn accept(
201 acceptor: &TlsAcceptor,
202 stream: TcpStream,
203 ) -> Result<Self, TlsError> {
204 let tls_stream = acceptor.accept(stream).await.map_err(|e| {
205 TlsError::Handshake(format!("Server TLS handshake failed: {}", e))
206 })?;
207 Ok(ReplicationTlsStream::Server(TlsStream::Server(tls_stream)))
208 }
209
210 pub async fn connect(
212 connector: &TlsConnector,
213 server_name: &str,
214 stream: TcpStream,
215 ) -> Result<Self, TlsError> {
216 let server_name = ServerName::try_from(server_name.to_string())
217 .map_err(|e| TlsError::Configuration(format!("Invalid server name: {}", e)))?;
218
219 let tls_stream = connector.connect(server_name, stream).await.map_err(|e| {
220 TlsError::Handshake(format!("Client TLS handshake failed: {}", e))
221 })?;
222 Ok(ReplicationTlsStream::Client(TlsStream::Client(tls_stream)))
223 }
224
225 pub fn plain(stream: TcpStream) -> Self {
227 ReplicationTlsStream::Plain(stream)
228 }
229}
230
231impl tokio::io::AsyncRead for ReplicationTlsStream {
233 fn poll_read(
234 self: std::pin::Pin<&mut Self>,
235 cx: &mut std::task::Context<'_>,
236 buf: &mut tokio::io::ReadBuf<'_>,
237 ) -> std::task::Poll<std::io::Result<()>> {
238 match self.get_mut() {
239 ReplicationTlsStream::Server(s) => std::pin::Pin::new(s).poll_read(cx, buf),
240 ReplicationTlsStream::Client(s) => std::pin::Pin::new(s).poll_read(cx, buf),
241 ReplicationTlsStream::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
242 }
243 }
244}
245
246impl tokio::io::AsyncWrite for ReplicationTlsStream {
247 fn poll_write(
248 self: std::pin::Pin<&mut Self>,
249 cx: &mut std::task::Context<'_>,
250 buf: &[u8],
251 ) -> std::task::Poll<Result<usize, std::io::Error>> {
252 match self.get_mut() {
253 ReplicationTlsStream::Server(s) => std::pin::Pin::new(s).poll_write(cx, buf),
254 ReplicationTlsStream::Client(s) => std::pin::Pin::new(s).poll_write(cx, buf),
255 ReplicationTlsStream::Plain(s) => std::pin::Pin::new(s).poll_write(cx, buf),
256 }
257 }
258
259 fn poll_flush(
260 self: std::pin::Pin<&mut Self>,
261 cx: &mut std::task::Context<'_>,
262 ) -> std::task::Poll<Result<(), std::io::Error>> {
263 match self.get_mut() {
264 ReplicationTlsStream::Server(s) => std::pin::Pin::new(s).poll_flush(cx),
265 ReplicationTlsStream::Client(s) => std::pin::Pin::new(s).poll_flush(cx),
266 ReplicationTlsStream::Plain(s) => std::pin::Pin::new(s).poll_flush(cx),
267 }
268 }
269
270 fn poll_shutdown(
271 self: std::pin::Pin<&mut Self>,
272 cx: &mut std::task::Context<'_>,
273 ) -> std::task::Poll<Result<(), std::io::Error>> {
274 match self.get_mut() {
275 ReplicationTlsStream::Server(s) => std::pin::Pin::new(s).poll_shutdown(cx),
276 ReplicationTlsStream::Client(s) => std::pin::Pin::new(s).poll_shutdown(cx),
277 ReplicationTlsStream::Plain(s) => std::pin::Pin::new(s).poll_shutdown(cx),
278 }
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use std::path::PathBuf;
286
287 #[test]
288 fn test_tls_config_creation() {
289 let config = ReplicationTlsConfig::new(
290 PathBuf::from("/path/to/cert.pem"),
291 PathBuf::from("/path/to/key.pem"),
292 PathBuf::from("/path/to/ca.pem"),
293 );
294
295 assert!(config.require_client_cert);
296 assert_eq!(config.cert_path, PathBuf::from("/path/to/cert.pem"));
297 }
298
299 #[test]
300 fn test_tls_config_without_client_cert() {
301 let mut config = ReplicationTlsConfig::new(
302 PathBuf::from("/path/to/cert.pem"),
303 PathBuf::from("/path/to/key.pem"),
304 PathBuf::from("/path/to/ca.pem"),
305 );
306 config.require_client_cert = false;
307
308 assert!(!config.require_client_cert);
309 }
310}