ant_quic/crypto/
rustls.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8use std::{any::Any, io, str, sync::Arc};
9
10use aws_lc_rs::aead;
11use bytes::BytesMut;
12pub use rustls::Error;
13use rustls::{
14    self, CipherSuite,
15    client::danger::ServerCertVerifier,
16    pki_types::{CertificateDer, PrivateKeyDer, ServerName},
17    quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version},
18};
19#[cfg(feature = "platform-verifier")]
20use rustls_platform_verifier::BuilderVerifierExt;
21
22use std::sync::atomic::{AtomicBool, Ordering};
23
24/// Internal debug flag indicating whether the build/runtime is configured
25/// to prefer ML‑KEM‑only key exchange groups. This is a diagnostic aid used
26/// by tests; it does not by itself enforce KEM selection.
27static DEBUG_KEM_ONLY: AtomicBool = AtomicBool::new(false);
28
29use crate::{
30    ConnectError, ConnectionId, Side, TransportError, TransportErrorCode,
31    crypto::{
32        self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion,
33        tls_extension_simulation::{
34            ExtensionAwareTlsSession, SimulatedExtensionContext, TlsExtensionHooks,
35        },
36    },
37    transport_parameters::TransportParameters,
38};
39
40use crate::crypto::pqc::PqcConfig;
41
42impl From<Side> for rustls::Side {
43    fn from(s: Side) -> Self {
44        match s {
45            Side::Client => Self::Client,
46            Side::Server => Self::Server,
47        }
48    }
49}
50
51/// A rustls TLS session
52pub struct TlsSession {
53    version: Version,
54    got_handshake_data: bool,
55    next_secrets: Option<Secrets>,
56    inner: Connection,
57    suite: Suite,
58}
59
60impl TlsSession {
61    fn side(&self) -> Side {
62        match self.inner {
63            Connection::Client(_) => Side::Client,
64            Connection::Server(_) => Side::Server,
65        }
66    }
67}
68
69impl crypto::Session for TlsSession {
70    fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys {
71        initial_keys(self.version, *dst_cid, side, &self.suite)
72    }
73
74    fn handshake_data(&self) -> Option<Box<dyn Any>> {
75        if !self.got_handshake_data {
76            return None;
77        }
78        Some(Box::new(HandshakeData {
79            protocol: self.inner.alpn_protocol().map(|x| x.into()),
80            server_name: match self.inner {
81                Connection::Client(_) => None,
82                Connection::Server(ref session) => session.server_name().map(|x| x.into()),
83            },
84        }))
85    }
86
87    /// For the rustls `TlsSession`, the `Any` type is `Vec<rustls::pki_types::CertificateDer>`
88    fn peer_identity(&self) -> Option<Box<dyn Any>> {
89        self.inner.peer_certificates().map(|v| -> Box<dyn Any> {
90            Box::new(
91                v.iter()
92                    .map(|v| v.clone().into_owned())
93                    .collect::<Vec<CertificateDer<'static>>>(),
94            )
95        })
96    }
97
98    fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn crypto::PacketKey>)> {
99        let keys = self.inner.zero_rtt_keys()?;
100        Some((Box::new(keys.header), Box::new(keys.packet)))
101    }
102
103    fn early_data_accepted(&self) -> Option<bool> {
104        match self.inner {
105            Connection::Client(ref session) => Some(session.is_early_data_accepted()),
106            _ => None,
107        }
108    }
109
110    fn is_handshaking(&self) -> bool {
111        self.inner.is_handshaking()
112    }
113
114    fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError> {
115        self.inner.read_hs(buf).map_err(|e| {
116            if let Some(alert) = self.inner.alert() {
117                TransportError {
118                    code: TransportErrorCode::crypto(alert.into()),
119                    frame: None,
120                    reason: e.to_string(),
121                }
122            } else {
123                TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}"))
124            }
125        })?;
126        if !self.got_handshake_data {
127            // Hack around the lack of an explicit signal from rustls to reflect ClientHello being
128            // ready on incoming connections, or ALPN negotiation completing on outgoing
129            // connections.
130            let have_server_name = match self.inner {
131                Connection::Client(_) => false,
132                Connection::Server(ref session) => session.server_name().is_some(),
133            };
134            if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() {
135                self.got_handshake_data = true;
136                return Ok(true);
137            }
138        }
139        Ok(false)
140    }
141
142    fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError> {
143        match self.inner.quic_transport_parameters() {
144            None => Ok(None),
145            Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) {
146                Ok(params) => Ok(Some(params)),
147                Err(e) => Err(e.into()),
148            },
149        }
150    }
151
152    fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys> {
153        let keys = match self.inner.write_hs(buf)? {
154            KeyChange::Handshake { keys } => keys,
155            KeyChange::OneRtt { keys, next } => {
156                self.next_secrets = Some(next);
157                keys
158            }
159        };
160
161        Some(Keys {
162            header: KeyPair {
163                local: Box::new(keys.local.header),
164                remote: Box::new(keys.remote.header),
165            },
166            packet: KeyPair {
167                local: Box::new(keys.local.packet),
168                remote: Box::new(keys.remote.packet),
169            },
170        })
171    }
172
173    fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn crypto::PacketKey>>> {
174        let secrets = self.next_secrets.as_mut()?;
175        let keys = secrets.next_packet_keys();
176        Some(KeyPair {
177            local: Box::new(keys.local),
178            remote: Box::new(keys.remote),
179        })
180    }
181
182    fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool {
183        let tag_start = match payload.len().checked_sub(16) {
184            Some(x) => x,
185            None => return false,
186        };
187
188        let mut pseudo_packet =
189            Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
190        pseudo_packet.push(orig_dst_cid.len() as u8);
191        pseudo_packet.extend_from_slice(orig_dst_cid);
192        pseudo_packet.extend_from_slice(header);
193        let tag_start = tag_start + pseudo_packet.len();
194        pseudo_packet.extend_from_slice(payload);
195
196        let (nonce, key) = match self.version {
197            Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
198            Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
199            _ => unreachable!(),
200        };
201
202        let nonce = aead::Nonce::assume_unique_for_key(nonce);
203        let key = match aead::UnboundKey::new(&aead::AES_128_GCM, &key) {
204            Ok(unbound_key) => aead::LessSafeKey::new(unbound_key),
205            Err(_) => {
206                // This should never happen with our hardcoded keys
207                debug_assert!(false, "Failed to create AEAD key for retry integrity");
208                return false;
209            }
210        };
211
212        let (aad, tag) = pseudo_packet.split_at_mut(tag_start);
213        key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok()
214    }
215
216    fn export_keying_material(
217        &self,
218        output: &mut [u8],
219        label: &[u8],
220        context: &[u8],
221    ) -> Result<(), ExportKeyingMaterialError> {
222        self.inner
223            .export_keying_material(output, label, Some(context))
224            .map_err(|_| ExportKeyingMaterialError)?;
225        Ok(())
226    }
227}
228
229const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [
230    0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1,
231];
232const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [
233    0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c,
234];
235
236const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
237    0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
238];
239const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
240    0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
241];
242
243impl crypto::HeaderKey for Box<dyn HeaderProtectionKey> {
244    fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) {
245        let (header, sample) = packet.split_at_mut(pn_offset + 4);
246        let (first, rest) = header.split_at_mut(1);
247        let pn_end = Ord::min(pn_offset + 3, rest.len());
248        if let Err(e) = self.decrypt_in_place(
249            &sample[..self.sample_size()],
250            &mut first[0],
251            &mut rest[pn_offset - 1..pn_end],
252        ) {
253            debug_assert!(false, "Header protection decrypt failed: {:?}", e);
254        }
255    }
256
257    fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) {
258        let (header, sample) = packet.split_at_mut(pn_offset + 4);
259        let (first, rest) = header.split_at_mut(1);
260        let pn_end = Ord::min(pn_offset + 3, rest.len());
261        if let Err(e) = self.encrypt_in_place(
262            &sample[..self.sample_size()],
263            &mut first[0],
264            &mut rest[pn_offset - 1..pn_end],
265        ) {
266            debug_assert!(false, "Header protection encrypt failed: {:?}", e);
267        }
268    }
269
270    fn sample_size(&self) -> usize {
271        self.sample_len()
272    }
273}
274
275/// Authentication data for (rustls) TLS session
276pub struct HandshakeData {
277    /// The negotiated application protocol, if ALPN is in use
278    ///
279    /// Guaranteed to be set if a nonempty list of protocols was specified for this connection.
280    pub protocol: Option<Vec<u8>>,
281    /// The server name specified by the client, if any
282    ///
283    /// Always `None` for outgoing connections
284    pub server_name: Option<String>,
285}
286
287/// A QUIC-compatible TLS client configuration
288///
289/// Quinn implicitly constructs a `QuicClientConfig` with reasonable defaults within
290/// [`ClientConfig::with_root_certificates()`][root_certs] and [`ClientConfig::with_platform_verifier()`][platform].
291/// Alternatively, `QuicClientConfig`'s [`TryFrom`] implementation can be used to wrap around a
292/// custom [`rustls::ClientConfig`], in which case care should be taken around certain points:
293///
294/// - If `enable_early_data` is not set to true, then sending 0-RTT data will not be possible on
295///   outgoing connections.
296/// - The [`rustls::ClientConfig`] must have TLS 1.3 support enabled for conversion to succeed.
297///
298/// The object in the `resumption` field of the inner [`rustls::ClientConfig`] determines whether
299/// calling `into_0rtt` on outgoing connections returns `Ok` or `Err`. It typically allows
300/// `into_0rtt` to proceed if it recognizes the server name, and defaults to an in-memory cache of
301/// 256 server names.
302///
303/// [root_certs]: crate::config::ClientConfig::with_root_certificates()
304/// [platform]: crate::config::ClientConfig::with_platform_verifier()
305pub struct QuicClientConfig {
306    pub(crate) inner: Arc<rustls::ClientConfig>,
307    initial: Suite,
308    /// Optional RFC 7250 extension context for certificate type negotiation
309    pub(crate) extension_context: Option<Arc<SimulatedExtensionContext>>,
310}
311
312impl QuicClientConfig {
313    #[cfg(feature = "platform-verifier")]
314    pub(crate) fn with_platform_verifier() -> Result<Self, Error> {
315        // Keep in sync with `inner()` below
316        let mut inner = rustls::ClientConfig::builder_with_provider(configured_provider())
317            .with_protocol_versions(&[&rustls::version::TLS13])
318            .map_err(|_| Error::General("default providers should support TLS 1.3".into()))?
319            .with_platform_verifier()?
320            .with_no_client_auth();
321
322        inner.enable_early_data = true;
323        Ok(Self {
324            // We're confident that the *ring* default provider contains TLS13_AES_128_GCM_SHA256
325            initial: initial_suite_from_provider(inner.crypto_provider())
326                .ok_or_else(|| Error::General("no initial cipher suite found".into()))?,
327            inner: Arc::new(inner),
328            extension_context: None,
329        })
330    }
331
332    /// Initialize a sane QUIC-compatible TLS client configuration
333    ///
334    /// QUIC requires that TLS 1.3 be enabled. Advanced users can use any [`rustls::ClientConfig`] that
335    /// satisfies this requirement.
336    pub(crate) fn new(verifier: Arc<dyn ServerCertVerifier>) -> Result<Self, Error> {
337        let inner = Self::inner(verifier)?;
338        Ok(Self {
339            // We're confident that the *ring* default provider contains TLS13_AES_128_GCM_SHA256
340            initial: initial_suite_from_provider(inner.crypto_provider())
341                .ok_or_else(|| Error::General("no initial cipher suite found".into()))?,
342            inner: Arc::new(inner),
343            extension_context: None,
344        })
345    }
346
347    /// Initialize a QUIC-compatible TLS client configuration with a separate initial cipher suite
348    ///
349    /// This is useful if you want to avoid the initial cipher suite for traffic encryption.
350    pub fn with_initial(
351        inner: Arc<rustls::ClientConfig>,
352        initial: Suite,
353    ) -> Result<Self, NoInitialCipherSuite> {
354        match initial.suite.common.suite {
355            CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self {
356                inner,
357                initial,
358                extension_context: None,
359            }),
360            _ => Err(NoInitialCipherSuite { specific: true }),
361        }
362    }
363
364    /// Set the certificate type extension context for RFC 7250 support
365    pub fn with_extension_context(mut self, context: Arc<SimulatedExtensionContext>) -> Self {
366        self.extension_context = Some(context);
367        self
368    }
369
370    pub(crate) fn inner(
371        verifier: Arc<dyn ServerCertVerifier>,
372    ) -> Result<rustls::ClientConfig, Error> {
373        // Keep in sync with `with_platform_verifier()` above
374        let mut config = rustls::ClientConfig::builder_with_provider(configured_provider())
375            .with_protocol_versions(&[&rustls::version::TLS13])
376            .map_err(|_| Error::General("default providers should support TLS 1.3".into()))?
377            .dangerous()
378            .with_custom_certificate_verifier(verifier)
379            .with_no_client_auth();
380
381        config.enable_early_data = true;
382        Ok(config)
383    }
384}
385
386impl crypto::ClientConfig for QuicClientConfig {
387    fn start_session(
388        self: Arc<Self>,
389        version: u32,
390        server_name: &str,
391        params: &TransportParameters,
392    ) -> Result<Box<dyn crypto::Session>, ConnectError> {
393        let version = interpret_version(version)?;
394        let inner_session = Box::new(TlsSession {
395            version,
396            got_handshake_data: false,
397            next_secrets: None,
398            inner: rustls::quic::Connection::Client(rustls::quic::ClientConnection::new(
399                self.inner.clone(),
400                version,
401                ServerName::try_from(server_name)
402                    .map_err(|_| ConnectError::InvalidServerName(server_name.into()))?
403                    .to_owned(),
404                to_vec(params),
405            )?),
406            suite: self.initial,
407        });
408
409        // Wrap with extension awareness if RFC 7250 support is enabled
410        if let Some(extension_context) = &self.extension_context {
411            let conn_id = format!(
412                "client-{}-{}",
413                server_name,
414                std::time::SystemTime::now()
415                    .duration_since(std::time::UNIX_EPOCH)
416                    .unwrap_or_else(|_| std::time::Duration::from_secs(0))
417                    .as_nanos()
418            );
419            Ok(Box::new(ExtensionAwareTlsSession::new(
420                inner_session,
421                extension_context.clone() as Arc<dyn TlsExtensionHooks>,
422                conn_id,
423                true, // is_client
424            )))
425        } else {
426            Ok(inner_session)
427        }
428    }
429}
430
431impl TryFrom<rustls::ClientConfig> for QuicClientConfig {
432    type Error = NoInitialCipherSuite;
433
434    fn try_from(inner: rustls::ClientConfig) -> Result<Self, Self::Error> {
435        Arc::new(inner).try_into()
436    }
437}
438
439impl TryFrom<Arc<rustls::ClientConfig>> for QuicClientConfig {
440    type Error = NoInitialCipherSuite;
441
442    fn try_from(inner: Arc<rustls::ClientConfig>) -> Result<Self, Self::Error> {
443        Ok(Self {
444            initial: initial_suite_from_provider(inner.crypto_provider())
445                .ok_or(NoInitialCipherSuite { specific: false })?,
446            inner,
447            extension_context: None,
448        })
449    }
450}
451
452/// The initial cipher suite (AES-128-GCM-SHA256) is not available
453///
454/// When the cipher suite is supplied `with_initial()`, it must be
455/// [`CipherSuite::TLS13_AES_128_GCM_SHA256`]. When the cipher suite is derived from a config's
456/// [`CryptoProvider`][provider], that provider must reference a cipher suite with the same ID.
457///
458/// [provider]: rustls::crypto::CryptoProvider
459#[derive(Clone, Debug)]
460pub struct NoInitialCipherSuite {
461    /// Whether the initial cipher suite was supplied by the caller
462    specific: bool,
463}
464
465impl std::fmt::Display for NoInitialCipherSuite {
466    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
467        f.write_str(match self.specific {
468            true => "invalid cipher suite specified",
469            false => "no initial cipher suite found",
470        })
471    }
472}
473
474impl std::error::Error for NoInitialCipherSuite {}
475
476/// A QUIC-compatible TLS server configuration
477///
478/// Quinn implicitly constructs a `QuicServerConfig` with reasonable defaults within
479/// [`ServerConfig::with_single_cert()`][single]. Alternatively, `QuicServerConfig`'s [`TryFrom`]
480/// implementation or `with_initial` method can be used to wrap around a custom
481/// [`rustls::ServerConfig`], in which case care should be taken around certain points:
482///
483/// - If `max_early_data_size` is not set to `u32::MAX`, the server will not be able to accept
484///   incoming 0-RTT data. QUIC prohibits `max_early_data_size` values other than 0 or `u32::MAX`.
485/// - The `rustls::ServerConfig` must have TLS 1.3 support enabled for conversion to succeed.
486///
487/// [single]: crate::config::ServerConfig::with_single_cert()
488pub struct QuicServerConfig {
489    inner: Arc<rustls::ServerConfig>,
490    initial: Suite,
491    /// Optional RFC 7250 extension context for certificate type negotiation
492    pub(crate) extension_context: Option<Arc<SimulatedExtensionContext>>,
493}
494
495impl QuicServerConfig {
496    pub(crate) fn new(
497        cert_chain: Vec<CertificateDer<'static>>,
498        key: PrivateKeyDer<'static>,
499    ) -> Result<Self, rustls::Error> {
500        let inner = Self::inner(cert_chain, key)?;
501        Ok(Self {
502            // We're confident that the *ring* default provider contains TLS13_AES_128_GCM_SHA256
503            initial: initial_suite_from_provider(inner.crypto_provider())
504                .ok_or_else(|| rustls::Error::General("no initial cipher suite found".into()))?,
505            inner: Arc::new(inner),
506            extension_context: None,
507        })
508    }
509
510    /// Set the certificate type extension context for RFC 7250 support
511    pub fn with_extension_context(mut self, context: Arc<SimulatedExtensionContext>) -> Self {
512        self.extension_context = Some(context);
513        self
514    }
515
516    /// Initialize a QUIC-compatible TLS client configuration with a separate initial cipher suite
517    ///
518    /// This is useful if you want to avoid the initial cipher suite for traffic encryption.
519    pub fn with_initial(
520        inner: Arc<rustls::ServerConfig>,
521        initial: Suite,
522    ) -> Result<Self, NoInitialCipherSuite> {
523        match initial.suite.common.suite {
524            CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self {
525                inner,
526                initial,
527                extension_context: None,
528            }),
529            _ => Err(NoInitialCipherSuite { specific: true }),
530        }
531    }
532
533    /// Initialize a sane QUIC-compatible TLS server configuration
534    ///
535    /// QUIC requires that TLS 1.3 be enabled, and that the maximum early data size is either 0 or
536    /// `u32::MAX`. Advanced users can use any [`rustls::ServerConfig`] that satisfies these
537    /// requirements.
538    pub(crate) fn inner(
539        cert_chain: Vec<CertificateDer<'static>>,
540        key: PrivateKeyDer<'static>,
541    ) -> Result<rustls::ServerConfig, rustls::Error> {
542        let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider())
543            .with_protocol_versions(&[&rustls::version::TLS13])
544            .map_err(|_| rustls::Error::General("TLS 1.3 not supported".into()))? // The *ring* default provider supports TLS 1.3
545            .with_no_client_auth()
546            .with_single_cert(cert_chain, key)?;
547
548        inner.max_early_data_size = u32::MAX;
549        Ok(inner)
550    }
551}
552
553impl TryFrom<rustls::ServerConfig> for QuicServerConfig {
554    type Error = NoInitialCipherSuite;
555
556    fn try_from(inner: rustls::ServerConfig) -> Result<Self, Self::Error> {
557        Arc::new(inner).try_into()
558    }
559}
560
561impl TryFrom<Arc<rustls::ServerConfig>> for QuicServerConfig {
562    type Error = NoInitialCipherSuite;
563
564    fn try_from(inner: Arc<rustls::ServerConfig>) -> Result<Self, Self::Error> {
565        Ok(Self {
566            initial: initial_suite_from_provider(inner.crypto_provider())
567                .ok_or(NoInitialCipherSuite { specific: false })?,
568            inner,
569            extension_context: None,
570        })
571    }
572}
573
574impl crypto::ServerConfig for QuicServerConfig {
575    #[allow(clippy::expect_used)]
576    fn start_session(
577        self: Arc<Self>,
578        version: u32,
579        params: &TransportParameters,
580    ) -> Box<dyn crypto::Session> {
581        // Safe: `start_session()` is never called if `initial_keys()` rejected `version`
582        let version = interpret_version(version).map_err(|_| {
583            rustls::Error::General("Invalid QUIC version for server connection".into())
584        }).expect("Version should be valid at this point - start_session() is never called if initial_keys() rejected version");
585        let inner_session = Box::new(TlsSession {
586            version,
587            got_handshake_data: false,
588            next_secrets: None,
589            inner: rustls::quic::Connection::Server(
590                rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params))
591                    .map_err(|_| {
592                        rustls::Error::General("Failed to create server connection".into())
593                    })
594                    .expect("Server connection creation should not fail with valid parameters"),
595            ),
596            suite: self.initial,
597        });
598
599        // Wrap with extension awareness if RFC 7250 support is enabled
600        if let Some(extension_context) = &self.extension_context {
601            let conn_id = format!(
602                "server-{}",
603                std::time::SystemTime::now()
604                    .duration_since(std::time::UNIX_EPOCH)
605                    .unwrap_or_else(|_| std::time::Duration::from_secs(0))
606                    .as_nanos()
607            );
608            Box::new(ExtensionAwareTlsSession::new(
609                inner_session,
610                extension_context.clone() as Arc<dyn TlsExtensionHooks>,
611                conn_id,
612                false, // is_client = false for server
613            ))
614        } else {
615            inner_session
616        }
617    }
618
619    fn initial_keys(
620        &self,
621        version: u32,
622        dst_cid: &ConnectionId,
623    ) -> Result<Keys, UnsupportedVersion> {
624        let version = interpret_version(version)?;
625        Ok(initial_keys(version, *dst_cid, Side::Server, &self.initial))
626    }
627
628    #[allow(clippy::expect_used)]
629    fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
630        // Safe: `start_session()` is never called if `initial_keys()` rejected `version`
631        let version = interpret_version(version).map_err(|_| {
632            rustls::Error::General("Invalid QUIC version for retry tag".into())
633        }).expect("Version should be valid at this point - retry_tag() is never called if initial_keys() rejected version");
634        let (nonce, key) = match version {
635            Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
636            Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
637            _ => unreachable!(),
638        };
639
640        let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
641        pseudo_packet.push(orig_dst_cid.len() as u8);
642        pseudo_packet.extend_from_slice(orig_dst_cid);
643        pseudo_packet.extend_from_slice(packet);
644
645        let nonce = aead::Nonce::assume_unique_for_key(nonce);
646        let key = match aead::UnboundKey::new(&aead::AES_128_GCM, &key) {
647            Ok(unbound_key) => aead::LessSafeKey::new(unbound_key),
648            Err(_) => {
649                // This should never happen with our hardcoded keys
650                debug_assert!(false, "Failed to create AEAD key for retry integrity");
651                return [0; 16];
652            }
653        };
654
655        let tag =
656            match key.seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut []) {
657                Ok(tag) => tag,
658                Err(_) => {
659                    debug_assert!(false, "Failed to seal retry integrity tag");
660                    return [0; 16];
661                }
662            };
663        let mut result = [0; 16];
664        result.copy_from_slice(tag.as_ref());
665        result
666    }
667}
668
669pub(crate) fn initial_suite_from_provider(
670    provider: &Arc<rustls::crypto::CryptoProvider>,
671) -> Option<Suite> {
672    provider
673        .cipher_suites
674        .iter()
675        .find_map(|cs| match (cs.suite(), cs.tls13()) {
676            (rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
677                Some(suite.quic_suite())
678            }
679            _ => None,
680        })
681        .flatten()
682}
683
684pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
685    // Mark KEM-only intent for tests; group restriction wiring follows.
686    DEBUG_KEM_ONLY.store(true, Ordering::Relaxed);
687    let provider = rustls::crypto::aws_lc_rs::default_provider();
688    Arc::new(provider)
689}
690
691/// Create a CryptoProvider with PQC support
692///
693/// v0.13.0+: PQC is always enabled. When a `PqcConfig` is provided, this function
694/// creates a provider that uses only PQC key exchange groups (ML-KEM-768 or hybrid
695/// groups containing ML-KEM). Classical algorithms like X25519 are excluded.
696pub fn configured_provider_with_pqc(
697    config: Option<&PqcConfig>,
698) -> Arc<rustls::crypto::CryptoProvider> {
699    // v0.13.0+: Always use PQC
700    DEBUG_KEM_ONLY.store(true, Ordering::Relaxed);
701
702    match config {
703        Some(pqc_config) => {
704            // Use the PQC crypto provider factory
705            match crate::crypto::pqc::create_crypto_provider(pqc_config) {
706                Ok(provider) => provider,
707                Err(e) => {
708                    tracing::warn!(
709                        "Failed to create PQC provider: {:?}, falling back to default",
710                        e
711                    );
712                    configured_provider()
713                }
714            }
715        }
716        None => configured_provider(),
717    }
718}
719
720/// Validate that a connection used PQC algorithms
721///
722/// v0.13.0+: After a TLS handshake completes, this validates that the
723/// negotiated key exchange group is a PQC group. All connections must
724/// use PQC in v0.13.0+.
725pub fn validate_pqc_connection(
726    negotiated_group: rustls::NamedGroup,
727) -> Result<(), crate::crypto::pqc::PqcError> {
728    crate::crypto::pqc::validate_negotiated_group(negotiated_group)
729}
730
731/// Returns true if the runtime was configured to run in a KEM-only
732/// (ML‑KEM) handshake mode. This is a best-effort diagnostic used in
733/// tests and may return false when the provider does not expose PQ KEM.
734pub fn debug_kem_only_enabled() -> bool {
735    DEBUG_KEM_ONLY.load(Ordering::Relaxed)
736}
737
738fn to_vec(params: &TransportParameters) -> Vec<u8> {
739    let mut bytes = Vec::new();
740    params.write(&mut bytes);
741    bytes
742}
743
744pub(crate) fn initial_keys(
745    version: Version,
746    dst_cid: ConnectionId,
747    side: Side,
748    suite: &Suite,
749) -> Keys {
750    let keys = suite.keys(&dst_cid, side.into(), version);
751    Keys {
752        header: KeyPair {
753            local: Box::new(keys.local.header),
754            remote: Box::new(keys.remote.header),
755        },
756        packet: KeyPair {
757            local: Box::new(keys.local.packet),
758            remote: Box::new(keys.remote.packet),
759        },
760    }
761}
762
763impl crypto::PacketKey for Box<dyn PacketKey> {
764    #[allow(clippy::expect_used)]
765    fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) {
766        let (header, payload_tag) = buf.split_at_mut(header_len);
767        let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len());
768        let tag = self
769            .encrypt_in_place(packet, &*header, payload)
770            .map_err(|_| rustls::Error::General("Packet encryption failed".into()))
771            .expect("Packet encryption should not fail with valid parameters");
772        tag_storage.copy_from_slice(tag.as_ref());
773    }
774
775    fn decrypt(
776        &self,
777        packet: u64,
778        header: &[u8],
779        payload: &mut BytesMut,
780    ) -> Result<(), CryptoError> {
781        let plain = self
782            .decrypt_in_place(packet, header, payload.as_mut())
783            .map_err(|_| CryptoError)?;
784        let plain_len = plain.len();
785        payload.truncate(plain_len);
786        Ok(())
787    }
788
789    fn tag_len(&self) -> usize {
790        (**self).tag_len()
791    }
792
793    fn confidentiality_limit(&self) -> u64 {
794        (**self).confidentiality_limit()
795    }
796
797    fn integrity_limit(&self) -> u64 {
798        (**self).integrity_limit()
799    }
800}
801
802fn interpret_version(version: u32) -> Result<Version, UnsupportedVersion> {
803    match version {
804        0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft),
805        0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1),
806        _ => Err(UnsupportedVersion),
807    }
808}