1use crate::error::ServerError;
4use crate::handler::{MessageHandler, Responder, SendError};
5use crate::session::SessionManager;
6use bytes::BytesMut;
7use futures::SinkExt;
8use ironsbe_channel::mpsc::{MpscChannel, MpscReceiver, MpscSender};
9use ironsbe_core::header::MessageHeader;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::net::TcpStream;
13use tokio::sync::mpsc as tokio_mpsc;
14use tokio_stream::StreamExt;
15use tokio_util::codec::{Decoder, Encoder, Framed};
16
17pub struct ServerBuilder<H> {
19 bind_addr: SocketAddr,
20 handler: Option<H>,
21 max_connections: usize,
22 max_frame_size: usize,
23 channel_capacity: usize,
24}
25
26impl<H: MessageHandler> ServerBuilder<H> {
27 #[must_use]
29 pub fn new() -> Self {
30 Self {
31 bind_addr: "0.0.0.0:9000".parse().unwrap(),
32 handler: None,
33 max_connections: 1000,
34 max_frame_size: 64 * 1024,
35 channel_capacity: 4096,
36 }
37 }
38
39 #[must_use]
41 pub fn bind(mut self, addr: SocketAddr) -> Self {
42 self.bind_addr = addr;
43 self
44 }
45
46 #[must_use]
48 pub fn handler(mut self, handler: H) -> Self {
49 self.handler = Some(handler);
50 self
51 }
52
53 #[must_use]
55 pub fn max_connections(mut self, max: usize) -> Self {
56 self.max_connections = max;
57 self
58 }
59
60 #[must_use]
62 pub fn max_frame_size(mut self, size: usize) -> Self {
63 self.max_frame_size = size;
64 self
65 }
66
67 #[must_use]
69 pub fn channel_capacity(mut self, capacity: usize) -> Self {
70 self.channel_capacity = capacity;
71 self
72 }
73
74 #[must_use]
79 pub fn build(self) -> (Server<H>, ServerHandle) {
80 let handler = self.handler.expect("Handler required");
81 let (cmd_tx, cmd_rx) = MpscChannel::bounded(self.channel_capacity);
82 let (event_tx, event_rx) = MpscChannel::bounded(self.channel_capacity);
83
84 let server = Server {
85 bind_addr: self.bind_addr,
86 handler: Arc::new(handler),
87 max_connections: self.max_connections,
88 max_frame_size: self.max_frame_size,
89 cmd_rx,
90 event_tx,
91 sessions: SessionManager::new(),
92 };
93
94 let handle = ServerHandle { cmd_tx, event_rx };
95
96 (server, handle)
97 }
98}
99
100impl<H: MessageHandler> Default for ServerBuilder<H> {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106#[allow(dead_code)]
108pub struct Server<H> {
109 bind_addr: SocketAddr,
110 handler: Arc<H>,
111 max_connections: usize,
112 max_frame_size: usize,
113 cmd_rx: MpscReceiver<ServerCommand>,
114 event_tx: MpscSender<ServerEvent>,
115 sessions: SessionManager,
116}
117
118impl<H: MessageHandler + Send + Sync + 'static> Server<H> {
119 pub async fn run(&mut self) -> Result<(), ServerError> {
124 let listener = tokio::net::TcpListener::bind(self.bind_addr).await?;
125 tracing::info!("Server listening on {}", self.bind_addr);
126
127 loop {
128 tokio::select! {
129 result = listener.accept() => {
130 match result {
131 Ok((stream, addr)) => {
132 self.handle_connection(stream, addr).await;
133 }
134 Err(e) => {
135 tracing::error!("Accept error: {}", e);
136 }
137 }
138 }
139
140 cmd = async { self.cmd_rx.try_recv() } => {
141 if let Some(cmd) = cmd && self.handle_command(cmd).await {
142 return Ok(());
143 }
144 }
145 }
146 }
147 }
148
149 async fn handle_connection(&mut self, stream: TcpStream, addr: SocketAddr) {
150 if self.sessions.count() >= self.max_connections {
151 tracing::warn!("Max connections reached, rejecting {}", addr);
152 return;
153 }
154
155 let session_id = self.sessions.create_session(addr);
156 let handler = Arc::clone(&self.handler);
157 let event_tx = self.event_tx.clone();
158 let max_frame_size = self.max_frame_size;
159
160 handler.on_session_start(session_id);
161 let _ = event_tx.try_send(ServerEvent::SessionCreated(session_id, addr));
162
163 tokio::spawn(async move {
165 tracing::info!("Session {} connected from {}", session_id, addr);
166
167 if let Err(e) =
168 handle_session(session_id, stream, handler.as_ref(), max_frame_size).await
169 {
170 tracing::error!("Session {} error: {:?}", session_id, e);
171 }
172
173 handler.on_session_end(session_id);
175 let _ = event_tx.try_send(ServerEvent::SessionClosed(session_id));
176 });
177 }
178
179 async fn handle_command(&mut self, cmd: ServerCommand) -> bool {
180 match cmd {
181 ServerCommand::Shutdown => {
182 tracing::info!("Server shutdown requested");
183 true
184 }
185 ServerCommand::CloseSession(session_id) => {
186 self.sessions.close_session(session_id);
187 false
188 }
189 ServerCommand::Broadcast(_message) => {
190 false
192 }
193 }
194 }
195}
196
197pub struct ServerHandle {
199 cmd_tx: MpscSender<ServerCommand>,
200 event_rx: MpscReceiver<ServerEvent>,
201}
202
203impl ServerHandle {
204 pub fn shutdown(&self) {
206 let _ = self.cmd_tx.try_send(ServerCommand::Shutdown);
207 }
208
209 pub fn close_session(&self, session_id: u64) {
211 let _ = self
212 .cmd_tx
213 .try_send(ServerCommand::CloseSession(session_id));
214 }
215
216 pub fn broadcast(&self, message: Vec<u8>) {
218 let _ = self.cmd_tx.try_send(ServerCommand::Broadcast(message));
219 }
220
221 pub fn poll_events(&self) -> impl Iterator<Item = ServerEvent> + '_ {
223 std::iter::from_fn(|| self.event_rx.try_recv())
224 }
225}
226
227#[derive(Debug)]
229pub enum ServerCommand {
230 Shutdown,
232 CloseSession(u64),
234 Broadcast(Vec<u8>),
236}
237
238#[derive(Debug, Clone)]
240pub enum ServerEvent {
241 SessionCreated(u64, SocketAddr),
243 SessionClosed(u64),
245 Error(String),
247}
248
249struct SbeFrameCodec {
251 max_frame_size: usize,
252}
253
254impl SbeFrameCodec {
255 fn new(max_frame_size: usize) -> Self {
256 Self { max_frame_size }
257 }
258}
259
260impl Decoder for SbeFrameCodec {
261 type Item = BytesMut;
262 type Error = std::io::Error;
263
264 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
265 if src.len() < 4 {
266 return Ok(None);
267 }
268
269 let length = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
270
271 if length > self.max_frame_size {
272 return Err(std::io::Error::new(
273 std::io::ErrorKind::InvalidData,
274 format!("Frame too large: {} > {}", length, self.max_frame_size),
275 ));
276 }
277
278 if src.len() < 4 + length {
279 src.reserve(4 + length - src.len());
280 return Ok(None);
281 }
282
283 let _ = src.split_to(4);
284 Ok(Some(src.split_to(length)))
285 }
286}
287
288impl<T: AsRef<[u8]>> Encoder<T> for SbeFrameCodec {
289 type Error = std::io::Error;
290
291 fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
292 let data = item.as_ref();
293 let length = data.len() as u32;
294 dst.reserve(4 + data.len());
295 dst.extend_from_slice(&length.to_le_bytes());
296 dst.extend_from_slice(data);
297 Ok(())
298 }
299}
300
301struct SessionResponder {
303 tx: tokio_mpsc::UnboundedSender<Vec<u8>>,
304}
305
306impl Responder for SessionResponder {
307 fn send(&self, message: &[u8]) -> Result<(), SendError> {
308 self.tx.send(message.to_vec()).map_err(|_| SendError {
309 message: "channel closed".to_string(),
310 })
311 }
312
313 fn send_to(&self, _session_id: u64, message: &[u8]) -> Result<(), SendError> {
314 self.send(message)
316 }
317}
318
319async fn handle_session<H: MessageHandler>(
321 session_id: u64,
322 stream: TcpStream,
323 handler: &H,
324 max_frame_size: usize,
325) -> Result<(), std::io::Error> {
326 let codec = SbeFrameCodec::new(max_frame_size);
327 let mut framed = Framed::new(stream, codec);
328
329 let (tx, mut rx) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
331 let responder = SessionResponder { tx };
332
333 loop {
334 tokio::select! {
335 result = framed.next() => {
337 match result {
338 Some(Ok(data)) => {
339 if data.len() >= MessageHeader::ENCODED_LENGTH {
341 let header = MessageHeader::wrap(data.as_ref(), 0);
342 handler.on_message(session_id, &header, data.as_ref(), &responder);
343 } else {
344 handler.on_error(session_id, "Message too short for header");
345 }
346 }
347 Some(Err(e)) => {
348 tracing::error!("Session {} read error: {}", session_id, e);
349 return Err(e);
350 }
351 None => {
352 tracing::info!("Session {} disconnected", session_id);
353 return Ok(());
354 }
355 }
356 }
357
358 Some(msg) = rx.recv() => {
360 if let Err(e) = framed.send(msg).await {
361 tracing::error!("Session {} write error: {}", session_id, e);
362 return Err(e);
363 }
364 }
365 }
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 struct TestHandler;
374
375 impl MessageHandler for TestHandler {
376 fn on_message(
377 &self,
378 _session_id: u64,
379 _header: &MessageHeader,
380 _data: &[u8],
381 _responder: &dyn Responder,
382 ) {
383 }
384 }
385
386 #[test]
387 fn test_server_builder_new() {
388 let builder = ServerBuilder::<TestHandler>::new();
389 let _ = builder;
390 }
391
392 #[test]
393 fn test_server_builder_default() {
394 let builder = ServerBuilder::<TestHandler>::default();
395 let _ = builder;
396 }
397
398 #[test]
399 fn test_server_builder_bind() {
400 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
401 let builder = ServerBuilder::<TestHandler>::new().bind(addr);
402 let _ = builder;
403 }
404
405 #[test]
406 fn test_server_builder_handler() {
407 let builder = ServerBuilder::new().handler(TestHandler);
408 let _ = builder;
409 }
410
411 #[test]
412 fn test_server_builder_max_connections() {
413 let builder = ServerBuilder::<TestHandler>::new().max_connections(500);
414 let _ = builder;
415 }
416
417 #[test]
418 fn test_server_builder_max_frame_size() {
419 let builder = ServerBuilder::<TestHandler>::new().max_frame_size(128 * 1024);
420 let _ = builder;
421 }
422
423 #[test]
424 fn test_server_builder_channel_capacity() {
425 let builder = ServerBuilder::<TestHandler>::new().channel_capacity(8192);
426 let _ = builder;
427 }
428
429 #[test]
430 fn test_server_builder_build() {
431 let (_server, _handle) = ServerBuilder::new().handler(TestHandler).build();
432 }
433
434 #[test]
435 fn test_server_command_debug() {
436 let cmd = ServerCommand::Shutdown;
437 let debug_str = format!("{:?}", cmd);
438 assert!(debug_str.contains("Shutdown"));
439
440 let cmd2 = ServerCommand::CloseSession(42);
441 let debug_str2 = format!("{:?}", cmd2);
442 assert!(debug_str2.contains("CloseSession"));
443
444 let cmd3 = ServerCommand::Broadcast(vec![1, 2, 3]);
445 let debug_str3 = format!("{:?}", cmd3);
446 assert!(debug_str3.contains("Broadcast"));
447 }
448
449 #[test]
450 fn test_server_event_clone_debug() {
451 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
452 let event = ServerEvent::SessionCreated(1, addr);
453 let cloned = event.clone();
454 let _ = cloned;
455
456 let debug_str = format!("{:?}", event);
457 assert!(debug_str.contains("SessionCreated"));
458
459 let event2 = ServerEvent::SessionClosed(1);
460 let debug_str2 = format!("{:?}", event2);
461 assert!(debug_str2.contains("SessionClosed"));
462
463 let event3 = ServerEvent::Error("test error".to_string());
464 let debug_str3 = format!("{:?}", event3);
465 assert!(debug_str3.contains("Error"));
466 }
467
468 #[test]
469 fn test_server_handle_shutdown() {
470 let (_server, handle) = ServerBuilder::new().handler(TestHandler).build();
471 handle.shutdown();
472 }
473
474 #[test]
475 fn test_server_handle_close_session() {
476 let (_server, handle) = ServerBuilder::new().handler(TestHandler).build();
477 handle.close_session(1);
478 }
479
480 #[test]
481 fn test_server_handle_broadcast() {
482 let (_server, handle) = ServerBuilder::new().handler(TestHandler).build();
483 handle.broadcast(vec![1, 2, 3]);
484 }
485
486 #[test]
487 fn test_sbe_frame_codec_new() {
488 let codec = SbeFrameCodec::new(64 * 1024);
489 assert_eq!(codec.max_frame_size, 64 * 1024);
490 }
491
492 #[test]
493 fn test_sbe_frame_codec_decode_incomplete() {
494 let mut codec = SbeFrameCodec::new(1024);
495 let mut buf = BytesMut::from(&[0u8, 0, 0][..]);
496 assert!(codec.decode(&mut buf).unwrap().is_none());
497 }
498
499 #[test]
500 fn test_sbe_frame_codec_decode_complete() {
501 let mut codec = SbeFrameCodec::new(1024);
502 let mut buf = BytesMut::new();
503 buf.extend_from_slice(&5u32.to_le_bytes());
504 buf.extend_from_slice(b"hello");
505
506 let result = codec.decode(&mut buf).unwrap();
507 assert!(result.is_some());
508 assert_eq!(result.unwrap().as_ref(), b"hello");
509 }
510
511 #[test]
512 fn test_sbe_frame_codec_decode_too_large() {
513 let mut codec = SbeFrameCodec::new(10);
514 let mut buf = BytesMut::new();
515 buf.extend_from_slice(&100u32.to_le_bytes());
516
517 let result = codec.decode(&mut buf);
518 assert!(result.is_err());
519 }
520
521 #[test]
522 fn test_sbe_frame_codec_encode() {
523 let mut codec = SbeFrameCodec::new(1024);
524 let mut buf = BytesMut::new();
525 codec.encode(b"hello", &mut buf).unwrap();
526
527 assert_eq!(&buf[0..4], &5u32.to_le_bytes());
528 assert_eq!(&buf[4..9], b"hello");
529 }
530}