use super::Packet;
use std::{
io::{Error as IoError, ErrorKind},
net::{SocketAddr, UdpSocket},
sync::{
Mutex,
mpsc::{self, Receiver, Sender},
},
time::Duration,
};
const MAX_REQUEST_PACKET_SIZE: usize = 512;
pub trait Socket: Send + Sync + 'static {
fn send(&self, packet: &Packet) -> anyhow::Result<()>;
fn send_to(&self, packet: &Packet, to: &SocketAddr) -> anyhow::Result<()>;
fn recv(&self) -> anyhow::Result<Packet> {
self.recv_with_size(MAX_REQUEST_PACKET_SIZE)
}
fn recv_with_size(&self, size: usize) -> anyhow::Result<Packet>;
fn recv_from(&self) -> anyhow::Result<(Packet, SocketAddr)> {
self.recv_from_with_size(MAX_REQUEST_PACKET_SIZE)
}
fn recv_from_with_size(&self, size: usize) -> anyhow::Result<(Packet, SocketAddr)>;
fn remote_addr(&self) -> anyhow::Result<SocketAddr>;
fn set_read_timeout(&mut self, dur: Duration) -> anyhow::Result<()>;
fn set_write_timeout(&mut self, dur: Duration) -> anyhow::Result<()>;
fn set_nonblocking(&mut self, nonblocking: bool) -> anyhow::Result<()>;
}
impl Socket for UdpSocket {
fn send(&self, packet: &Packet) -> anyhow::Result<()> {
self.send(&packet.serialize()?)?;
Ok(())
}
fn send_to(&self, packet: &Packet, to: &SocketAddr) -> anyhow::Result<()> {
self.send_to(&packet.serialize()?, to)?;
Ok(())
}
fn recv_with_size(&self, size: usize) -> anyhow::Result<Packet> {
let mut buf = vec![0; size + 4];
let amt = self.recv(&mut buf)?;
let packet = Packet::deserialize(&buf[..amt])?;
Ok(packet)
}
fn recv_from_with_size(&self, size: usize) -> anyhow::Result<(Packet, SocketAddr)> {
let mut buf = vec![0; size + 4];
let (amt, addr) = self.recv_from(&mut buf)?;
let packet = Packet::deserialize(&buf[..amt])?;
Ok((packet, addr))
}
fn remote_addr(&self) -> anyhow::Result<SocketAddr> {
Ok(self.peer_addr()?)
}
fn set_read_timeout(&mut self, dur: Duration) -> anyhow::Result<()> {
UdpSocket::set_read_timeout(self, Some(dur))?;
Ok(())
}
fn set_write_timeout(&mut self, dur: Duration) -> anyhow::Result<()> {
UdpSocket::set_write_timeout(self, Some(dur))?;
Ok(())
}
fn set_nonblocking(&mut self, nonblocking: bool) -> anyhow::Result<()> {
UdpSocket::set_nonblocking(self, nonblocking)?;
Ok(())
}
}
pub struct ServerSocket {
socket: UdpSocket,
remote: SocketAddr,
sender: Mutex<Sender<Packet>>,
receiver: Mutex<Receiver<Packet>>,
timeout: Duration,
nonblocking: bool,
}
impl Socket for ServerSocket {
fn send(&self, packet: &Packet) -> anyhow::Result<()> {
self.send_to(packet, &self.remote)
}
fn send_to(&self, packet: &Packet, to: &SocketAddr) -> anyhow::Result<()> {
self.socket.send_to(&packet.serialize()?, to)?;
Ok(())
}
fn recv_with_size(&self, _size: usize) -> anyhow::Result<Packet> {
if let Ok(receiver) = self.receiver.lock() {
if self.nonblocking {
if let Ok(packet) = receiver.try_recv() {
Ok(packet)
} else {
Err(IoError::from(ErrorKind::WouldBlock).into())
}
} else if let Ok(packet) = receiver.recv_timeout(self.timeout) {
Ok(packet)
} else {
Err(anyhow::anyhow!("Failed to receive"))
}
} else {
Err(anyhow::anyhow!("Failed to lock mutex"))
}
}
fn recv_from_with_size(&self, _size: usize) -> anyhow::Result<(Packet, SocketAddr)> {
Ok((self.recv()?, self.remote))
}
fn remote_addr(&self) -> anyhow::Result<SocketAddr> {
Ok(self.remote)
}
fn set_read_timeout(&mut self, dur: Duration) -> anyhow::Result<()> {
self.timeout = dur;
Ok(())
}
fn set_write_timeout(&mut self, dur: Duration) -> anyhow::Result<()> {
self.socket.set_write_timeout(Some(dur))?;
Ok(())
}
fn set_nonblocking(&mut self, nonblocking: bool) -> anyhow::Result<()> {
self.nonblocking = nonblocking;
self.socket.set_nonblocking(nonblocking)?;
Ok(())
}
}
impl ServerSocket {
pub fn new(socket: UdpSocket, remote: SocketAddr, timeout: Duration) -> Self {
let (sender, receiver) = mpsc::channel();
Self {
socket,
remote,
sender: Mutex::new(sender),
receiver: Mutex::new(receiver),
timeout,
nonblocking: false,
}
}
pub fn sender(&self) -> Sender<Packet> {
self.sender.lock().unwrap().clone()
}
}
impl<T: Socket + ?Sized> Socket for Box<T> {
fn send(&self, packet: &Packet) -> anyhow::Result<()> {
(**self).send(packet)
}
fn send_to(&self, packet: &Packet, to: &SocketAddr) -> anyhow::Result<()> {
(**self).send_to(packet, to)
}
fn recv_with_size(&self, size: usize) -> anyhow::Result<Packet> {
(**self).recv_with_size(size)
}
fn recv_from_with_size(&self, size: usize) -> anyhow::Result<(Packet, SocketAddr)> {
(**self).recv_from_with_size(size)
}
fn remote_addr(&self) -> anyhow::Result<SocketAddr> {
(**self).remote_addr()
}
fn set_read_timeout(&mut self, dur: Duration) -> anyhow::Result<()> {
(**self).set_read_timeout(dur)
}
fn set_write_timeout(&mut self, dur: Duration) -> anyhow::Result<()> {
(**self).set_write_timeout(dur)
}
fn set_nonblocking(&mut self, nonblocking: bool) -> anyhow::Result<()> {
(**self).set_nonblocking(nonblocking)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn test_recv() {
let socket = ServerSocket::new(
UdpSocket::bind("127.0.0.1:0").unwrap(),
SocketAddr::from_str("127.0.0.1:50000").unwrap(),
Duration::from_secs(3),
);
socket.sender.lock().unwrap().send(Packet::Ack(1)).unwrap();
let packet = socket.recv().unwrap();
assert_eq!(packet, Packet::Ack(1));
socket
.sender
.lock()
.unwrap()
.send(Packet::Data {
block_num: 15,
data: vec![0x01, 0x02, 0x03],
})
.unwrap();
let packet = socket.recv().unwrap();
assert_eq!(
packet,
Packet::Data {
block_num: 15,
data: vec![0x01, 0x02, 0x03]
}
);
}
}