nautilus_sockets/server/
mod.rs

1pub mod config;
2use std::{
3    collections::{HashMap, VecDeque},
4    marker::PhantomData,
5    net::{SocketAddr, ToSocketAddrs, UdpSocket},
6    time::{Duration, Instant},
7};
8
9use anyhow::anyhow;
10use byteorder::{ByteOrder, LittleEndian};
11use config::ServerConfig;
12
13use crate::{
14    acknowledgement::{manager::AcknowledgementManager, packet::AckNumber},
15    client::ConnectionId,
16    connection::EstablishedConnection,
17    events::EventEmitter,
18    packet::{IntoPacketDelivery, PacketDelivery},
19    persistent::storage::PersistentStorage,
20    sequence::SequenceNumber,
21    socket::{events::SocketEvent, NautSocket, SocketType},
22};
23
24// Incremental Id
25pub struct NautServer {
26    max_connections: u8,
27
28    connection_addr_to_id: HashMap<SocketAddr, ConnectionId>,
29    connection_id_to_addr: HashMap<ConnectionId, SocketAddr>,
30    connections: HashMap<ConnectionId, EstablishedConnection>,
31
32    time_outs: HashMap<ConnectionId, Instant>,
33
34    next_id: ConnectionId,
35    freed_ids: VecDeque<ConnectionId>,
36
37    idle_connection_timeout: Duration,
38
39    server_events: VecDeque<ServerEvent>,
40}
41
42impl NautServer {
43    pub fn new(config: ServerConfig) -> Self {
44        Self {
45            max_connections: config.max_connections,
46            idle_connection_timeout: config.idle_connection_time,
47            ..Default::default()
48        }
49    }
50
51    /// Gets the [client's id](ConnectionId) from an [address](SocketAddr)
52    pub fn get_client_addr(&self, id: &ConnectionId) -> Option<&SocketAddr> {
53        self.connection_id_to_addr.get(id)
54    }
55
56    /// Gets the [client's address](SocketAddr) from an [id](ConnectionId)
57    pub fn get_client_id(&self, addr: &SocketAddr) -> Option<&ConnectionId> {
58        self.connection_addr_to_id.get(addr)
59    }
60
61    /// Gets an iterator to all [server events](ServerEvent) in the queue, this will not remove any from queue
62    pub fn iter_server_events(&self) -> std::collections::vec_deque::Iter<'_, ServerEvent> {
63        self.server_events.iter()
64    }
65
66    /// Gets the max amount of connections the server can handle
67    pub fn get_max_connections(&self) -> u8 {
68        self.max_connections
69    }
70
71    /// Gets the current amount of established connections
72    pub fn get_current_connections(&self) -> u8 {
73        self.connections.len() as u8
74    }
75
76    /// Checks if a client has not sent a packet for the (idle time)[Self::idle_connection_timeout]
77    pub(crate) fn any_client_needs_freeing(&self) -> Option<Vec<ConnectionId>> {
78        let mut ids = Vec::new();
79        for (id, time) in self.time_outs.iter() {
80            if Instant::now().duration_since(*time) < self.idle_connection_timeout {
81                continue;
82            }
83            ids.push(*id);
84        }
85
86        if ids.is_empty() {
87            return None;
88        }
89
90        Some(ids)
91    }
92
93    /// Frees a client up to the server
94    pub(crate) fn free_client(&mut self, id: ConnectionId) {
95        self.freed_ids.push_back(id);
96        let Some(addr) = self.connection_id_to_addr.remove(&id) else {
97            println!("Failed to find address of idle'd client with id: {id}");
98            return;
99        };
100
101        self.connection_addr_to_id.remove(&addr);
102        self.time_outs.remove(&id);
103        self.connections.remove(&id);
104    }
105
106    /// Closes a connection with a client and pushes a [client disconnected event](ServerEvent::OnClientDisconnected)
107    /// to the server events queue
108    pub fn close_connection_with_client(&mut self, id: ConnectionId) {
109        self.free_client(id);
110        self.server_events
111            .push_back(ServerEvent::OnClientDisconnected(id));
112    }
113
114    /// Establishes a new connection to a new [socket address](SocketAddr) and pushes a
115    /// [client connected event](ServerEvent::OnClientConnected) to the server events queue
116    pub(crate) fn establish_new_connection(&mut self, addr: SocketAddr) {
117        // Gets a new client id
118        let client_id = {
119            if let Some(client_id) = self.freed_ids.pop_front() {
120                client_id
121            } else {
122                let client_id = self.next_id;
123                self.next_id += 1;
124                client_id
125            }
126        };
127
128        self.connection_addr_to_id.insert(addr, client_id);
129        self.connection_id_to_addr.insert(client_id, addr);
130        self.connections
131            .insert(client_id, EstablishedConnection::new(addr));
132
133        self.server_events
134            .push_back(ServerEvent::OnClientConnected(client_id));
135    }
136}
137
138impl Default for NautServer {
139    fn default() -> Self {
140        Self {
141            max_connections: 128,
142            connections: Default::default(),
143            connection_addr_to_id: Default::default(),
144            connection_id_to_addr: Default::default(),
145            time_outs: Default::default(),
146            next_id: Default::default(),
147            freed_ids: VecDeque::new(),
148            idle_connection_timeout: Duration::from_secs(20),
149            server_events: VecDeque::new(),
150        }
151    }
152}
153
154impl<'socket> NautSocket<'socket, NautServer> {
155    /// Creates a new [event listening socket](crate::socket::NautSocket) with a
156    /// [server](NautServer) type
157    pub fn new<A>(addr: A, config: ServerConfig) -> anyhow::Result<Self>
158    where
159        A: ToSocketAddrs,
160    {
161        let socket = UdpSocket::bind(addr)?;
162        socket.set_nonblocking(true)?;
163
164        let server = NautServer::new(config);
165        let event_emitter = EventEmitter::new();
166        Ok(Self {
167            socket,
168            packet_queue: VecDeque::new(),
169            inner: server,
170            event_emitter,
171            ack_manager: AcknowledgementManager::new(),
172            phantom: PhantomData,
173            socket_events: Vec::new(),
174            persistent: PersistentStorage::new(),
175        })
176    }
177
178    /// Gets a reference to the [server](NautServer)
179    pub fn server(&self) -> &NautServer {
180        &self.inner
181    }
182
183    /// Gets a mutable reference to the [server](NautServer)
184    pub fn server_mut(&mut self) -> &mut NautServer {
185        &mut self.inner
186    }
187
188    /// Gets the packets from the packet queue and will handle returning
189    /// [ack packets](crate::acknowledgement::packet::AckPacket), resolving sequenced packets, emitting
190    /// listening events, establishing new connections and disconnecting idling clients
191    pub fn run_events(&mut self) {
192        // Disconnect idle clients
193        if let Some(ids_to_free) = self.inner.any_client_needs_freeing() {
194            for id in ids_to_free.iter() {
195                self.inner.free_client(*id);
196
197                self.inner
198                    .server_events
199                    .push_back(ServerEvent::OnClientTimeout(*id));
200            }
201        }
202
203        let event_emitter = std::mem::take(&mut self.event_emitter);
204        let event_emitter_ref = &event_emitter;
205        while let Some((addr, packet)) = self.oldest_packet_in_queue() {
206            let Some(delivery_type) = Self::get_delivery_type_from_packet(&packet) else {
207                self.socket_events
208                    .push(SocketEvent::ReadPacketFail("No delivery type".to_string()));
209                continue;
210            };
211
212            let Ok(delivery_type) =
213                <PacketDelivery as IntoPacketDelivery<u16>>::into_packet_delivery(delivery_type)
214            else {
215                self.socket_events
216                    .push(SocketEvent::ReadPacketFail(String::from(
217                        "Failed to read packet due to invalid delivery type",
218                    )));
219                continue;
220            };
221
222            // We must check if the packet is of ack delivery first because ack packets do not have
223            // the same byte size as a normal packet
224            if delivery_type == PacketDelivery::ack_delivery() {
225                let ack_num = AckNumber::new(LittleEndian::read_u32(&packet[2..6]));
226                self.ack_manager.packets_waiting_on_ack.remove(&ack_num);
227
228                continue;
229            }
230
231            // Check size here instead of in poll as ack packets do not fit into padding
232            if packet.len() < Self::PACKET_PADDING {
233                continue;
234            }
235
236            // Send a packet  to acknowledge the sender we have recieved their packet
237            if delivery_type.is_reliable() {
238                if let Err(e) = self.send_ack_packet(addr, &packet) {
239                    self.socket_events
240                        .push(SocketEvent::SendPacketFail(e.to_string()))
241                }
242            }
243
244            let Ok(event) = Self::get_event_from_packet(&packet) else {
245                continue;
246            };
247
248            if delivery_type.is_sequenced() {
249                let Some(seq_num) = Self::get_seq_from_packet(&packet) else {
250                    self.socket_events.push(SocketEvent::ReadPacketFail(
251                        "No sequence number in sequenced packet".to_string(),
252                    ));
253                    continue;
254                };
255
256                if let Some(last_recv_seq_num) =
257                    self.inner.last_recv_seq_num_for_event(&addr, &event)
258                {
259                    // Discard packet
260                    if seq_num < *last_recv_seq_num {
261                        println!(
262                            "Discarding {event} packet, last recv: {:?} recv: {:?}",
263                            *last_recv_seq_num, seq_num
264                        );
265                        continue;
266                    }
267
268                    *last_recv_seq_num = seq_num;
269                };
270            }
271
272            // Just ignore the packet and dont establish connection as its maxed out
273            if self.inner.connections.len() >= self.inner.max_connections as usize {
274                continue;
275            }
276
277            // Establishes a connection with a client if not already established
278            if !self.inner.connection_addr_to_id.contains_key(&addr) {
279                self.inner.establish_new_connection(addr);
280            }
281
282            let Some(client) = self.inner.connection_addr_to_id.get(&addr) else {
283                continue;
284            };
285
286            let client = *client;
287            self.inner.time_outs.insert(client, Instant::now());
288
289            let bytes = Self::get_packet_bytes(&packet).unwrap_or(Default::default());
290            event_emitter_ref.emit_event(&event, self, (addr, &bytes));
291        }
292
293        // Emit all polled events
294        event_emitter.emit_polled_events(self);
295
296        // Clear server events this time around
297        self.inner.server_events.clear();
298        self.socket_events.clear();
299
300        // Retry ack packets
301        self.retry_ack_packets();
302
303        self.event_emitter = event_emitter;
304    }
305
306    /// Sends an event message to all [established connections](EstablishedConnection)
307    pub fn broadcast(
308        &mut self,
309        event: &str,
310        buf: &[u8],
311        delivery: PacketDelivery,
312    ) -> anyhow::Result<()> {
313        let connection_ids: Vec<ConnectionId> =
314            { self.inner.connection_id_to_addr.keys().cloned().collect() };
315
316        for id in connection_ids {
317            let _ = self.send(event, buf, delivery, id);
318        }
319
320        Ok(())
321    }
322
323    /// Sends an event message to the [server](crate::server::NautServer) we are connected to
324    pub fn send(
325        &mut self,
326        event: &str,
327        buf: &[u8],
328        delivery: PacketDelivery,
329        client: ConnectionId,
330    ) -> anyhow::Result<()> {
331        let addr = {
332            *self
333                .inner
334                .connection_id_to_addr
335                .get(&client)
336                .ok_or(anyhow!(
337                    "There is no associated address with this client id"
338                ))?
339        };
340
341        let _ = self.send_by_addr(event, buf, delivery, addr.to_string());
342
343        Ok(())
344    }
345}
346
347impl<'socket> SocketType<'socket> for NautServer {
348    fn last_recv_seq_num_for_event(
349        &'socket mut self,
350        addr: &std::net::SocketAddr,
351        event: &str,
352    ) -> Option<&'socket mut SequenceNumber> {
353        let client_id = self.connection_addr_to_id.get(addr)?;
354        let connection = self.connections.get_mut(client_id)?;
355
356        if !connection.last_seq_num_recv.contains_key(event) {
357            connection
358                .last_seq_num_recv
359                .insert(event.to_string(), SequenceNumber::new(0));
360        }
361
362        let seq = connection.last_seq_num_recv.get_mut(event)?;
363
364        Some(seq)
365    }
366
367    fn update_current_send_seq_num_for_event(
368        &mut self,
369        addr: &SocketAddr,
370        event: &str,
371    ) -> Option<SequenceNumber> {
372        let client_id = self.connection_addr_to_id.get(addr)?;
373
374        let connection = self.connections.get_mut(client_id)?;
375
376        let Some(seq) = connection.current_send_seq_num.get_mut(event) else {
377            connection
378                .current_send_seq_num
379                .insert(event.to_owned(), SequenceNumber::new(0));
380            return Some(SequenceNumber::new(0));
381        };
382
383        *seq += SequenceNumber::new(1);
384
385        Some(*seq)
386    }
387}
388
389#[derive(Clone, Copy, Debug)]
390pub enum ServerEvent {
391    /// Pushed to the server event queue when a client connects
392    OnClientConnected(ConnectionId),
393    /// Pushed to the server event queue when a client times out
394    OnClientTimeout(ConnectionId),
395    /// Pushes to the server event queue when a client is disconnected
396    OnClientDisconnected(ConnectionId),
397}