1use crate::error::ServerError;
4use crate::handler::{MessageHandler, Responder, SendError};
5use crate::session::SessionManager;
6use ironsbe_channel::mpsc::{MpscChannel, MpscReceiver, MpscSender};
7use ironsbe_core::header::MessageHeader;
8use ironsbe_transport::traits::{Connection, Listener, Transport};
9use parking_lot::RwLock;
10use std::collections::HashMap;
11use std::marker::PhantomData;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use tokio::sync::{Notify, mpsc as tokio_mpsc};
15use tokio_util::sync::CancellationToken;
16
17type SessionSenderMap = Arc<RwLock<HashMap<u64, tokio_mpsc::UnboundedSender<Vec<u8>>>>>;
23
24#[cfg(feature = "tcp-tokio")]
32pub struct ServerBuilder<H, T: Transport = ironsbe_transport::DefaultTransport> {
33 bind_addr: SocketAddr,
34 bind_config: Option<T::BindConfig>,
35 handler: Option<H>,
36 max_connections: usize,
37 channel_capacity: usize,
38 _transport: PhantomData<T>,
39}
40
41#[cfg(not(feature = "tcp-tokio"))]
46pub struct ServerBuilder<H, T: Transport> {
47 bind_addr: SocketAddr,
48 bind_config: Option<T::BindConfig>,
49 handler: Option<H>,
50 max_connections: usize,
51 channel_capacity: usize,
52 _transport: PhantomData<T>,
53}
54
55impl<H: MessageHandler, T: Transport> ServerBuilder<H, T> {
56 #[must_use]
58 pub fn new() -> Self {
59 Self {
60 bind_addr: "0.0.0.0:9000".parse().unwrap(),
61 bind_config: None,
62 handler: None,
63 max_connections: 1000,
64 channel_capacity: 4096,
65 _transport: PhantomData,
66 }
67 }
68
69 #[must_use]
75 pub fn bind(mut self, addr: SocketAddr) -> Self {
76 self.bind_addr = addr;
77 self.bind_config = None;
78 self
79 }
80
81 #[must_use]
87 pub fn bind_config(mut self, config: T::BindConfig) -> Self {
88 self.bind_config = Some(config);
89 self
90 }
91
92 #[must_use]
94 pub fn handler(mut self, handler: H) -> Self {
95 self.handler = Some(handler);
96 self
97 }
98
99 #[must_use]
101 pub fn max_connections(mut self, max: usize) -> Self {
102 self.max_connections = max;
103 self
104 }
105
106 #[must_use]
108 pub fn channel_capacity(mut self, capacity: usize) -> Self {
109 self.channel_capacity = capacity;
110 self
111 }
112
113 #[must_use]
118 pub fn build(self) -> (Server<H, T>, ServerHandle) {
119 let handler = self.handler.expect("Handler required");
120 let (cmd_tx, cmd_rx) = MpscChannel::bounded(self.channel_capacity);
121 let (event_tx, event_rx) = MpscChannel::bounded(self.channel_capacity);
122
123 let cmd_notify = Arc::new(Notify::new());
124
125 let server = Server {
126 bind_addr: self.bind_addr,
127 bind_config: Some(
128 self.bind_config
129 .unwrap_or_else(|| T::BindConfig::from(self.bind_addr)),
130 ),
131 handler: Arc::new(handler),
132 max_connections: self.max_connections,
133 cmd_tx: cmd_tx.clone(),
134 cmd_rx,
135 event_tx,
136 sessions: SessionManager::new(),
137 cmd_notify: Arc::clone(&cmd_notify),
138 shutdown_token: CancellationToken::new(),
139 session_tokens: HashMap::new(),
140 session_senders: Arc::new(RwLock::new(HashMap::new())),
141 _transport: PhantomData,
142 };
143
144 let handle = ServerHandle {
145 cmd_tx,
146 event_rx,
147 cmd_notify,
148 };
149
150 (server, handle)
151 }
152}
153
154impl<H: MessageHandler, T: Transport> Default for ServerBuilder<H, T> {
155 fn default() -> Self {
156 Self::new()
157 }
158}
159
160#[cfg(feature = "tcp-tokio")]
161impl<H: MessageHandler> ServerBuilder<H> {
162 #[must_use]
169 pub fn with_default_transport() -> Self {
170 <Self as Default>::default()
171 }
172
173 #[must_use]
179 pub fn max_frame_size(mut self, size: usize) -> Self {
180 let cfg = self
181 .bind_config
182 .take()
183 .unwrap_or_else(|| ironsbe_transport::tcp::TcpServerConfig::new(self.bind_addr));
184 self.bind_config = Some(cfg.max_frame_size(size));
185 self
186 }
187}
188
189#[cfg(feature = "tcp-tokio")]
193#[allow(dead_code)]
194pub struct Server<H, T: Transport = ironsbe_transport::DefaultTransport> {
195 bind_addr: SocketAddr,
196 bind_config: Option<T::BindConfig>,
197 handler: Arc<H>,
198 max_connections: usize,
199 cmd_tx: MpscSender<ServerCommand>,
205 cmd_rx: MpscReceiver<ServerCommand>,
206 event_tx: MpscSender<ServerEvent>,
207 sessions: SessionManager,
208 cmd_notify: Arc<Notify>,
209 shutdown_token: CancellationToken,
213 session_tokens: HashMap<u64, CancellationToken>,
218 session_senders: SessionSenderMap,
223 _transport: PhantomData<T>,
224}
225
226#[cfg(not(feature = "tcp-tokio"))]
230#[allow(dead_code)]
231pub struct Server<H, T: Transport> {
232 bind_addr: SocketAddr,
233 bind_config: Option<T::BindConfig>,
234 handler: Arc<H>,
235 max_connections: usize,
236 cmd_tx: MpscSender<ServerCommand>,
238 cmd_rx: MpscReceiver<ServerCommand>,
239 event_tx: MpscSender<ServerEvent>,
240 sessions: SessionManager,
241 cmd_notify: Arc<Notify>,
242 shutdown_token: CancellationToken,
244 session_tokens: HashMap<u64, CancellationToken>,
246 session_senders: SessionSenderMap,
248 _transport: PhantomData<T>,
249}
250
251impl<H, T> Server<H, T>
252where
253 H: MessageHandler + Send + Sync + 'static,
254 T: Transport,
255 T::Connection: Send + 'static,
256{
257 pub async fn run(&mut self) -> Result<(), ServerError> {
264 let bind_config = self
265 .bind_config
266 .take()
267 .unwrap_or_else(|| T::BindConfig::from(self.bind_addr));
268 let mut listener = T::bind_with(bind_config)
269 .await
270 .map_err(|e| ServerError::Io(std::io::Error::other(e)))?;
271 let effective_addr = listener.local_addr().unwrap_or(self.bind_addr);
272 tracing::info!("Server listening on {}", effective_addr);
273 let _ = self
276 .event_tx
277 .try_send(ServerEvent::Listening(effective_addr));
278
279 loop {
280 tokio::select! {
281 result = listener.accept() => {
282 match result {
283 Ok(conn) => {
284 let addr = conn.peer_addr().unwrap_or_else(
285 |_| "0.0.0.0:0".parse().unwrap()
286 );
287 self.handle_connection(conn, addr).await;
288 }
289 Err(e) => {
290 tracing::error!("Accept error: {}", e);
291 }
292 }
293 }
294
295 _ = self.cmd_notify.notified() => {
296 while let Some(cmd) = self.cmd_rx.try_recv() {
297 if self.handle_command(cmd).await {
298 return Ok(());
299 }
300 }
301 }
302 }
303 }
304 }
305
306 async fn handle_connection(&mut self, conn: T::Connection, addr: SocketAddr) {
307 if self.sessions.count() >= self.max_connections {
308 tracing::warn!("Max connections reached, rejecting {}", addr);
309 return;
310 }
311
312 let session_id = self.sessions.create_session(addr);
313 let handler = Arc::clone(&self.handler);
314 let event_tx = self.event_tx.clone();
315 let cmd_tx = self.cmd_tx.clone();
320 let cmd_notify = Arc::clone(&self.cmd_notify);
321
322 let session_token = self.shutdown_token.child_token();
326 self.session_tokens
327 .insert(session_id, session_token.clone());
328
329 let (out_tx, out_rx) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
335 self.session_senders
336 .write()
337 .insert(session_id, out_tx.clone());
338 let senders = Arc::clone(&self.session_senders);
339
340 handler.on_session_start(session_id);
341 let _ = event_tx.try_send(ServerEvent::SessionCreated(session_id, addr));
342
343 let span = tracing::info_span!("sbe_session", session_id, %addr);
347 tokio::spawn(async move {
348 let _guard = span.enter();
349 tracing::info!("connected");
350
351 if let Err(e) = handle_session(
352 session_id,
353 conn,
354 handler.as_ref(),
355 session_token,
356 out_tx,
357 out_rx,
358 senders,
359 )
360 .await
361 {
362 tracing::error!(error = %e, "session error");
363 }
364
365 tracing::info!("disconnected");
366 handler.on_session_end(session_id);
367 let _ = event_tx.try_send(ServerEvent::SessionClosed(session_id));
368 let _ = cmd_tx.try_send(ServerCommand::CloseSession(session_id));
369 cmd_notify.notify_one();
370 });
371 }
372
373 async fn handle_command(&mut self, cmd: ServerCommand) -> bool {
374 match cmd {
375 ServerCommand::Shutdown => {
376 tracing::info!("Server shutdown requested");
377 self.shutdown_token.cancel();
382 self.session_tokens.clear();
383 self.session_senders.write().clear();
384 true
385 }
386 ServerCommand::CloseSession(session_id) => {
387 if let Some(token) = self.session_tokens.remove(&session_id) {
393 token.cancel();
394 }
395 self.session_senders.write().remove(&session_id);
396 self.sessions.close_session(session_id);
397 false
398 }
399 ServerCommand::Broadcast(message) => {
400 self.session_senders
406 .write()
407 .retain(|_, sender| sender.send(message.clone()).is_ok());
408 false
409 }
410 }
411 }
412}
413
414pub struct ServerHandle {
416 cmd_tx: MpscSender<ServerCommand>,
417 event_rx: MpscReceiver<ServerEvent>,
418 cmd_notify: Arc<Notify>,
419}
420
421impl ServerHandle {
422 pub(crate) fn new(
428 cmd_tx: MpscSender<ServerCommand>,
429 event_rx: MpscReceiver<ServerEvent>,
430 cmd_notify: Arc<Notify>,
431 ) -> Self {
432 Self {
433 cmd_tx,
434 event_rx,
435 cmd_notify,
436 }
437 }
438
439 pub fn shutdown(&self) {
441 let _ = self.cmd_tx.try_send(ServerCommand::Shutdown);
442 self.cmd_notify.notify_one();
443 }
444
445 pub fn close_session(&self, session_id: u64) {
447 let _ = self
448 .cmd_tx
449 .try_send(ServerCommand::CloseSession(session_id));
450 self.cmd_notify.notify_one();
451 }
452
453 pub fn broadcast(&self, message: Vec<u8>) {
455 let _ = self.cmd_tx.try_send(ServerCommand::Broadcast(message));
456 self.cmd_notify.notify_one();
457 }
458
459 pub fn poll_events(&self) -> impl Iterator<Item = ServerEvent> + '_ {
461 std::iter::from_fn(|| self.event_rx.try_recv())
462 }
463}
464
465#[derive(Debug, Clone)]
467pub enum ServerCommand {
468 Shutdown,
470 CloseSession(u64),
472 Broadcast(Vec<u8>),
474}
475
476#[derive(Debug, Clone)]
478pub enum ServerEvent {
479 Listening(SocketAddr),
483 SessionCreated(u64, SocketAddr),
485 SessionClosed(u64),
487 Error(String),
489}
490
491struct SessionResponder {
502 tx: tokio_mpsc::UnboundedSender<Vec<u8>>,
503 senders: SessionSenderMap,
504 session_id: u64,
505}
506
507impl Responder for SessionResponder {
508 fn send(&self, message: &[u8]) -> Result<(), SendError> {
509 self.tx.send(message.to_vec()).map_err(|_| SendError {
510 message: format!("session {} channel closed", self.session_id),
511 })
512 }
513
514 fn send_to(&self, session_id: u64, message: &[u8]) -> Result<(), SendError> {
515 let senders = self.senders.read();
516 match senders.get(&session_id) {
517 Some(sender) => sender.send(message.to_vec()).map_err(|_| SendError {
518 message: format!("session {session_id} channel closed"),
519 }),
520 None => Err(SendError {
521 message: format!("unknown session {session_id}"),
522 }),
523 }
524 }
525}
526
527async fn handle_session<H, C>(
543 session_id: u64,
544 mut conn: C,
545 handler: &H,
546 session_token: CancellationToken,
547 out_tx: tokio_mpsc::UnboundedSender<Vec<u8>>,
548 mut out_rx: tokio_mpsc::UnboundedReceiver<Vec<u8>>,
549 senders: SessionSenderMap,
550) -> Result<(), std::io::Error>
551where
552 H: MessageHandler,
553 C: Connection,
554{
555 let responder = SessionResponder {
556 tx: out_tx,
557 senders,
558 session_id,
559 };
560
561 loop {
562 tokio::select! {
563 result = conn.recv() => {
565 match result {
566 Ok(Some(data)) => {
567 if data.len() >= MessageHeader::ENCODED_LENGTH {
569 let header = MessageHeader::wrap(data.as_ref(), 0);
570 handler.on_message(session_id, &header, data.as_ref(), &responder);
571 } else {
572 handler.on_error(session_id, "Message too short for header");
573 }
574 }
575 Ok(None) => {
576 return Ok(());
577 }
578 Err(e) => {
579 tracing::error!(error = %e, "read error");
580 return Err(std::io::Error::other(e));
581 }
582 }
583 }
584
585 Some(msg) = out_rx.recv() => {
593 tokio::select! {
594 send_result = conn.send(&msg) => {
595 if let Err(e) = send_result {
596 tracing::error!(error = %e, "write error");
597 return Err(std::io::Error::other(e));
598 }
599 }
600 _ = session_token.cancelled() => {
601 tracing::debug!("session cancelled mid-send");
602 return Ok(());
603 }
604 }
605 }
606
607 _ = session_token.cancelled() => {
611 tracing::debug!("session cancelled");
612 return Ok(());
613 }
614 }
615 }
616}
617
618#[cfg(all(test, feature = "tcp-tokio"))]
619mod tests {
620 use super::*;
621
622 type DefaultBuilder<H> = ServerBuilder<H, ironsbe_transport::DefaultTransport>;
623
624 struct TestHandler;
625
626 impl MessageHandler for TestHandler {
627 fn on_message(
628 &self,
629 _session_id: u64,
630 _header: &MessageHeader,
631 _data: &[u8],
632 _responder: &dyn Responder,
633 ) {
634 }
635 }
636
637 #[test]
638 fn test_server_builder_new() {
639 let builder = DefaultBuilder::<TestHandler>::new();
640 let _ = builder;
641 }
642
643 #[test]
644 fn test_server_builder_default() {
645 let builder = DefaultBuilder::<TestHandler>::default();
646 let _ = builder;
647 }
648
649 #[test]
650 fn test_server_builder_bind() {
651 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
652 let builder = DefaultBuilder::<TestHandler>::new().bind(addr);
653 let _ = builder;
654 }
655
656 #[test]
657 fn test_server_builder_handler() {
658 let builder = DefaultBuilder::<TestHandler>::new().handler(TestHandler);
659 let _ = builder;
660 }
661
662 #[test]
663 fn test_server_builder_max_connections() {
664 let builder = DefaultBuilder::<TestHandler>::new().max_connections(500);
665 let _ = builder;
666 }
667
668 #[test]
669 fn test_server_builder_channel_capacity() {
670 let builder = DefaultBuilder::<TestHandler>::new().channel_capacity(8192);
671 let _ = builder;
672 }
673
674 #[test]
675 fn test_server_builder_build() {
676 let (_server, _handle) = DefaultBuilder::<TestHandler>::new()
677 .handler(TestHandler)
678 .build();
679 }
680
681 #[test]
682 fn test_server_command_debug() {
683 let cmd = ServerCommand::Shutdown;
684 let debug_str = format!("{:?}", cmd);
685 assert!(debug_str.contains("Shutdown"));
686
687 let cmd2 = ServerCommand::CloseSession(42);
688 let debug_str2 = format!("{:?}", cmd2);
689 assert!(debug_str2.contains("CloseSession"));
690
691 let cmd3 = ServerCommand::Broadcast(vec![1, 2, 3]);
692 let debug_str3 = format!("{:?}", cmd3);
693 assert!(debug_str3.contains("Broadcast"));
694 }
695
696 #[test]
697 fn test_server_event_clone_debug() {
698 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
699 let event = ServerEvent::SessionCreated(1, addr);
700 let cloned = event.clone();
701 let _ = cloned;
702
703 let debug_str = format!("{:?}", event);
704 assert!(debug_str.contains("SessionCreated"));
705
706 let event2 = ServerEvent::SessionClosed(1);
707 let debug_str2 = format!("{:?}", event2);
708 assert!(debug_str2.contains("SessionClosed"));
709
710 let event3 = ServerEvent::Error("test error".to_string());
711 let debug_str3 = format!("{:?}", event3);
712 assert!(debug_str3.contains("Error"));
713 }
714
715 #[test]
716 fn test_server_handle_shutdown() {
717 let (_server, handle) = DefaultBuilder::<TestHandler>::new()
718 .handler(TestHandler)
719 .build();
720 handle.shutdown();
721 }
722
723 #[test]
724 fn test_server_handle_close_session() {
725 let (_server, handle) = DefaultBuilder::<TestHandler>::new()
726 .handler(TestHandler)
727 .build();
728 handle.close_session(1);
729 }
730
731 #[test]
732 fn test_server_handle_broadcast() {
733 let (_server, handle) = DefaultBuilder::<TestHandler>::new()
734 .handler(TestHandler)
735 .build();
736 handle.broadcast(vec![1, 2, 3]);
737 }
738
739 #[test]
742 fn test_server_starts_with_uncancelled_shutdown_token() {
743 let (server, _handle) = DefaultBuilder::<TestHandler>::new()
744 .handler(TestHandler)
745 .build();
746
747 assert!(
748 !server.shutdown_token.is_cancelled(),
749 "fresh server should have an uncancelled shutdown_token"
750 );
751 assert!(
752 server.session_tokens.is_empty(),
753 "fresh server should have an empty session_tokens registry"
754 );
755 }
756
757 #[tokio::test]
762 async fn test_shutdown_handler_cancels_every_child_token() {
763 let (mut server, _handle) = DefaultBuilder::<TestHandler>::new()
764 .handler(TestHandler)
765 .build();
766
767 let child_a = server.shutdown_token.child_token();
770 let child_b = server.shutdown_token.child_token();
771 server.session_tokens.insert(1, child_a.clone());
772 server.session_tokens.insert(2, child_b.clone());
773
774 let exited = server.handle_command(ServerCommand::Shutdown).await;
775
776 assert!(exited, "Shutdown must signal the run loop to exit");
777 assert!(
778 server.shutdown_token.is_cancelled(),
779 "parent token must be cancelled after Shutdown"
780 );
781 assert!(
782 child_a.is_cancelled() && child_b.is_cancelled(),
783 "every child token must be cancelled by the parent"
784 );
785 assert!(
786 server.session_tokens.is_empty(),
787 "session_tokens registry must be drained on Shutdown"
788 );
789 }
790
791 #[tokio::test]
796 async fn test_close_session_handler_cancels_only_that_token() {
797 let (mut server, _handle) = DefaultBuilder::<TestHandler>::new()
798 .handler(TestHandler)
799 .build();
800
801 let child_a = server.shutdown_token.child_token();
802 let child_b = server.shutdown_token.child_token();
803 server.session_tokens.insert(1, child_a.clone());
804 server.session_tokens.insert(2, child_b.clone());
805
806 let exited = server.handle_command(ServerCommand::CloseSession(1)).await;
807
808 assert!(!exited, "CloseSession must not stop the run loop");
809 assert!(
810 child_a.is_cancelled(),
811 "the targeted child token must be cancelled"
812 );
813 assert!(
814 !child_b.is_cancelled(),
815 "untargeted siblings must remain live"
816 );
817 assert!(
818 !server.session_tokens.contains_key(&1),
819 "the closed session entry must be removed from the registry"
820 );
821 assert!(
822 server.session_tokens.contains_key(&2),
823 "untargeted entries must stay in the registry"
824 );
825 }
826
827 #[tokio::test]
832 async fn test_close_session_handler_unknown_id_is_noop() {
833 let (mut server, _handle) = DefaultBuilder::<TestHandler>::new()
834 .handler(TestHandler)
835 .build();
836
837 let exited = server
838 .handle_command(ServerCommand::CloseSession(999))
839 .await;
840
841 assert!(!exited);
842 assert!(server.session_tokens.is_empty());
843 }
844
845 #[tokio::test]
848 async fn test_broadcast_handler_with_no_sessions_is_noop() {
849 let (mut server, _handle) = DefaultBuilder::<TestHandler>::new()
850 .handler(TestHandler)
851 .build();
852
853 let exited = server
854 .handle_command(ServerCommand::Broadcast(b"anything".to_vec()))
855 .await;
856
857 assert!(!exited);
858 assert!(server.session_senders.read().is_empty());
859 }
860
861 #[tokio::test]
864 async fn test_broadcast_handler_pushes_to_every_session() {
865 let (mut server, _handle) = DefaultBuilder::<TestHandler>::new()
866 .handler(TestHandler)
867 .build();
868
869 let (tx1, mut rx1) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
870 let (tx2, mut rx2) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
871 {
872 let mut senders = server.session_senders.write();
873 senders.insert(1, tx1);
874 senders.insert(2, tx2);
875 }
876
877 let payload = b"hello-broadcast".to_vec();
878 let exited = server
879 .handle_command(ServerCommand::Broadcast(payload.clone()))
880 .await;
881
882 assert!(!exited);
883 match rx1.try_recv() {
884 Ok(bytes) => assert_eq!(bytes, payload),
885 other => panic!("session 1 did not receive broadcast: {other:?}"),
886 }
887 match rx2.try_recv() {
888 Ok(bytes) => assert_eq!(bytes, payload),
889 other => panic!("session 2 did not receive broadcast: {other:?}"),
890 }
891 assert_eq!(server.session_senders.read().len(), 2);
894 }
895
896 #[tokio::test]
900 async fn test_broadcast_handler_drops_closed_senders() {
901 let (mut server, _handle) = DefaultBuilder::<TestHandler>::new()
902 .handler(TestHandler)
903 .build();
904
905 let (tx_live, mut rx_live) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
906 let (tx_dead, rx_dead) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
907 drop(rx_dead); {
909 let mut senders = server.session_senders.write();
910 senders.insert(1, tx_live);
911 senders.insert(2, tx_dead);
912 }
913
914 let _ = server
915 .handle_command(ServerCommand::Broadcast(b"ping".to_vec()))
916 .await;
917
918 match rx_live.try_recv() {
921 Ok(bytes) => assert_eq!(bytes, b"ping"),
922 other => panic!("live session did not receive broadcast: {other:?}"),
923 }
924 let senders = server.session_senders.read();
925 assert_eq!(senders.len(), 1);
926 assert!(senders.contains_key(&1));
927 assert!(!senders.contains_key(&2));
928 }
929
930 #[tokio::test]
934 async fn test_close_session_handler_removes_session_sender() {
935 let (mut server, _handle) = DefaultBuilder::<TestHandler>::new()
936 .handler(TestHandler)
937 .build();
938
939 let (tx1, _rx1) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
940 let (tx2, _rx2) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
941 {
942 let mut senders = server.session_senders.write();
943 senders.insert(1, tx1);
944 senders.insert(2, tx2);
945 }
946
947 let _ = server.handle_command(ServerCommand::CloseSession(1)).await;
948
949 let senders = server.session_senders.read();
950 assert!(!senders.contains_key(&1));
951 assert!(senders.contains_key(&2));
952 }
953
954 #[test]
958 fn test_session_responder_send_to_unknown_session_returns_err() {
959 let senders: SessionSenderMap = Arc::new(RwLock::new(HashMap::new()));
960 let (tx, _rx) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
961 let responder = SessionResponder {
962 tx,
963 senders,
964 session_id: 1,
965 };
966
967 let result = responder.send_to(99, b"payload");
968 match result {
969 Err(err) => assert!(
970 err.message.contains("unknown session 99"),
971 "unexpected error: {err}"
972 ),
973 Ok(()) => panic!("send_to on unknown session must fail"),
974 }
975 }
976
977 #[test]
980 fn test_session_responder_send_to_routes_to_target() {
981 let senders: SessionSenderMap = Arc::new(RwLock::new(HashMap::new()));
982 let (tx_self, mut rx_self) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
983 let (tx_other, mut rx_other) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
984 senders.write().insert(1, tx_self.clone());
985 senders.write().insert(2, tx_other);
986
987 let responder = SessionResponder {
988 tx: tx_self,
989 senders,
990 session_id: 1,
991 };
992
993 let result = responder.send_to(2, b"cross-routed");
994 assert!(result.is_ok(), "send_to should succeed for a live target");
995
996 match rx_other.try_recv() {
997 Ok(bytes) => assert_eq!(bytes, b"cross-routed"),
998 other => panic!("target session did not receive payload: {other:?}"),
999 }
1000 assert!(
1003 rx_self.try_recv().is_err(),
1004 "send_to must not fall through to the sender's own session"
1005 );
1006 }
1007
1008 #[test]
1012 fn test_session_responder_send_to_closed_channel_returns_err() {
1013 let senders: SessionSenderMap = Arc::new(RwLock::new(HashMap::new()));
1014 let (tx_self, _rx_self) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
1015 let (tx_dead, rx_dead) = tokio_mpsc::unbounded_channel::<Vec<u8>>();
1016 drop(rx_dead);
1017 senders.write().insert(1, tx_self.clone());
1018 senders.write().insert(2, tx_dead);
1019
1020 let responder = SessionResponder {
1021 tx: tx_self,
1022 senders,
1023 session_id: 1,
1024 };
1025
1026 let result = responder.send_to(2, b"lost");
1027 match result {
1028 Err(err) => assert!(
1029 err.message.contains("channel closed"),
1030 "unexpected error: {err}"
1031 ),
1032 Ok(()) => panic!("send_to on closed channel must fail"),
1033 }
1034 }
1035}