Skip to main content

ironsbe_server/
builder.rs

1//! Server builder and main server implementation.
2
3use crate::error::ServerError;
4use crate::handler::{MessageHandler, Responder, SendError};
5use crate::session::SessionManager;
6use bytes::BytesMut;
7use futures::SinkExt;
8use ironsbe_channel::mpsc::{MpscChannel, MpscReceiver, MpscSender};
9use ironsbe_core::header::MessageHeader;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::net::TcpStream;
13use tokio::sync::mpsc as tokio_mpsc;
14use tokio_stream::StreamExt;
15use tokio_util::codec::{Decoder, Encoder, Framed};
16
17/// Builder for configuring and creating a server.
18pub struct ServerBuilder<H> {
19    bind_addr: SocketAddr,
20    handler: Option<H>,
21    max_connections: usize,
22    max_frame_size: usize,
23    channel_capacity: usize,
24}
25
26impl<H: MessageHandler> ServerBuilder<H> {
27    /// Creates a new server builder with default settings.
28    #[must_use]
29    pub fn new() -> Self {
30        Self {
31            bind_addr: "0.0.0.0:9000".parse().unwrap(),
32            handler: None,
33            max_connections: 1000,
34            max_frame_size: 64 * 1024,
35            channel_capacity: 4096,
36        }
37    }
38
39    /// Sets the bind address.
40    #[must_use]
41    pub fn bind(mut self, addr: SocketAddr) -> Self {
42        self.bind_addr = addr;
43        self
44    }
45
46    /// Sets the message handler.
47    #[must_use]
48    pub fn handler(mut self, handler: H) -> Self {
49        self.handler = Some(handler);
50        self
51    }
52
53    /// Sets the maximum number of connections.
54    #[must_use]
55    pub fn max_connections(mut self, max: usize) -> Self {
56        self.max_connections = max;
57        self
58    }
59
60    /// Sets the maximum frame size.
61    #[must_use]
62    pub fn max_frame_size(mut self, size: usize) -> Self {
63        self.max_frame_size = size;
64        self
65    }
66
67    /// Sets the channel capacity.
68    #[must_use]
69    pub fn channel_capacity(mut self, capacity: usize) -> Self {
70        self.channel_capacity = capacity;
71        self
72    }
73
74    /// Builds the server and handle.
75    ///
76    /// # Panics
77    /// Panics if no handler was set.
78    #[must_use]
79    pub fn build(self) -> (Server<H>, ServerHandle) {
80        let handler = self.handler.expect("Handler required");
81        let (cmd_tx, cmd_rx) = MpscChannel::bounded(self.channel_capacity);
82        let (event_tx, event_rx) = MpscChannel::bounded(self.channel_capacity);
83
84        let server = Server {
85            bind_addr: self.bind_addr,
86            handler: Arc::new(handler),
87            max_connections: self.max_connections,
88            max_frame_size: self.max_frame_size,
89            cmd_rx,
90            event_tx,
91            sessions: SessionManager::new(),
92        };
93
94        let handle = ServerHandle { cmd_tx, event_rx };
95
96        (server, handle)
97    }
98}
99
100impl<H: MessageHandler> Default for ServerBuilder<H> {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106/// The main server instance.
107#[allow(dead_code)]
108pub struct Server<H> {
109    bind_addr: SocketAddr,
110    handler: Arc<H>,
111    max_connections: usize,
112    max_frame_size: usize,
113    cmd_rx: MpscReceiver<ServerCommand>,
114    event_tx: MpscSender<ServerEvent>,
115    sessions: SessionManager,
116}
117
118impl<H: MessageHandler + Send + Sync + 'static> Server<H> {
119    /// Runs the server, accepting connections and processing messages.
120    ///
121    /// # Errors
122    /// Returns `ServerError` if the server fails to start or encounters an error.
123    pub async fn run(&mut self) -> Result<(), ServerError> {
124        let listener = tokio::net::TcpListener::bind(self.bind_addr).await?;
125        tracing::info!("Server listening on {}", self.bind_addr);
126
127        loop {
128            tokio::select! {
129                result = listener.accept() => {
130                    match result {
131                        Ok((stream, addr)) => {
132                            self.handle_connection(stream, addr).await;
133                        }
134                        Err(e) => {
135                            tracing::error!("Accept error: {}", e);
136                        }
137                    }
138                }
139
140                cmd = async { self.cmd_rx.try_recv() } => {
141                    if let Some(cmd) = cmd && self.handle_command(cmd).await {
142                        return Ok(());
143                    }
144                }
145            }
146        }
147    }
148
149    async fn handle_connection(&mut self, stream: TcpStream, addr: SocketAddr) {
150        if self.sessions.count() >= self.max_connections {
151            tracing::warn!("Max connections reached, rejecting {}", addr);
152            return;
153        }
154
155        let session_id = self.sessions.create_session(addr);
156        let handler = Arc::clone(&self.handler);
157        let event_tx = self.event_tx.clone();
158        let max_frame_size = self.max_frame_size;
159
160        handler.on_session_start(session_id);
161        let _ = event_tx.try_send(ServerEvent::SessionCreated(session_id, addr));
162
163        // Spawn connection handler task
164        tokio::spawn(async move {
165            tracing::info!("Session {} connected from {}", session_id, addr);
166
167            if let Err(e) =
168                handle_session(session_id, stream, handler.as_ref(), max_frame_size).await
169            {
170                tracing::error!("Session {} error: {:?}", session_id, e);
171            }
172
173            // When done, notify
174            handler.on_session_end(session_id);
175            let _ = event_tx.try_send(ServerEvent::SessionClosed(session_id));
176        });
177    }
178
179    async fn handle_command(&mut self, cmd: ServerCommand) -> bool {
180        match cmd {
181            ServerCommand::Shutdown => {
182                tracing::info!("Server shutdown requested");
183                true
184            }
185            ServerCommand::CloseSession(session_id) => {
186                self.sessions.close_session(session_id);
187                false
188            }
189            ServerCommand::Broadcast(_message) => {
190                // Broadcast to all sessions
191                false
192            }
193        }
194    }
195}
196
197/// Handle for controlling the server from outside.
198pub struct ServerHandle {
199    cmd_tx: MpscSender<ServerCommand>,
200    event_rx: MpscReceiver<ServerEvent>,
201}
202
203impl ServerHandle {
204    /// Requests server shutdown.
205    pub fn shutdown(&self) {
206        let _ = self.cmd_tx.try_send(ServerCommand::Shutdown);
207    }
208
209    /// Closes a specific session.
210    pub fn close_session(&self, session_id: u64) {
211        let _ = self
212            .cmd_tx
213            .try_send(ServerCommand::CloseSession(session_id));
214    }
215
216    /// Broadcasts a message to all sessions.
217    pub fn broadcast(&self, message: Vec<u8>) {
218        let _ = self.cmd_tx.try_send(ServerCommand::Broadcast(message));
219    }
220
221    /// Polls for server events.
222    pub fn poll_events(&self) -> impl Iterator<Item = ServerEvent> + '_ {
223        std::iter::from_fn(|| self.event_rx.try_recv())
224    }
225}
226
227/// Commands that can be sent to the server.
228#[derive(Debug)]
229pub enum ServerCommand {
230    /// Shutdown the server.
231    Shutdown,
232    /// Close a specific session.
233    CloseSession(u64),
234    /// Broadcast a message to all sessions.
235    Broadcast(Vec<u8>),
236}
237
238/// Events emitted by the server.
239#[derive(Debug, Clone)]
240pub enum ServerEvent {
241    /// A new session was created.
242    SessionCreated(u64, SocketAddr),
243    /// A session was closed.
244    SessionClosed(u64),
245    /// An error occurred.
246    Error(String),
247}
248
249/// Length-prefixed frame codec for SBE messages.
250struct SbeFrameCodec {
251    max_frame_size: usize,
252}
253
254impl SbeFrameCodec {
255    fn new(max_frame_size: usize) -> Self {
256        Self { max_frame_size }
257    }
258}
259
260impl Decoder for SbeFrameCodec {
261    type Item = BytesMut;
262    type Error = std::io::Error;
263
264    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
265        if src.len() < 4 {
266            return Ok(None);
267        }
268
269        let length = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
270
271        if length > self.max_frame_size {
272            return Err(std::io::Error::new(
273                std::io::ErrorKind::InvalidData,
274                format!("Frame too large: {} > {}", length, self.max_frame_size),
275            ));
276        }
277
278        if src.len() < 4 + length {
279            src.reserve(4 + length - src.len());
280            return Ok(None);
281        }
282
283        let _ = src.split_to(4);
284        Ok(Some(src.split_to(length)))
285    }
286}
287
288impl<T: AsRef<[u8]>> Encoder<T> for SbeFrameCodec {
289    type Error = std::io::Error;
290
291    fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
292        let data = item.as_ref();
293        let length = data.len() as u32;
294        dst.reserve(4 + data.len());
295        dst.extend_from_slice(&length.to_le_bytes());
296        dst.extend_from_slice(data);
297        Ok(())
298    }
299}
300
301/// Session responder that sends messages back to the client.
302struct SessionResponder {
303    tx: tokio_mpsc::UnboundedSender<Vec<u8>>,
304}
305
306impl Responder for SessionResponder {
307    fn send(&self, message: &[u8]) -> Result<(), SendError> {
308        self.tx.send(message.to_vec()).map_err(|_| SendError {
309            message: "channel closed".to_string(),
310        })
311    }
312
313    fn send_to(&self, _session_id: u64, message: &[u8]) -> Result<(), SendError> {
314        // For now, just send to current session
315        self.send(message)
316    }
317}
318
319/// Handles a single client session.
320async fn handle_session<H: MessageHandler>(
321    session_id: u64,
322    stream: TcpStream,
323    handler: &H,
324    max_frame_size: usize,
325) -> Result<(), std::io::Error> {
326    let codec = SbeFrameCodec::new(max_frame_size);
327    let mut framed = Framed::new(stream, codec);
328
329    // Channel for sending responses
330    let (tx, mut rx) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
331    let responder = SessionResponder { tx };
332
333    loop {
334        tokio::select! {
335            // Read incoming messages
336            result = framed.next() => {
337                match result {
338                    Some(Ok(data)) => {
339                        // Decode header and dispatch to handler
340                        if data.len() >= MessageHeader::ENCODED_LENGTH {
341                            let header = MessageHeader::wrap(data.as_ref(), 0);
342                            handler.on_message(session_id, &header, data.as_ref(), &responder);
343                        } else {
344                            handler.on_error(session_id, "Message too short for header");
345                        }
346                    }
347                    Some(Err(e)) => {
348                        tracing::error!("Session {} read error: {}", session_id, e);
349                        return Err(e);
350                    }
351                    None => {
352                        tracing::info!("Session {} disconnected", session_id);
353                        return Ok(());
354                    }
355                }
356            }
357
358            // Send outgoing messages
359            Some(msg) = rx.recv() => {
360                if let Err(e) = framed.send(msg).await {
361                    tracing::error!("Session {} write error: {}", session_id, e);
362                    return Err(e);
363                }
364            }
365        }
366    }
367}