frag_datagram/
server.rs

1/*
2 * Copyright (c) Peter Bjorklund. All rights reserved. https://github.com/piot/frag-datagram
3 * Licensed under the MIT License. See LICENSE in the project root for license information.
4 */
5use crate::{
6    FLAG_CONNECT_REQUEST, FLAG_CONNECT_RESPONSE, Receiver, Sender, read_connect_request,
7    read_datagram, write_connect_response,
8};
9use std::collections::{HashMap, HashSet};
10use std::time::Duration;
11use tracing::trace;
12
13/// A complete bidirectional connection that combines Sender + Receiver
14pub struct ServerConnection {
15    connection_id: u16,
16    address_hash: u64,
17    sender: Sender,
18    receiver: Receiver,
19}
20
21impl ServerConnection {
22    /// Create a new server connection
23    #[must_use]
24    pub fn new(connection_id: u16, address_hash: u64, max_fragments: usize) -> Self {
25        Self {
26            connection_id,
27            address_hash,
28            sender: Sender::new(connection_id),
29            receiver: Receiver::new(max_fragments, address_hash),
30        }
31    }
32
33    /// Send data to this connection's client
34    pub fn send(&mut self, payload: &[u8]) -> Vec<Vec<u8>> {
35        self.sender.send(payload)
36    }
37
38    /// Process incoming datagram for this connection
39    pub fn receive(
40        &mut self,
41        datagram: &[u8],
42        sender_address_hash: u64,
43    ) -> Option<(Vec<u8>, bool)> {
44        self.receiver.receive(datagram, sender_address_hash)
45    }
46
47    /// Check if this connection is inactive
48    #[must_use]
49    pub fn is_inactive(&self, timeout: Duration) -> bool {
50        self.receiver.is_inactive(timeout)
51    }
52
53    /// Get connection ID
54    #[must_use]
55    pub const fn connection_id(&self) -> u16 {
56        self.connection_id
57    }
58
59    /// Get address hash
60    #[must_use]
61    pub const fn address_hash(&self) -> u64 {
62        self.address_hash
63    }
64}
65
66/// Server-side hub that manages complete bidirectional connections
67pub struct ServerHub {
68    connections: HashMap<u16, ServerConnection>,
69    seen_connection_requests: HashSet<(u64, u32)>,
70    max_fragments_per_connection: usize,
71    max_connections: usize,
72    connection_timeout: Duration,
73}
74
75impl ServerHub {
76    /// Create a new server hub
77    #[must_use]
78    pub fn new(
79        max_fragments_per_connection: usize,
80        max_connections: usize,
81        connection_timeout_seconds: u64,
82    ) -> Self {
83        Self {
84            connections: HashMap::new(),
85            seen_connection_requests: HashSet::new(),
86            max_fragments_per_connection,
87            max_connections,
88            connection_timeout: Duration::from_secs(connection_timeout_seconds),
89        }
90    }
91
92    /// Process incoming datagram
93    /// Returns:
94    /// - `Some((connection_id`, payload, `response_datagrams`, `is_new_connection`)) for complete messages
95    /// - None for incomplete/invalid datagrams
96    pub fn receive(
97        &mut self,
98        datagram: &[u8],
99        address_hash: u64,
100    ) -> Option<(u16, Vec<u8>, Vec<Vec<u8>>, bool)> {
101        // Parse header to get connection_id and flags
102        let (header, _) = read_datagram(datagram)?;
103
104        // Handle connection requests directly
105        if header.flags == FLAG_CONNECT_REQUEST {
106            trace!(len = datagram.len(), "got connect request");
107
108            let (_, connect_req, user_payload) = read_connect_request(datagram)?;
109            let request_key = (address_hash, connect_req.request_id);
110            let is_new_request = !self.seen_connection_requests.contains(&request_key);
111
112            if is_new_request {
113                self.seen_connection_requests.insert(request_key);
114            }
115
116            let (connection_id, is_new) = if is_new_request {
117                (self.create_connection(address_hash)?, true)
118            } else {
119                let existing_conn_id = self
120                    .connections
121                    .iter()
122                    .find(|(_, connection)| connection.address_hash() == address_hash)
123                    .map(|(&id, _)| id)?;
124                (existing_conn_id, false)
125            };
126
127            if let Some(response) =
128                self.create_connection_response(connect_req.request_id, connection_id, &[])
129            {
130                // Return the user payload from this fragment so the caller can reconstruct
131                return Some((connection_id, user_payload.to_vec(), vec![response], is_new));
132            }
133
134            return None;
135        }
136
137        if header.flags == FLAG_CONNECT_RESPONSE {
138            return None;
139        }
140
141        let connection_id = header.connection_id;
142        if let Some(connection) = self.connections.get_mut(&connection_id)
143            && let Some((payload, _is_new_conn)) = connection.receive(datagram, address_hash)
144        {
145            return Some((connection_id, payload, vec![], false));
146        }
147
148        None
149    }
150
151    /// Create a new connection after accepting a connection request
152    pub fn create_connection(&mut self, address_hash: u64) -> Option<u16> {
153        if self.connections.len() >= self.max_connections {
154            return None;
155        }
156
157        // Generate a unique connection ID
158        let mut connection_id = 1u16;
159        while self.connections.contains_key(&connection_id) {
160            connection_id = connection_id.wrapping_add(1);
161            if connection_id == 0 {
162                connection_id = 1; // Skip 0, reserved for connection requests
163            }
164        }
165
166        let connection = ServerConnection::new(
167            connection_id,
168            address_hash,
169            self.max_fragments_per_connection,
170        );
171        self.connections.insert(connection_id, connection);
172
173        Some(connection_id)
174    }
175
176    /// Send data to a specific connection
177    pub fn send_to(&mut self, connection_id: u16, payload: &[u8]) -> Option<Vec<Vec<u8>>> {
178        self.connections
179            .get_mut(&connection_id)
180            .map(|conn| conn.send(payload))
181    }
182
183    /// Create a connection response datagram
184    #[must_use]
185    pub fn create_connection_response(
186        &self,
187        request_id: u32,
188        connection_id: u16,
189        payload: &[u8],
190    ) -> Option<Vec<u8>> {
191        let mut buf = vec![0u8; 1024];
192        if let Some(len) =
193            write_connect_response(&mut buf, request_id, connection_id, 0, 0, payload)
194        {
195            buf.truncate(len);
196            Some(buf)
197        } else {
198            None
199        }
200    }
201
202    /// Remove inactive connections
203    pub fn cleanup_inactive_connections(&mut self) -> Vec<u16> {
204        let timeout = self.connection_timeout;
205        let mut inactive_connections = Vec::new();
206
207        for (&connection_id, connection) in &self.connections {
208            if connection.is_inactive(timeout) {
209                inactive_connections.push(connection_id);
210            }
211        }
212
213        for connection_id in &inactive_connections {
214            self.connections.remove(connection_id);
215        }
216
217        inactive_connections
218    }
219
220    /// Get connection count
221    #[must_use]
222    pub fn connection_count(&self) -> usize {
223        self.connections.len()
224    }
225
226    /// Check if connection exists
227    #[must_use]
228    pub fn has_connection(&self, connection_id: u16) -> bool {
229        self.connections.contains_key(&connection_id)
230    }
231
232    /// Remove a specific connection
233    pub fn remove_connection(&mut self, connection_id: u16) -> bool {
234        self.connections.remove(&connection_id).is_some()
235    }
236}
237
238/// Utility functions for generating address hashes from network addresses
239pub mod address_hash {
240    use std::collections::hash_map::DefaultHasher;
241    use std::hash::{Hash, Hasher};
242
243    /// Generate a hash from IPv4 address and port
244    #[must_use]
245    pub fn hash_ipv4(ip: u32, port: u16) -> u64 {
246        let mut hasher = DefaultHasher::new();
247        ip.hash(&mut hasher);
248        port.hash(&mut hasher);
249        hasher.finish()
250    }
251
252    /// Generate a hash from IPv6 address and port
253    #[must_use]
254    pub fn hash_ipv6(ip: [u8; 16], port: u16) -> u64 {
255        let mut hasher = DefaultHasher::new();
256        ip.hash(&mut hasher);
257        port.hash(&mut hasher);
258        hasher.finish()
259    }
260
261    /// Generate a hash from a string representation of an address
262    #[must_use]
263    pub fn hash_from_string(address_string: &str) -> u64 {
264        let mut hasher = DefaultHasher::new();
265        address_string.hash(&mut hasher);
266        hasher.finish()
267    }
268}