extern crate alloc;
use alloc::vec::Vec;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum DtlsError {
HandshakeFailed {
reason: &'static str,
},
SendFailed {
reason: &'static str,
},
RecvFailed {
reason: &'static str,
},
Closed,
}
impl core::fmt::Display for DtlsError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::HandshakeFailed { reason } => write!(f, "dtls handshake failed: {reason}"),
Self::SendFailed { reason } => write!(f, "dtls send failed: {reason}"),
Self::RecvFailed { reason } => write!(f, "dtls recv failed: {reason}"),
Self::Closed => write!(f, "dtls stream closed"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for DtlsError {}
pub trait DtlsLayer {
fn handshake(&mut self) -> Result<(), DtlsError>;
fn send(&mut self, plaintext: &[u8]) -> Result<(), DtlsError>;
fn recv(&mut self) -> Result<Vec<u8>, DtlsError>;
fn close(&mut self) -> Result<(), DtlsError>;
fn is_handshake_complete(&self) -> bool;
}
#[derive(Debug, Default)]
pub struct DummyDtls {
handshake_done: bool,
inbox: alloc::collections::VecDeque<Vec<u8>>,
closed: bool,
}
impl DummyDtls {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn inject(&mut self, plaintext: Vec<u8>) {
self.inbox.push_back(plaintext);
}
#[must_use]
pub fn inbox_len(&self) -> usize {
self.inbox.len()
}
}
impl DtlsLayer for DummyDtls {
fn handshake(&mut self) -> Result<(), DtlsError> {
self.handshake_done = true;
Ok(())
}
fn send(&mut self, plaintext: &[u8]) -> Result<(), DtlsError> {
if self.closed {
return Err(DtlsError::Closed);
}
if !self.handshake_done {
return Err(DtlsError::SendFailed {
reason: "handshake not complete",
});
}
self.inbox.push_back(plaintext.to_vec());
Ok(())
}
fn recv(&mut self) -> Result<Vec<u8>, DtlsError> {
if self.closed && self.inbox.is_empty() {
return Err(DtlsError::Closed);
}
if !self.handshake_done {
return Err(DtlsError::RecvFailed {
reason: "handshake not complete",
});
}
self.inbox.pop_front().ok_or(DtlsError::RecvFailed {
reason: "inbox empty",
})
}
fn close(&mut self) -> Result<(), DtlsError> {
self.closed = true;
Ok(())
}
fn is_handshake_complete(&self) -> bool {
self.handshake_done
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
#[test]
fn dummy_dtls_handshake_then_send_then_recv() {
let mut d = DummyDtls::new();
assert!(!d.is_handshake_complete());
d.handshake().unwrap();
assert!(d.is_handshake_complete());
d.send(&[1, 2, 3, 4]).unwrap();
let pt = d.recv().unwrap();
assert_eq!(pt, alloc::vec![1, 2, 3, 4]);
}
#[test]
fn dummy_dtls_send_before_handshake_fails() {
let mut d = DummyDtls::new();
let res = d.send(&[1, 2, 3]);
assert!(matches!(
res,
Err(DtlsError::SendFailed {
reason: "handshake not complete"
})
));
}
#[test]
fn dummy_dtls_recv_before_handshake_fails() {
let mut d = DummyDtls::new();
let res = d.recv();
assert!(matches!(
res,
Err(DtlsError::RecvFailed {
reason: "handshake not complete"
})
));
}
#[test]
fn dummy_dtls_close_returns_closed_on_subsequent_send() {
let mut d = DummyDtls::new();
d.handshake().unwrap();
d.close().unwrap();
let res = d.send(&[1]);
assert!(matches!(res, Err(DtlsError::Closed)));
}
#[test]
fn dummy_dtls_close_drains_inbox_then_returns_closed() {
let mut d = DummyDtls::new();
d.handshake().unwrap();
d.send(&[1]).unwrap();
d.close().unwrap();
let pt = d.recv().unwrap();
assert_eq!(pt, alloc::vec![1]);
let res = d.recv();
assert!(matches!(res, Err(DtlsError::Closed)));
}
#[test]
fn dummy_dtls_inject_makes_recv_yield_payload() {
let mut d = DummyDtls::new();
d.handshake().unwrap();
d.inject(alloc::vec![9, 8, 7]);
assert_eq!(d.inbox_len(), 1);
assert_eq!(d.recv().unwrap(), alloc::vec![9, 8, 7]);
assert_eq!(d.inbox_len(), 0);
}
#[test]
fn dtls_error_display_formats_handshake() {
let e = DtlsError::HandshakeFailed { reason: "bad cert" };
let s = alloc::format!("{e}");
assert!(s.contains("bad cert"));
}
#[test]
fn dtls_error_display_formats_closed() {
let s = alloc::format!("{}", DtlsError::Closed);
assert!(s.contains("closed"));
}
#[test]
fn dummy_dtls_default_is_constructible() {
let _ = DummyDtls::default();
}
}