1use std::sync::Arc;
7
8use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
9use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
10use rustls::{DigitallySignedStruct, SignatureScheme};
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio_rustls::client::TlsStream as ClientTlsStream;
13use tokio_rustls::server::TlsStream as ServerTlsStream;
14use tokio_rustls::{TlsAcceptor, TlsConnector};
15
16use irontide_core::Id20;
17
18use crate::error::{Error, Result};
19
20#[derive(Clone)]
22pub struct SslConfig {
23 pub ca_cert_pem: Vec<u8>,
25 pub our_cert_pem: Vec<u8>,
27 pub our_key_pem: Vec<u8>,
29}
30
31#[derive(Debug)]
39struct SslTorrentServerVerifier {
40 inner: Arc<rustls::client::WebPkiServerVerifier>,
41}
42
43impl SslTorrentServerVerifier {
44 fn new(root_store: Arc<rustls::RootCertStore>) -> Result<Self> {
45 let inner = rustls::client::WebPkiServerVerifier::builder_with_provider(
46 root_store,
47 Arc::new(rustls::crypto::ring::default_provider()),
48 )
49 .build()
50 .map_err(|e| Error::Ssl(format!("server verifier error: {e}")))?;
51 Ok(Self { inner })
52 }
53}
54
55impl ServerCertVerifier for SslTorrentServerVerifier {
56 fn verify_server_cert(
57 &self,
58 end_entity: &CertificateDer<'_>,
59 intermediates: &[CertificateDer<'_>],
60 server_name: &ServerName<'_>,
61 ocsp_response: &[u8],
62 now: UnixTime,
63 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
64 match self.inner.verify_server_cert(
72 end_entity,
73 intermediates,
74 server_name,
75 ocsp_response,
76 now,
77 ) {
78 Ok(verified) => Ok(verified),
79 Err(rustls::Error::InvalidCertificate(ref cert_err))
80 if matches!(
81 cert_err,
82 rustls::CertificateError::NotValidForName
83 | rustls::CertificateError::NotValidForNameContext { .. }
84 ) =>
85 {
86 Ok(ServerCertVerified::assertion())
90 }
91 Err(e) => Err(e),
92 }
93 }
94
95 fn verify_tls12_signature(
96 &self,
97 message: &[u8],
98 cert: &CertificateDer<'_>,
99 dss: &DigitallySignedStruct,
100 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
101 self.inner.verify_tls12_signature(message, cert, dss)
102 }
103
104 fn verify_tls13_signature(
105 &self,
106 message: &[u8],
107 cert: &CertificateDer<'_>,
108 dss: &DigitallySignedStruct,
109 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
110 self.inner.verify_tls13_signature(message, cert, dss)
111 }
112
113 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
114 self.inner.supported_verify_schemes()
115 }
116}
117
118pub fn build_client_config(config: &SslConfig) -> Result<Arc<rustls::ClientConfig>> {
128 let ca_certs = parse_pem_certs(&config.ca_cert_pem)?;
129 let our_certs = parse_pem_certs(&config.our_cert_pem)?;
130 let our_key = parse_pem_key(&config.our_key_pem)?;
131
132 let mut root_store = rustls::RootCertStore::empty();
133 for cert in &ca_certs {
134 root_store
135 .add(cert.clone())
136 .map_err(|e| Error::Ssl(format!("failed to add CA cert: {e}")))?;
137 }
138
139 let verifier = SslTorrentServerVerifier::new(Arc::new(root_store))?;
140
141 let provider = rustls::crypto::ring::default_provider();
142 let client_config = rustls::ClientConfig::builder_with_provider(Arc::new(provider))
143 .with_safe_default_protocol_versions()
144 .map_err(|e| Error::Ssl(format!("protocol version error: {e}")))?
145 .dangerous()
146 .with_custom_certificate_verifier(Arc::new(verifier))
147 .with_client_auth_cert(our_certs, our_key)
148 .map_err(|e| Error::Ssl(format!("client config error: {e}")))?;
149
150 Ok(Arc::new(client_config))
151}
152
153pub fn build_server_config(config: &SslConfig) -> Result<Arc<rustls::ServerConfig>> {
159 let ca_certs = parse_pem_certs(&config.ca_cert_pem)?;
160 let our_certs = parse_pem_certs(&config.our_cert_pem)?;
161 let our_key = parse_pem_key(&config.our_key_pem)?;
162
163 let mut root_store = rustls::RootCertStore::empty();
165 for cert in &ca_certs {
166 root_store
167 .add(cert.clone())
168 .map_err(|e| Error::Ssl(format!("failed to add CA cert: {e}")))?;
169 }
170
171 let client_verifier = rustls::server::WebPkiClientVerifier::builder_with_provider(
172 Arc::new(root_store),
173 Arc::new(rustls::crypto::ring::default_provider()),
174 )
175 .build()
176 .map_err(|e| Error::Ssl(format!("client verifier error: {e}")))?;
177
178 let provider = rustls::crypto::ring::default_provider();
179 let server_config = rustls::ServerConfig::builder_with_provider(Arc::new(provider))
180 .with_safe_default_protocol_versions()
181 .map_err(|e| Error::Ssl(format!("protocol version error: {e}")))?
182 .with_client_cert_verifier(client_verifier)
183 .with_single_cert(our_certs, our_key)
184 .map_err(|e| Error::Ssl(format!("server config error: {e}")))?;
185
186 Ok(Arc::new(server_config))
187}
188
189pub async fn connect_tls<S: AsyncRead + AsyncWrite + Unpin>(
195 stream: S,
196 info_hash: Id20,
197 client_config: Arc<rustls::ClientConfig>,
198) -> Result<ClientTlsStream<S>> {
199 let sni = info_hash.to_hex();
200 let server_name =
201 ServerName::try_from(sni.as_str()).map_err(|e| Error::Ssl(format!("invalid SNI: {e}")))?;
202
203 let connector = TlsConnector::from(client_config);
204 connector
205 .connect(server_name.to_owned(), stream)
206 .await
207 .map_err(|e| Error::Ssl(format!("TLS handshake failed: {e}")))
208}
209
210pub async fn accept_tls<S: AsyncRead + AsyncWrite + Unpin>(
216 stream: S,
217 server_config: Arc<rustls::ServerConfig>,
218) -> Result<ServerTlsStream<S>> {
219 let acceptor = TlsAcceptor::from(server_config);
220 acceptor
221 .accept(stream)
222 .await
223 .map_err(|e| Error::Ssl(format!("TLS accept failed: {e}")))
224}
225
226pub fn generate_self_signed_cert() -> Result<(Vec<u8>, Vec<u8>)> {
234 use rcgen::{CertificateParams, KeyPair};
235
236 let key_pair =
237 KeyPair::generate().map_err(|e| Error::Ssl(format!("key generation failed: {e}")))?;
238
239 let mut params = CertificateParams::new(vec!["torrent-peer".to_string()])
240 .map_err(|e| Error::Ssl(format!("cert params error: {e}")))?;
241 params.distinguished_name.push(
242 rcgen::DnType::CommonName,
243 rcgen::DnValue::Utf8String("torrent-peer".into()),
244 );
245
246 let cert = params
247 .self_signed(&key_pair)
248 .map_err(|e| Error::Ssl(format!("self-signed cert generation failed: {e}")))?;
249
250 let cert_pem = cert.pem().into_bytes();
251 let key_pem = key_pair.serialize_pem().into_bytes();
252
253 Ok((cert_pem, key_pem))
254}
255
256fn parse_pem_certs(pem: &[u8]) -> Result<Vec<CertificateDer<'static>>> {
258 let mut reader = std::io::BufReader::new(pem);
259 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
260 .collect::<std::result::Result<Vec<_>, _>>()
261 .map_err(|e| Error::Ssl(format!("failed to parse PEM certs: {e}")))?;
262
263 if certs.is_empty() {
264 return Err(Error::Ssl("no certificates found in PEM data".into()));
265 }
266
267 Ok(certs)
268}
269
270fn parse_pem_key(pem: &[u8]) -> Result<PrivateKeyDer<'static>> {
272 let mut reader = std::io::BufReader::new(pem);
273 rustls_pemfile::private_key(&mut reader)
274 .map_err(|e| Error::Ssl(format!("failed to parse PEM key: {e}")))?
275 .ok_or_else(|| Error::Ssl("no private key found in PEM data".into()))
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use tokio::io::{AsyncReadExt, AsyncWriteExt};
282
283 #[test]
284 fn generate_self_signed_cert_produces_valid_pem() {
285 let (cert_pem, key_pem) = generate_self_signed_cert().unwrap();
286 assert!(cert_pem.starts_with(b"-----BEGIN CERTIFICATE-----"));
287 assert!(
288 key_pem.starts_with(b"-----BEGIN PRIVATE KEY-----")
289 || key_pem.starts_with(b"-----BEGIN RSA PRIVATE KEY-----")
290 || key_pem.starts_with(b"-----BEGIN EC PRIVATE KEY-----")
291 );
292
293 let certs = parse_pem_certs(&cert_pem).unwrap();
295 assert_eq!(certs.len(), 1);
296 let _key = parse_pem_key(&key_pem).unwrap();
297 }
298
299 #[test]
300 fn parse_pem_certs_rejects_empty() {
301 assert!(parse_pem_certs(b"").is_err());
302 assert!(parse_pem_certs(b"not a cert").is_err());
303 }
304
305 #[test]
306 fn parse_pem_key_rejects_empty() {
307 assert!(parse_pem_key(b"").is_err());
308 }
309
310 #[test]
311 fn build_client_config_with_self_signed() {
312 let (cert_pem, key_pem) = generate_self_signed_cert().unwrap();
314 let config = SslConfig {
315 ca_cert_pem: cert_pem.clone(),
316 our_cert_pem: cert_pem,
317 our_key_pem: key_pem,
318 };
319 let client_config = build_client_config(&config).unwrap();
320 assert!(Arc::strong_count(&client_config) >= 1);
321 }
322
323 #[test]
324 fn build_server_config_with_self_signed() {
325 let (cert_pem, key_pem) = generate_self_signed_cert().unwrap();
326 let config = SslConfig {
327 ca_cert_pem: cert_pem.clone(),
328 our_cert_pem: cert_pem,
329 our_key_pem: key_pem,
330 };
331 let server_config = build_server_config(&config).unwrap();
332 assert!(Arc::strong_count(&server_config) >= 1);
333 }
334
335 fn generate_ca_and_leaf() -> (Vec<u8>, Vec<u8>, Vec<u8>) {
337 use rcgen::{BasicConstraints, CertificateParams, IsCa, KeyPair};
338
339 let ca_key = KeyPair::generate().unwrap();
341 let mut ca_params = CertificateParams::new(vec![]).unwrap();
342 ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
343 ca_params.distinguished_name.push(
344 rcgen::DnType::CommonName,
345 rcgen::DnValue::Utf8String("Test CA".into()),
346 );
347 let ca_cert = ca_params.self_signed(&ca_key).unwrap();
348
349 let leaf_key = KeyPair::generate().unwrap();
351 let mut leaf_params = CertificateParams::new(vec!["torrent-peer".to_string()]).unwrap();
352 leaf_params.distinguished_name.push(
353 rcgen::DnType::CommonName,
354 rcgen::DnValue::Utf8String("torrent-peer".into()),
355 );
356 let leaf_cert = leaf_params.signed_by(&leaf_key, &ca_cert, &ca_key).unwrap();
357
358 (
359 ca_cert.pem().into_bytes(),
360 leaf_cert.pem().into_bytes(),
361 leaf_key.serialize_pem().into_bytes(),
362 )
363 }
364
365 #[tokio::test]
366 async fn tls_handshake_client_server_round_trip() {
367 let (ca_pem, leaf_cert_pem, leaf_key_pem) = generate_ca_and_leaf();
368
369 let ssl_config = SslConfig {
371 ca_cert_pem: ca_pem,
372 our_cert_pem: leaf_cert_pem,
373 our_key_pem: leaf_key_pem,
374 };
375
376 let client_tls_config = build_client_config(&ssl_config).unwrap();
377 let server_tls_config = build_server_config(&ssl_config).unwrap();
378
379 let info_hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
380 let (client_raw, server_raw) = tokio::io::duplex(16384);
381
382 let server_handle = tokio::spawn(async move {
383 let mut tls_stream = accept_tls(server_raw, server_tls_config).await.unwrap();
384 let mut buf = [0u8; 11];
385 tls_stream.read_exact(&mut buf).await.unwrap();
386 assert_eq!(&buf, b"hello world");
387 tls_stream.write_all(b"hello back").await.unwrap();
388 tls_stream.flush().await.unwrap();
389 });
390
391 let mut client_stream = connect_tls(client_raw, info_hash, client_tls_config)
392 .await
393 .unwrap();
394 client_stream.write_all(b"hello world").await.unwrap();
395 client_stream.flush().await.unwrap();
396
397 let mut buf = [0u8; 10];
398 client_stream.read_exact(&mut buf).await.unwrap();
399 assert_eq!(&buf, b"hello back");
400
401 server_handle.await.unwrap();
402 }
403
404 #[tokio::test]
405 async fn tls_handshake_rejects_untrusted_cert() {
406 let (ca1_pem, leaf1_cert_pem, leaf1_key_pem) = generate_ca_and_leaf();
408 let (ca2_pem, leaf2_cert_pem, leaf2_key_pem) = generate_ca_and_leaf();
409
410 let client_config_data = SslConfig {
411 ca_cert_pem: ca1_pem, our_cert_pem: leaf1_cert_pem,
413 our_key_pem: leaf1_key_pem,
414 };
415 let server_config_data = SslConfig {
416 ca_cert_pem: ca2_pem, our_cert_pem: leaf2_cert_pem,
418 our_key_pem: leaf2_key_pem,
419 };
420
421 let client_tls_config = build_client_config(&client_config_data).unwrap();
422 let server_tls_config = build_server_config(&server_config_data).unwrap();
423
424 let info_hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
425 let (client_raw, server_raw) = tokio::io::duplex(16384);
426
427 let server_handle = tokio::spawn(async move {
428 let _ = accept_tls(server_raw, server_tls_config).await;
429 });
430
431 let result = connect_tls(client_raw, info_hash, client_tls_config).await;
433 assert!(result.is_err());
434
435 let _ = server_handle.await;
436 }
437}