use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use tokio::time::{Instant as TokioInstant, sleep_until};
use crate::{PooledBuf, SlabPool, Tunn, TunnResult};
#[derive(Debug)]
pub struct AsyncTunn {
tunn: Tunn,
socket: Arc<UdpSocket>,
pool: Arc<SlabPool>,
peer_addr: Option<SocketAddr>,
}
#[derive(Debug)]
pub struct TunnChannels {
pub to_net: mpsc::Sender<PooledBuf>,
pub from_net: mpsc::Receiver<PooledBuf>,
}
impl AsyncTunn {
#[must_use]
pub fn new(tunn: Tunn, socket: Arc<UdpSocket>, peer_addr: Option<SocketAddr>) -> Self {
let pool = Arc::clone(tunn.pool());
Self {
tunn,
socket,
pool,
peer_addr,
}
}
#[must_use]
pub fn channels(
tunn: Tunn,
socket: Arc<UdpSocket>,
peer_addr: Option<SocketAddr>,
depth: usize,
) -> (
Self,
TunnChannels,
mpsc::Receiver<PooledBuf>,
mpsc::Sender<PooledBuf>,
) {
let (to_net_tx, to_net_rx) = mpsc::channel(depth);
let (from_net_tx, from_net_rx) = mpsc::channel(depth);
let driver = Self::new(tunn, socket, peer_addr);
(
driver,
TunnChannels {
to_net: to_net_tx,
from_net: from_net_rx,
},
to_net_rx,
from_net_tx,
)
}
#[must_use]
pub fn pool(&self) -> &Arc<SlabPool> {
&self.pool
}
#[must_use]
pub fn peer_addr(&self) -> Option<SocketAddr> {
self.peer_addr
}
#[must_use]
pub fn tunn(&self) -> &Tunn {
&self.tunn
}
pub async fn run(
mut self,
mut to_net: mpsc::Receiver<PooledBuf>,
from_net: mpsc::Sender<PooledBuf>,
) -> std::io::Result<()> {
let mut net_out = self.pool.get();
let mut tun_out = self.pool.get();
let mut rx = self.pool.get();
loop {
let wake = self
.tunn
.next_wake()
.map(TokioInstant::from_std)
.unwrap_or_else(far_future);
tokio::select! {
biased;
r = self.socket.recv_from(rx.spare_mut()) => {
let (n, src) = r?;
rx.set_len(n);
let datagram = &*rx;
match self.tunn.decapsulate(Some(src), datagram, tun_out.spare_mut()) {
TunnResult::WriteToTunnel(d) => {
self.peer_addr = Some(src);
let n = d.len();
tun_out.set_len(n);
let delivered = std::mem::replace(&mut tun_out, self.pool.get());
if from_net.send(delivered).await.is_err() {
return Ok(()); }
}
TunnResult::WriteToNetwork(w) => {
self.peer_addr = Some(src);
let first_len = w.len();
self.send_slice(tun_out.spare_mut(), first_len, src).await?;
self.drain_queue(&mut net_out, src).await?;
}
TunnResult::Done => {}
TunnResult::Err(_) => {
}
}
rx.set_len(0); }
pkt = to_net.recv() => {
let Some(pkt) = pkt else { return Ok(()); }; let Some(dst) = self.peer_addr else {
let _ = self.tunn.encapsulate(&pkt, net_out.spare_mut());
continue;
};
match self.tunn.encapsulate(&pkt, net_out.spare_mut()) {
TunnResult::WriteToNetwork(w) => {
let n = w.len();
self.send_slice(net_out.spare_mut(), n, dst).await?;
}
TunnResult::Done => {} TunnResult::Err(_) => {}
TunnResult::WriteToTunnel(_) => {} }
}
() = sleep_until(wake) => {
let Some(dst) = self.peer_addr else { continue; };
loop {
match self.tunn.update_timers(net_out.spare_mut()) {
TunnResult::WriteToNetwork(w) => {
let n = w.len();
self.send_slice(net_out.spare_mut(), n, dst).await?;
}
TunnResult::Done => break,
TunnResult::Err(_) => break,
TunnResult::WriteToTunnel(_) => break,
}
}
}
}
}
}
async fn send_slice(&self, slab: &mut [u8], n: usize, dst: SocketAddr) -> std::io::Result<()> {
if let Some(bytes) = slab.get(..n) {
self.socket.send_to(bytes, dst).await?;
}
Ok(())
}
async fn drain_queue(
&mut self,
net_out: &mut PooledBuf,
dst: SocketAddr,
) -> std::io::Result<()> {
loop {
match self.tunn.decapsulate(Some(dst), &[], net_out.spare_mut()) {
TunnResult::WriteToNetwork(w) => {
let n = w.len();
self.send_slice(net_out.spare_mut(), n, dst).await?;
}
_ => return Ok(()),
}
}
}
}
fn far_future() -> TokioInstant {
TokioInstant::now()
.checked_add(std::time::Duration::from_secs(86_400))
.unwrap_or_else(TokioInstant::now)
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::indexing_slicing,
clippy::arithmetic_side_effects
)]
use super::*;
use crate::RateLimiter;
use wireguard_sans_io::{PublicKey, StaticSecret};
async fn bound_socket() -> Arc<UdpSocket> {
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap())
}
fn keypair() -> (StaticSecret, PublicKey) {
use wireguard_sans_io::EntropySource;
let sk = StaticSecret::from_bytes(crate::OsEntropy.gen32().unwrap());
let pk = sk.public_key();
(sk, pk)
}
#[tokio::test]
async fn async_roundtrip_via_channels() {
let pool = SlabPool::for_wireguard();
pool.prefill(32);
let rl = Arc::new(RateLimiter::new(1_000_000));
let (a_sk, a_pk) = keypair();
let (b_sk, b_pk) = keypair();
let sock_a = bound_socket().await;
let sock_b = bound_socket().await;
let addr_a = sock_a.local_addr().unwrap();
let addr_b = sock_b.local_addr().unwrap();
let tunn_a =
Tunn::with_pool(a_sk, b_pk, None, None, Some(rl.clone()), pool.clone()).unwrap();
let tunn_b = Tunn::with_pool(b_sk, a_pk, None, None, Some(rl), pool.clone()).unwrap();
let (drv_a, chans_a, rx_a, tx_a) = AsyncTunn::channels(tunn_a, sock_a, Some(addr_b), 64);
let (drv_b, mut chans_b, rx_b, tx_b) =
AsyncTunn::channels(tunn_b, sock_b, Some(addr_a), 64);
let ha = tokio::spawn(drv_a.run(rx_a, tx_a));
let hb = tokio::spawn(drv_b.run(rx_b, tx_b));
let pkt = PooledBuf::copy_from(&pool, &{
let mut p = [0u8; 60];
p[0] = 0x45;
p[2..4].copy_from_slice(&60u16.to_be_bytes());
p[40..].fill(0xab);
p
});
let baseline_idle = pool.idle();
chans_a.to_net.send(pkt).await.unwrap();
let got = tokio::time::timeout(std::time::Duration::from_secs(5), chans_b.from_net.recv())
.await
.expect("timeout waiting for first packet")
.expect("channel closed");
assert_eq!(got.len(), 60);
assert_eq!(got[40], 0xab);
drop(got);
for i in 0..100u8 {
let mut p = [0u8; 60];
p[0] = 0x45;
p[2..4].copy_from_slice(&60u16.to_be_bytes());
p[59] = i;
chans_a
.to_net
.send(PooledBuf::copy_from(&pool, &p))
.await
.unwrap();
let g =
tokio::time::timeout(std::time::Duration::from_secs(2), chans_b.from_net.recv())
.await
.unwrap()
.unwrap();
assert_eq!(g[59], i);
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(
pool.outstanding() <= 16,
"outstanding={}",
pool.outstanding()
);
assert!(pool.idle() >= baseline_idle.saturating_sub(16));
drop(chans_a);
drop(chans_b);
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), ha).await;
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), hb).await;
}
}