1use std::{any::Any, io, str, sync::Arc};
2
3#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
4use aws_lc_rs::aead;
5use bytes::BytesMut;
6#[cfg(feature = "ring")]
7use ring::aead;
8pub use rustls::Error;
9use rustls::{
10 self, CipherSuite,
11 client::danger::ServerCertVerifier,
12 pki_types::{CertificateDer, PrivateKeyDer, ServerName},
13 quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version},
14};
15#[cfg(feature = "platform-verifier")]
16use rustls_platform_verifier::BuilderVerifierExt;
17
18use crate::{
19 ConnectError, ConnectionId, Side, TransportError, TransportErrorCode,
20 crypto::{
21 self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion,
22 tls_extension_simulation::{
23 ExtensionAwareTlsSession, SimulatedExtensionContext, TlsExtensionHooks,
24 },
25 },
26 transport_parameters::TransportParameters,
27};
28
29impl From<Side> for rustls::Side {
30 fn from(s: Side) -> Self {
31 match s {
32 Side::Client => Self::Client,
33 Side::Server => Self::Server,
34 }
35 }
36}
37
38pub struct TlsSession {
40 version: Version,
41 got_handshake_data: bool,
42 next_secrets: Option<Secrets>,
43 inner: Connection,
44 suite: Suite,
45}
46
47impl TlsSession {
48 fn side(&self) -> Side {
49 match self.inner {
50 Connection::Client(_) => Side::Client,
51 Connection::Server(_) => Side::Server,
52 }
53 }
54}
55
56impl crypto::Session for TlsSession {
57 fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys {
58 initial_keys(self.version, *dst_cid, side, &self.suite)
59 }
60
61 fn handshake_data(&self) -> Option<Box<dyn Any>> {
62 if !self.got_handshake_data {
63 return None;
64 }
65 Some(Box::new(HandshakeData {
66 protocol: self.inner.alpn_protocol().map(|x| x.into()),
67 server_name: match self.inner {
68 Connection::Client(_) => None,
69 Connection::Server(ref session) => session.server_name().map(|x| x.into()),
70 },
71 }))
72 }
73
74 fn peer_identity(&self) -> Option<Box<dyn Any>> {
76 self.inner.peer_certificates().map(|v| -> Box<dyn Any> {
77 Box::new(
78 v.iter()
79 .map(|v| v.clone().into_owned())
80 .collect::<Vec<CertificateDer<'static>>>(),
81 )
82 })
83 }
84
85 fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn crypto::PacketKey>)> {
86 let keys = self.inner.zero_rtt_keys()?;
87 Some((Box::new(keys.header), Box::new(keys.packet)))
88 }
89
90 fn early_data_accepted(&self) -> Option<bool> {
91 match self.inner {
92 Connection::Client(ref session) => Some(session.is_early_data_accepted()),
93 _ => None,
94 }
95 }
96
97 fn is_handshaking(&self) -> bool {
98 self.inner.is_handshaking()
99 }
100
101 fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError> {
102 self.inner.read_hs(buf).map_err(|e| {
103 if let Some(alert) = self.inner.alert() {
104 TransportError {
105 code: TransportErrorCode::crypto(alert.into()),
106 frame: None,
107 reason: e.to_string(),
108 }
109 } else {
110 TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}"))
111 }
112 })?;
113 if !self.got_handshake_data {
114 let have_server_name = match self.inner {
118 Connection::Client(_) => false,
119 Connection::Server(ref session) => session.server_name().is_some(),
120 };
121 if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() {
122 self.got_handshake_data = true;
123 return Ok(true);
124 }
125 }
126 Ok(false)
127 }
128
129 fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError> {
130 match self.inner.quic_transport_parameters() {
131 None => Ok(None),
132 Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) {
133 Ok(params) => Ok(Some(params)),
134 Err(e) => Err(e.into()),
135 },
136 }
137 }
138
139 fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys> {
140 let keys = match self.inner.write_hs(buf)? {
141 KeyChange::Handshake { keys } => keys,
142 KeyChange::OneRtt { keys, next } => {
143 self.next_secrets = Some(next);
144 keys
145 }
146 };
147
148 Some(Keys {
149 header: KeyPair {
150 local: Box::new(keys.local.header),
151 remote: Box::new(keys.remote.header),
152 },
153 packet: KeyPair {
154 local: Box::new(keys.local.packet),
155 remote: Box::new(keys.remote.packet),
156 },
157 })
158 }
159
160 fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn crypto::PacketKey>>> {
161 let secrets = self.next_secrets.as_mut()?;
162 let keys = secrets.next_packet_keys();
163 Some(KeyPair {
164 local: Box::new(keys.local),
165 remote: Box::new(keys.remote),
166 })
167 }
168
169 fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool {
170 let tag_start = match payload.len().checked_sub(16) {
171 Some(x) => x,
172 None => return false,
173 };
174
175 let mut pseudo_packet =
176 Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
177 pseudo_packet.push(orig_dst_cid.len() as u8);
178 pseudo_packet.extend_from_slice(orig_dst_cid);
179 pseudo_packet.extend_from_slice(header);
180 let tag_start = tag_start + pseudo_packet.len();
181 pseudo_packet.extend_from_slice(payload);
182
183 let (nonce, key) = match self.version {
184 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
185 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
186 _ => unreachable!(),
187 };
188
189 let nonce = aead::Nonce::assume_unique_for_key(nonce);
190 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
191
192 let (aad, tag) = pseudo_packet.split_at_mut(tag_start);
193 key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok()
194 }
195
196 fn export_keying_material(
197 &self,
198 output: &mut [u8],
199 label: &[u8],
200 context: &[u8],
201 ) -> Result<(), ExportKeyingMaterialError> {
202 self.inner
203 .export_keying_material(output, label, Some(context))
204 .map_err(|_| ExportKeyingMaterialError)?;
205 Ok(())
206 }
207}
208
209const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [
210 0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1,
211];
212const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [
213 0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c,
214];
215
216const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
217 0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
218];
219const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
220 0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
221];
222
223impl crypto::HeaderKey for Box<dyn HeaderProtectionKey> {
224 fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) {
225 let (header, sample) = packet.split_at_mut(pn_offset + 4);
226 let (first, rest) = header.split_at_mut(1);
227 let pn_end = Ord::min(pn_offset + 3, rest.len());
228 self.decrypt_in_place(
229 &sample[..self.sample_size()],
230 &mut first[0],
231 &mut rest[pn_offset - 1..pn_end],
232 )
233 .unwrap();
234 }
235
236 fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) {
237 let (header, sample) = packet.split_at_mut(pn_offset + 4);
238 let (first, rest) = header.split_at_mut(1);
239 let pn_end = Ord::min(pn_offset + 3, rest.len());
240 self.encrypt_in_place(
241 &sample[..self.sample_size()],
242 &mut first[0],
243 &mut rest[pn_offset - 1..pn_end],
244 )
245 .unwrap();
246 }
247
248 fn sample_size(&self) -> usize {
249 self.sample_len()
250 }
251}
252
253pub struct HandshakeData {
255 pub protocol: Option<Vec<u8>>,
259 pub server_name: Option<String>,
263}
264
265pub struct QuicClientConfig {
284 pub(crate) inner: Arc<rustls::ClientConfig>,
285 initial: Suite,
286 pub(crate) extension_context: Option<Arc<SimulatedExtensionContext>>,
288}
289
290impl QuicClientConfig {
291 #[cfg(feature = "platform-verifier")]
292 pub(crate) fn with_platform_verifier() -> Result<Self, Error> {
293 let mut inner = rustls::ClientConfig::builder_with_provider(configured_provider())
295 .with_protocol_versions(&[&rustls::version::TLS13])
296 .unwrap() .with_platform_verifier()?
298 .with_no_client_auth();
299
300 inner.enable_early_data = true;
301 Ok(Self {
302 initial: initial_suite_from_provider(inner.crypto_provider())
304 .expect("no initial cipher suite found"),
305 inner: Arc::new(inner),
306 extension_context: None,
307 })
308 }
309
310 pub(crate) fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
315 let inner = Self::inner(verifier);
316 Self {
317 initial: initial_suite_from_provider(inner.crypto_provider())
319 .expect("no initial cipher suite found"),
320 inner: Arc::new(inner),
321 extension_context: None,
322 }
323 }
324
325 pub fn with_initial(
329 inner: Arc<rustls::ClientConfig>,
330 initial: Suite,
331 ) -> Result<Self, NoInitialCipherSuite> {
332 match initial.suite.common.suite {
333 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self {
334 inner,
335 initial,
336 extension_context: None,
337 }),
338 _ => Err(NoInitialCipherSuite { specific: true }),
339 }
340 }
341
342 pub fn with_extension_context(mut self, context: Arc<SimulatedExtensionContext>) -> Self {
344 self.extension_context = Some(context);
345 self
346 }
347
348 pub(crate) fn inner(verifier: Arc<dyn ServerCertVerifier>) -> rustls::ClientConfig {
349 let mut config = rustls::ClientConfig::builder_with_provider(configured_provider())
351 .with_protocol_versions(&[&rustls::version::TLS13])
352 .unwrap() .dangerous()
354 .with_custom_certificate_verifier(verifier)
355 .with_no_client_auth();
356
357 config.enable_early_data = true;
358 config
359 }
360}
361
362impl crypto::ClientConfig for QuicClientConfig {
363 fn start_session(
364 self: Arc<Self>,
365 version: u32,
366 server_name: &str,
367 params: &TransportParameters,
368 ) -> Result<Box<dyn crypto::Session>, ConnectError> {
369 let version = interpret_version(version)?;
370 let inner_session = Box::new(TlsSession {
371 version,
372 got_handshake_data: false,
373 next_secrets: None,
374 inner: rustls::quic::Connection::Client(
375 rustls::quic::ClientConnection::new(
376 self.inner.clone(),
377 version,
378 ServerName::try_from(server_name)
379 .map_err(|_| ConnectError::InvalidServerName(server_name.into()))?
380 .to_owned(),
381 to_vec(params),
382 )
383 .unwrap(),
384 ),
385 suite: self.initial,
386 });
387
388 if let Some(extension_context) = &self.extension_context {
390 let conn_id = format!(
391 "client-{}-{}",
392 server_name,
393 std::time::SystemTime::now()
394 .duration_since(std::time::UNIX_EPOCH)
395 .unwrap()
396 .as_nanos()
397 );
398 Ok(Box::new(ExtensionAwareTlsSession::new(
399 inner_session,
400 extension_context.clone() as Arc<dyn TlsExtensionHooks>,
401 conn_id,
402 true, )))
404 } else {
405 Ok(inner_session)
406 }
407 }
408}
409
410impl TryFrom<rustls::ClientConfig> for QuicClientConfig {
411 type Error = NoInitialCipherSuite;
412
413 fn try_from(inner: rustls::ClientConfig) -> Result<Self, Self::Error> {
414 Arc::new(inner).try_into()
415 }
416}
417
418impl TryFrom<Arc<rustls::ClientConfig>> for QuicClientConfig {
419 type Error = NoInitialCipherSuite;
420
421 fn try_from(inner: Arc<rustls::ClientConfig>) -> Result<Self, Self::Error> {
422 Ok(Self {
423 initial: initial_suite_from_provider(inner.crypto_provider())
424 .ok_or(NoInitialCipherSuite { specific: false })?,
425 inner,
426 extension_context: None,
427 })
428 }
429}
430
431#[derive(Clone, Debug)]
439pub struct NoInitialCipherSuite {
440 specific: bool,
442}
443
444impl std::fmt::Display for NoInitialCipherSuite {
445 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
446 f.write_str(match self.specific {
447 true => "invalid cipher suite specified",
448 false => "no initial cipher suite found",
449 })
450 }
451}
452
453impl std::error::Error for NoInitialCipherSuite {}
454
455pub struct QuicServerConfig {
468 inner: Arc<rustls::ServerConfig>,
469 initial: Suite,
470 pub(crate) extension_context: Option<Arc<SimulatedExtensionContext>>,
472}
473
474impl QuicServerConfig {
475 pub(crate) fn new(
476 cert_chain: Vec<CertificateDer<'static>>,
477 key: PrivateKeyDer<'static>,
478 ) -> Result<Self, rustls::Error> {
479 let inner = Self::inner(cert_chain, key)?;
480 Ok(Self {
481 initial: initial_suite_from_provider(inner.crypto_provider())
483 .expect("no initial cipher suite found"),
484 inner: Arc::new(inner),
485 extension_context: None,
486 })
487 }
488
489 pub fn with_extension_context(mut self, context: Arc<SimulatedExtensionContext>) -> Self {
491 self.extension_context = Some(context);
492 self
493 }
494
495 pub fn with_initial(
499 inner: Arc<rustls::ServerConfig>,
500 initial: Suite,
501 ) -> Result<Self, NoInitialCipherSuite> {
502 match initial.suite.common.suite {
503 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self {
504 inner,
505 initial,
506 extension_context: None,
507 }),
508 _ => Err(NoInitialCipherSuite { specific: true }),
509 }
510 }
511
512 pub(crate) fn inner(
518 cert_chain: Vec<CertificateDer<'static>>,
519 key: PrivateKeyDer<'static>,
520 ) -> Result<rustls::ServerConfig, rustls::Error> {
521 let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider())
522 .with_protocol_versions(&[&rustls::version::TLS13])
523 .unwrap() .with_no_client_auth()
525 .with_single_cert(cert_chain, key)?;
526
527 inner.max_early_data_size = u32::MAX;
528 Ok(inner)
529 }
530}
531
532impl TryFrom<rustls::ServerConfig> for QuicServerConfig {
533 type Error = NoInitialCipherSuite;
534
535 fn try_from(inner: rustls::ServerConfig) -> Result<Self, Self::Error> {
536 Arc::new(inner).try_into()
537 }
538}
539
540impl TryFrom<Arc<rustls::ServerConfig>> for QuicServerConfig {
541 type Error = NoInitialCipherSuite;
542
543 fn try_from(inner: Arc<rustls::ServerConfig>) -> Result<Self, Self::Error> {
544 Ok(Self {
545 initial: initial_suite_from_provider(inner.crypto_provider())
546 .ok_or(NoInitialCipherSuite { specific: false })?,
547 inner,
548 extension_context: None,
549 })
550 }
551}
552
553impl crypto::ServerConfig for QuicServerConfig {
554 fn start_session(
555 self: Arc<Self>,
556 version: u32,
557 params: &TransportParameters,
558 ) -> Box<dyn crypto::Session> {
559 let version = interpret_version(version).unwrap();
561 let inner_session = Box::new(TlsSession {
562 version,
563 got_handshake_data: false,
564 next_secrets: None,
565 inner: rustls::quic::Connection::Server(
566 rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params))
567 .unwrap(),
568 ),
569 suite: self.initial,
570 });
571
572 if let Some(extension_context) = &self.extension_context {
574 let conn_id = format!(
575 "server-{}",
576 std::time::SystemTime::now()
577 .duration_since(std::time::UNIX_EPOCH)
578 .unwrap()
579 .as_nanos()
580 );
581 Box::new(ExtensionAwareTlsSession::new(
582 inner_session,
583 extension_context.clone() as Arc<dyn TlsExtensionHooks>,
584 conn_id,
585 false, ))
587 } else {
588 inner_session
589 }
590 }
591
592 fn initial_keys(
593 &self,
594 version: u32,
595 dst_cid: &ConnectionId,
596 ) -> Result<Keys, UnsupportedVersion> {
597 let version = interpret_version(version)?;
598 Ok(initial_keys(version, *dst_cid, Side::Server, &self.initial))
599 }
600
601 fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
602 let version = interpret_version(version).unwrap();
604 let (nonce, key) = match version {
605 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
606 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
607 _ => unreachable!(),
608 };
609
610 let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
611 pseudo_packet.push(orig_dst_cid.len() as u8);
612 pseudo_packet.extend_from_slice(orig_dst_cid);
613 pseudo_packet.extend_from_slice(packet);
614
615 let nonce = aead::Nonce::assume_unique_for_key(nonce);
616 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
617
618 let tag = key
619 .seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut [])
620 .unwrap();
621 let mut result = [0; 16];
622 result.copy_from_slice(tag.as_ref());
623 result
624 }
625}
626
627pub(crate) fn initial_suite_from_provider(
628 provider: &Arc<rustls::crypto::CryptoProvider>,
629) -> Option<Suite> {
630 provider
631 .cipher_suites
632 .iter()
633 .find_map(|cs| match (cs.suite(), cs.tls13()) {
634 (rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
635 Some(suite.quic_suite())
636 }
637 _ => None,
638 })
639 .flatten()
640}
641
642pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
643 #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
644 let provider = rustls::crypto::aws_lc_rs::default_provider();
645 #[cfg(feature = "rustls-ring")]
646 let provider = rustls::crypto::ring::default_provider();
647 Arc::new(provider)
648}
649
650fn to_vec(params: &TransportParameters) -> Vec<u8> {
651 let mut bytes = Vec::new();
652 params.write(&mut bytes);
653 bytes
654}
655
656pub(crate) fn initial_keys(
657 version: Version,
658 dst_cid: ConnectionId,
659 side: Side,
660 suite: &Suite,
661) -> Keys {
662 let keys = suite.keys(&dst_cid, side.into(), version);
663 Keys {
664 header: KeyPair {
665 local: Box::new(keys.local.header),
666 remote: Box::new(keys.remote.header),
667 },
668 packet: KeyPair {
669 local: Box::new(keys.local.packet),
670 remote: Box::new(keys.remote.packet),
671 },
672 }
673}
674
675impl crypto::PacketKey for Box<dyn PacketKey> {
676 fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) {
677 let (header, payload_tag) = buf.split_at_mut(header_len);
678 let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len());
679 let tag = self.encrypt_in_place(packet, &*header, payload).unwrap();
680 tag_storage.copy_from_slice(tag.as_ref());
681 }
682
683 fn decrypt(
684 &self,
685 packet: u64,
686 header: &[u8],
687 payload: &mut BytesMut,
688 ) -> Result<(), CryptoError> {
689 let plain = self
690 .decrypt_in_place(packet, header, payload.as_mut())
691 .map_err(|_| CryptoError)?;
692 let plain_len = plain.len();
693 payload.truncate(plain_len);
694 Ok(())
695 }
696
697 fn tag_len(&self) -> usize {
698 (**self).tag_len()
699 }
700
701 fn confidentiality_limit(&self) -> u64 {
702 (**self).confidentiality_limit()
703 }
704
705 fn integrity_limit(&self) -> u64 {
706 (**self).integrity_limit()
707 }
708}
709
710fn interpret_version(version: u32) -> Result<Version, UnsupportedVersion> {
711 match version {
712 0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft),
713 0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1),
714 _ => Err(UnsupportedVersion),
715 }
716}