1use crate::error::ServerError;
4use crate::handler::{MessageHandler, Responder, SendError};
5use crate::session::SessionManager;
6use bytes::BytesMut;
7use futures::SinkExt;
8use ironsbe_channel::mpsc::{MpscChannel, MpscReceiver, MpscSender};
9use ironsbe_core::header::MessageHeader;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::net::TcpStream;
13use tokio::sync::mpsc as tokio_mpsc;
14use tokio_stream::StreamExt;
15use tokio_util::codec::{Decoder, Encoder, Framed};
16
17pub struct ServerBuilder<H> {
19 bind_addr: SocketAddr,
20 handler: Option<H>,
21 max_connections: usize,
22 max_frame_size: usize,
23 channel_capacity: usize,
24}
25
26impl<H: MessageHandler> ServerBuilder<H> {
27 #[must_use]
29 pub fn new() -> Self {
30 Self {
31 bind_addr: "0.0.0.0:9000".parse().unwrap(),
32 handler: None,
33 max_connections: 1000,
34 max_frame_size: 64 * 1024,
35 channel_capacity: 4096,
36 }
37 }
38
39 #[must_use]
41 pub fn bind(mut self, addr: SocketAddr) -> Self {
42 self.bind_addr = addr;
43 self
44 }
45
46 #[must_use]
48 pub fn handler(mut self, handler: H) -> Self {
49 self.handler = Some(handler);
50 self
51 }
52
53 #[must_use]
55 pub fn max_connections(mut self, max: usize) -> Self {
56 self.max_connections = max;
57 self
58 }
59
60 #[must_use]
62 pub fn max_frame_size(mut self, size: usize) -> Self {
63 self.max_frame_size = size;
64 self
65 }
66
67 #[must_use]
69 pub fn channel_capacity(mut self, capacity: usize) -> Self {
70 self.channel_capacity = capacity;
71 self
72 }
73
74 #[must_use]
79 pub fn build(self) -> (Server<H>, ServerHandle) {
80 let handler = self.handler.expect("Handler required");
81 let (cmd_tx, cmd_rx) = MpscChannel::bounded(self.channel_capacity);
82 let (event_tx, event_rx) = MpscChannel::bounded(self.channel_capacity);
83
84 let server = Server {
85 bind_addr: self.bind_addr,
86 handler: Arc::new(handler),
87 max_connections: self.max_connections,
88 max_frame_size: self.max_frame_size,
89 cmd_rx,
90 event_tx,
91 sessions: SessionManager::new(),
92 };
93
94 let handle = ServerHandle { cmd_tx, event_rx };
95
96 (server, handle)
97 }
98}
99
100impl<H: MessageHandler> Default for ServerBuilder<H> {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106#[allow(dead_code)]
108pub struct Server<H> {
109 bind_addr: SocketAddr,
110 handler: Arc<H>,
111 max_connections: usize,
112 max_frame_size: usize,
113 cmd_rx: MpscReceiver<ServerCommand>,
114 event_tx: MpscSender<ServerEvent>,
115 sessions: SessionManager,
116}
117
118impl<H: MessageHandler + Send + Sync + 'static> Server<H> {
119 pub async fn run(&mut self) -> Result<(), ServerError> {
124 let listener = tokio::net::TcpListener::bind(self.bind_addr).await?;
125 tracing::info!("Server listening on {}", self.bind_addr);
126
127 loop {
128 tokio::select! {
129 result = listener.accept() => {
130 match result {
131 Ok((stream, addr)) => {
132 self.handle_connection(stream, addr).await;
133 }
134 Err(e) => {
135 tracing::error!("Accept error: {}", e);
136 }
137 }
138 }
139
140 cmd = async { self.cmd_rx.try_recv() } => {
141 if let Some(cmd) = cmd && self.handle_command(cmd).await {
142 return Ok(());
143 }
144 }
145 }
146 }
147 }
148
149 async fn handle_connection(&mut self, stream: TcpStream, addr: SocketAddr) {
150 if self.sessions.count() >= self.max_connections {
151 tracing::warn!("Max connections reached, rejecting {}", addr);
152 return;
153 }
154
155 let session_id = self.sessions.create_session(addr);
156 let handler = Arc::clone(&self.handler);
157 let event_tx = self.event_tx.clone();
158 let max_frame_size = self.max_frame_size;
159
160 handler.on_session_start(session_id);
161 let _ = event_tx.try_send(ServerEvent::SessionCreated(session_id, addr));
162
163 tokio::spawn(async move {
165 tracing::info!("Session {} connected from {}", session_id, addr);
166
167 if let Err(e) =
168 handle_session(session_id, stream, handler.as_ref(), max_frame_size).await
169 {
170 tracing::error!("Session {} error: {:?}", session_id, e);
171 }
172
173 handler.on_session_end(session_id);
175 let _ = event_tx.try_send(ServerEvent::SessionClosed(session_id));
176 });
177 }
178
179 async fn handle_command(&mut self, cmd: ServerCommand) -> bool {
180 match cmd {
181 ServerCommand::Shutdown => {
182 tracing::info!("Server shutdown requested");
183 true
184 }
185 ServerCommand::CloseSession(session_id) => {
186 self.sessions.close_session(session_id);
187 false
188 }
189 ServerCommand::Broadcast(_message) => {
190 false
192 }
193 }
194 }
195}
196
197pub struct ServerHandle {
199 cmd_tx: MpscSender<ServerCommand>,
200 event_rx: MpscReceiver<ServerEvent>,
201}
202
203impl ServerHandle {
204 pub fn shutdown(&self) {
206 let _ = self.cmd_tx.try_send(ServerCommand::Shutdown);
207 }
208
209 pub fn close_session(&self, session_id: u64) {
211 let _ = self
212 .cmd_tx
213 .try_send(ServerCommand::CloseSession(session_id));
214 }
215
216 pub fn broadcast(&self, message: Vec<u8>) {
218 let _ = self.cmd_tx.try_send(ServerCommand::Broadcast(message));
219 }
220
221 pub fn poll_events(&self) -> impl Iterator<Item = ServerEvent> + '_ {
223 std::iter::from_fn(|| self.event_rx.try_recv())
224 }
225}
226
227#[derive(Debug)]
229pub enum ServerCommand {
230 Shutdown,
232 CloseSession(u64),
234 Broadcast(Vec<u8>),
236}
237
238#[derive(Debug, Clone)]
240pub enum ServerEvent {
241 SessionCreated(u64, SocketAddr),
243 SessionClosed(u64),
245 Error(String),
247}
248
249struct SbeFrameCodec {
251 max_frame_size: usize,
252}
253
254impl SbeFrameCodec {
255 fn new(max_frame_size: usize) -> Self {
256 Self { max_frame_size }
257 }
258}
259
260impl Decoder for SbeFrameCodec {
261 type Item = BytesMut;
262 type Error = std::io::Error;
263
264 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
265 if src.len() < 4 {
266 return Ok(None);
267 }
268
269 let length = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
270
271 if length > self.max_frame_size {
272 return Err(std::io::Error::new(
273 std::io::ErrorKind::InvalidData,
274 format!("Frame too large: {} > {}", length, self.max_frame_size),
275 ));
276 }
277
278 if src.len() < 4 + length {
279 src.reserve(4 + length - src.len());
280 return Ok(None);
281 }
282
283 let _ = src.split_to(4);
284 Ok(Some(src.split_to(length)))
285 }
286}
287
288impl<T: AsRef<[u8]>> Encoder<T> for SbeFrameCodec {
289 type Error = std::io::Error;
290
291 fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
292 let data = item.as_ref();
293 let length = data.len() as u32;
294 dst.reserve(4 + data.len());
295 dst.extend_from_slice(&length.to_le_bytes());
296 dst.extend_from_slice(data);
297 Ok(())
298 }
299}
300
301struct SessionResponder {
303 tx: tokio_mpsc::UnboundedSender<Vec<u8>>,
304}
305
306impl Responder for SessionResponder {
307 fn send(&self, message: &[u8]) -> Result<(), SendError> {
308 self.tx.send(message.to_vec()).map_err(|_| SendError {
309 message: "channel closed".to_string(),
310 })
311 }
312
313 fn send_to(&self, _session_id: u64, message: &[u8]) -> Result<(), SendError> {
314 self.send(message)
316 }
317}
318
319async fn handle_session<H: MessageHandler>(
321 session_id: u64,
322 stream: TcpStream,
323 handler: &H,
324 max_frame_size: usize,
325) -> Result<(), std::io::Error> {
326 let codec = SbeFrameCodec::new(max_frame_size);
327 let mut framed = Framed::new(stream, codec);
328
329 let (tx, mut rx) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
331 let responder = SessionResponder { tx };
332
333 loop {
334 tokio::select! {
335 result = framed.next() => {
337 match result {
338 Some(Ok(data)) => {
339 if data.len() >= MessageHeader::ENCODED_LENGTH {
341 let header = MessageHeader::wrap(data.as_ref(), 0);
342 handler.on_message(session_id, &header, data.as_ref(), &responder);
343 } else {
344 handler.on_error(session_id, "Message too short for header");
345 }
346 }
347 Some(Err(e)) => {
348 tracing::error!("Session {} read error: {}", session_id, e);
349 return Err(e);
350 }
351 None => {
352 tracing::info!("Session {} disconnected", session_id);
353 return Ok(());
354 }
355 }
356 }
357
358 Some(msg) = rx.recv() => {
360 if let Err(e) = framed.send(msg).await {
361 tracing::error!("Session {} write error: {}", session_id, e);
362 return Err(e);
363 }
364 }
365 }
366 }
367}