use std::io::{Error, ErrorKind};
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::mpsc::TryRecvError;
use std::collections::{HashMap, VecDeque};
use shared::stats::{Stats, StatsCollector};
use shared::ticker::Ticker;
use super::{
Config,
ConnectionID, Connection, ConnectionEvent,
RateLimiter, PacketModifier, Socket
};
#[derive(Debug, PartialEq)]
pub enum ServerEvent {
Connection(ConnectionID),
ConnectionLost(ConnectionID, bool),
ConnectionClosed(ConnectionID, bool),
Message(ConnectionID, Vec<u8>),
ConnectionCongestionStateChanged(ConnectionID, bool),
PacketLost(ConnectionID, Vec<u8>)
}
#[derive(Debug)]
pub struct Server<S: Socket, R: RateLimiter, M: PacketModifier> {
config: Config,
socket: Option<S>,
connections: HashMap<ConnectionID, Connection<R, M>>,
addresses: HashMap<ConnectionID, SocketAddr>,
dropped: Vec<ConnectionID>,
ticker: Ticker,
local_address: Option<SocketAddr>,
events: VecDeque<ServerEvent>,
should_receive: bool,
stats_collector: StatsCollector,
stats: Stats
}
impl<S: Socket, R: RateLimiter, M: PacketModifier> Server<S, R, M> {
pub fn new(config: Config) -> Server<S, R, M> {
Server {
config: config,
socket: None,
connections: HashMap::new(),
addresses: HashMap::new(),
dropped: Vec::new(),
ticker: Ticker::new(config),
local_address: None,
events: VecDeque::new(),
should_receive: false,
stats_collector: StatsCollector::new(config),
stats: Stats {
bytes_sent: 0,
bytes_received: 0
}
}
}
pub fn bytes_sent(&self) -> u32 {
self.stats.bytes_sent
}
pub fn bytes_received(&self) -> u32 {
self.stats.bytes_received
}
pub fn local_addr(&self) -> Result<SocketAddr, Error> {
self.local_address.ok_or_else(|| Error::new(ErrorKind::AddrNotAvailable, ""))
}
pub fn connection(&mut self, id: &ConnectionID) -> Result<&mut Connection<R, M>, Error> {
if self.socket.is_some() {
if let Some(conn) = self.connections.get_mut(id) {
Ok(conn)
} else {
Err(Error::new(ErrorKind::NotFound, ""))
}
} else {
Err(Error::new(ErrorKind::NotConnected, ""))
}
}
pub fn connections(&mut self) -> &mut HashMap<ConnectionID, Connection<R, M>> {
&mut self.connections
}
pub fn socket(&mut self) -> Result<&mut S, Error> {
if let Some(socket) = self.socket.as_mut() {
Ok(socket)
} else {
Err(Error::new(ErrorKind::NotConnected, ""))
}
}
pub fn config(&self) -> Config {
self.config
}
pub fn set_config(&mut self, config: Config) {
self.config = config;
self.ticker.set_config(config);
self.stats_collector.set_config(config);
for conn in self.connections.values_mut() {
conn.set_config(config);
}
}
pub fn listen<A: ToSocketAddrs>(&mut self, addr: A) -> Result<(), Error> {
if self.socket.is_none() {
let local_addr = try!(addr.to_socket_addrs()).nth(0).unwrap();
let socket = try!(S::new(
local_addr,
self.config.packet_max_size
));
self.socket = Some(socket);
self.local_address = Some(local_addr);
self.should_receive = true;
Ok(())
} else {
Err(Error::new(ErrorKind::AlreadyExists, ""))
}
}
pub fn accept_receive(&mut self) -> Result<ServerEvent, TryRecvError> {
if self.socket.is_none() {
Err(TryRecvError::Disconnected)
} else {
if self.should_receive {
self.ticker.begin_tick();
let mut bytes_received = 0;
while let Ok((addr, packet)) = self.socket.as_mut().unwrap().try_recv() {
if let Some(id) = Connection::<R, M>::id_from_packet(&self.config, &packet) {
bytes_received += self.receive_connection_packet(id, addr, packet);
}
}
self.stats_collector.set_bytes_received(bytes_received as u32);
self.should_receive = false;
}
if let Some(event) = self.events.pop_front() {
Ok(event)
} else {
Err(TryRecvError::Empty)
}
}
}
pub fn send(&mut self, auto_tick: bool) -> Result<(), Error> {
if self.socket.is_some() {
for id in self.dropped.drain(0..) {
self.connections.remove(&id).unwrap().reset();
self.addresses.remove(&id);
}
let mut bytes_sent = 0;
for (id, connection) in &mut self.connections {
let addr = &self.addresses[id];
bytes_sent += connection.send_packet(
self.socket.as_mut().unwrap(),
addr
);
if !connection.open() {
map_connection_events(&mut self.events, connection);
self.dropped.push(*id);
}
}
self.stats_collector.set_bytes_sent(bytes_sent);
self.stats_collector.tick();
self.stats = self.stats_collector.average();
self.should_receive = true;
if auto_tick {
self.ticker.end_tick();
}
Ok(())
} else {
Err(Error::new(ErrorKind::NotConnected, ""))
}
}
pub fn shutdown(&mut self) -> Result<(), Error> {
if self.socket.is_some() {
self.should_receive = false;
self.stats_collector.reset();
self.stats.reset();
self.events.clear();
self.connections.clear();
self.addresses.clear();
self.dropped.clear();
self.ticker.reset();
self.local_address = None;
self.socket = None;
Ok(())
} else {
Err(Error::new(ErrorKind::NotConnected, ""))
}
}
fn receive_connection_packet(
&mut self,
id: ConnectionID,
addr: SocketAddr,
packet: Vec<u8>
) -> usize {
let packet_length = packet.len();
if self.connections.contains_key(&id) {
let connection = self.connections.get_mut(&id).unwrap();
if connection.receive_packet(packet) && addr != connection.peer_addr() {
connection.set_peer_addr(addr);
self.addresses.remove(&id);
self.addresses.insert(id, addr);
}
map_connection_events(&mut self.events, connection);
packet_length
} else {
let mut conn = Connection::new(
self.config,
self.local_address.unwrap(),
addr,
R::new(self.config),
M::new(self.config)
);
conn.set_id(id);
self.connections.insert(id, conn);
self.addresses.insert(id, addr);
let connection = self.connections.get_mut(&id).unwrap();
if connection.receive_packet(packet) {
map_connection_events(&mut self.events, connection);
}
packet_length
}
}
}
fn map_connection_events<R: RateLimiter, M: PacketModifier>(
server_events: &mut VecDeque<ServerEvent>,
connection: &mut Connection<R, M>
) {
let id = connection.id();
for event in connection.events() {
server_events.push_back(match event {
ConnectionEvent::Connected => ServerEvent::Connection(id),
ConnectionEvent::Lost(by_remote) => ServerEvent::ConnectionLost(id, by_remote),
ConnectionEvent::FailedToConnect => unreachable!(),
ConnectionEvent::Closed(by_remote) => ServerEvent::ConnectionClosed(id, by_remote),
ConnectionEvent::Message(payload) => ServerEvent::Message(id, payload),
ConnectionEvent::CongestionStateChanged(c) => ServerEvent::ConnectionCongestionStateChanged(id, c),
ConnectionEvent::PacketLost(payload) => ServerEvent::PacketLost(id, payload)
})
}
}