Skip to main content

ironsbe_server/
local_builder.rs

1//! Single-threaded server builder for thread-per-core / `!Send` backends.
2//!
3//! Mirrors [`crate::ServerBuilder`] but is generic over [`LocalTransport`]
4//! instead of [`Transport`](ironsbe_transport::Transport).  Use this when
5//! the chosen backend (e.g. `tokio-uring` via the `tcp-uring` feature) is
6//! single-threaded by construction and its handle types are `!Send`.
7//!
8//! # Runtime requirements
9//!
10//! [`LocalServer::run`] must be polled inside a single-threaded reactor
11//! that owns a Tokio `LocalSet` (or anything with the same semantics).
12//! `tokio_uring::start` provides one for free; for plain `tokio` you can
13//! use `tokio::task::LocalSet::run_until`.
14
15use crate::error::ServerError;
16use crate::handler::{MessageHandler, Responder, SendError};
17use crate::session::SessionManager;
18use ironsbe_channel::mpsc::{MpscChannel, MpscReceiver, MpscSender};
19use ironsbe_core::header::MessageHeader;
20use ironsbe_transport::traits::{LocalConnection, LocalListener, LocalTransport};
21use std::marker::PhantomData;
22use std::net::SocketAddr;
23use std::rc::Rc;
24use std::sync::Arc;
25use tokio::sync::{Notify, mpsc as tokio_mpsc};
26
27use crate::builder::{ServerCommand, ServerEvent, ServerHandle};
28
29/// Builder for [`LocalServer`].
30///
31/// Single-threaded counterpart of [`crate::ServerBuilder`]; the type
32/// parameter `T` selects a [`LocalTransport`] backend rather than the
33/// multi-threaded [`Transport`](ironsbe_transport::Transport) family.
34pub struct LocalServerBuilder<H, T: LocalTransport> {
35    bind_addr: SocketAddr,
36    bind_config: Option<T::BindConfig>,
37    handler: Option<H>,
38    max_connections: usize,
39    channel_capacity: usize,
40    _transport: PhantomData<T>,
41}
42
43impl<H: MessageHandler, T: LocalTransport> LocalServerBuilder<H, T> {
44    /// Creates a new local server builder with default settings.
45    #[must_use]
46    pub fn new() -> Self {
47        Self {
48            bind_addr: "0.0.0.0:9000"
49                .parse()
50                .expect("hardcoded default bind addr is valid"),
51            bind_config: None,
52            handler: None,
53            max_connections: 1000,
54            channel_capacity: 4096,
55            _transport: PhantomData,
56        }
57    }
58
59    /// Sets the bind address.  Clears any previously-supplied
60    /// [`bind_config`](Self::bind_config) since the address is now
61    /// ambiguous.
62    #[must_use]
63    pub fn bind(mut self, addr: SocketAddr) -> Self {
64        self.bind_addr = addr;
65        self.bind_config = None;
66        self
67    }
68
69    /// Supplies a backend-specific bind configuration.
70    #[must_use]
71    pub fn bind_config(mut self, config: T::BindConfig) -> Self {
72        self.bind_config = Some(config);
73        self
74    }
75
76    /// Sets the message handler.
77    #[must_use]
78    pub fn handler(mut self, handler: H) -> Self {
79        self.handler = Some(handler);
80        self
81    }
82
83    /// Sets the maximum number of concurrent sessions.
84    #[must_use]
85    pub fn max_connections(mut self, max: usize) -> Self {
86        self.max_connections = max;
87        self
88    }
89
90    /// Sets the cmd/event channel capacity.
91    #[must_use]
92    pub fn channel_capacity(mut self, capacity: usize) -> Self {
93        self.channel_capacity = capacity;
94        self
95    }
96
97    /// Builds the server and its external handle.
98    ///
99    /// # Panics
100    /// Panics if no [`handler`](Self::handler) was set.
101    #[must_use]
102    pub fn build(self) -> (LocalServer<H, T>, ServerHandle) {
103        let handler = self.handler.expect("Handler required");
104        let (cmd_tx, cmd_rx) = MpscChannel::bounded(self.channel_capacity);
105        let (event_tx, event_rx) = MpscChannel::bounded(self.channel_capacity);
106        let cmd_notify = Arc::new(Notify::new());
107
108        let server = LocalServer {
109            bind_addr: self.bind_addr,
110            bind_config: Some(
111                self.bind_config
112                    .unwrap_or_else(|| T::BindConfig::from(self.bind_addr)),
113            ),
114            handler: Rc::new(handler),
115            max_connections: self.max_connections,
116            cmd_tx: cmd_tx.clone(),
117            cmd_rx,
118            event_tx,
119            sessions: SessionManager::new(),
120            cmd_notify: Arc::clone(&cmd_notify),
121            _transport: PhantomData,
122        };
123
124        let handle = ServerHandle::new(cmd_tx, event_rx, cmd_notify);
125        (server, handle)
126    }
127}
128
129impl<H: MessageHandler, T: LocalTransport> Default for LocalServerBuilder<H, T> {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135/// Single-threaded server instance for [`LocalTransport`] backends.
136///
137/// `LocalServer::run` **must** be polled inside a Tokio `LocalSet` (e.g.
138/// from inside `tokio_uring::start(async { server.run().await })`).
139/// Polling it from a context without a `LocalSet` will fail at the first
140/// `spawn_local` call.
141#[allow(dead_code)]
142pub struct LocalServer<H, T: LocalTransport> {
143    bind_addr: SocketAddr,
144    bind_config: Option<T::BindConfig>,
145    handler: Rc<H>,
146    max_connections: usize,
147    /// Cloned and handed to per-session tasks so they can fire
148    /// `ServerCommand::CloseSession` when the session ends, freeing the
149    /// `SessionManager` slot back in the run loop.
150    cmd_tx: MpscSender<ServerCommand>,
151    cmd_rx: MpscReceiver<ServerCommand>,
152    event_tx: MpscSender<ServerEvent>,
153    sessions: SessionManager,
154    cmd_notify: Arc<Notify>,
155    _transport: PhantomData<T>,
156}
157
158impl<H, T> LocalServer<H, T>
159where
160    H: MessageHandler + 'static,
161    T: LocalTransport,
162    T::Connection: 'static,
163{
164    /// Runs the server, accepting connections and processing messages.
165    ///
166    /// # Errors
167    /// Returns [`ServerError`] if the listener fails to bind or the
168    /// accept loop encounters an unrecoverable error.
169    ///
170    /// # Panics
171    /// Panics indirectly via `tokio::task::spawn_local` if called outside
172    /// a `LocalSet` context.  See the type-level docs.
173    pub async fn run(&mut self) -> Result<(), ServerError> {
174        let bind_config = self
175            .bind_config
176            .take()
177            .unwrap_or_else(|| T::BindConfig::from(self.bind_addr));
178        let mut listener = T::bind_with(bind_config)
179            .await
180            .map_err(|e| ServerError::Io(std::io::Error::other(e.to_string())))?;
181        let effective_addr = listener.local_addr().unwrap_or(self.bind_addr);
182        tracing::info!("Local server listening on {}", effective_addr);
183        // Notify any external observer (test harness, supervisor) of
184        // the effective bound address.  Best-effort: dropping the event
185        // is fine if no one is listening.
186        let _ = self
187            .event_tx
188            .try_send(ServerEvent::Listening(effective_addr));
189
190        loop {
191            tokio::select! {
192                result = listener.accept() => {
193                    match result {
194                        Ok(conn) => {
195                            let addr = conn
196                                .peer_addr()
197                                .unwrap_or_else(|_| "0.0.0.0:0".parse().expect("placeholder"));
198                            self.handle_connection(conn, addr);
199                        }
200                        Err(e) => {
201                            tracing::error!("Local accept error: {}", e);
202                        }
203                    }
204                }
205
206                _ = self.cmd_notify.notified() => {
207                    while let Some(cmd) = self.cmd_rx.try_recv() {
208                        if self.handle_command(cmd).await {
209                            return Ok(());
210                        }
211                    }
212                }
213            }
214        }
215    }
216
217    fn handle_connection(&mut self, conn: T::Connection, addr: SocketAddr) {
218        if self.sessions.count() >= self.max_connections {
219            tracing::warn!("Max connections reached, rejecting {}", addr);
220            return;
221        }
222
223        let session_id = self.sessions.create_session(addr);
224        let handler = Rc::clone(&self.handler);
225        let event_tx = self.event_tx.clone();
226        // Cloned cmd_tx so the spawned task can fire CloseSession back
227        // to the run loop on disconnect, releasing the SessionManager
228        // slot.  Without this the slot leaks and `max_connections`
229        // eventually rejects every new connection.
230        let cmd_tx = self.cmd_tx.clone();
231        let cmd_notify = Arc::clone(&self.cmd_notify);
232
233        handler.on_session_start(session_id);
234        let _ = event_tx.try_send(ServerEvent::SessionCreated(session_id, addr));
235
236        // `spawn_local` keeps the future on the current single-threaded
237        // runtime, satisfying the `!Send` connection bound.  The span
238        // gives every log line inside the session the
239        // `sbe_session{session_id=N}:` prefix.
240        let span = tracing::info_span!("sbe_session", session_id, %addr);
241        tokio::task::spawn_local(async move {
242            let _guard = span.enter();
243            tracing::info!("connected");
244            if let Err(e) = handle_local_session(session_id, conn, handler.as_ref()).await {
245                tracing::error!(error = %e, "session error");
246            }
247            tracing::info!("disconnected");
248            handler.on_session_end(session_id);
249            let _ = event_tx.try_send(ServerEvent::SessionClosed(session_id));
250            let _ = cmd_tx.try_send(ServerCommand::CloseSession(session_id));
251            cmd_notify.notify_one();
252        });
253    }
254
255    async fn handle_command(&mut self, cmd: ServerCommand) -> bool {
256        match cmd {
257            ServerCommand::Shutdown => {
258                tracing::info!("Local server shutdown requested");
259                true
260            }
261            ServerCommand::CloseSession(session_id) => {
262                self.sessions.close_session(session_id);
263                false
264            }
265            ServerCommand::Broadcast(_message) => false,
266        }
267    }
268}
269
270/// Per-session responder that ferries handler outputs back to the
271/// connection writer over an unbounded local channel.  Mirrors the
272/// equivalent type in [`crate::builder`].
273struct LocalSessionResponder {
274    tx: tokio_mpsc::UnboundedSender<Vec<u8>>,
275}
276
277impl Responder for LocalSessionResponder {
278    fn send(&self, message: &[u8]) -> Result<(), SendError> {
279        self.tx.send(message.to_vec()).map_err(|_| SendError {
280            message: "channel closed".to_string(),
281        })
282    }
283
284    fn send_to(&self, _session_id: u64, message: &[u8]) -> Result<(), SendError> {
285        self.send(message)
286    }
287}
288
289/// Drives one [`LocalConnection`] end-to-end: read framed SBE messages,
290/// dispatch to the handler, and write any responses produced by the
291/// handler back over the same connection.
292///
293/// Mirrors the [`Connection`](ironsbe_transport::traits::Connection)
294/// version in [`crate::builder`].
295async fn handle_local_session<H, C>(
296    session_id: u64,
297    mut conn: C,
298    handler: &H,
299) -> Result<(), std::io::Error>
300where
301    H: MessageHandler,
302    C: LocalConnection,
303{
304    let (tx, mut rx) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
305    let responder = LocalSessionResponder { tx };
306
307    loop {
308        tokio::select! {
309            result = conn.recv() => {
310                match result {
311                    Ok(Some(data)) => {
312                        if data.len() >= MessageHeader::ENCODED_LENGTH {
313                            let header = MessageHeader::wrap(data.as_ref(), 0);
314                            handler.on_message(session_id, &header, data.as_ref(), &responder);
315                        } else {
316                            handler.on_error(session_id, "Message too short for header");
317                        }
318                    }
319                    Ok(None) => {
320                        return Ok(());
321                    }
322                    Err(e) => {
323                        tracing::error!(error = %e, "read error");
324                        return Err(std::io::Error::other(e.to_string()));
325                    }
326                }
327            }
328
329            Some(msg) = rx.recv() => {
330                if let Err(e) = conn.send(&msg).await {
331                    tracing::error!(error = %e, "write error");
332                    return Err(std::io::Error::other(e.to_string()));
333                }
334            }
335        }
336    }
337}
338
339#[cfg(all(test, feature = "tcp-uring", target_os = "linux"))]
340mod tests {
341    use super::*;
342    use crate::handler::Responder;
343    use ironsbe_transport::tcp_uring::UringTcpTransport;
344
345    struct TestHandler;
346    impl MessageHandler for TestHandler {
347        fn on_message(
348            &self,
349            _session_id: u64,
350            _header: &MessageHeader,
351            _data: &[u8],
352            _responder: &dyn Responder,
353        ) {
354        }
355    }
356
357    #[test]
358    fn test_local_server_builder_new() {
359        let builder = LocalServerBuilder::<TestHandler, UringTcpTransport>::new();
360        let _ = builder;
361    }
362
363    #[test]
364    fn test_local_server_builder_default() {
365        let builder = LocalServerBuilder::<TestHandler, UringTcpTransport>::default();
366        let _ = builder;
367    }
368
369    #[test]
370    fn test_local_server_builder_bind() {
371        let addr: SocketAddr = "127.0.0.1:8080".parse().expect("test addr");
372        let builder = LocalServerBuilder::<TestHandler, UringTcpTransport>::new().bind(addr);
373        let _ = builder;
374    }
375
376    #[test]
377    fn test_local_server_builder_max_connections() {
378        let builder =
379            LocalServerBuilder::<TestHandler, UringTcpTransport>::new().max_connections(500);
380        let _ = builder;
381    }
382
383    #[test]
384    fn test_local_server_builder_channel_capacity() {
385        let builder =
386            LocalServerBuilder::<TestHandler, UringTcpTransport>::new().channel_capacity(8192);
387        let _ = builder;
388    }
389
390    #[test]
391    fn test_local_server_builder_build() {
392        let (_server, _handle) = LocalServerBuilder::<TestHandler, UringTcpTransport>::new()
393            .handler(TestHandler)
394            .build();
395    }
396}