1#![no_std]
2
3use core::{borrow::BorrowMut, marker::PhantomData, mem, net::SocketAddr, ops::Range};
4
5use crypto::TrafficSecret;
6use handshake::{
7 handle_handshake_message_client, handle_handshake_message_server, ClientState,
8 CryptoInformation, HandshakeContext, HandshakeInformation, ServerState,
9};
10use log::{debug, info, trace};
11use parsing::{
12 encode_alert, encode_hello_retry, parse_alert, parse_client_hello_first_pass,
13 parse_client_hello_second_pass, ClientHelloResult, EncodeAck, EncodeHandshakeMessage,
14 HandshakeType, HelloRetryCookie, ParseHandshakeMessage,
15};
16
17pub use crypto::{HashFunction, Psk};
18
19use parsing_utility::ParseBuffer;
20use record_parsing::{
21 parse_plaintext_record, parse_record, EncodeCiphertextRecord, EncodePlaintextRecord,
22 RecordContentType,
23};
24
25mod fmt;
26
27#[cfg(feature = "async")]
28mod asynchronous;
29#[cfg(feature = "async")]
30pub use asynchronous::{DtlsStackAsync, Event};
31
32mod sync;
33pub use sync::DtlsStack;
34
35mod buffer_record_queue;
36mod crypto;
37mod handshake;
38mod parsing;
39mod parsing_utility;
40mod record_parsing;
41
42type Epoch = u64;
43type EpochShort = u8;
44
45type TimeStampMs = u64;
46
47type HandshakeSeqNum = u16;
48
49type RecordSeqNum = u64;
50type RecordSeqNumShort = u8;
51
52type Connections<'a> = [Option<DtlsConnection<'a>>];
53type RecordQueue<'a> = buffer_record_queue::BufferMessageQueue<'a>;
54
55#[derive(Debug, Clone, Copy)]
56#[repr(u8)]
57pub enum AlertDescription {
58 CloseNotify = 0,
59 UnexpectedMessage = 10,
60 IllegalParameter = 47,
61 DecodeError = 50,
62 DecryptionError = 51,
63 MissingExtension = 109,
64 UnsupportedExtension = 110,
65 Unknown,
66}
67
68impl From<u8> for AlertDescription {
69 fn from(value: u8) -> Self {
70 match value {
71 0 => AlertDescription::CloseNotify,
72 10 => AlertDescription::UnexpectedMessage,
73 47 => AlertDescription::IllegalParameter,
74 50 => AlertDescription::DecodeError,
75 51 => AlertDescription::DecryptionError,
76 109 => AlertDescription::MissingExtension,
77 110 => AlertDescription::UnsupportedExtension,
78 _ => AlertDescription::Unknown,
79 }
80 }
81}
82
83impl AlertDescription {
84 pub fn alert_level(&self) -> AlertLevel {
85 match self {
86 AlertDescription::UnexpectedMessage
87 | AlertDescription::IllegalParameter
88 | AlertDescription::DecodeError
89 | AlertDescription::DecryptionError
90 | AlertDescription::MissingExtension
91 | AlertDescription::UnsupportedExtension
92 | AlertDescription::CloseNotify
93 | AlertDescription::Unknown => AlertLevel::Fatal,
94 }
95 }
96}
97
98#[derive(Debug, PartialEq, Eq)]
99#[repr(u8)]
100pub enum AlertLevel {
101 Warning = 1,
102 Fatal = 2,
103}
104
105impl From<u8> for AlertLevel {
106 fn from(value: u8) -> Self {
107 match value {
108 1 => AlertLevel::Warning,
109 _ => AlertLevel::Fatal,
110 }
111 }
112}
113
114#[derive(Debug)]
115pub enum DtlsError {
116 MaximumConnectionsReached,
117 UnknownConnection,
118 MaximumRetransmissionsReached,
119 HandshakeAlreadyRunning,
120 OutOfMemory,
121 IllegalInnerState,
123 IoError,
124 RngError,
125 ParseError,
126 CryptoError,
127 NoMatchingEpoch,
128 RejectedSequenceNumber,
129 Alert(AlertDescription),
130 MultipleRecordsPerPacketNotSupported,
131}
132
133#[derive(PartialEq, Eq, Clone, Copy)]
134pub struct ConnectionId(usize);
135
136impl core::fmt::Debug for ConnectionId {
137 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
138 self.0.fmt(f)
139 }
140}
141
142struct DtlsConnection<'a> {
143 epochs: heapless::Vec<EpochState, 4>,
144 current_epoch: Epoch,
145 pub addr: SocketAddr,
146 handshake_finished: bool,
147 p: PhantomData<&'a ()>,
148}
149
150#[derive(Debug, PartialEq, Eq)]
151pub enum DtlsPoll {
152 WaitTimeoutMs(u32),
155 Wait,
157 FinishedHandshake,
160}
161
162impl DtlsPoll {
163 pub fn merge(self, other: Self) -> Self {
164 match (self, other) {
165 (DtlsPoll::FinishedHandshake, _) | (_, DtlsPoll::FinishedHandshake) => {
166 DtlsPoll::FinishedHandshake
167 }
168 (DtlsPoll::WaitTimeoutMs(t1), DtlsPoll::WaitTimeoutMs(t2)) => {
169 DtlsPoll::WaitTimeoutMs(t1.min(t2))
170 }
171 (DtlsPoll::WaitTimeoutMs(t), _) | (_, DtlsPoll::WaitTimeoutMs(t)) => {
172 DtlsPoll::WaitTimeoutMs(t)
173 }
174 (DtlsPoll::Wait, DtlsPoll::Wait) => DtlsPoll::Wait,
175 }
176 }
177}
178
179enum DeferredAction<'a> {
180 None,
181 Send(&'a [u8]),
182 AppData(ConnectionId, Range<usize>),
183 Unhandled,
184}
185
186fn try_pass_packet_to_connection<'a>(
187 staging_buffer: &'a mut [u8],
188 connections: &mut [Option<DtlsConnection>],
189 addr: &SocketAddr,
190 packet_len: usize,
191) -> Result<DeferredAction<'a>, DtlsError> {
192 for i in 0..connections.len() {
193 let connection = match connections[i].as_mut() {
194 Some(c) if &c.addr == addr && c.handshake_finished => c,
195 _ => continue,
196 };
197 let mut packet_buffer = ParseBuffer::init(&mut staging_buffer[..packet_len]);
198 let res = parse_record(&mut packet_buffer, &mut connection.epochs);
199 let action = match res {
200 Ok(RecordContentType::ApplicationData) => DeferredAction::AppData(
201 ConnectionId(i),
202 packet_buffer.offset()..packet_buffer.capacity(),
203 ),
204 Ok(RecordContentType::DtlsHandshake) => {
205 trace!("Received handshake message on existing connection");
206 match ParseHandshakeMessage::retrieve_content_type(&mut packet_buffer) {
207 Ok(
209 HandshakeType::ServerHello
210 | HandshakeType::EncryptedExtension
211 | HandshakeType::Finished,
212 ) if connection.current_epoch < 6 => {
213 debug!("Found retransmitted handshake message. Resending ack.");
214 DeferredAction::Send(stage_ack(
215 staging_buffer,
216 &mut connection.epochs,
217 &3,
218 &2,
219 )?)
220 }
221 _ => close_connection(ConnectionId(i), staging_buffer, connections),
222 }
223 }
224 Ok(RecordContentType::Ack) => DeferredAction::None,
225 Ok(_) => close_connection(ConnectionId(i), staging_buffer, connections),
226 Err(err) => {
227 trace!("Received broken record: {:?}", err);
228 if let DtlsError::Alert(alert) = err {
229 DeferredAction::Send(stage_alert(
230 staging_buffer,
231 &mut connection.epochs,
232 &connection.current_epoch,
233 alert,
234 )?)
235 } else {
236 DeferredAction::None
237 }
238 }
239 };
240 return Ok(action);
241 }
242 Ok(DeferredAction::Unhandled)
243}
244
245fn close_connection<'a>(
246 connection_id: ConnectionId,
247 staging_buffer: &'a mut [u8],
248 connections: &mut [Option<DtlsConnection>],
249) -> DeferredAction<'a> {
250 debug_assert!(connection_id.0 < connections.len());
251 let mut action = DeferredAction::None;
252 if connection_id.0 < connections.len() {
253 if let Some(c) = connections[connection_id.0].as_mut() {
254 if !c.handshake_finished {
255 return action;
256 }
257 if let Ok(buf) = stage_alert(
258 staging_buffer,
259 &mut c.epochs,
260 &c.current_epoch,
261 AlertDescription::CloseNotify,
262 ) {
263 action = DeferredAction::Send(buf);
264 }
265 }
266 connections[connection_id.0] = None;
267 }
268 action
269}
270
271fn try_pass_packet_to_handshake<'a>(
272 staging_buffer: &'a mut [u8],
273 connections: &mut [Option<DtlsConnection>],
274 handshakes: &mut [HandshakeSlot],
275 addr: &SocketAddr,
276 packet_len: usize,
277) -> Result<DeferredAction<'a>, DtlsError> {
278 for handshake in handshakes {
279 if let HandshakeSlotState::Running {
280 state,
281 handshake: ctx,
282 } = &mut handshake.state
283 {
284 let connection = ctx.connection(connections);
285 if &connection.addr != addr {
286 continue;
287 }
288 let Some((content_type, mut packet)) =
289 try_unpack_record(&mut staging_buffer[..packet_len], &mut connection.epochs)?
290 else {
291 return Ok(DeferredAction::None);
292 };
293 let mut new_state = *state;
294 if content_type == RecordContentType::Alert {
295 let (level, desc) = parse_alert(&mut packet)?;
296 info!("Received alert: {:?}, {:?}", level, desc);
297 handshake.close(connections);
298 continue;
299 }
300 else if content_type == RecordContentType::ApplicationData
302 && matches!(state, HandshakeState::Client(ClientState::WaitServerAck))
303 {
304 debug!("[Client] Acked client finished through app data");
305 let start = packet.offset();
306 let end = packet.capacity();
307 let id = ctx.conn_id;
308 handshake.finish_handshake(connection);
309 return Ok(DeferredAction::AppData(ConnectionId(id), start..end));
310 } else {
311 let res = match &mut new_state {
312 HandshakeState::Client(state) => handle_handshake_message_client(
313 state,
314 ctx,
315 &mut handshake.rt_queue,
316 connection,
317 content_type,
318 packet,
319 ),
320 HandshakeState::Server(state) => handle_handshake_message_server(
321 state,
322 ctx,
323 &mut handshake.rt_queue,
324 connection,
325 content_type,
326 packet,
327 ),
328 };
329 if let Err(DtlsError::Alert(alert)) = res {
330 let buf = stage_alert(
331 staging_buffer,
332 &mut connection.epochs,
333 &connection.current_epoch,
334 alert,
335 )?;
336 return Ok(DeferredAction::Send(buf));
337 }
338 res?;
339 *state = new_state;
340 if matches!(
342 state,
343 HandshakeState::Server(ServerState::FinishedHandshake)
344 ) {
345 debug!("[Server] Send ACK for client finish");
346 return Ok(DeferredAction::Send(stage_ack(
347 staging_buffer,
348 &mut connection.epochs,
349 &3,
350 &2,
351 )?));
352 }
353 }
354 return Ok(DeferredAction::None);
355 }
356 }
357 Ok(DeferredAction::Unhandled)
358}
359
360fn create_handshake_connection<'a, 'b>(
361 connections: &'a mut [Option<DtlsConnection<'b>>],
362 addr: &SocketAddr,
363) -> Result<(usize, &'a mut DtlsConnection<'b>), DtlsError> {
364 let slot = find_empty_connection_slot(connections);
365 if let Some(slot) = slot {
366 connections[slot] = Some(DtlsConnection {
367 epochs: heapless::Vec::new(),
368 current_epoch: 0,
369 addr: *addr,
370 handshake_finished: false,
371 p: PhantomData,
372 });
373 let _ = connections[slot]
374 .as_mut()
375 .unwrap()
376 .epochs
377 .push(EpochState::empty());
378 Ok((slot, connections[slot].as_mut().unwrap()))
379 } else {
380 Err(DtlsError::MaximumConnectionsReached)
381 }
382}
383
384fn open_connection(
385 connections: &mut Connections,
386 slot: &mut HandshakeSlot,
387 addr: &SocketAddr,
388) -> bool {
389 let Ok((conn_id, _)) = create_handshake_connection(connections, addr) else {
390 return false;
391 };
392
393 let HandshakeSlotState::Empty = slot.state else {
394 return false;
395 };
396 slot.state = HandshakeSlotState::Running {
397 state: HandshakeState::Client(ClientState::default()),
398 handshake: HandshakeContext {
399 recv_handshake_seq_num: 0,
400 send_handshake_seq_num: 0,
401 conn_id,
402 info: HandshakeInformation {
403 available_psks: slot.psks,
404 selected_psk: None,
405 crypto: CryptoInformation::new(),
406 selected_cipher_suite: None,
407 received_hello_retry_request: false,
408 },
409 },
410 };
411
412 true
413}
414
415fn find_empty_connection_slot(connections: &mut [Option<DtlsConnection>]) -> Option<usize> {
416 for (i, c) in connections.iter().enumerate() {
417 if c.is_none() {
418 return Some(i);
419 }
420 }
421 None
422}
423
424fn try_open_new_handshake<'a>(
425 staging_buffer: &'a mut [u8],
426 require_cookie: bool,
427 cookie_key: &[u8],
428 handshakes: &mut [HandshakeSlot],
429 connections: &mut [Option<DtlsConnection>],
430 addr: &SocketAddr,
431 packet_len: usize,
432) -> Result<Option<&'a [u8]>, DtlsError> {
433 let mut packet_buffer = ParseBuffer::init(&mut staging_buffer[..packet_len]);
434 let mut epoch_states = [EpochState::empty()];
435 let res = parse_plaintext_record(&mut packet_buffer, &mut epoch_states);
436 let Ok(RecordContentType::DtlsHandshake) = res else {
437 return Ok(None);
438 };
439 let mut send_buf = None;
440 for handshake_slot in handshakes {
441 if !matches!(handshake_slot.state, HandshakeSlotState::Empty) {
442 continue;
443 }
444 let Ok((conn_id, conn)) = create_handshake_connection(connections, addr) else {
445 return Ok(None);
446 };
447 conn.epochs[0] = epoch_states.into_iter().next().unwrap();
448
449 handshake_slot.fill(conn_id);
450 let HandshakeSlotState::Running {
451 state: _,
452 handshake: ctx,
453 } = &mut handshake_slot.state
454 else {
455 unreachable!()
456 };
457 let Ok((mut client_hello, HandshakeType::ClientHello, client_hello_seq_num @ (0 | 1))) =
458 ParseHandshakeMessage::new(packet_buffer)
459 else {
460 break;
461 };
462 let client_hello_start = client_hello.payload_buffer().offset();
463 match parse_client_hello_first_pass(
464 client_hello.payload_buffer(),
465 require_cookie,
466 cookie_key,
467 addr,
468 ctx,
469 &mut handshake_slot.rt_queue,
470 ) {
471 Ok(ClientHelloResult::MissingCookie) => {
472 debug!("[Server] Didn't find valid cookie. Sending hello_retry");
473 client_hello.add_to_transcript_hash(&mut ctx.info.crypto);
474 send_buf = Some(stage_hello_retry_message(
475 staging_buffer,
476 cookie_key,
477 addr,
478 &mut ctx.info,
479 )?);
480 }
481 Ok(ClientHelloResult::Ok) => {
482 if require_cookie {
483 debug!("[Server] Found valid cookie opening handshake");
484 }
485 parse_client_hello_second_pass(
486 client_hello.payload_buffer(),
487 &mut ctx.info,
488 client_hello_start,
489 )?;
490 client_hello.add_to_transcript_hash(&mut ctx.info.crypto);
491 conn.epochs[0].send_record_seq_num = client_hello_seq_num as u64;
492 ctx.send_handshake_seq_num = client_hello_seq_num as u8;
493 ctx.recv_handshake_seq_num = client_hello_seq_num as u8 + 1;
494 break;
495 }
496 Err(err) => {
497 debug!("[Server] Error parsing client_hello: {err:?}");
498 }
499 }
500 handshake_slot.close(connections);
501 connections[conn_id] = None;
502 break;
503 }
504 Ok(send_buf)
505}
506
507fn stage_hello_retry_message<'a>(
508 staging_buffer: &'a mut [u8],
509 cookie_key: &[u8],
510 addr: &SocketAddr,
511 info: &mut HandshakeInformation,
512) -> Result<&'a [u8], DtlsError> {
513 let mut buffer = ParseBuffer::init(staging_buffer.borrow_mut());
514 let mut record = EncodePlaintextRecord::new(&mut buffer, RecordContentType::DtlsHandshake, 0)?;
515 let mut handshake =
516 EncodeHandshakeMessage::new(record.payload_buffer(), HandshakeType::ServerHello, 0)?;
517 encode_hello_retry(
518 handshake.payload_buffer(),
519 &[],
520 info.selected_cipher_suite
521 .ok_or(DtlsError::IllegalInnerState)?,
522 HelloRetryCookie::calculate(info.crypto.psk_hash_mut()?, cookie_key, addr),
523 )?;
524 handshake.finish(&mut info.crypto);
525 record.finish();
526 let offset = buffer.offset();
527 Ok(&buffer.release_buffer()[..offset])
528}
529
530fn stage_ack<'a>(
531 staging_buffer: &'a mut [u8],
532 epoch_states: &mut [EpochState],
533 epoch: &u64,
534 ack_epoch: &u64,
535) -> Result<&'a [u8], DtlsError> {
536 let mut buffer = ParseBuffer::init(staging_buffer);
537 let send_epoch_index = *epoch as usize & 3;
538 let ack_epoch_index = *ack_epoch as usize & 3;
539 let max_entries = (buffer.capacity() as u64 - 2) / 16;
540 let mut record =
541 EncodeCiphertextRecord::new(&mut buffer, &epoch_states[send_epoch_index], epoch)?;
542 let mut ack = EncodeAck::new(record.payload_buffer())?;
543 let w = &epoch_states[ack_epoch_index].sliding_window;
544 let r = &epoch_states[ack_epoch_index].receive_record_seq_num;
545 let mut index = 1;
546 for i in 0..64.min(max_entries) {
547 if w & index > 0 {
548 let s = r - i;
549 ack.add_entry(ack_epoch, &s)?;
550 }
551 index <<= 1;
552 }
553 ack.finish();
554 record.finish(&mut epoch_states[send_epoch_index], RecordContentType::Ack)?;
555 let offset = buffer.offset();
556 Ok(&buffer.release_buffer()[..offset])
557}
558
559fn stage_alert<'a>(
560 staging_buffer: &'a mut [u8],
561 epoch_states: &mut [EpochState],
562 epoch: &u64,
563 alert: AlertDescription,
564) -> Result<&'a [u8], DtlsError> {
565 info!("Sending alert: {:?}", alert);
566 let epoch_index = *epoch as usize & 3;
567 let mut buffer = ParseBuffer::init(staging_buffer);
568 if epoch < &2 {
569 let mut record = EncodePlaintextRecord::new(
570 &mut buffer,
571 RecordContentType::Alert,
572 epoch_states[epoch_index].send_record_seq_num,
573 )?;
574 encode_alert(record.payload_buffer(), alert, alert.alert_level())?;
575 record.finish();
576 } else {
577 let mut record =
578 EncodeCiphertextRecord::new(&mut buffer, &epoch_states[epoch_index], epoch)?;
579 encode_alert(record.payload_buffer(), alert, alert.alert_level())?;
580 record.finish(&mut epoch_states[epoch_index], RecordContentType::Alert)?;
581 }
582 let offset = buffer.offset();
583 Ok(&buffer.release_buffer()[..offset])
584}
585
586pub struct HandshakeSlot<'a> {
587 rt_queue: RecordQueue<'a>,
588 psks: &'a [Psk<'a>],
589 state: HandshakeSlotState<'a>,
590}
591
592#[derive(Default)]
593pub enum HandshakeSlotState<'a> {
594 Running {
595 state: HandshakeState,
596 handshake: HandshakeContext<'a>,
597 },
598 #[default]
599 Empty,
600 Finished(ConnectionId),
601}
602
603#[derive(Clone, Copy)]
604pub enum HandshakeState {
605 Client(ClientState),
606 Server(ServerState),
607}
608
609impl<'a> HandshakeSlot<'a> {
610 pub fn new(available_psks: &'a [Psk<'a>], buffer: &'a mut [u8]) -> Self {
611 HandshakeSlot {
612 rt_queue: RecordQueue::new(buffer),
613 psks: available_psks,
614 state: HandshakeSlotState::Empty,
615 }
616 }
617
618 fn fill(&mut self, conn_id: usize) {
619 if let HandshakeSlotState::Empty = self.state {
620 self.state = HandshakeSlotState::Running {
621 state: HandshakeState::Server(ServerState::default()),
622 handshake: HandshakeContext {
623 recv_handshake_seq_num: 0,
624 send_handshake_seq_num: 0,
625 conn_id,
626 info: HandshakeInformation {
627 received_hello_retry_request: false,
628 available_psks: self.psks,
629 selected_psk: None,
630 crypto: CryptoInformation::new(),
631 selected_cipher_suite: None,
632 },
633 },
634 }
635 }
636 }
637
638 pub fn try_take_connection_id(&mut self) -> Option<ConnectionId> {
639 if let HandshakeSlotState::Finished(cid) = self.state {
640 self.state = HandshakeSlotState::Empty;
641 Some(cid)
642 } else {
643 None
644 }
645 }
646
647 fn finish_handshake(&mut self, conn: &mut DtlsConnection) {
648 if let HandshakeSlotState::Running {
649 state: _,
650 handshake: ctx,
651 } = mem::take(&mut self.state)
652 {
653 conn.handshake_finished = true;
654 self.rt_queue.reset();
655 let id = ctx.conn_id;
656 self.state = HandshakeSlotState::Finished(ConnectionId(id));
657 }
658 }
659
660 fn close(&mut self, connections: &mut [Option<DtlsConnection>]) {
661 debug!("Closing handshake prematurely");
662 match mem::take(&mut self.state) {
663 HandshakeSlotState::Running {
664 state: _,
665 handshake: c,
666 } => {
667 connections[c.conn_id] = None;
668 }
669 HandshakeSlotState::Empty | HandshakeSlotState::Finished(_) => {}
670 }
671 self.rt_queue.reset();
672 }
673}
674
675fn try_unpack_record<'a>(
676 packet: &'a mut [u8],
677 viable_epochs: &mut [EpochState],
678) -> Result<Option<(RecordContentType, ParseBuffer<'a>)>, DtlsError> {
679 let mut packet_buffer = ParseBuffer::init(packet);
680 let res = parse_record(&mut packet_buffer, viable_epochs);
681
682 match res {
683 Err(DtlsError::NoMatchingEpoch) => {
684 trace!("Rejected record because no cipher state was present for its epoch");
685 Ok(None)
686 }
687 Err(DtlsError::RejectedSequenceNumber) => {
688 trace!("Rejected record because it was already received");
689 Ok(None)
690 }
691 Err(DtlsError::ParseError | DtlsError::CryptoError) => {
692 trace!("Rejected record because it was broken");
693 Ok(None)
694 }
695 Err(err) => Err(err),
696 Ok(content_type) => Ok(Some((content_type, packet_buffer))),
697 }
698}
699
700struct EpochState {
701 send_record_seq_num: RecordSeqNum,
702 receive_record_seq_num: RecordSeqNum,
703 read_traffic_secret: TrafficSecret,
704 write_traffic_secret: TrafficSecret,
705 sliding_window: u64,
706}
707
708impl EpochState {
709 pub const fn new(
710 read_traffic_secret: TrafficSecret,
711 write_traffic_secret: TrafficSecret,
712 ) -> Self {
713 Self {
714 send_record_seq_num: 0,
715 receive_record_seq_num: 0,
716 read_traffic_secret,
717 write_traffic_secret,
718 sliding_window: 0,
719 }
720 }
721
722 pub const fn empty() -> Self {
723 Self::new(TrafficSecret::None, TrafficSecret::None)
724 }
725
726 pub(crate) fn check_seq_num(&self, seq_num: &u64) -> Result<(), DtlsError> {
727 const WINDOW_MAX_SHIFT_BITS: u64 = 64 - 1;
728 let highest_seq_num = self.receive_record_seq_num;
729
730 if highest_seq_num > *seq_num {
731 let diff = highest_seq_num - seq_num;
732 if diff > WINDOW_MAX_SHIFT_BITS {
733 return Err(DtlsError::RejectedSequenceNumber);
734 }
735 let window_index = 1u64 << diff;
736 if self.sliding_window & window_index > 0 {
737 return Err(DtlsError::RejectedSequenceNumber);
739 }
740 } else {
741 let shift = seq_num - highest_seq_num;
742 if shift == 0 && self.sliding_window & 1 == 1 {
743 return Err(DtlsError::RejectedSequenceNumber);
745 }
746 }
747 Ok(())
748 }
749
750 pub(crate) fn mark_received(&mut self, seq_num: &u64) {
751 let highest_seq_num = &self.receive_record_seq_num;
752 if highest_seq_num > seq_num {
753 let diff = highest_seq_num - seq_num;
754 let window_index = 1u64 << diff;
755 debug_assert!(self.sliding_window & window_index == 0);
756 self.sliding_window |= window_index;
757 } else {
758 let shift = seq_num - highest_seq_num;
759 if shift >= 64 {
760 self.sliding_window = 0;
761 } else {
762 self.sliding_window <<= shift;
763 }
764 self.receive_record_seq_num = *seq_num;
765 debug_assert!(self.sliding_window & 1 == 0);
766 self.sliding_window |= 1;
767 }
768 }
769}
770
771#[cfg(test)]
772mod tests {
773
774 use crate::{crypto::TrafficSecret, DtlsError, EpochState};
775
776 #[test]
777 pub fn reject_double_receive() {
778 let mut state = EpochState::new(TrafficSecret::None, TrafficSecret::None);
779 state.check_seq_num(&2).unwrap();
780 state.mark_received(&2);
781 assert!(matches!(
782 state.check_seq_num(&2),
783 Err(DtlsError::RejectedSequenceNumber)
784 ));
785 }
786 #[test]
787 pub fn reject_too_old_receive() {
788 let mut state = EpochState::new(TrafficSecret::None, TrafficSecret::None);
789 state.mark_received(&64);
790 assert!(matches!(
791 state.check_seq_num(&0),
792 Err(DtlsError::RejectedSequenceNumber)
793 ));
794 }
795 #[test]
796 pub fn correctly_check_after_shift() {
797 let mut state = EpochState::new(TrafficSecret::None, TrafficSecret::None);
798 state.mark_received(&20);
799 state.mark_received(&64);
800 assert!(matches!(
801 state.check_seq_num(&20),
802 Err(DtlsError::RejectedSequenceNumber)
803 ));
804 }
805}