1use crate::error::ServerError;
16use crate::handler::{MessageHandler, Responder, SendError};
17use crate::session::SessionManager;
18use ironsbe_channel::mpsc::{MpscChannel, MpscReceiver, MpscSender};
19use ironsbe_core::header::MessageHeader;
20use ironsbe_transport::traits::{LocalConnection, LocalListener, LocalTransport};
21use std::marker::PhantomData;
22use std::net::SocketAddr;
23use std::rc::Rc;
24use std::sync::Arc;
25use tokio::sync::{Notify, mpsc as tokio_mpsc};
26
27use crate::builder::{ServerCommand, ServerEvent, ServerHandle};
28
29pub struct LocalServerBuilder<H, T: LocalTransport> {
35 bind_addr: SocketAddr,
36 bind_config: Option<T::BindConfig>,
37 handler: Option<H>,
38 max_connections: usize,
39 channel_capacity: usize,
40 _transport: PhantomData<T>,
41}
42
43impl<H: MessageHandler, T: LocalTransport> LocalServerBuilder<H, T> {
44 #[must_use]
46 pub fn new() -> Self {
47 Self {
48 bind_addr: "0.0.0.0:9000"
49 .parse()
50 .expect("hardcoded default bind addr is valid"),
51 bind_config: None,
52 handler: None,
53 max_connections: 1000,
54 channel_capacity: 4096,
55 _transport: PhantomData,
56 }
57 }
58
59 #[must_use]
63 pub fn bind(mut self, addr: SocketAddr) -> Self {
64 self.bind_addr = addr;
65 self.bind_config = None;
66 self
67 }
68
69 #[must_use]
71 pub fn bind_config(mut self, config: T::BindConfig) -> Self {
72 self.bind_config = Some(config);
73 self
74 }
75
76 #[must_use]
78 pub fn handler(mut self, handler: H) -> Self {
79 self.handler = Some(handler);
80 self
81 }
82
83 #[must_use]
85 pub fn max_connections(mut self, max: usize) -> Self {
86 self.max_connections = max;
87 self
88 }
89
90 #[must_use]
92 pub fn channel_capacity(mut self, capacity: usize) -> Self {
93 self.channel_capacity = capacity;
94 self
95 }
96
97 #[must_use]
102 pub fn build(self) -> (LocalServer<H, T>, ServerHandle) {
103 let handler = self.handler.expect("Handler required");
104 let (cmd_tx, cmd_rx) = MpscChannel::bounded(self.channel_capacity);
105 let (event_tx, event_rx) = MpscChannel::bounded(self.channel_capacity);
106 let cmd_notify = Arc::new(Notify::new());
107
108 let server = LocalServer {
109 bind_addr: self.bind_addr,
110 bind_config: Some(
111 self.bind_config
112 .unwrap_or_else(|| T::BindConfig::from(self.bind_addr)),
113 ),
114 handler: Rc::new(handler),
115 max_connections: self.max_connections,
116 cmd_tx: cmd_tx.clone(),
117 cmd_rx,
118 event_tx,
119 sessions: SessionManager::new(),
120 cmd_notify: Arc::clone(&cmd_notify),
121 _transport: PhantomData,
122 };
123
124 let handle = ServerHandle::new(cmd_tx, event_rx, cmd_notify);
125 (server, handle)
126 }
127}
128
129impl<H: MessageHandler, T: LocalTransport> Default for LocalServerBuilder<H, T> {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135#[allow(dead_code)]
142pub struct LocalServer<H, T: LocalTransport> {
143 bind_addr: SocketAddr,
144 bind_config: Option<T::BindConfig>,
145 handler: Rc<H>,
146 max_connections: usize,
147 cmd_tx: MpscSender<ServerCommand>,
151 cmd_rx: MpscReceiver<ServerCommand>,
152 event_tx: MpscSender<ServerEvent>,
153 sessions: SessionManager,
154 cmd_notify: Arc<Notify>,
155 _transport: PhantomData<T>,
156}
157
158impl<H, T> LocalServer<H, T>
159where
160 H: MessageHandler + 'static,
161 T: LocalTransport,
162 T::Connection: 'static,
163{
164 pub async fn run(&mut self) -> Result<(), ServerError> {
174 let bind_config = self
175 .bind_config
176 .take()
177 .unwrap_or_else(|| T::BindConfig::from(self.bind_addr));
178 let mut listener = T::bind_with(bind_config)
179 .await
180 .map_err(|e| ServerError::Io(std::io::Error::other(e.to_string())))?;
181 let effective_addr = listener.local_addr().unwrap_or(self.bind_addr);
182 tracing::info!("Local server listening on {}", effective_addr);
183 let _ = self
187 .event_tx
188 .try_send(ServerEvent::Listening(effective_addr));
189
190 loop {
191 tokio::select! {
192 result = listener.accept() => {
193 match result {
194 Ok(conn) => {
195 let addr = conn
196 .peer_addr()
197 .unwrap_or_else(|_| "0.0.0.0:0".parse().expect("placeholder"));
198 self.handle_connection(conn, addr);
199 }
200 Err(e) => {
201 tracing::error!("Local accept error: {}", e);
202 }
203 }
204 }
205
206 _ = self.cmd_notify.notified() => {
207 while let Some(cmd) = self.cmd_rx.try_recv() {
208 if self.handle_command(cmd).await {
209 return Ok(());
210 }
211 }
212 }
213 }
214 }
215 }
216
217 fn handle_connection(&mut self, conn: T::Connection, addr: SocketAddr) {
218 if self.sessions.count() >= self.max_connections {
219 tracing::warn!("Max connections reached, rejecting {}", addr);
220 return;
221 }
222
223 let session_id = self.sessions.create_session(addr);
224 let handler = Rc::clone(&self.handler);
225 let event_tx = self.event_tx.clone();
226 let cmd_tx = self.cmd_tx.clone();
231 let cmd_notify = Arc::clone(&self.cmd_notify);
232
233 handler.on_session_start(session_id);
234 let _ = event_tx.try_send(ServerEvent::SessionCreated(session_id, addr));
235
236 let span = tracing::info_span!("sbe_session", session_id, %addr);
241 tokio::task::spawn_local(async move {
242 let _guard = span.enter();
243 tracing::info!("connected");
244 if let Err(e) = handle_local_session(session_id, conn, handler.as_ref()).await {
245 tracing::error!(error = %e, "session error");
246 }
247 tracing::info!("disconnected");
248 handler.on_session_end(session_id);
249 let _ = event_tx.try_send(ServerEvent::SessionClosed(session_id));
250 let _ = cmd_tx.try_send(ServerCommand::CloseSession(session_id));
251 cmd_notify.notify_one();
252 });
253 }
254
255 async fn handle_command(&mut self, cmd: ServerCommand) -> bool {
256 match cmd {
257 ServerCommand::Shutdown => {
258 tracing::info!("Local server shutdown requested");
259 true
260 }
261 ServerCommand::CloseSession(session_id) => {
262 self.sessions.close_session(session_id);
263 false
264 }
265 ServerCommand::Broadcast(_message) => false,
266 }
267 }
268}
269
270struct LocalSessionResponder {
274 tx: tokio_mpsc::UnboundedSender<Vec<u8>>,
275}
276
277impl Responder for LocalSessionResponder {
278 fn send(&self, message: &[u8]) -> Result<(), SendError> {
279 self.tx.send(message.to_vec()).map_err(|_| SendError {
280 message: "channel closed".to_string(),
281 })
282 }
283
284 fn send_to(&self, _session_id: u64, message: &[u8]) -> Result<(), SendError> {
285 self.send(message)
286 }
287}
288
289async fn handle_local_session<H, C>(
296 session_id: u64,
297 mut conn: C,
298 handler: &H,
299) -> Result<(), std::io::Error>
300where
301 H: MessageHandler,
302 C: LocalConnection,
303{
304 let (tx, mut rx) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
305 let responder = LocalSessionResponder { tx };
306
307 loop {
308 tokio::select! {
309 result = conn.recv() => {
310 match result {
311 Ok(Some(data)) => {
312 if data.len() >= MessageHeader::ENCODED_LENGTH {
313 let header = MessageHeader::wrap(data.as_ref(), 0);
314 handler.on_message(session_id, &header, data.as_ref(), &responder);
315 } else {
316 handler.on_error(session_id, "Message too short for header");
317 }
318 }
319 Ok(None) => {
320 return Ok(());
321 }
322 Err(e) => {
323 tracing::error!(error = %e, "read error");
324 return Err(std::io::Error::other(e.to_string()));
325 }
326 }
327 }
328
329 Some(msg) = rx.recv() => {
330 if let Err(e) = conn.send(&msg).await {
331 tracing::error!(error = %e, "write error");
332 return Err(std::io::Error::other(e.to_string()));
333 }
334 }
335 }
336 }
337}
338
339#[cfg(all(test, feature = "tcp-uring", target_os = "linux"))]
340mod tests {
341 use super::*;
342 use crate::handler::Responder;
343 use ironsbe_transport::tcp_uring::UringTcpTransport;
344
345 struct TestHandler;
346 impl MessageHandler for TestHandler {
347 fn on_message(
348 &self,
349 _session_id: u64,
350 _header: &MessageHeader,
351 _data: &[u8],
352 _responder: &dyn Responder,
353 ) {
354 }
355 }
356
357 #[test]
358 fn test_local_server_builder_new() {
359 let builder = LocalServerBuilder::<TestHandler, UringTcpTransport>::new();
360 let _ = builder;
361 }
362
363 #[test]
364 fn test_local_server_builder_default() {
365 let builder = LocalServerBuilder::<TestHandler, UringTcpTransport>::default();
366 let _ = builder;
367 }
368
369 #[test]
370 fn test_local_server_builder_bind() {
371 let addr: SocketAddr = "127.0.0.1:8080".parse().expect("test addr");
372 let builder = LocalServerBuilder::<TestHandler, UringTcpTransport>::new().bind(addr);
373 let _ = builder;
374 }
375
376 #[test]
377 fn test_local_server_builder_max_connections() {
378 let builder =
379 LocalServerBuilder::<TestHandler, UringTcpTransport>::new().max_connections(500);
380 let _ = builder;
381 }
382
383 #[test]
384 fn test_local_server_builder_channel_capacity() {
385 let builder =
386 LocalServerBuilder::<TestHandler, UringTcpTransport>::new().channel_capacity(8192);
387 let _ = builder;
388 }
389
390 #[test]
391 fn test_local_server_builder_build() {
392 let (_server, _handle) = LocalServerBuilder::<TestHandler, UringTcpTransport>::new()
393 .handler(TestHandler)
394 .build();
395 }
396}