nomad_protocol/server/
server.rs1use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Duration;
10
11use thiserror::Error;
12use tokio::net::UdpSocket;
13use tokio::sync::{mpsc, oneshot, RwLock};
14
15use super::session::{ServerSession, ServerSessionId};
16use crate::core::SyncState;
17
18#[derive(Debug, Error)]
20pub enum ServerError {
21 #[error("bind failed: {0}")]
23 BindFailed(String),
24
25 #[error("session error: {0}")]
27 SessionError(String),
28
29 #[error("I/O error: {0}")]
31 Io(#[from] std::io::Error),
32
33 #[error("server shut down")]
35 Shutdown,
36
37 #[error("invalid handshake: {0}")]
39 InvalidHandshake(String),
40}
41
42#[derive(Debug, Clone)]
44pub struct ServerConfig {
45 pub bind_addr: SocketAddr,
47
48 pub private_key: [u8; 32],
50
51 pub max_sessions: usize,
53
54 pub session_timeout: Duration,
56
57 pub enable_compression: bool,
59}
60
61impl Default for ServerConfig {
62 fn default() -> Self {
63 Self {
64 bind_addr: "0.0.0.0:19999".parse().unwrap(),
65 private_key: [0u8; 32],
66 max_sessions: 1000,
67 session_timeout: Duration::from_secs(300),
68 enable_compression: true,
69 }
70 }
71}
72
73#[derive(Debug)]
75pub struct NomadServerBuilder {
76 config: ServerConfig,
77}
78
79impl NomadServerBuilder {
80 pub fn new() -> Self {
82 Self {
83 config: ServerConfig::default(),
84 }
85 }
86
87 pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
89 self.config.bind_addr = addr;
90 self
91 }
92
93 pub fn private_key(mut self, key: [u8; 32]) -> Self {
95 self.config.private_key = key;
96 self
97 }
98
99 pub fn max_sessions(mut self, max: usize) -> Self {
101 self.config.max_sessions = max;
102 self
103 }
104
105 pub fn session_timeout(mut self, timeout: Duration) -> Self {
107 self.config.session_timeout = timeout;
108 self
109 }
110
111 pub fn compression(mut self, enabled: bool) -> Self {
113 self.config.enable_compression = enabled;
114 self
115 }
116
117 pub fn build(self) -> ServerConfig {
119 self.config
120 }
121}
122
123impl Default for NomadServerBuilder {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129#[derive(Debug)]
131pub enum ServerEvent<S: SyncState> {
132 ClientConnected {
134 session_id: ServerSessionId,
136 client_public_key: [u8; 32],
138 },
139
140 StateUpdated {
142 session_id: ServerSessionId,
144 state: S,
146 },
147
148 ClientDisconnected {
150 session_id: ServerSessionId,
152 },
153}
154
155pub struct SessionSender<S: SyncState> {
157 session_id: ServerSessionId,
158 tx: mpsc::Sender<(ServerSessionId, S)>,
159}
160
161impl<S: SyncState> SessionSender<S> {
162 pub async fn send(&self, state: S) -> Result<(), ServerError> {
164 self.tx
165 .send((self.session_id, state))
166 .await
167 .map_err(|_| ServerError::Shutdown)
168 }
169
170 pub fn session_id(&self) -> ServerSessionId {
172 self.session_id
173 }
174}
175
176impl<S: SyncState> Clone for SessionSender<S> {
177 fn clone(&self) -> Self {
178 Self {
179 session_id: self.session_id,
180 tx: self.tx.clone(),
181 }
182 }
183}
184
185pub struct NomadServer<S: SyncState> {
218 config: ServerConfig,
220
221 sessions: Arc<RwLock<HashMap<ServerSessionId, ServerSession<S>>>>,
223
224 state_tx: mpsc::Sender<(ServerSessionId, S)>,
226
227 shutdown_tx: Option<oneshot::Sender<()>>,
229
230 local_addr: SocketAddr,
232}
233
234impl<S: SyncState> NomadServer<S> {
235 pub async fn bind<F>(
239 config: ServerConfig,
240 _state_factory: F,
241 ) -> Result<(Self, mpsc::Receiver<ServerEvent<S>>), ServerError>
242 where
243 F: Fn() -> S + Send + Sync + 'static,
244 {
245 let socket = UdpSocket::bind(config.bind_addr)
247 .await
248 .map_err(|e| ServerError::BindFailed(e.to_string()))?;
249
250 let local_addr = socket.local_addr()?;
251
252 let (state_tx, _state_rx) = mpsc::channel::<(ServerSessionId, S)>(256);
254 let (event_tx, event_rx) = mpsc::channel::<ServerEvent<S>>(256);
255 let (shutdown_tx, _shutdown_rx) = oneshot::channel();
256
257 let sessions: Arc<RwLock<HashMap<ServerSessionId, ServerSession<S>>>> =
258 Arc::new(RwLock::new(HashMap::new()));
259
260 let _sessions_clone = sessions.clone();
262 let _config_clone = config.clone();
263 let _event_tx = event_tx;
264
265 tokio::spawn(async move {
266 let mut buf = [0u8; 65535];
271 while let Ok((_len, _addr)) = socket.recv_from(&mut buf).await {
272 }
275 });
276
277 let server = Self {
278 config,
279 sessions,
280 state_tx,
281 shutdown_tx: Some(shutdown_tx),
282 local_addr,
283 };
284
285 Ok((server, event_rx))
286 }
287
288 pub fn local_addr(&self) -> SocketAddr {
290 self.local_addr
291 }
292
293 pub async fn session_count(&self) -> usize {
295 self.sessions.read().await.len()
296 }
297
298 pub async fn send_to(&self, session_id: ServerSessionId, state: S) -> Result<(), ServerError> {
300 self.state_tx
301 .send((session_id, state))
302 .await
303 .map_err(|_| ServerError::Shutdown)
304 }
305
306 pub async fn broadcast(&self, state: S) -> Result<(), ServerError> {
308 let sessions = self.sessions.read().await;
309 for session_id in sessions.keys() {
310 self.state_tx
311 .send((*session_id, state.clone()))
312 .await
313 .map_err(|_| ServerError::Shutdown)?;
314 }
315 Ok(())
316 }
317
318 pub fn session_sender(&self, session_id: ServerSessionId) -> SessionSender<S> {
320 SessionSender {
321 session_id,
322 tx: self.state_tx.clone(),
323 }
324 }
325
326 pub async fn disconnect(&self, session_id: ServerSessionId) -> Result<(), ServerError> {
328 let mut sessions = self.sessions.write().await;
329 if sessions.remove(&session_id).is_some() {
330 Ok(())
332 } else {
333 Err(ServerError::SessionError(format!(
334 "session not found: {:?}",
335 session_id
336 )))
337 }
338 }
339
340 pub async fn shutdown(mut self) -> Result<(), ServerError> {
342 if let Some(tx) = self.shutdown_tx.take() {
344 let _ = tx.send(());
345 }
346
347 self.sessions.write().await.clear();
351
352 Ok(())
353 }
354
355 pub fn config(&self) -> &ServerConfig {
357 &self.config
358 }
359}
360
361impl<S: SyncState> Drop for NomadServer<S> {
362 fn drop(&mut self) {
363 if let Some(tx) = self.shutdown_tx.take() {
365 let _ = tx.send(());
366 }
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 }