#[derive(Debug)]
pub struct ConnectionClosed;
impl std::fmt::Display for ConnectionClosed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "connection closed")
}
}
impl std::error::Error for ConnectionClosed {}
use anyhow::{Context, Result};
use std::{
collections::HashMap,
net::{IpAddr, SocketAddr},
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Duration,
};
use tokio::{net::UdpSocket, sync::Mutex};
fn normalize_remote_for_socket(socket: &UdpSocket, remote: &str) -> String {
let Ok(parsed) = remote.parse::<SocketAddr>() else {
return remote.to_owned();
};
let Ok(local) = socket.local_addr() else {
return parsed.to_string();
};
let normalized = match (local.is_ipv6(), parsed) {
(true, SocketAddr::V4(v4)) => {
let mapped = v4.ip().to_ipv6_mapped();
SocketAddr::new(IpAddr::V6(mapped), v4.port())
}
(false, SocketAddr::V6(v6)) => {
if let Some(v4) = v6.ip().to_ipv4_mapped() {
SocketAddr::new(IpAddr::V4(v4), v6.port())
} else {
parsed
}
}
_ => parsed,
};
normalized.to_string()
}
#[async_trait::async_trait]
pub trait ConnectionTrait: Send + Sync {
async fn send(&self, data: &[u8]) -> Result<()>;
async fn receive(&self, timeout: Duration) -> Result<Vec<u8>>;
fn is_reliable(&self) -> bool { false }
}
#[derive(Debug, Clone)]
struct ConnectionInfo {
sender: tokio::sync::mpsc::Sender<Vec<u8>>,
generation: u64,
}
pub struct Transport {
socket: Arc<UdpSocket>,
connections: Mutex<HashMap<String, ConnectionInfo>>,
remove_channel_sender: tokio::sync::mpsc::UnboundedSender<(String, u64)>,
next_generation: AtomicU64,
stop_receive_token: tokio_util::sync::CancellationToken,
}
pub struct Connection {
transport: Arc<Transport>,
remote_address: String,
receiver: Mutex<tokio::sync::mpsc::Receiver<Vec<u8>>>,
generation: u64,
}
impl Transport {
async fn read_from_socket_loop(
socket: Arc<UdpSocket>,
stop_receive_token: tokio_util::sync::CancellationToken,
self_weak: std::sync::Weak<Transport>,
) -> Result<()> {
loop {
let mut buf = vec![0u8; 2048];
let recv_result = {
tokio::select! {
recv_resp = socket.recv_from(&mut buf) => recv_resp,
_ = stop_receive_token.cancelled() => break
}
};
let (n, addr) = match recv_result {
Ok(r) => r,
Err(e) => {
log::debug!("transport recv error (ignored): {:?}", e);
continue;
}
};
buf.resize(n, 0);
let self_strong = self_weak
.upgrade()
.context("weakpointer to self is gone - just stop")?;
let cons = self_strong.connections.lock().await;
if let Some(c) = cons.get(&addr.to_string()) {
_ = c.sender.send(buf).await;
}
}
Ok(())
}
async fn read_from_delete_queue_loop(
mut remove_channel_receiver: tokio::sync::mpsc::UnboundedReceiver<(String, u64)>,
self_weak: std::sync::Weak<Transport>,
) -> Result<()> {
loop {
let to_remove = remove_channel_receiver.recv().await;
match to_remove {
Some((addr, _gen)) if addr.is_empty() => {
break;
}
Some((addr, gen)) => {
let self_strong = self_weak
.upgrade()
.context("weak to self is gone - just stop")?;
let mut cons = self_strong.connections.lock().await;
if cons.get(&addr).map(|c| c.generation) == Some(gen) {
cons.remove(&addr);
}
}
None => break, }
}
Ok(())
}
pub async fn new(local: &str) -> Result<Arc<Self>> {
let socket = UdpSocket::bind(local).await?;
let (remove_channel_sender, remove_channel_receiver) =
tokio::sync::mpsc::unbounded_channel();
let stop_receive_token = tokio_util::sync::CancellationToken::new();
let stop_receive_token_child = stop_receive_token.child_token();
let o = Arc::new(Self {
socket: Arc::new(socket),
connections: Mutex::new(HashMap::new()),
remove_channel_sender,
next_generation: AtomicU64::new(1),
stop_receive_token,
});
let self_weak = Arc::downgrade(&o.clone());
let socket = o.socket.clone();
tokio::spawn(async move {
_ = Self::read_from_socket_loop(socket, stop_receive_token_child, self_weak).await;
});
let self_weak = Arc::downgrade(&o.clone());
tokio::spawn(async move {
_ = Self::read_from_delete_queue_loop(remove_channel_receiver, self_weak).await;
});
Ok(o)
}
pub async fn create_connection(self: &Arc<Self>, remote: &str) -> Arc<dyn ConnectionTrait> {
let remote = normalize_remote_for_socket(&self.socket, remote);
let mut clock = self.connections.lock().await;
let generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
let (sender, receiver) = tokio::sync::mpsc::channel(32);
clock.insert(remote.to_owned(), ConnectionInfo { sender, generation });
Arc::new(Connection {
transport: self.clone(),
remote_address: remote,
receiver: Mutex::new(receiver),
generation,
})
}
}
impl Connection {
pub async fn send(&self, data: &[u8]) -> Result<()> {
self.transport
.socket
.send_to(data, &self.remote_address)
.await?;
Ok(())
}
pub async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
let mut ch = self.receiver.lock().await;
let rec_future = ch.recv();
let with_timeout = tokio::time::timeout(timeout, rec_future);
match with_timeout.await {
Err(_elapsed) => Err(anyhow::anyhow!("receive timeout")),
Ok(None) => Err(anyhow::Error::new(ConnectionClosed)),
Ok(Some(v)) => Ok(v),
}
}
}
impl Drop for Transport {
fn drop(&mut self) {
_ = self.remove_channel_sender.send(("".to_owned(), 0));
self.stop_receive_token.cancel();
}
}
#[async_trait::async_trait]
impl ConnectionTrait for Connection {
async fn send(&self, data: &[u8]) -> Result<()> {
self.send(data).await
}
async fn receive(&self, timeout: Duration) -> Result<Vec<u8>> {
self.receive(timeout).await
}
}
impl Drop for Connection {
fn drop(&mut self) {
_ = self
.transport
.remove_channel_sender
.send((self.remote_address.clone(), self.generation));
}
}