1use std::sync::Arc;
7
8use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
9use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
10use rustls::{DigitallySignedStruct, SignatureScheme};
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio_rustls::client::TlsStream as ClientTlsStream;
13use tokio_rustls::server::TlsStream as ServerTlsStream;
14use tokio_rustls::{TlsAcceptor, TlsConnector};
15
16use irontide_core::Id20;
17
18use crate::error::{Error, Result};
19
20#[derive(Clone)]
22pub struct SslConfig {
23 pub ca_cert_pem: Vec<u8>,
25 pub our_cert_pem: Vec<u8>,
27 pub our_key_pem: Vec<u8>,
29}
30
31#[derive(Debug)]
39struct SslTorrentServerVerifier {
40 inner: Arc<rustls::client::WebPkiServerVerifier>,
41}
42
43impl SslTorrentServerVerifier {
44 fn new(root_store: Arc<rustls::RootCertStore>) -> Result<Self> {
45 let inner = rustls::client::WebPkiServerVerifier::builder_with_provider(
46 root_store,
47 Arc::new(rustls::crypto::ring::default_provider()),
48 )
49 .build()
50 .map_err(|e| Error::Ssl(format!("server verifier error: {e}")))?;
51 Ok(Self { inner })
52 }
53}
54
55impl ServerCertVerifier for SslTorrentServerVerifier {
56 fn verify_server_cert(
57 &self,
58 end_entity: &CertificateDer<'_>,
59 intermediates: &[CertificateDer<'_>],
60 _server_name: &ServerName<'_>,
61 _ocsp_response: &[u8],
62 now: UnixTime,
63 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
64 match self.inner.verify_server_cert(
72 end_entity,
73 intermediates,
74 _server_name,
75 _ocsp_response,
76 now,
77 ) {
78 Ok(verified) => Ok(verified),
79 Err(rustls::Error::InvalidCertificate(ref cert_err))
80 if matches!(
81 cert_err,
82 rustls::CertificateError::NotValidForName
83 | rustls::CertificateError::NotValidForNameContext { .. }
84 ) =>
85 {
86 Ok(ServerCertVerified::assertion())
90 }
91 Err(e) => Err(e),
92 }
93 }
94
95 fn verify_tls12_signature(
96 &self,
97 message: &[u8],
98 cert: &CertificateDer<'_>,
99 dss: &DigitallySignedStruct,
100 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
101 self.inner.verify_tls12_signature(message, cert, dss)
102 }
103
104 fn verify_tls13_signature(
105 &self,
106 message: &[u8],
107 cert: &CertificateDer<'_>,
108 dss: &DigitallySignedStruct,
109 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
110 self.inner.verify_tls13_signature(message, cert, dss)
111 }
112
113 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
114 self.inner.supported_verify_schemes()
115 }
116}
117
118pub fn build_client_config(config: &SslConfig) -> Result<Arc<rustls::ClientConfig>> {
124 let ca_certs = parse_pem_certs(&config.ca_cert_pem)?;
125 let our_certs = parse_pem_certs(&config.our_cert_pem)?;
126 let our_key = parse_pem_key(&config.our_key_pem)?;
127
128 let mut root_store = rustls::RootCertStore::empty();
129 for cert in &ca_certs {
130 root_store
131 .add(cert.clone())
132 .map_err(|e| Error::Ssl(format!("failed to add CA cert: {e}")))?;
133 }
134
135 let verifier = SslTorrentServerVerifier::new(Arc::new(root_store))?;
136
137 let provider = rustls::crypto::ring::default_provider();
138 let client_config = rustls::ClientConfig::builder_with_provider(Arc::new(provider))
139 .with_safe_default_protocol_versions()
140 .map_err(|e| Error::Ssl(format!("protocol version error: {e}")))?
141 .dangerous()
142 .with_custom_certificate_verifier(Arc::new(verifier))
143 .with_client_auth_cert(our_certs, our_key)
144 .map_err(|e| Error::Ssl(format!("client config error: {e}")))?;
145
146 Ok(Arc::new(client_config))
147}
148
149pub fn build_server_config(config: &SslConfig) -> Result<Arc<rustls::ServerConfig>> {
151 let ca_certs = parse_pem_certs(&config.ca_cert_pem)?;
152 let our_certs = parse_pem_certs(&config.our_cert_pem)?;
153 let our_key = parse_pem_key(&config.our_key_pem)?;
154
155 let mut root_store = rustls::RootCertStore::empty();
157 for cert in &ca_certs {
158 root_store
159 .add(cert.clone())
160 .map_err(|e| Error::Ssl(format!("failed to add CA cert: {e}")))?;
161 }
162
163 let client_verifier = rustls::server::WebPkiClientVerifier::builder_with_provider(
164 Arc::new(root_store),
165 Arc::new(rustls::crypto::ring::default_provider()),
166 )
167 .build()
168 .map_err(|e| Error::Ssl(format!("client verifier error: {e}")))?;
169
170 let provider = rustls::crypto::ring::default_provider();
171 let server_config = rustls::ServerConfig::builder_with_provider(Arc::new(provider))
172 .with_safe_default_protocol_versions()
173 .map_err(|e| Error::Ssl(format!("protocol version error: {e}")))?
174 .with_client_cert_verifier(client_verifier)
175 .with_single_cert(our_certs, our_key)
176 .map_err(|e| Error::Ssl(format!("server config error: {e}")))?;
177
178 Ok(Arc::new(server_config))
179}
180
181pub async fn connect_tls<S: AsyncRead + AsyncWrite + Unpin>(
183 stream: S,
184 info_hash: Id20,
185 client_config: Arc<rustls::ClientConfig>,
186) -> Result<ClientTlsStream<S>> {
187 let sni = info_hash.to_hex();
188 let server_name =
189 ServerName::try_from(sni.as_str()).map_err(|e| Error::Ssl(format!("invalid SNI: {e}")))?;
190
191 let connector = TlsConnector::from(client_config);
192 connector
193 .connect(server_name.to_owned(), stream)
194 .await
195 .map_err(|e| Error::Ssl(format!("TLS handshake failed: {e}")))
196}
197
198pub async fn accept_tls<S: AsyncRead + AsyncWrite + Unpin>(
200 stream: S,
201 server_config: Arc<rustls::ServerConfig>,
202) -> Result<ServerTlsStream<S>> {
203 let acceptor = TlsAcceptor::from(server_config);
204 acceptor
205 .accept(stream)
206 .await
207 .map_err(|e| Error::Ssl(format!("TLS accept failed: {e}")))
208}
209
210pub fn generate_self_signed_cert() -> Result<(Vec<u8>, Vec<u8>)> {
214 use rcgen::{CertificateParams, KeyPair};
215
216 let key_pair =
217 KeyPair::generate().map_err(|e| Error::Ssl(format!("key generation failed: {e}")))?;
218
219 let mut params = CertificateParams::new(vec!["torrent-peer".to_string()])
220 .map_err(|e| Error::Ssl(format!("cert params error: {e}")))?;
221 params.distinguished_name.push(
222 rcgen::DnType::CommonName,
223 rcgen::DnValue::Utf8String("torrent-peer".into()),
224 );
225
226 let cert = params
227 .self_signed(&key_pair)
228 .map_err(|e| Error::Ssl(format!("self-signed cert generation failed: {e}")))?;
229
230 let cert_pem = cert.pem().into_bytes();
231 let key_pem = key_pair.serialize_pem().into_bytes();
232
233 Ok((cert_pem, key_pem))
234}
235
236fn parse_pem_certs(pem: &[u8]) -> Result<Vec<CertificateDer<'static>>> {
238 let mut reader = std::io::BufReader::new(pem);
239 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
240 .collect::<std::result::Result<Vec<_>, _>>()
241 .map_err(|e| Error::Ssl(format!("failed to parse PEM certs: {e}")))?;
242
243 if certs.is_empty() {
244 return Err(Error::Ssl("no certificates found in PEM data".into()));
245 }
246
247 Ok(certs)
248}
249
250fn parse_pem_key(pem: &[u8]) -> Result<PrivateKeyDer<'static>> {
252 let mut reader = std::io::BufReader::new(pem);
253 rustls_pemfile::private_key(&mut reader)
254 .map_err(|e| Error::Ssl(format!("failed to parse PEM key: {e}")))?
255 .ok_or_else(|| Error::Ssl("no private key found in PEM data".into()))
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use tokio::io::{AsyncReadExt, AsyncWriteExt};
262
263 #[test]
264 fn generate_self_signed_cert_produces_valid_pem() {
265 let (cert_pem, key_pem) = generate_self_signed_cert().unwrap();
266 assert!(cert_pem.starts_with(b"-----BEGIN CERTIFICATE-----"));
267 assert!(
268 key_pem.starts_with(b"-----BEGIN PRIVATE KEY-----")
269 || key_pem.starts_with(b"-----BEGIN RSA PRIVATE KEY-----")
270 || key_pem.starts_with(b"-----BEGIN EC PRIVATE KEY-----")
271 );
272
273 let certs = parse_pem_certs(&cert_pem).unwrap();
275 assert_eq!(certs.len(), 1);
276 let _key = parse_pem_key(&key_pem).unwrap();
277 }
278
279 #[test]
280 fn parse_pem_certs_rejects_empty() {
281 assert!(parse_pem_certs(b"").is_err());
282 assert!(parse_pem_certs(b"not a cert").is_err());
283 }
284
285 #[test]
286 fn parse_pem_key_rejects_empty() {
287 assert!(parse_pem_key(b"").is_err());
288 }
289
290 #[test]
291 fn build_client_config_with_self_signed() {
292 let (cert_pem, key_pem) = generate_self_signed_cert().unwrap();
294 let config = SslConfig {
295 ca_cert_pem: cert_pem.clone(),
296 our_cert_pem: cert_pem,
297 our_key_pem: key_pem,
298 };
299 let client_config = build_client_config(&config).unwrap();
300 assert!(Arc::strong_count(&client_config) >= 1);
301 }
302
303 #[test]
304 fn build_server_config_with_self_signed() {
305 let (cert_pem, key_pem) = generate_self_signed_cert().unwrap();
306 let config = SslConfig {
307 ca_cert_pem: cert_pem.clone(),
308 our_cert_pem: cert_pem,
309 our_key_pem: key_pem,
310 };
311 let server_config = build_server_config(&config).unwrap();
312 assert!(Arc::strong_count(&server_config) >= 1);
313 }
314
315 fn generate_ca_and_leaf() -> (Vec<u8>, Vec<u8>, Vec<u8>) {
317 use rcgen::{BasicConstraints, CertificateParams, IsCa, KeyPair};
318
319 let ca_key = KeyPair::generate().unwrap();
321 let mut ca_params = CertificateParams::new(vec![]).unwrap();
322 ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
323 ca_params.distinguished_name.push(
324 rcgen::DnType::CommonName,
325 rcgen::DnValue::Utf8String("Test CA".into()),
326 );
327 let ca_cert = ca_params.self_signed(&ca_key).unwrap();
328
329 let leaf_key = KeyPair::generate().unwrap();
331 let mut leaf_params = CertificateParams::new(vec!["torrent-peer".to_string()]).unwrap();
332 leaf_params.distinguished_name.push(
333 rcgen::DnType::CommonName,
334 rcgen::DnValue::Utf8String("torrent-peer".into()),
335 );
336 let leaf_cert = leaf_params.signed_by(&leaf_key, &ca_cert, &ca_key).unwrap();
337
338 (
339 ca_cert.pem().into_bytes(),
340 leaf_cert.pem().into_bytes(),
341 leaf_key.serialize_pem().into_bytes(),
342 )
343 }
344
345 #[tokio::test]
346 async fn tls_handshake_client_server_round_trip() {
347 let (ca_pem, leaf_cert_pem, leaf_key_pem) = generate_ca_and_leaf();
348
349 let ssl_config = SslConfig {
351 ca_cert_pem: ca_pem,
352 our_cert_pem: leaf_cert_pem,
353 our_key_pem: leaf_key_pem,
354 };
355
356 let client_tls_config = build_client_config(&ssl_config).unwrap();
357 let server_tls_config = build_server_config(&ssl_config).unwrap();
358
359 let info_hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
360 let (client_raw, server_raw) = tokio::io::duplex(16384);
361
362 let server_handle = tokio::spawn(async move {
363 let mut tls_stream = accept_tls(server_raw, server_tls_config).await.unwrap();
364 let mut buf = [0u8; 11];
365 tls_stream.read_exact(&mut buf).await.unwrap();
366 assert_eq!(&buf, b"hello world");
367 tls_stream.write_all(b"hello back").await.unwrap();
368 tls_stream.flush().await.unwrap();
369 });
370
371 let mut client_stream = connect_tls(client_raw, info_hash, client_tls_config)
372 .await
373 .unwrap();
374 client_stream.write_all(b"hello world").await.unwrap();
375 client_stream.flush().await.unwrap();
376
377 let mut buf = [0u8; 10];
378 client_stream.read_exact(&mut buf).await.unwrap();
379 assert_eq!(&buf, b"hello back");
380
381 server_handle.await.unwrap();
382 }
383
384 #[tokio::test]
385 async fn tls_handshake_rejects_untrusted_cert() {
386 let (ca1_pem, leaf1_cert_pem, leaf1_key_pem) = generate_ca_and_leaf();
388 let (ca2_pem, leaf2_cert_pem, leaf2_key_pem) = generate_ca_and_leaf();
389
390 let client_config_data = SslConfig {
391 ca_cert_pem: ca1_pem, our_cert_pem: leaf1_cert_pem,
393 our_key_pem: leaf1_key_pem,
394 };
395 let server_config_data = SslConfig {
396 ca_cert_pem: ca2_pem, our_cert_pem: leaf2_cert_pem,
398 our_key_pem: leaf2_key_pem,
399 };
400
401 let client_tls_config = build_client_config(&client_config_data).unwrap();
402 let server_tls_config = build_server_config(&server_config_data).unwrap();
403
404 let info_hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
405 let (client_raw, server_raw) = tokio::io::duplex(16384);
406
407 let server_handle = tokio::spawn(async move {
408 let _ = accept_tls(server_raw, server_tls_config).await;
409 });
410
411 let result = connect_tls(client_raw, info_hash, client_tls_config).await;
413 assert!(result.is_err());
414
415 let _ = server_handle.await;
416 }
417}