use std::collections::{HashMap, VecDeque};
use std::net::SocketAddr;
use std::time::{SystemTime, UNIX_EPOCH};
use tracing::{debug, error, trace};
use crate::serialize::reader::ReadBuffer;
use crate::serialize::wordbuffer::reader::ReadWordBuffer;
use crate::transport::io::Io;
use crate::transport::{PacketReceiver, PacketSender, Transport};
use super::{
bytes::Bytes,
crypto::{self, Key},
error::{Error, Result},
packet::{
ChallengePacket, DeniedPacket, DisconnectPacket, KeepAlivePacket, Packet, PayloadPacket,
RequestPacket, ResponsePacket,
},
replay::ReplayProtection,
token::{ChallengeToken, ConnectToken, ConnectTokenBuilder, ConnectTokenPrivate},
MAC_BYTES, MAX_PACKET_SIZE, MAX_PKT_BUF_SIZE, PACKET_SEND_RATE_SEC,
};
pub const MAX_CLIENTS: usize = 256;
#[derive(Clone, Copy)]
struct TokenEntry {
time: f64,
mac: [u8; 16],
addr: SocketAddr,
}
struct TokenEntries {
inner: Vec<TokenEntry>,
}
impl TokenEntries {
fn new() -> Self {
Self { inner: Vec::new() }
}
fn find_or_insert(&mut self, entry: TokenEntry) -> bool {
let (mut oldest, mut matching) = (None, None);
let mut oldest_time = f64::INFINITY;
for (idx, saved_entry) in self.inner.iter().enumerate() {
if entry.time < oldest_time {
oldest_time = saved_entry.time;
oldest = Some(idx);
}
if entry.mac == saved_entry.mac {
matching = Some(idx);
}
}
let Some(oldest) = oldest else {
self.inner.push(entry);
return true;
};
if let Some(matching) = matching {
self.inner[matching].addr == entry.addr
} else {
self.inner[oldest] = entry;
true
}
}
}
#[derive(Debug, Clone, Copy)]
struct Connection {
confirmed: bool,
connected: bool,
client_id: ClientId,
addr: SocketAddr,
timeout: i32,
last_access_time: f64,
last_send_time: f64,
last_receive_time: f64,
send_key: Key,
receive_key: Key,
sequence: u64,
}
impl Connection {
fn confirm(&mut self) {
self.confirmed = true;
}
fn connect(&mut self) {
self.connected = true;
}
fn is_confirmed(&self) -> bool {
self.confirmed
}
fn is_connected(&self) -> bool {
self.connected
}
}
pub type ClientId = u64;
struct ConnectionCache {
clients: HashMap<ClientId, Connection>,
client_id_map: HashMap<SocketAddr, ClientId>,
replay_protection: HashMap<ClientId, ReplayProtection>,
packet_queue: VecDeque<(ReadWordBuffer, ClientId)>,
time: f64,
}
impl ConnectionCache {
fn new(server_time: f64) -> Self {
Self {
clients: HashMap::with_capacity(MAX_CLIENTS),
client_id_map: HashMap::with_capacity(MAX_CLIENTS),
replay_protection: HashMap::with_capacity(MAX_CLIENTS),
packet_queue: VecDeque::with_capacity(MAX_CLIENTS * 2),
time: server_time,
}
}
fn add(
&mut self,
client_id: ClientId,
addr: SocketAddr,
timeout: i32,
send_key: Key,
receive_key: Key,
) {
if let Some((_, ref mut existing)) = self.find_by_addr(&addr) {
existing.client_id = client_id;
existing.timeout = timeout;
existing.send_key = send_key;
existing.receive_key = receive_key;
existing.last_access_time = self.time;
return;
}
let conn = Connection {
confirmed: false,
connected: false,
client_id,
addr,
timeout,
last_access_time: self.time,
last_send_time: f64::NEG_INFINITY,
last_receive_time: f64::NEG_INFINITY,
send_key,
receive_key,
sequence: 0,
};
self.clients.insert(client_id, conn);
self.replay_protection
.insert(client_id, ReplayProtection::new());
self.client_id_map.insert(addr, client_id);
}
fn remove(&mut self, client_id: ClientId) {
let Some(conn) = self.clients.get(&client_id) else {
return;
};
if !conn.is_connected() {
return;
}
self.client_id_map.remove(&conn.addr);
self.replay_protection.remove(&client_id);
self.clients.remove(&client_id);
}
fn ids(&self) -> Vec<ClientId> {
self.clients.keys().cloned().collect()
}
fn find_by_addr(&self, addr: &SocketAddr) -> Option<(ClientId, Connection)> {
self.client_id_map
.get(addr)
.and_then(|id| self.clients.get(id).map(|conn| (*id, *conn)))
}
fn find_by_id(&self, client_id: ClientId) -> Option<Connection> {
self.clients.get(&client_id).cloned()
}
fn update(&mut self, delta_ms: f64) {
self.time += delta_ms;
}
}
pub type Callback<Ctx> = Box<dyn FnMut(ClientId, &mut Ctx) + Send + Sync + 'static>;
pub struct ServerConfig<Ctx> {
num_disconnect_packets: usize,
keep_alive_send_rate: f64,
context: Ctx,
on_connect: Option<Callback<Ctx>>,
on_disconnect: Option<Callback<Ctx>>,
}
impl Default for ServerConfig<()> {
fn default() -> Self {
Self {
num_disconnect_packets: 10,
keep_alive_send_rate: PACKET_SEND_RATE_SEC,
context: (),
on_connect: None,
on_disconnect: None,
}
}
}
impl<Ctx> ServerConfig<Ctx> {
pub fn new() -> ServerConfig<()> {
ServerConfig::<()>::default()
}
pub fn with_context(ctx: Ctx) -> Self {
Self {
num_disconnect_packets: 10,
keep_alive_send_rate: PACKET_SEND_RATE_SEC,
context: ctx,
on_connect: None,
on_disconnect: None,
}
}
pub fn num_disconnect_packets(mut self, num: usize) -> Self {
self.num_disconnect_packets = num;
self
}
pub fn keep_alive_send_rate(mut self, rate_seconds: f64) -> Self {
self.keep_alive_send_rate = rate_seconds;
self
}
pub fn on_connect<F>(mut self, cb: F) -> Self
where
F: FnMut(ClientId, &mut Ctx) + Send + Sync + 'static,
{
self.on_connect = Some(Box::new(cb));
self
}
pub fn on_disconnect<F>(mut self, cb: F) -> Self
where
F: FnMut(ClientId, &mut Ctx) + Send + Sync + 'static,
{
self.on_disconnect = Some(Box::new(cb));
self
}
}
pub struct Server<Ctx = ()> {
time: f64,
private_key: Key,
sequence: u64,
token_sequence: u64,
challenge_sequence: u64,
challenge_key: Key,
protocol_id: u64,
conn_cache: ConnectionCache,
token_entries: TokenEntries,
cfg: ServerConfig<Ctx>,
}
impl Server {
pub fn new(protocol_id: u64, private_key: Key) -> Result<Self> {
let server: Server<()> = Server {
time: 0.0,
private_key,
protocol_id,
sequence: 1 << 63,
token_sequence: 0,
challenge_sequence: 0,
challenge_key: crypto::generate_key(),
conn_cache: ConnectionCache::new(0.0),
token_entries: TokenEntries::new(),
cfg: ServerConfig::default(),
};
Ok(server)
}
}
impl<Ctx> Server<Ctx> {
pub fn with_config(protocol_id: u64, private_key: Key, cfg: ServerConfig<Ctx>) -> Result<Self> {
let server = Server {
time: 0.0,
private_key,
protocol_id,
sequence: 1 << 63,
token_sequence: 0,
challenge_sequence: 0,
challenge_key: crypto::generate_key(),
conn_cache: ConnectionCache::new(0.0),
token_entries: TokenEntries::new(),
cfg,
};
Ok(server)
}
}
impl<Ctx> Server<Ctx> {
const ALLOWED_PACKETS: u8 = 1 << Packet::REQUEST
| 1 << Packet::RESPONSE
| 1 << Packet::KEEP_ALIVE
| 1 << Packet::PAYLOAD
| 1 << Packet::DISCONNECT;
fn on_connect(&mut self, client_id: ClientId) {
if let Some(cb) = self.cfg.on_connect.as_mut() {
cb(client_id, &mut self.cfg.context)
}
}
fn on_disconnect(&mut self, client_id: ClientId) {
if let Some(cb) = self.cfg.on_disconnect.as_mut() {
cb(client_id, &mut self.cfg.context)
}
}
fn touch_client(&mut self, client_id: Option<ClientId>) -> Result<()> {
let Some(id) = client_id else {
return Ok(());
};
let Some(conn) = self.conn_cache.clients.get_mut(&id) else {
return Ok(());
};
conn.last_receive_time = self.time;
if !conn.is_confirmed() {
debug!("server confirmed connection with client {id}");
conn.confirm();
}
Ok(())
}
fn process_packet(
&mut self,
addr: SocketAddr,
packet: Packet,
sender: &mut impl PacketSender,
) -> Result<()> {
let client_id = self.conn_cache.find_by_addr(&addr).map(|(id, _)| id);
trace!(
"server received {} from {}",
packet.to_string(),
client_id
.map(|idx| format!("client {idx}"))
.unwrap_or_else(|| addr.to_string())
);
match packet {
Packet::Request(packet) => self.process_connection_request(addr, packet, sender),
Packet::Response(packet) => self.process_connection_response(addr, packet, sender),
Packet::KeepAlive(_) => self.touch_client(client_id),
Packet::Payload(packet) => {
self.touch_client(client_id)?;
if let Some(idx) = client_id {
self.conn_cache
.packet_queue
.push_back((ReadWordBuffer::start_read(packet.buf), idx));
}
Ok(())
}
Packet::Disconnect(_) => {
if let Some(idx) = client_id {
debug!("server disconnected client {idx}");
self.on_disconnect(idx);
self.conn_cache.remove(idx);
}
Ok(())
}
_ => unreachable!("packet should have been filtered out by `ALLOWED_PACKETS`"),
}
}
fn send_to_addr(
&mut self,
packet: Packet,
addr: SocketAddr,
key: Key,
sender: &mut impl PacketSender,
) -> Result<()> {
let mut buf = [0u8; MAX_PKT_BUF_SIZE];
let size = packet.write(&mut buf, self.sequence, &key, self.protocol_id)?;
sender.send(&buf[..size], &addr).map_err(Error::from)?;
self.sequence += 1;
Ok(())
}
fn send_to_client(
&mut self,
packet: Packet,
id: ClientId,
sender: &mut impl PacketSender,
) -> Result<()> {
let mut buf = [0u8; MAX_PKT_BUF_SIZE];
let conn = &mut self
.conn_cache
.clients
.get_mut(&id)
.expect("invalid client id");
let size = packet.write(&mut buf, conn.sequence, &conn.send_key, self.protocol_id)?;
sender.send(&buf[..size], &conn.addr).map_err(Error::from)?;
conn.last_access_time = self.time;
conn.last_send_time = self.time;
conn.sequence += 1;
Ok(())
}
fn process_connection_request(
&mut self,
from_addr: SocketAddr,
mut packet: RequestPacket,
sender: &mut impl PacketSender,
) -> Result<()> {
let mut reader = std::io::Cursor::new(&mut packet.token_data[..]);
let Ok(token) = ConnectTokenPrivate::read_from(&mut reader) else {
debug!("server ignored connection request. failed to read connect token");
return Ok(());
};
if self
.conn_cache
.find_by_addr(&from_addr)
.is_some_and(|(_, conn)| conn.is_connected())
{
debug!("server ignored connection request. a client with this address is already connected");
return Ok(());
};
if self
.conn_cache
.find_by_id(token.client_id)
.is_some_and(|conn| conn.is_connected())
{
debug!("server ignored connection request. a client with this id is already connected");
return Ok(());
};
let entry = TokenEntry {
time: self.time,
addr: from_addr,
mac: packet.token_data
[ConnectTokenPrivate::SIZE - MAC_BYTES..ConnectTokenPrivate::SIZE]
.try_into()
.expect("valid MAC size"),
};
if !self.token_entries.find_or_insert(entry) {
debug!("server ignored connection request. connect token has already been used");
return Ok(());
};
if self.num_connected_clients() >= MAX_CLIENTS {
debug!("server denied connection request. server is full");
self.send_to_addr(
DeniedPacket::create(),
from_addr,
token.server_to_client_key,
sender,
)?;
return Ok(());
};
self.conn_cache.add(
token.client_id,
from_addr,
token.timeout_seconds,
token.server_to_client_key,
token.client_to_server_key,
);
let Ok(challenge_token_encrypted) = ChallengeToken {
client_id: token.client_id,
user_data: token.user_data,
}
.encrypt(self.challenge_sequence, &self.challenge_key) else {
debug!("server ignored connection request. failed to encrypt challenge token");
return Ok(());
};
self.send_to_addr(
ChallengePacket::create(self.challenge_sequence, challenge_token_encrypted),
from_addr,
token.server_to_client_key,
sender,
)?;
debug!("server sent connection challenge packet");
self.challenge_sequence += 1;
Ok(())
}
fn process_connection_response(
&mut self,
from_addr: SocketAddr,
mut packet: ResponsePacket,
sender: &mut impl PacketSender,
) -> Result<()> {
let Ok(challenge_token) =
ChallengeToken::decrypt(&mut packet.token, packet.sequence, &self.challenge_key)
else {
debug!("server ignored connection response. failed to decrypt challenge token");
return Ok(());
};
let id: ClientId = challenge_token.client_id;
let Some(conn) = self.conn_cache.find_by_id(id) else {
debug!("server ignored connection response. no packet send key");
return Ok(());
};
if conn.is_connected() {
debug!("server ignored connection request. a client with this id is already connected");
return Ok(());
};
if self.num_connected_clients() >= MAX_CLIENTS {
debug!("server denied connection response. server is full");
self.send_to_addr(
DeniedPacket::create(),
from_addr,
self.conn_cache
.clients
.get(&id)
.expect("invalid client id")
.send_key,
sender,
)?;
return Ok(());
};
let client = self
.conn_cache
.clients
.get_mut(&id)
.expect("invalid client id");
client.connect();
client.last_send_time = self.time;
client.last_receive_time = self.time;
debug!(
"server accepted client {} with id {}",
id, challenge_token.client_id
);
self.send_to_client(KeepAlivePacket::create(id), id, sender)?;
self.on_connect(id);
Ok(())
}
fn check_for_timeouts(&mut self) {
for id in self.conn_cache.ids() {
let Some(client) = self.conn_cache.clients.get_mut(&id) else {
continue;
};
if !client.is_connected() {
continue;
}
if client.timeout.is_positive()
&& client.last_receive_time + (client.timeout as f64) < self.time
{
debug!("server timed out client {id}");
self.on_disconnect(id);
self.conn_cache.remove(id);
}
}
}
fn send_packets(&mut self, io: &mut Io) -> Result<()> {
for id in self.conn_cache.ids() {
let Some(client) = self.conn_cache.clients.get_mut(&id) else {
continue;
};
if !client.is_connected() {
continue;
}
if client.last_send_time + self.cfg.keep_alive_send_rate >= self.time {
continue;
}
self.send_to_client(KeepAlivePacket::create(id), id, io)?;
trace!("server sent connection keep-alive packet to client {id}");
}
Ok(())
}
fn recv_packet(
&mut self,
buf: &mut [u8],
now: u64,
addr: SocketAddr,
sender: &mut impl PacketSender,
) -> Result<()> {
if buf.len() <= 1 {
return Ok(());
}
let (key, replay_protection) = match self.conn_cache.find_by_addr(&addr) {
_ if buf[0] == Packet::REQUEST => (self.private_key, None),
Some((client_id, _)) => (
self.conn_cache
.clients
.get(&client_id)
.expect("client id not found")
.receive_key,
self.conn_cache.replay_protection.get_mut(&client_id),
),
None => {
debug!("server ignored non-connection-request packet from unknown address {addr}");
return Ok(());
}
};
let packet = match Packet::read(
buf,
self.protocol_id,
now,
key,
replay_protection,
Self::ALLOWED_PACKETS,
) {
Ok(packet) => packet,
Err(Error::Crypto(e)) => {
debug!(error = ?e, "server ignored packet because it failed to decrypt.");
return Ok(());
}
Err(e) => {
error!("server ignored packet: {e}");
return Ok(());
}
};
self.process_packet(addr, packet, sender)
}
fn recv_packets(
&mut self,
sender: &mut impl PacketSender,
receiver: &mut impl PacketReceiver,
) -> Result<()> {
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
while let Some((buf, addr)) = receiver.recv().map_err(Error::from)? {
self.recv_packet(buf, now, addr, sender)?;
}
Ok(())
}
pub fn update(&mut self, delta_ms: f64, io: &mut Io) {
self.try_update(delta_ms, io)
.expect("send/recv error while updating server")
}
pub fn try_update(&mut self, delta_ms: f64, io: &mut Io) -> Result<()> {
self.time += delta_ms;
self.conn_cache.update(delta_ms);
let (sender, receiver) = io.split();
self.recv_packets(sender, receiver)?;
self.send_packets(io)?;
self.check_for_timeouts();
Ok(())
}
pub fn recv(&mut self) -> Option<(ReadWordBuffer, ClientId)> {
self.conn_cache.packet_queue.pop_front()
}
pub fn send(&mut self, buf: &[u8], client_id: ClientId, io: &mut Io) -> Result<()> {
if buf.len() > MAX_PACKET_SIZE {
return Err(Error::SizeMismatch(MAX_PACKET_SIZE, buf.len()));
}
let Some(conn) = self.conn_cache.clients.get_mut(&client_id) else {
return Err(Error::ClientNotFound);
};
if !conn.is_connected() {
return Err(Error::ClientNotConnected);
}
if !conn.is_confirmed() {
self.send_to_client(KeepAlivePacket::create(client_id), client_id, io)?;
}
let packet = PayloadPacket::create(buf);
self.send_to_client(packet, client_id, io)
}
pub fn send_all(&mut self, buf: &[u8], io: &mut Io) -> Result<()> {
for id in self.conn_cache.ids() {
match self.send(buf, id, io) {
Ok(_) | Err(Error::ClientNotConnected) | Err(Error::ClientNotFound) => continue,
Err(e) => return Err(e),
}
}
Ok(())
}
pub fn token(
&mut self,
client_id: ClientId,
server_addr: SocketAddr,
) -> ConnectTokenBuilder<SocketAddr> {
let token_builder =
ConnectToken::build(server_addr, self.protocol_id, client_id, self.private_key);
self.token_sequence += 1;
token_builder
}
pub fn disconnect(&mut self, client_id: ClientId, io: &mut Io) -> Result<()> {
let Some(conn) = self.conn_cache.clients.get_mut(&client_id) else {
return Ok(());
};
if !conn.is_connected() {
return Ok(());
}
debug!("server disconnecting client {client_id}");
for _ in 0..self.cfg.num_disconnect_packets {
self.send_to_client(DisconnectPacket::create(), client_id, io)?;
}
self.on_disconnect(client_id);
self.conn_cache.remove(client_id);
Ok(())
}
pub fn disconnect_all(&mut self, io: &mut Io) -> Result<()> {
debug!("server disconnecting all clients");
for id in self.conn_cache.ids() {
let Some(conn) = self.conn_cache.clients.get_mut(&id) else {
continue;
};
if conn.is_connected() {
self.disconnect(id, io)?;
}
}
Ok(())
}
pub fn connected_client_ids(&self) -> Vec<ClientId> {
self.conn_cache
.clients
.iter()
.filter_map(|(id, c)| c.is_connected().then_some(id))
.cloned()
.collect()
}
pub fn client_ids(&self) -> impl Iterator<Item = ClientId> + '_ {
self.conn_cache.clients.keys().copied()
}
pub fn num_connected_clients(&self) -> usize {
self.conn_cache
.clients
.iter()
.filter(|(_, c)| c.is_connected())
.count()
}
pub fn client_addr(&self, client_id: ClientId) -> Option<SocketAddr> {
self.conn_cache.clients.get(&client_id).map(|c| c.addr)
}
}