1use std::{collections::HashMap, time::Duration};
6
7use display_error_chain::ErrorChainExt;
8use parking_lot::Mutex;
9use tokio::sync::{self, oneshot};
10use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug};
13
14use slim_auth::traits::{TokenProvider, Verifier};
15use slim_datapath::{
16 api::{
17 CommandPayload, Content, ProtoMessage as Message, ProtoSessionMessageType,
18 ProtoSessionType, SlimHeader,
19 },
20 messages::{Name, utils::SlimHeaderFlags},
21};
22
23use crate::{
25 MessageDirection, SessionError, Transmitter,
26 common::SessionMessage,
27 completion_handle::CompletionHandle,
28 controller_sender::{ControllerSender, PING_INTERVAL},
29 session_builder::{ForController, SessionBuilder},
30 session_config::SessionConfig,
31 session_settings::SessionSettings,
32 traits::{MessageHandler, ProcessingState},
33};
34
35pub struct SessionController {
36 pub(crate) id: u32,
38
39 pub(crate) source: Name,
41
42 pub(crate) destination: Name,
44
45 pub(crate) config: SessionConfig,
47
48 tx_controller: sync::mpsc::Sender<SessionMessage>,
50
51 pub(crate) cancellation_token: CancellationToken,
53
54 handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
56}
57
58impl SessionController {
59 pub fn builder<P, V>() -> SessionBuilder<P, V, ForController>
61 where
62 P: TokenProvider + Send + Sync + Clone + 'static,
63 V: Verifier + Send + Sync + Clone + 'static,
64 {
65 SessionBuilder::for_controller()
66 }
67
68 #[allow(clippy::too_many_arguments)]
70 pub(crate) fn from_parts<I, P, V, M>(
71 id: u32,
72 source: Name,
73 destination: Name,
74 config: SessionConfig,
75 settings: SessionSettings<P, V, M>,
76 tx: sync::mpsc::Sender<SessionMessage>,
77 rx: sync::mpsc::Receiver<SessionMessage>,
78 inner: I,
79 ) -> Self
80 where
81 I: MessageHandler + Send + Sync + 'static,
82 P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
83 V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
84 M: crate::subscription_manager::SubscriptionOps,
85 {
86 let cancellation_token = CancellationToken::new();
88
89 let span = tracing::debug_span!(
91 parent: None,
92 "session_controller_processing_loop",
93 session_id = id,
94 service_id = %settings.service_id,
95 source = %source,
96 destination = %destination,
97 session_type = ?config.session_type
98 );
99
100 let handle = tokio::spawn(
101 Self::processing_loop(inner, rx, cancellation_token.clone(), settings).instrument(span),
102 );
103
104 Self {
105 id,
106 source,
107 destination,
108 config,
109 tx_controller: tx,
110 cancellation_token,
111 handle: Mutex::new(Some(handle)),
112 }
113 }
114
115 fn enter_draining_state<P, V, M>(
117 shutdown_deadline: &mut std::pin::Pin<&mut tokio::time::Sleep>,
118 settings: &SessionSettings<P, V, M>,
119 ) where
120 P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
121 V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
122 M: crate::subscription_manager::SubscriptionOps,
123 {
124 let shutdown_timeout = settings
125 .graceful_shutdown_timeout
126 .unwrap_or(Duration::from_secs(60));
127 shutdown_deadline
128 .as_mut()
129 .reset(tokio::time::Instant::now() + shutdown_timeout);
130 }
131
132 async fn processing_loop<P, V, M>(
133 mut inner: impl MessageHandler + 'static,
134 mut rx: sync::mpsc::Receiver<SessionMessage>,
135 cancellation_token: CancellationToken,
136 settings: SessionSettings<P, V, M>,
137 ) where
138 P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
139 V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
140 M: crate::subscription_manager::SubscriptionOps,
141 {
142 let mut shutdown_deadline = std::pin::pin!(tokio::time::sleep(Duration::MAX));
144
145 if let Err(e) = inner.init().await {
147 tracing::error!(error = %e.chain(), "error during initialization of session");
148 }
149
150 loop {
151 tokio::select! {
152 _ = cancellation_token.cancelled(), if inner.processing_state() == ProcessingState::Active => {
153 let shutdown_timeout = settings.graceful_shutdown_timeout
155 .unwrap_or(Duration::from_secs(60)); debug!("consuming pending messages before entering draining state");
159 while let Ok(msg) = rx.try_recv() {
160 if let Err(e) = inner.on_message(msg).await {
161 tracing::error!(error = %e.chain(), "error processing message during draining - close immediately.");
162 break;
163 }
164 }
165
166 if let Err(e) = inner.on_message(SessionMessage::StartDrain {
168 grace_period: shutdown_timeout
169 }).await {
170 tracing::error!(error = %e.chain(), "error during start drain");
171 break;
172 }
173
174 Self::enter_draining_state(&mut shutdown_deadline, &settings);
175
176 debug!("cancellation requested, entering draining state");
177 }
178 _ = &mut shutdown_deadline => {
179 debug!("graceful shutdown timeout reached, forcing exit");
180 break;
181 }
182 msg = rx.recv() => {
183 match msg {
184 Some(session_message) => {
185 if let SessionMessage::GetParticipantsList { tx } = session_message {
187 let participants_list = inner.participants_list();
188 let _ = tx.send(participants_list);
189 continue;
190 }
191
192 let draining = inner.processing_state() == ProcessingState::Draining;
193
194 if draining && matches!(session_message, SessionMessage::OnMessage { direction: MessageDirection::South, .. }) {
196 tracing::debug!("session is draining, rejecting new messages from application");
197 if let SessionMessage::OnMessage { ack_tx: Some(ack_tx), .. } = session_message {
198 let _ = ack_tx.send(Err(SessionError::SessionDrainingDrop));
199 }
200 continue;
201 }
202
203 if let Err(e) = inner.on_message(session_message).await {
204 debug!(
205 error=%e,
206 "Error processing message{}",
207 if draining { " during graceful shutdown" } else { "" }
208 );
209 if draining {
210 debug!("Exiting processing loop due to error while draining");
211 break;
212 }
213 } else {
214 if !draining && inner.processing_state() == ProcessingState::Draining {
217 debug!("internal component requested draining, entering draining state");
218 Self::enter_draining_state(&mut shutdown_deadline, &settings);
219 }
220 }
221 }
222 None => {
223 debug!("Session channel closed, no more messages can arrive - exiting processing loop");
224 break;
225 }
226 }
227 }
228 }
229
230 if inner.processing_state() == ProcessingState::Draining && !inner.needs_drain() {
232 debug!("draining complete, exiting processing loop");
233 break;
234 }
235 }
236
237 if let Err(e) = inner.on_shutdown().await {
239 tracing::error!(error = %e.chain(), "error during shutdown of session");
240 }
241 }
242
243 pub fn id(&self) -> u32 {
245 self.id
246 }
247
248 pub fn source(&self) -> &Name {
249 &self.source
250 }
251
252 pub fn dst(&self) -> &Name {
253 &self.destination
254 }
255
256 pub fn session_type(&self) -> ProtoSessionType {
257 self.config.session_type
258 }
259
260 pub fn metadata(&self) -> HashMap<String, String> {
261 self.config.metadata.clone()
262 }
263
264 pub fn session_config(&self) -> SessionConfig {
265 self.config.clone()
266 }
267
268 pub fn is_initiator(&self) -> bool {
269 self.config.initiator
270 }
271
272 pub async fn participants_list(&self) -> Result<Vec<Name>, SessionError> {
273 let (tx, rx) = oneshot::channel();
274
275 self.tx_controller
277 .send(SessionMessage::GetParticipantsList { tx })
278 .await
279 .map_err(|_| SessionError::ParticipantsListQueryFailed)?;
280
281 rx.await
283 .map_err(|_| SessionError::ParticipantsListQueryFailed)
284 }
285
286 async fn on_message(
287 &self,
288 message: Message,
289 direction: MessageDirection,
290 ack_tx: Option<oneshot::Sender<Result<(), SessionError>>>,
291 ) -> Result<(), SessionError> {
292 self.tx_controller
293 .send(SessionMessage::OnMessage {
294 message,
295 direction,
296 ack_tx,
297 })
298 .await
299 .map_err(|_e| SessionError::SessionControllerSendFailed)
300 }
301
302 pub async fn on_message_from_app(
304 &self,
305 message: Message,
306 ) -> Result<CompletionHandle, SessionError> {
307 let (ack_tx, ack_rx) = oneshot::channel();
308 self.on_message(message, MessageDirection::South, Some(ack_tx))
309 .await?;
310
311 let ret = CompletionHandle::from_oneshot_receiver(ack_rx);
312
313 Ok(ret)
314 }
315
316 pub async fn on_message_from_slim(&self, message: Message) -> Result<(), SessionError> {
318 self.on_message(message, MessageDirection::North, None)
319 .await
320 }
321
322 pub async fn on_error_message_from_slim(
324 &self,
325 error: SessionError,
326 ) -> Result<(), SessionError> {
327 self.tx_controller
328 .send(SessionMessage::MessageError { error })
329 .await
330 .map_err(|_e| SessionError::SessionControllerSendFailed)
331 }
332
333 pub fn close(&self) -> Result<tokio::task::JoinHandle<()>, SessionError> {
334 self.cancellation_token.cancel();
335
336 self.handle
337 .lock()
338 .take()
339 .ok_or(SessionError::SessionAlreadyClosed)
340 }
341
342 pub async fn publish_message(
343 &self,
344 message: Message,
345 ) -> Result<CompletionHandle, SessionError> {
346 self.on_message_from_app(message).await
347 }
348
349 pub async fn publish_to(
351 &self,
352 name: &Name,
353 forward_to: u64,
354 blob: Vec<u8>,
355 payload_type: Option<String>,
356 metadata: Option<HashMap<String, String>>,
357 ) -> Result<CompletionHandle, SessionError> {
358 self.publish_with_flags(
359 name,
360 SlimHeaderFlags::default().with_forward_to(forward_to),
361 blob,
362 payload_type,
363 metadata,
364 )
365 .await
366 }
367
368 pub async fn publish(
370 &self,
371 name: &Name,
372 blob: Vec<u8>,
373 payload_type: Option<String>,
374 metadata: Option<HashMap<String, String>>,
375 ) -> Result<CompletionHandle, SessionError> {
376 self.publish_with_flags(
377 name,
378 SlimHeaderFlags::default(),
379 blob,
380 payload_type,
381 metadata,
382 )
383 .await
384 }
385
386 pub async fn publish_with_flags(
388 &self,
389 name: &Name,
390 flags: SlimHeaderFlags,
391 blob: Vec<u8>,
392 payload_type: Option<String>,
393 metadata: Option<HashMap<String, String>>,
394 ) -> Result<CompletionHandle, SessionError> {
395 let ct = payload_type.unwrap_or_else(|| "msg".to_string());
396
397 let mut msg = Message::builder()
398 .source(self.source().clone())
399 .destination(name.clone())
400 .identity("")
401 .flags(flags)
402 .session_type(self.session_type())
403 .session_message_type(ProtoSessionMessageType::Msg)
404 .session_id(self.id())
405 .message_id(rand::random::<u32>()) .application_payload(&ct, blob)
407 .build_publish()?;
408 if let Some(map) = metadata
409 && !map.is_empty()
410 {
411 msg.set_metadata_map(map);
412 }
413
414 self.publish_message(msg).await
416 }
417
418 fn create_discovery_request(&self, destination: &Name) -> Result<Message, SessionError> {
420 let payload = CommandPayload::builder()
421 .discovery_request(None)
422 .as_content();
423
424 let msg = Message::builder()
425 .source(self.source().clone())
426 .destination(destination.clone())
427 .identity("")
428 .session_type(self.session_type())
429 .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
430 .session_id(self.id())
431 .message_id(rand::random::<u32>())
432 .payload(payload)
433 .build_publish()?;
434
435 Ok(msg)
436 }
437
438 pub(crate) async fn invite_participant_internal(
439 &self,
440 destination: &Name,
441 ) -> Result<CompletionHandle, SessionError> {
442 let msg = self.create_discovery_request(destination)?;
443 self.publish_message(msg).await
444 }
445
446 pub async fn invite_participant(
447 &self,
448 destination: &Name,
449 ) -> Result<CompletionHandle, SessionError> {
450 match self.session_type() {
451 ProtoSessionType::PointToPoint => Err(SessionError::CannotInviteToP2P),
452 ProtoSessionType::Multicast => {
453 if !self.is_initiator() {
454 return Err(SessionError::NotInitiator);
455 }
456 self.invite_participant_internal(destination).await
457 }
458 _ => Err(SessionError::SessionTypeUnknown(self.session_type())),
459 }
460 }
461
462 pub async fn remove_participant(
463 &self,
464 destination: &Name,
465 ) -> Result<CompletionHandle, SessionError> {
466 match self.session_type() {
467 ProtoSessionType::PointToPoint => Err(SessionError::CannotRemoveFromP2P),
468 ProtoSessionType::Multicast => {
469 if !self.is_initiator() {
470 return Err(SessionError::NotInitiator);
471 }
472 let msg = Message::builder()
473 .source(self.source().clone())
474 .destination(destination.clone().with_id(Name::NULL_COMPONENT))
475 .identity("")
476 .session_type(ProtoSessionType::Multicast)
477 .session_message_type(ProtoSessionMessageType::LeaveRequest)
478 .session_id(self.id())
479 .message_id(rand::random::<u32>())
480 .payload(CommandPayload::builder().leave_request(None).as_content())
481 .build_publish()?;
482 self.publish_message(msg).await
483 }
484 _ => Err(SessionError::SessionTypeUnknown(self.session_type())),
485 }
486 }
487}
488
489impl Drop for SessionController {
490 fn drop(&mut self) {
491 self.cancellation_token.cancel();
492 }
493}
494
495pub fn handle_channel_discovery_message(
496 message: &Message,
497 app_name: &Name,
498 session_id: u32,
499 session_type: ProtoSessionType,
500) -> Result<Message, SessionError> {
501 let destination = message.get_source();
502
503 let mut source = message.get_dst();
509 source.set_id(app_name.id());
510 let msg_id = message.get_id();
511
512 let slim_header = SlimHeader::new(
513 &source,
514 &destination,
515 "", Some(SlimHeaderFlags::default().with_forward_to(message.get_incoming_conn())),
517 );
518
519 let msg = Message::builder()
520 .with_slim_header(slim_header)
521 .session_type(session_type)
522 .session_message_type(ProtoSessionMessageType::DiscoveryReply)
523 .session_id(session_id)
524 .message_id(msg_id)
525 .payload(CommandPayload::builder().discovery_reply().as_content())
526 .build_publish()?;
527
528 Ok(msg)
529}
530
531pub(crate) struct SessionControllerCommon<
532 P,
533 V,
534 M = crate::subscription_manager::SubscriptionManager,
535> where
536 P: TokenProvider + Send + Sync + Clone + 'static,
537 V: Verifier + Send + Sync + Clone + 'static,
538 M: crate::subscription_manager::SubscriptionOps,
539{
540 pub(crate) settings: SessionSettings<P, V, M>,
542
543 pub(crate) sender: ControllerSender,
545
546 pub(crate) processing_state: ProcessingState,
548
549 subscription_ids: HashMap<(SubscriptionKind, Name, u64), u64>,
551}
552
553#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
556pub(crate) enum SubscriptionKind {
557 Route,
559 Subscription,
561}
562
563impl<P, V, M> SessionControllerCommon<P, V, M>
564where
565 P: TokenProvider + Send + Sync + Clone + 'static,
566 V: Verifier + Send + Sync + Clone + 'static,
567 M: crate::subscription_manager::SubscriptionOps,
568{
569 pub(crate) fn new(settings: SessionSettings<P, V, M>) -> Self {
570 let controller_sender = ControllerSender::new(
572 settings.config.get_timer_settings(),
573 settings.source.clone(),
574 settings.config.session_type,
575 settings.id,
576 Some(PING_INTERVAL),
577 settings.config.initiator,
578 settings.tx.clone(),
580 settings.tx_session.clone(),
582 );
583
584 SessionControllerCommon {
585 settings,
586 sender: controller_sender,
587 processing_state: ProcessingState::Active,
588 subscription_ids: HashMap::new(),
589 }
590 }
591
592 pub(crate) async fn send_to_slim(&self, message: Message) -> Result<(), SessionError> {
594 self.settings.tx.send_to_slim(Ok(message)).await
595 }
596
597 pub(crate) async fn send_to_app(&self, error: SessionError) -> Result<(), SessionError> {
599 self.settings.tx.send_to_app(Err(error)).await
600 }
601
602 pub(crate) async fn send_with_timer(&mut self, message: Message) -> Result<(), SessionError> {
604 self.sender.on_message(&message).await
605 }
606
607 async fn await_subscription_ack(
608 rx: tokio::sync::oneshot::Receiver<
609 Result<(), crate::subscription_manager::SubscriptionAckError>,
610 >,
611 ) -> Result<(), SessionError> {
612 crate::subscription_manager::SubscriptionManager::await_ack(rx)
613 .await
614 .map_err(SessionError::SubscriptionAckFailed)
615 }
616
617 pub(crate) async fn add_route(&mut self, name: Name, conn: u64) -> Result<(), SessionError> {
618 if name == self.settings.source {
619 return Ok(());
621 }
622
623 let (subscription_id, rx) = self
624 .settings
625 .subscription_manager
626 .set_route(&self.settings.source, &name, conn)
627 .await
628 .map_err(SessionError::SubscriptionAckFailed)?;
629 Self::await_subscription_ack(rx).await?;
630
631 debug!(%name, %conn, %subscription_id, source = %self.settings.source, "route added");
632
633 self.subscription_ids
634 .insert((SubscriptionKind::Route, name, conn), subscription_id);
635
636 Ok(())
637 }
638
639 pub(crate) async fn delete_route(&mut self, name: Name, conn: u64) -> Result<(), SessionError> {
640 if name == self.settings.source {
641 return Ok(());
643 }
644
645 let key = (SubscriptionKind::Route, name, conn);
646 let subscription_id = self.subscription_ids.remove(&key);
647 let (_, name, conn) = key;
648 match subscription_id {
649 Some(subscription_id) => {
650 let rx = self
651 .settings
652 .subscription_manager
653 .remove_route(&self.settings.source, &name, subscription_id, conn)
654 .await
655 .map_err(SessionError::SubscriptionAckFailed)?;
656
657 Self::await_subscription_ack(rx).await?;
658 tracing::debug!(%name, %conn, %subscription_id, "route deleted");
659 }
660 None => {
661 tracing::warn!(
662 %name, %conn, io = %self.settings.source,
663 "no subscription_id found for route, skipping delete"
664 );
665 }
666 }
667
668 Ok(())
669 }
670
671 pub(crate) async fn add_subscription(
672 &mut self,
673 name: Name,
674 conn: u64,
675 ) -> Result<(), SessionError> {
676 let (subscription_id, rx) = self
677 .settings
678 .subscription_manager
679 .subscribe(&self.settings.source, &name, Some(conn))
680 .await
681 .map_err(SessionError::SubscriptionAckFailed)?;
682
683 Self::await_subscription_ack(rx).await?;
684
685 debug!(%name, %conn, %subscription_id, "subscription added");
686
687 self.subscription_ids.insert(
688 (SubscriptionKind::Subscription, name, conn),
689 subscription_id,
690 );
691
692 Ok(())
693 }
694
695 pub(crate) async fn delete_subscription(
696 &mut self,
697 name: Name,
698 conn: u64,
699 ) -> Result<(), SessionError> {
700 let key = (SubscriptionKind::Subscription, name, conn);
701 let subscription_id = self.subscription_ids.remove(&key);
702 let (_, name, conn) = key;
703 match subscription_id {
704 Some(subscription_id) => {
705 let rx = self
706 .settings
707 .subscription_manager
708 .unsubscribe(&self.settings.source, &name, subscription_id, Some(conn))
709 .await
710 .map_err(SessionError::SubscriptionAckFailed)?;
711
712 Self::await_subscription_ack(rx).await?;
713 debug!(%name, %conn, %subscription_id, "subscription deleted");
714 }
715 None => {
716 tracing::warn!(
717 %name, %conn,
718 "no subscription_id found for subscription, skipping delete"
719 );
720 }
721 }
722
723 Ok(())
724 }
725
726 pub(crate) fn create_control_message(
727 &mut self,
728 dst: &Name,
729 message_type: ProtoSessionMessageType,
730 message_id: u32,
731 payload: Content,
732 broadcast: bool,
733 ) -> Result<Message, SessionError> {
734 let mut builder = Message::builder()
735 .source(self.settings.source.clone())
736 .destination(dst.clone())
737 .identity("")
738 .session_type(self.settings.config.session_type)
739 .session_message_type(message_type)
740 .session_id(self.settings.id)
741 .message_id(message_id)
742 .payload(payload);
743
744 if broadcast {
745 builder = builder.fanout(256);
746 }
747
748 let ret = builder.build_publish()?;
749
750 Ok(ret)
751 }
752
753 pub(crate) async fn send_control_message(
755 &mut self,
756 dst: &Name,
757 message_type: ProtoSessionMessageType,
758 message_id: u32,
759 payload: Content,
760 metadata: Option<HashMap<String, String>>,
761 broadcast: bool,
762 ) -> Result<(), SessionError> {
763 let mut msg =
764 self.create_control_message(dst, message_type, message_id, payload, broadcast)?;
765 if let Some(m) = metadata {
766 msg.set_metadata_map(m);
767 }
768 self.send_with_timer(msg).await
769 }
770}
771
772#[cfg(test)]
773mod tests {
774 use super::*;
775
776 use crate::subscription_manager::{SpySubscriptionManager, SubscriptionCall};
783 use crate::transmitter::SessionTransmitter;
784 use slim_auth::shared_secret::SharedSecret;
785
786 use std::sync::Arc;
787 use std::sync::atomic::AtomicBool;
788 use std::time::Duration;
789 use tokio::time::timeout;
790 use tracing_test::traced_test;
791
792 const SHARED_SECRET: &str = "kjandjansdiasb8udaijdniasdaindasndasndasndasndasndasndasndas";
793
794 struct SessionControllerTestBuilder {
796 session_id: u32,
797 source: Name,
798 destination: Name,
799 session_type: ProtoSessionType,
800 mls_enabled: bool,
801 initiator: bool,
802 max_retries: Option<u32>,
803 interval: Option<Duration>,
804 metadata: HashMap<String, String>,
805 graceful_shutdown_timeout: Option<Duration>,
806 }
807
808 impl SessionControllerTestBuilder {
809 #[allow(dead_code)]
810 fn new() -> Self {
811 Self {
812 session_id: 10,
813 source: Name::from_strings(["org", "ns", "source"]).with_id(1),
814 destination: Name::from_strings(["org", "ns", "dest"]).with_id(2),
815 session_type: ProtoSessionType::PointToPoint,
816 mls_enabled: false,
817 initiator: true,
818 max_retries: Some(5),
819 interval: Some(Duration::from_millis(200)),
820 metadata: HashMap::new(),
821 graceful_shutdown_timeout: Some(Duration::from_secs(10)),
822 }
823 }
824
825 fn with_session_id(mut self, id: u32) -> Self {
826 self.session_id = id;
827 self
828 }
829
830 #[allow(dead_code)]
831 fn with_source(mut self, source: Name) -> Self {
832 self.source = source;
833 self
834 }
835
836 #[allow(dead_code)]
837 fn with_destination(mut self, destination: Name) -> Self {
838 self.destination = destination;
839 self
840 }
841
842 fn with_session_type(mut self, session_type: ProtoSessionType) -> Self {
843 self.session_type = session_type;
844 self
845 }
846
847 fn with_mls_enabled(mut self, enabled: bool) -> Self {
848 self.mls_enabled = enabled;
849 self
850 }
851
852 fn with_initiator(mut self, initiator: bool) -> Self {
853 self.initiator = initiator;
854 self
855 }
856
857 fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
858 self.metadata = metadata;
859 self
860 }
861
862 fn with_graceful_shutdown_timeout(mut self, timeout: Duration) -> Self {
863 self.graceful_shutdown_timeout = Some(timeout);
864 self
865 }
866
867 fn build(
868 self,
869 ) -> (
870 SessionController,
871 tokio::sync::mpsc::Receiver<Result<Message, slim_datapath::Status>>,
872 tokio::sync::mpsc::UnboundedReceiver<Result<Message, SessionError>>,
873 ) {
874 let config = SessionConfig {
875 session_type: self.session_type,
876 max_retries: self.max_retries,
877 interval: self.interval,
878 mls_enabled: self.mls_enabled,
879 initiator: self.initiator,
880 metadata: self.metadata,
881 };
882
883 let (tx_slim, rx_slim) = tokio::sync::mpsc::channel(10);
884 let (tx_app, rx_app) = tokio::sync::mpsc::unbounded_channel();
885 let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10);
886
887 let tx = SessionTransmitter::new(tx_slim, tx_app);
888
889 let controller = SessionController::builder()
890 .with_id(self.session_id)
891 .with_source(self.source.clone())
892 .with_destination(self.destination.clone())
893 .with_config(config)
894 .with_identity_provider(SharedSecret::new("test", SHARED_SECRET).unwrap())
895 .with_identity_verifier(SharedSecret::new("test", SHARED_SECRET).unwrap())
896 .with_tx(tx)
897 .with_tx_to_session_layer(tx_session_layer)
898 .ready()
899 .expect("failed to validate builder")
900 .build()
901 .expect("failed to build controller");
902
903 (controller, rx_slim, rx_app)
904 }
905 }
906
907 #[tokio::test]
908 async fn test_session_controller_getters() {
909 let mut metadata = HashMap::new();
910 metadata.insert("key1".to_string(), "value1".to_string());
911
912 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
913 .with_session_id(42)
914 .with_session_type(ProtoSessionType::Multicast)
915 .with_mls_enabled(true)
916 .with_metadata(metadata)
917 .build();
918
919 assert_eq!(controller.id(), 42);
920 assert_eq!(
921 controller.source(),
922 &Name::from_strings(["org", "ns", "source"]).with_id(1)
923 );
924 assert_eq!(
925 controller.dst(),
926 &Name::from_strings(["org", "ns", "dest"]).with_id(2)
927 );
928 assert_eq!(controller.session_type(), ProtoSessionType::Multicast);
929 assert!(controller.is_initiator());
930 assert_eq!(
931 controller.metadata().get("key1"),
932 Some(&"value1".to_string())
933 );
934
935 let retrieved_config = controller.session_config();
936 assert_eq!(retrieved_config.session_type, ProtoSessionType::Multicast);
937 assert_eq!(retrieved_config.max_retries, Some(5));
938 }
939
940 #[tokio::test]
941 async fn test_publish_basic() {
942 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
943
944 let target_name = Name::from_strings(["org", "ns", "target"]);
945 let payload = b"Hello World".to_vec();
946
947 controller
948 .publish(
949 &target_name,
950 payload.clone(),
951 Some("test-type".to_string()),
952 None,
953 )
954 .await
955 .expect("publish should succeed");
956
957 tokio::time::sleep(Duration::from_millis(50)).await;
958 }
959
960 #[tokio::test]
961 async fn test_publish_to_specific_connection() {
962 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
963 .with_session_type(ProtoSessionType::Multicast)
964 .build();
965
966 let target_name = Name::from_strings(["org", "ns", "target"]);
967 let payload = b"Hello to specific connection".to_vec();
968 let connection_id = 123u64;
969
970 controller
971 .publish_to(
972 &target_name,
973 connection_id,
974 payload.clone(),
975 Some("test-type".to_string()),
976 None,
977 )
978 .await
979 .expect("publish_to should succeed");
980
981 tokio::time::sleep(Duration::from_millis(50)).await;
982 }
983
984 #[tokio::test]
985 async fn test_publish_with_metadata() {
986 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
987 .with_session_type(ProtoSessionType::Multicast)
988 .build();
989
990 let target_name = Name::from_strings(["org", "ns", "target"]);
991 let payload = b"Hello with metadata".to_vec();
992
993 let mut metadata = HashMap::new();
994 metadata.insert("custom_key".to_string(), "custom_value".to_string());
995
996 controller
997 .publish(
998 &target_name,
999 payload.clone(),
1000 Some("test-type".to_string()),
1001 Some(metadata),
1002 )
1003 .await
1004 .expect("publish with metadata should succeed");
1005
1006 tokio::time::sleep(Duration::from_millis(50)).await;
1007 }
1008
1009 #[tokio::test]
1010 async fn test_invite_participant_in_multicast() {
1011 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1012 .with_session_type(ProtoSessionType::Multicast)
1013 .build();
1014
1015 let participant = Name::from_strings(["org", "ns", "participant"]);
1016
1017 controller
1018 .invite_participant(&participant)
1019 .await
1020 .expect("invite should succeed");
1021
1022 tokio::time::sleep(Duration::from_millis(50)).await;
1023 }
1024
1025 #[tokio::test]
1026 async fn test_invite_participant_not_initiator_error() {
1027 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1028 .with_session_type(ProtoSessionType::Multicast)
1029 .with_initiator(false)
1030 .build();
1031
1032 let participant = Name::from_strings(["org", "ns", "new_participant"]);
1033
1034 let result = controller.invite_participant(&participant).await;
1035 assert!(result.is_err_and(|e| matches!(e, SessionError::NotInitiator)));
1036 }
1037
1038 #[tokio::test]
1039 async fn test_invite_participant_p2p_error() {
1040 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1041 .with_session_type(ProtoSessionType::PointToPoint)
1042 .build();
1043
1044 let participant = Name::from_strings(["org", "ns", "participant"]);
1045
1046 let result = controller.invite_participant(&participant).await;
1047 assert!(result.is_err_and(|e| matches!(e, SessionError::CannotInviteToP2P)));
1048 }
1049
1050 #[tokio::test]
1051 async fn test_remove_participant_in_multicast() {
1052 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1053 .with_session_type(ProtoSessionType::Multicast)
1054 .build();
1055
1056 let participant = Name::from_strings(["org", "ns", "participant"]);
1057
1058 controller
1059 .remove_participant(&participant)
1060 .await
1061 .expect("remove should succeed");
1062
1063 tokio::time::sleep(Duration::from_millis(50)).await;
1064 }
1065
1066 #[tokio::test]
1067 async fn test_remove_participant_not_initiator_error() {
1068 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1069 .with_session_type(ProtoSessionType::Multicast)
1070 .with_initiator(false)
1071 .build();
1072
1073 let participant = Name::from_strings(["org", "ns", "participant"]);
1074
1075 let result = controller.remove_participant(&participant).await;
1076 assert!(result.is_err_and(|e| matches!(e, SessionError::NotInitiator)));
1077 }
1078
1079 #[tokio::test]
1080 async fn test_remove_participant_p2p_error() {
1081 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1082 .with_session_type(ProtoSessionType::PointToPoint)
1083 .build();
1084
1085 let participant = Name::from_strings(["org", "ns", "participant"]);
1086
1087 let result = controller.remove_participant(&participant).await;
1088 assert!(result.is_err_and(|e| matches!(e, SessionError::CannotRemoveFromP2P)));
1089 }
1090
1091 #[test]
1092 fn test_handle_channel_discovery_message() {
1093 let app_name = Name::from_strings(["org", "ns", "app"]).with_id(100);
1094 let session_id = 42;
1095
1096 let discovery_request = Message::builder()
1097 .source(Name::from_strings(["org", "ns", "requester"]).with_id(1))
1098 .destination(Name::from_strings(["org", "ns", "service"]))
1099 .identity("")
1100 .incoming_conn(999)
1101 .session_type(ProtoSessionType::Multicast)
1102 .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
1103 .session_id(session_id)
1104 .message_id(123)
1105 .payload(
1106 CommandPayload::builder()
1107 .discovery_request(None)
1108 .as_content(),
1109 )
1110 .build_publish()
1111 .unwrap();
1112
1113 let response = handle_channel_discovery_message(
1114 &discovery_request,
1115 &app_name,
1116 session_id,
1117 ProtoSessionType::Multicast,
1118 )
1119 .expect("should create discovery response");
1120
1121 assert_eq!(
1122 response.get_session_message_type(),
1123 ProtoSessionMessageType::DiscoveryReply
1124 );
1125 assert_eq!(response.get_session_header().get_session_id(), session_id);
1126 assert_eq!(response.get_id(), 123);
1127 assert_eq!(
1128 response.get_dst(),
1129 Name::from_strings(["org", "ns", "requester"]).with_id(1)
1130 );
1131 assert_eq!(response.get_slim_header().get_forward_to(), Some(999));
1132 }
1133
1134 #[tokio::test]
1135 async fn test_controller_drop_cancels_processing() {
1136 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1137
1138 let token = controller.cancellation_token.clone();
1139 assert!(!token.is_cancelled());
1140
1141 drop(controller);
1142
1143 tokio::time::sleep(Duration::from_millis(100)).await;
1144 assert!(token.is_cancelled());
1145 }
1146
1147 #[tokio::test]
1148 async fn test_close_success() {
1149 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1150 .with_graceful_shutdown_timeout(std::time::Duration::from_secs(2))
1151 .build();
1152
1153 let token = controller.cancellation_token.clone();
1154 assert!(!token.is_cancelled());
1155
1156 let handle = controller.close();
1157 assert!(handle.is_ok(), "got error {}", handle.unwrap_err());
1158 assert!(token.is_cancelled());
1159
1160 handle
1162 .unwrap()
1163 .await
1164 .expect("processing task should complete");
1165 }
1166
1167 #[tokio::test]
1168 async fn test_close_already_closed() {
1169 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1170
1171 let handle = controller.close();
1173 assert!(handle.is_ok());
1174 handle
1175 .unwrap()
1176 .await
1177 .expect("processing task should complete");
1178
1179 let result = controller.close();
1181 assert!(result.is_err());
1182 match result {
1183 Err(SessionError::SessionAlreadyClosed) => {
1184 }
1186 _ => panic!("Expected SessionError::SessionAlreadyClosed"),
1187 }
1188 }
1189
1190 #[tokio::test]
1191 async fn test_close_cancels_token_immediately() {
1192 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1193
1194 let token = controller.cancellation_token.clone();
1195
1196 assert!(!token.is_cancelled());
1198
1199 let handle = controller.close();
1201 assert!(handle.is_ok());
1202
1203 assert!(token.is_cancelled());
1205
1206 handle.unwrap().await.expect("processing should complete");
1208 }
1209
1210 #[tokio::test]
1211 async fn test_on_message_direction_north() {
1212 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1213
1214 let test_message = Message::builder()
1215 .source(controller.dst().clone())
1216 .destination(controller.source().clone())
1217 .identity("")
1218 .session_type(ProtoSessionType::PointToPoint)
1219 .session_message_type(ProtoSessionMessageType::Msg)
1220 .session_id(controller.id())
1221 .message_id(1)
1222 .application_payload("test", b"test data".to_vec())
1223 .build_publish()
1224 .unwrap();
1225
1226 let result = controller.on_message_from_slim(test_message).await;
1227 assert!(result.is_ok());
1228 }
1229
1230 #[tokio::test]
1231 async fn test_create_discovery_request() {
1232 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1233 .with_session_type(ProtoSessionType::Multicast)
1234 .build();
1235
1236 let target = Name::from_strings(["org", "ns", "target"]);
1237 let discovery_msg = controller
1238 .create_discovery_request(&target)
1239 .expect("should create discovery request");
1240
1241 assert_eq!(discovery_msg.get_source(), *controller.source());
1242 assert_eq!(discovery_msg.get_dst(), target);
1243 assert_eq!(
1244 discovery_msg.get_session_message_type(),
1245 ProtoSessionMessageType::DiscoveryRequest
1246 );
1247 assert_eq!(
1248 discovery_msg.get_session_header().get_session_id(),
1249 controller.id()
1250 );
1251 assert_eq!(
1252 discovery_msg.get_session_type(),
1253 ProtoSessionType::Multicast
1254 );
1255 }
1256
1257 #[tokio::test]
1258 #[traced_test]
1259 async fn test_end_to_end_p2p() {
1260 let session_id = 10;
1261 let moderator_name = Name::from_strings(["org", "ns", "moderator"]).with_id(1);
1262 let participant_name = Name::from_strings(["org", "ns", "participant"]);
1263 let participant_name_id = Name::from_strings(["org", "ns", "participant"]).with_id(1);
1264 let (tx_slim_moderator, mut rx_slim_moderator) = tokio::sync::mpsc::channel(10);
1266 let (tx_app_moderator, _rx_app_moderator) = tokio::sync::mpsc::unbounded_channel();
1267 let (tx_session_layer_moderator, _rx_session_layer_moderator) =
1268 tokio::sync::mpsc::channel(10);
1269
1270 let tx_moderator =
1271 SessionTransmitter::new(tx_slim_moderator.clone(), tx_app_moderator.clone());
1272
1273 let moderator_config = SessionConfig {
1274 session_type: slim_datapath::api::ProtoSessionType::PointToPoint,
1275 max_retries: Some(5),
1276 interval: Some(Duration::from_millis(1000)),
1277 mls_enabled: true,
1278 initiator: true,
1279 metadata: std::collections::HashMap::new(),
1280 };
1281
1282 let (spy_moderator_mgr, mut rx_spy_moderator) = SpySubscriptionManager::new();
1283 let moderator = SessionController::builder()
1284 .with_id(session_id)
1285 .with_source(moderator_name.clone())
1286 .with_destination(participant_name.clone())
1287 .with_config(moderator_config)
1288 .with_identity_provider(SharedSecret::new("moderator", SHARED_SECRET).unwrap())
1289 .with_identity_verifier(SharedSecret::new("moderator", SHARED_SECRET).unwrap())
1290 .with_tx(tx_moderator.clone())
1291 .with_tx_to_session_layer(tx_session_layer_moderator)
1292 .with_subscription_manager(spy_moderator_mgr)
1293 .ready()
1294 .expect("failed to validate builder")
1295 .build()
1296 .unwrap();
1297
1298 let (tx_slim_participant, mut rx_slim_participant) = tokio::sync::mpsc::channel(10);
1300 let (tx_app_participant, mut rx_app_participant) = tokio::sync::mpsc::unbounded_channel();
1301 let (tx_session_layer_participant, _rx_session_layer_participant) =
1302 tokio::sync::mpsc::channel(10);
1303
1304 let tx_participant =
1305 SessionTransmitter::new(tx_slim_participant.clone(), tx_app_participant.clone());
1306
1307 let participant_config = SessionConfig {
1308 session_type: slim_datapath::api::ProtoSessionType::PointToPoint,
1309 max_retries: Some(5),
1310 interval: Some(Duration::from_millis(200)),
1311 mls_enabled: true,
1312 initiator: false,
1313 metadata: std::collections::HashMap::new(),
1314 };
1315
1316 let (spy_participant_mgr, mut rx_spy_participant) = SpySubscriptionManager::new();
1317 let participant = SessionController::builder()
1318 .with_id(session_id)
1319 .with_source(participant_name_id.clone())
1320 .with_destination(moderator_name.clone())
1321 .with_config(participant_config)
1322 .with_identity_provider(SharedSecret::new("participant", SHARED_SECRET).unwrap())
1323 .with_identity_verifier(SharedSecret::new("participant", SHARED_SECRET).unwrap())
1324 .with_tx(tx_participant.clone())
1325 .with_tx_to_session_layer(tx_session_layer_participant)
1326 .with_subscription_manager(spy_participant_mgr)
1327 .ready()
1328 .expect("failed to validate builder")
1329 .build()
1330 .unwrap();
1331
1332 let completion_handle = moderator
1333 .invite_participant_internal(&participant_name)
1334 .await
1335 .expect("error inviting participant");
1336
1337 let received_discovery_request =
1338 timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1339 .await
1340 .expect("timeout waiting for discovery request on moderator slim channel")
1341 .expect("channel closed")
1342 .expect("error in discovery request");
1343
1344 assert_eq!(
1345 received_discovery_request.get_session_message_type(),
1346 slim_datapath::api::ProtoSessionMessageType::DiscoveryRequest
1347 );
1348
1349 let discovery_msg_id = received_discovery_request.get_id();
1350
1351 let mut discovery_reply = Message::builder()
1353 .source(participant_name_id.clone())
1354 .destination(moderator_name.clone())
1355 .identity("")
1356 .forward_to(1)
1357 .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1358 .session_message_type(slim_datapath::api::ProtoSessionMessageType::DiscoveryReply)
1359 .session_id(session_id)
1360 .message_id(discovery_msg_id)
1361 .payload(CommandPayload::builder().discovery_reply().as_content())
1362 .build_publish()
1363 .unwrap();
1364 discovery_reply
1365 .get_slim_header_mut()
1366 .set_incoming_conn(Some(1));
1367
1368 moderator
1369 .on_message_from_slim(discovery_reply)
1370 .await
1371 .expect("error processing discovery reply on moderator");
1372
1373 assert_eq!(
1375 rx_spy_moderator.recv().await,
1376 Some(SubscriptionCall::SetRoute),
1377 "moderator should set route after discovery reply"
1378 );
1379
1380 let join_request = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1382 .await
1383 .expect("timeout waiting for join request on moderator slim channel")
1384 .expect("channel closed")
1385 .expect("error in join request");
1386
1387 assert_eq!(
1388 join_request.get_session_message_type(),
1389 slim_datapath::api::ProtoSessionMessageType::JoinRequest
1390 );
1391 assert_eq!(join_request.get_dst(), participant_name_id);
1392
1393 let mut join_request_to_participant = join_request.clone();
1395 join_request_to_participant
1396 .get_slim_header_mut()
1397 .set_incoming_conn(Some(1));
1398
1399 participant
1400 .on_message_from_slim(join_request_to_participant)
1401 .await
1402 .expect("error processing join request on participant");
1403
1404 assert_eq!(
1406 rx_spy_participant.recv().await,
1407 Some(SubscriptionCall::SetRoute),
1408 "participant should set route after join request"
1409 );
1410
1411 let join_reply = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1413 .await
1414 .expect("timeout waiting for join reply on participant slim channel")
1415 .expect("channel closed")
1416 .expect("error in join reply");
1417
1418 assert_eq!(
1419 join_reply.get_session_message_type(),
1420 slim_datapath::api::ProtoSessionMessageType::JoinReply
1421 );
1422 assert_eq!(join_reply.get_dst(), moderator_name);
1423
1424 let mut join_reply_to_moderator = join_reply.clone();
1426 join_reply_to_moderator
1427 .get_slim_header_mut()
1428 .set_incoming_conn(Some(1));
1429
1430 moderator
1431 .on_message_from_slim(join_reply_to_moderator)
1432 .await
1433 .expect("error processing join reply on moderator");
1434
1435 let welcome_message = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1437 .await
1438 .expect("timeout waiting for welcome message on moderator slim channel")
1439 .expect("channel closed")
1440 .expect("error in welcome message");
1441
1442 assert_eq!(
1443 welcome_message.get_session_message_type(),
1444 slim_datapath::api::ProtoSessionMessageType::GroupWelcome
1445 );
1446 assert_eq!(welcome_message.get_dst(), participant_name_id);
1447
1448 let mut welcome_to_participant = welcome_message.clone();
1450 welcome_to_participant
1451 .get_slim_header_mut()
1452 .set_incoming_conn(Some(1));
1453
1454 participant
1455 .on_message_from_slim(welcome_to_participant)
1456 .await
1457 .expect("error processing welcome message on participant");
1458
1459 let ack_group = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1461 .await
1462 .expect("timeout waiting for ack group on participant slim channel")
1463 .expect("channel closed")
1464 .expect("error in ack group");
1465
1466 assert_eq!(
1467 ack_group.get_session_message_type(),
1468 slim_datapath::api::ProtoSessionMessageType::GroupAck
1469 );
1470 assert_eq!(ack_group.get_dst(), moderator_name);
1471
1472 let mut ack_to_moderator = ack_group.clone();
1474 ack_to_moderator
1475 .get_slim_header_mut()
1476 .set_incoming_conn(Some(1));
1477
1478 moderator
1479 .on_message_from_slim(ack_to_moderator)
1480 .await
1481 .expect("error processing ack group on moderator");
1482
1483 let no_more_moderator = timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1485 assert!(
1486 no_more_moderator.is_err(),
1487 "Expected no more messages on moderator slim channel, received {:?}",
1488 no_more_moderator
1489 .ok()
1490 .and_then(|opt| opt)
1491 .and_then(|res| res.ok())
1492 );
1493
1494 let no_more_participant =
1495 timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1496 assert!(
1497 no_more_participant.is_err(),
1498 "Expected no more messages on participant slim channel"
1499 );
1500
1501 completion_handle.await.expect("error in completion handle");
1503
1504 let app_data = b"Hello from moderator to participant".to_vec();
1506 let app_message = Message::builder()
1507 .source(moderator_name.clone())
1508 .destination(participant_name.clone())
1509 .identity("")
1510 .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1511 .session_message_type(slim_datapath::api::ProtoSessionMessageType::Msg)
1512 .session_id(session_id)
1513 .message_id(1)
1514 .application_payload("test-app-data", app_data.clone())
1515 .build_publish()
1516 .unwrap();
1517
1518 moderator
1520 .on_message_from_app(app_message)
1521 .await
1522 .expect("error sending application message from moderator");
1523
1524 let app_msg_to_slim = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1526 .await
1527 .expect("timeout waiting for application message on moderator slim channel")
1528 .expect("channel closed")
1529 .expect("error in application message");
1530
1531 assert_eq!(app_msg_to_slim.get_dst(), participant_name_id);
1532 assert!(
1533 app_msg_to_slim.is_publish(),
1534 "message should be a publish message"
1535 );
1536
1537 let app_msg_id = app_msg_to_slim.get_id();
1538
1539 let mut app_msg_to_participant = app_msg_to_slim.clone();
1541 app_msg_to_participant
1542 .get_slim_header_mut()
1543 .set_incoming_conn(Some(1));
1544
1545 participant
1546 .on_message_from_slim(app_msg_to_participant)
1547 .await
1548 .expect("error processing application message on participant");
1549
1550 let app_msg_received = timeout(Duration::from_millis(100), rx_app_participant.recv())
1552 .await
1553 .expect("timeout waiting for application message on participant app channel")
1554 .expect("channel closed")
1555 .expect("error in application message to app");
1556
1557 assert_eq!(app_msg_received.get_source(), moderator_name);
1558 assert!(
1559 app_msg_received.is_publish(),
1560 "message should be a publish message"
1561 );
1562
1563 let content = app_msg_received
1564 .get_payload()
1565 .unwrap()
1566 .as_application_payload()
1567 .unwrap()
1568 .blob
1569 .clone();
1570 assert_eq!(content, app_data);
1571
1572 let ack_msg = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1574 .await
1575 .expect("timeout waiting for ack on participant slim channel")
1576 .expect("channel closed")
1577 .expect("error in ack");
1578
1579 assert_eq!(
1580 ack_msg.get_session_message_type(),
1581 slim_datapath::api::ProtoSessionMessageType::MsgAck,
1582 "message should be an ack"
1583 );
1584 assert_eq!(ack_msg.get_dst(), moderator_name);
1585 assert_eq!(ack_msg.get_id(), app_msg_id);
1586
1587 let mut ack_to_moderator = ack_msg.clone();
1589 ack_to_moderator
1590 .get_slim_header_mut()
1591 .set_incoming_conn(Some(1));
1592
1593 moderator
1594 .on_message_from_slim(ack_to_moderator)
1595 .await
1596 .expect("error processing ack on moderator");
1597
1598 let no_more_moderator_after_ack =
1600 timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1601 assert!(
1602 no_more_moderator_after_ack.is_err(),
1603 "Expected no more messages on moderator slim channel after ack"
1604 );
1605
1606 let no_more_participant_after_ack =
1607 timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1608 assert!(
1609 no_more_participant_after_ack.is_err(),
1610 "Expected no more messages on participant slim channel after ack"
1611 );
1612
1613 let leave_request = Message::builder()
1615 .source(moderator_name.clone())
1616 .destination(participant_name.clone())
1617 .identity("")
1618 .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1619 .session_message_type(slim_datapath::api::ProtoSessionMessageType::LeaveRequest)
1620 .session_id(session_id)
1621 .message_id(rand::random::<u32>())
1622 .payload(CommandPayload::builder().leave_request(None).as_content())
1623 .build_publish()
1624 .unwrap();
1625
1626 moderator
1627 .on_message_from_app(leave_request)
1628 .await
1629 .expect("error sending leave request");
1630
1631 let received_leave_request = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1633 .await
1634 .expect("timeout waiting for leave request on moderator slim channel")
1635 .expect("channel closed")
1636 .expect("error in leave request");
1637
1638 assert_eq!(
1639 received_leave_request.get_session_message_type(),
1640 slim_datapath::api::ProtoSessionMessageType::LeaveRequest
1641 );
1642 assert_eq!(received_leave_request.get_dst(), participant_name_id);
1643
1644 let mut leave_request_to_participant = received_leave_request.clone();
1646 leave_request_to_participant
1647 .get_slim_header_mut()
1648 .set_incoming_conn(Some(1));
1649
1650 participant
1651 .on_message_from_slim(leave_request_to_participant)
1652 .await
1653 .expect("error processing leave request on participant");
1654
1655 let leave_reply = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1657 .await
1658 .expect("timeout waiting for leave reply on participant slim channel")
1659 .expect("channel closed")
1660 .expect("error in leave reply");
1661
1662 assert_eq!(
1663 leave_reply.get_session_message_type(),
1664 slim_datapath::api::ProtoSessionMessageType::LeaveReply
1665 );
1666 assert_eq!(leave_reply.get_dst(), moderator_name);
1667
1668 assert_eq!(
1670 rx_spy_participant.recv().await,
1671 Some(SubscriptionCall::RemoveRoute),
1672 "participant should remove route after leave request"
1673 );
1674
1675 let mut leave_reply_to_moderator = leave_reply.clone();
1677 leave_reply_to_moderator
1678 .get_slim_header_mut()
1679 .set_incoming_conn(Some(1));
1680
1681 moderator
1682 .on_message_from_slim(leave_reply_to_moderator)
1683 .await
1684 .expect("error processing leave reply on moderator");
1685
1686 assert_eq!(
1688 rx_spy_moderator.recv().await,
1689 Some(SubscriptionCall::RemoveRoute),
1690 "moderator should remove route after leave reply"
1691 );
1692
1693 let no_more_moderator_final =
1695 timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1696
1697 assert!(
1698 no_more_moderator_final.is_err(),
1699 "Expected no more messages on moderator slim channel after leave"
1700 );
1701
1702 let no_more_participant_final =
1703 timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1704 assert!(
1705 no_more_participant_final.is_err(),
1706 "Expected no more messages on participant slim channel after leave"
1707 );
1708 }
1709
1710 #[traced_test]
1713 #[tokio::test]
1714 async fn test_internal_draining_via_processing_state_switch() {
1715 use super::*;
1716 use tokio::sync::mpsc;
1717 use tracing::debug;
1718
1719 struct InternalDrainHandler {
1721 state: ProcessingState,
1722 messages: Vec<SessionMessage>,
1723 needs_drain: Arc<AtomicBool>,
1724 }
1725
1726 impl InternalDrainHandler {
1727 fn new(needs_drain: Arc<AtomicBool>) -> Self {
1728 Self {
1729 state: ProcessingState::Active,
1730 messages: vec![],
1731 needs_drain,
1732 }
1733 }
1734 }
1735
1736 #[async_trait::async_trait]
1737 impl MessageHandler for InternalDrainHandler {
1738 async fn init(&mut self) -> Result<(), SessionError> {
1739 Ok(())
1740 }
1741
1742 async fn on_message(&mut self, message: SessionMessage) -> Result<(), SessionError> {
1743 debug!(?self.state, "internal-drain-handler received message");
1744 self.messages.push(message);
1745
1746 if self.messages.len() == 2 {
1748 debug!("internal-drain-handler transitioning to draining");
1749 self.state = ProcessingState::Draining;
1750 }
1751
1752 Ok(())
1753 }
1754
1755 fn needs_drain(&self) -> bool {
1756 self.needs_drain.load(std::sync::atomic::Ordering::SeqCst)
1757 }
1758
1759 fn processing_state(&self) -> ProcessingState {
1760 self.state
1761 }
1762
1763 async fn on_shutdown(&mut self) -> Result<(), SessionError> {
1764 debug!("shutdown called on handler");
1765 Ok(())
1766 }
1767 }
1768
1769 let (tx_slim, _rx_slim) = mpsc::channel(8);
1771 let (tx_app, _rx_app) = mpsc::unbounded_channel();
1772 let (tx_session, rx_session) = mpsc::channel(32);
1773 let (tx_session_layer, _rx_session_layer) = mpsc::channel(8);
1774
1775 let subscription_manager =
1776 crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
1777 let settings = SessionSettings {
1778 id: 999,
1779 source: Name::from_strings(["org", "ns", "source"]).with_id(1),
1780 destination: Name::from_strings(["org", "ns", "dest"]).with_id(2),
1781 config: SessionConfig {
1782 session_type: ProtoSessionType::PointToPoint,
1783 max_retries: Some(3),
1784 interval: Some(Duration::from_millis(150)),
1785 mls_enabled: false,
1786 initiator: true,
1787 metadata: HashMap::new(),
1788 },
1789 tx: SessionTransmitter::new(tx_slim, tx_app),
1790 tx_session: tx_session.clone(),
1791 tx_to_session_layer: tx_session_layer,
1792 identity_provider: SharedSecret::new("src", SHARED_SECRET).unwrap(),
1793 identity_verifier: SharedSecret::new("src", SHARED_SECRET).unwrap(),
1794 graceful_shutdown_timeout: Some(Duration::from_secs(10)),
1795 subscription_manager,
1796 service_id: String::new(),
1797 };
1798
1799 let needs_drain = Arc::new(AtomicBool::new(true));
1800 let handler = InternalDrainHandler::new(needs_drain.clone());
1801 let cancellation_token = CancellationToken::new();
1802 let cancellation_token_clone = cancellation_token.clone();
1803
1804 let processing_handle = tokio::spawn(async move {
1806 SessionController::processing_loop(
1807 handler,
1808 rx_session,
1809 cancellation_token_clone,
1810 settings,
1811 )
1812 .await
1813 });
1814
1815 tx_session
1817 .send(create_test_message(1, b"first".to_vec()))
1818 .await
1819 .expect("failed to send first message");
1820
1821 tokio::time::sleep(Duration::from_millis(100)).await;
1822
1823 assert!(logs_contain("internal-drain-handler received message"));
1824
1825 tx_session
1827 .send(create_test_message(2, b"second".to_vec()))
1828 .await
1829 .expect("failed to send second message");
1830
1831 tokio::time::sleep(Duration::from_millis(100)).await;
1832
1833 assert!(logs_contain("internal-drain-handler received message"));
1834
1835 assert!(logs_contain(
1836 "internal-drain-handler transitioning to draining"
1837 ));
1838
1839 tx_session
1841 .send(create_test_message(3, b"third".to_vec()))
1842 .await
1843 .expect("failed to send third message");
1844
1845 tokio::time::sleep(Duration::from_millis(100)).await;
1846
1847 assert!(logs_contain(
1848 "session is draining, rejecting new messages from application"
1849 ));
1850
1851 needs_drain.store(false, std::sync::atomic::Ordering::SeqCst);
1853
1854 cancellation_token.cancel();
1856
1857 tokio::time::sleep(Duration::from_millis(100)).await;
1858
1859 tx_session
1861 .send(SessionMessage::StartDrain {
1862 grace_period: std::time::Duration::from_millis(100),
1863 })
1864 .await
1865 .expect("failed to send timeout message");
1866
1867 processing_handle.await.expect("processing loop panicked");
1869 }
1870 struct DrainableHandler {
1874 messages_received: Arc<tokio::sync::Mutex<Vec<SessionMessage>>>,
1875 needs_drain: Arc<AtomicBool>,
1876 shutdown_called: Arc<tokio::sync::Mutex<bool>>,
1877 drain_delay: Option<Duration>,
1878 }
1879
1880 impl DrainableHandler {
1881 fn new() -> Self {
1882 Self {
1883 messages_received: Arc::new(tokio::sync::Mutex::new(Vec::new())),
1884 needs_drain: Arc::new(AtomicBool::new(false)),
1885 shutdown_called: Arc::new(tokio::sync::Mutex::new(false)),
1886 drain_delay: None,
1887 }
1888 }
1889
1890 fn with_needs_drain(self, needs_drain: bool) -> Self {
1891 self.needs_drain
1892 .store(needs_drain, std::sync::atomic::Ordering::SeqCst);
1893 self
1894 }
1895
1896 #[allow(dead_code)]
1897 fn with_drain_delay(mut self, delay: Duration) -> Self {
1898 self.drain_delay = Some(delay);
1899 self
1900 }
1901
1902 #[allow(dead_code)]
1903 async fn get_messages_count(&self) -> usize {
1904 self.messages_received.lock().await.len()
1905 }
1906
1907 #[allow(dead_code)]
1908 async fn was_shutdown_called(&self) -> bool {
1909 *self.shutdown_called.lock().await
1910 }
1911 }
1912
1913 #[async_trait::async_trait]
1914 impl MessageHandler for DrainableHandler {
1915 async fn init(&mut self) -> Result<(), SessionError> {
1916 Ok(())
1917 }
1918
1919 async fn on_message(&mut self, message: SessionMessage) -> Result<(), SessionError> {
1920 self.messages_received.lock().await.push(message);
1921 Ok(())
1922 }
1923
1924 fn needs_drain(&self) -> bool {
1925 self.needs_drain.load(std::sync::atomic::Ordering::SeqCst)
1926 }
1927
1928 async fn on_shutdown(&mut self) -> Result<(), SessionError> {
1929 if let Some(delay) = self.drain_delay {
1930 tokio::time::sleep(delay).await;
1931 }
1932 *self.shutdown_called.lock().await = true;
1933 Ok(())
1934 }
1935 }
1936
1937 fn create_test_settings(
1939 graceful_shutdown_timeout: Option<Duration>,
1940 ) -> SessionSettings<SharedSecret, SharedSecret> {
1941 let (tx_slim, _rx_slim) = tokio::sync::mpsc::channel(10);
1942 let (tx_app, _rx_app) = tokio::sync::mpsc::unbounded_channel();
1943 let (tx_session, _rx_session) = tokio::sync::mpsc::channel(10);
1944 let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10);
1945
1946 let subscription_manager =
1947 crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
1948 SessionSettings {
1949 id: 1,
1950 source: Name::from_strings(["org", "ns", "test"]).with_id(1),
1951 destination: Name::from_strings(["org", "ns", "test"]).with_id(2),
1952 config: SessionConfig {
1953 session_type: ProtoSessionType::PointToPoint,
1954 max_retries: Some(5),
1955 interval: Some(Duration::from_millis(200)),
1956 mls_enabled: false,
1957 initiator: true,
1958 metadata: HashMap::new(),
1959 },
1960 tx: SessionTransmitter::new(tx_slim, tx_app),
1961 tx_session,
1962 tx_to_session_layer: tx_session_layer,
1963 identity_provider: SharedSecret::new("test", SHARED_SECRET).unwrap(),
1964 identity_verifier: SharedSecret::new("test", SHARED_SECRET).unwrap(),
1965 graceful_shutdown_timeout,
1966 subscription_manager,
1967 service_id: String::new(),
1968 }
1969 }
1970
1971 fn create_test_message(message_id: u32, payload: Vec<u8>) -> SessionMessage {
1973 SessionMessage::OnMessage {
1974 message: Message::builder()
1975 .source(Name::from_strings(["org", "ns", "test"]).with_id(1))
1976 .destination(Name::from_strings(["org", "ns", "test"]).with_id(2))
1977 .identity("")
1978 .forward_to(1)
1979 .session_type(ProtoSessionType::PointToPoint)
1980 .session_message_type(ProtoSessionMessageType::Msg)
1981 .session_id(1)
1982 .message_id(message_id)
1983 .application_payload("test", payload)
1984 .build_publish()
1985 .unwrap(),
1986 direction: MessageDirection::South,
1987 ack_tx: None,
1988 }
1989 }
1990
1991 async fn count_on_messages(messages: &Arc<tokio::sync::Mutex<Vec<SessionMessage>>>) -> usize {
1992 let messages = messages.lock().await;
1993 messages
1994 .iter()
1995 .filter(|msg| matches!(msg, SessionMessage::OnMessage { .. }))
1996 .count()
1997 }
1998
1999 fn spawn_processing_loop(
2001 handler: DrainableHandler,
2002 rx: tokio::sync::mpsc::Receiver<SessionMessage>,
2003 cancellation_token: CancellationToken,
2004 settings: SessionSettings<SharedSecret, SharedSecret>,
2005 ) -> tokio::task::JoinHandle<()> {
2006 tokio::spawn(async move {
2007 SessionController::processing_loop(handler, rx, cancellation_token, settings).await;
2008 })
2009 }
2010
2011 #[tokio::test]
2012 async fn test_draining_processes_queued_messages() {
2013 let handler = DrainableHandler::new();
2014 let messages_received = handler.messages_received.clone();
2015 let shutdown_called = handler.shutdown_called.clone();
2016
2017 let (tx, rx) = tokio::sync::mpsc::channel(10);
2018 let cancellation_token = CancellationToken::new();
2019 let token_clone = cancellation_token.clone();
2020
2021 let settings = create_test_settings(Some(Duration::from_secs(2)));
2022 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2023
2024 tx.send(create_test_message(1, vec![1, 2, 3]))
2026 .await
2027 .unwrap();
2028 tx.send(create_test_message(2, vec![4, 5, 6]))
2029 .await
2030 .unwrap();
2031 tx.send(create_test_message(3, vec![7, 8, 9]))
2032 .await
2033 .unwrap();
2034
2035 tokio::time::sleep(Duration::from_millis(50)).await;
2037
2038 token_clone.cancel();
2040
2041 drop(tx);
2043
2044 timeout(Duration::from_secs(3), processing_task)
2046 .await
2047 .expect("timeout waiting for processing loop")
2048 .expect("processing loop panicked");
2049
2050 let processed_messages = count_on_messages(&messages_received).await;
2052 assert_eq!(
2053 processed_messages, 3,
2054 "All queued messages should be processed during draining"
2055 );
2056 assert!(
2057 *shutdown_called.lock().await,
2058 "Shutdown should have been called"
2059 );
2060 }
2061
2062 #[tokio::test]
2063 async fn test_draining_with_needs_drain_true() {
2064 let handler = DrainableHandler::new().with_needs_drain(true);
2065 let messages_received = handler.messages_received.clone();
2066 let shutdown_called = handler.shutdown_called.clone();
2067
2068 let (tx, rx) = tokio::sync::mpsc::channel(10);
2069 let cancellation_token = CancellationToken::new();
2070 let token_clone = cancellation_token.clone();
2071
2072 let settings = create_test_settings(Some(Duration::from_secs(2)));
2073 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2074
2075 tx.send(create_test_message(1, vec![1, 2, 3]))
2077 .await
2078 .unwrap();
2079 tokio::time::sleep(Duration::from_millis(50)).await;
2080
2081 token_clone.cancel();
2083 drop(tx);
2084
2085 timeout(Duration::from_secs(3), processing_task)
2087 .await
2088 .expect("timeout waiting for processing loop")
2089 .expect("processing loop panicked");
2090
2091 let processed_messages = count_on_messages(&messages_received).await;
2093 assert_eq!(processed_messages, 1, "Message should be processed");
2094 assert!(
2095 *shutdown_called.lock().await,
2096 "Shutdown should have been called after draining"
2097 );
2098 }
2099
2100 #[tokio::test]
2101 async fn test_draining_with_needs_drain_false() {
2102 let handler = DrainableHandler::new().with_needs_drain(false);
2103 let messages_received = handler.messages_received.clone();
2104 let shutdown_called = handler.shutdown_called.clone();
2105
2106 let (tx, rx) = tokio::sync::mpsc::channel(10);
2107 let cancellation_token = CancellationToken::new();
2108 let token_clone = cancellation_token.clone();
2109
2110 let settings = create_test_settings(Some(Duration::from_secs(2)));
2111
2112 let start_time = tokio::time::Instant::now();
2113 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2114
2115 tx.send(create_test_message(1, vec![1, 2, 3]))
2117 .await
2118 .unwrap();
2119 tokio::time::sleep(Duration::from_millis(50)).await;
2120
2121 token_clone.cancel();
2123 drop(tx);
2124
2125 timeout(Duration::from_secs(1), processing_task)
2127 .await
2128 .expect("timeout waiting for processing loop")
2129 .expect("processing loop panicked");
2130
2131 let elapsed = start_time.elapsed();
2132
2133 let processed_messages = count_on_messages(&messages_received).await;
2135 assert_eq!(processed_messages, 1, "Message should be processed");
2136 assert!(
2137 *shutdown_called.lock().await,
2138 "Shutdown should have been called"
2139 );
2140 assert!(
2141 elapsed < Duration::from_millis(500),
2142 "Should exit quickly when no draining needed"
2143 );
2144 }
2145
2146 #[tokio::test]
2147 async fn test_draining_timeout_enforced() {
2148 let handler = DrainableHandler::new().with_needs_drain(true);
2150
2151 let (tx, rx) = tokio::sync::mpsc::channel(10);
2152 let cancellation_token = CancellationToken::new();
2153 let token_clone = cancellation_token.clone();
2154
2155 let settings = create_test_settings(Some(Duration::from_millis(500)));
2156 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2157
2158 tokio::time::sleep(Duration::from_millis(50)).await;
2160
2161 let start_time = tokio::time::Instant::now();
2162
2163 token_clone.cancel();
2165
2166 let send_task = tokio::spawn(async move {
2169 for i in 0..10 {
2170 tokio::time::sleep(Duration::from_millis(100)).await;
2171 if tx
2172 .send(create_test_message(i, vec![i as u8]))
2173 .await
2174 .is_err()
2175 {
2176 break;
2177 }
2178 }
2179 });
2180
2181 timeout(Duration::from_secs(2), processing_task)
2183 .await
2184 .expect("timeout waiting for processing loop")
2185 .expect("processing loop panicked");
2186
2187 let elapsed = start_time.elapsed();
2188
2189 assert!(
2191 elapsed >= Duration::from_millis(400),
2192 "Should wait at least close to the timeout period"
2193 );
2194 assert!(
2195 elapsed < Duration::from_secs(2),
2196 "Should respect the timeout and exit, not wait forever"
2197 );
2198
2199 send_task.abort();
2201 }
2202
2203 #[tokio::test]
2204 async fn test_draining_no_messages_in_queue() {
2205 let handler = DrainableHandler::new().with_needs_drain(true);
2206 let messages_received = handler.messages_received.clone();
2207 let shutdown_called = handler.shutdown_called.clone();
2208
2209 let (tx, rx) = tokio::sync::mpsc::channel(10);
2210 let cancellation_token = CancellationToken::new();
2211 let token_clone = cancellation_token.clone();
2212
2213 let settings = create_test_settings(Some(Duration::from_secs(1)));
2214 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2215
2216 token_clone.cancel();
2218 drop(tx);
2219
2220 timeout(Duration::from_secs(2), processing_task)
2222 .await
2223 .expect("timeout waiting for processing loop")
2224 .expect("processing loop panicked");
2225
2226 let processed_messages = count_on_messages(&messages_received).await;
2228 assert_eq!(processed_messages, 0, "No messages should be processed");
2229 assert!(
2230 *shutdown_called.lock().await,
2231 "Shutdown should still be called"
2232 );
2233 }
2234
2235 #[tokio::test]
2236 async fn test_draining_messages_after_cancellation_processed() {
2237 let handler = DrainableHandler::new();
2238 let messages_received = handler.messages_received.clone();
2239
2240 let (tx, rx) = tokio::sync::mpsc::channel(10);
2241 let cancellation_token = CancellationToken::new();
2242 let token_clone = cancellation_token.clone();
2243
2244 let settings = create_test_settings(Some(Duration::from_secs(2)));
2245
2246 tx.send(create_test_message(1, vec![1, 2, 3]))
2248 .await
2249 .unwrap();
2250 tx.send(create_test_message(2, vec![4, 5, 6]))
2251 .await
2252 .unwrap();
2253
2254 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2256
2257 tokio::time::sleep(Duration::from_millis(50)).await;
2259
2260 token_clone.cancel();
2262
2263 drop(tx);
2265
2266 timeout(Duration::from_secs(3), processing_task)
2268 .await
2269 .expect("timeout waiting for processing loop")
2270 .expect("processing loop panicked");
2271
2272 let processed_messages = count_on_messages(&messages_received).await;
2274 assert_eq!(
2275 processed_messages, 2,
2276 "Messages in queue during cancellation should be processed"
2277 );
2278 }
2279}