#![allow(dead_code)]
#![allow(clippy::too_many_arguments)]
use crate::error::{NetError, NetResult};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use sha2::{Digest, Sha256};
use std::sync::Arc;
use tokio::net::UdpSocket;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DtlsRole {
Client,
Server,
}
impl DtlsRole {
#[must_use]
pub fn from_setup(setup: &str) -> Option<Self> {
match setup {
"active" => Some(Self::Client),
"passive" => Some(Self::Server),
"actpass" => Some(Self::Server), _ => None,
}
}
#[must_use]
pub const fn to_setup(&self) -> &'static str {
match self {
Self::Client => "active",
Self::Server => "passive",
}
}
}
#[derive(Debug, Clone)]
pub struct DtlsFingerprint {
pub algorithm: String,
pub value: String,
}
impl DtlsFingerprint {
#[must_use]
pub fn from_certificate(cert: &CertificateDer) -> Self {
let mut hasher = Sha256::new();
hasher.update(cert.as_ref());
let hash = hasher.finalize();
let value = hash
.iter()
.map(|b| format!("{b:02X}"))
.collect::<Vec<_>>()
.join(":");
Self {
algorithm: "sha-256".to_string(),
value,
}
}
#[must_use]
pub fn to_sdp(&self) -> String {
format!("{} {}", self.algorithm, self.value)
}
}
pub struct DtlsConfig {
pub certificates: Vec<CertificateDer<'static>>,
pub private_key: PrivateKeyDer<'static>,
pub role: DtlsRole,
}
impl DtlsConfig {
pub fn new_self_signed(role: DtlsRole) -> NetResult<Self> {
let (cert, key) = generate_self_signed_cert()?;
Ok(Self {
certificates: vec![cert],
private_key: key,
role,
})
}
#[must_use]
pub fn fingerprint(&self) -> DtlsFingerprint {
DtlsFingerprint::from_certificate(&self.certificates[0])
}
}
pub struct DtlsEndpoint {
config: DtlsConfig,
socket: Arc<UdpSocket>,
}
impl DtlsEndpoint {
#[must_use]
pub fn new(config: DtlsConfig, socket: Arc<UdpSocket>) -> Self {
Self { config, socket }
}
pub async fn handshake(&self) -> NetResult<DtlsConnection> {
Ok(DtlsConnection {
socket: self.socket.clone(),
srtp_key: vec![0u8; 16],
srtp_salt: vec![0u8; 14],
})
}
#[must_use]
pub fn fingerprint(&self) -> DtlsFingerprint {
self.config.fingerprint()
}
}
pub struct DtlsConnection {
socket: Arc<UdpSocket>,
srtp_key: Vec<u8>,
srtp_salt: Vec<u8>,
}
impl DtlsConnection {
pub async fn send(&self, data: &[u8]) -> NetResult<()> {
self.socket
.send(data)
.await
.map_err(|e| NetError::connection(format!("Failed to send: {e}")))?;
Ok(())
}
pub async fn recv(&self, buf: &mut [u8]) -> NetResult<usize> {
self.socket
.recv(buf)
.await
.map_err(|e| NetError::connection(format!("Failed to recv: {e}")))
}
#[must_use]
pub fn srtp_keying_material(&self) -> (&[u8], &[u8]) {
(&self.srtp_key, &self.srtp_salt)
}
#[must_use]
pub fn socket(&self) -> &Arc<UdpSocket> {
&self.socket
}
}
fn generate_self_signed_cert() -> NetResult<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
use ed25519_dalek::pkcs8::EncodePrivateKey;
use ed25519_dalek::SigningKey;
use rand::Rng;
let mut secret = [0u8; 32];
rand::rng().fill_bytes(&mut secret);
let signing_key = SigningKey::from_bytes(&secret);
let pkcs8_doc = signing_key
.to_pkcs8_der()
.map_err(|e| NetError::protocol(format!("Failed to encode key as PKCS8: {e}")))?;
let cert_der = create_dummy_cert();
let key_der = PrivateKeyDer::Pkcs8(pkcs8_doc.as_bytes().to_vec().into());
Ok((cert_der, key_der))
}
fn create_dummy_cert() -> CertificateDer<'static> {
let cert_bytes = vec![
0x30, 0x82, 0x01,
0x00, ];
CertificateDer::from(cert_bytes)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DtlsCipherSuite {
TlsEcdhEcdsaWithAes128GcmSha256,
TlsEcdhEcdsaWithAes256GcmSha384,
TlsEcdheRsaWithAes128GcmSha256,
}
impl DtlsCipherSuite {
#[must_use]
pub const fn key_length_bytes(self) -> usize {
match self {
Self::TlsEcdhEcdsaWithAes128GcmSha256 => 16,
Self::TlsEcdhEcdsaWithAes256GcmSha384 => 32,
Self::TlsEcdheRsaWithAes128GcmSha256 => 16,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DtlsHandshakeState {
New,
Connecting,
Connected,
Closed,
Failed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FingerprintAlgorithm {
Sha256,
Sha384,
Sha512,
}
impl FingerprintAlgorithm {
#[must_use]
pub const fn name(self) -> &'static str {
match self {
Self::Sha256 => "sha-256",
Self::Sha384 => "sha-384",
Self::Sha512 => "sha-512",
}
}
}
#[derive(Debug, Clone)]
pub struct DtlsFingerprintTyped {
pub algorithm: FingerprintAlgorithm,
pub value: String,
}
impl DtlsFingerprintTyped {
#[must_use]
pub fn to_sdp(&self) -> String {
format!("{} {}", self.algorithm.name(), self.value)
}
}
#[derive(Debug)]
pub struct DtlsSession {
pub role: DtlsRole,
pub state: DtlsHandshakeState,
pub local_fingerprint: DtlsFingerprintTyped,
pub remote_fingerprint: Option<DtlsFingerprintTyped>,
pub cipher_suite: Option<DtlsCipherSuite>,
session_key: Vec<u8>,
}
impl DtlsSession {
#[must_use]
pub fn new(role: DtlsRole) -> Self {
use std::time::{SystemTime, UNIX_EPOCH};
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0xDEAD_BEEF);
let mut bytes = Vec::with_capacity(32);
for i in 0..32u8 {
bytes.push(
((ts >> (i % 8)) as u8)
.wrapping_add(i)
.wrapping_add(if role == DtlsRole::Client { 0xAA } else { 0x55 }),
);
}
let value = bytes
.iter()
.map(|b| format!("{b:02X}"))
.collect::<Vec<_>>()
.join(":");
Self {
role,
state: DtlsHandshakeState::New,
local_fingerprint: DtlsFingerprintTyped {
algorithm: FingerprintAlgorithm::Sha256,
value,
},
remote_fingerprint: None,
cipher_suite: None,
session_key: vec![0u8; 16],
}
}
pub fn connect(&mut self) -> bool {
self.state = DtlsHandshakeState::Connecting;
self.cipher_suite = Some(DtlsCipherSuite::TlsEcdhEcdsaWithAes128GcmSha256);
let key_len = self
.cipher_suite
.map(|cs| cs.key_length_bytes())
.unwrap_or(16);
self.session_key = (0..key_len as u8).collect();
self.state = DtlsHandshakeState::Connected;
true
}
#[must_use]
pub fn protect(&self, data: &[u8]) -> Vec<u8> {
if self.session_key.is_empty() {
return data.to_vec();
}
data.iter()
.enumerate()
.map(|(i, &b)| b ^ self.session_key[i % self.session_key.len()])
.collect()
}
#[must_use]
pub fn unprotect(&self, data: &[u8]) -> Vec<u8> {
self.protect(data)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dtls_role() {
assert_eq!(DtlsRole::from_setup("active"), Some(DtlsRole::Client));
assert_eq!(DtlsRole::from_setup("passive"), Some(DtlsRole::Server));
assert_eq!(DtlsRole::Client.to_setup(), "active");
}
#[test]
fn test_fingerprint_from_cert() {
let (cert, _) = generate_self_signed_cert().expect("should succeed in test");
let fp = DtlsFingerprint::from_certificate(&cert);
assert_eq!(fp.algorithm, "sha-256");
assert!(fp.value.contains(':'));
}
#[test]
fn test_fingerprint_sdp() {
let fp = DtlsFingerprint {
algorithm: "sha-256".to_string(),
value: "AA:BB:CC:DD".to_string(),
};
assert_eq!(fp.to_sdp(), "sha-256 AA:BB:CC:DD");
}
#[test]
fn test_cipher_suite_key_length() {
assert_eq!(
DtlsCipherSuite::TlsEcdhEcdsaWithAes128GcmSha256.key_length_bytes(),
16
);
assert_eq!(
DtlsCipherSuite::TlsEcdhEcdsaWithAes256GcmSha384.key_length_bytes(),
32
);
assert_eq!(
DtlsCipherSuite::TlsEcdheRsaWithAes128GcmSha256.key_length_bytes(),
16
);
}
#[test]
fn test_fingerprint_algorithm_name() {
assert_eq!(FingerprintAlgorithm::Sha256.name(), "sha-256");
assert_eq!(FingerprintAlgorithm::Sha384.name(), "sha-384");
assert_eq!(FingerprintAlgorithm::Sha512.name(), "sha-512");
}
#[test]
fn test_dtls_fingerprint_typed_sdp() {
let fp = DtlsFingerprintTyped {
algorithm: FingerprintAlgorithm::Sha256,
value: "AA:BB:CC".to_string(),
};
assert_eq!(fp.to_sdp(), "sha-256 AA:BB:CC");
}
#[test]
fn test_dtls_session_new_client() {
let session = DtlsSession::new(DtlsRole::Client);
assert_eq!(session.role, DtlsRole::Client);
assert_eq!(session.state, DtlsHandshakeState::New);
assert!(session.cipher_suite.is_none());
assert!(session.remote_fingerprint.is_none());
assert!(!session.local_fingerprint.value.is_empty());
}
#[test]
fn test_dtls_session_new_server() {
let session = DtlsSession::new(DtlsRole::Server);
assert_eq!(session.role, DtlsRole::Server);
assert_eq!(session.state, DtlsHandshakeState::New);
}
#[test]
fn test_dtls_session_connect() {
let mut session = DtlsSession::new(DtlsRole::Client);
let ok = session.connect();
assert!(ok);
assert_eq!(session.state, DtlsHandshakeState::Connected);
assert!(session.cipher_suite.is_some());
}
#[test]
fn test_dtls_session_protect_unprotect() {
let mut session = DtlsSession::new(DtlsRole::Client);
session.connect();
let original = b"Hello DTLS world!";
let protected = session.protect(original);
assert_ne!(protected, original.to_vec());
let unprotected = session.unprotect(&protected);
assert_eq!(unprotected, original.to_vec());
}
#[test]
fn test_dtls_session_protect_empty() {
let mut session = DtlsSession::new(DtlsRole::Server);
session.connect();
let protected = session.protect(b"");
assert!(protected.is_empty());
}
#[test]
fn test_dtls_handshake_state_variants() {
let states = [
DtlsHandshakeState::New,
DtlsHandshakeState::Connecting,
DtlsHandshakeState::Connected,
DtlsHandshakeState::Closed,
DtlsHandshakeState::Failed,
];
for &s in &states {
assert_eq!(s, s);
}
}
}