clia_rustls_mod/msgs/
handshake.rs

1#![allow(non_camel_case_types)]
2
3use alloc::collections::BTreeSet;
4#[cfg(feature = "logging")]
5use alloc::string::String;
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt;
9use core::ops::Deref;
10
11use pki_types::{CertificateDer, DnsName};
12
13#[cfg(feature = "tls12")]
14use crate::crypto::ActiveKeyExchange;
15use crate::crypto::SecureRandom;
16use crate::enums::{CipherSuite, HandshakeType, ProtocolVersion, SignatureScheme};
17use crate::error::InvalidMessage;
18#[cfg(feature = "tls12")]
19use crate::ffdhe_groups::FfdheGroup;
20#[cfg(feature = "logging")]
21use crate::log::warn;
22use crate::msgs::base::{Payload, PayloadU16, PayloadU24, PayloadU8};
23use crate::msgs::codec::{self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement};
24use crate::msgs::enums::{
25    CertificateStatusType, ClientCertificateType, Compression, ECCurveType, ECPointFormat,
26    EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest, NamedGroup,
27    PSKKeyExchangeMode, ServerNameType,
28};
29use crate::rand;
30use crate::verify::DigitallySignedStruct;
31use crate::x509::wrap_in_sequence;
32
33/// Create a newtype wrapper around a given type.
34///
35/// This is used to create newtypes for the various TLS message types which is used to wrap
36/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
37/// anything other than access to the underlying bytes.
38macro_rules! wrapped_payload(
39  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident,) => {
40    $(#[$comment])*
41    #[derive(Clone, Debug)]
42    $vis struct $name($inner);
43
44    impl From<Vec<u8>> for $name {
45        fn from(v: Vec<u8>) -> Self {
46            Self($inner::new(v))
47        }
48    }
49
50    impl AsRef<[u8]> for $name {
51        fn as_ref(&self) -> &[u8] {
52            self.0.0.as_slice()
53        }
54    }
55
56    impl Codec<'_> for $name {
57        fn encode(&self, bytes: &mut Vec<u8>) {
58            self.0.encode(bytes);
59        }
60
61        fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
62            Ok(Self($inner::read(r)?))
63        }
64    }
65  }
66);
67
68#[derive(Clone, Copy, Eq, PartialEq)]
69pub struct Random(pub(crate) [u8; 32]);
70
71impl fmt::Debug for Random {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        super::base::hex(f, &self.0)
74    }
75}
76
77static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
78    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
79    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
80]);
81
82static ZERO_RANDOM: Random = Random([0u8; 32]);
83
84impl Codec<'_> for Random {
85    fn encode(&self, bytes: &mut Vec<u8>) {
86        bytes.extend_from_slice(&self.0);
87    }
88
89    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
90        let bytes = match r.take(32) {
91            Some(bytes) => bytes,
92            None => return Err(InvalidMessage::MissingData("Random")),
93        };
94
95        let mut opaque = [0; 32];
96        opaque.clone_from_slice(bytes);
97        Ok(Self(opaque))
98    }
99}
100
101impl Random {
102    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
103        let mut data = [0u8; 32];
104        secure_random.fill(&mut data)?;
105        Ok(Self(data))
106    }
107}
108
109impl From<[u8; 32]> for Random {
110    #[inline]
111    fn from(bytes: [u8; 32]) -> Self {
112        Self(bytes)
113    }
114}
115
116#[derive(Copy, Clone)]
117pub struct SessionId {
118    len: usize,
119    data: [u8; 32],
120}
121
122impl fmt::Debug for SessionId {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        super::base::hex(f, &self.data[..self.len])
125    }
126}
127
128impl PartialEq for SessionId {
129    fn eq(&self, other: &Self) -> bool {
130        if self.len != other.len {
131            return false;
132        }
133
134        let mut diff = 0u8;
135        for i in 0..self.len {
136            diff |= self.data[i] ^ other.data[i];
137        }
138
139        diff == 0u8
140    }
141}
142
143impl Codec<'_> for SessionId {
144    fn encode(&self, bytes: &mut Vec<u8>) {
145        debug_assert!(self.len <= 32);
146        bytes.push(self.len as u8);
147        bytes.extend_from_slice(&self.data[..self.len]);
148    }
149
150    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
151        let len = u8::read(r)? as usize;
152        if len > 32 {
153            return Err(InvalidMessage::TrailingData("SessionID"));
154        }
155
156        let bytes = match r.take(len) {
157            Some(bytes) => bytes,
158            None => return Err(InvalidMessage::MissingData("SessionID")),
159        };
160
161        let mut out = [0u8; 32];
162        out[..len].clone_from_slice(&bytes[..len]);
163        Ok(Self { data: out, len })
164    }
165}
166
167impl SessionId {
168    pub fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
169        let mut data = [0u8; 32];
170        secure_random.fill(&mut data)?;
171        Ok(Self { data, len: 32 })
172    }
173
174    pub(crate) fn empty() -> Self {
175        Self {
176            data: [0u8; 32],
177            len: 0,
178        }
179    }
180
181    #[cfg(feature = "tls12")]
182    pub(crate) fn is_empty(&self) -> bool {
183        self.len == 0
184    }
185}
186
187#[derive(Clone, Debug, PartialEq)]
188pub struct UnknownExtension {
189    pub(crate) typ: ExtensionType,
190    pub(crate) payload: Payload<'static>,
191}
192
193impl UnknownExtension {
194    fn encode(&self, bytes: &mut Vec<u8>) {
195        self.payload.encode(bytes);
196    }
197
198    fn read(typ: ExtensionType, r: &mut Reader) -> Self {
199        let payload = Payload::read(r).into_owned();
200        Self { typ, payload }
201    }
202}
203
204impl TlsListElement for ECPointFormat {
205    const SIZE_LEN: ListLength = ListLength::U8;
206}
207
208impl TlsListElement for NamedGroup {
209    const SIZE_LEN: ListLength = ListLength::U16;
210}
211
212impl TlsListElement for SignatureScheme {
213    const SIZE_LEN: ListLength = ListLength::U16;
214}
215
216#[derive(Clone, Debug)]
217pub(crate) enum ServerNamePayload {
218    HostName(DnsName<'static>),
219    Unknown(Payload<'static>),
220}
221
222impl ServerNamePayload {
223    pub(crate) fn new_hostname(hostname: DnsName<'static>) -> Self {
224        Self::HostName(hostname)
225    }
226
227    fn read_hostname(r: &mut Reader) -> Result<Self, InvalidMessage> {
228        let raw = PayloadU16::read(r)?;
229
230        match DnsName::try_from(raw.0.as_slice()) {
231            Ok(dns_name) => Ok(Self::HostName(dns_name.to_owned())),
232            Err(_) => {
233                warn!(
234                    "Illegal SNI hostname received {:?}",
235                    String::from_utf8_lossy(&raw.0)
236                );
237                Err(InvalidMessage::InvalidServerName)
238            }
239        }
240    }
241
242    fn encode(&self, bytes: &mut Vec<u8>) {
243        match *self {
244            Self::HostName(ref name) => {
245                (name.as_ref().len() as u16).encode(bytes);
246                bytes.extend_from_slice(name.as_ref().as_bytes());
247            }
248            Self::Unknown(ref r) => r.encode(bytes),
249        }
250    }
251}
252
253#[derive(Clone, Debug)]
254pub struct ServerName {
255    pub(crate) typ: ServerNameType,
256    pub(crate) payload: ServerNamePayload,
257}
258
259impl Codec<'_> for ServerName {
260    fn encode(&self, bytes: &mut Vec<u8>) {
261        self.typ.encode(bytes);
262        self.payload.encode(bytes);
263    }
264
265    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
266        let typ = ServerNameType::read(r)?;
267
268        let payload = match typ {
269            ServerNameType::HostName => ServerNamePayload::read_hostname(r)?,
270            _ => ServerNamePayload::Unknown(Payload::read(r).into_owned()),
271        };
272
273        Ok(Self { typ, payload })
274    }
275}
276
277impl TlsListElement for ServerName {
278    const SIZE_LEN: ListLength = ListLength::U16;
279}
280
281pub(crate) trait ConvertServerNameList {
282    fn has_duplicate_names_for_type(&self) -> bool;
283    fn single_hostname(&self) -> Option<DnsName<'_>>;
284}
285
286impl ConvertServerNameList for [ServerName] {
287    /// RFC6066: "The ServerNameList MUST NOT contain more than one name of the same name_type."
288    fn has_duplicate_names_for_type(&self) -> bool {
289        has_duplicates::<_, _, u8>(self.iter().map(|name| name.typ))
290    }
291
292    fn single_hostname(&self) -> Option<DnsName<'_>> {
293        fn only_dns_hostnames(name: &ServerName) -> Option<DnsName<'_>> {
294            if let ServerNamePayload::HostName(ref dns) = name.payload {
295                Some(dns.borrow())
296            } else {
297                None
298            }
299        }
300
301        self.iter()
302            .filter_map(only_dns_hostnames)
303            .next()
304    }
305}
306
307wrapped_payload!(pub struct ProtocolName, PayloadU8,);
308
309impl TlsListElement for ProtocolName {
310    const SIZE_LEN: ListLength = ListLength::U16;
311}
312
313pub(crate) trait ConvertProtocolNameList {
314    fn from_slices(names: &[&[u8]]) -> Self;
315    fn to_slices(&self) -> Vec<&[u8]>;
316    fn as_single_slice(&self) -> Option<&[u8]>;
317}
318
319impl ConvertProtocolNameList for Vec<ProtocolName> {
320    fn from_slices(names: &[&[u8]]) -> Self {
321        let mut ret = Self::new();
322
323        for name in names {
324            ret.push(ProtocolName::from(name.to_vec()));
325        }
326
327        ret
328    }
329
330    fn to_slices(&self) -> Vec<&[u8]> {
331        self.iter()
332            .map(|proto| proto.as_ref())
333            .collect::<Vec<&[u8]>>()
334    }
335
336    fn as_single_slice(&self) -> Option<&[u8]> {
337        if self.len() == 1 {
338            Some(self[0].as_ref())
339        } else {
340            None
341        }
342    }
343}
344
345// --- TLS 1.3 Key shares ---
346#[derive(Clone, Debug)]
347pub struct KeyShareEntry {
348    pub(crate) group: NamedGroup,
349    pub(crate) payload: PayloadU16,
350}
351
352impl KeyShareEntry {
353    pub fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self {
354        Self {
355            group,
356            payload: PayloadU16::new(payload.into()),
357        }
358    }
359
360    pub fn group(&self) -> NamedGroup {
361        self.group
362    }
363}
364
365impl Codec<'_> for KeyShareEntry {
366    fn encode(&self, bytes: &mut Vec<u8>) {
367        self.group.encode(bytes);
368        self.payload.encode(bytes);
369    }
370
371    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
372        let group = NamedGroup::read(r)?;
373        let payload = PayloadU16::read(r)?;
374
375        Ok(Self { group, payload })
376    }
377}
378
379// --- TLS 1.3 PresharedKey offers ---
380#[derive(Clone, Debug)]
381pub(crate) struct PresharedKeyIdentity {
382    pub(crate) identity: PayloadU16,
383    pub(crate) obfuscated_ticket_age: u32,
384}
385
386impl PresharedKeyIdentity {
387    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
388        Self {
389            identity: PayloadU16::new(id),
390            obfuscated_ticket_age: age,
391        }
392    }
393}
394
395impl Codec<'_> for PresharedKeyIdentity {
396    fn encode(&self, bytes: &mut Vec<u8>) {
397        self.identity.encode(bytes);
398        self.obfuscated_ticket_age.encode(bytes);
399    }
400
401    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
402        Ok(Self {
403            identity: PayloadU16::read(r)?,
404            obfuscated_ticket_age: u32::read(r)?,
405        })
406    }
407}
408
409impl TlsListElement for PresharedKeyIdentity {
410    const SIZE_LEN: ListLength = ListLength::U16;
411}
412
413wrapped_payload!(pub(crate) struct PresharedKeyBinder, PayloadU8,);
414
415impl TlsListElement for PresharedKeyBinder {
416    const SIZE_LEN: ListLength = ListLength::U16;
417}
418
419#[derive(Clone, Debug)]
420pub struct PresharedKeyOffer {
421    pub(crate) identities: Vec<PresharedKeyIdentity>,
422    pub(crate) binders: Vec<PresharedKeyBinder>,
423}
424
425impl PresharedKeyOffer {
426    /// Make a new one with one entry.
427    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
428        Self {
429            identities: vec![id],
430            binders: vec![PresharedKeyBinder::from(binder)],
431        }
432    }
433}
434
435impl Codec<'_> for PresharedKeyOffer {
436    fn encode(&self, bytes: &mut Vec<u8>) {
437        self.identities.encode(bytes);
438        self.binders.encode(bytes);
439    }
440
441    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
442        Ok(Self {
443            identities: Vec::read(r)?,
444            binders: Vec::read(r)?,
445        })
446    }
447}
448
449// --- RFC6066 certificate status request ---
450wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
451
452impl TlsListElement for ResponderId {
453    const SIZE_LEN: ListLength = ListLength::U16;
454}
455
456#[derive(Clone, Debug)]
457pub struct OcspCertificateStatusRequest {
458    pub(crate) responder_ids: Vec<ResponderId>,
459    pub(crate) extensions: PayloadU16,
460}
461
462impl Codec<'_> for OcspCertificateStatusRequest {
463    fn encode(&self, bytes: &mut Vec<u8>) {
464        CertificateStatusType::OCSP.encode(bytes);
465        self.responder_ids.encode(bytes);
466        self.extensions.encode(bytes);
467    }
468
469    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
470        Ok(Self {
471            responder_ids: Vec::read(r)?,
472            extensions: PayloadU16::read(r)?,
473        })
474    }
475}
476
477#[derive(Clone, Debug)]
478pub enum CertificateStatusRequest {
479    Ocsp(OcspCertificateStatusRequest),
480    Unknown((CertificateStatusType, Payload<'static>)),
481}
482
483impl Codec<'_> for CertificateStatusRequest {
484    fn encode(&self, bytes: &mut Vec<u8>) {
485        match self {
486            Self::Ocsp(ref r) => r.encode(bytes),
487            Self::Unknown((typ, payload)) => {
488                typ.encode(bytes);
489                payload.encode(bytes);
490            }
491        }
492    }
493
494    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
495        let typ = CertificateStatusType::read(r)?;
496
497        match typ {
498            CertificateStatusType::OCSP => {
499                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
500                Ok(Self::Ocsp(ocsp_req))
501            }
502            _ => {
503                let data = Payload::read(r).into_owned();
504                Ok(Self::Unknown((typ, data)))
505            }
506        }
507    }
508}
509
510impl CertificateStatusRequest {
511    pub(crate) fn build_ocsp() -> Self {
512        let ocsp = OcspCertificateStatusRequest {
513            responder_ids: Vec::new(),
514            extensions: PayloadU16::empty(),
515        };
516        Self::Ocsp(ocsp)
517    }
518}
519
520// ---
521
522impl TlsListElement for PSKKeyExchangeMode {
523    const SIZE_LEN: ListLength = ListLength::U8;
524}
525
526impl TlsListElement for KeyShareEntry {
527    const SIZE_LEN: ListLength = ListLength::U16;
528}
529
530impl TlsListElement for ProtocolVersion {
531    const SIZE_LEN: ListLength = ListLength::U8;
532}
533
534#[derive(Clone, Debug)]
535pub enum ClientExtension {
536    EcPointFormats(Vec<ECPointFormat>),
537    NamedGroups(Vec<NamedGroup>),
538    SignatureAlgorithms(Vec<SignatureScheme>),
539    ServerName(Vec<ServerName>),
540    SessionTicket(ClientSessionTicket),
541    Protocols(Vec<ProtocolName>),
542    SupportedVersions(Vec<ProtocolVersion>),
543    KeyShare(Vec<KeyShareEntry>),
544    PresharedKeyModes(Vec<PSKKeyExchangeMode>),
545    PresharedKey(PresharedKeyOffer),
546    Cookie(PayloadU16),
547    ExtendedMasterSecretRequest,
548    CertificateStatusRequest(CertificateStatusRequest),
549    TransportParameters(Vec<u8>),
550    TransportParametersDraft(Vec<u8>),
551    EarlyData,
552    Unknown(UnknownExtension),
553}
554
555impl ClientExtension {
556    pub(crate) fn ext_type(&self) -> ExtensionType {
557        match *self {
558            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
559            Self::NamedGroups(_) => ExtensionType::EllipticCurves,
560            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
561            Self::ServerName(_) => ExtensionType::ServerName,
562            Self::SessionTicket(_) => ExtensionType::SessionTicket,
563            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
564            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
565            Self::KeyShare(_) => ExtensionType::KeyShare,
566            Self::PresharedKeyModes(_) => ExtensionType::PSKKeyExchangeModes,
567            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
568            Self::Cookie(_) => ExtensionType::Cookie,
569            Self::ExtendedMasterSecretRequest => ExtensionType::ExtendedMasterSecret,
570            Self::CertificateStatusRequest(_) => ExtensionType::StatusRequest,
571            Self::TransportParameters(_) => ExtensionType::TransportParameters,
572            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
573            Self::EarlyData => ExtensionType::EarlyData,
574            Self::Unknown(ref r) => r.typ,
575        }
576    }
577}
578
579impl Codec<'_> for ClientExtension {
580    fn encode(&self, bytes: &mut Vec<u8>) {
581        self.ext_type().encode(bytes);
582
583        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
584        match *self {
585            Self::EcPointFormats(ref r) => r.encode(nested.buf),
586            Self::NamedGroups(ref r) => r.encode(nested.buf),
587            Self::SignatureAlgorithms(ref r) => r.encode(nested.buf),
588            Self::ServerName(ref r) => r.encode(nested.buf),
589            Self::SessionTicket(ClientSessionTicket::Request)
590            | Self::ExtendedMasterSecretRequest
591            | Self::EarlyData => {}
592            Self::SessionTicket(ClientSessionTicket::Offer(ref r)) => r.encode(nested.buf),
593            Self::Protocols(ref r) => r.encode(nested.buf),
594            Self::SupportedVersions(ref r) => r.encode(nested.buf),
595            Self::KeyShare(ref r) => r.encode(nested.buf),
596            Self::PresharedKeyModes(ref r) => r.encode(nested.buf),
597            Self::PresharedKey(ref r) => r.encode(nested.buf),
598            Self::Cookie(ref r) => r.encode(nested.buf),
599            Self::CertificateStatusRequest(ref r) => r.encode(nested.buf),
600            Self::TransportParameters(ref r) | Self::TransportParametersDraft(ref r) => {
601                nested.buf.extend_from_slice(r);
602            }
603            Self::Unknown(ref r) => r.encode(nested.buf),
604        }
605    }
606
607    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
608        let typ = ExtensionType::read(r)?;
609        let len = u16::read(r)? as usize;
610        let mut sub = r.sub(len)?;
611
612        let ext = match typ {
613            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
614            ExtensionType::EllipticCurves => Self::NamedGroups(Vec::read(&mut sub)?),
615            ExtensionType::SignatureAlgorithms => Self::SignatureAlgorithms(Vec::read(&mut sub)?),
616            ExtensionType::ServerName => Self::ServerName(Vec::read(&mut sub)?),
617            ExtensionType::SessionTicket => {
618                if sub.any_left() {
619                    let contents = Payload::read(&mut sub).into_owned();
620                    Self::SessionTicket(ClientSessionTicket::Offer(contents))
621                } else {
622                    Self::SessionTicket(ClientSessionTicket::Request)
623                }
624            }
625            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
626            ExtensionType::SupportedVersions => Self::SupportedVersions(Vec::read(&mut sub)?),
627            ExtensionType::KeyShare => Self::KeyShare(Vec::read(&mut sub)?),
628            ExtensionType::PSKKeyExchangeModes => Self::PresharedKeyModes(Vec::read(&mut sub)?),
629            ExtensionType::PreSharedKey => Self::PresharedKey(PresharedKeyOffer::read(&mut sub)?),
630            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
631            ExtensionType::ExtendedMasterSecret if !sub.any_left() => {
632                Self::ExtendedMasterSecretRequest
633            }
634            ExtensionType::StatusRequest => {
635                let csr = CertificateStatusRequest::read(&mut sub)?;
636                Self::CertificateStatusRequest(csr)
637            }
638            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
639            ExtensionType::TransportParametersDraft => {
640                Self::TransportParametersDraft(sub.rest().to_vec())
641            }
642            ExtensionType::EarlyData if !sub.any_left() => Self::EarlyData,
643            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
644        };
645
646        sub.expect_empty("ClientExtension")
647            .map(|_| ext)
648    }
649}
650
651fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
652    let dns_name_str = dns_name.as_ref();
653
654    // RFC6066: "The hostname is represented as a byte string using
655    // ASCII encoding without a trailing dot"
656    if dns_name_str.ends_with('.') {
657        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
658        DnsName::try_from(trimmed)
659            .unwrap()
660            .to_owned()
661    } else {
662        dns_name.to_owned()
663    }
664}
665
666impl ClientExtension {
667    /// Make a basic SNI ServerNameRequest quoting `hostname`.
668    pub(crate) fn make_sni(dns_name: &DnsName<'_>) -> Self {
669        let name = ServerName {
670            typ: ServerNameType::HostName,
671            payload: ServerNamePayload::new_hostname(trim_hostname_trailing_dot_for_sni(dns_name)),
672        };
673
674        Self::ServerName(vec![name])
675    }
676}
677
678#[derive(Clone, Debug)]
679pub enum ClientSessionTicket {
680    Request,
681    Offer(Payload<'static>),
682}
683
684#[derive(Clone, Debug)]
685pub enum ServerExtension {
686    EcPointFormats(Vec<ECPointFormat>),
687    ServerNameAck,
688    SessionTicketAck,
689    RenegotiationInfo(PayloadU8),
690    Protocols(Vec<ProtocolName>),
691    KeyShare(KeyShareEntry),
692    PresharedKey(u16),
693    ExtendedMasterSecretAck,
694    CertificateStatusAck,
695    SupportedVersions(ProtocolVersion),
696    TransportParameters(Vec<u8>),
697    TransportParametersDraft(Vec<u8>),
698    EarlyData,
699    Unknown(UnknownExtension),
700}
701
702impl ServerExtension {
703    pub(crate) fn ext_type(&self) -> ExtensionType {
704        match *self {
705            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
706            Self::ServerNameAck => ExtensionType::ServerName,
707            Self::SessionTicketAck => ExtensionType::SessionTicket,
708            Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo,
709            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
710            Self::KeyShare(_) => ExtensionType::KeyShare,
711            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
712            Self::ExtendedMasterSecretAck => ExtensionType::ExtendedMasterSecret,
713            Self::CertificateStatusAck => ExtensionType::StatusRequest,
714            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
715            Self::TransportParameters(_) => ExtensionType::TransportParameters,
716            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
717            Self::EarlyData => ExtensionType::EarlyData,
718            Self::Unknown(ref r) => r.typ,
719        }
720    }
721}
722
723impl Codec<'_> for ServerExtension {
724    fn encode(&self, bytes: &mut Vec<u8>) {
725        self.ext_type().encode(bytes);
726
727        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
728        match *self {
729            Self::EcPointFormats(ref r) => r.encode(nested.buf),
730            Self::ServerNameAck
731            | Self::SessionTicketAck
732            | Self::ExtendedMasterSecretAck
733            | Self::CertificateStatusAck
734            | Self::EarlyData => {}
735            Self::RenegotiationInfo(ref r) => r.encode(nested.buf),
736            Self::Protocols(ref r) => r.encode(nested.buf),
737            Self::KeyShare(ref r) => r.encode(nested.buf),
738            Self::PresharedKey(r) => r.encode(nested.buf),
739            Self::SupportedVersions(ref r) => r.encode(nested.buf),
740            Self::TransportParameters(ref r) | Self::TransportParametersDraft(ref r) => {
741                nested.buf.extend_from_slice(r);
742            }
743            Self::Unknown(ref r) => r.encode(nested.buf),
744        }
745    }
746
747    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
748        let typ = ExtensionType::read(r)?;
749        let len = u16::read(r)? as usize;
750        let mut sub = r.sub(len)?;
751
752        let ext = match typ {
753            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
754            ExtensionType::ServerName => Self::ServerNameAck,
755            ExtensionType::SessionTicket => Self::SessionTicketAck,
756            ExtensionType::StatusRequest => Self::CertificateStatusAck,
757            ExtensionType::RenegotiationInfo => Self::RenegotiationInfo(PayloadU8::read(&mut sub)?),
758            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
759            ExtensionType::KeyShare => Self::KeyShare(KeyShareEntry::read(&mut sub)?),
760            ExtensionType::PreSharedKey => Self::PresharedKey(u16::read(&mut sub)?),
761            ExtensionType::ExtendedMasterSecret => Self::ExtendedMasterSecretAck,
762            ExtensionType::SupportedVersions => {
763                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
764            }
765            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
766            ExtensionType::TransportParametersDraft => {
767                Self::TransportParametersDraft(sub.rest().to_vec())
768            }
769            ExtensionType::EarlyData => Self::EarlyData,
770            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
771        };
772
773        sub.expect_empty("ServerExtension")
774            .map(|_| ext)
775    }
776}
777
778impl ServerExtension {
779    pub(crate) fn make_alpn(proto: &[&[u8]]) -> Self {
780        Self::Protocols(Vec::from_slices(proto))
781    }
782
783    #[cfg(feature = "tls12")]
784    pub(crate) fn make_empty_renegotiation_info() -> Self {
785        let empty = Vec::new();
786        Self::RenegotiationInfo(PayloadU8::new(empty))
787    }
788}
789
790#[derive(Clone, Debug)]
791pub struct ClientHelloPayload {
792    pub client_version: ProtocolVersion,
793    pub random: Random,
794    pub session_id: SessionId,
795    pub cipher_suites: Vec<CipherSuite>,
796    pub compression_methods: Vec<Compression>,
797    pub extensions: Vec<ClientExtension>,
798}
799
800impl Codec<'_> for ClientHelloPayload {
801    fn encode(&self, bytes: &mut Vec<u8>) {
802        self.client_version.encode(bytes);
803        self.random.encode(bytes);
804        self.session_id.encode(bytes);
805        self.cipher_suites.encode(bytes);
806        self.compression_methods.encode(bytes);
807
808        if !self.extensions.is_empty() {
809            self.extensions.encode(bytes);
810        }
811    }
812
813    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
814        let mut ret = Self {
815            client_version: ProtocolVersion::read(r)?,
816            random: Random::read(r)?,
817            session_id: SessionId::read(r)?,
818            cipher_suites: Vec::read(r)?,
819            compression_methods: Vec::read(r)?,
820            extensions: Vec::new(),
821        };
822
823        if r.any_left() {
824            ret.extensions = Vec::read(r)?;
825        }
826
827        match (r.any_left(), ret.extensions.is_empty()) {
828            (true, _) => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
829            (_, true) => Err(InvalidMessage::MissingData("ClientHelloPayload")),
830            _ => Ok(ret),
831        }
832    }
833}
834
835impl TlsListElement for CipherSuite {
836    const SIZE_LEN: ListLength = ListLength::U16;
837}
838
839impl TlsListElement for Compression {
840    const SIZE_LEN: ListLength = ListLength::U8;
841}
842
843impl TlsListElement for ClientExtension {
844    const SIZE_LEN: ListLength = ListLength::U16;
845}
846
847impl ClientHelloPayload {
848    /// Returns true if there is more than one extension of a given
849    /// type.
850    pub(crate) fn has_duplicate_extension(&self) -> bool {
851        has_duplicates::<_, _, u16>(
852            self.extensions
853                .iter()
854                .map(|ext| ext.ext_type()),
855        )
856    }
857
858    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&ClientExtension> {
859        self.extensions
860            .iter()
861            .find(|x| x.ext_type() == ext)
862    }
863
864    pub(crate) fn sni_extension(&self) -> Option<&[ServerName]> {
865        let ext = self.find_extension(ExtensionType::ServerName)?;
866        match *ext {
867            ClientExtension::ServerName(ref req) => Some(req),
868            _ => None,
869        }
870    }
871
872    pub fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
873        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
874        match *ext {
875            ClientExtension::SignatureAlgorithms(ref req) => Some(req),
876            _ => None,
877        }
878    }
879
880    pub(crate) fn namedgroups_extension(&self) -> Option<&[NamedGroup]> {
881        let ext = self.find_extension(ExtensionType::EllipticCurves)?;
882        match *ext {
883            ClientExtension::NamedGroups(ref req) => Some(req),
884            _ => None,
885        }
886    }
887
888    #[cfg(feature = "tls12")]
889    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
890        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
891        match *ext {
892            ClientExtension::EcPointFormats(ref req) => Some(req),
893            _ => None,
894        }
895    }
896
897    pub(crate) fn alpn_extension(&self) -> Option<&Vec<ProtocolName>> {
898        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
899        match *ext {
900            ClientExtension::Protocols(ref req) => Some(req),
901            _ => None,
902        }
903    }
904
905    pub(crate) fn quic_params_extension(&self) -> Option<Vec<u8>> {
906        let ext = self
907            .find_extension(ExtensionType::TransportParameters)
908            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
909        match *ext {
910            ClientExtension::TransportParameters(ref bytes)
911            | ClientExtension::TransportParametersDraft(ref bytes) => Some(bytes.to_vec()),
912            _ => None,
913        }
914    }
915
916    #[cfg(feature = "tls12")]
917    pub(crate) fn ticket_extension(&self) -> Option<&ClientExtension> {
918        self.find_extension(ExtensionType::SessionTicket)
919    }
920
921    pub(crate) fn versions_extension(&self) -> Option<&[ProtocolVersion]> {
922        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
923        match *ext {
924            ClientExtension::SupportedVersions(ref vers) => Some(vers),
925            _ => None,
926        }
927    }
928
929    pub fn keyshare_extension(&self) -> Option<&[KeyShareEntry]> {
930        let ext = self.find_extension(ExtensionType::KeyShare)?;
931        match *ext {
932            ClientExtension::KeyShare(ref shares) => Some(shares),
933            _ => None,
934        }
935    }
936
937    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
938        if let Some(entries) = self.keyshare_extension() {
939            let mut seen = BTreeSet::new();
940
941            for kse in entries {
942                let grp = u16::from(kse.group);
943
944                if !seen.insert(grp) {
945                    return true;
946                }
947            }
948        }
949
950        false
951    }
952
953    pub(crate) fn psk(&self) -> Option<&PresharedKeyOffer> {
954        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
955        match *ext {
956            ClientExtension::PresharedKey(ref psk) => Some(psk),
957            _ => None,
958        }
959    }
960
961    pub(crate) fn check_psk_ext_is_last(&self) -> bool {
962        self.extensions
963            .last()
964            .map_or(false, |ext| ext.ext_type() == ExtensionType::PreSharedKey)
965    }
966
967    pub(crate) fn psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> {
968        let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
969        match *ext {
970            ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes),
971            _ => None,
972        }
973    }
974
975    pub(crate) fn psk_mode_offered(&self, mode: PSKKeyExchangeMode) -> bool {
976        self.psk_modes()
977            .map(|modes| modes.contains(&mode))
978            .unwrap_or(false)
979    }
980
981    pub(crate) fn set_psk_binder(&mut self, binder: impl Into<Vec<u8>>) {
982        let last_extension = self.extensions.last_mut();
983        if let Some(ClientExtension::PresharedKey(ref mut offer)) = last_extension {
984            offer.binders[0] = PresharedKeyBinder::from(binder.into());
985        }
986    }
987
988    #[cfg(feature = "tls12")]
989    pub(crate) fn ems_support_offered(&self) -> bool {
990        self.find_extension(ExtensionType::ExtendedMasterSecret)
991            .is_some()
992    }
993
994    pub(crate) fn early_data_extension_offered(&self) -> bool {
995        self.find_extension(ExtensionType::EarlyData)
996            .is_some()
997    }
998}
999
1000#[derive(Clone, Debug)]
1001pub(crate) enum HelloRetryExtension {
1002    KeyShare(NamedGroup),
1003    Cookie(PayloadU16),
1004    SupportedVersions(ProtocolVersion),
1005    Unknown(UnknownExtension),
1006}
1007
1008impl HelloRetryExtension {
1009    pub(crate) fn ext_type(&self) -> ExtensionType {
1010        match *self {
1011            Self::KeyShare(_) => ExtensionType::KeyShare,
1012            Self::Cookie(_) => ExtensionType::Cookie,
1013            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
1014            Self::Unknown(ref r) => r.typ,
1015        }
1016    }
1017}
1018
1019impl Codec<'_> for HelloRetryExtension {
1020    fn encode(&self, bytes: &mut Vec<u8>) {
1021        self.ext_type().encode(bytes);
1022
1023        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1024        match *self {
1025            Self::KeyShare(ref r) => r.encode(nested.buf),
1026            Self::Cookie(ref r) => r.encode(nested.buf),
1027            Self::SupportedVersions(ref r) => r.encode(nested.buf),
1028            Self::Unknown(ref r) => r.encode(nested.buf),
1029        }
1030    }
1031
1032    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1033        let typ = ExtensionType::read(r)?;
1034        let len = u16::read(r)? as usize;
1035        let mut sub = r.sub(len)?;
1036
1037        let ext = match typ {
1038            ExtensionType::KeyShare => Self::KeyShare(NamedGroup::read(&mut sub)?),
1039            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
1040            ExtensionType::SupportedVersions => {
1041                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
1042            }
1043            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1044        };
1045
1046        sub.expect_empty("HelloRetryExtension")
1047            .map(|_| ext)
1048    }
1049}
1050
1051impl TlsListElement for HelloRetryExtension {
1052    const SIZE_LEN: ListLength = ListLength::U16;
1053}
1054
1055#[derive(Debug)]
1056pub struct HelloRetryRequest {
1057    pub(crate) legacy_version: ProtocolVersion,
1058    pub session_id: SessionId,
1059    pub(crate) cipher_suite: CipherSuite,
1060    pub(crate) extensions: Vec<HelloRetryExtension>,
1061}
1062
1063impl Codec<'_> for HelloRetryRequest {
1064    fn encode(&self, bytes: &mut Vec<u8>) {
1065        self.legacy_version.encode(bytes);
1066        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1067        self.session_id.encode(bytes);
1068        self.cipher_suite.encode(bytes);
1069        Compression::Null.encode(bytes);
1070        self.extensions.encode(bytes);
1071    }
1072
1073    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1074        let session_id = SessionId::read(r)?;
1075        let cipher_suite = CipherSuite::read(r)?;
1076        let compression = Compression::read(r)?;
1077
1078        if compression != Compression::Null {
1079            return Err(InvalidMessage::UnsupportedCompression);
1080        }
1081
1082        Ok(Self {
1083            legacy_version: ProtocolVersion::Unknown(0),
1084            session_id,
1085            cipher_suite,
1086            extensions: Vec::read(r)?,
1087        })
1088    }
1089}
1090
1091impl HelloRetryRequest {
1092    /// Returns true if there is more than one extension of a given
1093    /// type.
1094    pub(crate) fn has_duplicate_extension(&self) -> bool {
1095        has_duplicates::<_, _, u16>(
1096            self.extensions
1097                .iter()
1098                .map(|ext| ext.ext_type()),
1099        )
1100    }
1101
1102    pub(crate) fn has_unknown_extension(&self) -> bool {
1103        self.extensions.iter().any(|ext| {
1104            ext.ext_type() != ExtensionType::KeyShare
1105                && ext.ext_type() != ExtensionType::SupportedVersions
1106                && ext.ext_type() != ExtensionType::Cookie
1107        })
1108    }
1109
1110    fn find_extension(&self, ext: ExtensionType) -> Option<&HelloRetryExtension> {
1111        self.extensions
1112            .iter()
1113            .find(|x| x.ext_type() == ext)
1114    }
1115
1116    pub fn requested_key_share_group(&self) -> Option<NamedGroup> {
1117        let ext = self.find_extension(ExtensionType::KeyShare)?;
1118        match *ext {
1119            HelloRetryExtension::KeyShare(grp) => Some(grp),
1120            _ => None,
1121        }
1122    }
1123
1124    pub(crate) fn cookie(&self) -> Option<&PayloadU16> {
1125        let ext = self.find_extension(ExtensionType::Cookie)?;
1126        match *ext {
1127            HelloRetryExtension::Cookie(ref ck) => Some(ck),
1128            _ => None,
1129        }
1130    }
1131
1132    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1133        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1134        match *ext {
1135            HelloRetryExtension::SupportedVersions(ver) => Some(ver),
1136            _ => None,
1137        }
1138    }
1139}
1140
1141#[derive(Clone, Debug)]
1142pub struct ServerHelloPayload {
1143    pub(crate) legacy_version: ProtocolVersion,
1144    pub(crate) random: Random,
1145    pub(crate) session_id: SessionId,
1146    pub(crate) cipher_suite: CipherSuite,
1147    pub(crate) compression_method: Compression,
1148    pub(crate) extensions: Vec<ServerExtension>,
1149}
1150
1151impl Codec<'_> for ServerHelloPayload {
1152    fn encode(&self, bytes: &mut Vec<u8>) {
1153        self.legacy_version.encode(bytes);
1154        self.random.encode(bytes);
1155
1156        self.session_id.encode(bytes);
1157        self.cipher_suite.encode(bytes);
1158        self.compression_method.encode(bytes);
1159
1160        if !self.extensions.is_empty() {
1161            self.extensions.encode(bytes);
1162        }
1163    }
1164
1165    // minus version and random, which have already been read.
1166    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1167        let session_id = SessionId::read(r)?;
1168        let suite = CipherSuite::read(r)?;
1169        let compression = Compression::read(r)?;
1170
1171        // RFC5246:
1172        // "The presence of extensions can be detected by determining whether
1173        //  there are bytes following the compression_method field at the end of
1174        //  the ServerHello."
1175        let extensions = if r.any_left() { Vec::read(r)? } else { vec![] };
1176
1177        let ret = Self {
1178            legacy_version: ProtocolVersion::Unknown(0),
1179            random: ZERO_RANDOM,
1180            session_id,
1181            cipher_suite: suite,
1182            compression_method: compression,
1183            extensions,
1184        };
1185
1186        r.expect_empty("ServerHelloPayload")
1187            .map(|_| ret)
1188    }
1189}
1190
1191impl HasServerExtensions for ServerHelloPayload {
1192    fn extensions(&self) -> &[ServerExtension] {
1193        &self.extensions
1194    }
1195}
1196
1197impl ServerHelloPayload {
1198    pub(crate) fn key_share(&self) -> Option<&KeyShareEntry> {
1199        let ext = self.find_extension(ExtensionType::KeyShare)?;
1200        match *ext {
1201            ServerExtension::KeyShare(ref share) => Some(share),
1202            _ => None,
1203        }
1204    }
1205
1206    pub(crate) fn psk_index(&self) -> Option<u16> {
1207        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1208        match *ext {
1209            ServerExtension::PresharedKey(ref index) => Some(*index),
1210            _ => None,
1211        }
1212    }
1213
1214    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1215        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1216        match *ext {
1217            ServerExtension::EcPointFormats(ref fmts) => Some(fmts),
1218            _ => None,
1219        }
1220    }
1221
1222    #[cfg(feature = "tls12")]
1223    pub(crate) fn ems_support_acked(&self) -> bool {
1224        self.find_extension(ExtensionType::ExtendedMasterSecret)
1225            .is_some()
1226    }
1227
1228    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1229        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1230        match *ext {
1231            ServerExtension::SupportedVersions(vers) => Some(vers),
1232            _ => None,
1233        }
1234    }
1235}
1236
1237#[derive(Clone, Default, Debug)]
1238pub struct CertificateChain<'a>(pub Vec<CertificateDer<'a>>);
1239
1240impl CertificateChain<'_> {
1241    pub(crate) fn into_owned(self) -> CertificateChain<'static> {
1242        CertificateChain(
1243            self.0
1244                .into_iter()
1245                .map(|c| c.into_owned())
1246                .collect(),
1247        )
1248    }
1249}
1250
1251impl<'a> Codec<'a> for CertificateChain<'a> {
1252    fn encode(&self, bytes: &mut Vec<u8>) {
1253        Vec::encode(&self.0, bytes)
1254    }
1255
1256    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1257        Vec::read(r).map(Self)
1258    }
1259}
1260
1261impl<'a> Deref for CertificateChain<'a> {
1262    type Target = [CertificateDer<'a>];
1263
1264    fn deref(&self) -> &[CertificateDer<'a>] {
1265        &self.0
1266    }
1267}
1268
1269impl TlsListElement for CertificateDer<'_> {
1270    const SIZE_LEN: ListLength = ListLength::U24 { max: 0x1_0000 };
1271}
1272
1273// TLS1.3 changes the Certificate payload encoding.
1274// That's annoying. It means the parsing is not
1275// context-free any more.
1276
1277#[derive(Debug)]
1278pub(crate) enum CertificateExtension {
1279    CertificateStatus(CertificateStatus),
1280    Unknown(UnknownExtension),
1281}
1282
1283impl CertificateExtension {
1284    pub(crate) fn ext_type(&self) -> ExtensionType {
1285        match *self {
1286            Self::CertificateStatus(_) => ExtensionType::StatusRequest,
1287            Self::Unknown(ref r) => r.typ,
1288        }
1289    }
1290
1291    pub(crate) fn cert_status(&self) -> Option<&Vec<u8>> {
1292        match *self {
1293            Self::CertificateStatus(ref cs) => Some(&cs.ocsp_response.0),
1294            _ => None,
1295        }
1296    }
1297}
1298
1299impl Codec<'_> for CertificateExtension {
1300    fn encode(&self, bytes: &mut Vec<u8>) {
1301        self.ext_type().encode(bytes);
1302
1303        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1304        match *self {
1305            Self::CertificateStatus(ref r) => r.encode(nested.buf),
1306            Self::Unknown(ref r) => r.encode(nested.buf),
1307        }
1308    }
1309
1310    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1311        let typ = ExtensionType::read(r)?;
1312        let len = u16::read(r)? as usize;
1313        let mut sub = r.sub(len)?;
1314
1315        let ext = match typ {
1316            ExtensionType::StatusRequest => {
1317                let st = CertificateStatus::read(&mut sub)?;
1318                Self::CertificateStatus(st)
1319            }
1320            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1321        };
1322
1323        sub.expect_empty("CertificateExtension")
1324            .map(|_| ext)
1325    }
1326}
1327
1328impl TlsListElement for CertificateExtension {
1329    const SIZE_LEN: ListLength = ListLength::U16;
1330}
1331
1332#[derive(Debug)]
1333pub(crate) struct CertificateEntry {
1334    pub(crate) cert: CertificateDer<'static>,
1335    pub(crate) exts: Vec<CertificateExtension>,
1336}
1337
1338impl Codec<'_> for CertificateEntry {
1339    fn encode(&self, bytes: &mut Vec<u8>) {
1340        self.cert.encode(bytes);
1341        self.exts.encode(bytes);
1342    }
1343
1344    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1345        Ok(Self {
1346            cert: CertificateDer::read(r)?.into_owned(),
1347            exts: Vec::read(r)?,
1348        })
1349    }
1350}
1351
1352impl CertificateEntry {
1353    pub(crate) fn new(cert: CertificateDer<'static>) -> Self {
1354        Self {
1355            cert,
1356            exts: Vec::new(),
1357        }
1358    }
1359
1360    pub(crate) fn has_duplicate_extension(&self) -> bool {
1361        has_duplicates::<_, _, u16>(
1362            self.exts
1363                .iter()
1364                .map(|ext| ext.ext_type()),
1365        )
1366    }
1367
1368    pub(crate) fn has_unknown_extension(&self) -> bool {
1369        self.exts
1370            .iter()
1371            .any(|ext| ext.ext_type() != ExtensionType::StatusRequest)
1372    }
1373
1374    pub(crate) fn ocsp_response(&self) -> Option<&Vec<u8>> {
1375        self.exts
1376            .iter()
1377            .find(|ext| ext.ext_type() == ExtensionType::StatusRequest)
1378            .and_then(CertificateExtension::cert_status)
1379    }
1380}
1381
1382impl TlsListElement for CertificateEntry {
1383    const SIZE_LEN: ListLength = ListLength::U24 { max: 0x1_0000 };
1384}
1385
1386#[derive(Debug)]
1387pub struct CertificatePayloadTls13 {
1388    pub(crate) context: PayloadU8,
1389    pub(crate) entries: Vec<CertificateEntry>,
1390}
1391
1392impl Codec<'_> for CertificatePayloadTls13 {
1393    fn encode(&self, bytes: &mut Vec<u8>) {
1394        self.context.encode(bytes);
1395        self.entries.encode(bytes);
1396    }
1397
1398    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1399        Ok(Self {
1400            context: PayloadU8::read(r)?,
1401            entries: Vec::read(r)?,
1402        })
1403    }
1404}
1405
1406impl CertificatePayloadTls13 {
1407    pub(crate) fn new(entries: Vec<CertificateEntry>) -> Self {
1408        Self {
1409            context: PayloadU8::empty(),
1410            entries,
1411        }
1412    }
1413
1414    pub(crate) fn any_entry_has_duplicate_extension(&self) -> bool {
1415        for entry in &self.entries {
1416            if entry.has_duplicate_extension() {
1417                return true;
1418            }
1419        }
1420
1421        false
1422    }
1423
1424    pub(crate) fn any_entry_has_unknown_extension(&self) -> bool {
1425        for entry in &self.entries {
1426            if entry.has_unknown_extension() {
1427                return true;
1428            }
1429        }
1430
1431        false
1432    }
1433
1434    pub(crate) fn any_entry_has_extension(&self) -> bool {
1435        for entry in &self.entries {
1436            if !entry.exts.is_empty() {
1437                return true;
1438            }
1439        }
1440
1441        false
1442    }
1443
1444    pub(crate) fn end_entity_ocsp(&self) -> Vec<u8> {
1445        self.entries
1446            .first()
1447            .and_then(CertificateEntry::ocsp_response)
1448            .cloned()
1449            .unwrap_or_default()
1450    }
1451
1452    pub(crate) fn convert(self) -> CertificateChain<'static> {
1453        CertificateChain(
1454            self.entries
1455                .into_iter()
1456                .map(|e| e.cert)
1457                .collect(),
1458        )
1459    }
1460}
1461
1462/// Describes supported key exchange mechanisms.
1463#[derive(Clone, Copy, Debug, PartialEq)]
1464#[non_exhaustive]
1465pub enum KeyExchangeAlgorithm {
1466    /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]).
1467    ///
1468    /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919
1469    DHE,
1470    /// Key exchange performed via elliptic curve Diffie-Hellman.
1471    ECDHE,
1472}
1473
1474pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] =
1475    &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE];
1476
1477// We don't support arbitrary curves.  It's a terrible
1478// idea and unnecessary attack surface.  Please,
1479// get a grip.
1480#[derive(Debug)]
1481pub(crate) struct EcParameters {
1482    pub(crate) curve_type: ECCurveType,
1483    pub(crate) named_group: NamedGroup,
1484}
1485
1486impl Codec<'_> for EcParameters {
1487    fn encode(&self, bytes: &mut Vec<u8>) {
1488        self.curve_type.encode(bytes);
1489        self.named_group.encode(bytes);
1490    }
1491
1492    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1493        let ct = ECCurveType::read(r)?;
1494        if ct != ECCurveType::NamedCurve {
1495            return Err(InvalidMessage::UnsupportedCurveType);
1496        }
1497
1498        let grp = NamedGroup::read(r)?;
1499
1500        Ok(Self {
1501            curve_type: ct,
1502            named_group: grp,
1503        })
1504    }
1505}
1506
1507#[cfg(feature = "tls12")]
1508pub(crate) trait KxDecode<'a>: fmt::Debug + Sized {
1509    /// Decode a key exchange message given the key_exchange `algo`
1510    fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>;
1511}
1512
1513#[cfg(feature = "tls12")]
1514#[derive(Debug)]
1515pub(crate) enum ClientKeyExchangeParams {
1516    Ecdh(ClientEcdhParams),
1517    Dh(ClientDhParams),
1518}
1519
1520#[cfg(feature = "tls12")]
1521impl ClientKeyExchangeParams {
1522    pub(crate) fn pub_key(&self) -> &[u8] {
1523        match self {
1524            Self::Ecdh(ecdh) => &ecdh.public.0,
1525            Self::Dh(dh) => &dh.public.0,
1526        }
1527    }
1528
1529    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1530        match self {
1531            Self::Ecdh(ecdh) => ecdh.encode(buf),
1532            Self::Dh(dh) => dh.encode(buf),
1533        }
1534    }
1535}
1536
1537#[cfg(feature = "tls12")]
1538impl KxDecode<'_> for ClientKeyExchangeParams {
1539    fn decode(r: &mut Reader, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1540        use KeyExchangeAlgorithm::*;
1541        Ok(match algo {
1542            ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?),
1543            DHE => Self::Dh(ClientDhParams::read(r)?),
1544        })
1545    }
1546}
1547
1548#[derive(Debug)]
1549pub(crate) struct ClientEcdhParams {
1550    pub(crate) public: PayloadU8,
1551}
1552
1553impl Codec<'_> for ClientEcdhParams {
1554    fn encode(&self, bytes: &mut Vec<u8>) {
1555        self.public.encode(bytes);
1556    }
1557
1558    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1559        let pb = PayloadU8::read(r)?;
1560        Ok(Self { public: pb })
1561    }
1562}
1563
1564#[derive(Debug)]
1565pub(crate) struct ClientDhParams {
1566    pub(crate) public: PayloadU16,
1567}
1568
1569impl Codec<'_> for ClientDhParams {
1570    fn encode(&self, bytes: &mut Vec<u8>) {
1571        self.public.encode(bytes);
1572    }
1573
1574    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1575        Ok(Self {
1576            public: PayloadU16::read(r)?,
1577        })
1578    }
1579}
1580
1581#[derive(Debug)]
1582pub(crate) struct ServerEcdhParams {
1583    pub(crate) curve_params: EcParameters,
1584    pub(crate) public: PayloadU8,
1585}
1586
1587impl ServerEcdhParams {
1588    #[cfg(feature = "tls12")]
1589    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1590        Self {
1591            curve_params: EcParameters {
1592                curve_type: ECCurveType::NamedCurve,
1593                named_group: kx.group(),
1594            },
1595            public: PayloadU8::new(kx.pub_key().to_vec()),
1596        }
1597    }
1598}
1599
1600impl Codec<'_> for ServerEcdhParams {
1601    fn encode(&self, bytes: &mut Vec<u8>) {
1602        self.curve_params.encode(bytes);
1603        self.public.encode(bytes);
1604    }
1605
1606    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1607        let cp = EcParameters::read(r)?;
1608        let pb = PayloadU8::read(r)?;
1609
1610        Ok(Self {
1611            curve_params: cp,
1612            public: pb,
1613        })
1614    }
1615}
1616
1617#[derive(Debug)]
1618#[allow(non_snake_case)]
1619pub(crate) struct ServerDhParams {
1620    pub(crate) dh_p: PayloadU16,
1621    pub(crate) dh_g: PayloadU16,
1622    pub(crate) dh_Ys: PayloadU16,
1623}
1624
1625impl ServerDhParams {
1626    #[cfg(feature = "tls12")]
1627    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1628        let params = match FfdheGroup::from_named_group(kx.group()) {
1629            Some(params) => params,
1630            None => panic!("invalid NamedGroup for DHE key exchange: {:?}", kx.group()),
1631        };
1632
1633        Self {
1634            dh_p: PayloadU16::new(params.p.to_vec()),
1635            dh_g: PayloadU16::new(params.g.to_vec()),
1636            dh_Ys: PayloadU16::new(kx.pub_key().to_vec()),
1637        }
1638    }
1639
1640    #[cfg(feature = "tls12")]
1641    fn named_group(&self) -> Option<NamedGroup> {
1642        FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0).named_group()
1643    }
1644}
1645
1646impl Codec<'_> for ServerDhParams {
1647    fn encode(&self, bytes: &mut Vec<u8>) {
1648        self.dh_p.encode(bytes);
1649        self.dh_g.encode(bytes);
1650        self.dh_Ys.encode(bytes);
1651    }
1652
1653    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1654        Ok(Self {
1655            dh_p: PayloadU16::read(r)?,
1656            dh_g: PayloadU16::read(r)?,
1657            dh_Ys: PayloadU16::read(r)?,
1658        })
1659    }
1660}
1661
1662#[allow(dead_code)]
1663#[derive(Debug)]
1664pub(crate) enum ServerKeyExchangeParams {
1665    Ecdh(ServerEcdhParams),
1666    Dh(ServerDhParams),
1667}
1668
1669impl ServerKeyExchangeParams {
1670    #[cfg(feature = "tls12")]
1671    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1672        match kx.group().key_exchange_algorithm() {
1673            KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)),
1674            KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)),
1675        }
1676    }
1677
1678    #[cfg(feature = "tls12")]
1679    pub(crate) fn pub_key(&self) -> &[u8] {
1680        match self {
1681            Self::Ecdh(ecdh) => &ecdh.public.0,
1682            Self::Dh(dh) => &dh.dh_Ys.0,
1683        }
1684    }
1685
1686    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1687        match self {
1688            Self::Ecdh(ecdh) => ecdh.encode(buf),
1689            Self::Dh(dh) => dh.encode(buf),
1690        }
1691    }
1692
1693    #[cfg(feature = "tls12")]
1694    pub(crate) fn named_group(&self) -> Option<NamedGroup> {
1695        match self {
1696            Self::Ecdh(ecdh) => Some(ecdh.curve_params.named_group),
1697            Self::Dh(dh) => dh.named_group(),
1698        }
1699    }
1700}
1701
1702#[cfg(feature = "tls12")]
1703impl KxDecode<'_> for ServerKeyExchangeParams {
1704    fn decode(r: &mut Reader, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1705        use KeyExchangeAlgorithm::*;
1706        Ok(match algo {
1707            ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?),
1708            DHE => Self::Dh(ServerDhParams::read(r)?),
1709        })
1710    }
1711}
1712
1713#[derive(Debug)]
1714pub struct ServerKeyExchange {
1715    pub(crate) params: ServerKeyExchangeParams,
1716    pub(crate) dss: DigitallySignedStruct,
1717}
1718
1719impl ServerKeyExchange {
1720    pub fn encode(&self, buf: &mut Vec<u8>) {
1721        self.params.encode(buf);
1722        self.dss.encode(buf);
1723    }
1724}
1725
1726#[derive(Debug)]
1727pub enum ServerKeyExchangePayload {
1728    Known(ServerKeyExchange),
1729    Unknown(Payload<'static>),
1730}
1731
1732impl From<ServerKeyExchange> for ServerKeyExchangePayload {
1733    fn from(value: ServerKeyExchange) -> Self {
1734        Self::Known(value)
1735    }
1736}
1737
1738impl Codec<'_> for ServerKeyExchangePayload {
1739    fn encode(&self, bytes: &mut Vec<u8>) {
1740        match *self {
1741            Self::Known(ref x) => x.encode(bytes),
1742            Self::Unknown(ref x) => x.encode(bytes),
1743        }
1744    }
1745
1746    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1747        // read as Unknown, fully parse when we know the
1748        // KeyExchangeAlgorithm
1749        Ok(Self::Unknown(Payload::read(r).into_owned()))
1750    }
1751}
1752
1753impl ServerKeyExchangePayload {
1754    #[cfg(feature = "tls12")]
1755    pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> {
1756        if let Self::Unknown(ref unk) = *self {
1757            let mut rd = Reader::init(unk.bytes());
1758
1759            let result = ServerKeyExchange {
1760                params: ServerKeyExchangeParams::decode(&mut rd, kxa).ok()?,
1761                dss: DigitallySignedStruct::read(&mut rd).ok()?,
1762            };
1763
1764            if !rd.any_left() {
1765                return Some(result);
1766            };
1767        }
1768
1769        None
1770    }
1771}
1772
1773// -- EncryptedExtensions (TLS1.3 only) --
1774
1775impl TlsListElement for ServerExtension {
1776    const SIZE_LEN: ListLength = ListLength::U16;
1777}
1778
1779pub(crate) trait HasServerExtensions {
1780    fn extensions(&self) -> &[ServerExtension];
1781
1782    /// Returns true if there is more than one extension of a given
1783    /// type.
1784    fn has_duplicate_extension(&self) -> bool {
1785        has_duplicates::<_, _, u16>(
1786            self.extensions()
1787                .iter()
1788                .map(|ext| ext.ext_type()),
1789        )
1790    }
1791
1792    fn find_extension(&self, ext: ExtensionType) -> Option<&ServerExtension> {
1793        self.extensions()
1794            .iter()
1795            .find(|x| x.ext_type() == ext)
1796    }
1797
1798    fn alpn_protocol(&self) -> Option<&[u8]> {
1799        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
1800        match *ext {
1801            ServerExtension::Protocols(ref protos) => protos.as_single_slice(),
1802            _ => None,
1803        }
1804    }
1805
1806    fn quic_params_extension(&self) -> Option<Vec<u8>> {
1807        let ext = self
1808            .find_extension(ExtensionType::TransportParameters)
1809            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
1810        match *ext {
1811            ServerExtension::TransportParameters(ref bytes)
1812            | ServerExtension::TransportParametersDraft(ref bytes) => Some(bytes.to_vec()),
1813            _ => None,
1814        }
1815    }
1816
1817    fn early_data_extension_offered(&self) -> bool {
1818        self.find_extension(ExtensionType::EarlyData)
1819            .is_some()
1820    }
1821}
1822
1823impl HasServerExtensions for Vec<ServerExtension> {
1824    fn extensions(&self) -> &[ServerExtension] {
1825        self
1826    }
1827}
1828
1829impl TlsListElement for ClientCertificateType {
1830    const SIZE_LEN: ListLength = ListLength::U8;
1831}
1832
1833wrapped_payload!(
1834    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
1835    ///
1836    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
1837    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
1838    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
1839    ///
1840    /// ```ignore
1841    /// for name in distinguished_names {
1842    ///     use x509_parser::prelude::FromDer;
1843    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
1844    /// }
1845    /// ```
1846    pub struct DistinguishedName,
1847    PayloadU16,
1848);
1849
1850impl DistinguishedName {
1851    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
1852    ///
1853    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
1854    ///
1855    /// ```ignore
1856    /// use x509_parser::prelude::FromDer;
1857    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
1858    /// ```
1859    pub fn in_sequence(bytes: &[u8]) -> Self {
1860        Self(PayloadU16::new(wrap_in_sequence(bytes)))
1861    }
1862}
1863
1864impl TlsListElement for DistinguishedName {
1865    const SIZE_LEN: ListLength = ListLength::U16;
1866}
1867
1868#[derive(Debug)]
1869pub struct CertificateRequestPayload {
1870    pub(crate) certtypes: Vec<ClientCertificateType>,
1871    pub(crate) sigschemes: Vec<SignatureScheme>,
1872    pub(crate) canames: Vec<DistinguishedName>,
1873}
1874
1875impl Codec<'_> for CertificateRequestPayload {
1876    fn encode(&self, bytes: &mut Vec<u8>) {
1877        self.certtypes.encode(bytes);
1878        self.sigschemes.encode(bytes);
1879        self.canames.encode(bytes);
1880    }
1881
1882    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1883        let certtypes = Vec::read(r)?;
1884        let sigschemes = Vec::read(r)?;
1885        let canames = Vec::read(r)?;
1886
1887        if sigschemes.is_empty() {
1888            warn!("meaningless CertificateRequest message");
1889            Err(InvalidMessage::NoSignatureSchemes)
1890        } else {
1891            Ok(Self {
1892                certtypes,
1893                sigschemes,
1894                canames,
1895            })
1896        }
1897    }
1898}
1899
1900#[derive(Debug)]
1901pub(crate) enum CertReqExtension {
1902    SignatureAlgorithms(Vec<SignatureScheme>),
1903    AuthorityNames(Vec<DistinguishedName>),
1904    Unknown(UnknownExtension),
1905}
1906
1907impl CertReqExtension {
1908    pub(crate) fn ext_type(&self) -> ExtensionType {
1909        match *self {
1910            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
1911            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
1912            Self::Unknown(ref r) => r.typ,
1913        }
1914    }
1915}
1916
1917impl Codec<'_> for CertReqExtension {
1918    fn encode(&self, bytes: &mut Vec<u8>) {
1919        self.ext_type().encode(bytes);
1920
1921        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1922        match *self {
1923            Self::SignatureAlgorithms(ref r) => r.encode(nested.buf),
1924            Self::AuthorityNames(ref r) => r.encode(nested.buf),
1925            Self::Unknown(ref r) => r.encode(nested.buf),
1926        }
1927    }
1928
1929    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1930        let typ = ExtensionType::read(r)?;
1931        let len = u16::read(r)? as usize;
1932        let mut sub = r.sub(len)?;
1933
1934        let ext = match typ {
1935            ExtensionType::SignatureAlgorithms => {
1936                let schemes = Vec::read(&mut sub)?;
1937                if schemes.is_empty() {
1938                    return Err(InvalidMessage::NoSignatureSchemes);
1939                }
1940                Self::SignatureAlgorithms(schemes)
1941            }
1942            ExtensionType::CertificateAuthorities => {
1943                let cas = Vec::read(&mut sub)?;
1944                Self::AuthorityNames(cas)
1945            }
1946            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1947        };
1948
1949        sub.expect_empty("CertReqExtension")
1950            .map(|_| ext)
1951    }
1952}
1953
1954impl TlsListElement for CertReqExtension {
1955    const SIZE_LEN: ListLength = ListLength::U16;
1956}
1957
1958#[derive(Debug)]
1959pub struct CertificateRequestPayloadTls13 {
1960    pub(crate) context: PayloadU8,
1961    pub(crate) extensions: Vec<CertReqExtension>,
1962}
1963
1964impl Codec<'_> for CertificateRequestPayloadTls13 {
1965    fn encode(&self, bytes: &mut Vec<u8>) {
1966        self.context.encode(bytes);
1967        self.extensions.encode(bytes);
1968    }
1969
1970    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1971        let context = PayloadU8::read(r)?;
1972        let extensions = Vec::read(r)?;
1973
1974        Ok(Self {
1975            context,
1976            extensions,
1977        })
1978    }
1979}
1980
1981impl CertificateRequestPayloadTls13 {
1982    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&CertReqExtension> {
1983        self.extensions
1984            .iter()
1985            .find(|x| x.ext_type() == ext)
1986    }
1987
1988    pub(crate) fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
1989        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
1990        match *ext {
1991            CertReqExtension::SignatureAlgorithms(ref sa) => Some(sa),
1992            _ => None,
1993        }
1994    }
1995
1996    pub(crate) fn authorities_extension(&self) -> Option<&[DistinguishedName]> {
1997        let ext = self.find_extension(ExtensionType::CertificateAuthorities)?;
1998        match *ext {
1999            CertReqExtension::AuthorityNames(ref an) => Some(an),
2000            _ => None,
2001        }
2002    }
2003}
2004
2005// -- NewSessionTicket --
2006#[derive(Debug)]
2007pub struct NewSessionTicketPayload {
2008    pub(crate) lifetime_hint: u32,
2009    pub(crate) ticket: PayloadU16,
2010}
2011
2012impl NewSessionTicketPayload {
2013    #[cfg(feature = "tls12")]
2014    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
2015        Self {
2016            lifetime_hint,
2017            ticket: PayloadU16::new(ticket),
2018        }
2019    }
2020}
2021
2022impl Codec<'_> for NewSessionTicketPayload {
2023    fn encode(&self, bytes: &mut Vec<u8>) {
2024        self.lifetime_hint.encode(bytes);
2025        self.ticket.encode(bytes);
2026    }
2027
2028    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2029        let lifetime = u32::read(r)?;
2030        let ticket = PayloadU16::read(r)?;
2031
2032        Ok(Self {
2033            lifetime_hint: lifetime,
2034            ticket,
2035        })
2036    }
2037}
2038
2039// -- NewSessionTicket electric boogaloo --
2040#[derive(Debug)]
2041pub(crate) enum NewSessionTicketExtension {
2042    EarlyData(u32),
2043    Unknown(UnknownExtension),
2044}
2045
2046impl NewSessionTicketExtension {
2047    pub(crate) fn ext_type(&self) -> ExtensionType {
2048        match *self {
2049            Self::EarlyData(_) => ExtensionType::EarlyData,
2050            Self::Unknown(ref r) => r.typ,
2051        }
2052    }
2053}
2054
2055impl Codec<'_> for NewSessionTicketExtension {
2056    fn encode(&self, bytes: &mut Vec<u8>) {
2057        self.ext_type().encode(bytes);
2058
2059        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2060        match *self {
2061            Self::EarlyData(r) => r.encode(nested.buf),
2062            Self::Unknown(ref r) => r.encode(nested.buf),
2063        }
2064    }
2065
2066    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2067        let typ = ExtensionType::read(r)?;
2068        let len = u16::read(r)? as usize;
2069        let mut sub = r.sub(len)?;
2070
2071        let ext = match typ {
2072            ExtensionType::EarlyData => Self::EarlyData(u32::read(&mut sub)?),
2073            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2074        };
2075
2076        sub.expect_empty("NewSessionTicketExtension")
2077            .map(|_| ext)
2078    }
2079}
2080
2081impl TlsListElement for NewSessionTicketExtension {
2082    const SIZE_LEN: ListLength = ListLength::U16;
2083}
2084
2085#[derive(Debug)]
2086pub struct NewSessionTicketPayloadTls13 {
2087    pub(crate) lifetime: u32,
2088    pub(crate) age_add: u32,
2089    pub(crate) nonce: PayloadU8,
2090    pub(crate) ticket: PayloadU16,
2091    pub(crate) exts: Vec<NewSessionTicketExtension>,
2092}
2093
2094impl NewSessionTicketPayloadTls13 {
2095    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
2096        Self {
2097            lifetime,
2098            age_add,
2099            nonce: PayloadU8::new(nonce),
2100            ticket: PayloadU16::new(ticket),
2101            exts: vec![],
2102        }
2103    }
2104
2105    pub(crate) fn has_duplicate_extension(&self) -> bool {
2106        has_duplicates::<_, _, u16>(
2107            self.exts
2108                .iter()
2109                .map(|ext| ext.ext_type()),
2110        )
2111    }
2112
2113    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&NewSessionTicketExtension> {
2114        self.exts
2115            .iter()
2116            .find(|x| x.ext_type() == ext)
2117    }
2118
2119    pub(crate) fn max_early_data_size(&self) -> Option<u32> {
2120        let ext = self.find_extension(ExtensionType::EarlyData)?;
2121        match *ext {
2122            NewSessionTicketExtension::EarlyData(ref sz) => Some(*sz),
2123            _ => None,
2124        }
2125    }
2126}
2127
2128impl Codec<'_> for NewSessionTicketPayloadTls13 {
2129    fn encode(&self, bytes: &mut Vec<u8>) {
2130        self.lifetime.encode(bytes);
2131        self.age_add.encode(bytes);
2132        self.nonce.encode(bytes);
2133        self.ticket.encode(bytes);
2134        self.exts.encode(bytes);
2135    }
2136
2137    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2138        let lifetime = u32::read(r)?;
2139        let age_add = u32::read(r)?;
2140        let nonce = PayloadU8::read(r)?;
2141        let ticket = PayloadU16::read(r)?;
2142        let exts = Vec::read(r)?;
2143
2144        Ok(Self {
2145            lifetime,
2146            age_add,
2147            nonce,
2148            ticket,
2149            exts,
2150        })
2151    }
2152}
2153
2154// -- RFC6066 certificate status types
2155
2156/// Only supports OCSP
2157#[derive(Debug)]
2158pub struct CertificateStatus {
2159    pub(crate) ocsp_response: PayloadU24,
2160}
2161
2162impl Codec<'_> for CertificateStatus {
2163    fn encode(&self, bytes: &mut Vec<u8>) {
2164        CertificateStatusType::OCSP.encode(bytes);
2165        self.ocsp_response.encode(bytes);
2166    }
2167
2168    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2169        let typ = CertificateStatusType::read(r)?;
2170
2171        match typ {
2172            CertificateStatusType::OCSP => Ok(Self {
2173                ocsp_response: PayloadU24::read(r)?,
2174            }),
2175            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2176        }
2177    }
2178}
2179
2180impl CertificateStatus {
2181    pub(crate) fn new(ocsp: Vec<u8>) -> Self {
2182        Self {
2183            ocsp_response: PayloadU24::new(ocsp),
2184        }
2185    }
2186
2187    #[cfg(feature = "tls12")]
2188    pub(crate) fn into_inner(self) -> Vec<u8> {
2189        self.ocsp_response.0
2190    }
2191}
2192
2193#[derive(Debug)]
2194pub enum HandshakePayload<'a> {
2195    HelloRequest,
2196    ClientHello(ClientHelloPayload),
2197    ServerHello(ServerHelloPayload),
2198    HelloRetryRequest(HelloRetryRequest),
2199    Certificate(CertificateChain<'a>),
2200    CertificateTls13(CertificatePayloadTls13),
2201    ServerKeyExchange(ServerKeyExchangePayload),
2202    CertificateRequest(CertificateRequestPayload),
2203    CertificateRequestTls13(CertificateRequestPayloadTls13),
2204    CertificateVerify(DigitallySignedStruct),
2205    ServerHelloDone,
2206    EndOfEarlyData,
2207    ClientKeyExchange(Payload<'a>),
2208    NewSessionTicket(NewSessionTicketPayload),
2209    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2210    EncryptedExtensions(Vec<ServerExtension>),
2211    KeyUpdate(KeyUpdateRequest),
2212    Finished(Payload<'a>),
2213    CertificateStatus(CertificateStatus),
2214    MessageHash(Payload<'a>),
2215    Unknown(Payload<'a>),
2216}
2217
2218impl HandshakePayload<'_> {
2219    fn encode(&self, bytes: &mut Vec<u8>) {
2220        use self::HandshakePayload::*;
2221        match *self {
2222            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2223            ClientHello(ref x) => x.encode(bytes),
2224            ServerHello(ref x) => x.encode(bytes),
2225            HelloRetryRequest(ref x) => x.encode(bytes),
2226            Certificate(ref x) => x.encode(bytes),
2227            CertificateTls13(ref x) => x.encode(bytes),
2228            ServerKeyExchange(ref x) => x.encode(bytes),
2229            ClientKeyExchange(ref x) => x.encode(bytes),
2230            CertificateRequest(ref x) => x.encode(bytes),
2231            CertificateRequestTls13(ref x) => x.encode(bytes),
2232            CertificateVerify(ref x) => x.encode(bytes),
2233            NewSessionTicket(ref x) => x.encode(bytes),
2234            NewSessionTicketTls13(ref x) => x.encode(bytes),
2235            EncryptedExtensions(ref x) => x.encode(bytes),
2236            KeyUpdate(ref x) => x.encode(bytes),
2237            Finished(ref x) => x.encode(bytes),
2238            CertificateStatus(ref x) => x.encode(bytes),
2239            MessageHash(ref x) => x.encode(bytes),
2240            Unknown(ref x) => x.encode(bytes),
2241        }
2242    }
2243
2244    fn into_owned(self) -> HandshakePayload<'static> {
2245        use HandshakePayload::*;
2246
2247        match self {
2248            HelloRequest => HelloRequest,
2249            ClientHello(x) => ClientHello(x),
2250            ServerHello(x) => ServerHello(x),
2251            HelloRetryRequest(x) => HelloRetryRequest(x),
2252            Certificate(x) => Certificate(x.into_owned()),
2253            CertificateTls13(x) => CertificateTls13(x),
2254            ServerKeyExchange(x) => ServerKeyExchange(x),
2255            CertificateRequest(x) => CertificateRequest(x),
2256            CertificateRequestTls13(x) => CertificateRequestTls13(x),
2257            CertificateVerify(x) => CertificateVerify(x),
2258            ServerHelloDone => ServerHelloDone,
2259            EndOfEarlyData => EndOfEarlyData,
2260            ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()),
2261            NewSessionTicket(x) => NewSessionTicket(x),
2262            NewSessionTicketTls13(x) => NewSessionTicketTls13(x),
2263            EncryptedExtensions(x) => EncryptedExtensions(x),
2264            KeyUpdate(x) => KeyUpdate(x),
2265            Finished(x) => Finished(x.into_owned()),
2266            CertificateStatus(x) => CertificateStatus(x),
2267            MessageHash(x) => MessageHash(x.into_owned()),
2268            Unknown(x) => Unknown(x.into_owned()),
2269        }
2270    }
2271}
2272
2273#[derive(Debug)]
2274pub struct HandshakeMessagePayload<'a> {
2275    pub typ: HandshakeType,
2276    pub payload: HandshakePayload<'a>,
2277}
2278
2279impl<'a> Codec<'a> for HandshakeMessagePayload<'a> {
2280    fn encode(&self, bytes: &mut Vec<u8>) {
2281        // output type, length, and encoded payload
2282        match self.typ {
2283            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2284            _ => self.typ,
2285        }
2286        .encode(bytes);
2287
2288        let nested = LengthPrefixedBuffer::new(ListLength::U24 { max: usize::MAX }, bytes);
2289        self.payload.encode(nested.buf);
2290    }
2291
2292    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2293        Self::read_version(r, ProtocolVersion::TLSv1_2)
2294    }
2295}
2296
2297impl<'a> HandshakeMessagePayload<'a> {
2298    pub(crate) fn read_version(
2299        r: &mut Reader<'a>,
2300        vers: ProtocolVersion,
2301    ) -> Result<Self, InvalidMessage> {
2302        let mut typ = HandshakeType::read(r)?;
2303        let len = codec::u24::read(r)?.0 as usize;
2304        let mut sub = r.sub(len)?;
2305
2306        let payload = match typ {
2307            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2308            HandshakeType::ClientHello => {
2309                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2310            }
2311            HandshakeType::ServerHello => {
2312                let version = ProtocolVersion::read(&mut sub)?;
2313                let random = Random::read(&mut sub)?;
2314
2315                if random == HELLO_RETRY_REQUEST_RANDOM {
2316                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2317                    hrr.legacy_version = version;
2318                    typ = HandshakeType::HelloRetryRequest;
2319                    HandshakePayload::HelloRetryRequest(hrr)
2320                } else {
2321                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2322                    shp.legacy_version = version;
2323                    shp.random = random;
2324                    HandshakePayload::ServerHello(shp)
2325                }
2326            }
2327            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2328                let p = CertificatePayloadTls13::read(&mut sub)?;
2329                HandshakePayload::CertificateTls13(p)
2330            }
2331            HandshakeType::Certificate => {
2332                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2333            }
2334            HandshakeType::ServerKeyExchange => {
2335                let p = ServerKeyExchangePayload::read(&mut sub)?;
2336                HandshakePayload::ServerKeyExchange(p)
2337            }
2338            HandshakeType::ServerHelloDone => {
2339                sub.expect_empty("ServerHelloDone")?;
2340                HandshakePayload::ServerHelloDone
2341            }
2342            HandshakeType::ClientKeyExchange => {
2343                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2344            }
2345            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2346                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2347                HandshakePayload::CertificateRequestTls13(p)
2348            }
2349            HandshakeType::CertificateRequest => {
2350                let p = CertificateRequestPayload::read(&mut sub)?;
2351                HandshakePayload::CertificateRequest(p)
2352            }
2353            HandshakeType::CertificateVerify => {
2354                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2355            }
2356            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2357                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2358                HandshakePayload::NewSessionTicketTls13(p)
2359            }
2360            HandshakeType::NewSessionTicket => {
2361                let p = NewSessionTicketPayload::read(&mut sub)?;
2362                HandshakePayload::NewSessionTicket(p)
2363            }
2364            HandshakeType::EncryptedExtensions => {
2365                HandshakePayload::EncryptedExtensions(Vec::read(&mut sub)?)
2366            }
2367            HandshakeType::KeyUpdate => {
2368                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2369            }
2370            HandshakeType::EndOfEarlyData => {
2371                sub.expect_empty("EndOfEarlyData")?;
2372                HandshakePayload::EndOfEarlyData
2373            }
2374            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2375            HandshakeType::CertificateStatus => {
2376                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2377            }
2378            HandshakeType::MessageHash => {
2379                // does not appear on the wire
2380                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2381            }
2382            HandshakeType::HelloRetryRequest => {
2383                // not legal on wire
2384                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2385            }
2386            _ => HandshakePayload::Unknown(Payload::read(&mut sub)),
2387        };
2388
2389        sub.expect_empty("HandshakeMessagePayload")
2390            .map(|_| Self { typ, payload })
2391    }
2392
2393    pub(crate) fn build_key_update_notify() -> Self {
2394        Self {
2395            typ: HandshakeType::KeyUpdate,
2396            payload: HandshakePayload::KeyUpdate(KeyUpdateRequest::UpdateNotRequested),
2397        }
2398    }
2399
2400    pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> {
2401        let mut ret = self.get_encoding();
2402
2403        let binder_len = match self.payload {
2404            HandshakePayload::ClientHello(ref ch) => match ch.extensions.last() {
2405                Some(ClientExtension::PresharedKey(ref offer)) => {
2406                    let mut binders_encoding = Vec::new();
2407                    offer
2408                        .binders
2409                        .encode(&mut binders_encoding);
2410                    binders_encoding.len()
2411                }
2412                _ => 0,
2413            },
2414            _ => 0,
2415        };
2416
2417        let ret_len = ret.len() - binder_len;
2418        ret.truncate(ret_len);
2419        ret
2420    }
2421
2422    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
2423        Self {
2424            typ: HandshakeType::MessageHash,
2425            payload: HandshakePayload::MessageHash(Payload::new(hash.to_vec())),
2426        }
2427    }
2428
2429    pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> {
2430        let Self { typ, payload } = self;
2431        HandshakeMessagePayload {
2432            typ,
2433            payload: payload.into_owned(),
2434        }
2435    }
2436}
2437
2438#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
2439pub struct HpkeSymmetricCipherSuite {
2440    pub kdf_id: HpkeKdf,
2441    pub aead_id: HpkeAead,
2442}
2443
2444impl Codec<'_> for HpkeSymmetricCipherSuite {
2445    fn encode(&self, bytes: &mut Vec<u8>) {
2446        self.kdf_id.encode(bytes);
2447        self.aead_id.encode(bytes);
2448    }
2449
2450    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2451        Ok(Self {
2452            kdf_id: HpkeKdf::read(r)?,
2453            aead_id: HpkeAead::read(r)?,
2454        })
2455    }
2456}
2457
2458impl TlsListElement for HpkeSymmetricCipherSuite {
2459    const SIZE_LEN: ListLength = ListLength::U16;
2460}
2461
2462#[derive(Clone, Debug, PartialEq)]
2463pub struct HpkeKeyConfig {
2464    pub config_id: u8,
2465    pub kem_id: HpkeKem,
2466    pub public_key: PayloadU16,
2467    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
2468}
2469
2470impl Codec<'_> for HpkeKeyConfig {
2471    fn encode(&self, bytes: &mut Vec<u8>) {
2472        self.config_id.encode(bytes);
2473        self.kem_id.encode(bytes);
2474        self.public_key.encode(bytes);
2475        self.symmetric_cipher_suites
2476            .encode(bytes);
2477    }
2478
2479    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2480        Ok(Self {
2481            config_id: u8::read(r)?,
2482            kem_id: HpkeKem::read(r)?,
2483            public_key: PayloadU16::read(r)?,
2484            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
2485        })
2486    }
2487}
2488
2489#[derive(Clone, Debug, PartialEq)]
2490pub struct EchConfigContents {
2491    pub key_config: HpkeKeyConfig,
2492    pub maximum_name_length: u8,
2493    pub public_name: DnsName<'static>,
2494    pub extensions: PayloadU16,
2495}
2496
2497impl Codec<'_> for EchConfigContents {
2498    fn encode(&self, bytes: &mut Vec<u8>) {
2499        self.key_config.encode(bytes);
2500        self.maximum_name_length.encode(bytes);
2501        let dns_name = &self.public_name.borrow();
2502        PayloadU8::encode_slice(dns_name.as_ref().as_ref(), bytes);
2503        self.extensions.encode(bytes);
2504    }
2505
2506    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2507        Ok(Self {
2508            key_config: HpkeKeyConfig::read(r)?,
2509            maximum_name_length: u8::read(r)?,
2510            public_name: {
2511                DnsName::try_from(PayloadU8::read(r)?.0.as_slice())
2512                    .map_err(|_| InvalidMessage::InvalidServerName)?
2513                    .to_owned()
2514            },
2515            extensions: PayloadU16::read(r)?,
2516        })
2517    }
2518}
2519
2520#[derive(Clone, Debug, PartialEq)]
2521pub struct EchConfig {
2522    pub version: EchVersion,
2523    pub contents: EchConfigContents,
2524}
2525
2526impl Codec<'_> for EchConfig {
2527    fn encode(&self, bytes: &mut Vec<u8>) {
2528        self.version.encode(bytes);
2529        let mut contents = Vec::with_capacity(128);
2530        self.contents.encode(&mut contents);
2531        let length: &mut [u8; 2] = &mut [0, 0];
2532        codec::put_u16(contents.len() as u16, length);
2533        bytes.extend_from_slice(length);
2534        bytes.extend(contents);
2535    }
2536
2537    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2538        let version = EchVersion::read(r)?;
2539        let length = u16::read(r)?;
2540        Ok(Self {
2541            version,
2542            contents: EchConfigContents::read(&mut r.sub(length as usize)?)?,
2543        })
2544    }
2545}
2546
2547impl TlsListElement for EchConfig {
2548    const SIZE_LEN: ListLength = ListLength::U16;
2549}
2550
2551fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool {
2552    let mut seen = BTreeSet::new();
2553
2554    for x in iter {
2555        if !seen.insert(x.into()) {
2556            return true;
2557        }
2558    }
2559
2560    false
2561}