aranya_internal_rustls/msgs/
handshake.rs

1use alloc::collections::BTreeSet;
2#[cfg(feature = "logging")]
3use alloc::string::String;
4use alloc::vec;
5use alloc::vec::Vec;
6use core::ops::Deref;
7use core::{fmt, iter};
8
9use pki_types::{CertificateDer, DnsName};
10
11#[cfg(feature = "tls12")]
12use crate::crypto::ActiveKeyExchange;
13use crate::crypto::SecureRandom;
14use crate::enums::{
15    CertificateCompressionAlgorithm, CipherSuite, EchClientHelloType, HandshakeType,
16    ProtocolVersion, SignatureScheme,
17};
18use crate::error::InvalidMessage;
19#[cfg(feature = "tls12")]
20use crate::ffdhe_groups::FfdheGroup;
21use crate::log::warn;
22use crate::msgs::base::{MaybeEmpty, NonEmpty, Payload, PayloadU8, PayloadU16, PayloadU24};
23use crate::msgs::codec::{self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement};
24use crate::msgs::enums::{
25    CertificateStatusType, CertificateType, ClientCertificateType, Compression, ECCurveType,
26    ECPointFormat, EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest,
27    NamedGroup, PskKeyExchangeMode, ServerNameType,
28};
29use crate::rand;
30use crate::sync::Arc;
31use crate::verify::DigitallySignedStruct;
32use crate::x509::wrap_in_sequence;
33
34/// Create a newtype wrapper around a given type.
35///
36/// This is used to create newtypes for the various TLS message types which is used to wrap
37/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
38/// anything other than access to the underlying bytes.
39macro_rules! wrapped_payload(
40  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident$(<$inner_ty:ty>)?,) => {
41    $(#[$comment])*
42    #[derive(Clone, Debug)]
43    $vis struct $name($inner$(<$inner_ty>)?);
44
45    impl From<Vec<u8>> for $name {
46        fn from(v: Vec<u8>) -> Self {
47            Self($inner::new(v))
48        }
49    }
50
51    impl AsRef<[u8]> for $name {
52        fn as_ref(&self) -> &[u8] {
53            self.0.0.as_slice()
54        }
55    }
56
57    impl Codec<'_> for $name {
58        fn encode(&self, bytes: &mut Vec<u8>) {
59            self.0.encode(bytes);
60        }
61
62        fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
63            Ok(Self($inner::read(r)?))
64        }
65    }
66  }
67);
68
69#[derive(Clone, Copy, Eq, PartialEq)]
70pub struct Random(pub(crate) [u8; 32]);
71
72impl fmt::Debug for Random {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        super::base::hex(f, &self.0)
75    }
76}
77
78static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
79    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
80    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
81]);
82
83static ZERO_RANDOM: Random = Random([0u8; 32]);
84
85impl Codec<'_> for Random {
86    fn encode(&self, bytes: &mut Vec<u8>) {
87        bytes.extend_from_slice(&self.0);
88    }
89
90    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
91        let Some(bytes) = r.take(32) else {
92            return Err(InvalidMessage::MissingData("Random"));
93        };
94
95        let mut opaque = [0; 32];
96        opaque.clone_from_slice(bytes);
97        Ok(Self(opaque))
98    }
99}
100
101impl Random {
102    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
103        let mut data = [0u8; 32];
104        secure_random.fill(&mut data)?;
105        Ok(Self(data))
106    }
107}
108
109impl From<[u8; 32]> for Random {
110    #[inline]
111    fn from(bytes: [u8; 32]) -> Self {
112        Self(bytes)
113    }
114}
115
116#[derive(Copy, Clone)]
117pub struct SessionId {
118    len: usize,
119    data: [u8; 32],
120}
121
122impl fmt::Debug for SessionId {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        super::base::hex(f, &self.data[..self.len])
125    }
126}
127
128impl PartialEq for SessionId {
129    fn eq(&self, other: &Self) -> bool {
130        if self.len != other.len {
131            return false;
132        }
133
134        let mut diff = 0u8;
135        for i in 0..self.len {
136            diff |= self.data[i] ^ other.data[i];
137        }
138
139        diff == 0u8
140    }
141}
142
143impl Codec<'_> for SessionId {
144    fn encode(&self, bytes: &mut Vec<u8>) {
145        debug_assert!(self.len <= 32);
146        bytes.push(self.len as u8);
147        bytes.extend_from_slice(self.as_ref());
148    }
149
150    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
151        let len = u8::read(r)? as usize;
152        if len > 32 {
153            return Err(InvalidMessage::TrailingData("SessionID"));
154        }
155
156        let Some(bytes) = r.take(len) else {
157            return Err(InvalidMessage::MissingData("SessionID"));
158        };
159
160        let mut out = [0u8; 32];
161        out[..len].clone_from_slice(&bytes[..len]);
162        Ok(Self { data: out, len })
163    }
164}
165
166impl SessionId {
167    pub fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
168        let mut data = [0u8; 32];
169        secure_random.fill(&mut data)?;
170        Ok(Self { data, len: 32 })
171    }
172
173    pub(crate) fn empty() -> Self {
174        Self {
175            data: [0u8; 32],
176            len: 0,
177        }
178    }
179
180    #[cfg(feature = "tls12")]
181    pub(crate) fn is_empty(&self) -> bool {
182        self.len == 0
183    }
184}
185
186impl AsRef<[u8]> for SessionId {
187    fn as_ref(&self) -> &[u8] {
188        &self.data[..self.len]
189    }
190}
191
192#[derive(Clone, Debug, PartialEq)]
193pub struct UnknownExtension {
194    pub(crate) typ: ExtensionType,
195    pub(crate) payload: Payload<'static>,
196}
197
198impl UnknownExtension {
199    fn encode(&self, bytes: &mut Vec<u8>) {
200        self.payload.encode(bytes);
201    }
202
203    fn read(typ: ExtensionType, r: &mut Reader<'_>) -> Self {
204        let payload = Payload::read(r).into_owned();
205        Self { typ, payload }
206    }
207}
208
209/// RFC8422: `ECPointFormat ec_point_format_list<1..2^8-1>`
210impl TlsListElement for ECPointFormat {
211    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
212        empty_error: InvalidMessage::IllegalEmptyList("ECPointFormats"),
213    };
214}
215
216/// RFC8422: `NamedCurve named_curve_list<2..2^16-1>`
217impl TlsListElement for NamedGroup {
218    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
219        empty_error: InvalidMessage::IllegalEmptyList("NamedGroups"),
220    };
221}
222
223/// RFC8446: `SignatureScheme supported_signature_algorithms<2..2^16-2>;`
224impl TlsListElement for SignatureScheme {
225    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
226        empty_error: InvalidMessage::NoSignatureSchemes,
227    };
228}
229
230#[derive(Clone, Debug)]
231pub enum ServerNamePayload<'a> {
232    /// A successfully decoded value:
233    SingleDnsName(DnsName<'a>),
234
235    /// A DNS name which was actually an IP address
236    IpAddress,
237
238    /// A successfully decoded, but syntactically-invalid value.
239    Invalid,
240}
241
242impl ServerNamePayload<'_> {
243    fn into_owned(self) -> ServerNamePayload<'static> {
244        match self {
245            Self::SingleDnsName(d) => ServerNamePayload::SingleDnsName(d.to_owned()),
246            Self::IpAddress => ServerNamePayload::IpAddress,
247            Self::Invalid => ServerNamePayload::Invalid,
248        }
249    }
250
251    /// RFC6066: `ServerName server_name_list<1..2^16-1>`
252    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
253        empty_error: InvalidMessage::IllegalEmptyList("ServerNames"),
254    };
255}
256
257/// Simplified encoding/decoding for a `ServerName` extension payload to/from `DnsName`
258///
259/// This is possible because:
260///
261/// - the spec (RFC6066) disallows multiple names for a given name type
262/// - name types other than ServerNameType::HostName are not defined, and they and
263///   any data that follows them cannot be skipped over.
264impl<'a> Codec<'a> for ServerNamePayload<'a> {
265    fn encode(&self, bytes: &mut Vec<u8>) {
266        let server_name_list = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
267
268        let ServerNamePayload::SingleDnsName(dns_name) = self else {
269            return;
270        };
271
272        ServerNameType::HostName.encode(server_name_list.buf);
273        let name_slice = dns_name.as_ref().as_bytes();
274        (name_slice.len() as u16).encode(server_name_list.buf);
275        server_name_list
276            .buf
277            .extend_from_slice(name_slice);
278    }
279
280    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
281        let mut found = None;
282
283        let len = Self::SIZE_LEN.read(r)?;
284        let mut sub = r.sub(len)?;
285
286        while sub.any_left() {
287            let typ = ServerNameType::read(&mut sub)?;
288
289            let payload = match typ {
290                ServerNameType::HostName => HostNamePayload::read(&mut sub)?,
291                _ => {
292                    // Consume remainder of extension bytes.  Since the length of the item
293                    // is an unknown encoding, we cannot continue.
294                    sub.rest();
295                    break;
296                }
297            };
298
299            // "The ServerNameList MUST NOT contain more than one name of
300            // the same name_type." - RFC6066
301            if found.is_some() {
302                warn!("Illegal SNI extension: duplicate host_name received");
303                return Err(InvalidMessage::InvalidServerName);
304            }
305
306            found = match payload {
307                HostNamePayload::HostName(dns_name) => {
308                    Some(Self::SingleDnsName(dns_name.to_owned()))
309                }
310
311                HostNamePayload::IpAddress(_invalid) => {
312                    warn!(
313                        "Illegal SNI extension: ignoring IP address presented as hostname ({:?})",
314                        _invalid
315                    );
316                    Some(Self::IpAddress)
317                }
318
319                HostNamePayload::Invalid(_invalid) => {
320                    warn!(
321                        "Illegal SNI hostname received {:?}",
322                        String::from_utf8_lossy(&_invalid.0)
323                    );
324                    Some(Self::Invalid)
325                }
326            };
327        }
328
329        Ok(found.unwrap_or(Self::Invalid))
330    }
331}
332
333impl<'a> From<&DnsName<'a>> for ServerNamePayload<'static> {
334    fn from(value: &DnsName<'a>) -> Self {
335        Self::SingleDnsName(trim_hostname_trailing_dot_for_sni(value))
336    }
337}
338
339#[derive(Clone, Debug)]
340pub(crate) enum HostNamePayload {
341    HostName(DnsName<'static>),
342    IpAddress(PayloadU16<NonEmpty>),
343    Invalid(PayloadU16<NonEmpty>),
344}
345
346impl HostNamePayload {
347    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
348        use pki_types::ServerName;
349        let raw = PayloadU16::<NonEmpty>::read(r)?;
350
351        match ServerName::try_from(raw.0.as_slice()) {
352            Ok(ServerName::DnsName(d)) => Ok(Self::HostName(d.to_owned())),
353            Ok(ServerName::IpAddress(_)) => Ok(Self::IpAddress(raw)),
354            Ok(_) | Err(_) => Ok(Self::Invalid(raw)),
355        }
356    }
357}
358
359wrapped_payload!(
360    /// RFC7301: `opaque ProtocolName<1..2^8-1>;`
361    pub struct ProtocolName, PayloadU8<NonEmpty>,
362);
363
364/// RFC7301: `ProtocolName protocol_name_list<2..2^16-1>`
365impl TlsListElement for ProtocolName {
366    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
367        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
368    };
369}
370
371/// RFC7301 encodes a single protocol name as `Vec<ProtocolName>`
372#[derive(Clone, Debug)]
373pub struct SingleProtocolName(ProtocolName);
374
375impl SingleProtocolName {
376    pub(crate) fn new(bytes: Vec<u8>) -> Self {
377        Self(ProtocolName::from(bytes))
378    }
379
380    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
381        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
382    };
383}
384
385impl Codec<'_> for SingleProtocolName {
386    fn encode(&self, bytes: &mut Vec<u8>) {
387        let body = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
388        self.0.encode(body.buf);
389    }
390
391    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
392        let len = Self::SIZE_LEN.read(reader)?;
393        let mut sub = reader.sub(len)?;
394
395        let item = ProtocolName::read(&mut sub)?;
396
397        if sub.any_left() {
398            Err(InvalidMessage::TrailingData("SingleProtocolName"))
399        } else {
400            Ok(Self(item))
401        }
402    }
403}
404
405impl AsRef<[u8]> for SingleProtocolName {
406    fn as_ref(&self) -> &[u8] {
407        self.0.as_ref()
408    }
409}
410
411// --- TLS 1.3 Key shares ---
412#[derive(Clone, Debug)]
413pub struct KeyShareEntry {
414    pub(crate) group: NamedGroup,
415    /// RFC8446: `opaque key_exchange<1..2^16-1>;`
416    pub(crate) payload: PayloadU16<NonEmpty>,
417}
418
419impl KeyShareEntry {
420    pub fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self {
421        Self {
422            group,
423            payload: PayloadU16::new(payload.into()),
424        }
425    }
426
427    pub fn group(&self) -> NamedGroup {
428        self.group
429    }
430}
431
432impl Codec<'_> for KeyShareEntry {
433    fn encode(&self, bytes: &mut Vec<u8>) {
434        self.group.encode(bytes);
435        self.payload.encode(bytes);
436    }
437
438    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
439        let group = NamedGroup::read(r)?;
440        let payload = PayloadU16::read(r)?;
441
442        Ok(Self { group, payload })
443    }
444}
445
446// --- TLS 1.3 PresharedKey offers ---
447#[derive(Clone, Debug)]
448pub(crate) struct PresharedKeyIdentity {
449    /// RFC8446: `opaque identity<1..2^16-1>;`
450    pub(crate) identity: PayloadU16<NonEmpty>,
451    pub(crate) obfuscated_ticket_age: u32,
452}
453
454impl PresharedKeyIdentity {
455    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
456        Self {
457            identity: PayloadU16::new(id),
458            obfuscated_ticket_age: age,
459        }
460    }
461
462    pub(crate) fn external(id: Vec<u8>) -> Self {
463        // See 4.2.11 of RFC 8446: "For identities established
464        // externally, an obfuscated_ticket_age of 0 SHOULD be
465        // used..."
466        Self::new(id, 0)
467    }
468}
469
470impl Codec<'_> for PresharedKeyIdentity {
471    fn encode(&self, bytes: &mut Vec<u8>) {
472        self.identity.encode(bytes);
473        self.obfuscated_ticket_age.encode(bytes);
474    }
475
476    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
477        Ok(Self {
478            identity: PayloadU16::read(r)?,
479            obfuscated_ticket_age: u32::read(r)?,
480        })
481    }
482}
483
484/// RFC8446: `PskIdentity identities<7..2^16-1>;`
485impl TlsListElement for PresharedKeyIdentity {
486    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
487        empty_error: InvalidMessage::IllegalEmptyList("PskIdentities"),
488    };
489}
490
491wrapped_payload!(
492    /// RFC8446: `opaque PskBinderEntry<32..255>;`
493    pub(crate) struct PresharedKeyBinder, PayloadU8<NonEmpty>,
494);
495
496/// RFC8446: `PskBinderEntry binders<33..2^16-1>;`
497impl TlsListElement for PresharedKeyBinder {
498    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
499        empty_error: InvalidMessage::IllegalEmptyList("PskBinders"),
500    };
501}
502
503#[derive(Clone, Debug)]
504pub struct PresharedKeyOffer {
505    pub(crate) identities: Vec<PresharedKeyIdentity>,
506    pub(crate) binders: Vec<PresharedKeyBinder>,
507}
508
509impl PresharedKeyOffer {
510    /// Make a new one with one entry.
511    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
512        Self {
513            identities: vec![id],
514            binders: vec![PresharedKeyBinder::from(binder)],
515        }
516    }
517}
518
519impl FromIterator<(PresharedKeyIdentity, PresharedKeyBinder)> for PresharedKeyOffer {
520    fn from_iter<I>(iter: I) -> Self
521    where
522        I: IntoIterator<Item = (PresharedKeyIdentity, PresharedKeyBinder)>,
523    {
524        let (identities, binders) = iter.into_iter().unzip();
525        Self {
526            identities,
527            binders,
528        }
529    }
530}
531
532impl Codec<'_> for PresharedKeyOffer {
533    fn encode(&self, bytes: &mut Vec<u8>) {
534        self.identities.encode(bytes);
535        self.binders.encode(bytes);
536    }
537
538    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
539        Ok(Self {
540            identities: Vec::read(r)?,
541            binders: Vec::read(r)?,
542        })
543    }
544}
545
546// --- RFC6066 certificate status request ---
547wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
548
549/// RFC6066: `ResponderID responder_id_list<0..2^16-1>;`
550impl TlsListElement for ResponderId {
551    const SIZE_LEN: ListLength = ListLength::U16;
552}
553
554#[derive(Clone, Debug)]
555pub struct OcspCertificateStatusRequest {
556    pub(crate) responder_ids: Vec<ResponderId>,
557    pub(crate) extensions: PayloadU16,
558}
559
560impl Codec<'_> for OcspCertificateStatusRequest {
561    fn encode(&self, bytes: &mut Vec<u8>) {
562        CertificateStatusType::OCSP.encode(bytes);
563        self.responder_ids.encode(bytes);
564        self.extensions.encode(bytes);
565    }
566
567    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
568        Ok(Self {
569            responder_ids: Vec::read(r)?,
570            extensions: PayloadU16::read(r)?,
571        })
572    }
573}
574
575#[derive(Clone, Debug)]
576pub enum CertificateStatusRequest {
577    Ocsp(OcspCertificateStatusRequest),
578    Unknown((CertificateStatusType, Payload<'static>)),
579}
580
581impl Codec<'_> for CertificateStatusRequest {
582    fn encode(&self, bytes: &mut Vec<u8>) {
583        match self {
584            Self::Ocsp(r) => r.encode(bytes),
585            Self::Unknown((typ, payload)) => {
586                typ.encode(bytes);
587                payload.encode(bytes);
588            }
589        }
590    }
591
592    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
593        let typ = CertificateStatusType::read(r)?;
594
595        match typ {
596            CertificateStatusType::OCSP => {
597                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
598                Ok(Self::Ocsp(ocsp_req))
599            }
600            _ => {
601                let data = Payload::read(r).into_owned();
602                Ok(Self::Unknown((typ, data)))
603            }
604        }
605    }
606}
607
608impl CertificateStatusRequest {
609    pub(crate) fn build_ocsp() -> Self {
610        let ocsp = OcspCertificateStatusRequest {
611            responder_ids: Vec::new(),
612            extensions: PayloadU16::empty(),
613        };
614        Self::Ocsp(ocsp)
615    }
616}
617
618// ---
619
620/// RFC8446: `PskKeyExchangeMode ke_modes<1..255>;`
621impl TlsListElement for PskKeyExchangeMode {
622    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
623        empty_error: InvalidMessage::IllegalEmptyList("PskKeyExchangeModes"),
624    };
625}
626
627/// RFC8446: `KeyShareEntry client_shares<0..2^16-1>;`
628impl TlsListElement for KeyShareEntry {
629    const SIZE_LEN: ListLength = ListLength::U16;
630}
631
632/// The body of the `SupportedVersions` extension when it appears in a
633/// `ClientHello`
634///
635/// This is documented as a preference-order vector, but we (as a server)
636/// ignore the preference of the client.
637///
638/// RFC8446: `ProtocolVersion versions<2..254>;`
639#[derive(Clone, Copy, Debug, Default)]
640pub struct SupportedProtocolVersions {
641    pub(crate) tls13: bool,
642    pub(crate) tls12: bool,
643}
644
645impl SupportedProtocolVersions {
646    /// Return true if `filter` returns true for any enabled version.
647    pub(crate) fn any(&self, filter: impl Fn(ProtocolVersion) -> bool) -> bool {
648        if self.tls13 && filter(ProtocolVersion::TLSv1_3) {
649            return true;
650        }
651        if self.tls12 && filter(ProtocolVersion::TLSv1_2) {
652            return true;
653        }
654        false
655    }
656
657    const LIST_LENGTH: ListLength = ListLength::NonZeroU8 {
658        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
659    };
660}
661
662impl Codec<'_> for SupportedProtocolVersions {
663    fn encode(&self, bytes: &mut Vec<u8>) {
664        let inner = LengthPrefixedBuffer::new(Self::LIST_LENGTH, bytes);
665        if self.tls13 {
666            ProtocolVersion::TLSv1_3.encode(inner.buf);
667        }
668        if self.tls12 {
669            ProtocolVersion::TLSv1_2.encode(inner.buf);
670        }
671    }
672
673    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
674        let len = Self::LIST_LENGTH.read(reader)?;
675        let mut sub = reader.sub(len)?;
676
677        let mut tls12 = false;
678        let mut tls13 = false;
679
680        while sub.any_left() {
681            match ProtocolVersion::read(&mut sub)? {
682                ProtocolVersion::TLSv1_3 => tls13 = true,
683                ProtocolVersion::TLSv1_2 => tls12 = true,
684                _ => continue,
685            };
686        }
687
688        Ok(Self { tls13, tls12 })
689    }
690}
691
692impl TlsListElement for ProtocolVersion {
693    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
694        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
695    };
696}
697
698/// RFC7250: `CertificateType client_certificate_types<1..2^8-1>;`
699///
700/// Ditto `CertificateType server_certificate_types<1..2^8-1>;`
701impl TlsListElement for CertificateType {
702    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
703        empty_error: InvalidMessage::IllegalEmptyList("CertificateTypes"),
704    };
705}
706
707/// RFC8879: `CertificateCompressionAlgorithm algorithms<2..2^8-2>;`
708impl TlsListElement for CertificateCompressionAlgorithm {
709    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
710        empty_error: InvalidMessage::IllegalEmptyList("CertificateCompressionAlgorithms"),
711    };
712}
713
714#[derive(Clone, Debug)]
715pub enum ClientExtension {
716    EcPointFormats(Vec<ECPointFormat>),
717    NamedGroups(Vec<NamedGroup>),
718    SignatureAlgorithms(Vec<SignatureScheme>),
719    ServerName(ServerNamePayload<'static>),
720    SessionTicket(ClientSessionTicket),
721    Protocols(Vec<ProtocolName>),
722    SupportedVersions(SupportedProtocolVersions),
723    KeyShare(Vec<KeyShareEntry>),
724    PresharedKeyModes(Vec<PskKeyExchangeMode>),
725    PresharedKey(PresharedKeyOffer),
726    Cookie(PayloadU16<NonEmpty>),
727    ExtendedMasterSecretRequest,
728    CertificateStatusRequest(CertificateStatusRequest),
729    ServerCertTypes(Vec<CertificateType>),
730    ClientCertTypes(Vec<CertificateType>),
731    TransportParameters(Vec<u8>),
732    TransportParametersDraft(Vec<u8>),
733    EarlyData,
734    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
735    EncryptedClientHello(EncryptedClientHello),
736    EncryptedClientHelloOuterExtensions(Vec<ExtensionType>),
737    AuthorityNames(Vec<DistinguishedName>),
738    Unknown(UnknownExtension),
739}
740
741impl ClientExtension {
742    pub(crate) fn ext_type(&self) -> ExtensionType {
743        match self {
744            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
745            Self::NamedGroups(_) => ExtensionType::EllipticCurves,
746            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
747            Self::ServerName(_) => ExtensionType::ServerName,
748            Self::SessionTicket(_) => ExtensionType::SessionTicket,
749            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
750            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
751            Self::KeyShare(_) => ExtensionType::KeyShare,
752            Self::PresharedKeyModes(_) => ExtensionType::PSKKeyExchangeModes,
753            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
754            Self::Cookie(_) => ExtensionType::Cookie,
755            Self::ExtendedMasterSecretRequest => ExtensionType::ExtendedMasterSecret,
756            Self::CertificateStatusRequest(_) => ExtensionType::StatusRequest,
757            Self::ClientCertTypes(_) => ExtensionType::ClientCertificateType,
758            Self::ServerCertTypes(_) => ExtensionType::ServerCertificateType,
759            Self::TransportParameters(_) => ExtensionType::TransportParameters,
760            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
761            Self::EarlyData => ExtensionType::EarlyData,
762            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
763            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
764            Self::EncryptedClientHelloOuterExtensions(_) => {
765                ExtensionType::EncryptedClientHelloOuterExtensions
766            }
767            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
768            Self::Unknown(r) => r.typ,
769        }
770    }
771}
772
773impl Codec<'_> for ClientExtension {
774    fn encode(&self, bytes: &mut Vec<u8>) {
775        self.ext_type().encode(bytes);
776
777        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
778        match self {
779            Self::EcPointFormats(r) => r.encode(nested.buf),
780            Self::NamedGroups(r) => r.encode(nested.buf),
781            Self::SignatureAlgorithms(r) => r.encode(nested.buf),
782            Self::ServerName(r) => r.encode(nested.buf),
783            Self::SessionTicket(ClientSessionTicket::Request)
784            | Self::ExtendedMasterSecretRequest
785            | Self::EarlyData => {}
786            Self::SessionTicket(ClientSessionTicket::Offer(r)) => r.encode(nested.buf),
787            Self::Protocols(r) => r.encode(nested.buf),
788            Self::SupportedVersions(r) => r.encode(nested.buf),
789            Self::KeyShare(r) => r.encode(nested.buf),
790            Self::PresharedKeyModes(r) => r.encode(nested.buf),
791            Self::PresharedKey(r) => r.encode(nested.buf),
792            Self::Cookie(r) => r.encode(nested.buf),
793            Self::CertificateStatusRequest(r) => r.encode(nested.buf),
794            Self::ClientCertTypes(r) => r.encode(nested.buf),
795            Self::ServerCertTypes(r) => r.encode(nested.buf),
796            Self::TransportParameters(r) | Self::TransportParametersDraft(r) => {
797                nested.buf.extend_from_slice(r);
798            }
799            Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf),
800            Self::EncryptedClientHello(r) => r.encode(nested.buf),
801            Self::EncryptedClientHelloOuterExtensions(r) => r.encode(nested.buf),
802            Self::AuthorityNames(r) => r.encode(nested.buf),
803            Self::Unknown(r) => r.encode(nested.buf),
804        }
805    }
806
807    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
808        let typ = ExtensionType::read(r)?;
809        let len = u16::read(r)? as usize;
810        let mut sub = r.sub(len)?;
811
812        let ext = match typ {
813            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
814            ExtensionType::EllipticCurves => Self::NamedGroups(Vec::read(&mut sub)?),
815            ExtensionType::SignatureAlgorithms => Self::SignatureAlgorithms(Vec::read(&mut sub)?),
816            ExtensionType::ServerName => {
817                Self::ServerName(ServerNamePayload::read(&mut sub)?.into_owned())
818            }
819            ExtensionType::SessionTicket => {
820                if sub.any_left() {
821                    let contents = Payload::read(&mut sub).into_owned();
822                    Self::SessionTicket(ClientSessionTicket::Offer(contents))
823                } else {
824                    Self::SessionTicket(ClientSessionTicket::Request)
825                }
826            }
827            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
828            ExtensionType::SupportedVersions => {
829                Self::SupportedVersions(SupportedProtocolVersions::read(&mut sub)?)
830            }
831            ExtensionType::KeyShare => Self::KeyShare(Vec::read(&mut sub)?),
832            ExtensionType::PSKKeyExchangeModes => Self::PresharedKeyModes(Vec::read(&mut sub)?),
833            ExtensionType::PreSharedKey => Self::PresharedKey(PresharedKeyOffer::read(&mut sub)?),
834            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
835            ExtensionType::ExtendedMasterSecret if !sub.any_left() => {
836                Self::ExtendedMasterSecretRequest
837            }
838            ExtensionType::ClientCertificateType => Self::ClientCertTypes(Vec::read(&mut sub)?),
839            ExtensionType::ServerCertificateType => Self::ServerCertTypes(Vec::read(&mut sub)?),
840            ExtensionType::StatusRequest => {
841                let csr = CertificateStatusRequest::read(&mut sub)?;
842                Self::CertificateStatusRequest(csr)
843            }
844            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
845            ExtensionType::TransportParametersDraft => {
846                Self::TransportParametersDraft(sub.rest().to_vec())
847            }
848            ExtensionType::EarlyData if !sub.any_left() => Self::EarlyData,
849            ExtensionType::CompressCertificate => {
850                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
851            }
852            ExtensionType::EncryptedClientHelloOuterExtensions => {
853                Self::EncryptedClientHelloOuterExtensions(Vec::read(&mut sub)?)
854            }
855            ExtensionType::CertificateAuthorities => Self::AuthorityNames({
856                let items = Vec::read(&mut sub)?;
857                if items.is_empty() {
858                    return Err(InvalidMessage::IllegalEmptyList("DistinguishedNames"));
859                }
860                items
861            }),
862            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
863        };
864
865        sub.expect_empty("ClientExtension")
866            .map(|_| ext)
867    }
868}
869
870fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
871    let dns_name_str = dns_name.as_ref();
872
873    // RFC6066: "The hostname is represented as a byte string using
874    // ASCII encoding without a trailing dot"
875    if dns_name_str.ends_with('.') {
876        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
877        DnsName::try_from(trimmed)
878            .unwrap()
879            .to_owned()
880    } else {
881        dns_name.to_owned()
882    }
883}
884
885#[derive(Clone, Debug)]
886pub enum ClientSessionTicket {
887    Request,
888    Offer(Payload<'static>),
889}
890
891#[derive(Clone, Debug)]
892pub enum ServerExtension {
893    EcPointFormats(Vec<ECPointFormat>),
894    ServerNameAck,
895    SessionTicketAck,
896    RenegotiationInfo(PayloadU8),
897    Protocols(SingleProtocolName),
898    KeyShare(KeyShareEntry),
899    PresharedKey(u16),
900    ExtendedMasterSecretAck,
901    CertificateStatusAck,
902    ServerCertType(CertificateType),
903    ClientCertType(CertificateType),
904    SupportedVersions(ProtocolVersion),
905    TransportParameters(Vec<u8>),
906    TransportParametersDraft(Vec<u8>),
907    EarlyData,
908    EncryptedClientHello(ServerEncryptedClientHello),
909    Unknown(UnknownExtension),
910}
911
912impl ServerExtension {
913    pub(crate) fn ext_type(&self) -> ExtensionType {
914        match self {
915            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
916            Self::ServerNameAck => ExtensionType::ServerName,
917            Self::SessionTicketAck => ExtensionType::SessionTicket,
918            Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo,
919            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
920            Self::KeyShare(_) => ExtensionType::KeyShare,
921            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
922            Self::ClientCertType(_) => ExtensionType::ClientCertificateType,
923            Self::ServerCertType(_) => ExtensionType::ServerCertificateType,
924            Self::ExtendedMasterSecretAck => ExtensionType::ExtendedMasterSecret,
925            Self::CertificateStatusAck => ExtensionType::StatusRequest,
926            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
927            Self::TransportParameters(_) => ExtensionType::TransportParameters,
928            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
929            Self::EarlyData => ExtensionType::EarlyData,
930            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
931            Self::Unknown(r) => r.typ,
932        }
933    }
934}
935
936impl Codec<'_> for ServerExtension {
937    fn encode(&self, bytes: &mut Vec<u8>) {
938        self.ext_type().encode(bytes);
939
940        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
941        match self {
942            Self::EcPointFormats(r) => r.encode(nested.buf),
943            Self::ServerNameAck
944            | Self::SessionTicketAck
945            | Self::ExtendedMasterSecretAck
946            | Self::CertificateStatusAck
947            | Self::EarlyData => {}
948            Self::RenegotiationInfo(r) => r.encode(nested.buf),
949            Self::Protocols(r) => r.encode(nested.buf),
950            Self::KeyShare(r) => r.encode(nested.buf),
951            Self::PresharedKey(r) => r.encode(nested.buf),
952            Self::ClientCertType(r) => r.encode(nested.buf),
953            Self::ServerCertType(r) => r.encode(nested.buf),
954            Self::SupportedVersions(r) => r.encode(nested.buf),
955            Self::TransportParameters(r) | Self::TransportParametersDraft(r) => {
956                nested.buf.extend_from_slice(r);
957            }
958            Self::EncryptedClientHello(r) => r.encode(nested.buf),
959            Self::Unknown(r) => r.encode(nested.buf),
960        }
961    }
962
963    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
964        let typ = ExtensionType::read(r)?;
965        let len = u16::read(r)? as usize;
966        let mut sub = r.sub(len)?;
967
968        let ext = match typ {
969            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
970            ExtensionType::ServerName => Self::ServerNameAck,
971            ExtensionType::SessionTicket => Self::SessionTicketAck,
972            ExtensionType::StatusRequest => Self::CertificateStatusAck,
973            ExtensionType::RenegotiationInfo => Self::RenegotiationInfo(PayloadU8::read(&mut sub)?),
974            ExtensionType::ALProtocolNegotiation => {
975                Self::Protocols(SingleProtocolName::read(&mut sub)?)
976            }
977            ExtensionType::ClientCertificateType => {
978                Self::ClientCertType(CertificateType::read(&mut sub)?)
979            }
980            ExtensionType::ServerCertificateType => {
981                Self::ServerCertType(CertificateType::read(&mut sub)?)
982            }
983            ExtensionType::KeyShare => Self::KeyShare(KeyShareEntry::read(&mut sub)?),
984            ExtensionType::PreSharedKey => Self::PresharedKey(u16::read(&mut sub)?),
985            ExtensionType::ExtendedMasterSecret => Self::ExtendedMasterSecretAck,
986            ExtensionType::SupportedVersions => {
987                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
988            }
989            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
990            ExtensionType::TransportParametersDraft => {
991                Self::TransportParametersDraft(sub.rest().to_vec())
992            }
993            ExtensionType::EarlyData => Self::EarlyData,
994            ExtensionType::EncryptedClientHello => {
995                Self::EncryptedClientHello(ServerEncryptedClientHello::read(&mut sub)?)
996            }
997            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
998        };
999
1000        sub.expect_empty("ServerExtension")
1001            .map(|_| ext)
1002    }
1003}
1004
1005impl ServerExtension {
1006    #[cfg(feature = "tls12")]
1007    pub(crate) fn make_empty_renegotiation_info() -> Self {
1008        let empty = Vec::new();
1009        Self::RenegotiationInfo(PayloadU8::new(empty))
1010    }
1011}
1012
1013#[derive(Clone, Debug)]
1014pub struct ClientHelloPayload {
1015    pub client_version: ProtocolVersion,
1016    pub random: Random,
1017    pub session_id: SessionId,
1018    pub cipher_suites: Vec<CipherSuite>,
1019    pub compression_methods: Vec<Compression>,
1020    pub extensions: Vec<ClientExtension>,
1021}
1022
1023impl Codec<'_> for ClientHelloPayload {
1024    fn encode(&self, bytes: &mut Vec<u8>) {
1025        self.payload_encode(bytes, Encoding::Standard)
1026    }
1027
1028    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1029        let mut ret = Self {
1030            client_version: ProtocolVersion::read(r)?,
1031            random: Random::read(r)?,
1032            session_id: SessionId::read(r)?,
1033            cipher_suites: Vec::read(r)?,
1034            compression_methods: Vec::read(r)?,
1035            extensions: Vec::new(),
1036        };
1037
1038        if r.any_left() {
1039            ret.extensions = Vec::read(r)?;
1040        }
1041
1042        match (r.any_left(), ret.extensions.is_empty()) {
1043            (true, _) => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
1044            (_, true) => Err(InvalidMessage::MissingData("ClientHelloPayload")),
1045            _ => Ok(ret),
1046        }
1047    }
1048}
1049
1050/// RFC8446: `CipherSuite cipher_suites<2..2^16-2>;`
1051impl TlsListElement for CipherSuite {
1052    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
1053        empty_error: InvalidMessage::IllegalEmptyList("CipherSuites"),
1054    };
1055}
1056
1057/// RFC5246: `CompressionMethod compression_methods<1..2^8-1>;`
1058impl TlsListElement for Compression {
1059    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1060        empty_error: InvalidMessage::IllegalEmptyList("Compressions"),
1061    };
1062}
1063
1064impl TlsListElement for ClientExtension {
1065    const SIZE_LEN: ListLength = ListLength::U16;
1066}
1067
1068/// draft-ietf-tls-esni-17: `ExtensionType OuterExtensions<2..254>;`
1069impl TlsListElement for ExtensionType {
1070    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1071        empty_error: InvalidMessage::IllegalEmptyList("ExtensionTypes"),
1072    };
1073}
1074
1075impl ClientHelloPayload {
1076    pub(crate) fn ech_inner_encoding(&self, to_compress: Vec<ExtensionType>) -> Vec<u8> {
1077        let mut bytes = Vec::new();
1078        self.payload_encode(&mut bytes, Encoding::EchInnerHello { to_compress });
1079        bytes
1080    }
1081
1082    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1083        self.client_version.encode(bytes);
1084        self.random.encode(bytes);
1085
1086        match purpose {
1087            // SessionID is required to be empty in the encoded inner client hello.
1088            Encoding::EchInnerHello { .. } => SessionId::empty().encode(bytes),
1089            _ => self.session_id.encode(bytes),
1090        }
1091
1092        self.cipher_suites.encode(bytes);
1093        self.compression_methods.encode(bytes);
1094
1095        let to_compress = match purpose {
1096            // Compressed extensions must be replaced in the encoded inner client hello.
1097            Encoding::EchInnerHello { to_compress } if !to_compress.is_empty() => to_compress,
1098            _ => {
1099                if !self.extensions.is_empty() {
1100                    self.extensions.encode(bytes);
1101                }
1102                return;
1103            }
1104        };
1105
1106        // Safety: not empty check in match guard.
1107        let first_compressed_type = *to_compress.first().unwrap();
1108
1109        // Compressed extensions are in a contiguous range and must be replaced
1110        // with a marker extension.
1111        let compressed_start_idx = self
1112            .extensions
1113            .iter()
1114            .position(|ext| ext.ext_type() == first_compressed_type);
1115        let compressed_end_idx = compressed_start_idx.map(|start| start + to_compress.len());
1116        let marker_ext = ClientExtension::EncryptedClientHelloOuterExtensions(to_compress);
1117
1118        let exts = self
1119            .extensions
1120            .iter()
1121            .enumerate()
1122            .filter_map(|(i, ext)| {
1123                if Some(i) == compressed_start_idx {
1124                    Some(&marker_ext)
1125                } else if Some(i) > compressed_start_idx && Some(i) < compressed_end_idx {
1126                    None
1127                } else {
1128                    Some(ext)
1129                }
1130            });
1131
1132        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1133        for ext in exts {
1134            ext.encode(nested.buf);
1135        }
1136    }
1137
1138    /// Returns true if there is more than one extension of a given
1139    /// type.
1140    pub(crate) fn has_duplicate_extension(&self) -> bool {
1141        has_duplicates::<_, _, u16>(
1142            self.extensions
1143                .iter()
1144                .map(|ext| ext.ext_type()),
1145        )
1146    }
1147
1148    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&ClientExtension> {
1149        self.extensions
1150            .iter()
1151            .find(|x| x.ext_type() == ext)
1152    }
1153
1154    pub(crate) fn sni_extension(&self) -> Option<&ServerNamePayload<'_>> {
1155        let ext = self.find_extension(ExtensionType::ServerName)?;
1156        match ext {
1157            ClientExtension::ServerName(req) => Some(req),
1158            _ => None,
1159        }
1160    }
1161
1162    pub fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
1163        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
1164        match ext {
1165            ClientExtension::SignatureAlgorithms(req) => Some(req),
1166            _ => None,
1167        }
1168    }
1169
1170    pub(crate) fn namedgroups_extension(&self) -> Option<&[NamedGroup]> {
1171        let ext = self.find_extension(ExtensionType::EllipticCurves)?;
1172        match ext {
1173            ClientExtension::NamedGroups(req) => Some(req),
1174            _ => None,
1175        }
1176    }
1177
1178    #[cfg(feature = "tls12")]
1179    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1180        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1181        match ext {
1182            ClientExtension::EcPointFormats(req) => Some(req),
1183            _ => None,
1184        }
1185    }
1186
1187    pub(crate) fn server_certificate_extension(&self) -> Option<&[CertificateType]> {
1188        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
1189        match ext {
1190            ClientExtension::ServerCertTypes(req) => Some(req),
1191            _ => None,
1192        }
1193    }
1194
1195    pub(crate) fn client_certificate_extension(&self) -> Option<&[CertificateType]> {
1196        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
1197        match ext {
1198            ClientExtension::ClientCertTypes(req) => Some(req),
1199            _ => None,
1200        }
1201    }
1202
1203    pub(crate) fn alpn_extension(&self) -> Option<&Vec<ProtocolName>> {
1204        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
1205        match ext {
1206            ClientExtension::Protocols(req) => Some(req),
1207            _ => None,
1208        }
1209    }
1210
1211    pub(crate) fn quic_params_extension(&self) -> Option<Vec<u8>> {
1212        let ext = self
1213            .find_extension(ExtensionType::TransportParameters)
1214            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
1215        match ext {
1216            ClientExtension::TransportParameters(bytes)
1217            | ClientExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()),
1218            _ => None,
1219        }
1220    }
1221
1222    #[cfg(feature = "tls12")]
1223    pub(crate) fn ticket_extension(&self) -> Option<&ClientExtension> {
1224        self.find_extension(ExtensionType::SessionTicket)
1225    }
1226
1227    pub(crate) fn versions_extension(&self) -> Option<SupportedProtocolVersions> {
1228        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1229        match ext {
1230            ClientExtension::SupportedVersions(vers) => Some(*vers),
1231            _ => None,
1232        }
1233    }
1234
1235    pub fn keyshare_extension(&self) -> Option<&[KeyShareEntry]> {
1236        let ext = self.find_extension(ExtensionType::KeyShare)?;
1237        match ext {
1238            ClientExtension::KeyShare(shares) => Some(shares),
1239            _ => None,
1240        }
1241    }
1242
1243    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
1244        self.keyshare_extension()
1245            .map(|entries| {
1246                has_duplicates::<_, _, u16>(
1247                    entries
1248                        .iter()
1249                        .map(|kse| u16::from(kse.group)),
1250                )
1251            })
1252            .unwrap_or_default()
1253    }
1254
1255    pub(crate) fn psk(&self) -> Option<&PresharedKeyOffer> {
1256        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1257        match ext {
1258            ClientExtension::PresharedKey(psk) => Some(psk),
1259            _ => None,
1260        }
1261    }
1262
1263    pub(crate) fn check_psk_ext_is_last(&self) -> bool {
1264        self.extensions
1265            .last()
1266            .is_some_and(|ext| ext.ext_type() == ExtensionType::PreSharedKey)
1267    }
1268
1269    pub(crate) fn psk_modes(&self) -> Option<&[PskKeyExchangeMode]> {
1270        let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
1271        match ext {
1272            ClientExtension::PresharedKeyModes(psk_modes) => Some(psk_modes),
1273            _ => None,
1274        }
1275    }
1276
1277    pub(crate) fn set_psk_binder(&mut self, binder: impl Into<Vec<u8>>) {
1278        let last_extension = self.extensions.last_mut();
1279        if let Some(ClientExtension::PresharedKey(offer)) = last_extension {
1280            offer.binders[0] = PresharedKeyBinder::from(binder.into());
1281        }
1282    }
1283
1284    /// Returns the PSK binders.
1285    ///
1286    /// Only useful when "filling in" the binders for an external
1287    /// PSK.
1288    pub(crate) fn psk_binders_mut(&mut self) -> &mut [PresharedKeyBinder] {
1289        // The "pre_shared_key" extension is always last.
1290        let last_extension = self.extensions.last_mut();
1291        if let Some(ClientExtension::PresharedKey(offer)) = last_extension {
1292            &mut offer.binders
1293        } else {
1294            &mut []
1295        }
1296    }
1297
1298    #[cfg(feature = "tls12")]
1299    pub(crate) fn ems_support_offered(&self) -> bool {
1300        self.find_extension(ExtensionType::ExtendedMasterSecret)
1301            .is_some()
1302    }
1303
1304    pub(crate) fn early_data_extension_offered(&self) -> bool {
1305        self.find_extension(ExtensionType::EarlyData)
1306            .is_some()
1307    }
1308
1309    pub(crate) fn certificate_compression_extension(
1310        &self,
1311    ) -> Option<&[CertificateCompressionAlgorithm]> {
1312        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
1313        match ext {
1314            ClientExtension::CertificateCompressionAlgorithms(algs) => Some(algs),
1315            _ => None,
1316        }
1317    }
1318
1319    pub(crate) fn has_certificate_compression_extension_with_duplicates(&self) -> bool {
1320        if let Some(algs) = self.certificate_compression_extension() {
1321            has_duplicates::<_, _, u16>(algs.iter().cloned())
1322        } else {
1323            false
1324        }
1325    }
1326
1327    pub(crate) fn certificate_authorities_extension(&self) -> Option<&[DistinguishedName]> {
1328        match self.find_extension(ExtensionType::CertificateAuthorities)? {
1329            ClientExtension::AuthorityNames(ext) => Some(ext),
1330            _ => unreachable!("extension type checked"),
1331        }
1332    }
1333}
1334
1335#[derive(Clone, Debug)]
1336pub(crate) enum HelloRetryExtension {
1337    KeyShare(NamedGroup),
1338    Cookie(PayloadU16<NonEmpty>),
1339    SupportedVersions(ProtocolVersion),
1340    EchHelloRetryRequest(Vec<u8>),
1341    Unknown(UnknownExtension),
1342}
1343
1344impl HelloRetryExtension {
1345    pub(crate) fn ext_type(&self) -> ExtensionType {
1346        match self {
1347            Self::KeyShare(_) => ExtensionType::KeyShare,
1348            Self::Cookie(_) => ExtensionType::Cookie,
1349            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
1350            Self::EchHelloRetryRequest(_) => ExtensionType::EncryptedClientHello,
1351            Self::Unknown(r) => r.typ,
1352        }
1353    }
1354}
1355
1356impl Codec<'_> for HelloRetryExtension {
1357    fn encode(&self, bytes: &mut Vec<u8>) {
1358        self.ext_type().encode(bytes);
1359
1360        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1361        match self {
1362            Self::KeyShare(r) => r.encode(nested.buf),
1363            Self::Cookie(r) => r.encode(nested.buf),
1364            Self::SupportedVersions(r) => r.encode(nested.buf),
1365            Self::EchHelloRetryRequest(r) => {
1366                nested.buf.extend_from_slice(r);
1367            }
1368            Self::Unknown(r) => r.encode(nested.buf),
1369        }
1370    }
1371
1372    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1373        let typ = ExtensionType::read(r)?;
1374        let len = u16::read(r)? as usize;
1375        let mut sub = r.sub(len)?;
1376
1377        let ext = match typ {
1378            ExtensionType::KeyShare => Self::KeyShare(NamedGroup::read(&mut sub)?),
1379            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
1380            ExtensionType::SupportedVersions => {
1381                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
1382            }
1383            ExtensionType::EncryptedClientHello => Self::EchHelloRetryRequest(sub.rest().to_vec()),
1384            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1385        };
1386
1387        sub.expect_empty("HelloRetryExtension")
1388            .map(|_| ext)
1389    }
1390}
1391
1392impl TlsListElement for HelloRetryExtension {
1393    const SIZE_LEN: ListLength = ListLength::U16;
1394}
1395
1396#[derive(Clone, Debug)]
1397pub struct HelloRetryRequest {
1398    pub(crate) legacy_version: ProtocolVersion,
1399    pub session_id: SessionId,
1400    pub(crate) cipher_suite: CipherSuite,
1401    pub(crate) extensions: Vec<HelloRetryExtension>,
1402}
1403
1404impl Codec<'_> for HelloRetryRequest {
1405    fn encode(&self, bytes: &mut Vec<u8>) {
1406        self.payload_encode(bytes, Encoding::Standard)
1407    }
1408
1409    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1410        let session_id = SessionId::read(r)?;
1411        let cipher_suite = CipherSuite::read(r)?;
1412        let compression = Compression::read(r)?;
1413
1414        if compression != Compression::Null {
1415            return Err(InvalidMessage::UnsupportedCompression);
1416        }
1417
1418        Ok(Self {
1419            legacy_version: ProtocolVersion::Unknown(0),
1420            session_id,
1421            cipher_suite,
1422            extensions: Vec::read(r)?,
1423        })
1424    }
1425}
1426
1427impl HelloRetryRequest {
1428    /// Returns true if there is more than one extension of a given
1429    /// type.
1430    pub(crate) fn has_duplicate_extension(&self) -> bool {
1431        has_duplicates::<_, _, u16>(
1432            self.extensions
1433                .iter()
1434                .map(|ext| ext.ext_type()),
1435        )
1436    }
1437
1438    pub(crate) fn has_unknown_extension(&self) -> bool {
1439        self.extensions.iter().any(|ext| {
1440            ext.ext_type() != ExtensionType::KeyShare
1441                && ext.ext_type() != ExtensionType::SupportedVersions
1442                && ext.ext_type() != ExtensionType::Cookie
1443                && ext.ext_type() != ExtensionType::EncryptedClientHello
1444        })
1445    }
1446
1447    fn find_extension(&self, ext: ExtensionType) -> Option<&HelloRetryExtension> {
1448        self.extensions
1449            .iter()
1450            .find(|x| x.ext_type() == ext)
1451    }
1452
1453    pub fn requested_key_share_group(&self) -> Option<NamedGroup> {
1454        let ext = self.find_extension(ExtensionType::KeyShare)?;
1455        match ext {
1456            HelloRetryExtension::KeyShare(grp) => Some(*grp),
1457            _ => None,
1458        }
1459    }
1460
1461    pub(crate) fn cookie(&self) -> Option<&PayloadU16<NonEmpty>> {
1462        let ext = self.find_extension(ExtensionType::Cookie)?;
1463        match ext {
1464            HelloRetryExtension::Cookie(ck) => Some(ck),
1465            _ => None,
1466        }
1467    }
1468
1469    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1470        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1471        match ext {
1472            HelloRetryExtension::SupportedVersions(ver) => Some(*ver),
1473            _ => None,
1474        }
1475    }
1476
1477    pub(crate) fn ech(&self) -> Option<&Vec<u8>> {
1478        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
1479        match ext {
1480            HelloRetryExtension::EchHelloRetryRequest(ech) => Some(ech),
1481            _ => None,
1482        }
1483    }
1484
1485    fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1486        self.legacy_version.encode(bytes);
1487        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1488        self.session_id.encode(bytes);
1489        self.cipher_suite.encode(bytes);
1490        Compression::Null.encode(bytes);
1491
1492        match purpose {
1493            // For the purpose of ECH confirmation, the Encrypted Client Hello extension
1494            // must have its payload replaced by 8 zero bytes.
1495            //
1496            // See draft-ietf-tls-esni-18 7.2.1:
1497            // <https://datatracker.ietf.org/doc/html/draft-ietf-tls-esni-18#name-sending-helloretryrequest-2>
1498            Encoding::EchConfirmation => {
1499                let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1500                for ext in &self.extensions {
1501                    match ext.ext_type() {
1502                        ExtensionType::EncryptedClientHello => {
1503                            HelloRetryExtension::EchHelloRetryRequest(vec![0u8; 8])
1504                                .encode(extensions.buf);
1505                        }
1506                        _ => {
1507                            ext.encode(extensions.buf);
1508                        }
1509                    }
1510                }
1511            }
1512            _ => {
1513                self.extensions.encode(bytes);
1514            }
1515        }
1516    }
1517}
1518
1519#[derive(Clone, Debug)]
1520pub struct ServerHelloPayload {
1521    pub extensions: Vec<ServerExtension>,
1522    pub(crate) legacy_version: ProtocolVersion,
1523    pub(crate) random: Random,
1524    pub(crate) session_id: SessionId,
1525    pub(crate) cipher_suite: CipherSuite,
1526    pub(crate) compression_method: Compression,
1527}
1528
1529impl Codec<'_> for ServerHelloPayload {
1530    fn encode(&self, bytes: &mut Vec<u8>) {
1531        self.payload_encode(bytes, Encoding::Standard)
1532    }
1533
1534    // minus version and random, which have already been read.
1535    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1536        let session_id = SessionId::read(r)?;
1537        let suite = CipherSuite::read(r)?;
1538        let compression = Compression::read(r)?;
1539
1540        // RFC5246:
1541        // "The presence of extensions can be detected by determining whether
1542        //  there are bytes following the compression_method field at the end of
1543        //  the ServerHello."
1544        let extensions = if r.any_left() { Vec::read(r)? } else { vec![] };
1545
1546        let ret = Self {
1547            legacy_version: ProtocolVersion::Unknown(0),
1548            random: ZERO_RANDOM,
1549            session_id,
1550            cipher_suite: suite,
1551            compression_method: compression,
1552            extensions,
1553        };
1554
1555        r.expect_empty("ServerHelloPayload")
1556            .map(|_| ret)
1557    }
1558}
1559
1560impl HasServerExtensions for ServerHelloPayload {
1561    fn extensions(&self) -> &[ServerExtension] {
1562        &self.extensions
1563    }
1564}
1565
1566impl ServerHelloPayload {
1567    pub(crate) fn key_share(&self) -> Option<&KeyShareEntry> {
1568        let ext = self.find_extension(ExtensionType::KeyShare)?;
1569        match ext {
1570            ServerExtension::KeyShare(share) => Some(share),
1571            _ => None,
1572        }
1573    }
1574
1575    pub(crate) fn psk_index(&self) -> Option<u16> {
1576        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1577        match ext {
1578            ServerExtension::PresharedKey(index) => Some(*index),
1579            _ => None,
1580        }
1581    }
1582
1583    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1584        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1585        match ext {
1586            ServerExtension::EcPointFormats(fmts) => Some(fmts),
1587            _ => None,
1588        }
1589    }
1590
1591    #[cfg(feature = "tls12")]
1592    pub(crate) fn ems_support_acked(&self) -> bool {
1593        self.find_extension(ExtensionType::ExtendedMasterSecret)
1594            .is_some()
1595    }
1596
1597    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1598        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1599        match ext {
1600            ServerExtension::SupportedVersions(vers) => Some(*vers),
1601            _ => None,
1602        }
1603    }
1604
1605    fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
1606        self.legacy_version.encode(bytes);
1607
1608        match encoding {
1609            // When encoding a ServerHello for ECH confirmation, the random value
1610            // has the last 8 bytes zeroed out.
1611            Encoding::EchConfirmation => {
1612                // Indexing safety: self.random is 32 bytes long by definition.
1613                let rand_vec = self.random.get_encoding();
1614                bytes.extend_from_slice(&rand_vec.as_slice()[..24]);
1615                bytes.extend_from_slice(&[0u8; 8]);
1616            }
1617            _ => self.random.encode(bytes),
1618        }
1619
1620        self.session_id.encode(bytes);
1621        self.cipher_suite.encode(bytes);
1622        self.compression_method.encode(bytes);
1623
1624        if !self.extensions.is_empty() {
1625            self.extensions.encode(bytes);
1626        }
1627    }
1628}
1629
1630#[derive(Clone, Default, Debug)]
1631pub struct CertificateChain<'a>(pub Vec<CertificateDer<'a>>);
1632
1633impl CertificateChain<'_> {
1634    pub(crate) fn into_owned(self) -> CertificateChain<'static> {
1635        CertificateChain(
1636            self.0
1637                .into_iter()
1638                .map(|c| c.into_owned())
1639                .collect(),
1640        )
1641    }
1642}
1643
1644impl<'a> Codec<'a> for CertificateChain<'a> {
1645    fn encode(&self, bytes: &mut Vec<u8>) {
1646        Vec::encode(&self.0, bytes)
1647    }
1648
1649    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1650        Vec::read(r).map(Self)
1651    }
1652}
1653
1654impl<'a> Deref for CertificateChain<'a> {
1655    type Target = [CertificateDer<'a>];
1656
1657    fn deref(&self) -> &[CertificateDer<'a>] {
1658        &self.0
1659    }
1660}
1661
1662impl TlsListElement for CertificateDer<'_> {
1663    const SIZE_LEN: ListLength = ListLength::U24 {
1664        max: CERTIFICATE_MAX_SIZE_LIMIT,
1665        error: InvalidMessage::CertificatePayloadTooLarge,
1666    };
1667}
1668
1669/// TLS has a 16MB size limit on any handshake message,
1670/// plus a 16MB limit on any given certificate.
1671///
1672/// We contract that to 64KB to limit the amount of memory allocation
1673/// that is directly controllable by the peer.
1674pub(crate) const CERTIFICATE_MAX_SIZE_LIMIT: usize = 0x1_0000;
1675
1676#[derive(Debug)]
1677pub(crate) enum CertificateExtension<'a> {
1678    CertificateStatus(CertificateStatus<'a>),
1679    Unknown(UnknownExtension),
1680}
1681
1682impl CertificateExtension<'_> {
1683    pub(crate) fn ext_type(&self) -> ExtensionType {
1684        match self {
1685            Self::CertificateStatus(_) => ExtensionType::StatusRequest,
1686            Self::Unknown(r) => r.typ,
1687        }
1688    }
1689
1690    pub(crate) fn cert_status(&self) -> Option<&[u8]> {
1691        match self {
1692            Self::CertificateStatus(cs) => Some(cs.ocsp_response.0.bytes()),
1693            _ => None,
1694        }
1695    }
1696
1697    pub(crate) fn into_owned(self) -> CertificateExtension<'static> {
1698        match self {
1699            Self::CertificateStatus(st) => CertificateExtension::CertificateStatus(st.into_owned()),
1700            Self::Unknown(unk) => CertificateExtension::Unknown(unk),
1701        }
1702    }
1703}
1704
1705impl<'a> Codec<'a> for CertificateExtension<'a> {
1706    fn encode(&self, bytes: &mut Vec<u8>) {
1707        self.ext_type().encode(bytes);
1708
1709        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1710        match self {
1711            Self::CertificateStatus(r) => r.encode(nested.buf),
1712            Self::Unknown(r) => r.encode(nested.buf),
1713        }
1714    }
1715
1716    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1717        let typ = ExtensionType::read(r)?;
1718        let len = u16::read(r)? as usize;
1719        let mut sub = r.sub(len)?;
1720
1721        let ext = match typ {
1722            ExtensionType::StatusRequest => {
1723                let st = CertificateStatus::read(&mut sub)?;
1724                Self::CertificateStatus(st)
1725            }
1726            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1727        };
1728
1729        sub.expect_empty("CertificateExtension")
1730            .map(|_| ext)
1731    }
1732}
1733
1734impl TlsListElement for CertificateExtension<'_> {
1735    const SIZE_LEN: ListLength = ListLength::U16;
1736}
1737
1738#[derive(Debug)]
1739pub(crate) struct CertificateEntry<'a> {
1740    pub(crate) cert: CertificateDer<'a>,
1741    pub(crate) exts: Vec<CertificateExtension<'a>>,
1742}
1743
1744impl<'a> Codec<'a> for CertificateEntry<'a> {
1745    fn encode(&self, bytes: &mut Vec<u8>) {
1746        self.cert.encode(bytes);
1747        self.exts.encode(bytes);
1748    }
1749
1750    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1751        Ok(Self {
1752            cert: CertificateDer::read(r)?,
1753            exts: Vec::read(r)?,
1754        })
1755    }
1756}
1757
1758impl<'a> CertificateEntry<'a> {
1759    pub(crate) fn new(cert: CertificateDer<'a>) -> Self {
1760        Self {
1761            cert,
1762            exts: Vec::new(),
1763        }
1764    }
1765
1766    pub(crate) fn into_owned(self) -> CertificateEntry<'static> {
1767        CertificateEntry {
1768            cert: self.cert.into_owned(),
1769            exts: self
1770                .exts
1771                .into_iter()
1772                .map(CertificateExtension::into_owned)
1773                .collect(),
1774        }
1775    }
1776
1777    pub(crate) fn has_duplicate_extension(&self) -> bool {
1778        has_duplicates::<_, _, u16>(
1779            self.exts
1780                .iter()
1781                .map(|ext| ext.ext_type()),
1782        )
1783    }
1784
1785    pub(crate) fn has_unknown_extension(&self) -> bool {
1786        self.exts
1787            .iter()
1788            .any(|ext| ext.ext_type() != ExtensionType::StatusRequest)
1789    }
1790
1791    pub(crate) fn ocsp_response(&self) -> Option<&[u8]> {
1792        self.exts
1793            .iter()
1794            .find(|ext| ext.ext_type() == ExtensionType::StatusRequest)
1795            .and_then(CertificateExtension::cert_status)
1796    }
1797}
1798
1799impl TlsListElement for CertificateEntry<'_> {
1800    const SIZE_LEN: ListLength = ListLength::U24 {
1801        max: CERTIFICATE_MAX_SIZE_LIMIT,
1802        error: InvalidMessage::CertificatePayloadTooLarge,
1803    };
1804}
1805
1806#[derive(Debug)]
1807pub struct CertificatePayloadTls13<'a> {
1808    pub(crate) context: PayloadU8,
1809    pub(crate) entries: Vec<CertificateEntry<'a>>,
1810}
1811
1812impl<'a> Codec<'a> for CertificatePayloadTls13<'a> {
1813    fn encode(&self, bytes: &mut Vec<u8>) {
1814        self.context.encode(bytes);
1815        self.entries.encode(bytes);
1816    }
1817
1818    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1819        Ok(Self {
1820            context: PayloadU8::read(r)?,
1821            entries: Vec::read(r)?,
1822        })
1823    }
1824}
1825
1826impl<'a> CertificatePayloadTls13<'a> {
1827    pub(crate) fn new(
1828        certs: impl Iterator<Item = &'a CertificateDer<'a>>,
1829        ocsp_response: Option<&'a [u8]>,
1830    ) -> Self {
1831        Self {
1832            context: PayloadU8::empty(),
1833            entries: certs
1834                // zip certificate iterator with `ocsp_response` followed by
1835                // an infinite-length iterator of `None`.
1836                .zip(
1837                    ocsp_response
1838                        .into_iter()
1839                        .map(Some)
1840                        .chain(iter::repeat(None)),
1841                )
1842                .map(|(cert, ocsp)| {
1843                    let mut e = CertificateEntry::new(cert.clone());
1844                    if let Some(ocsp) = ocsp {
1845                        e.exts
1846                            .push(CertificateExtension::CertificateStatus(
1847                                CertificateStatus::new(ocsp),
1848                            ));
1849                    }
1850                    e
1851                })
1852                .collect(),
1853        }
1854    }
1855
1856    pub(crate) fn into_owned(self) -> CertificatePayloadTls13<'static> {
1857        CertificatePayloadTls13 {
1858            context: self.context,
1859            entries: self
1860                .entries
1861                .into_iter()
1862                .map(CertificateEntry::into_owned)
1863                .collect(),
1864        }
1865    }
1866
1867    pub(crate) fn any_entry_has_duplicate_extension(&self) -> bool {
1868        for entry in &self.entries {
1869            if entry.has_duplicate_extension() {
1870                return true;
1871            }
1872        }
1873
1874        false
1875    }
1876
1877    pub(crate) fn any_entry_has_unknown_extension(&self) -> bool {
1878        for entry in &self.entries {
1879            if entry.has_unknown_extension() {
1880                return true;
1881            }
1882        }
1883
1884        false
1885    }
1886
1887    pub(crate) fn any_entry_has_extension(&self) -> bool {
1888        for entry in &self.entries {
1889            if !entry.exts.is_empty() {
1890                return true;
1891            }
1892        }
1893
1894        false
1895    }
1896
1897    pub(crate) fn end_entity_ocsp(&self) -> &[u8] {
1898        self.entries
1899            .first()
1900            .and_then(CertificateEntry::ocsp_response)
1901            .unwrap_or_default()
1902    }
1903
1904    pub(crate) fn into_certificate_chain(self) -> CertificateChain<'a> {
1905        CertificateChain(
1906            self.entries
1907                .into_iter()
1908                .map(|e| e.cert)
1909                .collect(),
1910        )
1911    }
1912}
1913
1914/// Describes supported key exchange mechanisms.
1915#[derive(Clone, Copy, Debug, PartialEq)]
1916#[non_exhaustive]
1917pub enum KeyExchangeAlgorithm {
1918    /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]).
1919    ///
1920    /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919
1921    DHE,
1922    /// Key exchange performed via elliptic curve Diffie-Hellman.
1923    ECDHE,
1924}
1925
1926pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] =
1927    &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE];
1928
1929// We don't support arbitrary curves.  It's a terrible
1930// idea and unnecessary attack surface.  Please,
1931// get a grip.
1932#[derive(Debug)]
1933pub(crate) struct EcParameters {
1934    pub(crate) curve_type: ECCurveType,
1935    pub(crate) named_group: NamedGroup,
1936}
1937
1938impl Codec<'_> for EcParameters {
1939    fn encode(&self, bytes: &mut Vec<u8>) {
1940        self.curve_type.encode(bytes);
1941        self.named_group.encode(bytes);
1942    }
1943
1944    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1945        let ct = ECCurveType::read(r)?;
1946        if ct != ECCurveType::NamedCurve {
1947            return Err(InvalidMessage::UnsupportedCurveType);
1948        }
1949
1950        let grp = NamedGroup::read(r)?;
1951
1952        Ok(Self {
1953            curve_type: ct,
1954            named_group: grp,
1955        })
1956    }
1957}
1958
1959#[cfg(feature = "tls12")]
1960pub(crate) trait KxDecode<'a>: fmt::Debug + Sized {
1961    /// Decode a key exchange message given the key_exchange `algo`
1962    fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>;
1963}
1964
1965#[cfg(feature = "tls12")]
1966#[derive(Debug)]
1967pub(crate) enum ClientKeyExchangeParams {
1968    Ecdh(ClientEcdhParams),
1969    Dh(ClientDhParams),
1970}
1971
1972#[cfg(feature = "tls12")]
1973impl ClientKeyExchangeParams {
1974    pub(crate) fn pub_key(&self) -> &[u8] {
1975        match self {
1976            Self::Ecdh(ecdh) => &ecdh.public.0,
1977            Self::Dh(dh) => &dh.public.0,
1978        }
1979    }
1980
1981    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1982        match self {
1983            Self::Ecdh(ecdh) => ecdh.encode(buf),
1984            Self::Dh(dh) => dh.encode(buf),
1985        }
1986    }
1987}
1988
1989#[cfg(feature = "tls12")]
1990impl KxDecode<'_> for ClientKeyExchangeParams {
1991    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1992        use KeyExchangeAlgorithm::*;
1993        Ok(match algo {
1994            ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?),
1995            DHE => Self::Dh(ClientDhParams::read(r)?),
1996        })
1997    }
1998}
1999
2000#[cfg(feature = "tls12")]
2001#[derive(Debug)]
2002pub(crate) struct ClientEcdhParams {
2003    /// RFC4492: `opaque point <1..2^8-1>;`
2004    pub(crate) public: PayloadU8<NonEmpty>,
2005}
2006
2007#[cfg(feature = "tls12")]
2008impl Codec<'_> for ClientEcdhParams {
2009    fn encode(&self, bytes: &mut Vec<u8>) {
2010        self.public.encode(bytes);
2011    }
2012
2013    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2014        let pb = PayloadU8::read(r)?;
2015        Ok(Self { public: pb })
2016    }
2017}
2018
2019#[cfg(feature = "tls12")]
2020#[derive(Debug)]
2021pub(crate) struct ClientDhParams {
2022    /// RFC5246: `opaque dh_Yc<1..2^16-1>;`
2023    pub(crate) public: PayloadU16<NonEmpty>,
2024}
2025
2026#[cfg(feature = "tls12")]
2027impl Codec<'_> for ClientDhParams {
2028    fn encode(&self, bytes: &mut Vec<u8>) {
2029        self.public.encode(bytes);
2030    }
2031
2032    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2033        Ok(Self {
2034            public: PayloadU16::read(r)?,
2035        })
2036    }
2037}
2038
2039#[derive(Debug)]
2040pub(crate) struct ServerEcdhParams {
2041    pub(crate) curve_params: EcParameters,
2042    /// RFC4492: `opaque point <1..2^8-1>;`
2043    pub(crate) public: PayloadU8<NonEmpty>,
2044}
2045
2046impl ServerEcdhParams {
2047    #[cfg(feature = "tls12")]
2048    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2049        Self {
2050            curve_params: EcParameters {
2051                curve_type: ECCurveType::NamedCurve,
2052                named_group: kx.group(),
2053            },
2054            public: PayloadU8::new(kx.pub_key().to_vec()),
2055        }
2056    }
2057}
2058
2059impl Codec<'_> for ServerEcdhParams {
2060    fn encode(&self, bytes: &mut Vec<u8>) {
2061        self.curve_params.encode(bytes);
2062        self.public.encode(bytes);
2063    }
2064
2065    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2066        let cp = EcParameters::read(r)?;
2067        let pb = PayloadU8::read(r)?;
2068
2069        Ok(Self {
2070            curve_params: cp,
2071            public: pb,
2072        })
2073    }
2074}
2075
2076#[derive(Debug)]
2077#[allow(non_snake_case)]
2078pub(crate) struct ServerDhParams {
2079    /// RFC5246: `opaque dh_p<1..2^16-1>;`
2080    pub(crate) dh_p: PayloadU16<NonEmpty>,
2081    /// RFC5246: `opaque dh_g<1..2^16-1>;`
2082    pub(crate) dh_g: PayloadU16<NonEmpty>,
2083    /// RFC5246: `opaque dh_Ys<1..2^16-1>;`
2084    pub(crate) dh_Ys: PayloadU16<NonEmpty>,
2085}
2086
2087impl ServerDhParams {
2088    #[cfg(feature = "tls12")]
2089    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2090        let Some(params) = kx.ffdhe_group() else {
2091            panic!("invalid NamedGroup for DHE key exchange: {:?}", kx.group());
2092        };
2093
2094        Self {
2095            dh_p: PayloadU16::new(params.p.to_vec()),
2096            dh_g: PayloadU16::new(params.g.to_vec()),
2097            dh_Ys: PayloadU16::new(kx.pub_key().to_vec()),
2098        }
2099    }
2100
2101    #[cfg(feature = "tls12")]
2102    pub(crate) fn as_ffdhe_group(&self) -> FfdheGroup<'_> {
2103        FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0)
2104    }
2105}
2106
2107impl Codec<'_> for ServerDhParams {
2108    fn encode(&self, bytes: &mut Vec<u8>) {
2109        self.dh_p.encode(bytes);
2110        self.dh_g.encode(bytes);
2111        self.dh_Ys.encode(bytes);
2112    }
2113
2114    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2115        Ok(Self {
2116            dh_p: PayloadU16::read(r)?,
2117            dh_g: PayloadU16::read(r)?,
2118            dh_Ys: PayloadU16::read(r)?,
2119        })
2120    }
2121}
2122
2123#[allow(dead_code)]
2124#[derive(Debug)]
2125pub(crate) enum ServerKeyExchangeParams {
2126    Ecdh(ServerEcdhParams),
2127    Dh(ServerDhParams),
2128}
2129
2130impl ServerKeyExchangeParams {
2131    #[cfg(feature = "tls12")]
2132    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2133        match kx.group().key_exchange_algorithm() {
2134            KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)),
2135            KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)),
2136        }
2137    }
2138
2139    #[cfg(feature = "tls12")]
2140    pub(crate) fn pub_key(&self) -> &[u8] {
2141        match self {
2142            Self::Ecdh(ecdh) => &ecdh.public.0,
2143            Self::Dh(dh) => &dh.dh_Ys.0,
2144        }
2145    }
2146
2147    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
2148        match self {
2149            Self::Ecdh(ecdh) => ecdh.encode(buf),
2150            Self::Dh(dh) => dh.encode(buf),
2151        }
2152    }
2153}
2154
2155#[cfg(feature = "tls12")]
2156impl KxDecode<'_> for ServerKeyExchangeParams {
2157    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
2158        use KeyExchangeAlgorithm::*;
2159        Ok(match algo {
2160            ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?),
2161            DHE => Self::Dh(ServerDhParams::read(r)?),
2162        })
2163    }
2164}
2165
2166#[derive(Debug)]
2167pub struct ServerKeyExchange {
2168    pub(crate) params: ServerKeyExchangeParams,
2169    pub(crate) dss: DigitallySignedStruct,
2170}
2171
2172impl ServerKeyExchange {
2173    pub fn encode(&self, buf: &mut Vec<u8>) {
2174        self.params.encode(buf);
2175        self.dss.encode(buf);
2176    }
2177}
2178
2179#[derive(Debug)]
2180pub enum ServerKeyExchangePayload {
2181    Known(ServerKeyExchange),
2182    Unknown(Payload<'static>),
2183}
2184
2185impl From<ServerKeyExchange> for ServerKeyExchangePayload {
2186    fn from(value: ServerKeyExchange) -> Self {
2187        Self::Known(value)
2188    }
2189}
2190
2191impl Codec<'_> for ServerKeyExchangePayload {
2192    fn encode(&self, bytes: &mut Vec<u8>) {
2193        match self {
2194            Self::Known(x) => x.encode(bytes),
2195            Self::Unknown(x) => x.encode(bytes),
2196        }
2197    }
2198
2199    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2200        // read as Unknown, fully parse when we know the
2201        // KeyExchangeAlgorithm
2202        Ok(Self::Unknown(Payload::read(r).into_owned()))
2203    }
2204}
2205
2206impl ServerKeyExchangePayload {
2207    #[cfg(feature = "tls12")]
2208    pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> {
2209        if let Self::Unknown(unk) = self {
2210            let mut rd = Reader::init(unk.bytes());
2211
2212            let result = ServerKeyExchange {
2213                params: ServerKeyExchangeParams::decode(&mut rd, kxa).ok()?,
2214                dss: DigitallySignedStruct::read(&mut rd).ok()?,
2215            };
2216
2217            if !rd.any_left() {
2218                return Some(result);
2219            };
2220        }
2221
2222        None
2223    }
2224}
2225
2226// -- EncryptedExtensions (TLS1.3 only) --
2227
2228impl TlsListElement for ServerExtension {
2229    const SIZE_LEN: ListLength = ListLength::U16;
2230}
2231
2232pub(crate) trait HasServerExtensions {
2233    fn extensions(&self) -> &[ServerExtension];
2234
2235    /// Returns true if there is more than one extension of a given
2236    /// type.
2237    fn has_duplicate_extension(&self) -> bool {
2238        has_duplicates::<_, _, u16>(
2239            self.extensions()
2240                .iter()
2241                .map(|ext| ext.ext_type()),
2242        )
2243    }
2244
2245    fn find_extension(&self, ext: ExtensionType) -> Option<&ServerExtension> {
2246        self.extensions()
2247            .iter()
2248            .find(|x| x.ext_type() == ext)
2249    }
2250
2251    fn alpn_protocol(&self) -> Option<&[u8]> {
2252        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
2253        match ext {
2254            ServerExtension::Protocols(protos) => Some(protos.as_ref()),
2255            _ => None,
2256        }
2257    }
2258
2259    fn server_cert_type(&self) -> Option<&CertificateType> {
2260        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
2261        match ext {
2262            ServerExtension::ServerCertType(req) => Some(req),
2263            _ => None,
2264        }
2265    }
2266
2267    fn client_cert_type(&self) -> Option<&CertificateType> {
2268        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
2269        match ext {
2270            ServerExtension::ClientCertType(req) => Some(req),
2271            _ => None,
2272        }
2273    }
2274
2275    fn quic_params_extension(&self) -> Option<Vec<u8>> {
2276        let ext = self
2277            .find_extension(ExtensionType::TransportParameters)
2278            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
2279        match ext {
2280            ServerExtension::TransportParameters(bytes)
2281            | ServerExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()),
2282            _ => None,
2283        }
2284    }
2285
2286    fn server_ech_extension(&self) -> Option<ServerEncryptedClientHello> {
2287        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
2288        match ext {
2289            ServerExtension::EncryptedClientHello(ech) => Some(ech.clone()),
2290            _ => None,
2291        }
2292    }
2293
2294    fn early_data_extension_offered(&self) -> bool {
2295        self.find_extension(ExtensionType::EarlyData)
2296            .is_some()
2297    }
2298}
2299
2300impl HasServerExtensions for Vec<ServerExtension> {
2301    fn extensions(&self) -> &[ServerExtension] {
2302        self
2303    }
2304}
2305
2306/// RFC5246: `ClientCertificateType certificate_types<1..2^8-1>;`
2307impl TlsListElement for ClientCertificateType {
2308    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
2309        empty_error: InvalidMessage::IllegalEmptyList("ClientCertificateTypes"),
2310    };
2311}
2312
2313wrapped_payload!(
2314    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
2315    ///
2316    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
2317    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
2318    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2319    ///
2320    /// ```ignore
2321    /// for name in distinguished_names {
2322    ///     use x509_parser::prelude::FromDer;
2323    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
2324    /// }
2325    /// ```
2326    ///
2327    /// The TLS encoding is defined in RFC5246: `opaque DistinguishedName<1..2^16-1>;`
2328    pub struct DistinguishedName,
2329    PayloadU16<NonEmpty>,
2330);
2331
2332impl DistinguishedName {
2333    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
2334    ///
2335    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2336    ///
2337    /// ```ignore
2338    /// use x509_parser::prelude::FromDer;
2339    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
2340    /// ```
2341    pub fn in_sequence(bytes: &[u8]) -> Self {
2342        Self(PayloadU16::new(wrap_in_sequence(bytes)))
2343    }
2344}
2345
2346/// RFC8446: `DistinguishedName authorities<3..2^16-1>;` however,
2347/// RFC5246: `DistinguishedName certificate_authorities<0..2^16-1>;`
2348impl TlsListElement for DistinguishedName {
2349    const SIZE_LEN: ListLength = ListLength::U16;
2350}
2351
2352#[derive(Debug)]
2353pub struct CertificateRequestPayload {
2354    pub(crate) certtypes: Vec<ClientCertificateType>,
2355    pub(crate) sigschemes: Vec<SignatureScheme>,
2356    pub(crate) canames: Vec<DistinguishedName>,
2357}
2358
2359impl Codec<'_> for CertificateRequestPayload {
2360    fn encode(&self, bytes: &mut Vec<u8>) {
2361        self.certtypes.encode(bytes);
2362        self.sigschemes.encode(bytes);
2363        self.canames.encode(bytes);
2364    }
2365
2366    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2367        let certtypes = Vec::read(r)?;
2368        let sigschemes = Vec::read(r)?;
2369        let canames = Vec::read(r)?;
2370
2371        if sigschemes.is_empty() {
2372            warn!("meaningless CertificateRequest message");
2373            Err(InvalidMessage::NoSignatureSchemes)
2374        } else {
2375            Ok(Self {
2376                certtypes,
2377                sigschemes,
2378                canames,
2379            })
2380        }
2381    }
2382}
2383
2384#[derive(Debug)]
2385pub(crate) enum CertReqExtension {
2386    SignatureAlgorithms(Vec<SignatureScheme>),
2387    AuthorityNames(Vec<DistinguishedName>),
2388    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
2389    Unknown(UnknownExtension),
2390}
2391
2392impl CertReqExtension {
2393    pub(crate) fn ext_type(&self) -> ExtensionType {
2394        match self {
2395            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
2396            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
2397            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
2398            Self::Unknown(r) => r.typ,
2399        }
2400    }
2401}
2402
2403impl Codec<'_> for CertReqExtension {
2404    fn encode(&self, bytes: &mut Vec<u8>) {
2405        self.ext_type().encode(bytes);
2406
2407        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2408        match self {
2409            Self::SignatureAlgorithms(r) => r.encode(nested.buf),
2410            Self::AuthorityNames(r) => r.encode(nested.buf),
2411            Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf),
2412            Self::Unknown(r) => r.encode(nested.buf),
2413        }
2414    }
2415
2416    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2417        let typ = ExtensionType::read(r)?;
2418        let len = u16::read(r)? as usize;
2419        let mut sub = r.sub(len)?;
2420
2421        let ext = match typ {
2422            ExtensionType::SignatureAlgorithms => {
2423                let schemes = Vec::read(&mut sub)?;
2424                if schemes.is_empty() {
2425                    return Err(InvalidMessage::NoSignatureSchemes);
2426                }
2427                Self::SignatureAlgorithms(schemes)
2428            }
2429            ExtensionType::CertificateAuthorities => {
2430                let cas = Vec::read(&mut sub)?;
2431                if cas.is_empty() {
2432                    return Err(InvalidMessage::IllegalEmptyList("DistinguishedNames"));
2433                }
2434                Self::AuthorityNames(cas)
2435            }
2436            ExtensionType::CompressCertificate => {
2437                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
2438            }
2439            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2440        };
2441
2442        sub.expect_empty("CertReqExtension")
2443            .map(|_| ext)
2444    }
2445}
2446
2447impl TlsListElement for CertReqExtension {
2448    const SIZE_LEN: ListLength = ListLength::U16;
2449}
2450
2451#[derive(Debug)]
2452pub struct CertificateRequestPayloadTls13 {
2453    pub(crate) context: PayloadU8,
2454    pub(crate) extensions: Vec<CertReqExtension>,
2455}
2456
2457impl Codec<'_> for CertificateRequestPayloadTls13 {
2458    fn encode(&self, bytes: &mut Vec<u8>) {
2459        self.context.encode(bytes);
2460        self.extensions.encode(bytes);
2461    }
2462
2463    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2464        let context = PayloadU8::read(r)?;
2465        let extensions = Vec::read(r)?;
2466
2467        Ok(Self {
2468            context,
2469            extensions,
2470        })
2471    }
2472}
2473
2474impl CertificateRequestPayloadTls13 {
2475    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&CertReqExtension> {
2476        self.extensions
2477            .iter()
2478            .find(|x| x.ext_type() == ext)
2479    }
2480
2481    pub(crate) fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
2482        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
2483        match ext {
2484            CertReqExtension::SignatureAlgorithms(sa) => Some(sa),
2485            _ => None,
2486        }
2487    }
2488
2489    pub(crate) fn authorities_extension(&self) -> Option<&[DistinguishedName]> {
2490        let ext = self.find_extension(ExtensionType::CertificateAuthorities)?;
2491        match ext {
2492            CertReqExtension::AuthorityNames(an) => Some(an),
2493            _ => None,
2494        }
2495    }
2496
2497    pub(crate) fn certificate_compression_extension(
2498        &self,
2499    ) -> Option<&[CertificateCompressionAlgorithm]> {
2500        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
2501        match ext {
2502            CertReqExtension::CertificateCompressionAlgorithms(comps) => Some(comps),
2503            _ => None,
2504        }
2505    }
2506}
2507
2508// -- NewSessionTicket --
2509#[derive(Debug)]
2510pub struct NewSessionTicketPayload {
2511    pub(crate) lifetime_hint: u32,
2512    // Tickets can be large (KB), so we deserialise this straight
2513    // into an Arc, so it can be passed directly into the client's
2514    // session object without copying.
2515    pub(crate) ticket: Arc<PayloadU16>,
2516}
2517
2518impl NewSessionTicketPayload {
2519    #[cfg(feature = "tls12")]
2520    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
2521        Self {
2522            lifetime_hint,
2523            ticket: Arc::new(PayloadU16::new(ticket)),
2524        }
2525    }
2526}
2527
2528impl Codec<'_> for NewSessionTicketPayload {
2529    fn encode(&self, bytes: &mut Vec<u8>) {
2530        self.lifetime_hint.encode(bytes);
2531        self.ticket.encode(bytes);
2532    }
2533
2534    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2535        let lifetime = u32::read(r)?;
2536        let ticket = Arc::new(PayloadU16::read(r)?);
2537
2538        Ok(Self {
2539            lifetime_hint: lifetime,
2540            ticket,
2541        })
2542    }
2543}
2544
2545// -- NewSessionTicket electric boogaloo --
2546#[derive(Debug)]
2547pub(crate) enum NewSessionTicketExtension {
2548    EarlyData(u32),
2549    Unknown(UnknownExtension),
2550}
2551
2552impl NewSessionTicketExtension {
2553    pub(crate) fn ext_type(&self) -> ExtensionType {
2554        match self {
2555            Self::EarlyData(_) => ExtensionType::EarlyData,
2556            Self::Unknown(r) => r.typ,
2557        }
2558    }
2559}
2560
2561impl Codec<'_> for NewSessionTicketExtension {
2562    fn encode(&self, bytes: &mut Vec<u8>) {
2563        self.ext_type().encode(bytes);
2564
2565        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2566        match self {
2567            Self::EarlyData(r) => r.encode(nested.buf),
2568            Self::Unknown(r) => r.encode(nested.buf),
2569        }
2570    }
2571
2572    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2573        let typ = ExtensionType::read(r)?;
2574        let len = u16::read(r)? as usize;
2575        let mut sub = r.sub(len)?;
2576
2577        let ext = match typ {
2578            ExtensionType::EarlyData => Self::EarlyData(u32::read(&mut sub)?),
2579            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2580        };
2581
2582        sub.expect_empty("NewSessionTicketExtension")
2583            .map(|_| ext)
2584    }
2585}
2586
2587impl TlsListElement for NewSessionTicketExtension {
2588    const SIZE_LEN: ListLength = ListLength::U16;
2589}
2590
2591#[derive(Debug)]
2592pub struct NewSessionTicketPayloadTls13 {
2593    pub(crate) lifetime: u32,
2594    pub(crate) age_add: u32,
2595    pub(crate) nonce: PayloadU8,
2596    pub(crate) ticket: Arc<PayloadU16>,
2597    pub(crate) exts: Vec<NewSessionTicketExtension>,
2598}
2599
2600impl NewSessionTicketPayloadTls13 {
2601    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
2602        Self {
2603            lifetime,
2604            age_add,
2605            nonce: PayloadU8::new(nonce),
2606            ticket: Arc::new(PayloadU16::new(ticket)),
2607            exts: vec![],
2608        }
2609    }
2610
2611    pub(crate) fn has_duplicate_extension(&self) -> bool {
2612        has_duplicates::<_, _, u16>(
2613            self.exts
2614                .iter()
2615                .map(|ext| ext.ext_type()),
2616        )
2617    }
2618
2619    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&NewSessionTicketExtension> {
2620        self.exts
2621            .iter()
2622            .find(|x| x.ext_type() == ext)
2623    }
2624
2625    pub(crate) fn max_early_data_size(&self) -> Option<u32> {
2626        let ext = self.find_extension(ExtensionType::EarlyData)?;
2627        match ext {
2628            NewSessionTicketExtension::EarlyData(sz) => Some(*sz),
2629            _ => None,
2630        }
2631    }
2632}
2633
2634impl Codec<'_> for NewSessionTicketPayloadTls13 {
2635    fn encode(&self, bytes: &mut Vec<u8>) {
2636        self.lifetime.encode(bytes);
2637        self.age_add.encode(bytes);
2638        self.nonce.encode(bytes);
2639        self.ticket.encode(bytes);
2640        self.exts.encode(bytes);
2641    }
2642
2643    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2644        let lifetime = u32::read(r)?;
2645        let age_add = u32::read(r)?;
2646        let nonce = PayloadU8::read(r)?;
2647        // nb. RFC8446: `opaque ticket<1..2^16-1>;`
2648        let ticket = Arc::new(match PayloadU16::<NonEmpty>::read(r) {
2649            Err(InvalidMessage::IllegalEmptyValue) => Err(InvalidMessage::EmptyTicketValue),
2650            Err(err) => Err(err),
2651            Ok(pl) => Ok(PayloadU16::new(pl.0)),
2652        }?);
2653        let exts = Vec::read(r)?;
2654
2655        Ok(Self {
2656            lifetime,
2657            age_add,
2658            nonce,
2659            ticket,
2660            exts,
2661        })
2662    }
2663}
2664
2665// -- RFC6066 certificate status types
2666
2667/// Only supports OCSP
2668#[derive(Debug)]
2669pub struct CertificateStatus<'a> {
2670    pub(crate) ocsp_response: PayloadU24<'a>,
2671}
2672
2673impl<'a> Codec<'a> for CertificateStatus<'a> {
2674    fn encode(&self, bytes: &mut Vec<u8>) {
2675        CertificateStatusType::OCSP.encode(bytes);
2676        self.ocsp_response.encode(bytes);
2677    }
2678
2679    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2680        let typ = CertificateStatusType::read(r)?;
2681
2682        match typ {
2683            CertificateStatusType::OCSP => Ok(Self {
2684                ocsp_response: PayloadU24::read(r)?,
2685            }),
2686            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2687        }
2688    }
2689}
2690
2691impl<'a> CertificateStatus<'a> {
2692    pub(crate) fn new(ocsp: &'a [u8]) -> Self {
2693        CertificateStatus {
2694            ocsp_response: PayloadU24(Payload::Borrowed(ocsp)),
2695        }
2696    }
2697
2698    #[cfg(feature = "tls12")]
2699    pub(crate) fn into_inner(self) -> Vec<u8> {
2700        self.ocsp_response.0.into_vec()
2701    }
2702
2703    pub(crate) fn into_owned(self) -> CertificateStatus<'static> {
2704        CertificateStatus {
2705            ocsp_response: self.ocsp_response.into_owned(),
2706        }
2707    }
2708}
2709
2710// -- RFC8879 compressed certificates
2711
2712#[derive(Debug)]
2713pub struct CompressedCertificatePayload<'a> {
2714    pub(crate) alg: CertificateCompressionAlgorithm,
2715    pub(crate) uncompressed_len: u32,
2716    pub(crate) compressed: PayloadU24<'a>,
2717}
2718
2719impl<'a> Codec<'a> for CompressedCertificatePayload<'a> {
2720    fn encode(&self, bytes: &mut Vec<u8>) {
2721        self.alg.encode(bytes);
2722        codec::u24(self.uncompressed_len).encode(bytes);
2723        self.compressed.encode(bytes);
2724    }
2725
2726    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2727        Ok(Self {
2728            alg: CertificateCompressionAlgorithm::read(r)?,
2729            uncompressed_len: codec::u24::read(r)?.0,
2730            compressed: PayloadU24::read(r)?,
2731        })
2732    }
2733}
2734
2735impl CompressedCertificatePayload<'_> {
2736    fn into_owned(self) -> CompressedCertificatePayload<'static> {
2737        CompressedCertificatePayload {
2738            compressed: self.compressed.into_owned(),
2739            ..self
2740        }
2741    }
2742
2743    pub(crate) fn as_borrowed(&self) -> CompressedCertificatePayload<'_> {
2744        CompressedCertificatePayload {
2745            alg: self.alg,
2746            uncompressed_len: self.uncompressed_len,
2747            compressed: PayloadU24(Payload::Borrowed(self.compressed.0.bytes())),
2748        }
2749    }
2750}
2751
2752#[derive(Debug)]
2753pub enum HandshakePayload<'a> {
2754    HelloRequest,
2755    ClientHello(ClientHelloPayload),
2756    ServerHello(ServerHelloPayload),
2757    HelloRetryRequest(HelloRetryRequest),
2758    Certificate(CertificateChain<'a>),
2759    CertificateTls13(CertificatePayloadTls13<'a>),
2760    CompressedCertificate(CompressedCertificatePayload<'a>),
2761    ServerKeyExchange(ServerKeyExchangePayload),
2762    CertificateRequest(CertificateRequestPayload),
2763    CertificateRequestTls13(CertificateRequestPayloadTls13),
2764    CertificateVerify(DigitallySignedStruct),
2765    ServerHelloDone,
2766    EndOfEarlyData,
2767    ClientKeyExchange(Payload<'a>),
2768    NewSessionTicket(NewSessionTicketPayload),
2769    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2770    EncryptedExtensions(Vec<ServerExtension>),
2771    KeyUpdate(KeyUpdateRequest),
2772    Finished(Payload<'a>),
2773    CertificateStatus(CertificateStatus<'a>),
2774    MessageHash(Payload<'a>),
2775    Unknown(Payload<'a>),
2776}
2777
2778impl HandshakePayload<'_> {
2779    fn encode(&self, bytes: &mut Vec<u8>) {
2780        use self::HandshakePayload::*;
2781        match self {
2782            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2783            ClientHello(x) => x.encode(bytes),
2784            ServerHello(x) => x.encode(bytes),
2785            HelloRetryRequest(x) => x.encode(bytes),
2786            Certificate(x) => x.encode(bytes),
2787            CertificateTls13(x) => x.encode(bytes),
2788            CompressedCertificate(x) => x.encode(bytes),
2789            ServerKeyExchange(x) => x.encode(bytes),
2790            ClientKeyExchange(x) => x.encode(bytes),
2791            CertificateRequest(x) => x.encode(bytes),
2792            CertificateRequestTls13(x) => x.encode(bytes),
2793            CertificateVerify(x) => x.encode(bytes),
2794            NewSessionTicket(x) => x.encode(bytes),
2795            NewSessionTicketTls13(x) => x.encode(bytes),
2796            EncryptedExtensions(x) => x.encode(bytes),
2797            KeyUpdate(x) => x.encode(bytes),
2798            Finished(x) => x.encode(bytes),
2799            CertificateStatus(x) => x.encode(bytes),
2800            MessageHash(x) => x.encode(bytes),
2801            Unknown(x) => x.encode(bytes),
2802        }
2803    }
2804
2805    fn into_owned(self) -> HandshakePayload<'static> {
2806        use HandshakePayload::*;
2807
2808        match self {
2809            HelloRequest => HelloRequest,
2810            ClientHello(x) => ClientHello(x),
2811            ServerHello(x) => ServerHello(x),
2812            HelloRetryRequest(x) => HelloRetryRequest(x),
2813            Certificate(x) => Certificate(x.into_owned()),
2814            CertificateTls13(x) => CertificateTls13(x.into_owned()),
2815            CompressedCertificate(x) => CompressedCertificate(x.into_owned()),
2816            ServerKeyExchange(x) => ServerKeyExchange(x),
2817            CertificateRequest(x) => CertificateRequest(x),
2818            CertificateRequestTls13(x) => CertificateRequestTls13(x),
2819            CertificateVerify(x) => CertificateVerify(x),
2820            ServerHelloDone => ServerHelloDone,
2821            EndOfEarlyData => EndOfEarlyData,
2822            ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()),
2823            NewSessionTicket(x) => NewSessionTicket(x),
2824            NewSessionTicketTls13(x) => NewSessionTicketTls13(x),
2825            EncryptedExtensions(x) => EncryptedExtensions(x),
2826            KeyUpdate(x) => KeyUpdate(x),
2827            Finished(x) => Finished(x.into_owned()),
2828            CertificateStatus(x) => CertificateStatus(x.into_owned()),
2829            MessageHash(x) => MessageHash(x.into_owned()),
2830            Unknown(x) => Unknown(x.into_owned()),
2831        }
2832    }
2833}
2834
2835#[derive(Debug)]
2836pub struct HandshakeMessagePayload<'a> {
2837    pub typ: HandshakeType,
2838    pub payload: HandshakePayload<'a>,
2839}
2840
2841impl<'a> Codec<'a> for HandshakeMessagePayload<'a> {
2842    fn encode(&self, bytes: &mut Vec<u8>) {
2843        self.payload_encode(bytes, Encoding::Standard);
2844    }
2845
2846    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2847        Self::read_version(r, ProtocolVersion::TLSv1_2)
2848    }
2849}
2850
2851impl<'a> HandshakeMessagePayload<'a> {
2852    pub(crate) fn read_version(
2853        r: &mut Reader<'a>,
2854        vers: ProtocolVersion,
2855    ) -> Result<Self, InvalidMessage> {
2856        let mut typ = HandshakeType::read(r)?;
2857        let len = codec::u24::read(r)?.0 as usize;
2858        let mut sub = r.sub(len)?;
2859
2860        let payload = match typ {
2861            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2862            HandshakeType::ClientHello => {
2863                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2864            }
2865            HandshakeType::ServerHello => {
2866                let version = ProtocolVersion::read(&mut sub)?;
2867                let random = Random::read(&mut sub)?;
2868
2869                if random == HELLO_RETRY_REQUEST_RANDOM {
2870                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2871                    hrr.legacy_version = version;
2872                    typ = HandshakeType::HelloRetryRequest;
2873                    HandshakePayload::HelloRetryRequest(hrr)
2874                } else {
2875                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2876                    shp.legacy_version = version;
2877                    shp.random = random;
2878                    HandshakePayload::ServerHello(shp)
2879                }
2880            }
2881            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2882                let p = CertificatePayloadTls13::read(&mut sub)?;
2883                HandshakePayload::CertificateTls13(p)
2884            }
2885            HandshakeType::Certificate => {
2886                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2887            }
2888            HandshakeType::ServerKeyExchange => {
2889                let p = ServerKeyExchangePayload::read(&mut sub)?;
2890                HandshakePayload::ServerKeyExchange(p)
2891            }
2892            HandshakeType::ServerHelloDone => {
2893                sub.expect_empty("ServerHelloDone")?;
2894                HandshakePayload::ServerHelloDone
2895            }
2896            HandshakeType::ClientKeyExchange => {
2897                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2898            }
2899            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2900                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2901                HandshakePayload::CertificateRequestTls13(p)
2902            }
2903            HandshakeType::CertificateRequest => {
2904                let p = CertificateRequestPayload::read(&mut sub)?;
2905                HandshakePayload::CertificateRequest(p)
2906            }
2907            HandshakeType::CompressedCertificate => HandshakePayload::CompressedCertificate(
2908                CompressedCertificatePayload::read(&mut sub)?,
2909            ),
2910            HandshakeType::CertificateVerify => {
2911                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2912            }
2913            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2914                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2915                HandshakePayload::NewSessionTicketTls13(p)
2916            }
2917            HandshakeType::NewSessionTicket => {
2918                let p = NewSessionTicketPayload::read(&mut sub)?;
2919                HandshakePayload::NewSessionTicket(p)
2920            }
2921            HandshakeType::EncryptedExtensions => {
2922                HandshakePayload::EncryptedExtensions(Vec::read(&mut sub)?)
2923            }
2924            HandshakeType::KeyUpdate => {
2925                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2926            }
2927            HandshakeType::EndOfEarlyData => {
2928                sub.expect_empty("EndOfEarlyData")?;
2929                HandshakePayload::EndOfEarlyData
2930            }
2931            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2932            HandshakeType::CertificateStatus => {
2933                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2934            }
2935            HandshakeType::MessageHash => {
2936                // does not appear on the wire
2937                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2938            }
2939            HandshakeType::HelloRetryRequest => {
2940                // not legal on wire
2941                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2942            }
2943            _ => HandshakePayload::Unknown(Payload::read(&mut sub)),
2944        };
2945
2946        sub.expect_empty("HandshakeMessagePayload")
2947            .map(|_| Self { typ, payload })
2948    }
2949
2950    /// Returns the encoding of `self`, less the PSK binders,
2951    /// which are always the final bytes in the ClientHello.
2952    pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> {
2953        let mut ret = self.get_encoding();
2954        let ret_len = ret.len() - self.total_binder_length();
2955        ret.truncate(ret_len);
2956        ret
2957    }
2958
2959    /// Returns the total encoded length of the PSK binders.
2960    pub(crate) fn total_binder_length(&self) -> usize {
2961        match &self.payload {
2962            HandshakePayload::ClientHello(ch) => match ch.extensions.last() {
2963                Some(ClientExtension::PresharedKey(offer)) => {
2964                    let mut binders_encoding = Vec::new();
2965                    offer
2966                        .binders
2967                        .encode(&mut binders_encoding);
2968                    binders_encoding.len()
2969                }
2970                _ => 0,
2971            },
2972            _ => 0,
2973        }
2974    }
2975
2976    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
2977        // output type, length, and encoded payload
2978        match self.typ {
2979            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2980            _ => self.typ,
2981        }
2982        .encode(bytes);
2983
2984        let nested = LengthPrefixedBuffer::new(
2985            ListLength::U24 {
2986                max: usize::MAX,
2987                error: InvalidMessage::MessageTooLarge,
2988            },
2989            bytes,
2990        );
2991
2992        match &self.payload {
2993            // for Server Hello and HelloRetryRequest payloads we need to encode the payload
2994            // differently based on the purpose of the encoding.
2995            HandshakePayload::ServerHello(payload) => payload.payload_encode(nested.buf, encoding),
2996            HandshakePayload::HelloRetryRequest(payload) => {
2997                payload.payload_encode(nested.buf, encoding)
2998            }
2999
3000            // All other payload types are encoded the same regardless of purpose.
3001            _ => self.payload.encode(nested.buf),
3002        }
3003    }
3004
3005    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
3006        Self {
3007            typ: HandshakeType::MessageHash,
3008            payload: HandshakePayload::MessageHash(Payload::new(hash.to_vec())),
3009        }
3010    }
3011
3012    pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> {
3013        let Self { typ, payload } = self;
3014        HandshakeMessagePayload {
3015            typ,
3016            payload: payload.into_owned(),
3017        }
3018    }
3019}
3020
3021#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
3022pub struct HpkeSymmetricCipherSuite {
3023    pub kdf_id: HpkeKdf,
3024    pub aead_id: HpkeAead,
3025}
3026
3027impl Codec<'_> for HpkeSymmetricCipherSuite {
3028    fn encode(&self, bytes: &mut Vec<u8>) {
3029        self.kdf_id.encode(bytes);
3030        self.aead_id.encode(bytes);
3031    }
3032
3033    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3034        Ok(Self {
3035            kdf_id: HpkeKdf::read(r)?,
3036            aead_id: HpkeAead::read(r)?,
3037        })
3038    }
3039}
3040
3041/// draft-ietf-tls-esni-24: `HpkeSymmetricCipherSuite cipher_suites<4..2^16-4>;`
3042impl TlsListElement for HpkeSymmetricCipherSuite {
3043    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
3044        empty_error: InvalidMessage::IllegalEmptyList("HpkeSymmetricCipherSuites"),
3045    };
3046}
3047
3048#[derive(Clone, Debug, PartialEq)]
3049pub struct HpkeKeyConfig {
3050    pub config_id: u8,
3051    pub kem_id: HpkeKem,
3052    /// draft-ietf-tls-esni-24: `opaque HpkePublicKey<1..2^16-1>;`
3053    pub public_key: PayloadU16<NonEmpty>,
3054    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
3055}
3056
3057impl Codec<'_> for HpkeKeyConfig {
3058    fn encode(&self, bytes: &mut Vec<u8>) {
3059        self.config_id.encode(bytes);
3060        self.kem_id.encode(bytes);
3061        self.public_key.encode(bytes);
3062        self.symmetric_cipher_suites
3063            .encode(bytes);
3064    }
3065
3066    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3067        Ok(Self {
3068            config_id: u8::read(r)?,
3069            kem_id: HpkeKem::read(r)?,
3070            public_key: PayloadU16::read(r)?,
3071            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
3072        })
3073    }
3074}
3075
3076#[derive(Clone, Debug, PartialEq)]
3077pub struct EchConfigContents {
3078    pub key_config: HpkeKeyConfig,
3079    pub maximum_name_length: u8,
3080    pub public_name: DnsName<'static>,
3081    pub extensions: Vec<EchConfigExtension>,
3082}
3083
3084impl EchConfigContents {
3085    /// Returns true if there is more than one extension of a given
3086    /// type.
3087    pub(crate) fn has_duplicate_extension(&self) -> bool {
3088        has_duplicates::<_, _, u16>(
3089            self.extensions
3090                .iter()
3091                .map(|ext| ext.ext_type()),
3092        )
3093    }
3094
3095    /// Returns true if there is at least one mandatory unsupported extension.
3096    pub(crate) fn has_unknown_mandatory_extension(&self) -> bool {
3097        self.extensions
3098            .iter()
3099            // An extension is considered mandatory if the high bit of its type is set.
3100            .any(|ext| {
3101                matches!(ext.ext_type(), ExtensionType::Unknown(_))
3102                    && u16::from(ext.ext_type()) & 0x8000 != 0
3103            })
3104    }
3105}
3106
3107impl Codec<'_> for EchConfigContents {
3108    fn encode(&self, bytes: &mut Vec<u8>) {
3109        self.key_config.encode(bytes);
3110        self.maximum_name_length.encode(bytes);
3111        let dns_name = &self.public_name.borrow();
3112        PayloadU8::<MaybeEmpty>::encode_slice(dns_name.as_ref().as_ref(), bytes);
3113        self.extensions.encode(bytes);
3114    }
3115
3116    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3117        Ok(Self {
3118            key_config: HpkeKeyConfig::read(r)?,
3119            maximum_name_length: u8::read(r)?,
3120            public_name: {
3121                DnsName::try_from(
3122                    PayloadU8::<MaybeEmpty>::read(r)?
3123                        .0
3124                        .as_slice(),
3125                )
3126                .map_err(|_| InvalidMessage::InvalidServerName)?
3127                .to_owned()
3128            },
3129            extensions: Vec::read(r)?,
3130        })
3131    }
3132}
3133
3134/// An encrypted client hello (ECH) config.
3135#[derive(Clone, Debug, PartialEq)]
3136pub enum EchConfigPayload {
3137    /// A recognized V18 ECH configuration.
3138    V18(EchConfigContents),
3139    /// An unknown version ECH configuration.
3140    Unknown {
3141        version: EchVersion,
3142        contents: PayloadU16,
3143    },
3144}
3145
3146impl TlsListElement for EchConfigPayload {
3147    const SIZE_LEN: ListLength = ListLength::U16;
3148}
3149
3150impl Codec<'_> for EchConfigPayload {
3151    fn encode(&self, bytes: &mut Vec<u8>) {
3152        match self {
3153            Self::V18(c) => {
3154                // Write the version, the length, and the contents.
3155                EchVersion::V18.encode(bytes);
3156                let inner = LengthPrefixedBuffer::new(ListLength::U16, bytes);
3157                c.encode(inner.buf);
3158            }
3159            Self::Unknown { version, contents } => {
3160                // Unknown configuration versions are opaque.
3161                version.encode(bytes);
3162                contents.encode(bytes);
3163            }
3164        }
3165    }
3166
3167    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3168        let version = EchVersion::read(r)?;
3169        let length = u16::read(r)?;
3170        let mut contents = r.sub(length as usize)?;
3171
3172        Ok(match version {
3173            EchVersion::V18 => Self::V18(EchConfigContents::read(&mut contents)?),
3174            _ => {
3175                // Note: we don't PayloadU16::read() here because we've already read the length prefix.
3176                let data = PayloadU16::new(contents.rest().into());
3177                Self::Unknown {
3178                    version,
3179                    contents: data,
3180                }
3181            }
3182        })
3183    }
3184}
3185
3186#[derive(Clone, Debug, PartialEq)]
3187pub enum EchConfigExtension {
3188    Unknown(UnknownExtension),
3189}
3190
3191impl EchConfigExtension {
3192    pub(crate) fn ext_type(&self) -> ExtensionType {
3193        match self {
3194            Self::Unknown(r) => r.typ,
3195        }
3196    }
3197}
3198
3199impl Codec<'_> for EchConfigExtension {
3200    fn encode(&self, bytes: &mut Vec<u8>) {
3201        self.ext_type().encode(bytes);
3202
3203        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
3204        match self {
3205            Self::Unknown(r) => r.encode(nested.buf),
3206        }
3207    }
3208
3209    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3210        let typ = ExtensionType::read(r)?;
3211        let len = u16::read(r)? as usize;
3212        let mut sub = r.sub(len)?;
3213
3214        #[allow(clippy::match_single_binding)] // Future-proofing.
3215        let ext = match typ {
3216            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
3217        };
3218
3219        sub.expect_empty("EchConfigExtension")
3220            .map(|_| ext)
3221    }
3222}
3223
3224impl TlsListElement for EchConfigExtension {
3225    const SIZE_LEN: ListLength = ListLength::U16;
3226}
3227
3228/// Representation of the `ECHClientHello` client extension specified in
3229/// [draft-ietf-tls-esni Section 5].
3230///
3231/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3232#[derive(Clone, Debug)]
3233pub enum EncryptedClientHello {
3234    /// A `ECHClientHello` with type [EchClientHelloType::ClientHelloOuter].
3235    Outer(EncryptedClientHelloOuter),
3236    /// An empty `ECHClientHello` with type [EchClientHelloType::ClientHelloInner].
3237    ///
3238    /// This variant has no payload.
3239    Inner,
3240}
3241
3242impl Codec<'_> for EncryptedClientHello {
3243    fn encode(&self, bytes: &mut Vec<u8>) {
3244        match self {
3245            Self::Outer(payload) => {
3246                EchClientHelloType::ClientHelloOuter.encode(bytes);
3247                payload.encode(bytes);
3248            }
3249            Self::Inner => {
3250                EchClientHelloType::ClientHelloInner.encode(bytes);
3251                // Empty payload.
3252            }
3253        }
3254    }
3255
3256    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3257        match EchClientHelloType::read(r)? {
3258            EchClientHelloType::ClientHelloOuter => {
3259                Ok(Self::Outer(EncryptedClientHelloOuter::read(r)?))
3260            }
3261            EchClientHelloType::ClientHelloInner => Ok(Self::Inner),
3262            _ => Err(InvalidMessage::InvalidContentType),
3263        }
3264    }
3265}
3266
3267/// Representation of the ECHClientHello extension with type outer specified in
3268/// [draft-ietf-tls-esni Section 5].
3269///
3270/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3271#[derive(Clone, Debug)]
3272pub struct EncryptedClientHelloOuter {
3273    /// The cipher suite used to encrypt ClientHelloInner. Must match a value from
3274    /// ECHConfigContents.cipher_suites list.
3275    pub cipher_suite: HpkeSymmetricCipherSuite,
3276    /// The ECHConfigContents.key_config.config_id for the chosen ECHConfig.
3277    pub config_id: u8,
3278    /// The HPKE encapsulated key, used by servers to decrypt the corresponding payload field.
3279    /// This field is empty in a ClientHelloOuter sent in response to a HelloRetryRequest.
3280    pub enc: PayloadU16,
3281    /// The serialized and encrypted ClientHelloInner structure, encrypted using HPKE.
3282    pub payload: PayloadU16<NonEmpty>,
3283}
3284
3285impl Codec<'_> for EncryptedClientHelloOuter {
3286    fn encode(&self, bytes: &mut Vec<u8>) {
3287        self.cipher_suite.encode(bytes);
3288        self.config_id.encode(bytes);
3289        self.enc.encode(bytes);
3290        self.payload.encode(bytes);
3291    }
3292
3293    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3294        Ok(Self {
3295            cipher_suite: HpkeSymmetricCipherSuite::read(r)?,
3296            config_id: u8::read(r)?,
3297            enc: PayloadU16::read(r)?,
3298            payload: PayloadU16::read(r)?,
3299        })
3300    }
3301}
3302
3303/// Representation of the ECHEncryptedExtensions extension specified in
3304/// [draft-ietf-tls-esni Section 5].
3305///
3306/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3307#[derive(Clone, Debug)]
3308pub struct ServerEncryptedClientHello {
3309    pub(crate) retry_configs: Vec<EchConfigPayload>,
3310}
3311
3312impl Codec<'_> for ServerEncryptedClientHello {
3313    fn encode(&self, bytes: &mut Vec<u8>) {
3314        self.retry_configs.encode(bytes);
3315    }
3316
3317    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3318        Ok(Self {
3319            retry_configs: Vec::<EchConfigPayload>::read(r)?,
3320        })
3321    }
3322}
3323
3324/// The method of encoding to use for a handshake message.
3325///
3326/// In some cases a handshake message may be encoded differently depending on the purpose
3327/// the encoded message is being used for. For example, a [ServerHelloPayload] may be encoded
3328/// with the last 8 bytes of the random zeroed out when being encoded for ECH confirmation.
3329pub(crate) enum Encoding {
3330    /// Standard RFC 8446 encoding.
3331    Standard,
3332    /// Encoding for ECH confirmation.
3333    EchConfirmation,
3334    /// Encoding for ECH inner client hello.
3335    EchInnerHello { to_compress: Vec<ExtensionType> },
3336}
3337
3338fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool {
3339    let mut seen = BTreeSet::new();
3340
3341    for x in iter {
3342        if !seen.insert(x.into()) {
3343            return true;
3344        }
3345    }
3346
3347    false
3348}
3349
3350#[cfg(test)]
3351mod tests {
3352    use super::*;
3353
3354    #[test]
3355    fn test_ech_config_dupe_exts() {
3356        let unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3357            typ: ExtensionType::Unknown(0x42),
3358            payload: Payload::new(vec![0x42]),
3359        });
3360        let mut config = config_template();
3361        config
3362            .extensions
3363            .push(unknown_ext.clone());
3364        config.extensions.push(unknown_ext);
3365
3366        assert!(config.has_duplicate_extension());
3367        assert!(!config.has_unknown_mandatory_extension());
3368    }
3369
3370    #[test]
3371    fn test_ech_config_mandatory_exts() {
3372        let mandatory_unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3373            typ: ExtensionType::Unknown(0x42 | 0x8000), // Note: high bit set.
3374            payload: Payload::new(vec![0x42]),
3375        });
3376        let mut config = config_template();
3377        config
3378            .extensions
3379            .push(mandatory_unknown_ext);
3380
3381        assert!(!config.has_duplicate_extension());
3382        assert!(config.has_unknown_mandatory_extension());
3383    }
3384
3385    fn config_template() -> EchConfigContents {
3386        EchConfigContents {
3387            key_config: HpkeKeyConfig {
3388                config_id: 0,
3389                kem_id: HpkeKem::DHKEM_P256_HKDF_SHA256,
3390                public_key: PayloadU16::new(b"xxx".into()),
3391                symmetric_cipher_suites: vec![HpkeSymmetricCipherSuite {
3392                    kdf_id: HpkeKdf::HKDF_SHA256,
3393                    aead_id: HpkeAead::AES_128_GCM,
3394                }],
3395            },
3396            maximum_name_length: 0,
3397            public_name: DnsName::try_from("example.com").unwrap(),
3398            extensions: vec![],
3399        }
3400    }
3401}