use std::{
cell::RefCell,
collections::{hash_map::Entry, HashMap},
net::SocketAddr,
rc::Rc,
};
use anyhow::Context;
use bytes::Bytes;
use tokio_uring::net::UdpSocket;
use crate::{
utp_packet::{Packet, PacketHeader, PacketType, HEADER_SIZE},
utp_stream::{UtpStream, WeakUtpStream},
};
pub struct UtpSocket {
socket: Rc<UdpSocket>,
shutdown_signal: Option<tokio::sync::oneshot::Sender<()>>,
accept_chan: Rc<RefCell<Option<tokio::sync::oneshot::Sender<UtpStream>>>>,
streams: Rc<RefCell<HashMap<StreamKey, WeakUtpStream>>>,
}
#[derive(PartialEq, Eq, Debug, Copy, Clone, Hash)]
struct StreamKey {
conn_id: u16,
addr: SocketAddr,
}
impl UtpSocket {
pub async fn bind(bind_addr: SocketAddr) -> anyhow::Result<Self> {
let socket = Rc::new(UdpSocket::bind(bind_addr).await?);
let net_loop_socket = socket.clone();
let (shutdown_signal, mut shutdown_receiver) = tokio::sync::oneshot::channel();
let utp_socket = UtpSocket {
socket,
shutdown_signal: Some(shutdown_signal),
accept_chan: Default::default(),
streams: Default::default(),
};
let streams_clone = utp_socket.streams.clone();
let accept_chan = utp_socket.accept_chan.clone();
tokio_uring::spawn(async move {
let mut recv_buf = vec![0; 1024 * 1024];
loop {
tokio::select! {
buf = process_incomming(&net_loop_socket, &streams_clone, &accept_chan, std::mem::take(&mut recv_buf)) => {
recv_buf = buf;
}
_ = &mut shutdown_receiver => {
log::info!("Shutting down network loop");
break;
}
}
}
});
Ok(utp_socket)
}
pub async fn connect(&self, addr: SocketAddr) -> anyhow::Result<UtpStream> {
let mut stream_key = StreamKey {
conn_id: rand::random(),
addr,
};
while self.streams.borrow().contains_key(&stream_key) {
log::debug!("Stream with same conn_id and addr already exists, regenerating conn_id");
stream_key = StreamKey {
conn_id: rand::random::<u16>(),
addr,
}
}
let stream = UtpStream::new(stream_key.conn_id, addr, Rc::downgrade(&self.socket));
self.streams
.borrow_mut()
.insert(stream_key, stream.clone().into());
stream.connect().await?;
Ok(stream)
}
pub async fn accept(&self) -> anyhow::Result<UtpStream> {
let (tx, rc) = tokio::sync::oneshot::channel();
{
let mut chan = self.accept_chan.borrow_mut();
*chan = Some(tx);
}
rc.await.context("Net loop exited")
}
}
impl Drop for UtpSocket {
fn drop(&mut self) {
println!("dropping");
self.shutdown_signal.take().unwrap().send(()).unwrap();
}
}
async fn process_incomming(
socket: &Rc<UdpSocket>,
connections: &Rc<RefCell<HashMap<StreamKey, WeakUtpStream>>>,
accept_chan: &Rc<RefCell<Option<tokio::sync::oneshot::Sender<UtpStream>>>>,
recv_buf: Vec<u8>,
) -> Vec<u8> {
let (result, buf) = socket.recv_from(recv_buf).await;
match result {
Ok((recv, addr)) => {
log::info!("Received {recv} from {addr}");
match PacketHeader::try_from(&buf[..recv]) {
Ok(packet_header) => {
let key = StreamKey {
conn_id: packet_header.conn_id,
addr,
};
let packet = Packet {
header: packet_header,
data: Bytes::copy_from_slice(&buf[HEADER_SIZE as usize..recv]),
};
let maybe_stream = { connections.borrow_mut().remove(&key) };
if let Some(weak_stream) = maybe_stream {
if let Some(stream) = weak_stream.try_upgrade() {
match stream.process_incoming(packet).await {
Ok(()) => {
connections.borrow_mut().insert(key, stream.into());
}
Err(err) => {
log::error!("Error: Failed processing incoming packet: {err}");
}
}
}
} else if packet_header.packet_type == PacketType::Syn {
let maybe_chan = { accept_chan.borrow_mut().take() };
if let Some(chan) = maybe_chan {
let stream = UtpStream::new_incoming(
packet_header.seq_nr,
packet_header.conn_id,
addr,
Rc::downgrade(socket),
);
let stream_key = StreamKey {
conn_id: packet_header.conn_id + 1,
addr,
};
{
let mut connections = connections.borrow_mut();
let entry = connections.entry(stream_key);
match entry {
Entry::Occupied(_) => {
log::warn!("Connection with id: {} already exists. Dropping connection",
packet_header.conn_id + 1
);
return buf;
}
Entry::Vacant(entry) => {
log::info!("New incoming connection!");
entry.insert(stream.clone().into());
}
}
}
if let Err(err) = stream.process_incoming(packet).await {
log::error!("Error accepting connection: {err}");
*accept_chan.borrow_mut() = Some(chan);
connections.borrow_mut().remove(&stream_key);
} else {
chan.send(stream).unwrap();
}
}
} else {
log::warn!("Connection not established prior");
}
}
Err(err) => log::error!("Error parsing packet: {err}"),
}
}
Err(err) => log::error!("Failed to receive on utp socket: {err}"),
}
buf
}