use std::{
collections::{HashMap, HashSet},
net::SocketAddr,
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use anyhow::Result;
use aws_lc_rs::{
aead::{Aad, LessSafeKey, NONCE_LEN, Nonce},
hmac::{Key, sign},
rand,
};
use bincode_next::{config::standard, encode_to_vec};
use bon::Builder;
use getset::MutGetters;
use tokio::{
net::UdpSocket,
select,
sync::{mpsc::Receiver, oneshot},
time::{Instant as TokioInstant, sleep, sleep_until},
};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use super::DiffMode;
use crate::EncryptedFrame;
fn now_micros() -> u64 {
u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_micros(),
)
.unwrap_or(0)
}
pub(crate) const RETRANSMIT_WINDOW: u64 = 512;
pub const MAX_UDP_PAYLOAD: usize = 1200;
#[derive(Builder, Debug, MutGetters)]
pub struct UdpSender {
id: Uuid,
rnk: LessSafeKey,
hmac: Key,
socket: Arc<UdpSocket>,
control_rx: Receiver<EncryptedFrame>,
rx: Receiver<EncryptedFrame>,
retransmit_rx: Receiver<Vec<u64>>,
#[builder(default)]
send_seq: u64,
#[builder(default)]
retransmit_buffer: HashMap<u64, Vec<u8>>,
#[builder(default)]
pending_retransmit: HashSet<u64>,
peer_discovered_rx: Option<oneshot::Receiver<SocketAddr>>,
peer_addr_rx: Option<Receiver<SocketAddr>>,
warmup_delay: Option<Duration>,
#[builder(default)]
diff_mode: DiffMode,
}
impl UdpSender {
pub async fn frame_loop(&mut self, token: CancellationToken) -> Result<()> {
let mut current_peer: Option<SocketAddr> = None;
if let Some(rx) = self.peer_discovered_rx.take() {
current_peer = rx.await.ok();
}
if let Some(delay) = self.warmup_delay {
sleep(delay).await;
}
let mut retransmit_active = true;
let mut control_active = true;
let retransmit_park = Duration::from_hours(24);
let mut retransmit_deadline = TokioInstant::now() + retransmit_park;
loop {
select! {
biased;
() = token.cancelled() => break,
frame_opt = self.control_rx.recv(), if control_active => {
match frame_opt {
Some(frame) => {
let frame = match frame {
EncryptedFrame::Keepalive(_) => {
EncryptedFrame::Keepalive(now_micros())
}
other => other,
};
let seq = self.send_seq;
self.send_seq += 1;
let wire = self.encrypt(&frame, seq)?;
self.drain_roam_updates(&mut current_peer);
self.send_wire(&wire, current_peer).await?;
}
None => control_active = false,
}
},
seqs = self.retransmit_rx.recv(), if retransmit_active => {
match seqs {
Some(seqs) => {
if self.diff_mode == DiffMode::Reliable {
let was_empty = self.pending_retransmit.is_empty();
self.pending_retransmit.extend(seqs);
if was_empty && !self.pending_retransmit.is_empty() {
retransmit_deadline =
TokioInstant::now() + Duration::from_millis(20);
}
}
}
None => retransmit_active = false,
}
},
() = sleep_until(retransmit_deadline) => {
self.drain_roam_updates(&mut current_peer);
let pending: Vec<u64> = self.pending_retransmit.drain().collect();
for seq in pending {
if let Some(wire) = self.retransmit_buffer.get(&seq) {
let wire = wire.clone();
self.send_wire(&wire, current_peer).await?;
}
}
retransmit_deadline = TokioInstant::now() + retransmit_park;
},
frame_opt = self.rx.recv() => {
match frame_opt {
Some(frame) => {
let seq = self.send_seq;
self.send_seq += 1;
let wire = self.encrypt(&frame, seq)?;
if self.diff_mode == DiffMode::Reliable {
let _prev = self.retransmit_buffer.insert(seq, wire.clone());
let cutoff = seq.saturating_sub(RETRANSMIT_WINDOW);
self.retransmit_buffer.retain(|&s, _| s >= cutoff);
}
self.drain_roam_updates(&mut current_peer);
self.send_wire(&wire, current_peer).await?;
}
None => break,
}
},
}
}
Ok(())
}
fn drain_roam_updates(&mut self, peer: &mut Option<SocketAddr>) {
if let Some(ref mut rx) = self.peer_addr_rx {
while let Ok(addr) = rx.try_recv() {
*peer = Some(addr);
}
}
}
async fn send_wire(&self, wire: &[u8], peer: Option<SocketAddr>) -> Result<()> {
if let Some(addr) = peer {
let _n = self.socket.send_to(wire, addr).await?;
} else {
let _n = self.socket.send(wire).await?;
}
Ok(())
}
fn encrypt(&self, frame: &EncryptedFrame, seq: u64) -> Result<Vec<u8>> {
let data = encode_to_vec(frame, standard())?;
let aad = Aad::from(seq.to_be_bytes());
let mut encrypted_part = self.id.as_bytes().to_vec();
encrypted_part.extend_from_slice(&data);
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::fill(&mut nonce_bytes)?;
let nonce = Nonce::try_assume_unique_for_key(&nonce_bytes)?;
self.rnk
.seal_in_place_append_tag(nonce, aad, &mut encrypted_part)?;
let seq_bytes = seq.to_be_bytes();
let mut to_sign = seq_bytes.to_vec();
to_sign.extend_from_slice(&encrypted_part);
let tag = sign(&self.hmac, &to_sign);
let tag_bytes = tag.as_ref();
let len = encrypted_part.len().to_be_bytes();
let mut packet = nonce_bytes.to_vec();
packet.extend_from_slice(&seq_bytes);
packet.extend_from_slice(tag_bytes);
packet.extend_from_slice(&len);
packet.extend_from_slice(&encrypted_part);
Ok(packet)
}
}
#[cfg(test)]
mod tests {
use aws_lc_rs::{
aead::{AES_256_GCM_SIV, UnboundKey},
hmac::HMAC_SHA512,
};
use tokio::sync::mpsc::channel;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use anyhow::Result;
use aws_lc_rs::aead::LessSafeKey;
use aws_lc_rs::hmac::Key;
use tokio::net::UdpSocket;
use tokio::spawn;
use tokio::sync::{mpsc::Receiver, oneshot};
use tokio::time::{sleep, timeout};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use super::{EncryptedFrame, UdpSender, now_micros};
fn make_sender(
socket: Arc<UdpSocket>,
control_rx: Receiver<EncryptedFrame>,
rx: Receiver<EncryptedFrame>,
retransmit_rx: Receiver<Vec<u64>>,
) -> UdpSender {
UdpSender::builder()
.id(Uuid::new_v4())
.rnk(LessSafeKey::new(
UnboundKey::new(&AES_256_GCM_SIV, &[0u8; 32])
.expect("test AES-256-GCM-SIV key setup"),
))
.hmac(Key::new(HMAC_SHA512, &[0u8; 64]))
.socket(socket)
.control_rx(control_rx)
.rx(rx)
.retransmit_rx(retransmit_rx)
.build()
}
#[tokio::test]
async fn keepalive_is_restamped_at_send_time() -> Result<()> {
let server = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
let server_addr = server.local_addr()?;
let (_ctrl_tx, ctrl_rx) = channel::<EncryptedFrame>(4);
let (frame_tx, frame_rx) = channel::<EncryptedFrame>(4);
let (_retransmit_tx, retransmit_rx) = channel::<Vec<u64>>(4);
let token = CancellationToken::new();
let stale_ts = now_micros().saturating_sub(5_000_000);
frame_tx
.send(EncryptedFrame::Keepalive(stale_ts))
.await
.expect("test channel send");
let t_before_send = now_micros();
let send_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
send_socket.connect(server_addr).await?;
server.connect(send_socket.local_addr()?).await?;
let mut sender = make_sender(send_socket, ctrl_rx, frame_rx, retransmit_rx);
let token2 = token.clone();
let handle = spawn(async move {
drop(sender.frame_loop(token2).await);
});
let mut buf = vec![0u8; 65535];
drop(timeout(Duration::from_millis(500), server.recv(&mut buf)).await);
token.cancel();
drop(handle.await);
let t_after_send = now_micros();
assert!(
stale_ts < t_before_send.saturating_sub(4_000_000),
"stale_ts must be at least 4 s before send"
);
assert!(
t_after_send >= t_before_send,
"monotonic clock must advance"
);
Ok(())
}
#[tokio::test]
async fn control_channel_close_does_not_break_loop() -> Result<()> {
let server = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
let server_addr = server.local_addr()?;
let (ctrl_tx, ctrl_rx) = channel::<EncryptedFrame>(4);
let (frame_tx, frame_rx) = channel::<EncryptedFrame>(4);
let (_retransmit_tx, retransmit_rx) = channel::<Vec<u64>>(4);
let token = CancellationToken::new();
let send_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
send_socket.connect(server_addr).await?;
server.connect(send_socket.local_addr()?).await?;
let ts = now_micros();
ctrl_tx
.send(EncryptedFrame::Keepalive(ts))
.await
.expect("test channel send");
drop(ctrl_tx);
frame_tx
.send(EncryptedFrame::Shutdown)
.await
.expect("test channel send");
drop(frame_tx);
let mut sender = make_sender(send_socket, ctrl_rx, frame_rx, retransmit_rx);
let token2 = token.clone();
let handle = spawn(async move {
drop(sender.frame_loop(token2).await);
});
let mut count = 0usize;
let mut buf = vec![0u8; 65535];
while let Ok(Ok(_)) = timeout(Duration::from_millis(200), server.recv(&mut buf)).await {
count += 1;
}
token.cancel();
drop(handle.await);
assert_eq!(count, 2, "expected exactly 2 wire packets");
Ok(())
}
#[tokio::test]
async fn sender_adopts_roamed_peer_addr() -> Result<()> {
let old_peer = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
let new_peer = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
let old_addr = old_peer.local_addr()?;
let new_addr = new_peer.local_addr()?;
let (_ctrl_tx, ctrl_rx) = channel::<EncryptedFrame>(4);
let (frame_tx, frame_rx) = channel::<EncryptedFrame>(4);
let (_retransmit_tx, retransmit_rx) = channel::<Vec<u64>>(4);
let (peer_disc_tx, peer_disc_rx) = oneshot::channel::<SocketAddr>();
let (peer_addr_tx, peer_addr_rx) = channel::<SocketAddr>(4);
let token = CancellationToken::new();
let send_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
let mut sender = UdpSender::builder()
.id(Uuid::new_v4())
.rnk(LessSafeKey::new(
UnboundKey::new(&AES_256_GCM_SIV, &[0u8; 32])
.expect("test AES-256-GCM-SIV key setup"),
))
.hmac(Key::new(HMAC_SHA512, &[0u8; 64]))
.socket(send_socket)
.control_rx(ctrl_rx)
.rx(frame_rx)
.retransmit_rx(retransmit_rx)
.peer_discovered_rx(peer_disc_rx)
.peer_addr_rx(peer_addr_rx)
.build();
peer_disc_tx.send(old_addr).expect("test oneshot send");
let token2 = token.clone();
let handle = spawn(async move {
drop(sender.frame_loop(token2).await);
});
frame_tx
.send(EncryptedFrame::Keepalive(0))
.await
.expect("test channel send");
let mut buf = vec![0u8; 65535];
let got_old = timeout(Duration::from_millis(500), old_peer.recv(&mut buf))
.await
.is_ok();
peer_addr_tx
.send(new_addr)
.await
.expect("test channel send");
sleep(Duration::from_millis(10)).await;
frame_tx
.send(EncryptedFrame::Keepalive(0))
.await
.expect("test channel send");
let got_new = timeout(Duration::from_millis(500), new_peer.recv(&mut buf))
.await
.is_ok();
token.cancel();
drop(handle.await);
assert!(got_old, "first frame did not reach original peer");
assert!(got_new, "second frame did not reach roamed peer");
Ok(())
}
#[tokio::test]
async fn retransmit_deadline_fires_after_nak_request() -> Result<()> {
let server = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
let server_addr = server.local_addr()?;
let send_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
send_socket.connect(server_addr).await?;
server.connect(send_socket.local_addr()?).await?;
let (_ctrl_tx, ctrl_rx) = channel::<EncryptedFrame>(4);
let (frame_tx, frame_rx) = channel::<EncryptedFrame>(4);
let (retransmit_tx, retransmit_rx) = channel::<Vec<u64>>(4);
let token = CancellationToken::new();
let mut sender = make_sender(send_socket, ctrl_rx, frame_rx, retransmit_rx);
let token2 = token.clone();
let handle = spawn(async move { drop(sender.frame_loop(token2).await) });
frame_tx
.send(EncryptedFrame::Keepalive(0))
.await
.expect("test channel send");
let mut buf = vec![0u8; 65535];
let got_original = timeout(Duration::from_millis(200), server.recv(&mut buf))
.await
.is_ok();
retransmit_tx
.send(vec![0])
.await
.expect("test channel send");
let got_retransmit = timeout(Duration::from_millis(100), server.recv(&mut buf))
.await
.is_ok();
token.cancel();
drop(handle.await);
assert!(got_original, "original packet must reach server");
assert!(
got_retransmit,
"retransmit must fire within 100ms of NAK request"
);
Ok(())
}
}