use std::{collections::VecDeque, net::SocketAddr, sync::Arc, time::Duration};
use ana_gotatun::{
noise::TunnResult,
packet::{Packet, WgKind},
};
use bytes::BytesMut;
use snap_tun::server::{SnapTunAuthorization, SnapTunServer};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
const UDP_BUFFER_SIZE: usize = 65535;
pub struct ServerHarness<T: SnapTunAuthorization> {
server: Arc<tokio::sync::Mutex<SnapTunServer<T>>>,
network_socket: Arc<tokio::net::UdpSocket>,
tunnel_from_server_tx: mpsc::UnboundedSender<(BytesMut, SocketAddr)>,
tunnel_from_server_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<(BytesMut, SocketAddr)>>>,
tunnel_to_server_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<(BytesMut, SocketAddr)>>>,
tunnel_to_server_tx: mpsc::UnboundedSender<(BytesMut, SocketAddr)>,
cancel_token: CancellationToken,
}
impl<T: SnapTunAuthorization + 'static> ServerHarness<T> {
pub async fn new(server: SnapTunServer<T>, bind_addr: SocketAddr) -> std::io::Result<Self> {
let network_socket = Arc::new(tokio::net::UdpSocket::bind(bind_addr).await?);
#[allow(clippy::disallowed_methods)]
let (from_server_tx, from_server_rx) = mpsc::unbounded_channel();
#[allow(clippy::disallowed_methods)]
let (to_server_tx, to_server_rx) = mpsc::unbounded_channel();
Ok(Self {
server: Arc::new(tokio::sync::Mutex::new(server)),
network_socket,
tunnel_from_server_tx: from_server_tx,
tunnel_from_server_rx: Arc::new(tokio::sync::Mutex::new(from_server_rx)),
tunnel_to_server_rx: Arc::new(tokio::sync::Mutex::new(to_server_rx)),
tunnel_to_server_tx: to_server_tx,
cancel_token: CancellationToken::new(),
})
}
pub async fn run(&self) {
let mut send_queue = VecDeque::<WgKind>::new();
let mut timer = tokio::time::interval(Duration::from_millis(250));
timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut to_server_rx = self.tunnel_to_server_rx.lock().await;
loop {
let mut buf = BytesMut::zeroed(UDP_BUFFER_SIZE);
tokio::select! {
_ = self.cancel_token.cancelled() => {
tracing::debug!("Server harness shutting down");
break;
}
result = self.network_socket.recv_from(&mut buf) => {
match result {
Ok((n, from)) => {
buf.truncate(n);
let packet = Packet::from_bytes(buf);
let mut server = self.server.lock().await;
let result = server.handle_incoming_packet(packet, from, &mut send_queue);
if let TunnResult::WriteToTunnel(mut p) = result {
let buf = p.buf_mut().to_owned();
if !buf.is_empty() {
let _ = self.tunnel_from_server_tx.send((buf, from));
}
}
use zerocopy::IntoBytes as _;
while let Some(wg_packet) = send_queue.pop_front() {
let bytes = match wg_packet {
WgKind::HandshakeInit(p) => p.into_bytes(),
WgKind::HandshakeResp(p) => p.into_bytes(),
WgKind::CookieReply(p) => p.into_bytes(),
WgKind::Data(p) => p.into_bytes(),
};
let _ = self.network_socket.send_to(bytes.as_bytes(), from).await;
}
}
Err(e) => {
tracing::error!("Error receiving from network socket: {}", e);
}
}
}
Some((packet, target_addr)) = to_server_rx.recv() => {
use zerocopy::IntoBytes as _;
let packet = Packet::from_bytes(packet);
let mut server = self.server.lock().await;
if let Some(wg_packet) = server.handle_outgoing_packet(packet, target_addr) {
let bytes = match wg_packet {
WgKind::HandshakeInit(p) => p.into_bytes(),
WgKind::HandshakeResp(p) => p.into_bytes(),
WgKind::CookieReply(p) => p.into_bytes(),
WgKind::Data(p) => p.into_bytes(),
};
let _ = self.network_socket.send_to(bytes.as_bytes(), target_addr).await;
}
}
_ = timer.tick() => {
use zerocopy::IntoBytes as _;
let mut server = self.server.lock().await;
let packets = server.update_timers();
for (target_addr, wg_packet) in packets {
let bytes = match wg_packet {
WgKind::HandshakeInit(p) => p.into_bytes(),
WgKind::HandshakeResp(p) => p.into_bytes(),
WgKind::CookieReply(p) => p.into_bytes(),
WgKind::Data(p) => p.into_bytes(),
};
let _ = self.network_socket.send_to(bytes.as_bytes(), target_addr).await;
}
}
}
}
}
pub fn cancel_token(&self) -> CancellationToken {
self.cancel_token.clone()
}
pub fn socket_addr(&self) -> SocketAddr {
self.network_socket.local_addr().unwrap()
}
pub fn send_to_tunnel(&self, packet: BytesMut, target_addr: SocketAddr) {
let _ = self.tunnel_to_server_tx.send((packet, target_addr));
}
pub async fn recv_from_tunnel(&self, timeout: Duration) -> Option<(BytesMut, SocketAddr)> {
tokio::time::timeout(timeout, async {
self.tunnel_from_server_rx.lock().await.recv().await
})
.await
.ok()
.flatten()
}
}