use std::{
collections::VecDeque,
net::SocketAddr,
time::{SystemTime, UNIX_EPOCH},
};
use tracing::{debug, error, info, trace};
use crate::serialize::reader::ReadBuffer;
use crate::serialize::wordbuffer::reader::ReadWordBuffer;
use crate::transport::io::Io;
use crate::transport::{PacketReceiver, PacketSender};
use super::{
bytes::Bytes,
error::{Error, Result},
packet::{
DisconnectPacket, KeepAlivePacket, Packet, PayloadPacket, RequestPacket, ResponsePacket,
},
replay::ReplayProtection,
token::{ChallengeToken, ConnectToken},
MAX_PACKET_SIZE, MAX_PKT_BUF_SIZE, PACKET_SEND_RATE_SEC,
};
type Callback<Ctx> = Box<dyn FnMut(ClientState, ClientState, &mut Ctx) + Send + Sync + 'static>;
pub struct ClientConfig<Ctx> {
num_disconnect_packets: usize,
packet_send_rate: f64,
context: Ctx,
on_state_change: Option<Callback<Ctx>>,
}
impl Default for ClientConfig<()> {
fn default() -> Self {
Self {
num_disconnect_packets: 10,
packet_send_rate: PACKET_SEND_RATE_SEC,
context: (),
on_state_change: None,
}
}
}
impl<Ctx> ClientConfig<Ctx> {
pub fn new() -> ClientConfig<()> {
ClientConfig::<()>::default()
}
pub fn with_context(ctx: Ctx) -> Self {
Self {
num_disconnect_packets: 10,
packet_send_rate: PACKET_SEND_RATE_SEC,
context: ctx,
on_state_change: None,
}
}
pub fn num_disconnect_packets(mut self, num_disconnect_packets: usize) -> Self {
self.num_disconnect_packets = num_disconnect_packets;
self
}
pub fn packet_send_rate(mut self, rate_seconds: f64) -> Self {
self.packet_send_rate = rate_seconds;
self
}
pub fn on_state_change<F>(mut self, cb: F) -> Self
where
F: FnMut(ClientState, ClientState, &mut Ctx) + Send + Sync + 'static,
{
self.on_state_change = Some(Box::new(cb));
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ClientState {
ConnectTokenExpired,
ConnectionTimedOut,
ConnectionRequestTimedOut,
ChallengeResponseTimedOut,
ConnectionDenied,
Disconnected,
SendingConnectionRequest,
SendingChallengeResponse,
Connected,
}
pub struct Client<Ctx = ()> {
state: ClientState,
time: f64,
start_time: f64,
last_send_time: f64,
last_receive_time: f64,
server_addr_idx: usize,
sequence: u64,
challenge_token_sequence: u64,
challenge_token_data: [u8; ChallengeToken::SIZE],
token: ConnectToken,
replay_protection: ReplayProtection,
should_disconnect: bool,
should_disconnect_state: ClientState,
packet_queue: VecDeque<ReadWordBuffer>,
cfg: ClientConfig<Ctx>,
}
impl<Ctx> Client<Ctx> {
fn from_token(token_bytes: &[u8], cfg: ClientConfig<Ctx>) -> Result<Self> {
if token_bytes.len() != ConnectToken::SIZE {
return Err(Error::SizeMismatch(ConnectToken::SIZE, token_bytes.len()));
}
let mut buf = [0u8; ConnectToken::SIZE];
buf.copy_from_slice(token_bytes);
let mut cursor = std::io::Cursor::new(&mut buf[..]);
let token = match ConnectToken::read_from(&mut cursor) {
Ok(token) => token,
Err(err) => {
error!("invalid connect token: {err}");
return Err(Error::InvalidToken(err));
}
};
Ok(Self {
state: ClientState::Disconnected,
time: 0.0,
start_time: 0.0,
last_send_time: f64::NEG_INFINITY,
last_receive_time: f64::NEG_INFINITY,
server_addr_idx: 0,
sequence: 0,
challenge_token_sequence: 0,
challenge_token_data: [0u8; ChallengeToken::SIZE],
token,
replay_protection: ReplayProtection::new(),
should_disconnect: false,
should_disconnect_state: ClientState::Disconnected,
packet_queue: VecDeque::new(),
cfg,
})
}
}
impl Client {
pub fn new(token_bytes: &[u8]) -> Result<Self> {
let client = Client::from_token(token_bytes, ClientConfig::default())?;
Ok(client)
}
}
impl<Ctx> Client<Ctx> {
pub fn with_config(token_bytes: &[u8], cfg: ClientConfig<Ctx>) -> Result<Self> {
let client = Client::from_token(token_bytes, cfg)?;
Ok(client)
}
}
impl<Ctx> Client<Ctx> {
const ALLOWED_PACKETS: u8 = 1 << Packet::DENIED
| 1 << Packet::CHALLENGE
| 1 << Packet::KEEP_ALIVE
| 1 << Packet::PAYLOAD
| 1 << Packet::DISCONNECT;
fn set_state(&mut self, state: ClientState) {
debug!("client state changing from {:?} to {:?}", self.state, state);
if let Some(ref mut cb) = self.cfg.on_state_change {
cb(self.state, state, &mut self.cfg.context)
}
self.state = state;
}
fn reset_connection(&mut self) {
self.start_time = self.time;
self.last_send_time = self.time - 1.0; self.last_receive_time = self.time;
self.should_disconnect = false;
self.should_disconnect_state = ClientState::Disconnected;
self.challenge_token_sequence = 0;
self.replay_protection = ReplayProtection::new();
}
fn reset(&mut self, new_state: ClientState) {
self.sequence = 0;
self.start_time = 0.0;
self.server_addr_idx = 0;
self.set_state(new_state);
self.reset_connection();
debug!("client disconnected");
}
fn send_packets(&mut self, io: &mut Io) -> Result<()> {
if self.last_send_time + self.cfg.packet_send_rate >= self.time {
return Ok(());
}
let packet = match self.state {
ClientState::SendingConnectionRequest => {
debug!("client sending connection request packet to server");
RequestPacket::create(
self.token.protocol_id,
self.token.expire_timestamp,
self.token.nonce,
self.token.private_data,
)
}
ClientState::SendingChallengeResponse => {
debug!("client sending connection response packet to server");
ResponsePacket::create(self.challenge_token_sequence, self.challenge_token_data)
}
ClientState::Connected => {
trace!("client sending connection keep-alive packet to server");
KeepAlivePacket::create(0)
}
_ => return Ok(()),
};
self.send_packet(packet, io)
}
fn connect_to_next_server(&mut self) -> std::result::Result<(), ()> {
if self.server_addr_idx + 1 >= self.token.server_addresses.len() {
debug!("no more servers to connect to");
return Err(());
}
self.server_addr_idx += 1;
self.connect();
Ok(())
}
fn send_packet(&mut self, packet: Packet, io: &mut Io) -> Result<()> {
let mut buf = [0u8; MAX_PKT_BUF_SIZE];
let size = packet.write(
&mut buf,
self.sequence,
&self.token.client_to_server_key,
self.token.protocol_id,
)?;
io.send(&buf[..size], &self.server_addr())
.map_err(Error::from)?;
self.last_send_time = self.time;
self.sequence += 1;
Ok(())
}
pub fn server_addr(&self) -> SocketAddr {
self.token.server_addresses[self.server_addr_idx]
}
fn process_packet(&mut self, addr: SocketAddr, packet: Packet) -> Result<()> {
if addr != self.server_addr() {
return Ok(());
}
match (packet, self.state) {
(
Packet::Denied(_),
ClientState::SendingConnectionRequest | ClientState::SendingChallengeResponse,
) => {
self.should_disconnect = true;
self.should_disconnect_state = ClientState::ConnectionDenied;
}
(Packet::Challenge(pkt), ClientState::SendingConnectionRequest) => {
debug!("client received connection challenge packet from server");
self.challenge_token_sequence = pkt.sequence;
self.challenge_token_data = pkt.token;
self.set_state(ClientState::SendingChallengeResponse);
}
(Packet::KeepAlive(_), ClientState::Connected) => {
trace!("client received connection keep-alive packet from server");
}
(Packet::KeepAlive(_), ClientState::SendingChallengeResponse) => {
debug!("client received connection keep-alive packet from server");
self.set_state(ClientState::Connected);
info!("client connected to server");
}
(Packet::Payload(pkt), ClientState::Connected) => {
trace!("client received payload packet from server");
let reader = ReadWordBuffer::start_read(pkt.buf);
self.packet_queue.push_back(reader);
}
(Packet::Disconnect(_), ClientState::Connected) => {
debug!("client received disconnect packet from server");
self.should_disconnect = true;
self.should_disconnect_state = ClientState::Disconnected;
}
_ => return Ok(()),
}
self.last_receive_time = self.time;
Ok(())
}
fn update_state(&mut self) {
let is_token_expired = self.time - self.start_time
>= self.token.expire_timestamp as f64 - self.token.create_timestamp as f64;
let is_connection_timed_out = self.token.timeout_seconds.is_positive()
&& (self.last_receive_time + (self.token.timeout_seconds as f64) < self.time);
let new_state = match self.state {
ClientState::SendingConnectionRequest | ClientState::SendingChallengeResponse
if is_token_expired =>
{
info!("client connect failed. connect token expired");
ClientState::ConnectTokenExpired
}
_ if self.should_disconnect => {
debug!(
"client should disconnect -> {:?}",
self.should_disconnect_state
);
if self.connect_to_next_server().is_ok() {
return;
};
self.should_disconnect_state
}
ClientState::SendingConnectionRequest if is_connection_timed_out => {
info!("client connect failed. connection request timed out");
if self.connect_to_next_server().is_ok() {
return;
};
ClientState::ConnectionRequestTimedOut
}
ClientState::SendingChallengeResponse if is_connection_timed_out => {
info!("client connect failed. connection response timed out");
if self.connect_to_next_server().is_ok() {
return;
};
ClientState::ChallengeResponseTimedOut
}
ClientState::Connected if is_connection_timed_out => {
info!("client connection timed out");
ClientState::ConnectionTimedOut
}
_ => return,
};
self.reset(new_state);
}
fn recv_packet(&mut self, buf: &mut [u8], now: u64, addr: SocketAddr) -> Result<()> {
if buf.len() <= 1 {
return Ok(());
}
let packet = match Packet::read(
buf,
self.token.protocol_id,
now,
self.token.server_to_client_key,
Some(&mut self.replay_protection),
Self::ALLOWED_PACKETS,
) {
Ok(packet) => packet,
Err(Error::Crypto(_)) => {
debug!("client ignored packet because it failed to decrypt");
return Ok(());
}
Err(e) => {
error!("client ignored packet: {e}");
return Ok(());
}
};
self.process_packet(addr, packet)
}
fn recv_packets(&mut self, io: &mut Io) -> Result<()> {
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
while let Some((buf, addr)) = io.recv().map_err(Error::from)? {
self.recv_packet(buf, now, addr)?;
}
Ok(())
}
pub fn connect(&mut self) {
self.reset_connection();
self.set_state(ClientState::SendingConnectionRequest);
info!(
"client connecting to server {} [{}/{}]",
self.token.server_addresses[self.server_addr_idx],
self.server_addr_idx + 1,
self.token.server_addresses.len()
);
}
pub fn update(&mut self, delta_ms: f64, io: &mut Io) {
self.try_update(delta_ms, io)
.expect("send/recv error while updating client")
}
pub fn try_update(&mut self, delta_ms: f64, io: &mut Io) -> Result<()> {
self.time += delta_ms;
self.recv_packets(io)?;
self.send_packets(io)?;
self.update_state();
Ok(())
}
pub fn recv(&mut self) -> Option<ReadWordBuffer> {
self.packet_queue.pop_front()
}
pub fn send(&mut self, buf: &[u8], io: &mut Io) -> Result<()> {
if self.state != ClientState::Connected {
return Ok(());
}
if buf.len() > MAX_PACKET_SIZE {
return Err(Error::SizeMismatch(MAX_PACKET_SIZE, buf.len()));
}
self.send_packet(PayloadPacket::create(buf), io)?;
Ok(())
}
pub fn disconnect(&mut self, io: &mut Io) -> Result<()> {
debug!(
"client sending {} disconnect packets to server",
self.cfg.num_disconnect_packets
);
for _ in 0..self.cfg.num_disconnect_packets {
self.send_packet(DisconnectPacket::create(), io)?;
}
self.reset(ClientState::Disconnected);
Ok(())
}
pub fn state(&self) -> ClientState {
self.state
}
pub fn is_error(&self) -> bool {
self.state < ClientState::Disconnected
}
pub fn is_pending(&self) -> bool {
self.state == ClientState::SendingConnectionRequest
|| self.state == ClientState::SendingChallengeResponse
}
pub fn is_connected(&self) -> bool {
self.state == ClientState::Connected
}
pub fn is_disconnected(&self) -> bool {
self.state == ClientState::Disconnected
}
}