use std::io;
use std::net::SocketAddr;
use std::time::Instant;
use bytes::Bytes;
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio::time::{Duration, Interval, MissedTickBehavior, interval};
use crate::protocol::DisconnectReason;
use crate::socket::{SocketConfig, SocketEvent, SoeMultiplexer, SoeSocket};
const RECV_BUFFER_SIZE: usize = 2048;
#[derive(Debug)]
pub struct TokioSoeSocket {
mux: SoeMultiplexer<SocketAddr>,
socket: UdpSocket,
tick: Interval,
buf: Box<[u8]>,
}
impl TokioSoeSocket {
pub async fn bind(
local: SocketAddr,
config: SocketConfig,
tick_period: Duration,
) -> io::Result<Self> {
let socket = UdpSocket::bind(local).await?;
let mut tick = interval(tick_period);
tick.set_missed_tick_behavior(MissedTickBehavior::Delay);
Ok(Self {
mux: SoeMultiplexer::new(config),
socket,
tick,
buf: vec![0u8; RECV_BUFFER_SIZE].into_boxed_slice(),
})
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.socket.local_addr()
}
pub async fn step(&mut self) -> io::Result<Vec<SocketEvent<SocketAddr>>> {
tokio::select! {
result = self.socket.recv_from(&mut self.buf) => {
let (len, from) = result?;
let datagram = Bytes::copy_from_slice(&self.buf[..len]);
self.mux.process_incoming(from, datagram, Instant::now());
}
_ = self.tick.tick() => {}
}
self.mux.run_tick(Instant::now());
for (addr, datagram) in self.mux.take_outgoing() {
self.socket.send_to(&datagram, addr).await?;
}
Ok(self.mux.take_events())
}
}
impl SoeSocket for TokioSoeSocket {
fn local_addr(&self) -> io::Result<SocketAddr> {
self.socket.local_addr()
}
fn session_count(&self) -> usize {
self.mux.session_count()
}
fn connect(&mut self, remote: SocketAddr) {
self.mux.connect(remote, Instant::now());
}
fn enqueue_data(&mut self, remote: &SocketAddr, data: &[u8]) -> bool {
self.mux.enqueue_data(remote, data)
}
fn terminate(&mut self, remote: &SocketAddr, reason: DisconnectReason) {
self.mux.terminate(remote, reason, Instant::now());
}
}
enum Command {
Connect(SocketAddr),
EnqueueData {
remote: SocketAddr,
data: Bytes,
},
Terminate {
remote: SocketAddr,
reason: DisconnectReason,
},
}
#[derive(Clone, Debug)]
pub struct SoeHandle {
commands: mpsc::UnboundedSender<Command>,
}
impl SoeHandle {
pub fn connect(&self, remote: SocketAddr) -> bool {
self.commands.send(Command::Connect(remote)).is_ok()
}
pub fn enqueue_data(&self, remote: SocketAddr, data: impl Into<Bytes>) -> bool {
self.commands
.send(Command::EnqueueData {
remote,
data: data.into(),
})
.is_ok()
}
pub fn terminate(&self, remote: SocketAddr, reason: DisconnectReason) -> bool {
self.commands
.send(Command::Terminate { remote, reason })
.is_ok()
}
}
#[derive(Debug)]
pub struct TokioSoeServer {
handle: SoeHandle,
events: mpsc::UnboundedReceiver<SocketEvent<SocketAddr>>,
local_addr: SocketAddr,
driver: JoinHandle<()>,
}
impl TokioSoeServer {
pub async fn bind(
local: SocketAddr,
config: SocketConfig,
tick_period: Duration,
) -> io::Result<Self> {
let socket = UdpSocket::bind(local).await?;
let local_addr = socket.local_addr()?;
let (command_tx, command_rx) = mpsc::unbounded_channel();
let (event_tx, event_rx) = mpsc::unbounded_channel();
let driver = tokio::spawn(drive_loop(
socket,
config,
tick_period,
command_rx,
event_tx,
));
Ok(Self {
handle: SoeHandle {
commands: command_tx,
},
events: event_rx,
local_addr,
driver,
})
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn handle(&self) -> SoeHandle {
self.handle.clone()
}
pub async fn recv_event(&mut self) -> Option<SocketEvent<SocketAddr>> {
self.events.recv().await
}
pub fn abort(&self) {
self.driver.abort();
}
}
async fn drive_loop(
socket: UdpSocket,
config: SocketConfig,
tick_period: Duration,
mut commands: mpsc::UnboundedReceiver<Command>,
events: mpsc::UnboundedSender<SocketEvent<SocketAddr>>,
) {
let mut mux = SoeMultiplexer::new(config);
let mut tick = interval(tick_period);
tick.set_missed_tick_behavior(MissedTickBehavior::Delay);
let mut buf = vec![0u8; RECV_BUFFER_SIZE].into_boxed_slice();
loop {
tokio::select! {
result = socket.recv_from(&mut buf) => {
match result {
Ok((len, from)) => {
let datagram = Bytes::copy_from_slice(&buf[..len]);
mux.process_incoming(from, datagram, Instant::now());
}
Err(_) => continue,
}
}
_ = tick.tick() => {
mux.run_tick(Instant::now());
}
command = commands.recv() => {
match command {
Some(Command::Connect(remote)) => mux.connect(remote, Instant::now()),
Some(Command::EnqueueData { remote, data }) => {
let _ = mux.enqueue_data(&remote, &data);
}
Some(Command::Terminate { remote, reason }) => {
mux.terminate(&remote, reason, Instant::now());
}
None => break,
}
}
}
for (addr, datagram) in mux.take_outgoing() {
let _ = socket.send_to(&datagram, addr).await;
}
for event in mux.take_events() {
if events.send(event).is_err() {
return;
}
}
}
}