1use std::{
2 collections::VecDeque,
3 net::{Ipv4Addr, SocketAddr},
4 time::{SystemTime, UNIX_EPOCH},
5};
6
7use crate::{
8 bytes::Bytes,
9 error::{Error, Result},
10 packet::{
11 DisconnectPacket, KeepAlivePacket, Packet, PayloadPacket, RequestPacket, ResponsePacket,
12 },
13 replay::ReplayProtection,
14 socket::NetcodeSocket,
15 token::{ChallengeToken, ConnectToken},
16 transceiver::Transceiver,
17 MAX_PACKET_SIZE, MAX_PKT_BUF_SIZE, PACKET_SEND_RATE_SEC,
18};
19
20const RECV_BUF_SIZE: usize = 256 * 1024;
21const SEND_BUF_SIZE: usize = 256 * 1024;
22
23type Callback<Ctx> = Box<dyn FnMut(ClientState, ClientState, &mut Ctx) + Send + Sync + 'static>;
24pub struct ClientConfig<Ctx> {
52 num_disconnect_packets: usize,
53 packet_send_rate: f64,
54 context: Ctx,
55 on_state_change: Option<Callback<Ctx>>,
56}
57
58impl Default for ClientConfig<()> {
59 fn default() -> Self {
60 Self {
61 num_disconnect_packets: 10,
62 packet_send_rate: PACKET_SEND_RATE_SEC,
63 context: (),
64 on_state_change: None,
65 }
66 }
67}
68
69impl<Ctx> ClientConfig<Ctx> {
70 pub fn new() -> ClientConfig<()> {
72 ClientConfig::<()>::default()
73 }
74 pub fn with_context(ctx: Ctx) -> Self {
76 Self {
77 num_disconnect_packets: 10,
78 packet_send_rate: PACKET_SEND_RATE_SEC,
79 context: ctx,
80 on_state_change: None,
81 }
82 }
83 pub fn num_disconnect_packets(mut self, num_disconnect_packets: usize) -> Self {
86 self.num_disconnect_packets = num_disconnect_packets;
87 self
88 }
89 pub fn packet_send_rate(mut self, rate_seconds: f64) -> Self {
92 self.packet_send_rate = rate_seconds;
93 self
94 }
95 pub fn on_state_change<F>(mut self, cb: F) -> Self
97 where
98 F: FnMut(ClientState, ClientState, &mut Ctx) + Send + Sync + 'static,
99 {
100 self.on_state_change = Some(Box::new(cb));
101 self
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
130pub enum ClientState {
131 ConnectTokenExpired,
133 ConnectionTimedOut,
136 ConnectionRequestTimedOut,
138 ChallengeResponseTimedOut,
140 ConnectionDenied,
142 Disconnected,
144 SendingConnectionRequest,
146 SendingChallengeResponse,
148 Connected,
150}
151
152pub struct Client<T: Transceiver, Ctx = ()> {
185 transceiver: T,
186 state: ClientState,
187 time: f64,
188 start_time: f64,
189 last_send_time: f64,
190 last_receive_time: f64,
191 server_addr_idx: usize,
192 sequence: u64,
193 challenge_token_sequence: u64,
194 challenge_token_data: [u8; ChallengeToken::SIZE],
195 client_index: i32,
196 max_clients: i32,
197 token: ConnectToken,
198 replay_protection: ReplayProtection,
199 should_disconnect: bool,
200 should_disconnect_state: ClientState,
201 packet_queue: VecDeque<Vec<u8>>,
202 cfg: ClientConfig<Ctx>,
203}
204
205impl<Trx: Transceiver, Ctx> Client<Trx, Ctx> {
206 fn from_token(token_bytes: &[u8], cfg: ClientConfig<Ctx>, trx: Trx) -> Result<Self> {
207 if token_bytes.len() != ConnectToken::SIZE {
208 return Err(Error::SizeMismatch(ConnectToken::SIZE, token_bytes.len()));
209 }
210 let mut buf = [0u8; ConnectToken::SIZE];
211 buf.copy_from_slice(token_bytes);
212 let mut cursor = std::io::Cursor::new(&mut buf[..]);
213 let token = match ConnectToken::read_from(&mut cursor) {
214 Ok(token) => token,
215 Err(err) => {
216 log::error!("invalid connect token: {err}");
217 return Err(Error::InvalidToken(err));
218 }
219 };
220 log::info!("client started on {}", trx.addr());
221 Ok(Self {
222 transceiver: trx,
223 state: ClientState::Disconnected,
224 time: 0.0,
225 start_time: 0.0,
226 last_send_time: f64::NEG_INFINITY,
227 last_receive_time: f64::NEG_INFINITY,
228 server_addr_idx: 0,
229 sequence: 0,
230 challenge_token_sequence: 0,
231 challenge_token_data: [0u8; ChallengeToken::SIZE],
232 client_index: 0,
233 max_clients: 0,
234 token,
235 replay_protection: ReplayProtection::new(),
236 should_disconnect: false,
237 should_disconnect_state: ClientState::Disconnected,
238 packet_queue: VecDeque::new(),
239 cfg,
240 })
241 }
242}
243
244impl Client<NetcodeSocket> {
245 pub fn new(token_bytes: &[u8]) -> Result<Self> {
261 let netcode_sock =
262 NetcodeSocket::new((Ipv4Addr::UNSPECIFIED, 0), SEND_BUF_SIZE, RECV_BUF_SIZE)?;
263 Client::from_token(token_bytes, ClientConfig::default(), netcode_sock)
264 }
265}
266
267impl<Ctx> Client<NetcodeSocket, Ctx> {
268 pub fn with_config(token_bytes: &[u8], cfg: ClientConfig<Ctx>) -> Result<Self> {
290 let netcode_sock =
291 NetcodeSocket::new((Ipv4Addr::UNSPECIFIED, 0), SEND_BUF_SIZE, RECV_BUF_SIZE)?;
292 Client::from_token(token_bytes, cfg, netcode_sock)
293 }
294}
295
296impl<T: Transceiver, Ctx> Client<T, Ctx> {
297 const ALLOWED_PACKETS: u8 = 1 << Packet::DENIED
298 | 1 << Packet::CHALLENGE
299 | 1 << Packet::KEEP_ALIVE
300 | 1 << Packet::PAYLOAD
301 | 1 << Packet::DISCONNECT;
302
303 fn set_state(&mut self, state: ClientState) {
304 log::debug!("client state changing from {:?} to {:?}", self.state, state);
305 if let Some(ref mut cb) = self.cfg.on_state_change {
306 cb(self.state, state, &mut self.cfg.context)
307 }
308 self.state = state;
309 }
310 fn reset_connection(&mut self) {
311 self.start_time = self.time;
312 self.last_send_time = self.time - 1.0; self.last_receive_time = self.time;
314 self.should_disconnect = false;
315 self.should_disconnect_state = ClientState::Disconnected;
316 self.challenge_token_sequence = 0;
317 self.replay_protection = ReplayProtection::new();
318 }
319 fn reset(&mut self, new_state: ClientState) {
320 self.sequence = 0;
321 self.client_index = 0;
322 self.max_clients = 0;
323 self.start_time = 0.0;
324 self.server_addr_idx = 0;
325 self.set_state(new_state);
326 self.reset_connection();
327 log::debug!("client disconnected");
328 }
329 fn send_packets(&mut self) -> Result<()> {
330 if self.last_send_time + self.cfg.packet_send_rate >= self.time {
331 return Ok(());
332 }
333 let packet = match self.state {
334 ClientState::SendingConnectionRequest => {
335 log::debug!("client sending connection request packet to server");
336 RequestPacket::create(
337 self.token.protocol_id,
338 self.token.expire_timestamp,
339 self.token.nonce,
340 self.token.private_data,
341 )
342 }
343 ClientState::SendingChallengeResponse => {
344 log::debug!("client sending connection response packet to server");
345 ResponsePacket::create(self.challenge_token_sequence, self.challenge_token_data)
346 }
347 ClientState::Connected => {
348 log::trace!("client sending connection keep-alive packet to server");
349 KeepAlivePacket::create(0, 0)
350 }
351 _ => return Ok(()),
352 };
353 self.send_packet(packet)
354 }
355 fn connect_to_next_server(&mut self) -> std::result::Result<(), ()> {
356 if self.server_addr_idx + 1 >= self.token.server_addresses.len() {
357 log::debug!("no more servers to connect to");
358 return Err(());
359 }
360 self.server_addr_idx += 1;
361 self.connect();
362 Ok(())
363 }
364 fn send_packet(&mut self, packet: Packet) -> Result<()> {
365 let mut buf = [0u8; MAX_PKT_BUF_SIZE];
366 let size = packet.write(
367 &mut buf,
368 self.sequence,
369 &self.token.client_to_server_key,
370 self.token.protocol_id,
371 )?;
372 let server_addr = self.token.server_addresses[self.server_addr_idx];
373 self.transceiver
374 .send(&buf[..size], server_addr)
375 .map_err(|e| e.into())?;
376 self.last_send_time = self.time;
377 self.sequence += 1;
378 Ok(())
379 }
380 fn process_packet(&mut self, addr: SocketAddr, packet: Packet) -> Result<()> {
381 if addr != self.token.server_addresses[self.server_addr_idx] {
382 return Ok(());
383 }
384 match (packet, self.state) {
385 (
386 Packet::Denied(_),
387 ClientState::SendingConnectionRequest | ClientState::SendingChallengeResponse,
388 ) => {
389 self.should_disconnect = true;
390 self.should_disconnect_state = ClientState::ConnectionDenied;
391 }
392 (Packet::Challenge(pkt), ClientState::SendingConnectionRequest) => {
393 log::debug!("client received connection challenge packet from server");
394 self.challenge_token_sequence = pkt.sequence;
395 self.challenge_token_data = pkt.token;
396 self.set_state(ClientState::SendingChallengeResponse);
397 }
398 (Packet::KeepAlive(_), ClientState::Connected) => {
399 log::trace!("client received connection keep-alive packet from server");
400 }
401 (Packet::KeepAlive(pkt), ClientState::SendingChallengeResponse) => {
402 log::debug!("client received connection keep-alive packet from server");
403 self.client_index = pkt.client_index;
404 self.max_clients = pkt.max_clients;
405 self.set_state(ClientState::Connected);
406 log::info!("client connected to server");
407 }
408 (Packet::Payload(pkt), ClientState::Connected) => {
409 log::debug!("client received payload packet from server");
410 self.packet_queue.push_back(pkt.buf.to_vec());
411 }
412 (Packet::Disconnect(_), ClientState::Connected) => {
413 log::debug!("client received disconnect packet from server");
414 self.should_disconnect = true;
415 self.should_disconnect_state = ClientState::Disconnected;
416 }
417 _ => return Ok(()),
418 }
419 self.last_receive_time = self.time;
420 Ok(())
421 }
422 fn update_state(&mut self) {
423 let is_token_expired = self.time - self.start_time
424 >= self.token.expire_timestamp as f64 - self.token.create_timestamp as f64;
425 let is_connection_timed_out = self.token.timeout_seconds.is_positive()
426 && (self.last_receive_time + (self.token.timeout_seconds as f64) < self.time);
427 let new_state = match self.state {
428 ClientState::SendingConnectionRequest | ClientState::SendingChallengeResponse
429 if is_token_expired =>
430 {
431 log::info!("client connect failed. connect token expired");
432 ClientState::ConnectTokenExpired
433 }
434 _ if self.should_disconnect => {
435 log::debug!(
436 "client should disconnect -> {:?}",
437 self.should_disconnect_state
438 );
439 if self.connect_to_next_server().is_ok() {
440 return;
441 };
442 self.should_disconnect_state
443 }
444 ClientState::SendingConnectionRequest if is_connection_timed_out => {
445 log::info!("client connect failed. connection request timed out");
446 if self.connect_to_next_server().is_ok() {
447 return;
448 };
449 ClientState::ConnectionRequestTimedOut
450 }
451 ClientState::SendingChallengeResponse if is_connection_timed_out => {
452 log::info!("client connect failed. connection response timed out");
453 if self.connect_to_next_server().is_ok() {
454 return;
455 };
456 ClientState::ChallengeResponseTimedOut
457 }
458 ClientState::Connected if is_connection_timed_out => {
459 log::info!("client connection timed out");
460 ClientState::ConnectionTimedOut
461 }
462 _ => return,
463 };
464 self.reset(new_state);
465 }
466 fn recv_packet(&mut self, buf: &mut [u8], now: u64, addr: SocketAddr) -> Result<()> {
467 if buf.len() <= 1 {
468 return Ok(());
470 }
471 let packet = match Packet::read(
472 buf,
473 self.token.protocol_id,
474 now,
475 self.token.server_to_client_key,
476 Some(&mut self.replay_protection),
477 Self::ALLOWED_PACKETS,
478 ) {
479 Ok(packet) => packet,
480 Err(Error::Crypto(_)) => {
481 log::debug!("client ignored packet because it failed to decrypt");
482 return Ok(());
483 }
484 Err(e) => {
485 log::error!("client ignored packet: {e}");
486 return Ok(());
487 }
488 };
489 self.process_packet(addr, packet)
490 }
491 fn recv_packets(&mut self) -> Result<()> {
492 let mut buf = [0u8; MAX_PACKET_SIZE];
493 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
494 while let Some((size, addr)) = self.transceiver.recv(&mut buf).map_err(|e| e.into())? {
495 self.recv_packet(&mut buf[..size], now, addr)?;
496 }
497 Ok(())
498 }
499 pub fn with_config_and_transceiver(
533 token_bytes: &[u8],
534 cfg: ClientConfig<Ctx>,
535 trx: T,
536 ) -> Result<Self> {
537 Client::from_token(token_bytes, cfg, trx)
538 }
539 pub fn connect(&mut self) {
543 self.reset_connection();
544 self.set_state(ClientState::SendingConnectionRequest);
545 log::info!(
546 "client connecting to server {} [{}/{}]",
547 self.token.server_addresses[self.server_addr_idx],
548 self.server_addr_idx + 1,
549 self.token.server_addresses.len()
550 );
551 }
552 pub fn update(&mut self, time: f64) {
565 self.try_update(time)
566 .expect("send/recv error while updating client")
567 }
568 pub fn try_update(&mut self, time: f64) -> Result<()> {
572 self.time = time;
573 self.recv_packets()?;
574 self.send_packets()?;
575 self.update_state();
576 Ok(())
577 }
578 pub fn recv(&mut self) -> Option<Vec<u8>> {
606 self.packet_queue.pop_front()
607 }
608 pub fn send(&mut self, buf: &[u8]) -> Result<()> {
612 if self.state != ClientState::Connected {
613 return Ok(());
614 }
615 if buf.len() > MAX_PACKET_SIZE {
616 return Err(Error::SizeMismatch(MAX_PACKET_SIZE, buf.len()));
617 }
618 self.send_packet(PayloadPacket::create(buf))?;
619 Ok(())
620 }
621 pub fn disconnect(&mut self) -> Result<()> {
625 log::debug!(
626 "client sending {} disconnect packets to server",
627 self.cfg.num_disconnect_packets
628 );
629 for _ in 0..self.cfg.num_disconnect_packets {
630 self.send_packet(DisconnectPacket::create())?;
631 }
632 self.reset(ClientState::Disconnected);
633 Ok(())
634 }
635 pub fn addr(&self) -> SocketAddr {
637 self.transceiver.addr()
638 }
639 pub fn state(&self) -> ClientState {
641 self.state
642 }
643 pub fn is_error(&self) -> bool {
645 self.state < ClientState::Disconnected
646 }
647 pub fn is_pending(&self) -> bool {
649 self.state == ClientState::SendingConnectionRequest
650 || self.state == ClientState::SendingChallengeResponse
651 }
652 pub fn is_connected(&self) -> bool {
654 self.state == ClientState::Connected
655 }
656 pub fn is_disconnected(&self) -> bool {
658 self.state == ClientState::Disconnected
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use byteorder::{LittleEndian, WriteBytesExt};
665 use chacha20poly1305::XNonce;
666
667 use super::*;
668 use crate::simulator::NetworkSimulator;
669 use crate::token::ConnectTokenPrivate;
670 use crate::{InvalidTokenError, NETCODE_VERSION};
671 use std::io::Write;
672 use std::mem::size_of;
673 impl Client<NetworkSimulator> {
674 pub(crate) fn with_simulator(token: ConnectToken, sim: NetworkSimulator) -> Result<Self> {
675 Client::with_config_and_transceiver(
676 &token.try_into_bytes()?,
677 ClientConfig::default(),
678 sim,
679 )
680 }
681 }
682
683 #[test]
684 fn invalid_connect_token() {
685 let mut token_bytes = [0u8; ConnectToken::SIZE];
686 let mut cursor = std::io::Cursor::new(&mut token_bytes[..]);
687 cursor.write_all(b"NETCODE VERSION 1.00\0").unwrap();
688 let res = Client::new(&token_bytes);
689 assert!(matches!(
690 res,
691 Err(Error::InvalidToken(InvalidTokenError::InvalidVersion))
692 ));
693 let mut token_bytes = [0u8; ConnectToken::SIZE];
694 let mut cursor = std::io::Cursor::new(&mut token_bytes[..]);
695 cursor.write_all(NETCODE_VERSION).unwrap();
696 let res = Client::new(&token_bytes);
697 assert!(matches!(
698 res,
699 Err(Error::InvalidToken(InvalidTokenError::AddressListLength(0)))
700 ));
701 let mut token_bytes = [0u8; ConnectToken::SIZE];
702 let mut cursor = std::io::Cursor::new(&mut token_bytes[..]);
703 cursor.write_all(NETCODE_VERSION).unwrap();
704 cursor.write_u64::<LittleEndian>(0).unwrap(); cursor.write_u64::<LittleEndian>(2).unwrap(); cursor.write_u64::<LittleEndian>(1).unwrap(); let res = Client::new(&token_bytes);
708 assert!(matches!(
709 res,
710 Err(Error::InvalidToken(InvalidTokenError::InvalidTimestamp))
711 ));
712 let mut token_bytes = [0u8; ConnectToken::SIZE];
713 let mut cursor = std::io::Cursor::new(&mut token_bytes[..]);
714 cursor.write_all(NETCODE_VERSION).unwrap();
715 cursor.write_u64::<LittleEndian>(0).unwrap(); cursor.write_u64::<LittleEndian>(0).unwrap(); cursor.write_u64::<LittleEndian>(0).unwrap(); cursor.write_all(&[0; size_of::<XNonce>()]).unwrap(); cursor.write_all(&[0; ConnectTokenPrivate::SIZE]).unwrap(); cursor.write_i32::<LittleEndian>(0).unwrap(); cursor.write_u32::<LittleEndian>(1).unwrap(); cursor.write_u8(3).unwrap(); let res = Client::new(&token_bytes);
724 assert!(matches!(
725 res,
726 Err(Error::InvalidToken(
727 InvalidTokenError::InvalidIpAddressType(3)
728 ))
729 ));
730 }
731}