Skip to main content

ironsbe_server/
builder.rs

1//! Server builder and main server implementation.
2
3use crate::error::ServerError;
4use crate::handler::{MessageHandler, Responder, SendError};
5use crate::session::SessionManager;
6use bytes::BytesMut;
7use futures::SinkExt;
8use ironsbe_channel::mpsc::{MpscChannel, MpscReceiver, MpscSender};
9use ironsbe_core::header::MessageHeader;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::net::TcpStream;
13use tokio::sync::{Notify, mpsc as tokio_mpsc};
14use tokio_stream::StreamExt;
15use tokio_util::codec::{Decoder, Encoder, Framed};
16
17/// Builder for configuring and creating a server.
18pub struct ServerBuilder<H> {
19    bind_addr: SocketAddr,
20    handler: Option<H>,
21    max_connections: usize,
22    max_frame_size: usize,
23    channel_capacity: usize,
24}
25
26impl<H: MessageHandler> ServerBuilder<H> {
27    /// Creates a new server builder with default settings.
28    #[must_use]
29    pub fn new() -> Self {
30        Self {
31            bind_addr: "0.0.0.0:9000".parse().unwrap(),
32            handler: None,
33            max_connections: 1000,
34            max_frame_size: 64 * 1024,
35            channel_capacity: 4096,
36        }
37    }
38
39    /// Sets the bind address.
40    #[must_use]
41    pub fn bind(mut self, addr: SocketAddr) -> Self {
42        self.bind_addr = addr;
43        self
44    }
45
46    /// Sets the message handler.
47    #[must_use]
48    pub fn handler(mut self, handler: H) -> Self {
49        self.handler = Some(handler);
50        self
51    }
52
53    /// Sets the maximum number of connections.
54    #[must_use]
55    pub fn max_connections(mut self, max: usize) -> Self {
56        self.max_connections = max;
57        self
58    }
59
60    /// Sets the maximum frame size.
61    #[must_use]
62    pub fn max_frame_size(mut self, size: usize) -> Self {
63        self.max_frame_size = size;
64        self
65    }
66
67    /// Sets the channel capacity.
68    #[must_use]
69    pub fn channel_capacity(mut self, capacity: usize) -> Self {
70        self.channel_capacity = capacity;
71        self
72    }
73
74    /// Builds the server and handle.
75    ///
76    /// # Panics
77    /// Panics if no handler was set.
78    #[must_use]
79    pub fn build(self) -> (Server<H>, ServerHandle) {
80        let handler = self.handler.expect("Handler required");
81        let (cmd_tx, cmd_rx) = MpscChannel::bounded(self.channel_capacity);
82        let (event_tx, event_rx) = MpscChannel::bounded(self.channel_capacity);
83
84        let cmd_notify = Arc::new(Notify::new());
85
86        let server = Server {
87            bind_addr: self.bind_addr,
88            handler: Arc::new(handler),
89            max_connections: self.max_connections,
90            max_frame_size: self.max_frame_size,
91            cmd_rx,
92            event_tx,
93            sessions: SessionManager::new(),
94            cmd_notify: Arc::clone(&cmd_notify),
95        };
96
97        let handle = ServerHandle {
98            cmd_tx,
99            event_rx,
100            cmd_notify,
101        };
102
103        (server, handle)
104    }
105}
106
107impl<H: MessageHandler> Default for ServerBuilder<H> {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112
113/// The main server instance.
114#[allow(dead_code)]
115pub struct Server<H> {
116    bind_addr: SocketAddr,
117    handler: Arc<H>,
118    max_connections: usize,
119    max_frame_size: usize,
120    cmd_rx: MpscReceiver<ServerCommand>,
121    event_tx: MpscSender<ServerEvent>,
122    sessions: SessionManager,
123    cmd_notify: Arc<Notify>,
124}
125
126impl<H: MessageHandler + Send + Sync + 'static> Server<H> {
127    /// Runs the server, accepting connections and processing messages.
128    ///
129    /// # Errors
130    /// Returns `ServerError` if the server fails to start or encounters an error.
131    pub async fn run(&mut self) -> Result<(), ServerError> {
132        let listener = tokio::net::TcpListener::bind(self.bind_addr).await?;
133        tracing::info!("Server listening on {}", self.bind_addr);
134
135        loop {
136            tokio::select! {
137                result = listener.accept() => {
138                    match result {
139                        Ok((stream, addr)) => {
140                            self.handle_connection(stream, addr).await;
141                        }
142                        Err(e) => {
143                            tracing::error!("Accept error: {}", e);
144                        }
145                    }
146                }
147
148                _ = self.cmd_notify.notified() => {
149                    while let Some(cmd) = self.cmd_rx.try_recv() {
150                        if self.handle_command(cmd).await {
151                            return Ok(());
152                        }
153                    }
154                }
155            }
156        }
157    }
158
159    async fn handle_connection(&mut self, stream: TcpStream, addr: SocketAddr) {
160        if self.sessions.count() >= self.max_connections {
161            tracing::warn!("Max connections reached, rejecting {}", addr);
162            return;
163        }
164
165        let session_id = self.sessions.create_session(addr);
166        let handler = Arc::clone(&self.handler);
167        let event_tx = self.event_tx.clone();
168        let max_frame_size = self.max_frame_size;
169
170        handler.on_session_start(session_id);
171        let _ = event_tx.try_send(ServerEvent::SessionCreated(session_id, addr));
172
173        // Spawn connection handler task
174        tokio::spawn(async move {
175            tracing::info!("Session {} connected from {}", session_id, addr);
176
177            if let Err(e) =
178                handle_session(session_id, stream, handler.as_ref(), max_frame_size).await
179            {
180                tracing::error!("Session {} error: {:?}", session_id, e);
181            }
182
183            // When done, notify
184            handler.on_session_end(session_id);
185            let _ = event_tx.try_send(ServerEvent::SessionClosed(session_id));
186        });
187    }
188
189    async fn handle_command(&mut self, cmd: ServerCommand) -> bool {
190        match cmd {
191            ServerCommand::Shutdown => {
192                tracing::info!("Server shutdown requested");
193                true
194            }
195            ServerCommand::CloseSession(session_id) => {
196                self.sessions.close_session(session_id);
197                false
198            }
199            ServerCommand::Broadcast(_message) => {
200                // Broadcast to all sessions
201                false
202            }
203        }
204    }
205}
206
207/// Handle for controlling the server from outside.
208pub struct ServerHandle {
209    cmd_tx: MpscSender<ServerCommand>,
210    event_rx: MpscReceiver<ServerEvent>,
211    cmd_notify: Arc<Notify>,
212}
213
214impl ServerHandle {
215    /// Requests server shutdown.
216    pub fn shutdown(&self) {
217        let _ = self.cmd_tx.try_send(ServerCommand::Shutdown);
218        self.cmd_notify.notify_one();
219    }
220
221    /// Closes a specific session.
222    pub fn close_session(&self, session_id: u64) {
223        let _ = self
224            .cmd_tx
225            .try_send(ServerCommand::CloseSession(session_id));
226        self.cmd_notify.notify_one();
227    }
228
229    /// Broadcasts a message to all sessions.
230    pub fn broadcast(&self, message: Vec<u8>) {
231        let _ = self.cmd_tx.try_send(ServerCommand::Broadcast(message));
232        self.cmd_notify.notify_one();
233    }
234
235    /// Polls for server events.
236    pub fn poll_events(&self) -> impl Iterator<Item = ServerEvent> + '_ {
237        std::iter::from_fn(|| self.event_rx.try_recv())
238    }
239}
240
241/// Commands that can be sent to the server.
242#[derive(Debug)]
243pub enum ServerCommand {
244    /// Shutdown the server.
245    Shutdown,
246    /// Close a specific session.
247    CloseSession(u64),
248    /// Broadcast a message to all sessions.
249    Broadcast(Vec<u8>),
250}
251
252/// Events emitted by the server.
253#[derive(Debug, Clone)]
254pub enum ServerEvent {
255    /// A new session was created.
256    SessionCreated(u64, SocketAddr),
257    /// A session was closed.
258    SessionClosed(u64),
259    /// An error occurred.
260    Error(String),
261}
262
263/// Length-prefixed frame codec for SBE messages.
264struct SbeFrameCodec {
265    max_frame_size: usize,
266}
267
268impl SbeFrameCodec {
269    fn new(max_frame_size: usize) -> Self {
270        Self { max_frame_size }
271    }
272}
273
274impl Decoder for SbeFrameCodec {
275    type Item = BytesMut;
276    type Error = std::io::Error;
277
278    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
279        if src.len() < 4 {
280            return Ok(None);
281        }
282
283        let length = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
284
285        if length > self.max_frame_size {
286            return Err(std::io::Error::new(
287                std::io::ErrorKind::InvalidData,
288                format!("Frame too large: {} > {}", length, self.max_frame_size),
289            ));
290        }
291
292        if src.len() < 4 + length {
293            src.reserve(4 + length - src.len());
294            return Ok(None);
295        }
296
297        let _ = src.split_to(4);
298        Ok(Some(src.split_to(length)))
299    }
300}
301
302impl<T: AsRef<[u8]>> Encoder<T> for SbeFrameCodec {
303    type Error = std::io::Error;
304
305    fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
306        let data = item.as_ref();
307        let length = data.len() as u32;
308        dst.reserve(4 + data.len());
309        dst.extend_from_slice(&length.to_le_bytes());
310        dst.extend_from_slice(data);
311        Ok(())
312    }
313}
314
315/// Session responder that sends messages back to the client.
316struct SessionResponder {
317    tx: tokio_mpsc::UnboundedSender<Vec<u8>>,
318}
319
320impl Responder for SessionResponder {
321    fn send(&self, message: &[u8]) -> Result<(), SendError> {
322        self.tx.send(message.to_vec()).map_err(|_| SendError {
323            message: "channel closed".to_string(),
324        })
325    }
326
327    fn send_to(&self, _session_id: u64, message: &[u8]) -> Result<(), SendError> {
328        // For now, just send to current session
329        self.send(message)
330    }
331}
332
333/// Handles a single client session.
334async fn handle_session<H: MessageHandler>(
335    session_id: u64,
336    stream: TcpStream,
337    handler: &H,
338    max_frame_size: usize,
339) -> Result<(), std::io::Error> {
340    let codec = SbeFrameCodec::new(max_frame_size);
341    let mut framed = Framed::new(stream, codec);
342
343    // Channel for sending responses
344    let (tx, mut rx) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
345    let responder = SessionResponder { tx };
346
347    loop {
348        tokio::select! {
349            // Read incoming messages
350            result = framed.next() => {
351                match result {
352                    Some(Ok(data)) => {
353                        // Decode header and dispatch to handler
354                        if data.len() >= MessageHeader::ENCODED_LENGTH {
355                            let header = MessageHeader::wrap(data.as_ref(), 0);
356                            handler.on_message(session_id, &header, data.as_ref(), &responder);
357                        } else {
358                            handler.on_error(session_id, "Message too short for header");
359                        }
360                    }
361                    Some(Err(e)) => {
362                        tracing::error!("Session {} read error: {}", session_id, e);
363                        return Err(e);
364                    }
365                    None => {
366                        tracing::info!("Session {} disconnected", session_id);
367                        return Ok(());
368                    }
369                }
370            }
371
372            // Send outgoing messages
373            Some(msg) = rx.recv() => {
374                if let Err(e) = framed.send(msg).await {
375                    tracing::error!("Session {} write error: {}", session_id, e);
376                    return Err(e);
377                }
378            }
379        }
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    struct TestHandler;
388
389    impl MessageHandler for TestHandler {
390        fn on_message(
391            &self,
392            _session_id: u64,
393            _header: &MessageHeader,
394            _data: &[u8],
395            _responder: &dyn Responder,
396        ) {
397        }
398    }
399
400    #[test]
401    fn test_server_builder_new() {
402        let builder = ServerBuilder::<TestHandler>::new();
403        let _ = builder;
404    }
405
406    #[test]
407    fn test_server_builder_default() {
408        let builder = ServerBuilder::<TestHandler>::default();
409        let _ = builder;
410    }
411
412    #[test]
413    fn test_server_builder_bind() {
414        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
415        let builder = ServerBuilder::<TestHandler>::new().bind(addr);
416        let _ = builder;
417    }
418
419    #[test]
420    fn test_server_builder_handler() {
421        let builder = ServerBuilder::new().handler(TestHandler);
422        let _ = builder;
423    }
424
425    #[test]
426    fn test_server_builder_max_connections() {
427        let builder = ServerBuilder::<TestHandler>::new().max_connections(500);
428        let _ = builder;
429    }
430
431    #[test]
432    fn test_server_builder_max_frame_size() {
433        let builder = ServerBuilder::<TestHandler>::new().max_frame_size(128 * 1024);
434        let _ = builder;
435    }
436
437    #[test]
438    fn test_server_builder_channel_capacity() {
439        let builder = ServerBuilder::<TestHandler>::new().channel_capacity(8192);
440        let _ = builder;
441    }
442
443    #[test]
444    fn test_server_builder_build() {
445        let (_server, _handle) = ServerBuilder::new().handler(TestHandler).build();
446    }
447
448    #[test]
449    fn test_server_command_debug() {
450        let cmd = ServerCommand::Shutdown;
451        let debug_str = format!("{:?}", cmd);
452        assert!(debug_str.contains("Shutdown"));
453
454        let cmd2 = ServerCommand::CloseSession(42);
455        let debug_str2 = format!("{:?}", cmd2);
456        assert!(debug_str2.contains("CloseSession"));
457
458        let cmd3 = ServerCommand::Broadcast(vec![1, 2, 3]);
459        let debug_str3 = format!("{:?}", cmd3);
460        assert!(debug_str3.contains("Broadcast"));
461    }
462
463    #[test]
464    fn test_server_event_clone_debug() {
465        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
466        let event = ServerEvent::SessionCreated(1, addr);
467        let cloned = event.clone();
468        let _ = cloned;
469
470        let debug_str = format!("{:?}", event);
471        assert!(debug_str.contains("SessionCreated"));
472
473        let event2 = ServerEvent::SessionClosed(1);
474        let debug_str2 = format!("{:?}", event2);
475        assert!(debug_str2.contains("SessionClosed"));
476
477        let event3 = ServerEvent::Error("test error".to_string());
478        let debug_str3 = format!("{:?}", event3);
479        assert!(debug_str3.contains("Error"));
480    }
481
482    #[test]
483    fn test_server_handle_shutdown() {
484        let (_server, handle) = ServerBuilder::new().handler(TestHandler).build();
485        handle.shutdown();
486    }
487
488    #[test]
489    fn test_server_handle_close_session() {
490        let (_server, handle) = ServerBuilder::new().handler(TestHandler).build();
491        handle.close_session(1);
492    }
493
494    #[test]
495    fn test_server_handle_broadcast() {
496        let (_server, handle) = ServerBuilder::new().handler(TestHandler).build();
497        handle.broadcast(vec![1, 2, 3]);
498    }
499
500    #[test]
501    fn test_sbe_frame_codec_new() {
502        let codec = SbeFrameCodec::new(64 * 1024);
503        assert_eq!(codec.max_frame_size, 64 * 1024);
504    }
505
506    #[test]
507    fn test_sbe_frame_codec_decode_incomplete() {
508        let mut codec = SbeFrameCodec::new(1024);
509        let mut buf = BytesMut::from(&[0u8, 0, 0][..]);
510        assert!(codec.decode(&mut buf).unwrap().is_none());
511    }
512
513    #[test]
514    fn test_sbe_frame_codec_decode_complete() {
515        let mut codec = SbeFrameCodec::new(1024);
516        let mut buf = BytesMut::new();
517        buf.extend_from_slice(&5u32.to_le_bytes());
518        buf.extend_from_slice(b"hello");
519
520        let result = codec.decode(&mut buf).unwrap();
521        assert!(result.is_some());
522        assert_eq!(result.unwrap().as_ref(), b"hello");
523    }
524
525    #[test]
526    fn test_sbe_frame_codec_decode_too_large() {
527        let mut codec = SbeFrameCodec::new(10);
528        let mut buf = BytesMut::new();
529        buf.extend_from_slice(&100u32.to_le_bytes());
530
531        let result = codec.decode(&mut buf);
532        assert!(result.is_err());
533    }
534
535    #[test]
536    fn test_sbe_frame_codec_encode() {
537        let mut codec = SbeFrameCodec::new(1024);
538        let mut buf = BytesMut::new();
539        codec.encode(b"hello", &mut buf).unwrap();
540
541        assert_eq!(&buf[0..4], &5u32.to_le_bytes());
542        assert_eq!(&buf[4..9], b"hello");
543    }
544}