1use crate::error::ServerError;
4use crate::handler::{MessageHandler, Responder, SendError};
5use crate::session::SessionManager;
6use bytes::BytesMut;
7use futures::SinkExt;
8use ironsbe_channel::mpsc::{MpscChannel, MpscReceiver, MpscSender};
9use ironsbe_core::header::MessageHeader;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::net::TcpStream;
13use tokio::sync::{Notify, mpsc as tokio_mpsc};
14use tokio_stream::StreamExt;
15use tokio_util::codec::{Decoder, Encoder, Framed};
16
17pub struct ServerBuilder<H> {
19 bind_addr: SocketAddr,
20 handler: Option<H>,
21 max_connections: usize,
22 max_frame_size: usize,
23 channel_capacity: usize,
24}
25
26impl<H: MessageHandler> ServerBuilder<H> {
27 #[must_use]
29 pub fn new() -> Self {
30 Self {
31 bind_addr: "0.0.0.0:9000".parse().unwrap(),
32 handler: None,
33 max_connections: 1000,
34 max_frame_size: 64 * 1024,
35 channel_capacity: 4096,
36 }
37 }
38
39 #[must_use]
41 pub fn bind(mut self, addr: SocketAddr) -> Self {
42 self.bind_addr = addr;
43 self
44 }
45
46 #[must_use]
48 pub fn handler(mut self, handler: H) -> Self {
49 self.handler = Some(handler);
50 self
51 }
52
53 #[must_use]
55 pub fn max_connections(mut self, max: usize) -> Self {
56 self.max_connections = max;
57 self
58 }
59
60 #[must_use]
62 pub fn max_frame_size(mut self, size: usize) -> Self {
63 self.max_frame_size = size;
64 self
65 }
66
67 #[must_use]
69 pub fn channel_capacity(mut self, capacity: usize) -> Self {
70 self.channel_capacity = capacity;
71 self
72 }
73
74 #[must_use]
79 pub fn build(self) -> (Server<H>, ServerHandle) {
80 let handler = self.handler.expect("Handler required");
81 let (cmd_tx, cmd_rx) = MpscChannel::bounded(self.channel_capacity);
82 let (event_tx, event_rx) = MpscChannel::bounded(self.channel_capacity);
83
84 let cmd_notify = Arc::new(Notify::new());
85
86 let server = Server {
87 bind_addr: self.bind_addr,
88 handler: Arc::new(handler),
89 max_connections: self.max_connections,
90 max_frame_size: self.max_frame_size,
91 cmd_rx,
92 event_tx,
93 sessions: SessionManager::new(),
94 cmd_notify: Arc::clone(&cmd_notify),
95 };
96
97 let handle = ServerHandle {
98 cmd_tx,
99 event_rx,
100 cmd_notify,
101 };
102
103 (server, handle)
104 }
105}
106
107impl<H: MessageHandler> Default for ServerBuilder<H> {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113#[allow(dead_code)]
115pub struct Server<H> {
116 bind_addr: SocketAddr,
117 handler: Arc<H>,
118 max_connections: usize,
119 max_frame_size: usize,
120 cmd_rx: MpscReceiver<ServerCommand>,
121 event_tx: MpscSender<ServerEvent>,
122 sessions: SessionManager,
123 cmd_notify: Arc<Notify>,
124}
125
126impl<H: MessageHandler + Send + Sync + 'static> Server<H> {
127 pub async fn run(&mut self) -> Result<(), ServerError> {
132 let listener = tokio::net::TcpListener::bind(self.bind_addr).await?;
133 tracing::info!("Server listening on {}", self.bind_addr);
134
135 loop {
136 tokio::select! {
137 result = listener.accept() => {
138 match result {
139 Ok((stream, addr)) => {
140 self.handle_connection(stream, addr).await;
141 }
142 Err(e) => {
143 tracing::error!("Accept error: {}", e);
144 }
145 }
146 }
147
148 _ = self.cmd_notify.notified() => {
149 while let Some(cmd) = self.cmd_rx.try_recv() {
150 if self.handle_command(cmd).await {
151 return Ok(());
152 }
153 }
154 }
155 }
156 }
157 }
158
159 async fn handle_connection(&mut self, stream: TcpStream, addr: SocketAddr) {
160 if self.sessions.count() >= self.max_connections {
161 tracing::warn!("Max connections reached, rejecting {}", addr);
162 return;
163 }
164
165 let session_id = self.sessions.create_session(addr);
166 let handler = Arc::clone(&self.handler);
167 let event_tx = self.event_tx.clone();
168 let max_frame_size = self.max_frame_size;
169
170 handler.on_session_start(session_id);
171 let _ = event_tx.try_send(ServerEvent::SessionCreated(session_id, addr));
172
173 tokio::spawn(async move {
175 tracing::info!("Session {} connected from {}", session_id, addr);
176
177 if let Err(e) =
178 handle_session(session_id, stream, handler.as_ref(), max_frame_size).await
179 {
180 tracing::error!("Session {} error: {:?}", session_id, e);
181 }
182
183 handler.on_session_end(session_id);
185 let _ = event_tx.try_send(ServerEvent::SessionClosed(session_id));
186 });
187 }
188
189 async fn handle_command(&mut self, cmd: ServerCommand) -> bool {
190 match cmd {
191 ServerCommand::Shutdown => {
192 tracing::info!("Server shutdown requested");
193 true
194 }
195 ServerCommand::CloseSession(session_id) => {
196 self.sessions.close_session(session_id);
197 false
198 }
199 ServerCommand::Broadcast(_message) => {
200 false
202 }
203 }
204 }
205}
206
207pub struct ServerHandle {
209 cmd_tx: MpscSender<ServerCommand>,
210 event_rx: MpscReceiver<ServerEvent>,
211 cmd_notify: Arc<Notify>,
212}
213
214impl ServerHandle {
215 pub fn shutdown(&self) {
217 let _ = self.cmd_tx.try_send(ServerCommand::Shutdown);
218 self.cmd_notify.notify_one();
219 }
220
221 pub fn close_session(&self, session_id: u64) {
223 let _ = self
224 .cmd_tx
225 .try_send(ServerCommand::CloseSession(session_id));
226 self.cmd_notify.notify_one();
227 }
228
229 pub fn broadcast(&self, message: Vec<u8>) {
231 let _ = self.cmd_tx.try_send(ServerCommand::Broadcast(message));
232 self.cmd_notify.notify_one();
233 }
234
235 pub fn poll_events(&self) -> impl Iterator<Item = ServerEvent> + '_ {
237 std::iter::from_fn(|| self.event_rx.try_recv())
238 }
239}
240
241#[derive(Debug)]
243pub enum ServerCommand {
244 Shutdown,
246 CloseSession(u64),
248 Broadcast(Vec<u8>),
250}
251
252#[derive(Debug, Clone)]
254pub enum ServerEvent {
255 SessionCreated(u64, SocketAddr),
257 SessionClosed(u64),
259 Error(String),
261}
262
263struct SbeFrameCodec {
265 max_frame_size: usize,
266}
267
268impl SbeFrameCodec {
269 fn new(max_frame_size: usize) -> Self {
270 Self { max_frame_size }
271 }
272}
273
274impl Decoder for SbeFrameCodec {
275 type Item = BytesMut;
276 type Error = std::io::Error;
277
278 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
279 if src.len() < 4 {
280 return Ok(None);
281 }
282
283 let length = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
284
285 if length > self.max_frame_size {
286 return Err(std::io::Error::new(
287 std::io::ErrorKind::InvalidData,
288 format!("Frame too large: {} > {}", length, self.max_frame_size),
289 ));
290 }
291
292 if src.len() < 4 + length {
293 src.reserve(4 + length - src.len());
294 return Ok(None);
295 }
296
297 let _ = src.split_to(4);
298 Ok(Some(src.split_to(length)))
299 }
300}
301
302impl<T: AsRef<[u8]>> Encoder<T> for SbeFrameCodec {
303 type Error = std::io::Error;
304
305 fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
306 let data = item.as_ref();
307 let length = data.len() as u32;
308 dst.reserve(4 + data.len());
309 dst.extend_from_slice(&length.to_le_bytes());
310 dst.extend_from_slice(data);
311 Ok(())
312 }
313}
314
315struct SessionResponder {
317 tx: tokio_mpsc::UnboundedSender<Vec<u8>>,
318}
319
320impl Responder for SessionResponder {
321 fn send(&self, message: &[u8]) -> Result<(), SendError> {
322 self.tx.send(message.to_vec()).map_err(|_| SendError {
323 message: "channel closed".to_string(),
324 })
325 }
326
327 fn send_to(&self, _session_id: u64, message: &[u8]) -> Result<(), SendError> {
328 self.send(message)
330 }
331}
332
333async fn handle_session<H: MessageHandler>(
335 session_id: u64,
336 stream: TcpStream,
337 handler: &H,
338 max_frame_size: usize,
339) -> Result<(), std::io::Error> {
340 let codec = SbeFrameCodec::new(max_frame_size);
341 let mut framed = Framed::new(stream, codec);
342
343 let (tx, mut rx) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
345 let responder = SessionResponder { tx };
346
347 loop {
348 tokio::select! {
349 result = framed.next() => {
351 match result {
352 Some(Ok(data)) => {
353 if data.len() >= MessageHeader::ENCODED_LENGTH {
355 let header = MessageHeader::wrap(data.as_ref(), 0);
356 handler.on_message(session_id, &header, data.as_ref(), &responder);
357 } else {
358 handler.on_error(session_id, "Message too short for header");
359 }
360 }
361 Some(Err(e)) => {
362 tracing::error!("Session {} read error: {}", session_id, e);
363 return Err(e);
364 }
365 None => {
366 tracing::info!("Session {} disconnected", session_id);
367 return Ok(());
368 }
369 }
370 }
371
372 Some(msg) = rx.recv() => {
374 if let Err(e) = framed.send(msg).await {
375 tracing::error!("Session {} write error: {}", session_id, e);
376 return Err(e);
377 }
378 }
379 }
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 struct TestHandler;
388
389 impl MessageHandler for TestHandler {
390 fn on_message(
391 &self,
392 _session_id: u64,
393 _header: &MessageHeader,
394 _data: &[u8],
395 _responder: &dyn Responder,
396 ) {
397 }
398 }
399
400 #[test]
401 fn test_server_builder_new() {
402 let builder = ServerBuilder::<TestHandler>::new();
403 let _ = builder;
404 }
405
406 #[test]
407 fn test_server_builder_default() {
408 let builder = ServerBuilder::<TestHandler>::default();
409 let _ = builder;
410 }
411
412 #[test]
413 fn test_server_builder_bind() {
414 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
415 let builder = ServerBuilder::<TestHandler>::new().bind(addr);
416 let _ = builder;
417 }
418
419 #[test]
420 fn test_server_builder_handler() {
421 let builder = ServerBuilder::new().handler(TestHandler);
422 let _ = builder;
423 }
424
425 #[test]
426 fn test_server_builder_max_connections() {
427 let builder = ServerBuilder::<TestHandler>::new().max_connections(500);
428 let _ = builder;
429 }
430
431 #[test]
432 fn test_server_builder_max_frame_size() {
433 let builder = ServerBuilder::<TestHandler>::new().max_frame_size(128 * 1024);
434 let _ = builder;
435 }
436
437 #[test]
438 fn test_server_builder_channel_capacity() {
439 let builder = ServerBuilder::<TestHandler>::new().channel_capacity(8192);
440 let _ = builder;
441 }
442
443 #[test]
444 fn test_server_builder_build() {
445 let (_server, _handle) = ServerBuilder::new().handler(TestHandler).build();
446 }
447
448 #[test]
449 fn test_server_command_debug() {
450 let cmd = ServerCommand::Shutdown;
451 let debug_str = format!("{:?}", cmd);
452 assert!(debug_str.contains("Shutdown"));
453
454 let cmd2 = ServerCommand::CloseSession(42);
455 let debug_str2 = format!("{:?}", cmd2);
456 assert!(debug_str2.contains("CloseSession"));
457
458 let cmd3 = ServerCommand::Broadcast(vec![1, 2, 3]);
459 let debug_str3 = format!("{:?}", cmd3);
460 assert!(debug_str3.contains("Broadcast"));
461 }
462
463 #[test]
464 fn test_server_event_clone_debug() {
465 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
466 let event = ServerEvent::SessionCreated(1, addr);
467 let cloned = event.clone();
468 let _ = cloned;
469
470 let debug_str = format!("{:?}", event);
471 assert!(debug_str.contains("SessionCreated"));
472
473 let event2 = ServerEvent::SessionClosed(1);
474 let debug_str2 = format!("{:?}", event2);
475 assert!(debug_str2.contains("SessionClosed"));
476
477 let event3 = ServerEvent::Error("test error".to_string());
478 let debug_str3 = format!("{:?}", event3);
479 assert!(debug_str3.contains("Error"));
480 }
481
482 #[test]
483 fn test_server_handle_shutdown() {
484 let (_server, handle) = ServerBuilder::new().handler(TestHandler).build();
485 handle.shutdown();
486 }
487
488 #[test]
489 fn test_server_handle_close_session() {
490 let (_server, handle) = ServerBuilder::new().handler(TestHandler).build();
491 handle.close_session(1);
492 }
493
494 #[test]
495 fn test_server_handle_broadcast() {
496 let (_server, handle) = ServerBuilder::new().handler(TestHandler).build();
497 handle.broadcast(vec![1, 2, 3]);
498 }
499
500 #[test]
501 fn test_sbe_frame_codec_new() {
502 let codec = SbeFrameCodec::new(64 * 1024);
503 assert_eq!(codec.max_frame_size, 64 * 1024);
504 }
505
506 #[test]
507 fn test_sbe_frame_codec_decode_incomplete() {
508 let mut codec = SbeFrameCodec::new(1024);
509 let mut buf = BytesMut::from(&[0u8, 0, 0][..]);
510 assert!(codec.decode(&mut buf).unwrap().is_none());
511 }
512
513 #[test]
514 fn test_sbe_frame_codec_decode_complete() {
515 let mut codec = SbeFrameCodec::new(1024);
516 let mut buf = BytesMut::new();
517 buf.extend_from_slice(&5u32.to_le_bytes());
518 buf.extend_from_slice(b"hello");
519
520 let result = codec.decode(&mut buf).unwrap();
521 assert!(result.is_some());
522 assert_eq!(result.unwrap().as_ref(), b"hello");
523 }
524
525 #[test]
526 fn test_sbe_frame_codec_decode_too_large() {
527 let mut codec = SbeFrameCodec::new(10);
528 let mut buf = BytesMut::new();
529 buf.extend_from_slice(&100u32.to_le_bytes());
530
531 let result = codec.decode(&mut buf);
532 assert!(result.is_err());
533 }
534
535 #[test]
536 fn test_sbe_frame_codec_encode() {
537 let mut codec = SbeFrameCodec::new(1024);
538 let mut buf = BytesMut::new();
539 codec.encode(b"hello", &mut buf).unwrap();
540
541 assert_eq!(&buf[0..4], &5u32.to_le_bytes());
542 assert_eq!(&buf[4..9], b"hello");
543 }
544}