1use alloc::vec::Vec;
4use nom::bytes::streaming::take;
5use nom::combinator::{complete, cond, map, map_parser, opt, verify};
6use nom::error::{make_error, ErrorKind};
7use nom::multi::{length_data, many1};
8use nom::number::streaming::{be_u16, be_u24, be_u64, be_u8};
9use nom::{Err, IResult};
10use nom_derive::Parse;
11
12use crate::tls_handshake::*;
13use crate::tls_message::*;
14use crate::tls_record::{TlsRecordType, MAX_RECORD_LEN};
15use crate::TlsMessageAlert;
16
17#[derive(Debug, PartialEq)]
19pub struct DTLSRecordHeader {
20    pub content_type: TlsRecordType,
21    pub version: TlsVersion,
22    pub epoch: u16,
24    pub sequence_number: u64, pub length: u16,
27}
28
29#[derive(Debug, PartialEq)]
35pub struct DTLSPlaintext<'a> {
36    pub header: DTLSRecordHeader,
37    pub messages: Vec<DTLSMessage<'a>>,
38}
39
40#[derive(Debug, PartialEq)]
41pub struct DTLSRawRecord<'a> {
42    pub header: DTLSRecordHeader,
43    pub fragment: &'a [u8],
44}
45
46#[derive(Debug, PartialEq)]
47pub struct DTLSClientHello<'a> {
48    pub version: TlsVersion,
49    pub random: &'a [u8],
50    pub session_id: Option<&'a [u8]>,
51    pub cookie: &'a [u8],
52    pub ciphers: Vec<TlsCipherSuiteID>,
54    pub comp: Vec<TlsCompressionID>,
56    pub ext: Option<&'a [u8]>,
57}
58
59impl<'a> ClientHello<'a> for DTLSClientHello<'a> {
60    fn version(&self) -> TlsVersion {
61        self.version
62    }
63
64    fn random(&self) -> &'a [u8] {
65        self.random
66    }
67
68    fn session_id(&self) -> Option<&'a [u8]> {
69        self.session_id
70    }
71
72    fn ciphers(&self) -> &Vec<TlsCipherSuiteID> {
73        &self.ciphers
74    }
75
76    fn comp(&self) -> &Vec<TlsCompressionID> {
77        &self.comp
78    }
79
80    fn ext(&self) -> Option<&'a [u8]> {
81        self.ext
82    }
83}
84
85#[derive(Debug, PartialEq)]
86pub struct DTLSHelloVerifyRequest<'a> {
87    pub server_version: TlsVersion,
88    pub cookie: &'a [u8],
89}
90
91#[derive(Debug, PartialEq)]
93pub struct DTLSMessageHandshake<'a> {
94    pub msg_type: TlsHandshakeType,
95    pub length: u32,
96    pub message_seq: u16,
97    pub fragment_offset: u32,
98    pub fragment_length: u32,
99    pub body: DTLSMessageHandshakeBody<'a>,
100}
101
102#[derive(Debug, PartialEq)]
104pub enum DTLSMessageHandshakeBody<'a> {
105    HelloRequest,
106    ClientHello(DTLSClientHello<'a>),
107    HelloVerifyRequest(DTLSHelloVerifyRequest<'a>),
108    ServerHello(TlsServerHelloContents<'a>),
109    NewSessionTicket(TlsNewSessionTicketContent<'a>),
110    HelloRetryRequest(TlsHelloRetryRequestContents<'a>),
111    Certificate(TlsCertificateContents<'a>),
112    ServerKeyExchange(TlsServerKeyExchangeContents<'a>),
113    CertificateRequest(TlsCertificateRequestContents<'a>),
114    ServerDone(&'a [u8]),
115    CertificateVerify(&'a [u8]),
116    ClientKeyExchange(TlsClientKeyExchangeContents<'a>),
117    Finished(&'a [u8]),
118    CertificateStatus(TlsCertificateStatusContents<'a>),
119    NextProtocol(TlsNextProtocolContent<'a>),
120    Fragment(&'a [u8]),
121}
122
123#[derive(Debug, PartialEq)]
127pub enum DTLSMessage<'a> {
128    Handshake(DTLSMessageHandshake<'a>),
129    ChangeCipherSpec,
130    Alert(TlsMessageAlert),
131    ApplicationData(TlsMessageApplicationData<'a>),
132    Heartbeat(TlsMessageHeartbeat<'a>),
133}
134
135impl<'a> DTLSMessage<'a> {
136    pub fn is_fragment(&self) -> bool {
139        match self {
140            DTLSMessage::Handshake(h) => matches!(h.body, DTLSMessageHandshakeBody::Fragment(_)),
141            _ => false,
142        }
143    }
144}
145
146pub fn parse_dtls_record_header(i: &[u8]) -> IResult<&[u8], DTLSRecordHeader> {
151    let (i, content_type) = TlsRecordType::parse(i)?;
152    let (i, version) = TlsVersion::parse(i)?;
153    let (i, int0) = be_u64(i)?;
154    let epoch = (int0 >> 48) as u16;
155    let sequence_number = int0 & 0xffff_ffff_ffff;
156    let (i, length) = be_u16(i)?;
157    let record = DTLSRecordHeader {
158        content_type,
159        version,
160        epoch,
161        sequence_number,
162        length,
163    };
164    Ok((i, record))
165}
166
167fn parse_dtls_fragment(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody> {
169    Ok((&[], DTLSMessageHandshakeBody::Fragment(i)))
170}
171
172fn parse_dtls_client_hello(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody> {
175    let (i, version) = TlsVersion::parse(i)?;
176    let (i, random) = take(32usize)(i)?;
177    let (i, sidlen) = verify(be_u8, |&n| n <= 32)(i)?;
178    let (i, session_id) = cond(sidlen > 0, take(sidlen as usize))(i)?;
179    let (i, cookie) = length_data(be_u8)(i)?;
180    let (i, ciphers_len) = be_u16(i)?;
181    let (i, ciphers) = parse_cipher_suites(i, ciphers_len as usize)?;
182    let (i, comp_len) = be_u8(i)?;
183    let (i, comp) = parse_compressions_algs(i, comp_len as usize)?;
184    let (i, ext) = opt(complete(length_data(be_u16)))(i)?;
185    let content = DTLSClientHello {
186        version,
187        random,
188        session_id,
189        cookie,
190        ciphers,
191        comp,
192        ext,
193    };
194    Ok((i, DTLSMessageHandshakeBody::ClientHello(content)))
195}
196
197fn parse_dtls_hello_verify_request(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody> {
200    let (i, server_version) = TlsVersion::parse(i)?;
201    let (i, cookie) = length_data(be_u8)(i)?;
202    let content = DTLSHelloVerifyRequest {
203        server_version,
204        cookie,
205    };
206    Ok((i, DTLSMessageHandshakeBody::HelloVerifyRequest(content)))
207}
208
209fn parse_dtls_handshake_msg_server_hello_tlsv12(
210    i: &[u8],
211) -> IResult<&[u8], DTLSMessageHandshakeBody> {
212    map(
213        parse_tls_server_hello_tlsv12::<true>,
214        DTLSMessageHandshakeBody::ServerHello,
215    )(i)
216}
217
218fn parse_dtls_handshake_msg_serverdone(
219    i: &[u8],
220    len: usize,
221) -> IResult<&[u8], DTLSMessageHandshakeBody> {
222    map(take(len), DTLSMessageHandshakeBody::ServerDone)(i)
223}
224
225fn parse_dtls_handshake_msg_clientkeyexchange(
226    i: &[u8],
227    len: usize,
228) -> IResult<&[u8], DTLSMessageHandshakeBody> {
229    map(
230        parse_tls_clientkeyexchange(len),
231        DTLSMessageHandshakeBody::ClientKeyExchange,
232    )(i)
233}
234
235fn parse_dtls_handshake_msg_certificate(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody> {
236    map(parse_tls_certificate, DTLSMessageHandshakeBody::Certificate)(i)
237}
238
239pub fn parse_dtls_message_handshake(i: &[u8]) -> IResult<&[u8], DTLSMessage> {
241    let (i, msg_type) = map(be_u8, TlsHandshakeType)(i)?;
242    let (i, length) = be_u24(i)?;
243    let (i, message_seq) = be_u16(i)?;
244    let (i, fragment_offset) = be_u24(i)?;
245    let (i, fragment_length) = be_u24(i)?;
246    let (i, raw_msg) = take(fragment_length)(i)?;
248
249    let is_fragment = fragment_offset > 0 || fragment_length < length;
253
254    let (_, body) = match msg_type {
255        _ if is_fragment => parse_dtls_fragment(raw_msg),
256        TlsHandshakeType::ClientHello => parse_dtls_client_hello(raw_msg),
257        TlsHandshakeType::HelloVerifyRequest => parse_dtls_hello_verify_request(raw_msg),
258        TlsHandshakeType::ServerHello => parse_dtls_handshake_msg_server_hello_tlsv12(raw_msg),
259        TlsHandshakeType::ServerDone => {
260            parse_dtls_handshake_msg_serverdone(raw_msg, length as usize)
261        }
262        TlsHandshakeType::ClientKeyExchange => {
263            parse_dtls_handshake_msg_clientkeyexchange(raw_msg, length as usize)
264        }
265        TlsHandshakeType::Certificate => parse_dtls_handshake_msg_certificate(raw_msg),
266        _ => {
267            Err(Err::Error(make_error(i, ErrorKind::Switch)))
269        }
270    }?;
271    let msg = DTLSMessageHandshake {
272        msg_type,
273        length,
274        message_seq,
275        fragment_offset,
276        fragment_length,
277        body,
278    };
279    Ok((i, DTLSMessage::Handshake(msg)))
280}
281
282pub fn parse_dtls_message_changecipherspec(i: &[u8]) -> IResult<&[u8], DTLSMessage> {
285    let (i, _) = verify(be_u8, |&tag| tag == 0x01)(i)?;
286    Ok((i, DTLSMessage::ChangeCipherSpec))
287}
288
289pub fn parse_dtls_message_alert(i: &[u8]) -> IResult<&[u8], DTLSMessage> {
292    let (i, alert) = TlsMessageAlert::parse(i)?;
293    Ok((i, DTLSMessage::Alert(alert)))
294}
295
296pub fn parse_dtls_record_with_header<'i>(
297    i: &'i [u8],
298    hdr: &DTLSRecordHeader,
299) -> IResult<&'i [u8], Vec<DTLSMessage<'i>>> {
300    match hdr.content_type {
301        TlsRecordType::ChangeCipherSpec => many1(complete(parse_dtls_message_changecipherspec))(i),
302        TlsRecordType::Alert => many1(complete(parse_dtls_message_alert))(i),
303        TlsRecordType::Handshake => many1(complete(parse_dtls_message_handshake))(i),
304        _ => {
307            Err(Err::Error(make_error(i, ErrorKind::Switch)))
309        }
310    }
311}
312
313pub fn parse_dtls_plaintext_record(i: &[u8]) -> IResult<&[u8], DTLSPlaintext> {
316    let (i, header) = parse_dtls_record_header(i)?;
317    if header.length > MAX_RECORD_LEN {
319        return Err(Err::Error(make_error(i, ErrorKind::TooLarge)));
320    }
321    let (i, messages) = map_parser(take(header.length as usize), |i| {
322        parse_dtls_record_with_header(i, &header)
323    })(i)?;
324    Ok((i, DTLSPlaintext { header, messages }))
325}
326
327pub fn parse_dtls_plaintext_records(i: &[u8]) -> IResult<&[u8], Vec<DTLSPlaintext>> {
330    many1(complete(parse_dtls_plaintext_record))(i)
331}