use std::collections::HashMap;
use std::io::ErrorKind;
use std::net::{Shutdown, SocketAddr, TcpListener, ToSocketAddrs};
use std::time::{Duration};
use anyhow::Result;
use serde::de::DeserializeOwned;
use serde::Serialize;
use socket2::{Domain, Protocol, Socket, Type};
use crate::*;
#[derive(Clone, Debug)]
pub struct ServerConfig {
pub keep_alive_timeout: Option<Duration>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
keep_alive_timeout: None,
}
}
}
#[derive(Debug)]
pub struct Server {
connections: HashMap<SocketAddr, Connection>,
pub listener: TcpListener,
}
impl Server {
pub fn accept(&mut self) -> Result<Option<SocketAddr>> {
self.listener.set_nonblocking(true)?;
match self.listener.accept() {
Ok((stream, addr)) => {
let connection = Connection { stream, addr };
self.connections.insert(addr, connection);
Ok(Some(addr))
}
Err(err) => {
if err.kind() == ErrorKind::WouldBlock {
Ok(None)
} else {
Err(anyhow::Error::from(err))
}
}
}
}
pub fn accept_all(&mut self) -> Result<Vec<SocketAddr>> {
let mut vec = Vec::new();
loop {
if let Some(accepted) = self.accept()? {
vec.push(accepted);
} else {
break;
}
}
Ok(vec)
}
pub fn accept_blocking(&mut self) -> Result<SocketAddr> {
self.listener.set_nonblocking(false)?;
let (stream, addr) = self.listener.accept()?;
let connection = Connection { stream, addr };
self.connections.insert(addr, connection);
Ok(addr)
}
pub fn bind<B: ToSocketAddrs>(addr: B, config: &ServerConfig) -> Result<Server> {
let socket_addr: SocketAddr = addr.to_socket_addrs()?.nth(0).unwrap();
let domain = if socket_addr.is_ipv4() {
Domain::ipv4()
} else {
Domain::ipv6()
};
let socket_type = Type::stream();
let socket = Socket::new(domain, socket_type,
Some(Protocol::tcp()))?;
socket.set_keepalive(config.keep_alive_timeout)?;
socket.bind(&socket_addr.into())?;
socket.listen(128)?;
let listener = socket.into_tcp_listener();
Ok(Server { connections: HashMap::new(), listener })
}
pub fn block_until_receive_from(&mut self, addr: SocketAddr, timeout: Duration) -> Result<PacketReceiveStatus> {
self.with_connection(addr, |conn| {
crate::block_until_receive(&mut conn.stream, timeout)
})
}
pub fn connections(&self) -> Vec<&Connection> {
self.connections.values().collect()
}
pub fn receive<A: Serialize + DeserializeOwned>(&mut self) -> Result<Option<(SocketAddr, A)>> {
for conn in self.connections.values_mut() {
let result: Option<A> = read_packet::<A>(&mut conn.stream, false)?;
if result.is_some() {
return Ok(Some((conn.addr, result.unwrap())))
}
}
Ok(None)
}
pub fn receive_all<A: Serialize + DeserializeOwned>(&mut self) -> Result<Vec<(SocketAddr, A)>> {
let mut vec = Vec::new();
loop {
if let Some(packet) = self.receive()? {
vec.push(packet);
} else {
break;
}
}
Ok(vec)
}
pub fn receive_from<A: Serialize + DeserializeOwned>(&mut self, addr: SocketAddr) -> Result<Option<A>> {
self.with_connection(addr, |conn| {
crate::read_packet::<A>(&mut conn.stream, false)
})
}
pub fn send<A: Serialize + DeserializeOwned>(&mut self, addr: SocketAddr, packet: &A) -> Result<()> {
self.with_connection(addr, |conn| {
write_packet(&mut conn.stream, packet)?;
Ok(())
})
}
pub fn send_global<A: Serialize + DeserializeOwned>(&mut self, packet: &A) -> Result<()> {
for conn in self.connections.values_mut() {
write_packet(&mut conn.stream, packet)?;
}
Ok(())
}
pub fn shutdown(mut self) -> Result<()> {
for conn in &mut self.connections.values_mut() {
conn.stream.shutdown(Shutdown::Both)?;
}
Ok(())
}
pub fn with_connection<A, F>(&mut self, addr: SocketAddr, f: F) -> Result<A> where
F: FnOnce(&mut Connection) -> Result<A> {
if let Some(conn) = self.connections.get_mut(&addr) {
match f(conn) {
Ok(res) => Ok(res),
Err(err) => {
if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
let kind = io_err.kind();
if kind == ErrorKind::ConnectionAborted {
self.connections.remove(&addr);
}
}
Err(err)
}
}
} else {
let kind = ErrorKind::NotConnected;
Err(anyhow::Error::from(std::io::Error::new(kind, "Connection not found.")))
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::client::*;
#[test]
pub fn test_server_send_packet() -> Result<()> {
let addr = "localhost:60003";
let mut server = Server::bind(addr, &ServerConfig::default())?;
let mut client = Client::connect(addr, &ClientConfig::default())?;
let client_addr = server.accept_blocking()?;
let packet = 42;
server.send_global(&packet);
client.block_until_receive(Duration::from_millis(2000));
let received_packet = client.receive()?
.expect("Unable to read packet.");
assert_eq!(packet, received_packet);
Ok(())
}
#[test]
pub fn test_server_accept_non_blocking() -> Result<()> {
let addr = "localhost:60003";
let mut server = Server::bind(addr, &ServerConfig::default())?;
let client_addr = server.accept()?;
assert_eq!(None, client_addr);
Ok(())
}
#[test]
pub fn test_server_block_until_receive_from_timeout() -> Result<()> {
let addr = "localhost:60002";
let mut server = Server::bind(addr, &ServerConfig::default())?;
let mut client = Client::connect(addr, &ClientConfig::default())?;
let client_addr = server.accept_blocking()?;
let start_time = Instant::now();
let millis = 200;
let status = server.block_until_receive_from(client_addr, Duration::from_millis(millis))?;
let elapsed = start_time.elapsed().as_millis();
let diff = ((elapsed as i64) - (millis as i64)).abs();
assert_eq!(PacketReceiveStatus::TimedOut, status);
assert!(diff < 5);
Ok(())
}
#[test]
pub fn test_server_connection_lost() -> Result<()> {
let addr = "localhost:60007";
let mut server = Server::bind(addr, &ServerConfig::default())?;
let mut client = Client::connect(addr, &ClientConfig::default())?;
let client_addr = server.accept_blocking()?;
assert_eq!(1, server.connections().len());
client.shutdown();
std::thread::sleep(Duration::from_secs(1));
server.send(client_addr, &42)
.expect_err("Expected Err here.");
assert_eq!(0, server.connections().len());
Ok(())
}
#[test]
pub fn test_server_non_blocking_receive_fail() -> Result<()> {
let addr = "localhost:60006";
let mut server = Server::bind(addr, &ServerConfig::default())?;
let mut client = Client::connect(addr, &ClientConfig::default())?;
let res: Option<(SocketAddr, Vec<u8>)> = server.receive()?;
assert_eq!(None, res);
Ok(())
}
#[test]
pub fn test_server_shutdown() -> Result<()> {
let addr = "localhost:60004";
let mut server = Server::bind(addr, &ServerConfig::default())?;
let mut client = Client::connect(addr, &ClientConfig::default())?;
let client_addr = server.accept_blocking()?;
assert_eq!(1, server.connections().len());
server.shutdown();
std::thread::sleep(Duration::from_secs(1));
client.send(&42)
.expect_err("Expected Err here.");
assert!(!client.is_connected());
Ok(())
}
}