1use std::{collections::HashMap, time::Duration};
6
7use display_error_chain::ErrorChainExt;
8use parking_lot::Mutex;
9use tokio::sync::{self, oneshot};
10use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug};
13
14use slim_auth::traits::{TokenProvider, Verifier};
15use slim_datapath::{
16 api::{
17 CommandPayload, Content, NameId, ProtoMessage as Message, ProtoName,
18 ProtoSessionMessageType, ProtoSessionType, SlimHeader,
19 },
20 messages::utils::SlimHeaderFlags,
21};
22
23use crate::{
25 MessageDirection, SessionError,
26 common::{OutboundMessage, SessionMessage, SessionOutput},
27 completion_handle::CompletionHandle,
28 controller_sender::{ControllerSender, PING_INTERVAL},
29 session_builder::{ForController, SessionBuilder},
30 session_config::SessionConfig,
31 session_settings::SessionSettings,
32 traits::{MessageHandler, ProcessingState},
33};
34
35pub(crate) async fn verify_identity<V>(msg: &Message, verifier: &V) -> Result<(), SessionError>
36where
37 V: Verifier + Send + Sync,
38{
39 let identity = msg.get_slim_header().get_identity();
40 if verifier.try_verify(&identity).is_err() {
41 verifier.verify(&identity).await?;
42 }
43 Ok(())
44}
45
46pub struct SessionController {
47 pub(crate) id: u32,
49
50 pub(crate) source: ProtoName,
52
53 pub(crate) destination: ProtoName,
55
56 pub(crate) config: SessionConfig,
58
59 tx_controller: sync::mpsc::Sender<SessionMessage>,
61
62 pub(crate) cancellation_token: CancellationToken,
64
65 handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
67}
68
69impl SessionController {
70 pub fn builder<P, V>() -> SessionBuilder<P, V, ForController>
72 where
73 P: TokenProvider + Send + Sync + Clone + 'static,
74 V: Verifier + Send + Sync + Clone + 'static,
75 {
76 SessionBuilder::for_controller()
77 }
78
79 #[allow(clippy::too_many_arguments)]
81 pub(crate) fn from_parts<I, P, V, M>(
82 id: u32,
83 source: ProtoName,
84 destination: ProtoName,
85 config: SessionConfig,
86 settings: SessionSettings<P, V, M>,
87 tx: sync::mpsc::Sender<SessionMessage>,
88 rx: sync::mpsc::Receiver<SessionMessage>,
89 inner: I,
90 ) -> Self
91 where
92 I: MessageHandler + Send + Sync + 'static,
93 P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
94 V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
95 M: crate::subscription_manager::SubscriptionOps,
96 {
97 let cancellation_token = CancellationToken::new();
99
100 let span = tracing::debug_span!(
102 parent: None,
103 "session_controller_processing_loop",
104 session_id = id,
105 service_id = %settings.service_id,
106 source = %source,
107 destination = %destination,
108 session_type = ?config.session_type
109 );
110
111 let handle = crate::runtime::spawn(
112 Self::processing_loop(inner, rx, cancellation_token.clone(), settings).instrument(span),
113 );
114
115 Self {
116 id,
117 source,
118 destination,
119 config,
120 tx_controller: tx,
121 cancellation_token,
122 handle: Mutex::new(Some(handle)),
123 }
124 }
125
126 fn enter_draining_state<P, V, M>(
128 shutdown_deadline: &mut std::pin::Pin<&mut tokio::time::Sleep>,
129 settings: &SessionSettings<P, V, M>,
130 ) where
131 P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
132 V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
133 M: crate::subscription_manager::SubscriptionOps,
134 {
135 let shutdown_timeout = settings
136 .graceful_shutdown_timeout
137 .unwrap_or(Duration::from_secs(60));
138 shutdown_deadline
139 .as_mut()
140 .reset(tokio::time::Instant::now() + shutdown_timeout);
141 }
142
143 pub(crate) fn apply_identity_to_slim_output<P>(
146 output: &mut SessionOutput,
147 identity_provider: &P,
148 ) -> Result<(), SessionError>
149 where
150 P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
151 {
152 let identity = identity_provider.get_token()?;
153 for msg in &mut output.messages {
154 if let OutboundMessage::ToSlim(m) = msg {
155 m.get_slim_header_mut().set_identity(identity.clone());
156 }
157 }
158 Ok(())
159 }
160
161 async fn dispatch_output<P, V, M>(output: SessionOutput, settings: &SessionSettings<P, V, M>)
164 where
165 P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
166 V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
167 M: crate::subscription_manager::SubscriptionOps,
168 {
169 for msg in output.messages {
170 match msg {
171 OutboundMessage::ToSlim(message) => {
172 if let Err(e) = settings.slim_tx.send(Ok(message)).await {
173 tracing::error!(error = %e, "failed to send message to SLIM");
174 }
175 }
176 OutboundMessage::ToApp(result) => {
177 if let Err(e) = settings.app_tx.send(result) {
178 tracing::error!(error = %e, "failed to send message to application");
179 }
180 }
181 }
182 }
183 }
184
185 async fn processing_loop<P, V, M>(
186 mut inner: impl MessageHandler + 'static,
187 mut rx: sync::mpsc::Receiver<SessionMessage>,
188 cancellation_token: CancellationToken,
189 settings: SessionSettings<P, V, M>,
190 ) where
191 P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
192 V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
193 M: crate::subscription_manager::SubscriptionOps,
194 {
195 let mut shutdown_deadline = std::pin::pin!(tokio::time::sleep(Duration::MAX));
197
198 if let Err(e) = inner.init().await {
200 tracing::error!(error = %e.chain(), "error during initialization of session");
201 }
202
203 loop {
204 tokio::select! {
205 _ = cancellation_token.cancelled(), if inner.processing_state() == ProcessingState::Active => {
206 let shutdown_timeout = settings.graceful_shutdown_timeout
208 .unwrap_or(Duration::from_secs(60)); debug!("consuming pending messages before entering draining state");
212 while let Ok(msg) = rx.try_recv() {
213 if let SessionMessage::OnMessage { message, direction: MessageDirection::North, .. } = &msg
214 && let Err(e) = crate::session_controller::verify_identity(message, &settings.identity_verifier).await {
215 debug!(error = %e.chain(), "dropping inbound message during drain: identity verification failed");
216 continue;
217 }
218 match inner.on_message(msg).await {
219 Ok(output) => Self::dispatch_output(output, &settings).await,
220 Err(e) => {
221 tracing::error!(error = %e.chain(), "error processing message during draining - close immediately.");
222 break;
223 }
224 }
225 }
226
227 match inner.on_message(SessionMessage::StartDrain {
229 grace_period: shutdown_timeout
230 }).await {
231 Ok(output) => Self::dispatch_output(output, &settings).await,
232 Err(e) => {
233 tracing::error!(error = %e.chain(), "error during start drain");
234 break;
235 }
236 }
237
238 Self::enter_draining_state(&mut shutdown_deadline, &settings);
239
240 debug!("cancellation requested, entering draining state");
241 }
242 _ = &mut shutdown_deadline => {
243 debug!("graceful shutdown timeout reached, forcing exit");
244 break;
245 }
246 msg = rx.recv() => {
247 match msg {
248 Some(session_message) => {
249 if let SessionMessage::GetParticipantsList { tx } = session_message {
251 let participants_list = inner.participants_list();
252 let _ = tx.send(participants_list);
253 continue;
254 }
255
256 if let SessionMessage::OnMessage { message, direction: MessageDirection::North, .. } = &session_message
257 && let Err(e) = crate::session_controller::verify_identity(message, &settings.identity_verifier).await {
258 debug!(
259 error = %e.chain(),
260 msg_type = %message.get_session_message_type().as_str_name(),
261 msg_id = %message.get_id(),
262 "dropping inbound message: identity verification failed",
263 );
264 continue;
265 }
266
267 let draining = inner.processing_state() == ProcessingState::Draining;
268
269 if draining && matches!(session_message, SessionMessage::OnMessage { direction: MessageDirection::South, .. }) {
271 tracing::debug!("session is draining, rejecting new messages from application");
272 if let SessionMessage::OnMessage { ack_tx: Some(ack_tx), .. } = session_message {
273 let _ = ack_tx.send(Err(SessionError::SessionDrainingDrop));
274 }
275 continue;
276 }
277
278 match inner.on_message(session_message).await {
279 Ok(output) => {
280 Self::dispatch_output(output, &settings).await;
281 if !draining && inner.processing_state() == ProcessingState::Draining {
284 debug!("internal component requested draining, entering draining state");
285 Self::enter_draining_state(&mut shutdown_deadline, &settings);
286 }
287 }
288 Err(e) => {
289 debug!(
290 error=%e,
291 "Error processing message{}",
292 if draining { " during graceful shutdown" } else { "" }
293 );
294 if draining {
295 debug!("Exiting processing loop due to error while draining");
296 break;
297 }
298 }
299 }
300 }
301 None => {
302 debug!("Session channel closed, no more messages can arrive - exiting processing loop");
303 break;
304 }
305 }
306 }
307 }
308
309 if inner.processing_state() == ProcessingState::Draining && !inner.needs_drain() {
311 debug!("draining complete, exiting processing loop");
312 break;
313 }
314 }
315
316 if let Err(e) = inner.on_shutdown().await {
318 tracing::error!(error = %e.chain(), "error during shutdown of session");
319 }
320 }
321
322 pub fn id(&self) -> u32 {
324 self.id
325 }
326
327 pub fn source(&self) -> &ProtoName {
328 &self.source
329 }
330
331 pub fn dst(&self) -> &ProtoName {
332 &self.destination
333 }
334
335 pub fn session_type(&self) -> ProtoSessionType {
336 self.config.session_type
337 }
338
339 pub fn metadata(&self) -> HashMap<String, String> {
340 self.config.metadata.clone()
341 }
342
343 pub fn session_config(&self) -> SessionConfig {
344 self.config.clone()
345 }
346
347 pub fn is_initiator(&self) -> bool {
348 self.config.initiator
349 }
350
351 pub async fn participants_list(&self) -> Result<Vec<ProtoName>, SessionError> {
352 let (tx, rx) = oneshot::channel();
353
354 self.tx_controller
356 .send(SessionMessage::GetParticipantsList { tx })
357 .await
358 .map_err(|_| SessionError::ParticipantsListQueryFailed)?;
359
360 rx.await
362 .map_err(|_| SessionError::ParticipantsListQueryFailed)
363 }
364
365 async fn on_message(
366 &self,
367 message: Message,
368 direction: MessageDirection,
369 ack_tx: Option<oneshot::Sender<Result<(), SessionError>>>,
370 ) -> Result<(), SessionError> {
371 self.tx_controller
372 .send(SessionMessage::OnMessage {
373 message,
374 direction,
375 ack_tx,
376 })
377 .await
378 .map_err(|_e| SessionError::SessionControllerSendFailed)
379 }
380
381 pub async fn on_message_from_app(
383 &self,
384 message: Message,
385 ) -> Result<CompletionHandle, SessionError> {
386 let (ack_tx, ack_rx) = oneshot::channel();
387 self.on_message(message, MessageDirection::South, Some(ack_tx))
388 .await?;
389
390 let ret = CompletionHandle::from_oneshot_receiver(ack_rx);
391
392 Ok(ret)
393 }
394
395 pub async fn on_message_from_slim(&self, message: Message) -> Result<(), SessionError> {
397 self.on_message(message, MessageDirection::North, None)
398 .await
399 }
400
401 pub async fn on_error_message_from_slim(
403 &self,
404 error: SessionError,
405 ) -> Result<(), SessionError> {
406 self.tx_controller
407 .send(SessionMessage::MessageError { error })
408 .await
409 .map_err(|_e| SessionError::SessionControllerSendFailed)
410 }
411
412 pub fn close(&self) -> Result<tokio::task::JoinHandle<()>, SessionError> {
413 self.cancellation_token.cancel();
414
415 self.handle
416 .lock()
417 .take()
418 .ok_or(SessionError::SessionAlreadyClosed)
419 }
420
421 pub async fn publish_message(
422 &self,
423 message: Message,
424 ) -> Result<CompletionHandle, SessionError> {
425 self.on_message_from_app(message).await
426 }
427
428 pub async fn publish_to(
430 &self,
431 name: &ProtoName,
432 forward_to: u64,
433 blob: Vec<u8>,
434 payload_type: Option<String>,
435 metadata: Option<HashMap<String, String>>,
436 ) -> Result<CompletionHandle, SessionError> {
437 self.publish_with_flags(
438 name,
439 SlimHeaderFlags::default().with_forward_to(forward_to),
440 blob,
441 payload_type,
442 metadata,
443 )
444 .await
445 }
446
447 pub async fn publish(
449 &self,
450 name: &ProtoName,
451 blob: Vec<u8>,
452 payload_type: Option<String>,
453 metadata: Option<HashMap<String, String>>,
454 ) -> Result<CompletionHandle, SessionError> {
455 self.publish_with_flags(
456 name,
457 SlimHeaderFlags::default(),
458 blob,
459 payload_type,
460 metadata,
461 )
462 .await
463 }
464
465 pub async fn publish_with_flags(
467 &self,
468 name: &ProtoName,
469 flags: SlimHeaderFlags,
470 blob: Vec<u8>,
471 payload_type: Option<String>,
472 metadata: Option<HashMap<String, String>>,
473 ) -> Result<CompletionHandle, SessionError> {
474 let ct = payload_type.unwrap_or_else(|| "msg".to_string());
475
476 let mut msg = Message::builder()
477 .source(self.source().clone())
478 .destination(name.clone())
479 .identity("")
480 .flags(flags)
481 .session_type(self.session_type())
482 .session_message_type(ProtoSessionMessageType::Msg)
483 .session_id(self.id())
484 .message_id(rand::random::<u32>()) .application_payload(&ct, blob)
486 .build_publish()?;
487 if let Some(map) = metadata
488 && !map.is_empty()
489 {
490 msg.set_metadata_map(map);
491 }
492
493 self.publish_message(msg).await
495 }
496
497 fn create_discovery_request(&self, destination: &ProtoName) -> Result<Message, SessionError> {
499 let payload = CommandPayload::builder().discovery_request().as_content();
500
501 let msg = Message::builder()
502 .source(self.source().clone())
503 .destination(destination.clone())
504 .identity("")
505 .session_type(self.session_type())
506 .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
507 .session_id(self.id())
508 .message_id(rand::random::<u32>())
509 .payload(payload)
510 .build_publish()?;
511
512 Ok(msg)
513 }
514
515 pub(crate) async fn invite_participant_internal(
516 &self,
517 destination: &ProtoName,
518 ) -> Result<CompletionHandle, SessionError> {
519 let msg = self.create_discovery_request(destination)?;
520 self.publish_message(msg).await
521 }
522
523 pub async fn invite_participant(
524 &self,
525 destination: &ProtoName,
526 ) -> Result<CompletionHandle, SessionError> {
527 match self.session_type() {
528 ProtoSessionType::PointToPoint => Err(SessionError::CannotInviteToP2P),
529 ProtoSessionType::Multicast => {
530 if !self.is_initiator() {
531 return Err(SessionError::NotInitiator);
532 }
533 self.invite_participant_internal(destination).await
534 }
535 _ => Err(SessionError::SessionTypeUnknown(self.session_type())),
536 }
537 }
538
539 pub async fn remove_participant(
540 &self,
541 destination: &ProtoName,
542 ) -> Result<CompletionHandle, SessionError> {
543 match self.session_type() {
544 ProtoSessionType::PointToPoint => Err(SessionError::CannotRemoveFromP2P),
545 ProtoSessionType::Multicast => {
546 if !self.is_initiator() {
547 return Err(SessionError::NotInitiator);
548 }
549 let msg = Message::builder()
550 .source(self.source().clone())
551 .destination(destination.clone().with_id(NameId::NULL_COMPONENT))
552 .identity("")
553 .session_type(ProtoSessionType::Multicast)
554 .session_message_type(ProtoSessionMessageType::LeaveRequest)
555 .session_id(self.id())
556 .message_id(rand::random::<u32>())
557 .payload(CommandPayload::builder().leave_request().as_content())
558 .build_publish()?;
559 self.publish_message(msg).await
560 }
561 _ => Err(SessionError::SessionTypeUnknown(self.session_type())),
562 }
563 }
564}
565
566impl Drop for SessionController {
567 fn drop(&mut self) {
568 self.cancellation_token.cancel();
569 }
570}
571
572pub fn handle_channel_discovery_message(
573 message: &Message,
574 app_name: &ProtoName,
575 session_id: u32,
576 session_type: ProtoSessionType,
577) -> Result<Message, SessionError> {
578 let destination = message.get_slim_header().source.clone().unwrap();
579
580 let mut source = message.get_slim_header().destination.clone().unwrap();
585 source.set_id(app_name.id());
586 let msg_id = message.get_id();
587
588 let slim_header = SlimHeader::new(
589 source,
590 destination,
591 "",
592 Some(SlimHeaderFlags::default().with_forward_to(message.get_incoming_conn())),
593 );
594
595 let msg = Message::builder()
596 .with_slim_header(slim_header)
597 .session_type(session_type)
598 .session_message_type(ProtoSessionMessageType::DiscoveryReply)
599 .session_id(session_id)
600 .message_id(msg_id)
601 .payload(CommandPayload::builder().discovery_reply().as_content())
602 .build_publish()?;
603
604 Ok(msg)
605}
606
607pub(crate) struct SessionControllerCommon<
608 P,
609 V,
610 M = crate::subscription_manager::SubscriptionManager,
611> where
612 P: TokenProvider + Send + Sync + Clone + 'static,
613 V: Verifier + Send + Sync + Clone + 'static,
614 M: crate::subscription_manager::SubscriptionOps,
615{
616 pub(crate) settings: SessionSettings<P, V, M>,
618
619 pub(crate) sender: ControllerSender,
621
622 pub(crate) processing_state: ProcessingState,
624
625 subscription_ids: HashMap<(SubscriptionKind, ProtoName, u64), u64>,
627}
628
629#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
632pub(crate) enum SubscriptionKind {
633 Route,
635 Subscription,
637}
638
639impl<P, V, M> SessionControllerCommon<P, V, M>
640where
641 P: TokenProvider + Send + Sync + Clone + 'static,
642 V: Verifier + Send + Sync + Clone + 'static,
643 M: crate::subscription_manager::SubscriptionOps,
644{
645 pub(crate) fn new(settings: SessionSettings<P, V, M>) -> Self {
646 let controller_sender = ControllerSender::new(
648 settings.config.get_timer_settings(),
649 settings.source.clone(),
650 settings.config.session_type,
651 settings.id,
652 Some(PING_INTERVAL),
653 settings.config.initiator,
654 settings.tx_session.clone(),
655 );
656
657 SessionControllerCommon {
658 settings,
659 sender: controller_sender,
660 processing_state: ProcessingState::Active,
661 subscription_ids: HashMap::new(),
662 }
663 }
664
665 pub(crate) fn send_with_timer(
667 &mut self,
668 message: Message,
669 ) -> Result<SessionOutput, SessionError> {
670 self.sender.on_message(&message)
671 }
672
673 async fn await_subscription_ack(
674 rx: tokio::sync::oneshot::Receiver<
675 Result<(), crate::subscription_manager::SubscriptionAckError>,
676 >,
677 ) -> Result<(), SessionError> {
678 crate::subscription_manager::SubscriptionManager::await_ack(rx)
679 .await
680 .map_err(SessionError::SubscriptionAckFailed)
681 }
682
683 pub(crate) async fn add_route(
684 &mut self,
685 name: ProtoName,
686 conn: u64,
687 ) -> Result<(), SessionError> {
688 if name == self.settings.source.clone() {
689 return Ok(());
691 }
692
693 let source_proto = self.settings.source.clone();
694 let (subscription_id, rx) = self
695 .settings
696 .subscription_manager
697 .set_route(&source_proto, &name, conn)
698 .await
699 .map_err(SessionError::SubscriptionAckFailed)?;
700 Self::await_subscription_ack(rx).await?;
701
702 debug!(%name, %conn, %subscription_id, source = %self.settings.source, "route added");
703
704 self.subscription_ids
705 .insert((SubscriptionKind::Route, name, conn), subscription_id);
706
707 Ok(())
708 }
709
710 pub(crate) async fn delete_route(
711 &mut self,
712 name: ProtoName,
713 conn: u64,
714 ) -> Result<(), SessionError> {
715 if name == self.settings.source.clone() {
716 return Ok(());
718 }
719
720 let key = (SubscriptionKind::Route, name, conn);
721 let subscription_id = self.subscription_ids.remove(&key);
722 let (_, name, conn) = key;
723 match subscription_id {
724 Some(subscription_id) => {
725 let source_proto = self.settings.source.clone();
726 let rx = self
727 .settings
728 .subscription_manager
729 .remove_route(&source_proto, &name, subscription_id, conn)
730 .await
731 .map_err(SessionError::SubscriptionAckFailed)?;
732
733 Self::await_subscription_ack(rx).await?;
734 tracing::debug!(%name, %conn, %subscription_id, "route deleted");
735 }
736 None => {
737 tracing::warn!(
738 %name, %conn, io = %self.settings.source,
739 "no subscription_id found for route, skipping delete"
740 );
741 }
742 }
743
744 Ok(())
745 }
746
747 pub(crate) async fn add_subscription(
748 &mut self,
749 name: ProtoName,
750 conn: u64,
751 ) -> Result<(), SessionError> {
752 let source_proto = self.settings.source.clone();
753 let (subscription_id, rx) = self
754 .settings
755 .subscription_manager
756 .subscribe(&source_proto, &name, Some(conn))
757 .await
758 .map_err(SessionError::SubscriptionAckFailed)?;
759
760 Self::await_subscription_ack(rx).await?;
761
762 debug!(%name, %conn, %subscription_id, "subscription added");
763
764 self.subscription_ids.insert(
765 (SubscriptionKind::Subscription, name, conn),
766 subscription_id,
767 );
768
769 Ok(())
770 }
771
772 pub(crate) async fn delete_subscription(
773 &mut self,
774 name: ProtoName,
775 conn: u64,
776 ) -> Result<(), SessionError> {
777 let key = (SubscriptionKind::Subscription, name, conn);
778 let subscription_id = self.subscription_ids.remove(&key);
779 let (_, name, conn) = key;
780 match subscription_id {
781 Some(subscription_id) => {
782 let source_proto = self.settings.source.clone();
783 let rx = self
784 .settings
785 .subscription_manager
786 .unsubscribe(&source_proto, &name, subscription_id, Some(conn))
787 .await
788 .map_err(SessionError::SubscriptionAckFailed)?;
789
790 Self::await_subscription_ack(rx).await?;
791 debug!(%name, %conn, %subscription_id, "subscription deleted");
792 }
793 None => {
794 tracing::debug!(
795 %name, %conn,
796 "no subscription_id found for subscription, skipping delete"
797 );
798 }
799 }
800
801 Ok(())
802 }
803
804 pub(crate) fn create_control_message(
805 &mut self,
806 dst: &ProtoName,
807 message_type: ProtoSessionMessageType,
808 message_id: u32,
809 payload: Content,
810 broadcast: bool,
811 ) -> Result<Message, SessionError> {
812 let mut builder = Message::builder()
813 .source(self.settings.source.clone())
814 .destination(dst.clone())
815 .identity("")
816 .session_type(self.settings.config.session_type)
817 .session_message_type(message_type)
818 .session_id(self.settings.id)
819 .message_id(message_id)
820 .payload(payload);
821
822 if broadcast {
823 builder = builder.fanout(256);
824 }
825
826 let ret = builder.build_publish()?;
827
828 Ok(ret)
829 }
830
831 pub(crate) fn send_control_message(
833 &mut self,
834 dst: &ProtoName,
835 message_type: ProtoSessionMessageType,
836 message_id: u32,
837 payload: Content,
838 metadata: Option<HashMap<String, String>>,
839 broadcast: bool,
840 ) -> Result<SessionOutput, SessionError> {
841 let mut msg =
842 self.create_control_message(dst, message_type, message_id, payload, broadcast)?;
843 if let Some(m) = metadata {
844 msg.set_metadata_map(m);
845 }
846 self.send_with_timer(msg)
847 }
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853
854 use crate::Direction;
861 use crate::session_config::MlsSettings;
862 use crate::subscription_manager::{SpySubscriptionManager, SubscriptionCall};
863 use slim_auth::shared_secret::SharedSecret;
864
865 use std::sync::Arc;
866 use std::sync::atomic::AtomicBool;
867 use std::time::Duration;
868 use tokio::time::timeout;
869 use tracing_test::traced_test;
870
871 const SHARED_SECRET: &str = "kjandjansdiasb8udaijdniasdaindasndasndasndasndasndasndasndas";
872
873 fn test_identity() -> String {
874 SharedSecret::new("test", SHARED_SECRET)
875 .unwrap()
876 .get_token()
877 .unwrap()
878 }
879
880 struct SessionControllerTestBuilder {
882 session_id: u32,
883 source: ProtoName,
884 destination: ProtoName,
885 session_type: ProtoSessionType,
886 mls_settings: Option<MlsSettings>,
887 initiator: bool,
888 max_retries: Option<u32>,
889 interval: Option<Duration>,
890 metadata: HashMap<String, String>,
891 graceful_shutdown_timeout: Option<Duration>,
892 }
893
894 impl SessionControllerTestBuilder {
895 #[allow(dead_code)]
896 fn new() -> Self {
897 Self {
898 session_id: 10,
899 source: ProtoName::from_strings(["org", "ns", "source"]).with_id(1),
900 destination: ProtoName::from_strings(["org", "ns", "dest"]).with_id(2),
901 session_type: ProtoSessionType::PointToPoint,
902 mls_settings: None,
903 initiator: true,
904 max_retries: Some(5),
905 interval: Some(Duration::from_millis(200)),
906 metadata: HashMap::new(),
907 graceful_shutdown_timeout: Some(Duration::from_secs(10)),
908 }
909 }
910
911 fn with_session_id(mut self, id: u32) -> Self {
912 self.session_id = id;
913 self
914 }
915
916 #[allow(dead_code)]
917 fn with_source(mut self, source: ProtoName) -> Self {
918 self.source = source;
919 self
920 }
921
922 #[allow(dead_code)]
923 fn with_destination(mut self, destination: ProtoName) -> Self {
924 self.destination = destination;
925 self
926 }
927
928 fn with_session_type(mut self, session_type: ProtoSessionType) -> Self {
929 self.session_type = session_type;
930 self
931 }
932
933 fn with_mls_enabled(mut self, enabled: bool) -> Self {
934 self.mls_settings = if enabled {
935 Some(MlsSettings::default())
936 } else {
937 None
938 };
939 self
940 }
941
942 fn with_initiator(mut self, initiator: bool) -> Self {
943 self.initiator = initiator;
944 self
945 }
946
947 fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
948 self.metadata = metadata;
949 self
950 }
951
952 fn with_graceful_shutdown_timeout(mut self, timeout: Duration) -> Self {
953 self.graceful_shutdown_timeout = Some(timeout);
954 self
955 }
956
957 fn build(
958 self,
959 ) -> (
960 SessionController,
961 tokio::sync::mpsc::Receiver<Result<Message, slim_datapath::Status>>,
962 tokio::sync::mpsc::UnboundedReceiver<Result<Message, SessionError>>,
963 ) {
964 let config = SessionConfig {
965 session_type: self.session_type,
966 max_retries: self.max_retries,
967 interval: self.interval,
968 mls_settings: self.mls_settings,
969 initiator: self.initiator,
970 metadata: self.metadata,
971 };
972
973 let (tx_slim, rx_slim) = tokio::sync::mpsc::channel(10);
974 let (tx_app, rx_app) = tokio::sync::mpsc::unbounded_channel();
975 let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10);
976
977 let controller = SessionController::builder()
978 .with_id(self.session_id)
979 .with_source(self.source.clone())
980 .with_destination(self.destination.clone())
981 .with_config(config)
982 .with_identity_provider(SharedSecret::new("test", SHARED_SECRET).unwrap())
983 .with_identity_verifier(SharedSecret::new("test", SHARED_SECRET).unwrap())
984 .with_slim_tx(tx_slim)
985 .with_app_tx(tx_app)
986 .with_tx_to_session_layer(tx_session_layer)
987 .ready()
988 .expect("failed to validate builder")
989 .build()
990 .expect("failed to build controller");
991
992 (controller, rx_slim, rx_app)
993 }
994 }
995
996 #[tokio::test]
997 async fn test_session_controller_getters() {
998 let mut metadata = HashMap::new();
999 metadata.insert("key1".to_string(), "value1".to_string());
1000
1001 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1002 .with_session_id(42)
1003 .with_session_type(ProtoSessionType::Multicast)
1004 .with_mls_enabled(true)
1005 .with_metadata(metadata)
1006 .build();
1007
1008 assert_eq!(controller.id(), 42);
1009 assert_eq!(
1010 controller.source(),
1011 &ProtoName::from_strings(["org", "ns", "source"]).with_id(1)
1012 );
1013 assert_eq!(
1015 controller.dst(),
1016 &ProtoName::from_strings(["org", "ns", "dest"]).with_id(NameId::DATA_CHANNEL_ID)
1017 );
1018 assert_eq!(controller.session_type(), ProtoSessionType::Multicast);
1019 assert!(controller.is_initiator());
1020 assert_eq!(
1021 controller.metadata().get("key1"),
1022 Some(&"value1".to_string())
1023 );
1024
1025 let retrieved_config = controller.session_config();
1026 assert_eq!(retrieved_config.session_type, ProtoSessionType::Multicast);
1027 assert_eq!(retrieved_config.max_retries, Some(5));
1028 }
1029
1030 #[tokio::test]
1031 async fn test_publish_basic() {
1032 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1033
1034 let target_name = ProtoName::from_strings(["org", "ns", "target"]);
1035 let payload = b"Hello World".to_vec();
1036
1037 controller
1038 .publish(
1039 &target_name,
1040 payload.clone(),
1041 Some("test-type".to_string()),
1042 None,
1043 )
1044 .await
1045 .expect("publish should succeed");
1046
1047 tokio::time::sleep(Duration::from_millis(50)).await;
1048 }
1049
1050 #[tokio::test]
1051 async fn test_publish_to_specific_connection() {
1052 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1053 .with_session_type(ProtoSessionType::Multicast)
1054 .build();
1055
1056 let target_name = ProtoName::from_strings(["org", "ns", "target"]);
1057 let payload = b"Hello to specific connection".to_vec();
1058 let connection_id = 123u64;
1059
1060 controller
1061 .publish_to(
1062 &target_name,
1063 connection_id,
1064 payload.clone(),
1065 Some("test-type".to_string()),
1066 None,
1067 )
1068 .await
1069 .expect("publish_to should succeed");
1070
1071 tokio::time::sleep(Duration::from_millis(50)).await;
1072 }
1073
1074 #[tokio::test]
1075 async fn test_publish_with_metadata() {
1076 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1077 .with_session_type(ProtoSessionType::Multicast)
1078 .build();
1079
1080 let target_name = ProtoName::from_strings(["org", "ns", "target"]);
1081 let payload = b"Hello with metadata".to_vec();
1082
1083 let mut metadata = HashMap::new();
1084 metadata.insert("custom_key".to_string(), "custom_value".to_string());
1085
1086 controller
1087 .publish(
1088 &target_name,
1089 payload.clone(),
1090 Some("test-type".to_string()),
1091 Some(metadata),
1092 )
1093 .await
1094 .expect("publish with metadata should succeed");
1095
1096 tokio::time::sleep(Duration::from_millis(50)).await;
1097 }
1098
1099 #[tokio::test]
1100 async fn test_invite_participant_in_multicast() {
1101 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1102 .with_session_type(ProtoSessionType::Multicast)
1103 .build();
1104
1105 let participant = ProtoName::from_strings(["org", "ns", "participant"]);
1106
1107 controller
1108 .invite_participant(&participant)
1109 .await
1110 .expect("invite should succeed");
1111
1112 tokio::time::sleep(Duration::from_millis(50)).await;
1113 }
1114
1115 #[tokio::test]
1116 async fn test_invite_participant_not_initiator_error() {
1117 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1118 .with_session_type(ProtoSessionType::Multicast)
1119 .with_initiator(false)
1120 .build();
1121
1122 let participant = ProtoName::from_strings(["org", "ns", "new_participant"]);
1123
1124 let result = controller.invite_participant(&participant).await;
1125 assert!(result.is_err_and(|e| matches!(e, SessionError::NotInitiator)));
1126 }
1127
1128 #[tokio::test]
1129 async fn test_invite_participant_p2p_error() {
1130 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1131 .with_session_type(ProtoSessionType::PointToPoint)
1132 .build();
1133
1134 let participant = ProtoName::from_strings(["org", "ns", "participant"]);
1135
1136 let result = controller.invite_participant(&participant).await;
1137 assert!(result.is_err_and(|e| matches!(e, SessionError::CannotInviteToP2P)));
1138 }
1139
1140 #[tokio::test]
1141 async fn test_remove_participant_in_multicast() {
1142 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1143 .with_session_type(ProtoSessionType::Multicast)
1144 .build();
1145
1146 let participant = ProtoName::from_strings(["org", "ns", "participant"]);
1147
1148 controller
1149 .remove_participant(&participant)
1150 .await
1151 .expect("remove should succeed");
1152
1153 tokio::time::sleep(Duration::from_millis(50)).await;
1154 }
1155
1156 #[tokio::test]
1157 async fn test_remove_participant_not_initiator_error() {
1158 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1159 .with_session_type(ProtoSessionType::Multicast)
1160 .with_initiator(false)
1161 .build();
1162
1163 let participant = ProtoName::from_strings(["org", "ns", "participant"]);
1164
1165 let result = controller.remove_participant(&participant).await;
1166 assert!(result.is_err_and(|e| matches!(e, SessionError::NotInitiator)));
1167 }
1168
1169 #[tokio::test]
1170 async fn test_remove_participant_p2p_error() {
1171 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1172 .with_session_type(ProtoSessionType::PointToPoint)
1173 .build();
1174
1175 let participant = ProtoName::from_strings(["org", "ns", "participant"]);
1176
1177 let result = controller.remove_participant(&participant).await;
1178 assert!(result.is_err_and(|e| matches!(e, SessionError::CannotRemoveFromP2P)));
1179 }
1180
1181 #[test]
1182 fn test_handle_channel_discovery_message() {
1183 let app_name = ProtoName::from_strings(["org", "ns", "app"]).with_id(100);
1184 let session_id = 42;
1185
1186 let discovery_request = Message::builder()
1187 .source(ProtoName::from_strings(["org", "ns", "requester"]).with_id(1))
1188 .destination(ProtoName::from_strings(["org", "ns", "service"]))
1189 .identity(test_identity())
1190 .incoming_conn(999)
1191 .session_type(ProtoSessionType::Multicast)
1192 .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
1193 .session_id(session_id)
1194 .message_id(123)
1195 .payload(CommandPayload::builder().discovery_request().as_content())
1196 .build_publish()
1197 .unwrap();
1198
1199 let response = handle_channel_discovery_message(
1200 &discovery_request,
1201 &app_name,
1202 session_id,
1203 ProtoSessionType::Multicast,
1204 )
1205 .expect("should create discovery response");
1206
1207 assert_eq!(
1208 response.get_session_message_type(),
1209 ProtoSessionMessageType::DiscoveryReply
1210 );
1211 assert_eq!(response.get_session_header().get_session_id(), session_id);
1212 assert_eq!(response.get_id(), 123);
1213 assert_eq!(
1214 response.get_dst(),
1215 ProtoName::from_strings(["org", "ns", "requester"]).with_id(1)
1216 );
1217 assert_eq!(response.get_slim_header().get_forward_to(), Some(999));
1218 }
1219
1220 #[tokio::test]
1221 async fn test_controller_drop_cancels_processing() {
1222 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1223
1224 let token = controller.cancellation_token.clone();
1225 assert!(!token.is_cancelled());
1226
1227 drop(controller);
1228
1229 tokio::time::sleep(Duration::from_millis(100)).await;
1230 assert!(token.is_cancelled());
1231 }
1232
1233 #[tokio::test]
1234 async fn test_close_success() {
1235 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1236 .with_graceful_shutdown_timeout(std::time::Duration::from_secs(2))
1237 .build();
1238
1239 let token = controller.cancellation_token.clone();
1240 assert!(!token.is_cancelled());
1241
1242 let handle = controller.close();
1243 assert!(handle.is_ok(), "got error {}", handle.unwrap_err());
1244 assert!(token.is_cancelled());
1245
1246 handle
1248 .unwrap()
1249 .await
1250 .expect("processing task should complete");
1251 }
1252
1253 #[tokio::test]
1254 async fn test_close_already_closed() {
1255 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1256
1257 let handle = controller.close();
1259 assert!(handle.is_ok());
1260 handle
1261 .unwrap()
1262 .await
1263 .expect("processing task should complete");
1264
1265 let result = controller.close();
1267 assert!(result.is_err());
1268 match result {
1269 Err(SessionError::SessionAlreadyClosed) => {
1270 }
1272 _ => panic!("Expected SessionError::SessionAlreadyClosed"),
1273 }
1274 }
1275
1276 #[tokio::test]
1277 async fn test_close_cancels_token_immediately() {
1278 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1279
1280 let token = controller.cancellation_token.clone();
1281
1282 assert!(!token.is_cancelled());
1284
1285 let handle = controller.close();
1287 assert!(handle.is_ok());
1288
1289 assert!(token.is_cancelled());
1291
1292 handle.unwrap().await.expect("processing should complete");
1294 }
1295
1296 #[tokio::test]
1297 async fn test_on_message_direction_north() {
1298 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1299
1300 let test_message = Message::builder()
1301 .source(controller.dst().clone())
1302 .destination(controller.source().clone())
1303 .identity(test_identity())
1304 .session_type(ProtoSessionType::PointToPoint)
1305 .session_message_type(ProtoSessionMessageType::Msg)
1306 .session_id(controller.id())
1307 .message_id(1)
1308 .application_payload("test", b"test data".to_vec())
1309 .build_publish()
1310 .unwrap();
1311
1312 let result = controller.on_message_from_slim(test_message).await;
1313 assert!(result.is_ok());
1314 }
1315
1316 #[tokio::test]
1317 async fn test_create_discovery_request() {
1318 let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1319 .with_session_type(ProtoSessionType::Multicast)
1320 .build();
1321
1322 let target = ProtoName::from_strings(["org", "ns", "target"]);
1323 let discovery_msg = controller
1324 .create_discovery_request(&target)
1325 .expect("should create discovery request");
1326
1327 assert_eq!(discovery_msg.get_source(), *controller.source());
1328 assert_eq!(discovery_msg.get_dst(), target);
1329 assert_eq!(
1330 discovery_msg.get_session_message_type(),
1331 ProtoSessionMessageType::DiscoveryRequest
1332 );
1333 assert_eq!(
1334 discovery_msg.get_session_header().get_session_id(),
1335 controller.id()
1336 );
1337 assert_eq!(
1338 discovery_msg.get_session_type(),
1339 ProtoSessionType::Multicast
1340 );
1341 }
1342
1343 #[tokio::test]
1344 #[traced_test]
1345 async fn test_end_to_end_p2p() {
1346 let session_id = 10;
1347 let moderator_name = ProtoName::from_strings(["org", "ns", "moderator"]).with_id(1);
1348 let participant_name = ProtoName::from_strings(["org", "ns", "participant"]);
1349 let participant_name_id = ProtoName::from_strings(["org", "ns", "participant"]).with_id(1);
1350 let (tx_slim_moderator, mut rx_slim_moderator) = tokio::sync::mpsc::channel(10);
1352 let (tx_app_moderator, _rx_app_moderator) = tokio::sync::mpsc::unbounded_channel();
1353 let (tx_session_layer_moderator, _rx_session_layer_moderator) =
1354 tokio::sync::mpsc::channel(10);
1355
1356 let moderator_config = SessionConfig {
1357 session_type: slim_datapath::api::ProtoSessionType::PointToPoint,
1358 max_retries: Some(5),
1359 interval: Some(Duration::from_millis(1000)),
1360 mls_settings: Some(MlsSettings::default()),
1361 initiator: true,
1362 metadata: std::collections::HashMap::new(),
1363 };
1364
1365 let (spy_moderator_mgr, mut rx_spy_moderator) = SpySubscriptionManager::new();
1366 let moderator = SessionController::builder()
1367 .with_id(session_id)
1368 .with_source(moderator_name.clone())
1369 .with_destination(participant_name.clone())
1370 .with_config(moderator_config)
1371 .with_identity_provider(SharedSecret::new("moderator", SHARED_SECRET).unwrap())
1372 .with_identity_verifier(SharedSecret::new("moderator", SHARED_SECRET).unwrap())
1373 .with_slim_tx(tx_slim_moderator.clone())
1374 .with_app_tx(tx_app_moderator.clone())
1375 .with_tx_to_session_layer(tx_session_layer_moderator)
1376 .with_subscription_manager(spy_moderator_mgr)
1377 .ready()
1378 .expect("failed to validate builder")
1379 .build()
1380 .unwrap();
1381
1382 let (tx_slim_participant, mut rx_slim_participant) = tokio::sync::mpsc::channel(10);
1384 let (tx_app_participant, mut rx_app_participant) = tokio::sync::mpsc::unbounded_channel();
1385 let (tx_session_layer_participant, _rx_session_layer_participant) =
1386 tokio::sync::mpsc::channel(10);
1387
1388 let participant_config = SessionConfig {
1389 session_type: slim_datapath::api::ProtoSessionType::PointToPoint,
1390 max_retries: Some(5),
1391 interval: Some(Duration::from_millis(200)),
1392 mls_settings: Some(MlsSettings::default()),
1393 initiator: false,
1394 metadata: std::collections::HashMap::new(),
1395 };
1396
1397 let (spy_participant_mgr, mut rx_spy_participant) = SpySubscriptionManager::new();
1398 let participant = SessionController::builder()
1399 .with_id(session_id)
1400 .with_source(participant_name_id.clone())
1401 .with_destination(moderator_name.clone())
1402 .with_config(participant_config)
1403 .with_identity_provider(SharedSecret::new("participant", SHARED_SECRET).unwrap())
1404 .with_identity_verifier(SharedSecret::new("participant", SHARED_SECRET).unwrap())
1405 .with_slim_tx(tx_slim_participant.clone())
1406 .with_app_tx(tx_app_participant.clone())
1407 .with_tx_to_session_layer(tx_session_layer_participant)
1408 .with_subscription_manager(spy_participant_mgr)
1409 .ready()
1410 .expect("failed to validate builder")
1411 .build()
1412 .unwrap();
1413
1414 let completion_handle = moderator
1415 .invite_participant_internal(&participant_name)
1416 .await
1417 .expect("error inviting participant");
1418
1419 let received_discovery_request =
1420 timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1421 .await
1422 .expect("timeout waiting for discovery request on moderator slim channel")
1423 .expect("channel closed")
1424 .expect("error in discovery request");
1425
1426 assert_eq!(
1427 received_discovery_request.get_session_message_type(),
1428 slim_datapath::api::ProtoSessionMessageType::DiscoveryRequest
1429 );
1430
1431 let discovery_msg_id = received_discovery_request.get_id();
1432
1433 let mut discovery_reply = Message::builder()
1435 .source(participant_name_id.clone())
1436 .destination(moderator_name.clone())
1437 .identity(test_identity())
1438 .forward_to(1)
1439 .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1440 .session_message_type(slim_datapath::api::ProtoSessionMessageType::DiscoveryReply)
1441 .session_id(session_id)
1442 .message_id(discovery_msg_id)
1443 .payload(CommandPayload::builder().discovery_reply().as_content())
1444 .build_publish()
1445 .unwrap();
1446 discovery_reply
1447 .get_slim_header_mut()
1448 .set_incoming_conn(Some(1));
1449
1450 moderator
1451 .on_message_from_slim(discovery_reply)
1452 .await
1453 .expect("error processing discovery reply on moderator");
1454
1455 assert_eq!(
1457 rx_spy_moderator.recv().await,
1458 Some(SubscriptionCall::SetRoute),
1459 "moderator should set route after discovery reply"
1460 );
1461
1462 let join_request = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1464 .await
1465 .expect("timeout waiting for join request on moderator slim channel")
1466 .expect("channel closed")
1467 .expect("error in join request");
1468
1469 assert_eq!(
1470 join_request.get_session_message_type(),
1471 slim_datapath::api::ProtoSessionMessageType::JoinRequest
1472 );
1473 assert_eq!(join_request.get_dst(), participant_name_id);
1474
1475 let mut join_request_to_participant = join_request.clone();
1477 join_request_to_participant
1478 .get_slim_header_mut()
1479 .set_incoming_conn(Some(1));
1480
1481 participant
1482 .on_message_from_slim(join_request_to_participant)
1483 .await
1484 .expect("error processing join request on participant");
1485
1486 assert_eq!(
1488 rx_spy_participant.recv().await,
1489 Some(SubscriptionCall::SetRoute),
1490 "participant should set route after join request"
1491 );
1492
1493 let join_reply = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1495 .await
1496 .expect("timeout waiting for join reply on participant slim channel")
1497 .expect("channel closed")
1498 .expect("error in join reply");
1499
1500 assert_eq!(
1501 join_reply.get_session_message_type(),
1502 slim_datapath::api::ProtoSessionMessageType::JoinReply
1503 );
1504 assert_eq!(join_reply.get_dst(), moderator_name);
1505
1506 let mut join_reply_to_moderator = join_reply.clone();
1508 join_reply_to_moderator
1509 .get_slim_header_mut()
1510 .set_incoming_conn(Some(1));
1511
1512 moderator
1513 .on_message_from_slim(join_reply_to_moderator)
1514 .await
1515 .expect("error processing join reply on moderator");
1516
1517 let welcome_message = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1519 .await
1520 .expect("timeout waiting for welcome message on moderator slim channel")
1521 .expect("channel closed")
1522 .expect("error in welcome message");
1523
1524 assert_eq!(
1525 welcome_message.get_session_message_type(),
1526 slim_datapath::api::ProtoSessionMessageType::GroupWelcome
1527 );
1528 assert_eq!(welcome_message.get_dst(), participant_name_id);
1529
1530 let mut welcome_to_participant = welcome_message.clone();
1532 welcome_to_participant
1533 .get_slim_header_mut()
1534 .set_incoming_conn(Some(1));
1535
1536 participant
1537 .on_message_from_slim(welcome_to_participant)
1538 .await
1539 .expect("error processing welcome message on participant");
1540
1541 let ack_group = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1543 .await
1544 .expect("timeout waiting for ack group on participant slim channel")
1545 .expect("channel closed")
1546 .expect("error in ack group");
1547
1548 assert_eq!(
1549 ack_group.get_session_message_type(),
1550 slim_datapath::api::ProtoSessionMessageType::GroupAck
1551 );
1552 assert_eq!(ack_group.get_dst(), moderator_name);
1553
1554 let mut ack_to_moderator = ack_group.clone();
1556 ack_to_moderator
1557 .get_slim_header_mut()
1558 .set_incoming_conn(Some(1));
1559
1560 moderator
1561 .on_message_from_slim(ack_to_moderator)
1562 .await
1563 .expect("error processing ack group on moderator");
1564
1565 let no_more_moderator = timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1567 assert!(
1568 no_more_moderator.is_err(),
1569 "Expected no more messages on moderator slim channel, received {:?}",
1570 no_more_moderator
1571 .ok()
1572 .and_then(|opt| opt)
1573 .and_then(|res| res.ok())
1574 );
1575
1576 let no_more_participant =
1577 timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1578 assert!(
1579 no_more_participant.is_err(),
1580 "Expected no more messages on participant slim channel"
1581 );
1582
1583 completion_handle.await.expect("error in completion handle");
1585
1586 let app_data = b"Hello from moderator to participant".to_vec();
1588 let app_message = Message::builder()
1589 .source(moderator_name.clone())
1590 .destination(participant_name.clone())
1591 .identity(test_identity())
1592 .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1593 .session_message_type(slim_datapath::api::ProtoSessionMessageType::Msg)
1594 .session_id(session_id)
1595 .message_id(1)
1596 .application_payload("test-app-data", app_data.clone())
1597 .build_publish()
1598 .unwrap();
1599
1600 moderator
1602 .on_message_from_app(app_message)
1603 .await
1604 .expect("error sending application message from moderator");
1605
1606 let app_msg_to_slim = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1608 .await
1609 .expect("timeout waiting for application message on moderator slim channel")
1610 .expect("channel closed")
1611 .expect("error in application message");
1612
1613 assert_eq!(app_msg_to_slim.get_dst(), participant_name_id);
1614 assert!(
1615 app_msg_to_slim.is_publish(),
1616 "message should be a publish message"
1617 );
1618
1619 let app_msg_id = app_msg_to_slim.get_id();
1620
1621 let mut app_msg_to_participant = app_msg_to_slim.clone();
1623 app_msg_to_participant
1624 .get_slim_header_mut()
1625 .set_incoming_conn(Some(1));
1626
1627 participant
1628 .on_message_from_slim(app_msg_to_participant)
1629 .await
1630 .expect("error processing application message on participant");
1631
1632 let app_msg_received = timeout(Duration::from_millis(100), rx_app_participant.recv())
1634 .await
1635 .expect("timeout waiting for application message on participant app channel")
1636 .expect("channel closed")
1637 .expect("error in application message to app");
1638
1639 assert_eq!(app_msg_received.get_source(), moderator_name);
1640 assert!(
1641 app_msg_received.is_publish(),
1642 "message should be a publish message"
1643 );
1644
1645 let content = app_msg_received
1646 .get_payload()
1647 .unwrap()
1648 .as_application_payload()
1649 .unwrap()
1650 .blob
1651 .clone();
1652 assert_eq!(content, app_data);
1653
1654 let ack_msg = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1656 .await
1657 .expect("timeout waiting for ack on participant slim channel")
1658 .expect("channel closed")
1659 .expect("error in ack");
1660
1661 assert_eq!(
1662 ack_msg.get_session_message_type(),
1663 slim_datapath::api::ProtoSessionMessageType::MsgAck,
1664 "message should be an ack"
1665 );
1666 assert_eq!(ack_msg.get_dst(), moderator_name);
1667 assert_eq!(ack_msg.get_id(), app_msg_id);
1668
1669 let mut ack_to_moderator = ack_msg.clone();
1671 ack_to_moderator
1672 .get_slim_header_mut()
1673 .set_incoming_conn(Some(1));
1674
1675 moderator
1676 .on_message_from_slim(ack_to_moderator)
1677 .await
1678 .expect("error processing ack on moderator");
1679
1680 let no_more_moderator_after_ack =
1682 timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1683 assert!(
1684 no_more_moderator_after_ack.is_err(),
1685 "Expected no more messages on moderator slim channel after ack"
1686 );
1687
1688 let no_more_participant_after_ack =
1689 timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1690 assert!(
1691 no_more_participant_after_ack.is_err(),
1692 "Expected no more messages on participant slim channel after ack"
1693 );
1694
1695 let leave_request = Message::builder()
1697 .source(moderator_name.clone())
1698 .destination(participant_name.clone())
1699 .identity(test_identity())
1700 .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1701 .session_message_type(slim_datapath::api::ProtoSessionMessageType::LeaveRequest)
1702 .session_id(session_id)
1703 .message_id(rand::random::<u32>())
1704 .payload(CommandPayload::builder().leave_request().as_content())
1705 .build_publish()
1706 .unwrap();
1707
1708 moderator
1709 .on_message_from_app(leave_request)
1710 .await
1711 .expect("error sending leave request");
1712
1713 let received_leave_request = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1715 .await
1716 .expect("timeout waiting for leave request on moderator slim channel")
1717 .expect("channel closed")
1718 .expect("error in leave request");
1719
1720 assert_eq!(
1721 received_leave_request.get_session_message_type(),
1722 slim_datapath::api::ProtoSessionMessageType::LeaveRequest
1723 );
1724 assert_eq!(received_leave_request.get_dst(), participant_name_id);
1725
1726 let mut leave_request_to_participant = received_leave_request.clone();
1728 leave_request_to_participant
1729 .get_slim_header_mut()
1730 .set_incoming_conn(Some(1));
1731
1732 participant
1733 .on_message_from_slim(leave_request_to_participant)
1734 .await
1735 .expect("error processing leave request on participant");
1736
1737 let leave_reply = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1739 .await
1740 .expect("timeout waiting for leave reply on participant slim channel")
1741 .expect("channel closed")
1742 .expect("error in leave reply");
1743
1744 assert_eq!(
1745 leave_reply.get_session_message_type(),
1746 slim_datapath::api::ProtoSessionMessageType::LeaveReply
1747 );
1748 assert_eq!(leave_reply.get_dst(), moderator_name);
1749
1750 assert_eq!(
1752 rx_spy_participant.recv().await,
1753 Some(SubscriptionCall::RemoveRoute),
1754 "participant should remove route after leave request"
1755 );
1756
1757 let mut leave_reply_to_moderator = leave_reply.clone();
1759 leave_reply_to_moderator
1760 .get_slim_header_mut()
1761 .set_incoming_conn(Some(1));
1762
1763 moderator
1764 .on_message_from_slim(leave_reply_to_moderator)
1765 .await
1766 .expect("error processing leave reply on moderator");
1767
1768 assert_eq!(
1770 rx_spy_moderator.recv().await,
1771 Some(SubscriptionCall::RemoveRoute),
1772 "moderator should remove route after leave reply"
1773 );
1774
1775 let no_more_moderator_final =
1777 timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1778
1779 assert!(
1780 no_more_moderator_final.is_err(),
1781 "Expected no more messages on moderator slim channel after leave"
1782 );
1783
1784 let no_more_participant_final =
1785 timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1786 assert!(
1787 no_more_participant_final.is_err(),
1788 "Expected no more messages on participant slim channel after leave"
1789 );
1790 }
1791
1792 #[traced_test]
1795 #[tokio::test]
1796 async fn test_internal_draining_via_processing_state_switch() {
1797 use super::*;
1798 use tokio::sync::mpsc;
1799 use tracing::debug;
1800
1801 struct InternalDrainHandler {
1803 state: ProcessingState,
1804 messages: Vec<SessionMessage>,
1805 needs_drain: Arc<AtomicBool>,
1806 }
1807
1808 impl InternalDrainHandler {
1809 fn new(needs_drain: Arc<AtomicBool>) -> Self {
1810 Self {
1811 state: ProcessingState::Active,
1812 messages: vec![],
1813 needs_drain,
1814 }
1815 }
1816 }
1817
1818 impl MessageHandler for InternalDrainHandler {
1819 async fn init(&mut self) -> Result<(), SessionError> {
1820 Ok(())
1821 }
1822
1823 async fn on_message(
1824 &mut self,
1825 message: SessionMessage,
1826 ) -> Result<SessionOutput, SessionError> {
1827 debug!(?self.state, "internal-drain-handler received message");
1828 self.messages.push(message);
1829
1830 if self.messages.len() == 2 {
1832 debug!("internal-drain-handler transitioning to draining");
1833 self.state = ProcessingState::Draining;
1834 }
1835
1836 Ok(SessionOutput::new())
1837 }
1838
1839 fn needs_drain(&self) -> bool {
1840 self.needs_drain.load(std::sync::atomic::Ordering::SeqCst)
1841 }
1842
1843 fn processing_state(&self) -> ProcessingState {
1844 self.state
1845 }
1846
1847 async fn on_shutdown(&mut self) -> Result<(), SessionError> {
1848 debug!("shutdown called on handler");
1849 Ok(())
1850 }
1851 }
1852
1853 let (tx_slim, _rx_slim) = mpsc::channel(8);
1855 let (tx_app, _rx_app) = mpsc::unbounded_channel();
1856 let (tx_session, rx_session) = mpsc::channel(32);
1857 let (tx_session_layer, _rx_session_layer) = mpsc::channel(8);
1858
1859 let subscription_manager =
1860 crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
1861 let settings = SessionSettings {
1862 id: 999,
1863 source: ProtoName::from_strings(["org", "ns", "source"]).with_id(1),
1864 destination: ProtoName::from_strings(["org", "ns", "dest"]).with_id(2),
1865 control: ProtoName::from_strings(["org", "ns", "dest"]).with_id(2),
1866 config: SessionConfig {
1867 session_type: ProtoSessionType::PointToPoint,
1868 max_retries: Some(3),
1869 interval: Some(Duration::from_millis(150)),
1870 mls_settings: None,
1871 initiator: true,
1872 metadata: HashMap::new(),
1873 },
1874 direction: Direction::Bidirectional,
1875 slim_tx: tx_slim,
1876 app_tx: tx_app,
1877 tx_session: tx_session.clone(),
1878 tx_to_session_layer: tx_session_layer,
1879 identity_provider: SharedSecret::new("src", SHARED_SECRET).unwrap(),
1880 identity_verifier: SharedSecret::new("src", SHARED_SECRET).unwrap(),
1881 graceful_shutdown_timeout: Some(Duration::from_secs(10)),
1882 subscription_manager,
1883 service_id: String::new(),
1884 };
1885
1886 let needs_drain = Arc::new(AtomicBool::new(true));
1887 let handler = InternalDrainHandler::new(needs_drain.clone());
1888 let cancellation_token = CancellationToken::new();
1889 let cancellation_token_clone = cancellation_token.clone();
1890
1891 let processing_handle = tokio::spawn(async move {
1893 SessionController::processing_loop(
1894 handler,
1895 rx_session,
1896 cancellation_token_clone,
1897 settings,
1898 )
1899 .await
1900 });
1901
1902 tx_session
1904 .send(create_test_message(1, b"first".to_vec()))
1905 .await
1906 .expect("failed to send first message");
1907
1908 tokio::time::sleep(Duration::from_millis(100)).await;
1909
1910 assert!(logs_contain("internal-drain-handler received message"));
1911
1912 tx_session
1914 .send(create_test_message(2, b"second".to_vec()))
1915 .await
1916 .expect("failed to send second message");
1917
1918 tokio::time::sleep(Duration::from_millis(100)).await;
1919
1920 assert!(logs_contain("internal-drain-handler received message"));
1921
1922 assert!(logs_contain(
1923 "internal-drain-handler transitioning to draining"
1924 ));
1925
1926 tx_session
1928 .send(create_test_message(3, b"third".to_vec()))
1929 .await
1930 .expect("failed to send third message");
1931
1932 tokio::time::sleep(Duration::from_millis(100)).await;
1933
1934 assert!(logs_contain(
1935 "session is draining, rejecting new messages from application"
1936 ));
1937
1938 needs_drain.store(false, std::sync::atomic::Ordering::SeqCst);
1940
1941 cancellation_token.cancel();
1943
1944 tokio::time::sleep(Duration::from_millis(100)).await;
1945
1946 tx_session
1948 .send(SessionMessage::StartDrain {
1949 grace_period: std::time::Duration::from_millis(100),
1950 })
1951 .await
1952 .expect("failed to send timeout message");
1953
1954 processing_handle.await.expect("processing loop panicked");
1956 }
1957 struct DrainableHandler {
1961 messages_received: Arc<tokio::sync::Mutex<Vec<SessionMessage>>>,
1962 needs_drain: Arc<AtomicBool>,
1963 shutdown_called: Arc<tokio::sync::Mutex<bool>>,
1964 drain_delay: Option<Duration>,
1965 }
1966
1967 impl DrainableHandler {
1968 fn new() -> Self {
1969 Self {
1970 messages_received: Arc::new(tokio::sync::Mutex::new(Vec::new())),
1971 needs_drain: Arc::new(AtomicBool::new(false)),
1972 shutdown_called: Arc::new(tokio::sync::Mutex::new(false)),
1973 drain_delay: None,
1974 }
1975 }
1976
1977 fn with_needs_drain(self, needs_drain: bool) -> Self {
1978 self.needs_drain
1979 .store(needs_drain, std::sync::atomic::Ordering::SeqCst);
1980 self
1981 }
1982
1983 #[allow(dead_code)]
1984 fn with_drain_delay(mut self, delay: Duration) -> Self {
1985 self.drain_delay = Some(delay);
1986 self
1987 }
1988
1989 #[allow(dead_code)]
1990 async fn get_messages_count(&self) -> usize {
1991 self.messages_received.lock().await.len()
1992 }
1993
1994 #[allow(dead_code)]
1995 async fn was_shutdown_called(&self) -> bool {
1996 *self.shutdown_called.lock().await
1997 }
1998 }
1999
2000 impl MessageHandler for DrainableHandler {
2001 async fn init(&mut self) -> Result<(), SessionError> {
2002 Ok(())
2003 }
2004
2005 async fn on_message(
2006 &mut self,
2007 message: SessionMessage,
2008 ) -> Result<SessionOutput, SessionError> {
2009 self.messages_received.lock().await.push(message);
2010 Ok(SessionOutput::new())
2011 }
2012
2013 fn needs_drain(&self) -> bool {
2014 self.needs_drain.load(std::sync::atomic::Ordering::SeqCst)
2015 }
2016
2017 async fn on_shutdown(&mut self) -> Result<(), SessionError> {
2018 if let Some(delay) = self.drain_delay {
2019 tokio::time::sleep(delay).await;
2020 }
2021 *self.shutdown_called.lock().await = true;
2022 Ok(())
2023 }
2024 }
2025
2026 fn create_test_settings(
2028 graceful_shutdown_timeout: Option<Duration>,
2029 ) -> SessionSettings<SharedSecret, SharedSecret> {
2030 let (tx_slim, _rx_slim) = tokio::sync::mpsc::channel(10);
2031 let (tx_app, _rx_app) = tokio::sync::mpsc::unbounded_channel();
2032 let (tx_session, _rx_session) = tokio::sync::mpsc::channel(10);
2033 let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10);
2034
2035 let subscription_manager =
2036 crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
2037 SessionSettings {
2038 id: 1,
2039 source: ProtoName::from_strings(["org", "ns", "test"]).with_id(1),
2040 destination: ProtoName::from_strings(["org", "ns", "test"]).with_id(2),
2041 control: ProtoName::from_strings(["org", "ns", "test"]).with_id(2),
2042 config: SessionConfig {
2043 session_type: ProtoSessionType::PointToPoint,
2044 max_retries: Some(5),
2045 interval: Some(Duration::from_millis(200)),
2046 mls_settings: None,
2047 initiator: true,
2048 metadata: HashMap::new(),
2049 },
2050 direction: Direction::Bidirectional,
2051 slim_tx: tx_slim,
2052 app_tx: tx_app,
2053 tx_session,
2054 tx_to_session_layer: tx_session_layer,
2055 identity_provider: SharedSecret::new("test", SHARED_SECRET).unwrap(),
2056 identity_verifier: SharedSecret::new("test", SHARED_SECRET).unwrap(),
2057 graceful_shutdown_timeout,
2058 subscription_manager,
2059 service_id: String::new(),
2060 }
2061 }
2062
2063 fn create_test_message(message_id: u32, payload: Vec<u8>) -> SessionMessage {
2065 SessionMessage::OnMessage {
2066 message: Message::builder()
2067 .source(ProtoName::from_strings(["org", "ns", "test"]).with_id(1))
2068 .destination(ProtoName::from_strings(["org", "ns", "test"]).with_id(2))
2069 .identity(test_identity())
2070 .forward_to(1)
2071 .session_type(ProtoSessionType::PointToPoint)
2072 .session_message_type(ProtoSessionMessageType::Msg)
2073 .session_id(1)
2074 .message_id(message_id)
2075 .application_payload("test", payload)
2076 .build_publish()
2077 .unwrap(),
2078 direction: MessageDirection::South,
2079 ack_tx: None,
2080 }
2081 }
2082
2083 async fn count_on_messages(messages: &Arc<tokio::sync::Mutex<Vec<SessionMessage>>>) -> usize {
2084 let messages = messages.lock().await;
2085 messages
2086 .iter()
2087 .filter(|msg| matches!(msg, SessionMessage::OnMessage { .. }))
2088 .count()
2089 }
2090
2091 fn spawn_processing_loop(
2093 handler: DrainableHandler,
2094 rx: tokio::sync::mpsc::Receiver<SessionMessage>,
2095 cancellation_token: CancellationToken,
2096 settings: SessionSettings<SharedSecret, SharedSecret>,
2097 ) -> tokio::task::JoinHandle<()> {
2098 tokio::spawn(async move {
2099 SessionController::processing_loop(handler, rx, cancellation_token, settings).await;
2100 })
2101 }
2102
2103 #[tokio::test]
2104 async fn test_draining_processes_queued_messages() {
2105 let handler = DrainableHandler::new();
2106 let messages_received = handler.messages_received.clone();
2107 let shutdown_called = handler.shutdown_called.clone();
2108
2109 let (tx, rx) = tokio::sync::mpsc::channel(10);
2110 let cancellation_token = CancellationToken::new();
2111 let token_clone = cancellation_token.clone();
2112
2113 let settings = create_test_settings(Some(Duration::from_secs(2)));
2114 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2115
2116 tx.send(create_test_message(1, vec![1, 2, 3]))
2118 .await
2119 .unwrap();
2120 tx.send(create_test_message(2, vec![4, 5, 6]))
2121 .await
2122 .unwrap();
2123 tx.send(create_test_message(3, vec![7, 8, 9]))
2124 .await
2125 .unwrap();
2126
2127 tokio::time::sleep(Duration::from_millis(50)).await;
2129
2130 token_clone.cancel();
2132
2133 drop(tx);
2135
2136 timeout(Duration::from_secs(3), processing_task)
2138 .await
2139 .expect("timeout waiting for processing loop")
2140 .expect("processing loop panicked");
2141
2142 let processed_messages = count_on_messages(&messages_received).await;
2144 assert_eq!(
2145 processed_messages, 3,
2146 "All queued messages should be processed during draining"
2147 );
2148 assert!(
2149 *shutdown_called.lock().await,
2150 "Shutdown should have been called"
2151 );
2152 }
2153
2154 #[tokio::test]
2155 async fn test_draining_with_needs_drain_true() {
2156 let handler = DrainableHandler::new().with_needs_drain(true);
2157 let messages_received = handler.messages_received.clone();
2158 let shutdown_called = handler.shutdown_called.clone();
2159
2160 let (tx, rx) = tokio::sync::mpsc::channel(10);
2161 let cancellation_token = CancellationToken::new();
2162 let token_clone = cancellation_token.clone();
2163
2164 let settings = create_test_settings(Some(Duration::from_secs(2)));
2165 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2166
2167 tx.send(create_test_message(1, vec![1, 2, 3]))
2169 .await
2170 .unwrap();
2171 tokio::time::sleep(Duration::from_millis(50)).await;
2172
2173 token_clone.cancel();
2175 drop(tx);
2176
2177 timeout(Duration::from_secs(3), processing_task)
2179 .await
2180 .expect("timeout waiting for processing loop")
2181 .expect("processing loop panicked");
2182
2183 let processed_messages = count_on_messages(&messages_received).await;
2185 assert_eq!(processed_messages, 1, "Message should be processed");
2186 assert!(
2187 *shutdown_called.lock().await,
2188 "Shutdown should have been called after draining"
2189 );
2190 }
2191
2192 #[tokio::test]
2193 async fn test_draining_with_needs_drain_false() {
2194 let handler = DrainableHandler::new().with_needs_drain(false);
2195 let messages_received = handler.messages_received.clone();
2196 let shutdown_called = handler.shutdown_called.clone();
2197
2198 let (tx, rx) = tokio::sync::mpsc::channel(10);
2199 let cancellation_token = CancellationToken::new();
2200 let token_clone = cancellation_token.clone();
2201
2202 let settings = create_test_settings(Some(Duration::from_secs(2)));
2203
2204 let start_time = tokio::time::Instant::now();
2205 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2206
2207 tx.send(create_test_message(1, vec![1, 2, 3]))
2209 .await
2210 .unwrap();
2211 tokio::time::sleep(Duration::from_millis(50)).await;
2212
2213 token_clone.cancel();
2215 drop(tx);
2216
2217 timeout(Duration::from_secs(1), processing_task)
2219 .await
2220 .expect("timeout waiting for processing loop")
2221 .expect("processing loop panicked");
2222
2223 let elapsed = start_time.elapsed();
2224
2225 let processed_messages = count_on_messages(&messages_received).await;
2227 assert_eq!(processed_messages, 1, "Message should be processed");
2228 assert!(
2229 *shutdown_called.lock().await,
2230 "Shutdown should have been called"
2231 );
2232 assert!(
2233 elapsed < Duration::from_millis(500),
2234 "Should exit quickly when no draining needed"
2235 );
2236 }
2237
2238 #[tokio::test]
2239 async fn test_draining_timeout_enforced() {
2240 let handler = DrainableHandler::new().with_needs_drain(true);
2242
2243 let (tx, rx) = tokio::sync::mpsc::channel(10);
2244 let cancellation_token = CancellationToken::new();
2245 let token_clone = cancellation_token.clone();
2246
2247 let settings = create_test_settings(Some(Duration::from_millis(500)));
2248 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2249
2250 tokio::time::sleep(Duration::from_millis(50)).await;
2252
2253 let start_time = tokio::time::Instant::now();
2254
2255 token_clone.cancel();
2257
2258 let send_task = tokio::spawn(async move {
2261 for i in 0..10 {
2262 tokio::time::sleep(Duration::from_millis(100)).await;
2263 if tx
2264 .send(create_test_message(i, vec![i as u8]))
2265 .await
2266 .is_err()
2267 {
2268 break;
2269 }
2270 }
2271 });
2272
2273 timeout(Duration::from_secs(2), processing_task)
2275 .await
2276 .expect("timeout waiting for processing loop")
2277 .expect("processing loop panicked");
2278
2279 let elapsed = start_time.elapsed();
2280
2281 assert!(
2283 elapsed >= Duration::from_millis(400),
2284 "Should wait at least close to the timeout period"
2285 );
2286 assert!(
2287 elapsed < Duration::from_secs(2),
2288 "Should respect the timeout and exit, not wait forever"
2289 );
2290
2291 send_task.abort();
2293 }
2294
2295 #[tokio::test]
2296 async fn test_draining_no_messages_in_queue() {
2297 let handler = DrainableHandler::new().with_needs_drain(true);
2298 let messages_received = handler.messages_received.clone();
2299 let shutdown_called = handler.shutdown_called.clone();
2300
2301 let (tx, rx) = tokio::sync::mpsc::channel(10);
2302 let cancellation_token = CancellationToken::new();
2303 let token_clone = cancellation_token.clone();
2304
2305 let settings = create_test_settings(Some(Duration::from_secs(1)));
2306 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2307
2308 token_clone.cancel();
2310 drop(tx);
2311
2312 timeout(Duration::from_secs(2), processing_task)
2314 .await
2315 .expect("timeout waiting for processing loop")
2316 .expect("processing loop panicked");
2317
2318 let processed_messages = count_on_messages(&messages_received).await;
2320 assert_eq!(processed_messages, 0, "No messages should be processed");
2321 assert!(
2322 *shutdown_called.lock().await,
2323 "Shutdown should still be called"
2324 );
2325 }
2326
2327 #[tokio::test]
2328 async fn test_draining_messages_after_cancellation_processed() {
2329 let handler = DrainableHandler::new();
2330 let messages_received = handler.messages_received.clone();
2331
2332 let (tx, rx) = tokio::sync::mpsc::channel(10);
2333 let cancellation_token = CancellationToken::new();
2334 let token_clone = cancellation_token.clone();
2335
2336 let settings = create_test_settings(Some(Duration::from_secs(2)));
2337
2338 tx.send(create_test_message(1, vec![1, 2, 3]))
2340 .await
2341 .unwrap();
2342 tx.send(create_test_message(2, vec![4, 5, 6]))
2343 .await
2344 .unwrap();
2345
2346 let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2348
2349 tokio::time::sleep(Duration::from_millis(50)).await;
2351
2352 token_clone.cancel();
2354
2355 drop(tx);
2357
2358 timeout(Duration::from_secs(3), processing_task)
2360 .await
2361 .expect("timeout waiting for processing loop")
2362 .expect("processing loop panicked");
2363
2364 let processed_messages = count_on_messages(&messages_received).await;
2366 assert_eq!(
2367 processed_messages, 2,
2368 "Messages in queue during cancellation should be processed"
2369 );
2370 }
2371}