1#[cfg(test)]
2mod conn_test;
3
4use crate::alert::*;
5use crate::application_data::*;
6use crate::content::*;
7use crate::curve::named_curve::NamedCurve;
8use crate::extension::extension_use_srtp::*;
9use crate::flight::flight0::*;
10use crate::flight::flight1::*;
11use crate::flight::flight5::*;
12use crate::flight::flight6::*;
13use crate::flight::*;
14use crate::fragment_buffer::*;
15use crate::handshake::handshake_cache::*;
16use crate::handshake::handshake_header::HandshakeHeader;
17use crate::handshake::*;
18use crate::handshaker::*;
19use crate::record_layer::record_layer_header::*;
20use crate::record_layer::*;
21use crate::state::*;
22use std::collections::VecDeque;
23
24use shared::{error::*, replay_detector::*};
25
26use crate::config::HandshakeConfig;
27use bytes::BytesMut;
28use log::*;
29use std::io::{BufReader, BufWriter};
30use std::sync::Arc;
31use std::time::{Duration, Instant};
32
33pub(crate) const INITIAL_TICKER_INTERVAL: Duration = Duration::from_secs(1);
34pub(crate) const COOKIE_LENGTH: usize = 20;
35pub(crate) const DEFAULT_NAMED_CURVE: NamedCurve = NamedCurve::X25519;
36pub(crate) const INBOUND_BUFFER_SIZE: usize = 8192;
37pub(crate) const DEFAULT_REPLAY_PROTECTION_WINDOW: usize = 64;
39
40pub(crate) static INVALID_KEYING_LABELS: &[&str] = &[
41 "client finished",
42 "server finished",
43 "master secret",
44 "key expansion",
45];
46
47pub struct DTLSConn {
49 is_client: bool,
50 maximum_transmission_unit: usize,
51 pub(crate) maximum_retransmit_number: usize,
52 replay_protection_window: usize,
53 replay_detector: Vec<Box<dyn ReplayDetector>>,
54 incoming_decrypted_packets: VecDeque<BytesMut>, incoming_encrypted_packets: VecDeque<Vec<u8>>,
56 fragment_buffer: FragmentBuffer,
57 pub(crate) cache: HandshakeCache, pub(crate) outgoing_packets: VecDeque<Packet>,
59 outgoing_queued_packets: VecDeque<Packet>,
60 outgoing_compacted_raw_packets: VecDeque<BytesMut>,
61
62 pub(crate) state: State, handshake_completed: bool,
65 connection_closed_by_user: bool,
66 closed: bool, pub(crate) current_handshake_state: HandshakeState,
81 pub(crate) current_retransmit_timer: Option<Instant>,
82 pub(crate) current_retransmit_count: usize,
83
84 pub(crate) current_flight: Box<dyn Flight>,
85 pub(crate) flights: Option<Vec<Packet>>,
86 pub(crate) handshake_config: Arc<HandshakeConfig>,
87 pub(crate) retransmit: bool,
88 pub(crate) handshake_rx: Option<()>,
89}
90
91impl DTLSConn {
92 pub fn new(
93 handshake_config: Arc<HandshakeConfig>,
94 is_client: bool,
95 initial_state: Option<State>,
96 ) -> Self {
97 let (state, flight, initial_fsm_state) = if let Some(state) = initial_state {
98 let flight = if is_client {
99 Box::new(Flight5 {}) as Box<dyn Flight>
100 } else {
101 Box::new(Flight6 {}) as Box<dyn Flight>
102 };
103
104 (state, flight, HandshakeState::Finished)
105 } else {
106 let flight = if is_client {
107 Box::new(Flight1 {}) as Box<dyn Flight>
108 } else {
109 Box::new(Flight0 {}) as Box<dyn Flight>
110 };
111
112 (
113 State {
114 is_client,
115 ..Default::default()
116 },
117 flight,
118 HandshakeState::Preparing,
119 )
120 };
121
122 Self {
123 is_client,
124 maximum_transmission_unit: handshake_config.maximum_transmission_unit,
125 maximum_retransmit_number: handshake_config.maximum_retransmit_number,
126 replay_protection_window: handshake_config.replay_protection_window,
127 replay_detector: vec![],
128 incoming_decrypted_packets: VecDeque::new(),
129 incoming_encrypted_packets: VecDeque::new(),
130 fragment_buffer: FragmentBuffer::new(),
131 outgoing_packets: VecDeque::new(),
132 outgoing_queued_packets: VecDeque::new(),
133 outgoing_compacted_raw_packets: VecDeque::new(),
134
135 cache: HandshakeCache::new(),
136 state,
137 handshake_completed: false,
138 connection_closed_by_user: false,
139 closed: false,
140
141 current_handshake_state: initial_fsm_state,
142 current_retransmit_timer: None,
143 current_retransmit_count: 0,
144
145 current_flight: flight,
146 flights: None,
147 handshake_config,
148 retransmit: false,
149 handshake_rx: None,
150 }
151 }
152
153 pub fn incoming_application_data(&mut self) -> Option<BytesMut> {
155 if !self.is_handshake_completed() {
156 None
157 } else {
158 self.incoming_decrypted_packets.pop_front()
159 }
160 }
161
162 pub fn outgoing_raw_packet(&mut self) -> Option<BytesMut> {
163 if let Err(err) = self.handle_outgoing_packets() {
164 warn!(
165 "handle_outgoing_packets [{}] with error {}",
166 srv_cli_str(self.is_client),
167 err
168 );
169 }
170 self.outgoing_compacted_raw_packets.pop_front()
171 }
172
173 pub fn write(&mut self, p: &[u8]) -> Result<()> {
175 if self.is_connection_closed() {
176 return Err(Error::ErrConnClosed);
177 }
178
179 let pkt = Packet {
180 record: RecordLayer::new(
181 PROTOCOL_VERSION1_2,
182 self.get_local_epoch(),
183 Content::ApplicationData(ApplicationData {
184 data: BytesMut::from(p),
185 }),
186 ),
187 should_encrypt: true,
188 reset_local_sequence_number: false,
189 };
190
191 if self.is_handshake_completed() {
192 self.write_packets(vec![pkt]);
193 } else {
194 self.outgoing_queued_packets.push_back(pkt);
195 }
196
197 Ok(())
198 }
199
200 pub fn close(&mut self) {
202 if !self.closed {
203 self.closed = true;
204
205 self.notify(AlertLevel::Warning, AlertDescription::CloseNotify);
208 }
209 }
210
211 pub fn connection_state(&self) -> &State {
214 &self.state
215 }
216
217 pub(crate) fn selected_srtp_protection_profile(&self) -> SrtpProtectionProfile {
219 self.state.srtp_protection_profile
220 }
221
222 pub(crate) fn notify(&mut self, level: AlertLevel, desc: AlertDescription) {
223 self.write_packets(vec![Packet {
224 record: RecordLayer::new(
225 PROTOCOL_VERSION1_2,
226 self.get_local_epoch(),
227 Content::Alert(Alert {
228 alert_level: level,
229 alert_description: desc,
230 }),
231 ),
232 should_encrypt: self.is_handshake_completed(),
233 reset_local_sequence_number: false,
234 }]);
235 }
236
237 pub(crate) fn write_packets(&mut self, pkts: Vec<Packet>) {
238 for pkt in pkts {
239 self.outgoing_packets.push_back(pkt);
240 }
241 }
242
243 fn handle_outgoing_packets(&mut self) -> Result<()> {
244 if self.is_handshake_completed() {
245 while let Some(mut pkt) = self.outgoing_queued_packets.pop_front() {
246 pkt.record.record_layer_header.epoch = self.get_local_epoch();
247 self.write_packets(vec![pkt]);
248 }
249 }
250
251 let mut raw_packets = vec![];
252 while let Some(p) = self.outgoing_packets.pop_front() {
253 if let Content::Handshake(h) = &p.record.content {
254 let mut handshake_raw = vec![];
255 {
256 let mut writer = BufWriter::<&mut Vec<u8>>::new(handshake_raw.as_mut());
257 p.record.marshal(&mut writer)?;
258 }
259 debug!(
260 "Send [handshake:{}] -> {} (epoch: {}, seq: {})",
261 srv_cli_str(self.is_client),
262 h.handshake_header.handshake_type,
263 p.record.record_layer_header.epoch,
264 h.handshake_header.message_sequence
265 );
266 self.cache.push(
267 handshake_raw[RECORD_LAYER_HEADER_SIZE..].to_vec(),
268 p.record.record_layer_header.epoch,
269 h.handshake_header.message_sequence,
270 h.handshake_header.handshake_type,
271 self.is_client,
272 );
273
274 let raw_handshake_packets = self.process_handshake_packet(&p, h)?;
275 raw_packets.extend_from_slice(&raw_handshake_packets);
276 } else {
277 let raw_packet = self.process_packet(p)?;
284 raw_packets.push(raw_packet);
285 }
286 }
287
288 if !raw_packets.is_empty() {
289 let compacted_raw_packets =
290 compact_raw_packets(&raw_packets, self.maximum_transmission_unit);
291
292 for compacted_raw_packets in compacted_raw_packets {
293 self.outgoing_compacted_raw_packets
294 .push_back(compacted_raw_packets);
295 }
296 }
297
298 Ok(())
299 }
300
301 fn process_packet(&mut self, mut p: Packet) -> Result<Vec<u8>> {
302 let epoch = p.record.record_layer_header.epoch as usize;
303 let seq = {
304 while self.state.local_sequence_number.len() <= epoch {
305 self.state.local_sequence_number.push(0);
306 }
307
308 self.state.local_sequence_number[epoch] += 1;
309 self.state.local_sequence_number[epoch] - 1
310 };
311 if seq > MAX_SEQUENCE_NUMBER {
314 return Err(Error::ErrSequenceNumberOverflow);
318 }
319 p.record.record_layer_header.sequence_number = seq;
320
321 let mut raw_packet = vec![];
322 {
323 let mut writer = BufWriter::<&mut Vec<u8>>::new(raw_packet.as_mut());
324 p.record.marshal(&mut writer)?;
325 }
326
327 if p.should_encrypt
328 && let Some(cipher_suite) = &self.state.cipher_suite
329 {
330 raw_packet = cipher_suite.encrypt(&p.record.record_layer_header, &raw_packet)?;
331 }
332
333 Ok(raw_packet)
334 }
335
336 fn process_handshake_packet(&mut self, p: &Packet, h: &Handshake) -> Result<Vec<Vec<u8>>> {
337 let mut raw_packets = vec![];
338
339 let handshake_fragments = DTLSConn::fragment_handshake(self.maximum_transmission_unit, h)?;
340
341 let epoch = p.record.record_layer_header.epoch as usize;
342
343 while self.state.local_sequence_number.len() <= epoch {
344 self.state.local_sequence_number.push(0);
345 }
346
347 for handshake_fragment in &handshake_fragments {
348 let seq = {
349 self.state.local_sequence_number[epoch] += 1;
350 self.state.local_sequence_number[epoch] - 1
351 };
352 if seq > MAX_SEQUENCE_NUMBER {
354 return Err(Error::ErrSequenceNumberOverflow);
355 }
356
357 let record_layer_header = RecordLayerHeader {
358 protocol_version: p.record.record_layer_header.protocol_version,
359 content_type: p.record.record_layer_header.content_type,
360 content_len: handshake_fragment.len() as u16,
361 epoch: p.record.record_layer_header.epoch,
362 sequence_number: seq,
363 };
364
365 let mut record_layer_header_bytes = vec![];
366 {
367 let mut writer = BufWriter::<&mut Vec<u8>>::new(record_layer_header_bytes.as_mut());
368 record_layer_header.marshal(&mut writer)?;
369 }
370
371 let mut raw_packet = vec![];
374 raw_packet.extend_from_slice(&record_layer_header_bytes);
375 raw_packet.extend_from_slice(handshake_fragment);
376 if p.should_encrypt
377 && let Some(cipher_suite) = &self.state.cipher_suite
378 {
379 raw_packet = cipher_suite.encrypt(&record_layer_header, &raw_packet)?;
380 }
381
382 raw_packets.push(raw_packet);
383 }
384
385 Ok(raw_packets)
386 }
387
388 fn fragment_handshake(maximum_transmission_unit: usize, h: &Handshake) -> Result<Vec<Vec<u8>>> {
389 let mut content = vec![];
390 {
391 let mut writer = BufWriter::<&mut Vec<u8>>::new(content.as_mut());
392 h.handshake_message.marshal(&mut writer)?;
393 }
394
395 let mut fragmented_handshakes = vec![];
396
397 let mut content_fragments = split_bytes(&content, maximum_transmission_unit);
398 if content_fragments.is_empty() {
399 content_fragments = vec![vec![]];
400 }
401
402 let mut offset = 0;
403 for content_fragment in &content_fragments {
404 let content_fragment_len = content_fragment.len();
405
406 let handshake_header_fragment = HandshakeHeader {
407 handshake_type: h.handshake_header.handshake_type,
408 length: h.handshake_header.length,
409 message_sequence: h.handshake_header.message_sequence,
410 fragment_offset: offset as u32,
411 fragment_length: content_fragment_len as u32,
412 };
413
414 offset += content_fragment_len;
415
416 let mut handshake_header_fragment_raw = vec![];
417 {
418 let mut writer =
419 BufWriter::<&mut Vec<u8>>::new(handshake_header_fragment_raw.as_mut());
420 handshake_header_fragment.marshal(&mut writer)?;
421 }
422
423 let mut fragmented_handshake = vec![];
424 fragmented_handshake.extend_from_slice(&handshake_header_fragment_raw);
425 fragmented_handshake.extend_from_slice(content_fragment);
426
427 fragmented_handshakes.push(fragmented_handshake);
428 }
429
430 Ok(fragmented_handshakes)
431 }
432
433 pub(crate) fn set_handshake_completed(&mut self) {
434 self.handshake_completed = true;
435 }
436
437 pub(crate) fn is_handshake_completed(&self) -> bool {
438 self.handshake_completed
439 }
440
441 pub fn read(&mut self, buf: &[u8]) -> Result<()> {
442 for pkt in unpack_datagram(buf)? {
443 let (hs, alert, err) = self.handle_incoming_packet(pkt, true);
444 if let Some(alert) = alert {
445 self.outgoing_packets.push_back(Packet {
446 record: RecordLayer::new(
447 PROTOCOL_VERSION1_2,
448 self.state.local_epoch,
449 Content::Alert(Alert {
450 alert_level: alert.alert_level,
451 alert_description: alert.alert_description,
452 }),
453 ),
454 should_encrypt: self.is_handshake_completed(),
455 reset_local_sequence_number: false,
456 });
457
458 if alert.alert_level == AlertLevel::Fatal
459 || alert.alert_description == AlertDescription::CloseNotify
460 {
461 return Err(Error::ErrAlertFatalOrClose);
462 }
463 }
464
465 if let Some(err) = err {
466 return Err(err);
467 }
468
469 if hs {
470 self.handshake_rx = Some(());
471 }
472 }
473
474 Ok(())
475 }
476
477 pub(crate) fn handle_incoming_queued_packets(&mut self) -> Result<()> {
478 if self.is_handshake_completed() {
479 while let Some(p) = self.incoming_encrypted_packets.pop_front() {
480 let (_, alert, err) = self.handle_incoming_packet(p, false); if let Some(alert) = alert {
482 self.outgoing_packets.push_back(Packet {
483 record: RecordLayer::new(
484 PROTOCOL_VERSION1_2,
485 self.state.local_epoch,
486 Content::Alert(Alert {
487 alert_level: alert.alert_level,
488 alert_description: alert.alert_description,
489 }),
490 ),
491 should_encrypt: self.is_handshake_completed(),
492 reset_local_sequence_number: false,
493 });
494
495 if alert.alert_level == AlertLevel::Fatal
496 || alert.alert_description == AlertDescription::CloseNotify
497 {
498 return Err(Error::ErrAlertFatalOrClose);
499 }
500 }
501
502 if let Some(err) = err {
503 return Err(err);
504 }
505 }
506 }
507
508 Ok(())
509 }
510
511 fn handle_incoming_packet(
512 &mut self,
513 mut pkt: Vec<u8>,
514 enqueue: bool,
515 ) -> (bool, Option<Alert>, Option<Error>) {
516 let mut reader = BufReader::new(pkt.as_slice());
517 let h = match RecordLayerHeader::unmarshal(&mut reader) {
518 Ok(h) => h,
519 Err(err) => {
520 debug!(
523 "{}: discarded broken packet: {}",
524 srv_cli_str(self.is_client),
525 err
526 );
527 return (false, None, None);
528 }
529 };
530
531 let epoch = self.state.remote_epoch;
533 if h.epoch > epoch {
534 if h.epoch > epoch + 1 {
535 debug!(
536 "{}: discarded future packet (epoch: {}, seq: {})",
537 srv_cli_str(self.is_client),
538 h.epoch,
539 h.sequence_number,
540 );
541 return (false, None, None);
542 }
543 if enqueue {
544 debug!(
545 "{}: received packet of next epoch, queuing packet",
546 srv_cli_str(self.is_client)
547 );
548 self.incoming_encrypted_packets.push_back(pkt);
549 }
550 return (false, None, None);
551 }
552
553 while self.replay_detector.len() <= h.epoch as usize {
555 self.replay_detector
556 .push(Box::new(SlidingWindowDetector::new(
557 self.replay_protection_window,
558 MAX_SEQUENCE_NUMBER,
559 )));
560 }
561
562 let ok = self.replay_detector[h.epoch as usize].check(h.sequence_number);
563 if !ok {
564 debug!(
565 "{}: discarded duplicated packet (epoch: {}, seq: {})",
566 srv_cli_str(self.is_client),
567 h.epoch,
568 h.sequence_number,
569 );
570 return (false, None, None);
571 }
572
573 if h.epoch != 0 {
575 let invalid_cipher_suite = {
576 if let Some(cipher_suite) = &self.state.cipher_suite {
577 !cipher_suite.is_initialized()
578 } else {
579 true
580 }
581 };
582 if invalid_cipher_suite {
583 if enqueue {
584 debug!(
585 "{}: handshake not finished, queuing packet",
586 srv_cli_str(self.is_client)
587 );
588 self.incoming_encrypted_packets.push_back(pkt);
589 }
590 return (false, None, None);
591 }
592
593 if let Some(cipher_suite) = &self.state.cipher_suite {
594 pkt = match cipher_suite.decrypt(&pkt) {
595 Ok(pkt) => pkt,
596 Err(err) => {
597 debug!("{}: decrypt failed: {}", srv_cli_str(self.is_client), err);
598
599 if cipher_suite.is_psk() {
601 return (
602 false,
603 Some(Alert {
604 alert_level: AlertLevel::Fatal,
605 alert_description: AlertDescription::UnknownPskIdentity,
606 }),
607 None,
608 );
609 } else {
610 return (false, None, None);
611 }
612 }
613 };
614 }
615 }
616
617 let is_handshake = match self.fragment_buffer.push(&pkt) {
618 Ok(is_handshake) => is_handshake,
619 Err(err) => {
620 debug!(
623 "{}: defragment failed: {}",
624 srv_cli_str(self.is_client),
625 err
626 );
627 return (false, None, None);
628 }
629 };
630 if is_handshake {
631 self.replay_detector[h.epoch as usize].accept();
632 while let Ok((out, epoch)) = self.fragment_buffer.pop() {
633 let mut reader = BufReader::new(out.as_slice());
635 let raw_handshake = match Handshake::unmarshal(&mut reader) {
636 Ok(rh) => {
637 debug!(
638 "Recv [handshake:{}] -> {} (epoch: {}, seq: {})",
639 srv_cli_str(self.is_client),
640 rh.handshake_header.handshake_type,
641 h.epoch,
642 rh.handshake_header.message_sequence
643 );
644 rh
645 }
646 Err(err) => {
647 debug!(
648 "{}: handshake parse failed: {}",
649 srv_cli_str(self.is_client),
650 err
651 );
652 continue;
653 }
654 };
655
656 self.cache.push(
657 out,
658 epoch,
659 raw_handshake.handshake_header.message_sequence,
660 raw_handshake.handshake_header.handshake_type,
661 !self.is_client,
662 );
663 }
664
665 return (true, None, None);
666 }
667
668 let mut reader = BufReader::new(pkt.as_slice());
669 let r = match RecordLayer::unmarshal(&mut reader) {
670 Ok(r) => r,
671 Err(err) => {
672 return (
673 false,
674 Some(Alert {
675 alert_level: AlertLevel::Fatal,
676 alert_description: AlertDescription::DecodeError,
677 }),
678 Some(err),
679 );
680 }
681 };
682
683 match r.content {
684 Content::Alert(mut a) => {
685 debug!("{}: <- {}", srv_cli_str(self.is_client), a);
686 if a.alert_description == AlertDescription::CloseNotify {
687 a = Alert {
689 alert_level: AlertLevel::Warning,
690 alert_description: AlertDescription::CloseNotify,
691 };
692 }
693 self.replay_detector[h.epoch as usize].accept();
694 return (
695 false,
696 Some(a),
697 Some(Error::Other(format!("Error of Alert {a}"))),
698 );
699 }
700 Content::ChangeCipherSpec(_) => {
701 let invalid_cipher_suite = {
702 if let Some(cipher_suite) = &self.state.cipher_suite {
703 !cipher_suite.is_initialized()
704 } else {
705 true
706 }
707 };
708
709 if invalid_cipher_suite {
710 if enqueue {
711 debug!(
712 "{}: CipherSuite not initialized, queuing packet",
713 srv_cli_str(self.is_client)
714 );
715 self.incoming_encrypted_packets.push_back(pkt);
716 }
717 return (false, None, None);
718 }
719
720 let new_remote_epoch = h.epoch + 1;
721 debug!(
722 "{}: <- ChangeCipherSpec (epoch: {})",
723 srv_cli_str(self.is_client),
724 new_remote_epoch
725 );
726
727 if epoch + 1 == new_remote_epoch {
728 self.state.remote_epoch = new_remote_epoch;
729 self.replay_detector[h.epoch as usize].accept();
730 }
731 }
732 Content::ApplicationData(a) => {
733 if h.epoch == 0 {
734 warn!(
735 "{}: <- Unexpected ApplicationData Message",
736 srv_cli_str(self.is_client),
737 );
738 return (
739 false,
740 Some(Alert {
741 alert_level: AlertLevel::Fatal,
742 alert_description: AlertDescription::UnexpectedMessage,
743 }),
744 Some(Error::ErrApplicationDataEpochZero),
745 );
746 }
747
748 self.replay_detector[h.epoch as usize].accept();
749
750 self.incoming_decrypted_packets.push_back(a.data);
751 }
752 _ => {
753 warn!(
754 "{}: <- Unexpected Handshake Message",
755 srv_cli_str(self.is_client),
756 );
757 return (
758 false,
759 Some(Alert {
760 alert_level: AlertLevel::Fatal,
761 alert_description: AlertDescription::UnexpectedMessage,
762 }),
763 Some(Error::ErrUnhandledContextType),
764 );
765 }
766 };
767
768 (false, None, None)
769 }
770
771 fn is_connection_closed(&self) -> bool {
772 self.closed
773 }
774
775 pub(crate) fn set_local_epoch(&mut self, epoch: u16) {
776 self.state.local_epoch = epoch;
777 }
778
779 pub(crate) fn get_local_epoch(&self) -> u16 {
780 self.state.local_epoch
781 }
782}
783
784fn compact_raw_packets(raw_packets: &[Vec<u8>], maximum_transmission_unit: usize) -> Vec<BytesMut> {
785 let mut combined_raw_packets = vec![];
786 let mut current_combined_raw_packet = BytesMut::new();
787
788 for raw_packet in raw_packets {
789 if !current_combined_raw_packet.is_empty()
790 && current_combined_raw_packet.len() + raw_packet.len() >= maximum_transmission_unit
791 {
792 combined_raw_packets.push(current_combined_raw_packet);
793 current_combined_raw_packet = BytesMut::new();
794 }
795 current_combined_raw_packet.extend_from_slice(raw_packet);
796 }
797
798 if !current_combined_raw_packet.is_empty() {
799 combined_raw_packets.push(current_combined_raw_packet);
800 }
801
802 combined_raw_packets
803}
804
805fn split_bytes(bytes: &[u8], split_len: usize) -> Vec<Vec<u8>> {
806 let mut splits = vec![];
807 let num_bytes = bytes.len();
808 for i in (0..num_bytes).step_by(split_len) {
809 let mut j = i + split_len;
810 if j > num_bytes {
811 j = num_bytes;
812 }
813
814 splits.push(bytes[i..j].to_vec());
815 }
816
817 splits
818}