ant_quic/crypto/
rustls.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8use std::{any::Any, io, str, sync::Arc};
9
10#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
11use aws_lc_rs::aead;
12use bytes::BytesMut;
13#[cfg(feature = "ring")]
14use ring::aead;
15pub use rustls::Error;
16use rustls::{
17    self, CipherSuite,
18    client::danger::ServerCertVerifier,
19    pki_types::{CertificateDer, PrivateKeyDer, ServerName},
20    quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version},
21};
22#[cfg(feature = "platform-verifier")]
23use rustls_platform_verifier::BuilderVerifierExt;
24
25use std::sync::atomic::{AtomicBool, Ordering};
26
27/// Internal debug flag indicating whether the build/runtime is configured
28/// to prefer ML‑KEM‑only key exchange groups. This is a diagnostic aid used
29/// by tests; it does not by itself enforce KEM selection.
30static DEBUG_KEM_ONLY: AtomicBool = AtomicBool::new(false);
31
32use crate::{
33    ConnectError, ConnectionId, Side, TransportError, TransportErrorCode,
34    crypto::{
35        self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion,
36        tls_extension_simulation::{
37            ExtensionAwareTlsSession, SimulatedExtensionContext, TlsExtensionHooks,
38        },
39    },
40    transport_parameters::TransportParameters,
41};
42
43impl From<Side> for rustls::Side {
44    fn from(s: Side) -> Self {
45        match s {
46            Side::Client => Self::Client,
47            Side::Server => Self::Server,
48        }
49    }
50}
51
52/// A rustls TLS session
53pub struct TlsSession {
54    version: Version,
55    got_handshake_data: bool,
56    next_secrets: Option<Secrets>,
57    inner: Connection,
58    suite: Suite,
59}
60
61impl TlsSession {
62    fn side(&self) -> Side {
63        match self.inner {
64            Connection::Client(_) => Side::Client,
65            Connection::Server(_) => Side::Server,
66        }
67    }
68}
69
70impl crypto::Session for TlsSession {
71    fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys {
72        initial_keys(self.version, *dst_cid, side, &self.suite)
73    }
74
75    fn handshake_data(&self) -> Option<Box<dyn Any>> {
76        if !self.got_handshake_data {
77            return None;
78        }
79        Some(Box::new(HandshakeData {
80            protocol: self.inner.alpn_protocol().map(|x| x.into()),
81            server_name: match self.inner {
82                Connection::Client(_) => None,
83                Connection::Server(ref session) => session.server_name().map(|x| x.into()),
84            },
85        }))
86    }
87
88    /// For the rustls `TlsSession`, the `Any` type is `Vec<rustls::pki_types::CertificateDer>`
89    fn peer_identity(&self) -> Option<Box<dyn Any>> {
90        self.inner.peer_certificates().map(|v| -> Box<dyn Any> {
91            Box::new(
92                v.iter()
93                    .map(|v| v.clone().into_owned())
94                    .collect::<Vec<CertificateDer<'static>>>(),
95            )
96        })
97    }
98
99    fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn crypto::PacketKey>)> {
100        let keys = self.inner.zero_rtt_keys()?;
101        Some((Box::new(keys.header), Box::new(keys.packet)))
102    }
103
104    fn early_data_accepted(&self) -> Option<bool> {
105        match self.inner {
106            Connection::Client(ref session) => Some(session.is_early_data_accepted()),
107            _ => None,
108        }
109    }
110
111    fn is_handshaking(&self) -> bool {
112        self.inner.is_handshaking()
113    }
114
115    fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError> {
116        self.inner.read_hs(buf).map_err(|e| {
117            if let Some(alert) = self.inner.alert() {
118                TransportError {
119                    code: TransportErrorCode::crypto(alert.into()),
120                    frame: None,
121                    reason: e.to_string(),
122                }
123            } else {
124                TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}"))
125            }
126        })?;
127        if !self.got_handshake_data {
128            // Hack around the lack of an explicit signal from rustls to reflect ClientHello being
129            // ready on incoming connections, or ALPN negotiation completing on outgoing
130            // connections.
131            let have_server_name = match self.inner {
132                Connection::Client(_) => false,
133                Connection::Server(ref session) => session.server_name().is_some(),
134            };
135            if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() {
136                self.got_handshake_data = true;
137                return Ok(true);
138            }
139        }
140        Ok(false)
141    }
142
143    fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError> {
144        match self.inner.quic_transport_parameters() {
145            None => Ok(None),
146            Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) {
147                Ok(params) => Ok(Some(params)),
148                Err(e) => Err(e.into()),
149            },
150        }
151    }
152
153    fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys> {
154        let keys = match self.inner.write_hs(buf)? {
155            KeyChange::Handshake { keys } => keys,
156            KeyChange::OneRtt { keys, next } => {
157                self.next_secrets = Some(next);
158                keys
159            }
160        };
161
162        Some(Keys {
163            header: KeyPair {
164                local: Box::new(keys.local.header),
165                remote: Box::new(keys.remote.header),
166            },
167            packet: KeyPair {
168                local: Box::new(keys.local.packet),
169                remote: Box::new(keys.remote.packet),
170            },
171        })
172    }
173
174    fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn crypto::PacketKey>>> {
175        let secrets = self.next_secrets.as_mut()?;
176        let keys = secrets.next_packet_keys();
177        Some(KeyPair {
178            local: Box::new(keys.local),
179            remote: Box::new(keys.remote),
180        })
181    }
182
183    fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool {
184        let tag_start = match payload.len().checked_sub(16) {
185            Some(x) => x,
186            None => return false,
187        };
188
189        let mut pseudo_packet =
190            Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
191        pseudo_packet.push(orig_dst_cid.len() as u8);
192        pseudo_packet.extend_from_slice(orig_dst_cid);
193        pseudo_packet.extend_from_slice(header);
194        let tag_start = tag_start + pseudo_packet.len();
195        pseudo_packet.extend_from_slice(payload);
196
197        let (nonce, key) = match self.version {
198            Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
199            Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
200            _ => unreachable!(),
201        };
202
203        let nonce = aead::Nonce::assume_unique_for_key(nonce);
204        let key = match aead::UnboundKey::new(&aead::AES_128_GCM, &key) {
205            Ok(unbound_key) => aead::LessSafeKey::new(unbound_key),
206            Err(_) => {
207                // This should never happen with our hardcoded keys
208                debug_assert!(false, "Failed to create AEAD key for retry integrity");
209                return false;
210            }
211        };
212
213        let (aad, tag) = pseudo_packet.split_at_mut(tag_start);
214        key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok()
215    }
216
217    fn export_keying_material(
218        &self,
219        output: &mut [u8],
220        label: &[u8],
221        context: &[u8],
222    ) -> Result<(), ExportKeyingMaterialError> {
223        self.inner
224            .export_keying_material(output, label, Some(context))
225            .map_err(|_| ExportKeyingMaterialError)?;
226        Ok(())
227    }
228}
229
230const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [
231    0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1,
232];
233const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [
234    0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c,
235];
236
237const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
238    0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
239];
240const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
241    0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
242];
243
244impl crypto::HeaderKey for Box<dyn HeaderProtectionKey> {
245    fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) {
246        let (header, sample) = packet.split_at_mut(pn_offset + 4);
247        let (first, rest) = header.split_at_mut(1);
248        let pn_end = Ord::min(pn_offset + 3, rest.len());
249        if let Err(e) = self.decrypt_in_place(
250            &sample[..self.sample_size()],
251            &mut first[0],
252            &mut rest[pn_offset - 1..pn_end],
253        ) {
254            debug_assert!(false, "Header protection decrypt failed: {:?}", e);
255        }
256    }
257
258    fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) {
259        let (header, sample) = packet.split_at_mut(pn_offset + 4);
260        let (first, rest) = header.split_at_mut(1);
261        let pn_end = Ord::min(pn_offset + 3, rest.len());
262        if let Err(e) = self.encrypt_in_place(
263            &sample[..self.sample_size()],
264            &mut first[0],
265            &mut rest[pn_offset - 1..pn_end],
266        ) {
267            debug_assert!(false, "Header protection encrypt failed: {:?}", e);
268        }
269    }
270
271    fn sample_size(&self) -> usize {
272        self.sample_len()
273    }
274}
275
276/// Authentication data for (rustls) TLS session
277pub struct HandshakeData {
278    /// The negotiated application protocol, if ALPN is in use
279    ///
280    /// Guaranteed to be set if a nonempty list of protocols was specified for this connection.
281    pub protocol: Option<Vec<u8>>,
282    /// The server name specified by the client, if any
283    ///
284    /// Always `None` for outgoing connections
285    pub server_name: Option<String>,
286}
287
288/// A QUIC-compatible TLS client configuration
289///
290/// Quinn implicitly constructs a `QuicClientConfig` with reasonable defaults within
291/// [`ClientConfig::with_root_certificates()`][root_certs] and [`ClientConfig::with_platform_verifier()`][platform].
292/// Alternatively, `QuicClientConfig`'s [`TryFrom`] implementation can be used to wrap around a
293/// custom [`rustls::ClientConfig`], in which case care should be taken around certain points:
294///
295/// - If `enable_early_data` is not set to true, then sending 0-RTT data will not be possible on
296///   outgoing connections.
297/// - The [`rustls::ClientConfig`] must have TLS 1.3 support enabled for conversion to succeed.
298///
299/// The object in the `resumption` field of the inner [`rustls::ClientConfig`] determines whether
300/// calling `into_0rtt` on outgoing connections returns `Ok` or `Err`. It typically allows
301/// `into_0rtt` to proceed if it recognizes the server name, and defaults to an in-memory cache of
302/// 256 server names.
303///
304/// [root_certs]: crate::config::ClientConfig::with_root_certificates()
305/// [platform]: crate::config::ClientConfig::with_platform_verifier()
306pub struct QuicClientConfig {
307    pub(crate) inner: Arc<rustls::ClientConfig>,
308    initial: Suite,
309    /// Optional RFC 7250 extension context for certificate type negotiation
310    pub(crate) extension_context: Option<Arc<SimulatedExtensionContext>>,
311}
312
313impl QuicClientConfig {
314    #[cfg(feature = "platform-verifier")]
315    #[allow(clippy::panic)]
316    pub(crate) fn with_platform_verifier() -> Result<Self, Error> {
317        // Keep in sync with `inner()` below
318        let mut inner = rustls::ClientConfig::builder_with_provider(configured_provider())
319            .with_protocol_versions(&[&rustls::version::TLS13])
320            .unwrap_or_else(|_| panic!("default providers should support TLS 1.3"))
321            .with_platform_verifier()?
322            .with_no_client_auth();
323
324        inner.enable_early_data = true;
325        Ok(Self {
326            // We're confident that the *ring* default provider contains TLS13_AES_128_GCM_SHA256
327            initial: initial_suite_from_provider(inner.crypto_provider())
328                .unwrap_or_else(|| panic!("no initial cipher suite found")),
329            inner: Arc::new(inner),
330            extension_context: None,
331        })
332    }
333
334    /// Initialize a sane QUIC-compatible TLS client configuration
335    ///
336    /// QUIC requires that TLS 1.3 be enabled. Advanced users can use any [`rustls::ClientConfig`] that
337    /// satisfies this requirement.
338    pub(crate) fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
339        let inner = Self::inner(verifier);
340        Self {
341            // We're confident that the *ring* default provider contains TLS13_AES_128_GCM_SHA256
342            initial: initial_suite_from_provider(inner.crypto_provider())
343                .unwrap_or_else(|| panic!("no initial cipher suite found")),
344            inner: Arc::new(inner),
345            extension_context: None,
346        }
347    }
348
349    /// Initialize a QUIC-compatible TLS client configuration with a separate initial cipher suite
350    ///
351    /// This is useful if you want to avoid the initial cipher suite for traffic encryption.
352    pub fn with_initial(
353        inner: Arc<rustls::ClientConfig>,
354        initial: Suite,
355    ) -> Result<Self, NoInitialCipherSuite> {
356        match initial.suite.common.suite {
357            CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self {
358                inner,
359                initial,
360                extension_context: None,
361            }),
362            _ => Err(NoInitialCipherSuite { specific: true }),
363        }
364    }
365
366    /// Set the certificate type extension context for RFC 7250 support
367    pub fn with_extension_context(mut self, context: Arc<SimulatedExtensionContext>) -> Self {
368        self.extension_context = Some(context);
369        self
370    }
371
372    #[allow(clippy::panic)]
373    pub(crate) fn inner(verifier: Arc<dyn ServerCertVerifier>) -> rustls::ClientConfig {
374        // Keep in sync with `with_platform_verifier()` above
375        let mut config = rustls::ClientConfig::builder_with_provider(configured_provider())
376            .with_protocol_versions(&[&rustls::version::TLS13])
377            .unwrap_or_else(|_| panic!("The default providers support TLS 1.3"))
378            .dangerous()
379            .with_custom_certificate_verifier(verifier)
380            .with_no_client_auth();
381
382        config.enable_early_data = true;
383        config
384    }
385}
386
387impl crypto::ClientConfig for QuicClientConfig {
388    fn start_session(
389        self: Arc<Self>,
390        version: u32,
391        server_name: &str,
392        params: &TransportParameters,
393    ) -> Result<Box<dyn crypto::Session>, ConnectError> {
394        let version = interpret_version(version)?;
395        let inner_session = Box::new(TlsSession {
396            version,
397            got_handshake_data: false,
398            next_secrets: None,
399            inner: rustls::quic::Connection::Client(rustls::quic::ClientConnection::new(
400                self.inner.clone(),
401                version,
402                ServerName::try_from(server_name)
403                    .map_err(|_| ConnectError::InvalidServerName(server_name.into()))?
404                    .to_owned(),
405                to_vec(params),
406            )?),
407            suite: self.initial,
408        });
409
410        // Wrap with extension awareness if RFC 7250 support is enabled
411        if let Some(extension_context) = &self.extension_context {
412            let conn_id = format!(
413                "client-{}-{}",
414                server_name,
415                std::time::SystemTime::now()
416                    .duration_since(std::time::UNIX_EPOCH)
417                    .unwrap_or_else(|_| std::time::Duration::from_secs(0))
418                    .as_nanos()
419            );
420            Ok(Box::new(ExtensionAwareTlsSession::new(
421                inner_session,
422                extension_context.clone() as Arc<dyn TlsExtensionHooks>,
423                conn_id,
424                true, // is_client
425            )))
426        } else {
427            Ok(inner_session)
428        }
429    }
430}
431
432impl TryFrom<rustls::ClientConfig> for QuicClientConfig {
433    type Error = NoInitialCipherSuite;
434
435    fn try_from(inner: rustls::ClientConfig) -> Result<Self, Self::Error> {
436        Arc::new(inner).try_into()
437    }
438}
439
440impl TryFrom<Arc<rustls::ClientConfig>> for QuicClientConfig {
441    type Error = NoInitialCipherSuite;
442
443    fn try_from(inner: Arc<rustls::ClientConfig>) -> Result<Self, Self::Error> {
444        Ok(Self {
445            initial: initial_suite_from_provider(inner.crypto_provider())
446                .ok_or(NoInitialCipherSuite { specific: false })?,
447            inner,
448            extension_context: None,
449        })
450    }
451}
452
453/// The initial cipher suite (AES-128-GCM-SHA256) is not available
454///
455/// When the cipher suite is supplied `with_initial()`, it must be
456/// [`CipherSuite::TLS13_AES_128_GCM_SHA256`]. When the cipher suite is derived from a config's
457/// [`CryptoProvider`][provider], that provider must reference a cipher suite with the same ID.
458///
459/// [provider]: rustls::crypto::CryptoProvider
460#[derive(Clone, Debug)]
461pub struct NoInitialCipherSuite {
462    /// Whether the initial cipher suite was supplied by the caller
463    specific: bool,
464}
465
466impl std::fmt::Display for NoInitialCipherSuite {
467    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
468        f.write_str(match self.specific {
469            true => "invalid cipher suite specified",
470            false => "no initial cipher suite found",
471        })
472    }
473}
474
475impl std::error::Error for NoInitialCipherSuite {}
476
477/// A QUIC-compatible TLS server configuration
478///
479/// Quinn implicitly constructs a `QuicServerConfig` with reasonable defaults within
480/// [`ServerConfig::with_single_cert()`][single]. Alternatively, `QuicServerConfig`'s [`TryFrom`]
481/// implementation or `with_initial` method can be used to wrap around a custom
482/// [`rustls::ServerConfig`], in which case care should be taken around certain points:
483///
484/// - If `max_early_data_size` is not set to `u32::MAX`, the server will not be able to accept
485///   incoming 0-RTT data. QUIC prohibits `max_early_data_size` values other than 0 or `u32::MAX`.
486/// - The `rustls::ServerConfig` must have TLS 1.3 support enabled for conversion to succeed.
487///
488/// [single]: crate::config::ServerConfig::with_single_cert()
489pub struct QuicServerConfig {
490    inner: Arc<rustls::ServerConfig>,
491    initial: Suite,
492    /// Optional RFC 7250 extension context for certificate type negotiation
493    pub(crate) extension_context: Option<Arc<SimulatedExtensionContext>>,
494}
495
496impl QuicServerConfig {
497    pub(crate) fn new(
498        cert_chain: Vec<CertificateDer<'static>>,
499        key: PrivateKeyDer<'static>,
500    ) -> Result<Self, rustls::Error> {
501        let inner = Self::inner(cert_chain, key)?;
502        Ok(Self {
503            // We're confident that the *ring* default provider contains TLS13_AES_128_GCM_SHA256
504            initial: initial_suite_from_provider(inner.crypto_provider())
505                .ok_or_else(|| rustls::Error::General("no initial cipher suite found".into()))?,
506            inner: Arc::new(inner),
507            extension_context: None,
508        })
509    }
510
511    /// Set the certificate type extension context for RFC 7250 support
512    pub fn with_extension_context(mut self, context: Arc<SimulatedExtensionContext>) -> Self {
513        self.extension_context = Some(context);
514        self
515    }
516
517    /// Initialize a QUIC-compatible TLS client configuration with a separate initial cipher suite
518    ///
519    /// This is useful if you want to avoid the initial cipher suite for traffic encryption.
520    pub fn with_initial(
521        inner: Arc<rustls::ServerConfig>,
522        initial: Suite,
523    ) -> Result<Self, NoInitialCipherSuite> {
524        match initial.suite.common.suite {
525            CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self {
526                inner,
527                initial,
528                extension_context: None,
529            }),
530            _ => Err(NoInitialCipherSuite { specific: true }),
531        }
532    }
533
534    /// Initialize a sane QUIC-compatible TLS server configuration
535    ///
536    /// QUIC requires that TLS 1.3 be enabled, and that the maximum early data size is either 0 or
537    /// `u32::MAX`. Advanced users can use any [`rustls::ServerConfig`] that satisfies these
538    /// requirements.
539    pub(crate) fn inner(
540        cert_chain: Vec<CertificateDer<'static>>,
541        key: PrivateKeyDer<'static>,
542    ) -> Result<rustls::ServerConfig, rustls::Error> {
543        let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider())
544            .with_protocol_versions(&[&rustls::version::TLS13])
545            .map_err(|_| rustls::Error::General("TLS 1.3 not supported".into()))? // The *ring* default provider supports TLS 1.3
546            .with_no_client_auth()
547            .with_single_cert(cert_chain, key)?;
548
549        inner.max_early_data_size = u32::MAX;
550        Ok(inner)
551    }
552}
553
554impl TryFrom<rustls::ServerConfig> for QuicServerConfig {
555    type Error = NoInitialCipherSuite;
556
557    fn try_from(inner: rustls::ServerConfig) -> Result<Self, Self::Error> {
558        Arc::new(inner).try_into()
559    }
560}
561
562impl TryFrom<Arc<rustls::ServerConfig>> for QuicServerConfig {
563    type Error = NoInitialCipherSuite;
564
565    fn try_from(inner: Arc<rustls::ServerConfig>) -> Result<Self, Self::Error> {
566        Ok(Self {
567            initial: initial_suite_from_provider(inner.crypto_provider())
568                .ok_or(NoInitialCipherSuite { specific: false })?,
569            inner,
570            extension_context: None,
571        })
572    }
573}
574
575impl crypto::ServerConfig for QuicServerConfig {
576    #[allow(clippy::expect_used)]
577    fn start_session(
578        self: Arc<Self>,
579        version: u32,
580        params: &TransportParameters,
581    ) -> Box<dyn crypto::Session> {
582        // Safe: `start_session()` is never called if `initial_keys()` rejected `version`
583        let version = interpret_version(version).map_err(|_| {
584            rustls::Error::General("Invalid QUIC version for server connection".into())
585        }).expect("Version should be valid at this point - start_session() is never called if initial_keys() rejected version");
586        let inner_session = Box::new(TlsSession {
587            version,
588            got_handshake_data: false,
589            next_secrets: None,
590            inner: rustls::quic::Connection::Server(
591                rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params))
592                    .map_err(|_| {
593                        rustls::Error::General("Failed to create server connection".into())
594                    })
595                    .expect("Server connection creation should not fail with valid parameters"),
596            ),
597            suite: self.initial,
598        });
599
600        // Wrap with extension awareness if RFC 7250 support is enabled
601        if let Some(extension_context) = &self.extension_context {
602            let conn_id = format!(
603                "server-{}",
604                std::time::SystemTime::now()
605                    .duration_since(std::time::UNIX_EPOCH)
606                    .unwrap_or_else(|_| std::time::Duration::from_secs(0))
607                    .as_nanos()
608            );
609            Box::new(ExtensionAwareTlsSession::new(
610                inner_session,
611                extension_context.clone() as Arc<dyn TlsExtensionHooks>,
612                conn_id,
613                false, // is_client = false for server
614            ))
615        } else {
616            inner_session
617        }
618    }
619
620    fn initial_keys(
621        &self,
622        version: u32,
623        dst_cid: &ConnectionId,
624    ) -> Result<Keys, UnsupportedVersion> {
625        let version = interpret_version(version)?;
626        Ok(initial_keys(version, *dst_cid, Side::Server, &self.initial))
627    }
628
629    #[allow(clippy::expect_used)]
630    fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
631        // Safe: `start_session()` is never called if `initial_keys()` rejected `version`
632        let version = interpret_version(version).map_err(|_| {
633            rustls::Error::General("Invalid QUIC version for retry tag".into())
634        }).expect("Version should be valid at this point - retry_tag() is never called if initial_keys() rejected version");
635        let (nonce, key) = match version {
636            Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
637            Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
638            _ => unreachable!(),
639        };
640
641        let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
642        pseudo_packet.push(orig_dst_cid.len() as u8);
643        pseudo_packet.extend_from_slice(orig_dst_cid);
644        pseudo_packet.extend_from_slice(packet);
645
646        let nonce = aead::Nonce::assume_unique_for_key(nonce);
647        let key = match aead::UnboundKey::new(&aead::AES_128_GCM, &key) {
648            Ok(unbound_key) => aead::LessSafeKey::new(unbound_key),
649            Err(_) => {
650                // This should never happen with our hardcoded keys
651                debug_assert!(false, "Failed to create AEAD key for retry integrity");
652                return [0; 16];
653            }
654        };
655
656        let tag =
657            match key.seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut []) {
658                Ok(tag) => tag,
659                Err(_) => {
660                    debug_assert!(false, "Failed to seal retry integrity tag");
661                    return [0; 16];
662                }
663            };
664        let mut result = [0; 16];
665        result.copy_from_slice(tag.as_ref());
666        result
667    }
668}
669
670pub(crate) fn initial_suite_from_provider(
671    provider: &Arc<rustls::crypto::CryptoProvider>,
672) -> Option<Suite> {
673    provider
674        .cipher_suites
675        .iter()
676        .find_map(|cs| match (cs.suite(), cs.tls13()) {
677            (rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
678                Some(suite.quic_suite())
679            }
680            _ => None,
681        })
682        .flatten()
683}
684
685pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
686    #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
687    let provider = {
688        // Mark KEM-only intent for tests; group restriction wiring follows.
689        DEBUG_KEM_ONLY.store(true, Ordering::Relaxed);
690        rustls::crypto::aws_lc_rs::default_provider()
691    };
692    #[cfg(feature = "rustls-ring")]
693    let provider = rustls::crypto::ring::default_provider();
694    Arc::new(provider)
695}
696
697/// Returns true if the runtime was configured to run in a KEM-only
698/// (ML‑KEM) handshake mode. This is a best-effort diagnostic used in
699/// tests and may return false when the provider does not expose PQ KEM.
700pub fn debug_kem_only_enabled() -> bool {
701    DEBUG_KEM_ONLY.load(Ordering::Relaxed)
702}
703
704fn to_vec(params: &TransportParameters) -> Vec<u8> {
705    let mut bytes = Vec::new();
706    params.write(&mut bytes);
707    bytes
708}
709
710pub(crate) fn initial_keys(
711    version: Version,
712    dst_cid: ConnectionId,
713    side: Side,
714    suite: &Suite,
715) -> Keys {
716    let keys = suite.keys(&dst_cid, side.into(), version);
717    Keys {
718        header: KeyPair {
719            local: Box::new(keys.local.header),
720            remote: Box::new(keys.remote.header),
721        },
722        packet: KeyPair {
723            local: Box::new(keys.local.packet),
724            remote: Box::new(keys.remote.packet),
725        },
726    }
727}
728
729impl crypto::PacketKey for Box<dyn PacketKey> {
730    #[allow(clippy::expect_used)]
731    fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) {
732        let (header, payload_tag) = buf.split_at_mut(header_len);
733        let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len());
734        let tag = self
735            .encrypt_in_place(packet, &*header, payload)
736            .map_err(|_| rustls::Error::General("Packet encryption failed".into()))
737            .expect("Packet encryption should not fail with valid parameters");
738        tag_storage.copy_from_slice(tag.as_ref());
739    }
740
741    fn decrypt(
742        &self,
743        packet: u64,
744        header: &[u8],
745        payload: &mut BytesMut,
746    ) -> Result<(), CryptoError> {
747        let plain = self
748            .decrypt_in_place(packet, header, payload.as_mut())
749            .map_err(|_| CryptoError)?;
750        let plain_len = plain.len();
751        payload.truncate(plain_len);
752        Ok(())
753    }
754
755    fn tag_len(&self) -> usize {
756        (**self).tag_len()
757    }
758
759    fn confidentiality_limit(&self) -> u64 {
760        (**self).confidentiality_limit()
761    }
762
763    fn integrity_limit(&self) -> u64 {
764        (**self).integrity_limit()
765    }
766}
767
768fn interpret_version(version: u32) -> Result<Version, UnsupportedVersion> {
769    match version {
770        0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft),
771        0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1),
772        _ => Err(UnsupportedVersion),
773    }
774}