use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use zeroize::Zeroize;
use ccm::{
Ccm,
aead::{Aead, KeyInit, generic_array::GenericArray},
consts::{U13, U16},
};
use aes::Aes128;
use super::message::MatterMessage;
use crate::homeauto::matter::error::{MatterError, MatterResult};
type Aes128Ccm = Ccm<Aes128, U16, U13>;
#[derive(Clone, Zeroize)]
pub struct SessionKeys {
pub encrypt_key: [u8; 16],
pub decrypt_key: [u8; 16],
}
pub type SessionMap = Arc<Mutex<HashMap<u16, SessionKeys>>>;
fn build_nonce(session_id: u16, message_counter: u32) -> [u8; 13] {
let mut nonce = [0u8; 13];
nonce[0..2].copy_from_slice(&session_id.to_le_bytes());
nonce[2..6].copy_from_slice(&message_counter.to_le_bytes());
nonce
}
fn header_bytes(msg: &MatterMessage) -> Vec<u8> {
let full = msg.encode();
let payload_len = msg.payload.len();
if full.len() > payload_len {
full[..full.len() - payload_len].to_vec()
} else {
full
}
}
pub struct UdpTransport {
socket: Arc<UdpSocket>,
pub sessions: SessionMap,
}
impl UdpTransport {
pub async fn new(port: u16) -> MatterResult<Self> {
Self::bind_addr(&format!("0.0.0.0:{port}")).await
}
pub async fn bind_addr(addr: &str) -> MatterResult<Self> {
let socket = UdpSocket::bind(addr).await?;
Ok(Self {
socket: Arc::new(socket),
sessions: Arc::new(Mutex::new(HashMap::new())),
})
}
pub async fn send(&self, msg: &MatterMessage, peer: SocketAddr) -> MatterResult<()> {
let session_id = msg.header.session_id;
let wire = if session_id != 0 {
let keys_guard = self.sessions.lock().await;
if let Some(keys) = keys_guard.get(&session_id) {
let enc_key = keys.encrypt_key;
drop(keys_guard);
let aad = header_bytes(msg);
let nonce_bytes = build_nonce(session_id, msg.header.message_counter);
let nonce = GenericArray::from_slice(&nonce_bytes);
let cipher = Aes128Ccm::new(GenericArray::from_slice(&enc_key));
let ciphertext = cipher
.encrypt(
nonce,
ccm::aead::Payload {
msg: &msg.payload,
aad: &aad,
},
)
.map_err(|_| MatterError::Transport("AES-CCM encrypt failed".into()))?;
let mut out_msg = msg.clone();
out_msg.payload = ciphertext;
out_msg.encode()
} else {
drop(keys_guard);
msg.encode()
}
} else {
msg.encode()
};
self.socket
.send_to(&wire, peer)
.await
.map_err(|e| MatterError::Transport(format!("send_to failed: {e}")))?;
Ok(())
}
pub async fn recv(&self) -> MatterResult<(MatterMessage, SocketAddr)> {
let mut buf = vec![0u8; 1280]; let (n, peer) = self
.socket
.recv_from(&mut buf)
.await
.map_err(|e| MatterError::Transport(format!("recv_from failed: {e}")))?;
buf.truncate(n);
let raw_msg = MatterMessage::decode(&buf)?;
let session_id = raw_msg.header.session_id;
if session_id == 0 {
return Ok((raw_msg, peer));
}
let keys_guard = self.sessions.lock().await;
if let Some(keys) = keys_guard.get(&session_id) {
let dec_key = keys.decrypt_key;
drop(keys_guard);
let aad = header_bytes(&raw_msg);
let nonce_bytes = build_nonce(session_id, raw_msg.header.message_counter);
let nonce = GenericArray::from_slice(&nonce_bytes);
let cipher = Aes128Ccm::new(GenericArray::from_slice(&dec_key));
let plaintext = cipher
.decrypt(
nonce,
ccm::aead::Payload {
msg: &raw_msg.payload,
aad: &aad,
},
)
.map_err(|_| MatterError::Transport("AES-CCM decrypt/verify failed".into()))?;
let mut out_msg = raw_msg;
out_msg.payload = plaintext;
Ok((out_msg, peer))
} else {
drop(keys_guard);
Ok((raw_msg, peer))
}
}
pub fn add_session(&self, id: u16, keys: SessionKeys) {
let sessions = Arc::clone(&self.sessions);
tokio::spawn(async move {
sessions.lock().await.insert(id, keys);
});
}
pub fn remove_session(&self, id: u16) {
let sessions = Arc::clone(&self.sessions);
tokio::spawn(async move {
if let Some(mut keys) = sessions.lock().await.remove(&id) {
keys.zeroize();
}
});
}
}