renetcode/
client.rs

1use std::{error::Error, fmt, net::SocketAddr, time::Duration};
2
3use crate::{
4    packet::Packet, replay_protection::ReplayProtection, token::ConnectToken, NetcodeError, NETCODE_CHALLENGE_TOKEN_BYTES,
5    NETCODE_KEY_BYTES, NETCODE_MAX_PACKET_BYTES, NETCODE_MAX_PAYLOAD_BYTES, NETCODE_SEND_RATE, NETCODE_USER_DATA_BYTES,
6};
7
8/// The reason why a client is in error state
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum DisconnectReason {
11    ConnectTokenExpired,
12    ConnectionTimedOut,
13    ConnectionResponseTimedOut,
14    ConnectionRequestTimedOut,
15    ConnectionDenied,
16    DisconnectedByClient,
17    DisconnectedByServer,
18}
19
20#[derive(Debug, PartialEq, Eq)]
21enum ClientState {
22    Disconnected(DisconnectReason),
23    SendingConnectionRequest,
24    SendingConnectionResponse,
25    Connected,
26}
27
28/// Configuration to establish a secure or unsecure connection with the server.
29#[derive(Debug, Clone)]
30#[allow(clippy::large_enum_variant)]
31pub enum ClientAuthentication {
32    /// Establishes a safe connection with the server using the [crate::ConnectToken].
33    ///
34    /// See also [crate::ServerAuthentication::Secure]
35    Secure { connect_token: ConnectToken },
36    /// Establishes an unsafe connection with the server, useful for testing and prototyping.
37    ///
38    /// See also [crate::ServerAuthentication::Unsecure]
39    Unsecure {
40        protocol_id: u64,
41        client_id: u64,
42        server_addr: SocketAddr,
43        user_data: Option<[u8; NETCODE_USER_DATA_BYTES]>,
44    },
45}
46
47/// A client that can generate encrypted packets that be sent to the connected server, or consume
48/// encrypted packets from the server.
49/// The client is agnostic from the transport layer, only consuming and generating bytes
50/// that can be transported in any way desired.
51#[derive(Debug)]
52pub struct NetcodeClient {
53    state: ClientState,
54    client_id: u64,
55    connect_start_time: Duration,
56    last_packet_send_time: Option<Duration>,
57    last_packet_received_time: Duration,
58    current_time: Duration,
59    sequence: u64,
60    server_addr: SocketAddr,
61    server_addr_index: usize,
62    connect_token: ConnectToken,
63    challenge_token_sequence: u64,
64    challenge_token_data: [u8; NETCODE_CHALLENGE_TOKEN_BYTES],
65    max_clients: u32,
66    client_index: u32,
67    send_rate: Duration,
68    replay_protection: ReplayProtection,
69    out: [u8; NETCODE_MAX_PACKET_BYTES],
70}
71
72impl fmt::Display for DisconnectReason {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        use DisconnectReason::*;
75
76        match *self {
77            ConnectTokenExpired => write!(f, "connection token has expired"),
78            ConnectionTimedOut => write!(f, "connection timed out"),
79            ConnectionResponseTimedOut => write!(f, "connection timed out during response step"),
80            ConnectionRequestTimedOut => write!(f, "connection timed out during request step"),
81            ConnectionDenied => write!(f, "server denied connection"),
82            DisconnectedByClient => write!(f, "connection terminated by client"),
83            DisconnectedByServer => write!(f, "connection terminated by server"),
84        }
85    }
86}
87
88impl Error for DisconnectReason {}
89
90impl NetcodeClient {
91    pub fn new(current_time: Duration, authentication: ClientAuthentication) -> Result<Self, NetcodeError> {
92        let connect_token: ConnectToken = match authentication {
93            ClientAuthentication::Unsecure {
94                server_addr,
95                protocol_id,
96                client_id,
97                user_data,
98            } => ConnectToken::generate(
99                current_time,
100                protocol_id,
101                300,
102                client_id,
103                15,
104                vec![server_addr],
105                user_data.as_ref(),
106                &[0; NETCODE_KEY_BYTES],
107            )?,
108            ClientAuthentication::Secure { connect_token } => connect_token,
109        };
110
111        let server_addr = connect_token.server_addresses[0].expect("cannot create or deserialize a ConnectToken without a server address");
112
113        Ok(Self {
114            sequence: 0,
115            client_id: connect_token.client_id,
116            server_addr,
117            server_addr_index: 0,
118            challenge_token_sequence: 0,
119            state: ClientState::SendingConnectionRequest,
120            connect_start_time: current_time,
121            last_packet_send_time: None,
122            last_packet_received_time: current_time,
123            current_time,
124            max_clients: 0,
125            client_index: 0,
126            send_rate: NETCODE_SEND_RATE,
127            challenge_token_data: [0u8; NETCODE_CHALLENGE_TOKEN_BYTES],
128            connect_token,
129            replay_protection: ReplayProtection::new(),
130            out: [0u8; NETCODE_MAX_PACKET_BYTES],
131        })
132    }
133
134    pub fn is_connecting(&self) -> bool {
135        matches!(
136            self.state,
137            ClientState::SendingConnectionRequest | ClientState::SendingConnectionResponse
138        )
139    }
140
141    pub fn is_connected(&self) -> bool {
142        self.state == ClientState::Connected
143    }
144
145    pub fn is_disconnected(&self) -> bool {
146        matches!(self.state, ClientState::Disconnected(_))
147    }
148
149    pub fn current_time(&self) -> Duration {
150        self.current_time
151    }
152
153    pub fn client_id(&self) -> u64 {
154        self.client_id
155    }
156
157    /// Returns the duration since the client last received a packet.
158    /// Usefull to detect timeouts.
159    pub fn time_since_last_received_packet(&self) -> Duration {
160        self.current_time - self.last_packet_received_time
161    }
162
163    /// Returns the reason that the client was disconnected for.
164    pub fn disconnect_reason(&self) -> Option<DisconnectReason> {
165        if let ClientState::Disconnected(reason) = &self.state {
166            return Some(*reason);
167        }
168        None
169    }
170
171    /// Returns the current server address the client is connected or trying to connect.
172    pub fn server_addr(&self) -> SocketAddr {
173        self.server_addr
174    }
175
176    /// Disconnect the client from the server.
177    /// Returns a disconnect packet that should be sent to the server.
178    pub fn disconnect(&mut self) -> Result<(SocketAddr, &mut [u8]), NetcodeError> {
179        self.state = ClientState::Disconnected(DisconnectReason::DisconnectedByClient);
180        let packet = Packet::Disconnect;
181        let len = packet.encode(
182            &mut self.out,
183            self.connect_token.protocol_id,
184            Some((self.sequence, &self.connect_token.client_to_server_key)),
185        )?;
186
187        Ok((self.server_addr, &mut self.out[..len]))
188    }
189
190    /// Process any packet received from the server. This function might return a payload sent from the
191    /// server. If nothing is returned, it was a packet used for the internal protocol or an
192    /// invalid packet.
193    pub fn process_packet<'a>(&mut self, buffer: &'a mut [u8]) -> Option<&'a [u8]> {
194        let packet = match Packet::decode(
195            buffer,
196            self.connect_token.protocol_id,
197            Some(&self.connect_token.server_to_client_key),
198            Some(&mut self.replay_protection),
199        ) {
200            Ok((_, packet)) => packet,
201            Err(e) => {
202                log::error!("Failed to decode packet: {}", e);
203                return None;
204            }
205        };
206        log::trace!("Received packet from server: {:?}", packet.packet_type());
207
208        match (packet, &self.state) {
209            (Packet::ConnectionDenied, ClientState::SendingConnectionRequest | ClientState::SendingConnectionResponse) => {
210                self.state = ClientState::Disconnected(DisconnectReason::ConnectionDenied);
211                self.last_packet_received_time = self.current_time;
212            }
213            (
214                Packet::Challenge {
215                    token_data,
216                    token_sequence,
217                },
218                ClientState::SendingConnectionRequest,
219            ) => {
220                self.challenge_token_sequence = token_sequence;
221                self.last_packet_received_time = self.current_time;
222                self.last_packet_send_time = None;
223                self.challenge_token_data = token_data;
224                self.state = ClientState::SendingConnectionResponse;
225            }
226            (Packet::KeepAlive { .. }, ClientState::Connected) => {
227                self.last_packet_received_time = self.current_time;
228            }
229            (Packet::KeepAlive { client_index, max_clients }, ClientState::SendingConnectionResponse) => {
230                self.last_packet_received_time = self.current_time;
231                self.max_clients = max_clients;
232                self.client_index = client_index;
233                self.state = ClientState::Connected;
234            }
235            (Packet::Payload(p), ClientState::Connected) => {
236                self.last_packet_received_time = self.current_time;
237                return Some(p);
238            }
239            (Packet::Disconnect, ClientState::Connected) => {
240                self.state = ClientState::Disconnected(DisconnectReason::DisconnectedByServer);
241                self.last_packet_received_time = self.current_time;
242            }
243            _ => {}
244        }
245
246        None
247    }
248
249    /// Returns the server address and an encrypted payload packet that can be sent to the server.
250    pub fn generate_payload_packet(&mut self, payload: &[u8]) -> Result<(SocketAddr, &mut [u8]), NetcodeError> {
251        if payload.len() > NETCODE_MAX_PAYLOAD_BYTES {
252            return Err(NetcodeError::PayloadAboveLimit);
253        }
254
255        if self.state != ClientState::Connected {
256            return Err(NetcodeError::ClientNotConnected);
257        }
258
259        let packet = Packet::Payload(payload);
260        let len = packet.encode(
261            &mut self.out,
262            self.connect_token.protocol_id,
263            Some((self.sequence, &self.connect_token.client_to_server_key)),
264        )?;
265        self.sequence += 1;
266        self.last_packet_send_time = Some(self.current_time);
267
268        Ok((self.server_addr, &mut self.out[..len]))
269    }
270
271    /// Update the internal state of the client, receives the duration since last updated.
272    /// Might return the serve address and a protocol packet to be sent to the server.
273    pub fn update(&mut self, duration: Duration) -> Option<(&mut [u8], SocketAddr)> {
274        if let Err(e) = self.update_internal_state(duration) {
275            log::error!("Failed to update client: {}", e);
276            return None;
277        }
278
279        // Generate packet for the current state
280        self.generate_packet()
281    }
282
283    fn update_internal_state(&mut self, duration: Duration) -> Result<(), NetcodeError> {
284        self.current_time += duration;
285        let connection_timed_out = self.connect_token.timeout_seconds > 0
286            && (self.last_packet_received_time + Duration::from_secs(self.connect_token.timeout_seconds as u64) < self.current_time);
287
288        match self.state {
289            ClientState::SendingConnectionRequest | ClientState::SendingConnectionResponse => {
290                let expire_seconds = self.connect_token.expire_timestamp - self.connect_token.create_timestamp;
291                let connection_expired = (self.current_time - self.connect_start_time).as_secs() >= expire_seconds;
292                if connection_expired {
293                    self.state = ClientState::Disconnected(DisconnectReason::ConnectTokenExpired);
294                    return Err(NetcodeError::Expired);
295                }
296                if connection_timed_out {
297                    let reason = if self.state == ClientState::SendingConnectionResponse {
298                        DisconnectReason::ConnectionResponseTimedOut
299                    } else {
300                        DisconnectReason::ConnectionRequestTimedOut
301                    };
302                    self.state = ClientState::Disconnected(reason);
303                    // Try to connect to the next server address
304                    self.server_addr_index += 1;
305                    if self.server_addr_index >= 32 {
306                        return Err(NetcodeError::NoMoreServers);
307                    }
308                    match self.connect_token.server_addresses[self.server_addr_index] {
309                        None => return Err(NetcodeError::NoMoreServers),
310                        Some(server_address) => {
311                            self.state = ClientState::SendingConnectionRequest;
312                            self.server_addr = server_address;
313                            self.connect_start_time = self.current_time;
314                            self.last_packet_send_time = None;
315                            self.last_packet_received_time = self.current_time;
316                            self.challenge_token_sequence = 0;
317
318                            return Ok(());
319                        }
320                    }
321                }
322                Ok(())
323            }
324            ClientState::Connected => {
325                if connection_timed_out {
326                    self.state = ClientState::Disconnected(DisconnectReason::ConnectionTimedOut);
327                    return Err(NetcodeError::Disconnected(DisconnectReason::ConnectionTimedOut));
328                }
329
330                Ok(())
331            }
332            ClientState::Disconnected(reason) => Err(NetcodeError::Disconnected(reason)),
333        }
334    }
335
336    fn generate_packet(&mut self) -> Option<(&mut [u8], SocketAddr)> {
337        if let Some(last_packet_send_time) = self.last_packet_send_time {
338            if self.current_time - last_packet_send_time < self.send_rate {
339                return None;
340            }
341        }
342
343        if matches!(
344            self.state,
345            ClientState::Connected | ClientState::SendingConnectionRequest | ClientState::SendingConnectionResponse
346        ) {
347            self.last_packet_send_time = Some(self.current_time);
348        }
349        let packet = match self.state {
350            ClientState::SendingConnectionRequest => Packet::connection_request_from_token(&self.connect_token),
351            ClientState::SendingConnectionResponse => Packet::Response {
352                token_sequence: self.challenge_token_sequence,
353                token_data: self.challenge_token_data,
354            },
355            ClientState::Connected => Packet::KeepAlive {
356                client_index: 0,
357                max_clients: 0,
358            },
359            _ => return None,
360        };
361
362        let result = packet.encode(
363            &mut self.out,
364            self.connect_token.protocol_id,
365            Some((self.sequence, &self.connect_token.client_to_server_key)),
366        );
367        match result {
368            Err(_) => None,
369            Ok(encoded) => {
370                self.sequence += 1;
371                Some((&mut self.out[..encoded], self.server_addr))
372            }
373        }
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use crate::{crypto::generate_random_bytes, NETCODE_MAX_PACKET_BYTES};
380
381    use super::*;
382
383    #[test]
384    fn client_connection() {
385        let mut buffer = [0u8; NETCODE_MAX_PACKET_BYTES];
386        let server_addresses: Vec<SocketAddr> = vec!["127.0.0.1:8080".parse().unwrap(), "127.0.0.2:3000".parse().unwrap()];
387        let user_data = generate_random_bytes();
388        let private_key = b"an example very very secret key."; // 32-bytes
389        let protocol_id = 2;
390        let expire_seconds = 3;
391        let client_id = 4;
392        let timeout_seconds = 5;
393        let connect_token = ConnectToken::generate(
394            Duration::ZERO,
395            protocol_id,
396            expire_seconds,
397            client_id,
398            timeout_seconds,
399            server_addresses,
400            Some(&user_data),
401            private_key,
402        )
403        .unwrap();
404        let server_key = connect_token.server_to_client_key;
405        let client_key = connect_token.client_to_server_key;
406        let authentication = ClientAuthentication::Secure { connect_token };
407        let mut client = NetcodeClient::new(Duration::ZERO, authentication).unwrap();
408        let (packet_buffer, _) = client.update(Duration::ZERO).unwrap();
409
410        let (r_sequence, packet) = Packet::decode(packet_buffer, protocol_id, None, None).unwrap();
411        assert_eq!(0, r_sequence);
412        assert!(matches!(packet, Packet::ConnectionRequest { .. }));
413
414        let challenge_sequence = 7;
415        let user_data = generate_random_bytes();
416        let challenge_key = generate_random_bytes();
417        let challenge_packet = Packet::generate_challenge(client_id, &user_data, challenge_sequence, &challenge_key).unwrap();
418        let len = challenge_packet.encode(&mut buffer, protocol_id, Some((0, &server_key))).unwrap();
419        client.process_packet(&mut buffer[..len]);
420        assert_eq!(ClientState::SendingConnectionResponse, client.state);
421
422        let (packet_buffer, _) = client.update(Duration::ZERO).unwrap();
423        let (_, packet) = Packet::decode(packet_buffer, protocol_id, Some(&client_key), None).unwrap();
424        assert!(matches!(packet, Packet::Response { .. }));
425
426        let max_clients = 4;
427        let client_index = 2;
428        let keep_alive_packet = Packet::KeepAlive { max_clients, client_index };
429        let len = keep_alive_packet.encode(&mut buffer, protocol_id, Some((1, &server_key))).unwrap();
430        client.process_packet(&mut buffer[..len]);
431
432        assert_eq!(client.state, ClientState::Connected);
433
434        let payload = vec![7u8; 500];
435        let payload_packet = Packet::Payload(&payload[..]);
436        let len = payload_packet.encode(&mut buffer, protocol_id, Some((2, &server_key))).unwrap();
437
438        let payload_client = client.process_packet(&mut buffer[..len]).unwrap();
439        assert_eq!(payload, payload_client);
440
441        let to_send_payload = vec![5u8; 1000];
442        let (_, packet) = client.generate_payload_packet(&to_send_payload).unwrap();
443        let (_, result) = Packet::decode(packet, protocol_id, Some(&client_key), None).unwrap();
444        match result {
445            Packet::Payload(payload) => assert_eq!(to_send_payload, payload),
446            _ => unreachable!(),
447        }
448    }
449}