renetcode/
server.rs

1use std::{collections::HashMap, net::SocketAddr, time::Duration};
2
3use crate::{
4    crypto::generate_random_bytes,
5    packet::{ChallengeToken, Packet},
6    replay_protection::ReplayProtection,
7    token::PrivateConnectToken,
8    NetcodeError, NETCODE_CONNECT_TOKEN_PRIVATE_BYTES, NETCODE_CONNECT_TOKEN_XNONCE_BYTES, NETCODE_KEY_BYTES, NETCODE_MAC_BYTES,
9    NETCODE_MAX_CLIENTS, NETCODE_MAX_PACKET_BYTES, NETCODE_MAX_PAYLOAD_BYTES, NETCODE_MAX_PENDING_CLIENTS, NETCODE_SEND_RATE,
10    NETCODE_USER_DATA_BYTES, NETCODE_VERSION_INFO,
11};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14enum ConnectionState {
15    Disconnected,
16    PendingResponse,
17    Connected,
18}
19
20#[derive(Debug, Clone)]
21struct Connection {
22    confirmed: bool,
23    client_id: u64,
24    state: ConnectionState,
25    send_key: [u8; NETCODE_KEY_BYTES],
26    receive_key: [u8; NETCODE_KEY_BYTES],
27    user_data: [u8; NETCODE_USER_DATA_BYTES],
28    addr: SocketAddr,
29    last_packet_received_time: Duration,
30    last_packet_send_time: Duration,
31    timeout_seconds: i32,
32    sequence: u64,
33    expire_timestamp: u64,
34    replay_protection: ReplayProtection,
35}
36
37#[derive(Debug, Copy, Clone)]
38struct ConnectTokenEntry {
39    time: Duration,
40    address: SocketAddr,
41    mac: [u8; NETCODE_MAC_BYTES],
42}
43
44/// A server that can generate packets from connect clients, that are encrypted, or process
45/// incoming encrypted packets from clients. The server is agnostic from the transport layer, only
46/// consuming and generating bytes that can be transported in any way desired.
47#[derive(Debug)]
48pub struct NetcodeServer {
49    clients: Box<[Option<Connection>]>,
50    pending_clients: HashMap<SocketAddr, Connection>,
51    connect_token_entries: Box<[Option<ConnectTokenEntry>; NETCODE_MAX_CLIENTS * 2]>,
52    protocol_id: u64,
53    connect_key: [u8; NETCODE_KEY_BYTES],
54    max_clients: usize,
55    challenge_sequence: u64,
56    challenge_key: [u8; NETCODE_KEY_BYTES],
57    public_addresses: Vec<SocketAddr>,
58    current_time: Duration,
59    global_sequence: u64,
60    secure: bool,
61    out: [u8; NETCODE_MAX_PACKET_BYTES],
62}
63
64/// Result from processing an packet in the server
65#[derive(Debug, PartialEq, Eq)]
66pub enum ServerResult<'a, 's> {
67    /// Nothing needs to be done.
68    None,
69    /// A packet to be sent back to the processed address.
70    PacketToSend { addr: SocketAddr, payload: &'s mut [u8] },
71    /// A payload received from the client.
72    Payload { client_id: u64, payload: &'a [u8] },
73    /// A new client has connected
74    ClientConnected {
75        client_id: u64,
76        addr: SocketAddr,
77        user_data: Box<[u8; NETCODE_USER_DATA_BYTES]>,
78        payload: &'s mut [u8],
79    },
80    /// The client connection has been terminated.
81    ClientDisconnected {
82        client_id: u64,
83        addr: SocketAddr,
84        payload: Option<&'s mut [u8]>,
85    },
86}
87
88/// Configuration to establish a secure or unsecure connection with the server.
89pub enum ServerAuthentication {
90    /// Establishes a safe connection using a private key for encryption. The private key cannot be
91    /// shared with the client. Connections are stablished using [crate::token::ConnectToken].
92    ///
93    /// See also [ClientAuthentication::Secure][crate::ClientAuthentication::Secure]
94    Secure { private_key: [u8; NETCODE_KEY_BYTES] },
95    /// Establishes unsafe connections with clients, useful for testing and prototyping.
96    ///
97    /// See also [ClientAuthentication::Unsecure][crate::ClientAuthentication::Unsecure]
98    Unsecure,
99}
100
101pub struct ServerConfig {
102    pub current_time: Duration,
103    /// Maximum numbers of clients that can be connected at a time
104    pub max_clients: usize,
105    /// Unique identifier for this particular game/application.
106    /// You can use a hash function with the current version of the game to generate this value
107    /// so that older versions cannot connect to newer versions.
108    pub protocol_id: u64,
109    /// Publicly available addresses to which clients will attempt to connect.
110    pub public_addresses: Vec<SocketAddr>,
111    /// Authentication configuration for the server
112    pub authentication: ServerAuthentication,
113}
114
115impl NetcodeServer {
116    pub fn new(config: ServerConfig) -> Self {
117        if config.max_clients > NETCODE_MAX_CLIENTS {
118            // TODO: do we really need to set a max?
119            //       only using for token entries
120            panic!("The max clients allowed is {}", NETCODE_MAX_CLIENTS);
121        }
122        let challenge_key = generate_random_bytes();
123        let clients = vec![None; config.max_clients].into_boxed_slice();
124
125        let connect_key = match config.authentication {
126            ServerAuthentication::Unsecure => [0; NETCODE_KEY_BYTES],
127            ServerAuthentication::Secure { private_key } => private_key,
128        };
129
130        let secure = match config.authentication {
131            ServerAuthentication::Unsecure => false,
132            ServerAuthentication::Secure { .. } => true,
133        };
134
135        Self {
136            clients,
137            connect_token_entries: Box::new([None; NETCODE_MAX_CLIENTS * 2]),
138            pending_clients: HashMap::new(),
139            protocol_id: config.protocol_id,
140            connect_key,
141            max_clients: config.max_clients,
142            challenge_sequence: 0,
143            global_sequence: 0,
144            challenge_key,
145            public_addresses: config.public_addresses,
146            current_time: config.current_time,
147            secure,
148            out: [0u8; NETCODE_MAX_PACKET_BYTES],
149        }
150    }
151
152    #[doc(hidden)]
153    pub fn __test() -> Self {
154        let config = ServerConfig {
155            current_time: Duration::ZERO,
156            max_clients: 32,
157            protocol_id: 0,
158            public_addresses: vec!["127.0.0.1:0".parse().unwrap()],
159            authentication: ServerAuthentication::Unsecure,
160        };
161        Self::new(config)
162    }
163
164    pub fn addresses(&self) -> Vec<SocketAddr> {
165        self.public_addresses.clone()
166    }
167
168    pub fn current_time(&self) -> Duration {
169        self.current_time
170    }
171
172    fn find_or_add_connect_token_entry(&mut self, new_entry: ConnectTokenEntry) -> bool {
173        let mut min = Duration::MAX;
174        let mut oldest_entry = 0;
175        let mut empty_entry = false;
176        let mut matching_entry = None;
177        for (i, entry) in self.connect_token_entries.iter().enumerate() {
178            match entry {
179                Some(e) => {
180                    if e.mac == new_entry.mac {
181                        matching_entry = Some(e);
182                    }
183                    if !empty_entry && e.time < min {
184                        oldest_entry = i;
185                        min = e.time;
186                    }
187                }
188                None => {
189                    if !empty_entry {
190                        empty_entry = true;
191                        oldest_entry = i;
192                    }
193                }
194            }
195        }
196
197        if let Some(entry) = matching_entry {
198            return entry.address == new_entry.address;
199        }
200
201        self.connect_token_entries[oldest_entry] = Some(new_entry);
202
203        true
204    }
205
206    /// Returns the user data from the connected client.
207    pub fn user_data(&self, client_id: u64) -> Option<[u8; NETCODE_USER_DATA_BYTES]> {
208        if let Some(client) = find_client_by_id(&self.clients, client_id) {
209            return Some(client.user_data);
210        }
211
212        None
213    }
214
215    /// Returns the duration since the connected client last received a packet.
216    /// Usefull to detect users that are timing out.
217    pub fn time_since_last_received_packet(&self, client_id: u64) -> Option<Duration> {
218        if let Some(client) = find_client_by_id(&self.clients, client_id) {
219            let time = self.current_time - client.last_packet_received_time;
220            return Some(time);
221        }
222
223        None
224    }
225
226    /// Returns the client address if connected.
227    pub fn client_addr(&self, client_id: u64) -> Option<SocketAddr> {
228        if let Some(client) = find_client_by_id(&self.clients, client_id) {
229            return Some(client.addr);
230        }
231
232        None
233    }
234
235    fn handle_connection_request<'a>(
236        &mut self,
237        addr: SocketAddr,
238        version_info: [u8; 13],
239        protocol_id: u64,
240        expire_timestamp: u64,
241        xnonce: [u8; NETCODE_CONNECT_TOKEN_XNONCE_BYTES],
242        data: [u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES],
243    ) -> Result<ServerResult<'a, '_>, NetcodeError> {
244        if version_info != *NETCODE_VERSION_INFO {
245            return Err(NetcodeError::InvalidVersion);
246        }
247
248        if protocol_id != self.protocol_id {
249            return Err(NetcodeError::InvalidProtocolID);
250        }
251
252        if self.current_time.as_secs() >= expire_timestamp {
253            return Err(NetcodeError::Expired);
254        }
255
256        let connect_token = PrivateConnectToken::decode(&data, self.protocol_id, expire_timestamp, &xnonce, &self.connect_key)?;
257
258        // Skip host list check when unsecure
259        if self.secure {
260            let in_host_list = connect_token
261                .server_addresses
262                .iter()
263                .filter_map(|host| *host)
264                .any(|addr| self.public_addresses.contains(&addr));
265
266            if !in_host_list {
267                return Err(NetcodeError::NotInHostList);
268            }
269        }
270
271        let addr_already_connected = find_client_mut_by_addr(&mut self.clients, addr).is_some();
272        let id_already_connected = find_client_mut_by_id(&mut self.clients, connect_token.client_id).is_some();
273        if id_already_connected || addr_already_connected {
274            log::debug!(
275                "Connection request denied: client {} already connected (address: {}).",
276                connect_token.client_id,
277                addr
278            );
279            return Ok(ServerResult::None);
280        }
281
282        if !self.pending_clients.contains_key(&addr) && self.pending_clients.len() >= NETCODE_MAX_PENDING_CLIENTS {
283            log::warn!(
284                "Connection request denied: reached max amount allowed of pending clients ({}).",
285                NETCODE_MAX_PENDING_CLIENTS
286            );
287            return Ok(ServerResult::None);
288        }
289
290        let mut mac = [0u8; NETCODE_MAC_BYTES];
291        mac.copy_from_slice(&data[NETCODE_CONNECT_TOKEN_PRIVATE_BYTES - NETCODE_MAC_BYTES..]);
292        let connect_token_entry = ConnectTokenEntry {
293            address: addr,
294            time: self.current_time,
295            mac,
296        };
297
298        if !self.find_or_add_connect_token_entry(connect_token_entry) {
299            log::warn!("Connection request denied: unable to add connect token entry");
300            return Ok(ServerResult::None);
301        }
302
303        if self.clients.iter().flatten().count() >= self.max_clients {
304            self.pending_clients.remove(&addr);
305            let packet = Packet::ConnectionDenied;
306            let len = packet.encode(
307                &mut self.out,
308                self.protocol_id,
309                Some((self.global_sequence, &connect_token.server_to_client_key)),
310            )?;
311            self.global_sequence += 1;
312            return Ok(ServerResult::PacketToSend {
313                addr,
314                payload: &mut self.out[..len],
315            });
316        }
317
318        self.challenge_sequence += 1;
319        let packet = Packet::generate_challenge(
320            connect_token.client_id,
321            &connect_token.user_data,
322            self.challenge_sequence,
323            &self.challenge_key,
324        )?;
325
326        let len = packet.encode(
327            &mut self.out,
328            self.protocol_id,
329            Some((self.global_sequence, &connect_token.server_to_client_key)),
330        )?;
331        self.global_sequence += 1;
332
333        log::trace!("Connection request from Client {}", connect_token.client_id);
334
335        let pending = self.pending_clients.entry(addr).or_insert_with(|| Connection {
336            confirmed: false,
337            sequence: 0,
338            client_id: connect_token.client_id,
339            last_packet_received_time: self.current_time,
340            last_packet_send_time: self.current_time,
341            addr,
342            state: ConnectionState::PendingResponse,
343            send_key: connect_token.server_to_client_key,
344            receive_key: connect_token.client_to_server_key,
345            timeout_seconds: connect_token.timeout_seconds,
346            expire_timestamp,
347            user_data: connect_token.user_data,
348            replay_protection: ReplayProtection::new(),
349        });
350        pending.last_packet_received_time = self.current_time;
351        pending.last_packet_send_time = self.current_time;
352
353        Ok(ServerResult::PacketToSend {
354            addr,
355            payload: &mut self.out[..len],
356        })
357    }
358
359    /// Returns an encoded packet payload to be sent to the client
360    pub fn generate_payload_packet<'s>(&'s mut self, client_id: u64, payload: &[u8]) -> Result<(SocketAddr, &'s mut [u8]), NetcodeError> {
361        if payload.len() > NETCODE_MAX_PAYLOAD_BYTES {
362            return Err(NetcodeError::PayloadAboveLimit);
363        }
364
365        if let Some(client) = find_client_mut_by_id(&mut self.clients, client_id) {
366            let packet = Packet::Payload(payload);
367            let len = packet.encode(&mut self.out, self.protocol_id, Some((client.sequence, &client.send_key)))?;
368            client.sequence += 1;
369            client.last_packet_send_time = self.current_time;
370
371            return Ok((client.addr, &mut self.out[..len]));
372        }
373
374        Err(NetcodeError::ClientNotFound)
375    }
376
377    /// Process an packet from the especifed address. Returns a server result, check out
378    /// [ServerResult].
379    pub fn process_packet<'a, 's>(&'s mut self, addr: SocketAddr, buffer: &'a mut [u8]) -> ServerResult<'a, 's> {
380        match self.process_packet_internal(addr, buffer) {
381            Err(e) => {
382                log::error!("Failed to process packet: {}", e);
383                ServerResult::None
384            }
385            Ok(r) => r,
386        }
387    }
388
389    fn process_packet_internal<'a, 's>(&'s mut self, addr: SocketAddr, buffer: &'a mut [u8]) -> Result<ServerResult<'a, 's>, NetcodeError> {
390        if buffer.len() < 2 + NETCODE_MAC_BYTES {
391            return Err(NetcodeError::PacketTooSmall);
392        }
393
394        // Handle connected client
395        if let Some((slot, client)) = find_client_mut_by_addr(&mut self.clients, addr) {
396            let (_, packet) = Packet::decode(
397                buffer,
398                self.protocol_id,
399                Some(&client.receive_key),
400                Some(&mut client.replay_protection),
401            )?;
402            log::trace!(
403                "Received packet from connected client ({}): {:?}",
404                client.client_id,
405                packet.packet_type()
406            );
407
408            client.last_packet_received_time = self.current_time;
409            match client.state {
410                ConnectionState::Connected => match packet {
411                    Packet::Disconnect => {
412                        client.state = ConnectionState::Disconnected;
413                        let client_id = client.client_id;
414                        self.clients[slot] = None;
415                        log::trace!("Client {} requested to disconnect", client_id);
416                        return Ok(ServerResult::ClientDisconnected {
417                            client_id,
418                            addr,
419                            payload: None,
420                        });
421                    }
422                    Packet::Payload(payload) => {
423                        if !client.confirmed {
424                            log::trace!("Confirmed connection for Client {}", client.client_id);
425                            client.confirmed = true;
426                        }
427                        return Ok(ServerResult::Payload {
428                            client_id: client.client_id,
429                            payload,
430                        });
431                    }
432                    Packet::KeepAlive { .. } => {
433                        if !client.confirmed {
434                            log::trace!("Confirmed connection for Client {}", client.client_id);
435                            client.confirmed = true;
436                        }
437                        return Ok(ServerResult::None);
438                    }
439                    _ => return Ok(ServerResult::None),
440                },
441                _ => return Ok(ServerResult::None),
442            }
443        }
444
445        // Handle pending client
446        if let Some(pending) = self.pending_clients.get_mut(&addr) {
447            let (_, packet) = Packet::decode(
448                buffer,
449                self.protocol_id,
450                Some(&pending.receive_key),
451                Some(&mut pending.replay_protection),
452            )?;
453            pending.last_packet_received_time = self.current_time;
454            log::trace!("Received packet from pending client ({}): {:?}", addr, packet.packet_type());
455            match packet {
456                Packet::ConnectionRequest {
457                    protocol_id,
458                    expire_timestamp,
459                    data,
460                    xnonce,
461                    version_info,
462                } => {
463                    return self.handle_connection_request(addr, version_info, protocol_id, expire_timestamp, xnonce, data);
464                }
465                Packet::Response {
466                    token_data,
467                    token_sequence,
468                } => {
469                    let challenge_token = ChallengeToken::decode(token_data, token_sequence, &self.challenge_key)?;
470                    let mut pending = self.pending_clients.remove(&addr).unwrap();
471                    if find_client_slot_by_id(&self.clients, challenge_token.client_id).is_some() {
472                        log::debug!(
473                            "Ignored connection response for Client {}, already connected.",
474                            challenge_token.client_id
475                        );
476                        return Ok(ServerResult::None);
477                    }
478                    match self.clients.iter().position(|c| c.is_none()) {
479                        None => {
480                            let packet = Packet::ConnectionDenied;
481                            let len = packet.encode(&mut self.out, self.protocol_id, Some((self.global_sequence, &pending.send_key)))?;
482                            pending.state = ConnectionState::Disconnected;
483                            self.global_sequence += 1;
484                            pending.last_packet_send_time = self.current_time;
485                            return Ok(ServerResult::PacketToSend {
486                                addr,
487                                payload: &mut self.out[..len],
488                            });
489                        }
490                        Some(client_index) => {
491                            pending.state = ConnectionState::Connected;
492                            pending.user_data = challenge_token.user_data;
493                            pending.last_packet_send_time = self.current_time;
494
495                            let packet = Packet::KeepAlive {
496                                max_clients: self.max_clients as u32,
497                                client_index: client_index as u32,
498                            };
499                            let len = packet.encode(&mut self.out, self.protocol_id, Some((pending.sequence, &pending.send_key)))?;
500                            pending.sequence += 1;
501
502                            let client_id: u64 = pending.client_id;
503                            let user_data: [u8; NETCODE_USER_DATA_BYTES] = pending.user_data;
504                            self.clients[client_index] = Some(pending);
505
506                            return Ok(ServerResult::ClientConnected {
507                                client_id,
508                                addr,
509                                user_data: Box::new(user_data),
510                                payload: &mut self.out[..len],
511                            });
512                        }
513                    }
514                }
515                _ => return Ok(ServerResult::None),
516            }
517        }
518
519        // Handle new client
520        let (_, packet) = Packet::decode(buffer, self.protocol_id, None, None)?;
521        match packet {
522            Packet::ConnectionRequest {
523                data,
524                protocol_id,
525                expire_timestamp,
526                xnonce,
527                version_info,
528            } => self.handle_connection_request(addr, version_info, protocol_id, expire_timestamp, xnonce, data),
529            _ => unreachable!("Decoding packet without key can only return ConnectionRequest packets"),
530        }
531    }
532
533    pub fn clients_slot(&self) -> Vec<usize> {
534        self.clients
535            .iter()
536            .enumerate()
537            .filter_map(|(index, slot)| if slot.is_some() { Some(index) } else { None })
538            .collect()
539    }
540
541    /// Returns the ids from the connected clients (iterator).
542    pub fn clients_id_iter(&self) -> impl Iterator<Item = u64> + '_ {
543        self.clients.iter().filter_map(|slot| slot.as_ref().map(|client| client.client_id))
544    }
545
546    /// Returns the ids from the connected clients.
547    pub fn clients_id(&self) -> Vec<u64> {
548        self.clients_id_iter().collect()
549    }
550
551    /// Returns the maximum number of clients that can be connected.
552    pub fn max_clients(&self) -> usize {
553        self.max_clients
554    }
555
556    /// Update the maximum numbers of clients that can be connected
557    ///
558    /// Changing the `max_clients` to a lower value than the current number of connect clients
559    /// does not disconnect clients. So [`NetcodeServer::connected_clients()`] can return a
560    /// higher value than [`NetcodeServer::max_clients()`].
561    pub fn set_max_clients(&mut self, max_clients: usize) {
562        let max_clients = max_clients.min(NETCODE_MAX_CLIENTS);
563        log::debug!("Netcode max_clients set to {}", max_clients);
564
565        self.max_clients = max_clients;
566    }
567
568    /// Returns current number of clients connected.
569    pub fn connected_clients(&self) -> usize {
570        self.clients.iter().filter(|slot| slot.is_some()).count()
571    }
572
573    /// Advance the server current time, and remove any pending connections that have expired.
574    pub fn update(&mut self, duration: Duration) {
575        self.current_time += duration;
576
577        for client in self.pending_clients.values_mut() {
578            if self.current_time.as_secs() > client.expire_timestamp {
579                log::debug!("Pending Client {} disconnected, connection token expired.", client.client_id);
580                client.state = ConnectionState::Disconnected;
581            }
582        }
583
584        self.pending_clients.retain(|_, c| c.state != ConnectionState::Disconnected);
585    }
586
587    /// Updates the client, returns a ServerResult.
588    ///
589    /// # Example
590    /// ```
591    /// # use renetcode::ServerResult;
592    /// # let mut server = renetcode::NetcodeServer::__test();
593    /// for client_id in server.clients_id().into_iter() {
594    ///     match server.update_client(client_id) {
595    ///         ServerResult::PacketToSend { payload, addr } => send_to(payload, addr),
596    ///         _ => { /* ... */ }
597    ///     }
598    /// }
599    /// # fn send_to(p: &[u8], addr: std::net::SocketAddr) {}
600    /// ```
601    pub fn update_client(&mut self, client_id: u64) -> ServerResult<'_, '_> {
602        let slot = match find_client_slot_by_id(&self.clients, client_id) {
603            None => return ServerResult::None,
604            Some(slot) => slot,
605        };
606
607        if let Some(client) = &mut self.clients[slot] {
608            let connection_timed_out = client.timeout_seconds > 0
609                && (client.last_packet_received_time + Duration::from_secs(client.timeout_seconds as u64) < self.current_time);
610            if connection_timed_out {
611                log::debug!("Client {} disconnected, connection timed out", client.client_id);
612                client.state = ConnectionState::Disconnected;
613            }
614
615            if client.state == ConnectionState::Disconnected {
616                let packet = Packet::Disconnect;
617                let sequence = client.sequence;
618                let send_key = client.send_key;
619                let addr = client.addr;
620                self.clients[slot] = None;
621
622                let len = match packet.encode(&mut self.out, self.protocol_id, Some((sequence, &send_key))) {
623                    Err(e) => {
624                        log::error!("Failed to encode disconnect packet: {}", e);
625                        return ServerResult::ClientDisconnected {
626                            client_id,
627                            addr,
628                            payload: None,
629                        };
630                    }
631                    Ok(len) => len,
632                };
633
634                return ServerResult::ClientDisconnected {
635                    client_id,
636                    addr,
637                    payload: Some(&mut self.out[..len]),
638                };
639            }
640
641            if client.last_packet_send_time + NETCODE_SEND_RATE <= self.current_time {
642                let packet = Packet::KeepAlive {
643                    client_index: slot as u32,
644                    max_clients: self.max_clients as u32,
645                };
646
647                let len = match packet.encode(&mut self.out, self.protocol_id, Some((client.sequence, &client.send_key))) {
648                    Err(e) => {
649                        log::error!("Failed to encode keep alive packet: {}", e);
650                        return ServerResult::None;
651                    }
652                    Ok(len) => len,
653                };
654                client.sequence += 1;
655                client.last_packet_send_time = self.current_time;
656                return ServerResult::PacketToSend {
657                    addr: client.addr,
658                    payload: &mut self.out[..len],
659                };
660            }
661        }
662
663        ServerResult::None
664    }
665
666    pub fn is_client_connected(&self, client_id: u64) -> bool {
667        find_client_slot_by_id(&self.clients, client_id).is_some()
668    }
669
670    /// Disconnect an client and returns its address and a disconnect packet to be sent to them.
671    // TODO: we can return Result<PacketToSend, NetcodeError>
672    //       but the library user would need to be aware that he has to run
673    //       the same code as Result::ClientDisconnected
674    pub fn disconnect(&mut self, client_id: u64) -> ServerResult<'_, '_> {
675        if let Some(slot) = find_client_slot_by_id(&self.clients, client_id) {
676            let client = self.clients[slot].take().unwrap();
677            let packet = Packet::Disconnect;
678
679            let len = match packet.encode(&mut self.out, self.protocol_id, Some((client.sequence, &client.send_key))) {
680                Err(e) => {
681                    log::error!("Failed to encode disconnect packet: {}", e);
682                    return ServerResult::ClientDisconnected {
683                        client_id,
684                        addr: client.addr,
685                        payload: None,
686                    };
687                }
688                Ok(len) => len,
689            };
690            return ServerResult::ClientDisconnected {
691                client_id,
692                addr: client.addr,
693                payload: Some(&mut self.out[..len]),
694            };
695        }
696
697        ServerResult::None
698    }
699}
700
701fn find_client_mut_by_id(clients: &mut [Option<Connection>], client_id: u64) -> Option<&mut Connection> {
702    clients.iter_mut().flatten().find(|c| c.client_id == client_id)
703}
704
705fn find_client_by_id(clients: &[Option<Connection>], client_id: u64) -> Option<&Connection> {
706    clients.iter().flatten().find(|c| c.client_id == client_id)
707}
708
709fn find_client_slot_by_id(clients: &[Option<Connection>], client_id: u64) -> Option<usize> {
710    clients.iter().enumerate().find_map(|(i, c)| match c {
711        Some(c) if c.client_id == client_id => Some(i),
712        _ => None,
713    })
714}
715
716fn find_client_mut_by_addr(clients: &mut [Option<Connection>], addr: SocketAddr) -> Option<(usize, &mut Connection)> {
717    clients.iter_mut().enumerate().find_map(|(i, c)| match c {
718        Some(c) if c.addr == addr => Some((i, c)),
719        _ => None,
720    })
721}
722
723#[cfg(test)]
724mod tests {
725    use crate::{client::NetcodeClient, token::ConnectToken, ClientAuthentication};
726
727    use super::*;
728
729    const TEST_KEY: &[u8; NETCODE_KEY_BYTES] = b"an example very very secret key."; // 32-bytes
730    const TEST_PROTOCOL_ID: u64 = 7;
731
732    fn new_server() -> NetcodeServer {
733        let config = ServerConfig {
734            current_time: Duration::ZERO,
735            max_clients: 16,
736            protocol_id: TEST_PROTOCOL_ID,
737            public_addresses: vec!["127.0.0.1:5000".parse().unwrap()],
738            authentication: ServerAuthentication::Secure { private_key: *TEST_KEY },
739        };
740        NetcodeServer::new(config)
741    }
742
743    #[test]
744    fn server_connection() {
745        let mut server = new_server();
746        let server_addresses: Vec<SocketAddr> = server.addresses();
747        let user_data = generate_random_bytes();
748        let expire_seconds = 3;
749        let client_id = 4;
750        let timeout_seconds = 5;
751        let client_addr: SocketAddr = "127.0.0.1:3000".parse().unwrap();
752        let connect_token = ConnectToken::generate(
753            Duration::ZERO,
754            TEST_PROTOCOL_ID,
755            expire_seconds,
756            client_id,
757            timeout_seconds,
758            server_addresses,
759            Some(&user_data),
760            TEST_KEY,
761        )
762        .unwrap();
763        let client_auth = ClientAuthentication::Secure { connect_token };
764        let mut client = NetcodeClient::new(Duration::ZERO, client_auth).unwrap();
765        let (client_packet, _) = client.update(Duration::ZERO).unwrap();
766
767        let result = server.process_packet(client_addr, client_packet);
768        assert!(matches!(result, ServerResult::PacketToSend { .. }));
769        match result {
770            ServerResult::PacketToSend { payload, .. } => client.process_packet(payload),
771            _ => unreachable!(),
772        };
773
774        assert!(!client.is_connected());
775        let (client_packet, _) = client.update(Duration::ZERO).unwrap();
776        let result = server.process_packet(client_addr, client_packet);
777
778        match result {
779            ServerResult::ClientConnected {
780                client_id: r_id,
781                user_data: r_data,
782                payload,
783                ..
784            } => {
785                assert_eq!(client_id, r_id);
786                assert_eq!(user_data, *r_data);
787                client.process_packet(payload)
788            }
789            _ => unreachable!(),
790        };
791
792        assert!(client.is_connected());
793
794        for _ in 0..3 {
795            let payload = [7u8; 300];
796            let (_, packet) = server.generate_payload_packet(client_id, &payload).unwrap();
797            let result_payload = client.process_packet(packet).unwrap();
798            assert_eq!(payload, result_payload);
799        }
800
801        let result = server.update_client(client_id);
802        assert_eq!(result, ServerResult::None);
803        server.update(NETCODE_SEND_RATE);
804
805        let result = server.update_client(client_id);
806        match result {
807            ServerResult::PacketToSend { payload, .. } => {
808                assert!(client.process_packet(payload).is_none());
809            }
810            _ => unreachable!(),
811        }
812
813        let client_payload = [2u8; 300];
814        let (_, packet) = client.generate_payload_packet(&client_payload).unwrap();
815
816        match server.process_packet(client_addr, packet) {
817            ServerResult::Payload { client_id: id, payload } => {
818                assert_eq!(id, client_id);
819                assert_eq!(client_payload, payload);
820            }
821            _ => unreachable!(),
822        }
823
824        assert!(server.is_client_connected(client_id));
825        let result = server.disconnect(client_id);
826        match result {
827            ServerResult::ClientDisconnected {
828                payload: Some(payload), ..
829            } => {
830                assert!(client.is_connected());
831                assert!(client.process_packet(payload).is_none());
832                assert!(!client.is_connected());
833            }
834            _ => unreachable!(),
835        }
836
837        assert!(!server.is_client_connected(client_id));
838    }
839
840    #[test]
841    fn connect_token_already_used() {
842        let mut server = new_server();
843
844        let client_addr: SocketAddr = "127.0.0.1:3000".parse().unwrap();
845        let mut connect_token = ConnectTokenEntry {
846            time: Duration::ZERO,
847            address: client_addr,
848            mac: generate_random_bytes(),
849        };
850        // Allow first entry
851        assert!(server.find_or_add_connect_token_entry(connect_token));
852        // Allow same token with the same address
853        assert!(server.find_or_add_connect_token_entry(connect_token));
854        connect_token.address = "127.0.0.1:3001".parse().unwrap();
855
856        // Don't allow same token with different address
857        assert!(!server.find_or_add_connect_token_entry(connect_token));
858    }
859}