1use std::{any::Any, io, str, sync::Arc};
2
3#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
4use aws_lc_rs::aead;
5use bytes::BytesMut;
6#[cfg(feature = "ring")]
7use ring::aead;
8pub use rustls::Error;
9use rustls::{
10 self, CipherSuite,
11 client::danger::ServerCertVerifier,
12 pki_types::{CertificateDer, PrivateKeyDer, ServerName},
13 quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version},
14};
15#[cfg(feature = "platform-verifier")]
16use rustls_platform_verifier::BuilderVerifierExt;
17
18use crate::{
19 ConnectError, ConnectionId, Side, TransportError, TransportErrorCode,
20 crypto::{
21 self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion,
22 tls_extension_simulation::{ExtensionAwareTlsSession, SimulatedExtensionContext, TlsExtensionHooks},
23 },
24 transport_parameters::TransportParameters,
25};
26
27impl From<Side> for rustls::Side {
28 fn from(s: Side) -> Self {
29 match s {
30 Side::Client => Self::Client,
31 Side::Server => Self::Server,
32 }
33 }
34}
35
36pub struct TlsSession {
38 version: Version,
39 got_handshake_data: bool,
40 next_secrets: Option<Secrets>,
41 inner: Connection,
42 suite: Suite,
43}
44
45impl TlsSession {
46 fn side(&self) -> Side {
47 match self.inner {
48 Connection::Client(_) => Side::Client,
49 Connection::Server(_) => Side::Server,
50 }
51 }
52}
53
54impl crypto::Session for TlsSession {
55 fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys {
56 initial_keys(self.version, *dst_cid, side, &self.suite)
57 }
58
59 fn handshake_data(&self) -> Option<Box<dyn Any>> {
60 if !self.got_handshake_data {
61 return None;
62 }
63 Some(Box::new(HandshakeData {
64 protocol: self.inner.alpn_protocol().map(|x| x.into()),
65 server_name: match self.inner {
66 Connection::Client(_) => None,
67 Connection::Server(ref session) => session.server_name().map(|x| x.into()),
68 },
69 }))
70 }
71
72 fn peer_identity(&self) -> Option<Box<dyn Any>> {
74 self.inner.peer_certificates().map(|v| -> Box<dyn Any> {
75 Box::new(
76 v.iter()
77 .map(|v| v.clone().into_owned())
78 .collect::<Vec<CertificateDer<'static>>>(),
79 )
80 })
81 }
82
83 fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn crypto::PacketKey>)> {
84 let keys = self.inner.zero_rtt_keys()?;
85 Some((Box::new(keys.header), Box::new(keys.packet)))
86 }
87
88 fn early_data_accepted(&self) -> Option<bool> {
89 match self.inner {
90 Connection::Client(ref session) => Some(session.is_early_data_accepted()),
91 _ => None,
92 }
93 }
94
95 fn is_handshaking(&self) -> bool {
96 self.inner.is_handshaking()
97 }
98
99 fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError> {
100 self.inner.read_hs(buf).map_err(|e| {
101 if let Some(alert) = self.inner.alert() {
102 TransportError {
103 code: TransportErrorCode::crypto(alert.into()),
104 frame: None,
105 reason: e.to_string(),
106 }
107 } else {
108 TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}"))
109 }
110 })?;
111 if !self.got_handshake_data {
112 let have_server_name = match self.inner {
116 Connection::Client(_) => false,
117 Connection::Server(ref session) => session.server_name().is_some(),
118 };
119 if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() {
120 self.got_handshake_data = true;
121 return Ok(true);
122 }
123 }
124 Ok(false)
125 }
126
127 fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError> {
128 match self.inner.quic_transport_parameters() {
129 None => Ok(None),
130 Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) {
131 Ok(params) => Ok(Some(params)),
132 Err(e) => Err(e.into()),
133 },
134 }
135 }
136
137 fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys> {
138 let keys = match self.inner.write_hs(buf)? {
139 KeyChange::Handshake { keys } => keys,
140 KeyChange::OneRtt { keys, next } => {
141 self.next_secrets = Some(next);
142 keys
143 }
144 };
145
146 Some(Keys {
147 header: KeyPair {
148 local: Box::new(keys.local.header),
149 remote: Box::new(keys.remote.header),
150 },
151 packet: KeyPair {
152 local: Box::new(keys.local.packet),
153 remote: Box::new(keys.remote.packet),
154 },
155 })
156 }
157
158 fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn crypto::PacketKey>>> {
159 let secrets = self.next_secrets.as_mut()?;
160 let keys = secrets.next_packet_keys();
161 Some(KeyPair {
162 local: Box::new(keys.local),
163 remote: Box::new(keys.remote),
164 })
165 }
166
167 fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool {
168 let tag_start = match payload.len().checked_sub(16) {
169 Some(x) => x,
170 None => return false,
171 };
172
173 let mut pseudo_packet =
174 Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
175 pseudo_packet.push(orig_dst_cid.len() as u8);
176 pseudo_packet.extend_from_slice(orig_dst_cid);
177 pseudo_packet.extend_from_slice(header);
178 let tag_start = tag_start + pseudo_packet.len();
179 pseudo_packet.extend_from_slice(payload);
180
181 let (nonce, key) = match self.version {
182 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
183 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
184 _ => unreachable!(),
185 };
186
187 let nonce = aead::Nonce::assume_unique_for_key(nonce);
188 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
189
190 let (aad, tag) = pseudo_packet.split_at_mut(tag_start);
191 key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok()
192 }
193
194 fn export_keying_material(
195 &self,
196 output: &mut [u8],
197 label: &[u8],
198 context: &[u8],
199 ) -> Result<(), ExportKeyingMaterialError> {
200 self.inner
201 .export_keying_material(output, label, Some(context))
202 .map_err(|_| ExportKeyingMaterialError)?;
203 Ok(())
204 }
205}
206
207const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [
208 0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1,
209];
210const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [
211 0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c,
212];
213
214const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
215 0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
216];
217const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
218 0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
219];
220
221impl crypto::HeaderKey for Box<dyn HeaderProtectionKey> {
222 fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) {
223 let (header, sample) = packet.split_at_mut(pn_offset + 4);
224 let (first, rest) = header.split_at_mut(1);
225 let pn_end = Ord::min(pn_offset + 3, rest.len());
226 self.decrypt_in_place(
227 &sample[..self.sample_size()],
228 &mut first[0],
229 &mut rest[pn_offset - 1..pn_end],
230 )
231 .unwrap();
232 }
233
234 fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) {
235 let (header, sample) = packet.split_at_mut(pn_offset + 4);
236 let (first, rest) = header.split_at_mut(1);
237 let pn_end = Ord::min(pn_offset + 3, rest.len());
238 self.encrypt_in_place(
239 &sample[..self.sample_size()],
240 &mut first[0],
241 &mut rest[pn_offset - 1..pn_end],
242 )
243 .unwrap();
244 }
245
246 fn sample_size(&self) -> usize {
247 self.sample_len()
248 }
249}
250
251pub struct HandshakeData {
253 pub protocol: Option<Vec<u8>>,
257 pub server_name: Option<String>,
261}
262
263pub struct QuicClientConfig {
282 pub(crate) inner: Arc<rustls::ClientConfig>,
283 initial: Suite,
284 pub(crate) extension_context: Option<Arc<SimulatedExtensionContext>>,
286}
287
288impl QuicClientConfig {
289 #[cfg(feature = "platform-verifier")]
290 pub(crate) fn with_platform_verifier() -> Result<Self, Error> {
291 let mut inner = rustls::ClientConfig::builder_with_provider(configured_provider())
293 .with_protocol_versions(&[&rustls::version::TLS13])
294 .unwrap() .with_platform_verifier()?
296 .with_no_client_auth();
297
298 inner.enable_early_data = true;
299 Ok(Self {
300 initial: initial_suite_from_provider(inner.crypto_provider())
302 .expect("no initial cipher suite found"),
303 inner: Arc::new(inner),
304 extension_context: None,
305 })
306 }
307
308 pub(crate) fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
313 let inner = Self::inner(verifier);
314 Self {
315 initial: initial_suite_from_provider(inner.crypto_provider())
317 .expect("no initial cipher suite found"),
318 inner: Arc::new(inner),
319 extension_context: None,
320 }
321 }
322
323 pub fn with_initial(
327 inner: Arc<rustls::ClientConfig>,
328 initial: Suite,
329 ) -> Result<Self, NoInitialCipherSuite> {
330 match initial.suite.common.suite {
331 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial, extension_context: None }),
332 _ => Err(NoInitialCipherSuite { specific: true }),
333 }
334 }
335
336 pub fn with_extension_context(mut self, context: Arc<SimulatedExtensionContext>) -> Self {
338 self.extension_context = Some(context);
339 self
340 }
341
342 pub(crate) fn inner(verifier: Arc<dyn ServerCertVerifier>) -> rustls::ClientConfig {
343 let mut config = rustls::ClientConfig::builder_with_provider(configured_provider())
345 .with_protocol_versions(&[&rustls::version::TLS13])
346 .unwrap() .dangerous()
348 .with_custom_certificate_verifier(verifier)
349 .with_no_client_auth();
350
351 config.enable_early_data = true;
352 config
353 }
354}
355
356impl crypto::ClientConfig for QuicClientConfig {
357 fn start_session(
358 self: Arc<Self>,
359 version: u32,
360 server_name: &str,
361 params: &TransportParameters,
362 ) -> Result<Box<dyn crypto::Session>, ConnectError> {
363 let version = interpret_version(version)?;
364 let inner_session = Box::new(TlsSession {
365 version,
366 got_handshake_data: false,
367 next_secrets: None,
368 inner: rustls::quic::Connection::Client(
369 rustls::quic::ClientConnection::new(
370 self.inner.clone(),
371 version,
372 ServerName::try_from(server_name)
373 .map_err(|_| ConnectError::InvalidServerName(server_name.into()))?
374 .to_owned(),
375 to_vec(params),
376 )
377 .unwrap(),
378 ),
379 suite: self.initial,
380 });
381
382 if let Some(extension_context) = &self.extension_context {
384 let conn_id = format!("client-{}-{}", server_name,
385 std::time::SystemTime::now()
386 .duration_since(std::time::UNIX_EPOCH)
387 .unwrap()
388 .as_nanos());
389 Ok(Box::new(ExtensionAwareTlsSession::new(
390 inner_session,
391 extension_context.clone() as Arc<dyn TlsExtensionHooks>,
392 conn_id,
393 true, )))
395 } else {
396 Ok(inner_session)
397 }
398 }
399}
400
401impl TryFrom<rustls::ClientConfig> for QuicClientConfig {
402 type Error = NoInitialCipherSuite;
403
404 fn try_from(inner: rustls::ClientConfig) -> Result<Self, Self::Error> {
405 Arc::new(inner).try_into()
406 }
407}
408
409impl TryFrom<Arc<rustls::ClientConfig>> for QuicClientConfig {
410 type Error = NoInitialCipherSuite;
411
412 fn try_from(inner: Arc<rustls::ClientConfig>) -> Result<Self, Self::Error> {
413 Ok(Self {
414 initial: initial_suite_from_provider(inner.crypto_provider())
415 .ok_or(NoInitialCipherSuite { specific: false })?,
416 inner,
417 extension_context: None,
418 })
419 }
420}
421
422#[derive(Clone, Debug)]
430pub struct NoInitialCipherSuite {
431 specific: bool,
433}
434
435impl std::fmt::Display for NoInitialCipherSuite {
436 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
437 f.write_str(match self.specific {
438 true => "invalid cipher suite specified",
439 false => "no initial cipher suite found",
440 })
441 }
442}
443
444impl std::error::Error for NoInitialCipherSuite {}
445
446pub struct QuicServerConfig {
459 inner: Arc<rustls::ServerConfig>,
460 initial: Suite,
461 pub(crate) extension_context: Option<Arc<SimulatedExtensionContext>>,
463}
464
465impl QuicServerConfig {
466 pub(crate) fn new(
467 cert_chain: Vec<CertificateDer<'static>>,
468 key: PrivateKeyDer<'static>,
469 ) -> Result<Self, rustls::Error> {
470 let inner = Self::inner(cert_chain, key)?;
471 Ok(Self {
472 initial: initial_suite_from_provider(inner.crypto_provider())
474 .expect("no initial cipher suite found"),
475 inner: Arc::new(inner),
476 extension_context: None,
477 })
478 }
479
480 pub fn with_extension_context(mut self, context: Arc<SimulatedExtensionContext>) -> Self {
482 self.extension_context = Some(context);
483 self
484 }
485
486 pub fn with_initial(
490 inner: Arc<rustls::ServerConfig>,
491 initial: Suite,
492 ) -> Result<Self, NoInitialCipherSuite> {
493 match initial.suite.common.suite {
494 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial, extension_context: None }),
495 _ => Err(NoInitialCipherSuite { specific: true }),
496 }
497 }
498
499 pub(crate) fn inner(
505 cert_chain: Vec<CertificateDer<'static>>,
506 key: PrivateKeyDer<'static>,
507 ) -> Result<rustls::ServerConfig, rustls::Error> {
508 let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider())
509 .with_protocol_versions(&[&rustls::version::TLS13])
510 .unwrap() .with_no_client_auth()
512 .with_single_cert(cert_chain, key)?;
513
514 inner.max_early_data_size = u32::MAX;
515 Ok(inner)
516 }
517}
518
519impl TryFrom<rustls::ServerConfig> for QuicServerConfig {
520 type Error = NoInitialCipherSuite;
521
522 fn try_from(inner: rustls::ServerConfig) -> Result<Self, Self::Error> {
523 Arc::new(inner).try_into()
524 }
525}
526
527impl TryFrom<Arc<rustls::ServerConfig>> for QuicServerConfig {
528 type Error = NoInitialCipherSuite;
529
530 fn try_from(inner: Arc<rustls::ServerConfig>) -> Result<Self, Self::Error> {
531 Ok(Self {
532 initial: initial_suite_from_provider(inner.crypto_provider())
533 .ok_or(NoInitialCipherSuite { specific: false })?,
534 inner,
535 extension_context: None,
536 })
537 }
538}
539
540impl crypto::ServerConfig for QuicServerConfig {
541 fn start_session(
542 self: Arc<Self>,
543 version: u32,
544 params: &TransportParameters,
545 ) -> Box<dyn crypto::Session> {
546 let version = interpret_version(version).unwrap();
548 let inner_session = Box::new(TlsSession {
549 version,
550 got_handshake_data: false,
551 next_secrets: None,
552 inner: rustls::quic::Connection::Server(
553 rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params))
554 .unwrap(),
555 ),
556 suite: self.initial,
557 });
558
559 if let Some(extension_context) = &self.extension_context {
561 let conn_id = format!("server-{}",
562 std::time::SystemTime::now()
563 .duration_since(std::time::UNIX_EPOCH)
564 .unwrap()
565 .as_nanos());
566 Box::new(ExtensionAwareTlsSession::new(
567 inner_session,
568 extension_context.clone() as Arc<dyn TlsExtensionHooks>,
569 conn_id,
570 false, ))
572 } else {
573 inner_session
574 }
575 }
576
577 fn initial_keys(
578 &self,
579 version: u32,
580 dst_cid: &ConnectionId,
581 ) -> Result<Keys, UnsupportedVersion> {
582 let version = interpret_version(version)?;
583 Ok(initial_keys(version, *dst_cid, Side::Server, &self.initial))
584 }
585
586 fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
587 let version = interpret_version(version).unwrap();
589 let (nonce, key) = match version {
590 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
591 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
592 _ => unreachable!(),
593 };
594
595 let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
596 pseudo_packet.push(orig_dst_cid.len() as u8);
597 pseudo_packet.extend_from_slice(orig_dst_cid);
598 pseudo_packet.extend_from_slice(packet);
599
600 let nonce = aead::Nonce::assume_unique_for_key(nonce);
601 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
602
603 let tag = key
604 .seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut [])
605 .unwrap();
606 let mut result = [0; 16];
607 result.copy_from_slice(tag.as_ref());
608 result
609 }
610}
611
612pub(crate) fn initial_suite_from_provider(
613 provider: &Arc<rustls::crypto::CryptoProvider>,
614) -> Option<Suite> {
615 provider
616 .cipher_suites
617 .iter()
618 .find_map(|cs| match (cs.suite(), cs.tls13()) {
619 (rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
620 Some(suite.quic_suite())
621 }
622 _ => None,
623 })
624 .flatten()
625}
626
627pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
628 #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
629 let provider = rustls::crypto::aws_lc_rs::default_provider();
630 #[cfg(feature = "rustls-ring")]
631 let provider = rustls::crypto::ring::default_provider();
632 Arc::new(provider)
633}
634
635fn to_vec(params: &TransportParameters) -> Vec<u8> {
636 let mut bytes = Vec::new();
637 params.write(&mut bytes);
638 bytes
639}
640
641pub(crate) fn initial_keys(
642 version: Version,
643 dst_cid: ConnectionId,
644 side: Side,
645 suite: &Suite,
646) -> Keys {
647 let keys = suite.keys(&dst_cid, side.into(), version);
648 Keys {
649 header: KeyPair {
650 local: Box::new(keys.local.header),
651 remote: Box::new(keys.remote.header),
652 },
653 packet: KeyPair {
654 local: Box::new(keys.local.packet),
655 remote: Box::new(keys.remote.packet),
656 },
657 }
658}
659
660impl crypto::PacketKey for Box<dyn PacketKey> {
661 fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) {
662 let (header, payload_tag) = buf.split_at_mut(header_len);
663 let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len());
664 let tag = self.encrypt_in_place(packet, &*header, payload).unwrap();
665 tag_storage.copy_from_slice(tag.as_ref());
666 }
667
668 fn decrypt(
669 &self,
670 packet: u64,
671 header: &[u8],
672 payload: &mut BytesMut,
673 ) -> Result<(), CryptoError> {
674 let plain = self
675 .decrypt_in_place(packet, header, payload.as_mut())
676 .map_err(|_| CryptoError)?;
677 let plain_len = plain.len();
678 payload.truncate(plain_len);
679 Ok(())
680 }
681
682 fn tag_len(&self) -> usize {
683 (**self).tag_len()
684 }
685
686 fn confidentiality_limit(&self) -> u64 {
687 (**self).confidentiality_limit()
688 }
689
690 fn integrity_limit(&self) -> u64 {
691 (**self).integrity_limit()
692 }
693}
694
695fn interpret_version(version: u32) -> Result<Version, UnsupportedVersion> {
696 match version {
697 0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft),
698 0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1),
699 _ => Err(UnsupportedVersion),
700 }
701}