use std::{
collections::{HashMap, HashSet},
net::SocketAddr,
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use anyhow::Result;
use aws_lc_rs::{
aead::{Aad, LessSafeKey, NONCE_LEN, Nonce},
hmac::{Key, sign},
rand,
};
use bincode_next::{config::standard, encode_to_vec};
use bon::Builder;
use getset::MutGetters;
use tokio::{
net::UdpSocket,
select,
sync::{mpsc::Receiver, oneshot},
time::{interval, sleep},
};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use super::DiffMode;
use crate::EncryptedFrame;
fn now_micros() -> u64 {
u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_micros(),
)
.unwrap_or(0)
}
pub(crate) const RETRANSMIT_WINDOW: u64 = 512;
pub const MAX_UDP_PAYLOAD: usize = 1200;
#[derive(Builder, Debug, MutGetters)]
pub struct UdpSender {
id: Uuid,
rnk: LessSafeKey,
hmac: Key,
socket: Arc<UdpSocket>,
control_rx: Receiver<EncryptedFrame>,
rx: Receiver<EncryptedFrame>,
retransmit_rx: Receiver<Vec<u64>>,
#[builder(default)]
send_seq: u64,
#[builder(default)]
retransmit_buffer: HashMap<u64, Vec<u8>>,
#[builder(default)]
pending_retransmit: HashSet<u64>,
peer_discovered_rx: Option<oneshot::Receiver<SocketAddr>>,
peer_addr_rx: Option<Receiver<SocketAddr>>,
warmup_delay: Option<Duration>,
#[builder(default)]
diff_mode: DiffMode,
}
impl UdpSender {
pub async fn frame_loop(&mut self, token: CancellationToken) -> Result<()> {
let mut current_peer: Option<SocketAddr> = None;
if let Some(rx) = self.peer_discovered_rx.take() {
current_peer = rx.await.ok();
}
if let Some(delay) = self.warmup_delay {
sleep(delay).await;
}
let mut retransmit_active = true;
let mut control_active = true;
let mut retransmit_tick = interval(Duration::from_millis(20));
loop {
select! {
biased;
() = token.cancelled() => break,
frame_opt = self.control_rx.recv(), if control_active => {
match frame_opt {
Some(frame) => {
let frame = match frame {
EncryptedFrame::Keepalive(_) => {
EncryptedFrame::Keepalive(now_micros())
}
other => other,
};
let seq = self.send_seq;
self.send_seq += 1;
let wire = self.encrypt(&frame, seq)?;
self.drain_roam_updates(&mut current_peer);
self.send_wire(&wire, current_peer).await?;
}
None => control_active = false,
}
},
seqs = self.retransmit_rx.recv(), if retransmit_active => {
match seqs {
Some(seqs) => {
if self.diff_mode == DiffMode::Reliable {
self.pending_retransmit.extend(seqs);
}
}
None => retransmit_active = false,
}
},
_ = retransmit_tick.tick(), if !self.pending_retransmit.is_empty() => {
self.drain_roam_updates(&mut current_peer);
let pending: Vec<u64> = self.pending_retransmit.drain().collect();
for seq in pending {
if let Some(wire) = self.retransmit_buffer.get(&seq) {
let wire = wire.clone();
self.send_wire(&wire, current_peer).await?;
}
}
},
frame_opt = self.rx.recv() => {
match frame_opt {
Some(frame) => {
let seq = self.send_seq;
self.send_seq += 1;
let wire = self.encrypt(&frame, seq)?;
if self.diff_mode == DiffMode::Reliable {
let _prev = self.retransmit_buffer.insert(seq, wire.clone());
let cutoff = seq.saturating_sub(RETRANSMIT_WINDOW);
self.retransmit_buffer.retain(|&s, _| s >= cutoff);
}
self.drain_roam_updates(&mut current_peer);
self.send_wire(&wire, current_peer).await?;
}
None => break,
}
},
}
}
Ok(())
}
fn drain_roam_updates(&mut self, peer: &mut Option<SocketAddr>) {
if let Some(ref mut rx) = self.peer_addr_rx {
while let Ok(addr) = rx.try_recv() {
*peer = Some(addr);
}
}
}
async fn send_wire(&self, wire: &[u8], peer: Option<SocketAddr>) -> Result<()> {
if let Some(addr) = peer {
let _n = self.socket.send_to(wire, addr).await?;
} else {
let _n = self.socket.send(wire).await?;
}
Ok(())
}
fn encrypt(&self, frame: &EncryptedFrame, seq: u64) -> Result<Vec<u8>> {
let data = encode_to_vec(frame, standard())?;
let aad = Aad::from(seq.to_be_bytes());
let mut encrypted_part = self.id.as_bytes().to_vec();
encrypted_part.extend_from_slice(&data);
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::fill(&mut nonce_bytes)?;
let nonce = Nonce::try_assume_unique_for_key(&nonce_bytes)?;
self.rnk
.seal_in_place_append_tag(nonce, aad, &mut encrypted_part)?;
let seq_bytes = seq.to_be_bytes();
let mut to_sign = seq_bytes.to_vec();
to_sign.extend_from_slice(&encrypted_part);
let tag = sign(&self.hmac, &to_sign);
let tag_bytes = tag.as_ref();
let len = encrypted_part.len().to_be_bytes();
let mut packet = nonce_bytes.to_vec();
packet.extend_from_slice(&seq_bytes);
packet.extend_from_slice(tag_bytes);
packet.extend_from_slice(&len);
packet.extend_from_slice(&encrypted_part);
Ok(packet)
}
}
#[cfg(test)]
mod tests {
use aws_lc_rs::{
aead::{AES_256_GCM_SIV, UnboundKey},
hmac::HMAC_SHA512,
};
use tokio::sync::mpsc::channel;
use super::*;
fn make_sender(
socket: Arc<UdpSocket>,
control_rx: Receiver<EncryptedFrame>,
rx: Receiver<EncryptedFrame>,
retransmit_rx: Receiver<Vec<u64>>,
) -> UdpSender {
UdpSender::builder()
.id(Uuid::new_v4())
.rnk(LessSafeKey::new(
UnboundKey::new(&AES_256_GCM_SIV, &[0u8; 32]).unwrap(),
))
.hmac(Key::new(HMAC_SHA512, &[0u8; 64]))
.socket(socket)
.control_rx(control_rx)
.rx(rx)
.retransmit_rx(retransmit_rx)
.build()
}
#[tokio::test]
async fn keepalive_is_restamped_at_send_time() {
let server = Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap());
let server_addr = server.local_addr().unwrap();
let (_ctrl_tx, ctrl_rx) = channel::<EncryptedFrame>(4);
let (frame_tx, frame_rx) = channel::<EncryptedFrame>(4);
let (_retransmit_tx, retransmit_rx) = channel::<Vec<u64>>(4);
let token = CancellationToken::new();
let stale_ts = now_micros().saturating_sub(5_000_000);
frame_tx
.send(EncryptedFrame::Keepalive(stale_ts))
.await
.unwrap();
let t_before_send = now_micros();
let send_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap());
send_socket.connect(server_addr).await.unwrap();
server
.connect(send_socket.local_addr().unwrap())
.await
.unwrap();
let mut sender = make_sender(send_socket, ctrl_rx, frame_rx, retransmit_rx);
let token2 = token.clone();
let handle = tokio::spawn(async move {
drop(sender.frame_loop(token2).await);
});
let mut buf = vec![0u8; 65535];
drop(tokio::time::timeout(Duration::from_millis(500), server.recv(&mut buf)).await);
token.cancel();
drop(handle.await);
let t_after_send = now_micros();
assert!(
stale_ts < t_before_send.saturating_sub(4_000_000),
"stale_ts must be at least 4 s before send"
);
assert!(
t_after_send >= t_before_send,
"monotonic clock must advance"
);
}
#[tokio::test]
async fn control_channel_close_does_not_break_loop() {
let server = Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap());
let server_addr = server.local_addr().unwrap();
let (ctrl_tx, ctrl_rx) = channel::<EncryptedFrame>(4);
let (frame_tx, frame_rx) = channel::<EncryptedFrame>(4);
let (_retransmit_tx, retransmit_rx) = channel::<Vec<u64>>(4);
let token = CancellationToken::new();
let send_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap());
send_socket.connect(server_addr).await.unwrap();
server
.connect(send_socket.local_addr().unwrap())
.await
.unwrap();
let ts = now_micros();
ctrl_tx.send(EncryptedFrame::Keepalive(ts)).await.unwrap();
drop(ctrl_tx);
frame_tx.send(EncryptedFrame::Shutdown).await.unwrap();
drop(frame_tx);
let mut sender = make_sender(send_socket, ctrl_rx, frame_rx, retransmit_rx);
let token2 = token.clone();
let handle = tokio::spawn(async move {
drop(sender.frame_loop(token2).await);
});
let mut count = 0usize;
let mut buf = vec![0u8; 65535];
while let Ok(Ok(_)) =
tokio::time::timeout(Duration::from_millis(200), server.recv(&mut buf)).await
{
count += 1;
}
token.cancel();
drop(handle.await);
assert_eq!(count, 2, "expected exactly 2 wire packets");
}
#[tokio::test]
async fn sender_adopts_roamed_peer_addr() {
let old_peer = Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap());
let new_peer = Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap());
let old_addr = old_peer.local_addr().unwrap();
let new_addr = new_peer.local_addr().unwrap();
let (_ctrl_tx, ctrl_rx) = channel::<EncryptedFrame>(4);
let (frame_tx, frame_rx) = channel::<EncryptedFrame>(4);
let (_retransmit_tx, retransmit_rx) = channel::<Vec<u64>>(4);
let (peer_disc_tx, peer_disc_rx) = oneshot::channel::<SocketAddr>();
let (peer_addr_tx, peer_addr_rx) = channel::<SocketAddr>(4);
let token = CancellationToken::new();
let send_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap());
let mut sender = UdpSender::builder()
.id(Uuid::new_v4())
.rnk(LessSafeKey::new(
UnboundKey::new(&AES_256_GCM_SIV, &[0u8; 32]).unwrap(),
))
.hmac(Key::new(HMAC_SHA512, &[0u8; 64]))
.socket(send_socket)
.control_rx(ctrl_rx)
.rx(frame_rx)
.retransmit_rx(retransmit_rx)
.peer_discovered_rx(peer_disc_rx)
.peer_addr_rx(peer_addr_rx)
.build();
peer_disc_tx.send(old_addr).unwrap();
let token2 = token.clone();
let handle = tokio::spawn(async move {
drop(sender.frame_loop(token2).await);
});
frame_tx.send(EncryptedFrame::Keepalive(0)).await.unwrap();
let mut buf = vec![0u8; 65535];
let got_old = tokio::time::timeout(Duration::from_millis(500), old_peer.recv(&mut buf))
.await
.is_ok();
peer_addr_tx.send(new_addr).await.unwrap();
sleep(Duration::from_millis(10)).await;
frame_tx.send(EncryptedFrame::Keepalive(0)).await.unwrap();
let got_new = tokio::time::timeout(Duration::from_millis(500), new_peer.recv(&mut buf))
.await
.is_ok();
token.cancel();
drop(handle.await);
assert!(got_old, "first frame did not reach original peer");
assert!(got_new, "second frame did not reach roamed peer");
}
}