1use std::{error::Error, fmt, net::SocketAddr, time::Duration};
2
3use crate::{
4 packet::Packet, replay_protection::ReplayProtection, token::ConnectToken, NetcodeError, NETCODE_CHALLENGE_TOKEN_BYTES,
5 NETCODE_KEY_BYTES, NETCODE_MAX_PACKET_BYTES, NETCODE_MAX_PAYLOAD_BYTES, NETCODE_SEND_RATE, NETCODE_USER_DATA_BYTES,
6};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum DisconnectReason {
11 ConnectTokenExpired,
12 ConnectionTimedOut,
13 ConnectionResponseTimedOut,
14 ConnectionRequestTimedOut,
15 ConnectionDenied,
16 DisconnectedByClient,
17 DisconnectedByServer,
18}
19
20#[derive(Debug, PartialEq, Eq)]
21enum ClientState {
22 Disconnected(DisconnectReason),
23 SendingConnectionRequest,
24 SendingConnectionResponse,
25 Connected,
26}
27
28#[derive(Debug, Clone)]
30#[allow(clippy::large_enum_variant)]
31pub enum ClientAuthentication {
32 Secure { connect_token: ConnectToken },
36 Unsecure {
40 protocol_id: u64,
41 client_id: u64,
42 server_addr: SocketAddr,
43 user_data: Option<[u8; NETCODE_USER_DATA_BYTES]>,
44 },
45}
46
47#[derive(Debug)]
52pub struct NetcodeClient {
53 state: ClientState,
54 client_id: u64,
55 connect_start_time: Duration,
56 last_packet_send_time: Option<Duration>,
57 last_packet_received_time: Duration,
58 current_time: Duration,
59 sequence: u64,
60 server_addr: SocketAddr,
61 server_addr_index: usize,
62 connect_token: ConnectToken,
63 challenge_token_sequence: u64,
64 challenge_token_data: [u8; NETCODE_CHALLENGE_TOKEN_BYTES],
65 max_clients: u32,
66 client_index: u32,
67 send_rate: Duration,
68 replay_protection: ReplayProtection,
69 out: [u8; NETCODE_MAX_PACKET_BYTES],
70}
71
72impl fmt::Display for DisconnectReason {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 use DisconnectReason::*;
75
76 match *self {
77 ConnectTokenExpired => write!(f, "connection token has expired"),
78 ConnectionTimedOut => write!(f, "connection timed out"),
79 ConnectionResponseTimedOut => write!(f, "connection timed out during response step"),
80 ConnectionRequestTimedOut => write!(f, "connection timed out during request step"),
81 ConnectionDenied => write!(f, "server denied connection"),
82 DisconnectedByClient => write!(f, "connection terminated by client"),
83 DisconnectedByServer => write!(f, "connection terminated by server"),
84 }
85 }
86}
87
88impl Error for DisconnectReason {}
89
90impl NetcodeClient {
91 pub fn new(current_time: Duration, authentication: ClientAuthentication) -> Result<Self, NetcodeError> {
92 let connect_token: ConnectToken = match authentication {
93 ClientAuthentication::Unsecure {
94 server_addr,
95 protocol_id,
96 client_id,
97 user_data,
98 } => ConnectToken::generate(
99 current_time,
100 protocol_id,
101 300,
102 client_id,
103 15,
104 vec![server_addr],
105 user_data.as_ref(),
106 &[0; NETCODE_KEY_BYTES],
107 )?,
108 ClientAuthentication::Secure { connect_token } => connect_token,
109 };
110
111 let server_addr = connect_token.server_addresses[0].expect("cannot create or deserialize a ConnectToken without a server address");
112
113 Ok(Self {
114 sequence: 0,
115 client_id: connect_token.client_id,
116 server_addr,
117 server_addr_index: 0,
118 challenge_token_sequence: 0,
119 state: ClientState::SendingConnectionRequest,
120 connect_start_time: current_time,
121 last_packet_send_time: None,
122 last_packet_received_time: current_time,
123 current_time,
124 max_clients: 0,
125 client_index: 0,
126 send_rate: NETCODE_SEND_RATE,
127 challenge_token_data: [0u8; NETCODE_CHALLENGE_TOKEN_BYTES],
128 connect_token,
129 replay_protection: ReplayProtection::new(),
130 out: [0u8; NETCODE_MAX_PACKET_BYTES],
131 })
132 }
133
134 pub fn is_connecting(&self) -> bool {
135 matches!(
136 self.state,
137 ClientState::SendingConnectionRequest | ClientState::SendingConnectionResponse
138 )
139 }
140
141 pub fn is_connected(&self) -> bool {
142 self.state == ClientState::Connected
143 }
144
145 pub fn is_disconnected(&self) -> bool {
146 matches!(self.state, ClientState::Disconnected(_))
147 }
148
149 pub fn current_time(&self) -> Duration {
150 self.current_time
151 }
152
153 pub fn client_id(&self) -> u64 {
154 self.client_id
155 }
156
157 pub fn time_since_last_received_packet(&self) -> Duration {
160 self.current_time - self.last_packet_received_time
161 }
162
163 pub fn disconnect_reason(&self) -> Option<DisconnectReason> {
165 if let ClientState::Disconnected(reason) = &self.state {
166 return Some(*reason);
167 }
168 None
169 }
170
171 pub fn server_addr(&self) -> SocketAddr {
173 self.server_addr
174 }
175
176 pub fn disconnect(&mut self) -> Result<(SocketAddr, &mut [u8]), NetcodeError> {
179 self.state = ClientState::Disconnected(DisconnectReason::DisconnectedByClient);
180 let packet = Packet::Disconnect;
181 let len = packet.encode(
182 &mut self.out,
183 self.connect_token.protocol_id,
184 Some((self.sequence, &self.connect_token.client_to_server_key)),
185 )?;
186
187 Ok((self.server_addr, &mut self.out[..len]))
188 }
189
190 pub fn process_packet<'a>(&mut self, buffer: &'a mut [u8]) -> Option<&'a [u8]> {
194 let packet = match Packet::decode(
195 buffer,
196 self.connect_token.protocol_id,
197 Some(&self.connect_token.server_to_client_key),
198 Some(&mut self.replay_protection),
199 ) {
200 Ok((_, packet)) => packet,
201 Err(e) => {
202 log::error!("Failed to decode packet: {}", e);
203 return None;
204 }
205 };
206 log::trace!("Received packet from server: {:?}", packet.packet_type());
207
208 match (packet, &self.state) {
209 (Packet::ConnectionDenied, ClientState::SendingConnectionRequest | ClientState::SendingConnectionResponse) => {
210 self.state = ClientState::Disconnected(DisconnectReason::ConnectionDenied);
211 self.last_packet_received_time = self.current_time;
212 }
213 (
214 Packet::Challenge {
215 token_data,
216 token_sequence,
217 },
218 ClientState::SendingConnectionRequest,
219 ) => {
220 self.challenge_token_sequence = token_sequence;
221 self.last_packet_received_time = self.current_time;
222 self.last_packet_send_time = None;
223 self.challenge_token_data = token_data;
224 self.state = ClientState::SendingConnectionResponse;
225 }
226 (Packet::KeepAlive { .. }, ClientState::Connected) => {
227 self.last_packet_received_time = self.current_time;
228 }
229 (Packet::KeepAlive { client_index, max_clients }, ClientState::SendingConnectionResponse) => {
230 self.last_packet_received_time = self.current_time;
231 self.max_clients = max_clients;
232 self.client_index = client_index;
233 self.state = ClientState::Connected;
234 }
235 (Packet::Payload(p), ClientState::Connected) => {
236 self.last_packet_received_time = self.current_time;
237 return Some(p);
238 }
239 (Packet::Disconnect, ClientState::Connected) => {
240 self.state = ClientState::Disconnected(DisconnectReason::DisconnectedByServer);
241 self.last_packet_received_time = self.current_time;
242 }
243 _ => {}
244 }
245
246 None
247 }
248
249 pub fn generate_payload_packet(&mut self, payload: &[u8]) -> Result<(SocketAddr, &mut [u8]), NetcodeError> {
251 if payload.len() > NETCODE_MAX_PAYLOAD_BYTES {
252 return Err(NetcodeError::PayloadAboveLimit);
253 }
254
255 if self.state != ClientState::Connected {
256 return Err(NetcodeError::ClientNotConnected);
257 }
258
259 let packet = Packet::Payload(payload);
260 let len = packet.encode(
261 &mut self.out,
262 self.connect_token.protocol_id,
263 Some((self.sequence, &self.connect_token.client_to_server_key)),
264 )?;
265 self.sequence += 1;
266 self.last_packet_send_time = Some(self.current_time);
267
268 Ok((self.server_addr, &mut self.out[..len]))
269 }
270
271 pub fn update(&mut self, duration: Duration) -> Option<(&mut [u8], SocketAddr)> {
274 if let Err(e) = self.update_internal_state(duration) {
275 log::error!("Failed to update client: {}", e);
276 return None;
277 }
278
279 self.generate_packet()
281 }
282
283 fn update_internal_state(&mut self, duration: Duration) -> Result<(), NetcodeError> {
284 self.current_time += duration;
285 let connection_timed_out = self.connect_token.timeout_seconds > 0
286 && (self.last_packet_received_time + Duration::from_secs(self.connect_token.timeout_seconds as u64) < self.current_time);
287
288 match self.state {
289 ClientState::SendingConnectionRequest | ClientState::SendingConnectionResponse => {
290 let expire_seconds = self.connect_token.expire_timestamp - self.connect_token.create_timestamp;
291 let connection_expired = (self.current_time - self.connect_start_time).as_secs() >= expire_seconds;
292 if connection_expired {
293 self.state = ClientState::Disconnected(DisconnectReason::ConnectTokenExpired);
294 return Err(NetcodeError::Expired);
295 }
296 if connection_timed_out {
297 let reason = if self.state == ClientState::SendingConnectionResponse {
298 DisconnectReason::ConnectionResponseTimedOut
299 } else {
300 DisconnectReason::ConnectionRequestTimedOut
301 };
302 self.state = ClientState::Disconnected(reason);
303 self.server_addr_index += 1;
305 if self.server_addr_index >= 32 {
306 return Err(NetcodeError::NoMoreServers);
307 }
308 match self.connect_token.server_addresses[self.server_addr_index] {
309 None => return Err(NetcodeError::NoMoreServers),
310 Some(server_address) => {
311 self.state = ClientState::SendingConnectionRequest;
312 self.server_addr = server_address;
313 self.connect_start_time = self.current_time;
314 self.last_packet_send_time = None;
315 self.last_packet_received_time = self.current_time;
316 self.challenge_token_sequence = 0;
317
318 return Ok(());
319 }
320 }
321 }
322 Ok(())
323 }
324 ClientState::Connected => {
325 if connection_timed_out {
326 self.state = ClientState::Disconnected(DisconnectReason::ConnectionTimedOut);
327 return Err(NetcodeError::Disconnected(DisconnectReason::ConnectionTimedOut));
328 }
329
330 Ok(())
331 }
332 ClientState::Disconnected(reason) => Err(NetcodeError::Disconnected(reason)),
333 }
334 }
335
336 fn generate_packet(&mut self) -> Option<(&mut [u8], SocketAddr)> {
337 if let Some(last_packet_send_time) = self.last_packet_send_time {
338 if self.current_time - last_packet_send_time < self.send_rate {
339 return None;
340 }
341 }
342
343 if matches!(
344 self.state,
345 ClientState::Connected | ClientState::SendingConnectionRequest | ClientState::SendingConnectionResponse
346 ) {
347 self.last_packet_send_time = Some(self.current_time);
348 }
349 let packet = match self.state {
350 ClientState::SendingConnectionRequest => Packet::connection_request_from_token(&self.connect_token),
351 ClientState::SendingConnectionResponse => Packet::Response {
352 token_sequence: self.challenge_token_sequence,
353 token_data: self.challenge_token_data,
354 },
355 ClientState::Connected => Packet::KeepAlive {
356 client_index: 0,
357 max_clients: 0,
358 },
359 _ => return None,
360 };
361
362 let result = packet.encode(
363 &mut self.out,
364 self.connect_token.protocol_id,
365 Some((self.sequence, &self.connect_token.client_to_server_key)),
366 );
367 match result {
368 Err(_) => None,
369 Ok(encoded) => {
370 self.sequence += 1;
371 Some((&mut self.out[..encoded], self.server_addr))
372 }
373 }
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use crate::{crypto::generate_random_bytes, NETCODE_MAX_PACKET_BYTES};
380
381 use super::*;
382
383 #[test]
384 fn client_connection() {
385 let mut buffer = [0u8; NETCODE_MAX_PACKET_BYTES];
386 let server_addresses: Vec<SocketAddr> = vec!["127.0.0.1:8080".parse().unwrap(), "127.0.0.2:3000".parse().unwrap()];
387 let user_data = generate_random_bytes();
388 let private_key = b"an example very very secret key."; let protocol_id = 2;
390 let expire_seconds = 3;
391 let client_id = 4;
392 let timeout_seconds = 5;
393 let connect_token = ConnectToken::generate(
394 Duration::ZERO,
395 protocol_id,
396 expire_seconds,
397 client_id,
398 timeout_seconds,
399 server_addresses,
400 Some(&user_data),
401 private_key,
402 )
403 .unwrap();
404 let server_key = connect_token.server_to_client_key;
405 let client_key = connect_token.client_to_server_key;
406 let authentication = ClientAuthentication::Secure { connect_token };
407 let mut client = NetcodeClient::new(Duration::ZERO, authentication).unwrap();
408 let (packet_buffer, _) = client.update(Duration::ZERO).unwrap();
409
410 let (r_sequence, packet) = Packet::decode(packet_buffer, protocol_id, None, None).unwrap();
411 assert_eq!(0, r_sequence);
412 assert!(matches!(packet, Packet::ConnectionRequest { .. }));
413
414 let challenge_sequence = 7;
415 let user_data = generate_random_bytes();
416 let challenge_key = generate_random_bytes();
417 let challenge_packet = Packet::generate_challenge(client_id, &user_data, challenge_sequence, &challenge_key).unwrap();
418 let len = challenge_packet.encode(&mut buffer, protocol_id, Some((0, &server_key))).unwrap();
419 client.process_packet(&mut buffer[..len]);
420 assert_eq!(ClientState::SendingConnectionResponse, client.state);
421
422 let (packet_buffer, _) = client.update(Duration::ZERO).unwrap();
423 let (_, packet) = Packet::decode(packet_buffer, protocol_id, Some(&client_key), None).unwrap();
424 assert!(matches!(packet, Packet::Response { .. }));
425
426 let max_clients = 4;
427 let client_index = 2;
428 let keep_alive_packet = Packet::KeepAlive { max_clients, client_index };
429 let len = keep_alive_packet.encode(&mut buffer, protocol_id, Some((1, &server_key))).unwrap();
430 client.process_packet(&mut buffer[..len]);
431
432 assert_eq!(client.state, ClientState::Connected);
433
434 let payload = vec![7u8; 500];
435 let payload_packet = Packet::Payload(&payload[..]);
436 let len = payload_packet.encode(&mut buffer, protocol_id, Some((2, &server_key))).unwrap();
437
438 let payload_client = client.process_packet(&mut buffer[..len]).unwrap();
439 assert_eq!(payload, payload_client);
440
441 let to_send_payload = vec![5u8; 1000];
442 let (_, packet) = client.generate_payload_packet(&to_send_payload).unwrap();
443 let (_, result) = Packet::decode(packet, protocol_id, Some(&client_key), None).unwrap();
444 match result {
445 Packet::Payload(payload) => assert_eq!(to_send_payload, payload),
446 _ => unreachable!(),
447 }
448 }
449}