1pub mod config;
2use std::{
3 collections::{HashMap, VecDeque},
4 marker::PhantomData,
5 net::{SocketAddr, ToSocketAddrs, UdpSocket},
6 time::{Duration, Instant},
7};
8
9use anyhow::anyhow;
10use byteorder::{ByteOrder, LittleEndian};
11use config::ServerConfig;
12
13use crate::{
14 acknowledgement::{manager::AcknowledgementManager, packet::AckNumber},
15 client::ConnectionId,
16 connection::EstablishedConnection,
17 events::EventEmitter,
18 packet::{IntoPacketDelivery, PacketDelivery},
19 persistent::storage::PersistentStorage,
20 sequence::SequenceNumber,
21 socket::{events::SocketEvent, NautSocket, SocketType},
22};
23
24pub struct NautServer {
26 max_connections: u8,
27
28 connection_addr_to_id: HashMap<SocketAddr, ConnectionId>,
29 connection_id_to_addr: HashMap<ConnectionId, SocketAddr>,
30 connections: HashMap<ConnectionId, EstablishedConnection>,
31
32 time_outs: HashMap<ConnectionId, Instant>,
33
34 next_id: ConnectionId,
35 freed_ids: VecDeque<ConnectionId>,
36
37 idle_connection_timeout: Duration,
38
39 server_events: VecDeque<ServerEvent>,
40}
41
42impl NautServer {
43 pub fn new(config: ServerConfig) -> Self {
44 Self {
45 max_connections: config.max_connections,
46 idle_connection_timeout: config.idle_connection_time,
47 ..Default::default()
48 }
49 }
50
51 pub fn get_client_addr(&self, id: &ConnectionId) -> Option<&SocketAddr> {
53 self.connection_id_to_addr.get(id)
54 }
55
56 pub fn get_client_id(&self, addr: &SocketAddr) -> Option<&ConnectionId> {
58 self.connection_addr_to_id.get(addr)
59 }
60
61 pub fn iter_server_events(&self) -> std::collections::vec_deque::Iter<'_, ServerEvent> {
63 self.server_events.iter()
64 }
65
66 pub fn get_max_connections(&self) -> u8 {
68 self.max_connections
69 }
70
71 pub fn get_current_connections(&self) -> u8 {
73 self.connections.len() as u8
74 }
75
76 pub(crate) fn any_client_needs_freeing(&self) -> Option<Vec<ConnectionId>> {
78 let mut ids = Vec::new();
79 for (id, time) in self.time_outs.iter() {
80 if Instant::now().duration_since(*time) < self.idle_connection_timeout {
81 continue;
82 }
83 ids.push(*id);
84 }
85
86 if ids.is_empty() {
87 return None;
88 }
89
90 Some(ids)
91 }
92
93 pub(crate) fn free_client(&mut self, id: ConnectionId) {
95 self.freed_ids.push_back(id);
96 let Some(addr) = self.connection_id_to_addr.remove(&id) else {
97 println!("Failed to find address of idle'd client with id: {id}");
98 return;
99 };
100
101 self.connection_addr_to_id.remove(&addr);
102 self.time_outs.remove(&id);
103 self.connections.remove(&id);
104 }
105
106 pub fn close_connection_with_client(&mut self, id: ConnectionId) {
109 self.free_client(id);
110 self.server_events
111 .push_back(ServerEvent::OnClientDisconnected(id));
112 }
113
114 pub(crate) fn establish_new_connection(&mut self, addr: SocketAddr) {
117 let client_id = {
119 if let Some(client_id) = self.freed_ids.pop_front() {
120 client_id
121 } else {
122 let client_id = self.next_id;
123 self.next_id += 1;
124 client_id
125 }
126 };
127
128 self.connection_addr_to_id.insert(addr, client_id);
129 self.connection_id_to_addr.insert(client_id, addr);
130 self.connections
131 .insert(client_id, EstablishedConnection::new(addr));
132
133 self.server_events
134 .push_back(ServerEvent::OnClientConnected(client_id));
135 }
136}
137
138impl Default for NautServer {
139 fn default() -> Self {
140 Self {
141 max_connections: 128,
142 connections: Default::default(),
143 connection_addr_to_id: Default::default(),
144 connection_id_to_addr: Default::default(),
145 time_outs: Default::default(),
146 next_id: Default::default(),
147 freed_ids: VecDeque::new(),
148 idle_connection_timeout: Duration::from_secs(20),
149 server_events: VecDeque::new(),
150 }
151 }
152}
153
154impl<'socket> NautSocket<'socket, NautServer> {
155 pub fn new<A>(addr: A, config: ServerConfig) -> anyhow::Result<Self>
158 where
159 A: ToSocketAddrs,
160 {
161 let socket = UdpSocket::bind(addr)?;
162 socket.set_nonblocking(true)?;
163
164 let server = NautServer::new(config);
165 let event_emitter = EventEmitter::new();
166 Ok(Self {
167 socket,
168 packet_queue: VecDeque::new(),
169 inner: server,
170 event_emitter,
171 ack_manager: AcknowledgementManager::new(),
172 phantom: PhantomData,
173 socket_events: Vec::new(),
174 persistent: PersistentStorage::new(),
175 })
176 }
177
178 pub fn server(&self) -> &NautServer {
180 &self.inner
181 }
182
183 pub fn server_mut(&mut self) -> &mut NautServer {
185 &mut self.inner
186 }
187
188 pub fn run_events(&mut self) {
192 if let Some(ids_to_free) = self.inner.any_client_needs_freeing() {
194 for id in ids_to_free.iter() {
195 self.inner.free_client(*id);
196
197 self.inner
198 .server_events
199 .push_back(ServerEvent::OnClientTimeout(*id));
200 }
201 }
202
203 let event_emitter = std::mem::take(&mut self.event_emitter);
204 let event_emitter_ref = &event_emitter;
205 while let Some((addr, packet)) = self.oldest_packet_in_queue() {
206 let Some(delivery_type) = Self::get_delivery_type_from_packet(&packet) else {
207 self.socket_events
208 .push(SocketEvent::ReadPacketFail("No delivery type".to_string()));
209 continue;
210 };
211
212 let Ok(delivery_type) =
213 <PacketDelivery as IntoPacketDelivery<u16>>::into_packet_delivery(delivery_type)
214 else {
215 self.socket_events
216 .push(SocketEvent::ReadPacketFail(String::from(
217 "Failed to read packet due to invalid delivery type",
218 )));
219 continue;
220 };
221
222 if delivery_type == PacketDelivery::ack_delivery() {
225 let ack_num = AckNumber::new(LittleEndian::read_u32(&packet[2..6]));
226 self.ack_manager.packets_waiting_on_ack.remove(&ack_num);
227
228 continue;
229 }
230
231 if packet.len() < Self::PACKET_PADDING {
233 continue;
234 }
235
236 if delivery_type.is_reliable() {
238 if let Err(e) = self.send_ack_packet(addr, &packet) {
239 self.socket_events
240 .push(SocketEvent::SendPacketFail(e.to_string()))
241 }
242 }
243
244 let Ok(event) = Self::get_event_from_packet(&packet) else {
245 continue;
246 };
247
248 if delivery_type.is_sequenced() {
249 let Some(seq_num) = Self::get_seq_from_packet(&packet) else {
250 self.socket_events.push(SocketEvent::ReadPacketFail(
251 "No sequence number in sequenced packet".to_string(),
252 ));
253 continue;
254 };
255
256 if let Some(last_recv_seq_num) =
257 self.inner.last_recv_seq_num_for_event(&addr, &event)
258 {
259 if seq_num < *last_recv_seq_num {
261 println!(
262 "Discarding {event} packet, last recv: {:?} recv: {:?}",
263 *last_recv_seq_num, seq_num
264 );
265 continue;
266 }
267
268 *last_recv_seq_num = seq_num;
269 };
270 }
271
272 if self.inner.connections.len() >= self.inner.max_connections as usize {
274 continue;
275 }
276
277 if !self.inner.connection_addr_to_id.contains_key(&addr) {
279 self.inner.establish_new_connection(addr);
280 }
281
282 let Some(client) = self.inner.connection_addr_to_id.get(&addr) else {
283 continue;
284 };
285
286 let client = *client;
287 self.inner.time_outs.insert(client, Instant::now());
288
289 let bytes = Self::get_packet_bytes(&packet).unwrap_or(Default::default());
290 event_emitter_ref.emit_event(&event, self, (addr, &bytes));
291 }
292
293 event_emitter.emit_polled_events(self);
295
296 self.inner.server_events.clear();
298 self.socket_events.clear();
299
300 self.retry_ack_packets();
302
303 self.event_emitter = event_emitter;
304 }
305
306 pub fn broadcast(
308 &mut self,
309 event: &str,
310 buf: &[u8],
311 delivery: PacketDelivery,
312 ) -> anyhow::Result<()> {
313 let connection_ids: Vec<ConnectionId> =
314 { self.inner.connection_id_to_addr.keys().cloned().collect() };
315
316 for id in connection_ids {
317 let _ = self.send(event, buf, delivery, id);
318 }
319
320 Ok(())
321 }
322
323 pub fn send(
325 &mut self,
326 event: &str,
327 buf: &[u8],
328 delivery: PacketDelivery,
329 client: ConnectionId,
330 ) -> anyhow::Result<()> {
331 let addr = {
332 *self
333 .inner
334 .connection_id_to_addr
335 .get(&client)
336 .ok_or(anyhow!(
337 "There is no associated address with this client id"
338 ))?
339 };
340
341 let _ = self.send_by_addr(event, buf, delivery, addr.to_string());
342
343 Ok(())
344 }
345}
346
347impl<'socket> SocketType<'socket> for NautServer {
348 fn last_recv_seq_num_for_event(
349 &'socket mut self,
350 addr: &std::net::SocketAddr,
351 event: &str,
352 ) -> Option<&'socket mut SequenceNumber> {
353 let client_id = self.connection_addr_to_id.get(addr)?;
354 let connection = self.connections.get_mut(client_id)?;
355
356 if !connection.last_seq_num_recv.contains_key(event) {
357 connection
358 .last_seq_num_recv
359 .insert(event.to_string(), SequenceNumber::new(0));
360 }
361
362 let seq = connection.last_seq_num_recv.get_mut(event)?;
363
364 Some(seq)
365 }
366
367 fn update_current_send_seq_num_for_event(
368 &mut self,
369 addr: &SocketAddr,
370 event: &str,
371 ) -> Option<SequenceNumber> {
372 let client_id = self.connection_addr_to_id.get(addr)?;
373
374 let connection = self.connections.get_mut(client_id)?;
375
376 let Some(seq) = connection.current_send_seq_num.get_mut(event) else {
377 connection
378 .current_send_seq_num
379 .insert(event.to_owned(), SequenceNumber::new(0));
380 return Some(SequenceNumber::new(0));
381 };
382
383 *seq += SequenceNumber::new(1);
384
385 Some(*seq)
386 }
387}
388
389#[derive(Clone, Copy, Debug)]
390pub enum ServerEvent {
391 OnClientConnected(ConnectionId),
393 OnClientTimeout(ConnectionId),
395 OnClientDisconnected(ConnectionId),
397}