use std::{
collections::VecDeque,
net::{Ipv4Addr, SocketAddr},
time::{SystemTime, UNIX_EPOCH},
};
use crate::{
bytes::Bytes,
error::{Error, Result},
packet::{
DisconnectPacket, KeepAlivePacket, Packet, PayloadPacket, RequestPacket, ResponsePacket,
},
replay::ReplayProtection,
socket::NetcodeSocket,
token::{ChallengeToken, ConnectToken},
transceiver::Transceiver,
MAX_PACKET_SIZE, MAX_PKT_BUF_SIZE, PACKET_SEND_RATE_SEC,
};
const RECV_BUF_SIZE: usize = 256 * 1024;
const SEND_BUF_SIZE: usize = 256 * 1024;
type Callback<Ctx> = Box<dyn FnMut(ClientState, ClientState, &mut Ctx) + Send + Sync + 'static>;
pub struct ClientConfig<Ctx> {
num_disconnect_packets: usize,
packet_send_rate: f64,
context: Ctx,
on_state_change: Option<Callback<Ctx>>,
}
impl Default for ClientConfig<()> {
fn default() -> Self {
Self {
num_disconnect_packets: 10,
packet_send_rate: PACKET_SEND_RATE_SEC,
context: (),
on_state_change: None,
}
}
}
impl<Ctx> ClientConfig<Ctx> {
pub fn new() -> ClientConfig<()> {
ClientConfig::<()>::default()
}
pub fn with_context(ctx: Ctx) -> Self {
Self {
num_disconnect_packets: 10,
packet_send_rate: PACKET_SEND_RATE_SEC,
context: ctx,
on_state_change: None,
}
}
pub fn num_disconnect_packets(mut self, num_disconnect_packets: usize) -> Self {
self.num_disconnect_packets = num_disconnect_packets;
self
}
pub fn packet_send_rate(mut self, rate_seconds: f64) -> Self {
self.packet_send_rate = rate_seconds;
self
}
pub fn on_state_change<F>(mut self, cb: F) -> Self
where
F: FnMut(ClientState, ClientState, &mut Ctx) + Send + Sync + 'static,
{
self.on_state_change = Some(Box::new(cb));
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ClientState {
ConnectTokenExpired,
ConnectionTimedOut,
ConnectionRequestTimedOut,
ChallengeResponseTimedOut,
ConnectionDenied,
Disconnected,
SendingConnectionRequest,
SendingChallengeResponse,
Connected,
}
pub struct Client<T: Transceiver, Ctx = ()> {
transceiver: T,
state: ClientState,
time: f64,
start_time: f64,
last_send_time: f64,
last_receive_time: f64,
server_addr_idx: usize,
sequence: u64,
challenge_token_sequence: u64,
challenge_token_data: [u8; ChallengeToken::SIZE],
client_index: i32,
max_clients: i32,
token: ConnectToken,
replay_protection: ReplayProtection,
should_disconnect: bool,
should_disconnect_state: ClientState,
packet_queue: VecDeque<Vec<u8>>,
cfg: ClientConfig<Ctx>,
}
impl<Trx: Transceiver, Ctx> Client<Trx, Ctx> {
fn from_token(token_bytes: &[u8], cfg: ClientConfig<Ctx>, trx: Trx) -> Result<Self> {
if token_bytes.len() != ConnectToken::SIZE {
return Err(Error::SizeMismatch(ConnectToken::SIZE, token_bytes.len()));
}
let mut buf = [0u8; ConnectToken::SIZE];
buf.copy_from_slice(token_bytes);
let mut cursor = std::io::Cursor::new(&mut buf[..]);
let token = match ConnectToken::read_from(&mut cursor) {
Ok(token) => token,
Err(err) => {
log::error!("invalid connect token: {err}");
return Err(Error::InvalidToken(err));
}
};
log::info!("client started on {}", trx.addr());
Ok(Self {
transceiver: trx,
state: ClientState::Disconnected,
time: 0.0,
start_time: 0.0,
last_send_time: f64::NEG_INFINITY,
last_receive_time: f64::NEG_INFINITY,
server_addr_idx: 0,
sequence: 0,
challenge_token_sequence: 0,
challenge_token_data: [0u8; ChallengeToken::SIZE],
client_index: 0,
max_clients: 0,
token,
replay_protection: ReplayProtection::new(),
should_disconnect: false,
should_disconnect_state: ClientState::Disconnected,
packet_queue: VecDeque::new(),
cfg,
})
}
}
impl Client<NetcodeSocket> {
pub fn new(token_bytes: &[u8]) -> Result<Self> {
let netcode_sock =
NetcodeSocket::new((Ipv4Addr::UNSPECIFIED, 0), SEND_BUF_SIZE, RECV_BUF_SIZE)?;
Client::from_token(token_bytes, ClientConfig::default(), netcode_sock)
}
}
impl<Ctx> Client<NetcodeSocket, Ctx> {
pub fn with_config(token_bytes: &[u8], cfg: ClientConfig<Ctx>) -> Result<Self> {
let netcode_sock =
NetcodeSocket::new((Ipv4Addr::UNSPECIFIED, 0), SEND_BUF_SIZE, RECV_BUF_SIZE)?;
Client::from_token(token_bytes, cfg, netcode_sock)
}
}
impl<T: Transceiver, Ctx> Client<T, Ctx> {
const ALLOWED_PACKETS: u8 = 1 << Packet::DENIED
| 1 << Packet::CHALLENGE
| 1 << Packet::KEEP_ALIVE
| 1 << Packet::PAYLOAD
| 1 << Packet::DISCONNECT;
fn set_state(&mut self, state: ClientState) {
log::debug!("client state changing from {:?} to {:?}", self.state, state);
if let Some(ref mut cb) = self.cfg.on_state_change {
cb(self.state, state, &mut self.cfg.context)
}
self.state = state;
}
fn reset_connection(&mut self) {
self.start_time = self.time;
self.last_send_time = self.time - 1.0; self.last_receive_time = self.time;
self.should_disconnect = false;
self.should_disconnect_state = ClientState::Disconnected;
self.challenge_token_sequence = 0;
self.replay_protection = ReplayProtection::new();
}
fn reset(&mut self, new_state: ClientState) {
self.sequence = 0;
self.client_index = 0;
self.max_clients = 0;
self.start_time = 0.0;
self.server_addr_idx = 0;
self.set_state(new_state);
self.reset_connection();
log::debug!("client disconnected");
}
fn send_packets(&mut self) -> Result<()> {
if self.last_send_time + self.cfg.packet_send_rate >= self.time {
return Ok(());
}
let packet = match self.state {
ClientState::SendingConnectionRequest => {
log::debug!("client sending connection request packet to server");
RequestPacket::create(
self.token.protocol_id,
self.token.expire_timestamp,
self.token.nonce,
self.token.private_data,
)
}
ClientState::SendingChallengeResponse => {
log::debug!("client sending connection response packet to server");
ResponsePacket::create(self.challenge_token_sequence, self.challenge_token_data)
}
ClientState::Connected => {
log::trace!("client sending connection keep-alive packet to server");
KeepAlivePacket::create(0, 0)
}
_ => return Ok(()),
};
self.send_packet(packet)
}
fn connect_to_next_server(&mut self) -> std::result::Result<(), ()> {
if self.server_addr_idx + 1 >= self.token.server_addresses.len() {
log::debug!("no more servers to connect to");
return Err(());
}
self.server_addr_idx += 1;
self.connect();
Ok(())
}
fn send_packet(&mut self, packet: Packet) -> Result<()> {
let mut buf = [0u8; MAX_PKT_BUF_SIZE];
let size = packet.write(
&mut buf,
self.sequence,
&self.token.client_to_server_key,
self.token.protocol_id,
)?;
let server_addr = self.token.server_addresses[self.server_addr_idx];
self.transceiver
.send(&buf[..size], server_addr)
.map_err(|e| e.into())?;
self.last_send_time = self.time;
self.sequence += 1;
Ok(())
}
fn process_packet(&mut self, addr: SocketAddr, packet: Packet) -> Result<()> {
if addr != self.token.server_addresses[self.server_addr_idx] {
return Ok(());
}
match (packet, self.state) {
(
Packet::Denied(_),
ClientState::SendingConnectionRequest | ClientState::SendingChallengeResponse,
) => {
self.should_disconnect = true;
self.should_disconnect_state = ClientState::ConnectionDenied;
}
(Packet::Challenge(pkt), ClientState::SendingConnectionRequest) => {
log::debug!("client received connection challenge packet from server");
self.challenge_token_sequence = pkt.sequence;
self.challenge_token_data = pkt.token;
self.set_state(ClientState::SendingChallengeResponse);
}
(Packet::KeepAlive(_), ClientState::Connected) => {
log::trace!("client received connection keep-alive packet from server");
}
(Packet::KeepAlive(pkt), ClientState::SendingChallengeResponse) => {
log::debug!("client received connection keep-alive packet from server");
self.client_index = pkt.client_index;
self.max_clients = pkt.max_clients;
self.set_state(ClientState::Connected);
log::info!("client connected to server");
}
(Packet::Payload(pkt), ClientState::Connected) => {
log::debug!("client received payload packet from server");
self.packet_queue.push_back(pkt.buf.to_vec());
}
(Packet::Disconnect(_), ClientState::Connected) => {
log::debug!("client received disconnect packet from server");
self.should_disconnect = true;
self.should_disconnect_state = ClientState::Disconnected;
}
_ => return Ok(()),
}
self.last_receive_time = self.time;
Ok(())
}
fn update_state(&mut self) {
let is_token_expired = self.time - self.start_time
>= self.token.expire_timestamp as f64 - self.token.create_timestamp as f64;
let is_connection_timed_out = self.token.timeout_seconds.is_positive()
&& (self.last_receive_time + (self.token.timeout_seconds as f64) < self.time);
let new_state = match self.state {
ClientState::SendingConnectionRequest | ClientState::SendingChallengeResponse
if is_token_expired =>
{
log::info!("client connect failed. connect token expired");
ClientState::ConnectTokenExpired
}
_ if self.should_disconnect => {
log::debug!(
"client should disconnect -> {:?}",
self.should_disconnect_state
);
if self.connect_to_next_server().is_ok() {
return;
};
self.should_disconnect_state
}
ClientState::SendingConnectionRequest if is_connection_timed_out => {
log::info!("client connect failed. connection request timed out");
if self.connect_to_next_server().is_ok() {
return;
};
ClientState::ConnectionRequestTimedOut
}
ClientState::SendingChallengeResponse if is_connection_timed_out => {
log::info!("client connect failed. connection response timed out");
if self.connect_to_next_server().is_ok() {
return;
};
ClientState::ChallengeResponseTimedOut
}
ClientState::Connected if is_connection_timed_out => {
log::info!("client connection timed out");
ClientState::ConnectionTimedOut
}
_ => return,
};
self.reset(new_state);
}
fn recv_packet(&mut self, buf: &mut [u8], now: u64, addr: SocketAddr) -> Result<()> {
if buf.len() <= 1 {
return Ok(());
}
let packet = match Packet::read(
buf,
self.token.protocol_id,
now,
self.token.server_to_client_key,
Some(&mut self.replay_protection),
Self::ALLOWED_PACKETS,
) {
Ok(packet) => packet,
Err(Error::Crypto(_)) => {
log::debug!("client ignored packet because it failed to decrypt");
return Ok(());
}
Err(e) => {
log::error!("client ignored packet: {e}");
return Ok(());
}
};
self.process_packet(addr, packet)
}
fn recv_packets(&mut self) -> Result<()> {
let mut buf = [0u8; MAX_PACKET_SIZE];
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
while let Some((size, addr)) = self.transceiver.recv(&mut buf).map_err(|e| e.into())? {
self.recv_packet(&mut buf[..size], now, addr)?;
}
Ok(())
}
pub fn with_config_and_transceiver(
token_bytes: &[u8],
cfg: ClientConfig<Ctx>,
trx: T,
) -> Result<Self> {
Client::from_token(token_bytes, cfg, trx)
}
pub fn connect(&mut self) {
self.reset_connection();
self.set_state(ClientState::SendingConnectionRequest);
log::info!(
"client connecting to server {} [{}/{}]",
self.token.server_addresses[self.server_addr_idx],
self.server_addr_idx + 1,
self.token.server_addresses.len()
);
}
pub fn update(&mut self, time: f64) {
self.try_update(time)
.expect("send/recv error while updating client")
}
pub fn try_update(&mut self, time: f64) -> Result<()> {
self.time = time;
self.recv_packets()?;
self.send_packets()?;
self.update_state();
Ok(())
}
pub fn recv(&mut self) -> Option<Vec<u8>> {
self.packet_queue.pop_front()
}
pub fn send(&mut self, buf: &[u8]) -> Result<()> {
if self.state != ClientState::Connected {
return Ok(());
}
if buf.len() > MAX_PACKET_SIZE {
return Err(Error::SizeMismatch(MAX_PACKET_SIZE, buf.len()));
}
self.send_packet(PayloadPacket::create(buf))?;
Ok(())
}
pub fn disconnect(&mut self) -> Result<()> {
log::debug!(
"client sending {} disconnect packets to server",
self.cfg.num_disconnect_packets
);
for _ in 0..self.cfg.num_disconnect_packets {
self.send_packet(DisconnectPacket::create())?;
}
self.reset(ClientState::Disconnected);
Ok(())
}
pub fn addr(&self) -> SocketAddr {
self.transceiver.addr()
}
pub fn state(&self) -> ClientState {
self.state
}
pub fn is_error(&self) -> bool {
self.state < ClientState::Disconnected
}
pub fn is_pending(&self) -> bool {
self.state == ClientState::SendingConnectionRequest
|| self.state == ClientState::SendingChallengeResponse
}
pub fn is_connected(&self) -> bool {
self.state == ClientState::Connected
}
pub fn is_disconnected(&self) -> bool {
self.state == ClientState::Disconnected
}
}
#[cfg(test)]
mod tests {
use byteorder::{LittleEndian, WriteBytesExt};
use chacha20poly1305::XNonce;
use super::*;
use crate::simulator::NetworkSimulator;
use crate::token::ConnectTokenPrivate;
use crate::{InvalidTokenError, NETCODE_VERSION};
use std::io::Write;
use std::mem::size_of;
impl Client<NetworkSimulator> {
pub(crate) fn with_simulator(token: ConnectToken, sim: NetworkSimulator) -> Result<Self> {
Client::with_config_and_transceiver(
&token.try_into_bytes()?,
ClientConfig::default(),
sim,
)
}
}
#[test]
fn invalid_connect_token() {
let mut token_bytes = [0u8; ConnectToken::SIZE];
let mut cursor = std::io::Cursor::new(&mut token_bytes[..]);
cursor.write_all(b"NETCODE VERSION 1.00\0").unwrap();
let res = Client::new(&token_bytes);
assert!(matches!(
res,
Err(Error::InvalidToken(InvalidTokenError::InvalidVersion))
));
let mut token_bytes = [0u8; ConnectToken::SIZE];
let mut cursor = std::io::Cursor::new(&mut token_bytes[..]);
cursor.write_all(NETCODE_VERSION).unwrap();
let res = Client::new(&token_bytes);
assert!(matches!(
res,
Err(Error::InvalidToken(InvalidTokenError::AddressListLength(0)))
));
let mut token_bytes = [0u8; ConnectToken::SIZE];
let mut cursor = std::io::Cursor::new(&mut token_bytes[..]);
cursor.write_all(NETCODE_VERSION).unwrap();
cursor.write_u64::<LittleEndian>(0).unwrap(); cursor.write_u64::<LittleEndian>(2).unwrap(); cursor.write_u64::<LittleEndian>(1).unwrap(); let res = Client::new(&token_bytes);
assert!(matches!(
res,
Err(Error::InvalidToken(InvalidTokenError::InvalidTimestamp))
));
let mut token_bytes = [0u8; ConnectToken::SIZE];
let mut cursor = std::io::Cursor::new(&mut token_bytes[..]);
cursor.write_all(NETCODE_VERSION).unwrap();
cursor.write_u64::<LittleEndian>(0).unwrap(); cursor.write_u64::<LittleEndian>(0).unwrap(); cursor.write_u64::<LittleEndian>(0).unwrap(); cursor.write_all(&[0; size_of::<XNonce>()]).unwrap(); cursor.write_all(&[0; ConnectTokenPrivate::SIZE]).unwrap(); cursor.write_i32::<LittleEndian>(0).unwrap(); cursor.write_u32::<LittleEndian>(1).unwrap(); cursor.write_u8(3).unwrap(); let res = Client::new(&token_bytes);
assert!(matches!(
res,
Err(Error::InvalidToken(
InvalidTokenError::InvalidIpAddressType(3)
))
));
}
}