Skip to main content

slim_session/
session_controller.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4// Standard library imports
5use std::{collections::HashMap, time::Duration};
6
7use display_error_chain::ErrorChainExt;
8use parking_lot::Mutex;
9use tokio::sync::{self, oneshot};
10// Third-party crates
11use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug};
13
14use slim_auth::traits::{TokenProvider, Verifier};
15use slim_datapath::{
16    api::{
17        CommandPayload, Content, ProtoMessage as Message, ProtoSessionMessageType,
18        ProtoSessionType, SlimHeader,
19    },
20    messages::{Name, utils::SlimHeaderFlags},
21};
22
23// Local crate
24use crate::{
25    MessageDirection, SessionError, Transmitter,
26    common::SessionMessage,
27    completion_handle::CompletionHandle,
28    controller_sender::{ControllerSender, PING_INTERVAL},
29    session_builder::{ForController, SessionBuilder},
30    session_config::SessionConfig,
31    session_settings::SessionSettings,
32    traits::{MessageHandler, ProcessingState},
33};
34
35pub struct SessionController {
36    /// session id
37    pub(crate) id: u32,
38
39    /// local name
40    pub(crate) source: Name,
41
42    /// group or remote endpoint name
43    pub(crate) destination: Name,
44
45    /// session config
46    pub(crate) config: SessionConfig,
47
48    /// channel to send messages to the processing loop
49    tx_controller: sync::mpsc::Sender<SessionMessage>,
50
51    /// use in drop implementation to gracefully close the processing loop
52    pub(crate) cancellation_token: CancellationToken,
53
54    /// handle for the processing loop
55    handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
56}
57
58impl SessionController {
59    /// Returns a new SessionBuilder for constructing a SessionController
60    pub fn builder<P, V>() -> SessionBuilder<P, V, ForController>
61    where
62        P: TokenProvider + Send + Sync + Clone + 'static,
63        V: Verifier + Send + Sync + Clone + 'static,
64    {
65        SessionBuilder::for_controller()
66    }
67
68    /// Internal constructor for the builder to use
69    #[allow(clippy::too_many_arguments)]
70    pub(crate) fn from_parts<I, P, V, M>(
71        id: u32,
72        source: Name,
73        destination: Name,
74        config: SessionConfig,
75        settings: SessionSettings<P, V, M>,
76        tx: sync::mpsc::Sender<SessionMessage>,
77        rx: sync::mpsc::Receiver<SessionMessage>,
78        inner: I,
79    ) -> Self
80    where
81        I: MessageHandler + Send + Sync + 'static,
82        P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
83        V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
84        M: crate::subscription_manager::SubscriptionOps,
85    {
86        // Spawn the processing loop
87        let cancellation_token = CancellationToken::new();
88
89        // setup tracing context
90        let span = tracing::debug_span!(
91            parent: None,
92            "session_controller_processing_loop",
93            session_id = id,
94            service_id = %settings.service_id,
95            source = %source,
96            destination = %destination,
97            session_type = ?config.session_type
98        );
99
100        let handle = tokio::spawn(
101            Self::processing_loop(inner, rx, cancellation_token.clone(), settings).instrument(span),
102        );
103
104        Self {
105            id,
106            source,
107            destination,
108            config,
109            tx_controller: tx,
110            cancellation_token,
111            handle: Mutex::new(Some(handle)),
112        }
113    }
114
115    /// Internal processing loop that handles messages with mutable access
116    fn enter_draining_state<P, V, M>(
117        shutdown_deadline: &mut std::pin::Pin<&mut tokio::time::Sleep>,
118        settings: &SessionSettings<P, V, M>,
119    ) where
120        P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
121        V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
122        M: crate::subscription_manager::SubscriptionOps,
123    {
124        let shutdown_timeout = settings
125            .graceful_shutdown_timeout
126            .unwrap_or(Duration::from_secs(60));
127        shutdown_deadline
128            .as_mut()
129            .reset(tokio::time::Instant::now() + shutdown_timeout);
130    }
131
132    async fn processing_loop<P, V, M>(
133        mut inner: impl MessageHandler + 'static,
134        mut rx: sync::mpsc::Receiver<SessionMessage>,
135        cancellation_token: CancellationToken,
136        settings: SessionSettings<P, V, M>,
137    ) where
138        P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
139        V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
140        M: crate::subscription_manager::SubscriptionOps,
141    {
142        // Start with an infinite timeout (will be updated on graceful shutdown)
143        let mut shutdown_deadline = std::pin::pin!(tokio::time::sleep(Duration::MAX));
144
145        // Init the inner components
146        if let Err(e) = inner.init().await {
147            tracing::error!(error = %e.chain(), "error during initialization of session");
148        }
149
150        loop {
151            tokio::select! {
152                _ = cancellation_token.cancelled(), if inner.processing_state() == ProcessingState::Active => {
153                    // Update the timeout to the configured grace period
154                    let shutdown_timeout = settings.graceful_shutdown_timeout
155                        .unwrap_or(Duration::from_secs(60)); // Default 60 seconds if not configured
156
157                    // Finish any ongoing processing before starting drain
158                    debug!("consuming pending messages before entering draining state");
159                    while let Ok(msg) = rx.try_recv() {
160                        if let Err(e) = inner.on_message(msg).await {
161                            tracing::error!(error = %e.chain(), "error processing message during draining - close immediately.");
162                            break;
163                        }
164                    }
165
166                    // Send drain to message to the inner to notify the beginning of the drain
167                    if let Err(e) = inner.on_message(SessionMessage::StartDrain {
168                        grace_period: shutdown_timeout
169                    }).await {
170                        tracing::error!(error = %e.chain(),  "error during start drain");
171                        break;
172                    }
173
174                    Self::enter_draining_state(&mut shutdown_deadline, &settings);
175
176                    debug!("cancellation requested, entering draining state");
177                }
178                _ = &mut shutdown_deadline => {
179                    debug!("graceful shutdown timeout reached, forcing exit");
180                    break;
181                }
182                msg = rx.recv() => {
183                    match msg {
184                        Some(session_message) => {
185                            // Handle GetParticipantsList query immediately without going through the handler
186                            if let SessionMessage::GetParticipantsList { tx } = session_message {
187                                let participants_list = inner.participants_list();
188                                let _ = tx.send(participants_list);
189                                continue;
190                            }
191
192                            let draining = inner.processing_state() == ProcessingState::Draining;
193
194                            // if draining and message is sent by the application, reject it
195                            if draining && matches!(session_message, SessionMessage::OnMessage { direction: MessageDirection::South, .. }) {
196                                tracing::debug!("session is draining, rejecting new messages from application");
197                                if let SessionMessage::OnMessage { ack_tx: Some(ack_tx), .. } = session_message {
198                                    let _ = ack_tx.send(Err(SessionError::SessionDrainingDrop));
199                                }
200                                continue;
201                            }
202
203                            if let Err(e) = inner.on_message(session_message).await {
204                                debug!(
205                                    error=%e,
206                                    "Error processing message{}",
207                                    if draining { " during graceful shutdown" } else { "" }
208                                );
209                                if draining {
210                                    debug!("Exiting processing loop due to error while draining");
211                                    break;
212                                }
213                            } else {
214                                // If we were active before processing and the handler switched to draining,
215                                // start (or reset) the graceful shutdown deadline just like on cancellation.
216                                if !draining && inner.processing_state() == ProcessingState::Draining {
217                                    debug!("internal component requested draining, entering draining state");
218                                    Self::enter_draining_state(&mut shutdown_deadline, &settings);
219                                }
220                            }
221                        }
222                        None => {
223                            debug!("Session channel closed, no more messages can arrive - exiting processing loop");
224                            break;
225                        }
226                    }
227                }
228            }
229
230            // If we are in draining state and the inner component does not require drain, exit
231            if inner.processing_state() == ProcessingState::Draining && !inner.needs_drain() {
232                debug!("draining complete, exiting processing loop");
233                break;
234            }
235        }
236
237        // Perform final shutdown
238        if let Err(e) = inner.on_shutdown().await {
239            tracing::error!(error = %e.chain(), "error during shutdown of session");
240        }
241    }
242
243    /// getters
244    pub fn id(&self) -> u32 {
245        self.id
246    }
247
248    pub fn source(&self) -> &Name {
249        &self.source
250    }
251
252    pub fn dst(&self) -> &Name {
253        &self.destination
254    }
255
256    pub fn session_type(&self) -> ProtoSessionType {
257        self.config.session_type
258    }
259
260    pub fn metadata(&self) -> HashMap<String, String> {
261        self.config.metadata.clone()
262    }
263
264    pub fn session_config(&self) -> SessionConfig {
265        self.config.clone()
266    }
267
268    pub fn is_initiator(&self) -> bool {
269        self.config.initiator
270    }
271
272    pub async fn participants_list(&self) -> Result<Vec<Name>, SessionError> {
273        let (tx, rx) = oneshot::channel();
274
275        // Send query to the processing loop
276        self.tx_controller
277            .send(SessionMessage::GetParticipantsList { tx })
278            .await
279            .map_err(|_| SessionError::ParticipantsListQueryFailed)?;
280
281        // Wait for response
282        rx.await
283            .map_err(|_| SessionError::ParticipantsListQueryFailed)
284    }
285
286    async fn on_message(
287        &self,
288        message: Message,
289        direction: MessageDirection,
290        ack_tx: Option<oneshot::Sender<Result<(), SessionError>>>,
291    ) -> Result<(), SessionError> {
292        self.tx_controller
293            .send(SessionMessage::OnMessage {
294                message,
295                direction,
296                ack_tx,
297            })
298            .await
299            .map_err(|_e| SessionError::SessionControllerSendFailed)
300    }
301
302    /// Send a message to the controller for processing
303    pub async fn on_message_from_app(
304        &self,
305        message: Message,
306    ) -> Result<CompletionHandle, SessionError> {
307        let (ack_tx, ack_rx) = oneshot::channel();
308        self.on_message(message, MessageDirection::South, Some(ack_tx))
309            .await?;
310
311        let ret = CompletionHandle::from_oneshot_receiver(ack_rx);
312
313        Ok(ret)
314    }
315
316    /// Send a message to the controller for processing
317    pub async fn on_message_from_slim(&self, message: Message) -> Result<(), SessionError> {
318        self.on_message(message, MessageDirection::North, None)
319            .await
320    }
321
322    /// Send an error message to the controller for processing
323    pub async fn on_error_message_from_slim(
324        &self,
325        error: SessionError,
326    ) -> Result<(), SessionError> {
327        self.tx_controller
328            .send(SessionMessage::MessageError { error })
329            .await
330            .map_err(|_e| SessionError::SessionControllerSendFailed)
331    }
332
333    pub fn close(&self) -> Result<tokio::task::JoinHandle<()>, SessionError> {
334        self.cancellation_token.cancel();
335
336        self.handle
337            .lock()
338            .take()
339            .ok_or(SessionError::SessionAlreadyClosed)
340    }
341
342    pub async fn publish_message(
343        &self,
344        message: Message,
345    ) -> Result<CompletionHandle, SessionError> {
346        self.on_message_from_app(message).await
347    }
348
349    /// Publish a message to a specific connection (forward_to)
350    pub async fn publish_to(
351        &self,
352        name: &Name,
353        forward_to: u64,
354        blob: Vec<u8>,
355        payload_type: Option<String>,
356        metadata: Option<HashMap<String, String>>,
357    ) -> Result<CompletionHandle, SessionError> {
358        self.publish_with_flags(
359            name,
360            SlimHeaderFlags::default().with_forward_to(forward_to),
361            blob,
362            payload_type,
363            metadata,
364        )
365        .await
366    }
367
368    /// Publish a message to a specific app name
369    pub async fn publish(
370        &self,
371        name: &Name,
372        blob: Vec<u8>,
373        payload_type: Option<String>,
374        metadata: Option<HashMap<String, String>>,
375    ) -> Result<CompletionHandle, SessionError> {
376        self.publish_with_flags(
377            name,
378            SlimHeaderFlags::default(),
379            blob,
380            payload_type,
381            metadata,
382        )
383        .await
384    }
385
386    /// Publish a message with specific flags
387    pub async fn publish_with_flags(
388        &self,
389        name: &Name,
390        flags: SlimHeaderFlags,
391        blob: Vec<u8>,
392        payload_type: Option<String>,
393        metadata: Option<HashMap<String, String>>,
394    ) -> Result<CompletionHandle, SessionError> {
395        let ct = payload_type.unwrap_or_else(|| "msg".to_string());
396
397        let mut msg = Message::builder()
398            .source(self.source().clone())
399            .destination(name.clone())
400            .identity("")
401            .flags(flags)
402            .session_type(self.session_type())
403            .session_message_type(ProtoSessionMessageType::Msg)
404            .session_id(self.id())
405            .message_id(rand::random::<u32>()) // this will be changed by the session itself
406            .application_payload(&ct, blob)
407            .build_publish()?;
408        if let Some(map) = metadata
409            && !map.is_empty()
410        {
411            msg.set_metadata_map(map);
412        }
413
414        // southbound=true means towards slim
415        self.publish_message(msg).await
416    }
417
418    /// Creates a discovery request message with minimum required information
419    fn create_discovery_request(&self, destination: &Name) -> Result<Message, SessionError> {
420        let payload = CommandPayload::builder()
421            .discovery_request(None)
422            .as_content();
423
424        let msg = Message::builder()
425            .source(self.source().clone())
426            .destination(destination.clone())
427            .identity("")
428            .session_type(self.session_type())
429            .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
430            .session_id(self.id())
431            .message_id(rand::random::<u32>())
432            .payload(payload)
433            .build_publish()?;
434
435        Ok(msg)
436    }
437
438    pub(crate) async fn invite_participant_internal(
439        &self,
440        destination: &Name,
441    ) -> Result<CompletionHandle, SessionError> {
442        let msg = self.create_discovery_request(destination)?;
443        self.publish_message(msg).await
444    }
445
446    pub async fn invite_participant(
447        &self,
448        destination: &Name,
449    ) -> Result<CompletionHandle, SessionError> {
450        match self.session_type() {
451            ProtoSessionType::PointToPoint => Err(SessionError::CannotInviteToP2P),
452            ProtoSessionType::Multicast => {
453                if !self.is_initiator() {
454                    return Err(SessionError::NotInitiator);
455                }
456                self.invite_participant_internal(destination).await
457            }
458            _ => Err(SessionError::SessionTypeUnknown(self.session_type())),
459        }
460    }
461
462    pub async fn remove_participant(
463        &self,
464        destination: &Name,
465    ) -> Result<CompletionHandle, SessionError> {
466        match self.session_type() {
467            ProtoSessionType::PointToPoint => Err(SessionError::CannotRemoveFromP2P),
468            ProtoSessionType::Multicast => {
469                if !self.is_initiator() {
470                    return Err(SessionError::NotInitiator);
471                }
472                let msg = Message::builder()
473                    .source(self.source().clone())
474                    .destination(destination.clone().with_id(Name::NULL_COMPONENT))
475                    .identity("")
476                    .session_type(ProtoSessionType::Multicast)
477                    .session_message_type(ProtoSessionMessageType::LeaveRequest)
478                    .session_id(self.id())
479                    .message_id(rand::random::<u32>())
480                    .payload(CommandPayload::builder().leave_request(None).as_content())
481                    .build_publish()?;
482                self.publish_message(msg).await
483            }
484            _ => Err(SessionError::SessionTypeUnknown(self.session_type())),
485        }
486    }
487}
488
489impl Drop for SessionController {
490    fn drop(&mut self) {
491        self.cancellation_token.cancel();
492    }
493}
494
495pub fn handle_channel_discovery_message(
496    message: &Message,
497    app_name: &Name,
498    session_id: u32,
499    session_type: ProtoSessionType,
500) -> Result<Message, SessionError> {
501    let destination = message.get_source();
502
503    // the destination of the discovery message may be different from the name of
504    // application itself. This can happen if the application subscribes to multiple
505    // service names. So we can reply using as a source the destination name of
506    // the discovery message but setting the application id
507
508    let mut source = message.get_dst();
509    source.set_id(app_name.id());
510    let msg_id = message.get_id();
511
512    let slim_header = SlimHeader::new(
513        &source,
514        &destination,
515        "", // the identity will be added by the identity interceptor
516        Some(SlimHeaderFlags::default().with_forward_to(message.get_incoming_conn())),
517    );
518
519    let msg = Message::builder()
520        .with_slim_header(slim_header)
521        .session_type(session_type)
522        .session_message_type(ProtoSessionMessageType::DiscoveryReply)
523        .session_id(session_id)
524        .message_id(msg_id)
525        .payload(CommandPayload::builder().discovery_reply().as_content())
526        .build_publish()?;
527
528    Ok(msg)
529}
530
531pub(crate) struct SessionControllerCommon<
532    P,
533    V,
534    M = crate::subscription_manager::SubscriptionManager,
535> where
536    P: TokenProvider + Send + Sync + Clone + 'static,
537    V: Verifier + Send + Sync + Clone + 'static,
538    M: crate::subscription_manager::SubscriptionOps,
539{
540    /// common session fields
541    pub(crate) settings: SessionSettings<P, V, M>,
542
543    /// sender for command messages
544    pub(crate) sender: ControllerSender,
545
546    /// processing state
547    pub(crate) processing_state: ProcessingState,
548
549    /// Maps (kind, name, conn) → subscription_id for route/subscription tracking.
550    subscription_ids: HashMap<(SubscriptionKind, Name, u64), u64>,
551}
552
553/// Distinguishes route entries from subscription entries in the subscription_ids map.
554/// Both can share the same `(Name, conn)` pair, so this enum prevents key collisions.
555#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
556pub(crate) enum SubscriptionKind {
557    /// A recv_from route (`set_route` / `remove_route`).
558    Route,
559    /// A forward_to subscription (`subscribe` / `unsubscribe`).
560    Subscription,
561}
562
563impl<P, V, M> SessionControllerCommon<P, V, M>
564where
565    P: TokenProvider + Send + Sync + Clone + 'static,
566    V: Verifier + Send + Sync + Clone + 'static,
567    M: crate::subscription_manager::SubscriptionOps,
568{
569    pub(crate) fn new(settings: SessionSettings<P, V, M>) -> Self {
570        // Create the controller sender.
571        let controller_sender = ControllerSender::new(
572            settings.config.get_timer_settings(),
573            settings.source.clone(),
574            settings.config.session_type,
575            settings.id,
576            Some(PING_INTERVAL),
577            settings.config.initiator,
578            // send messages to slim/app
579            settings.tx.clone(),
580            // send signal to the controller
581            settings.tx_session.clone(),
582        );
583
584        SessionControllerCommon {
585            settings,
586            sender: controller_sender,
587            processing_state: ProcessingState::Active,
588            subscription_ids: HashMap::new(),
589        }
590    }
591
592    /// internal and helper functions
593    pub(crate) async fn send_to_slim(&self, message: Message) -> Result<(), SessionError> {
594        self.settings.tx.send_to_slim(Ok(message)).await
595    }
596
597    /// Send error message to the application
598    pub(crate) async fn send_to_app(&self, error: SessionError) -> Result<(), SessionError> {
599        self.settings.tx.send_to_app(Err(error)).await
600    }
601
602    /// Send control message without creating ack channel (for internal use by moderator)
603    pub(crate) async fn send_with_timer(&mut self, message: Message) -> Result<(), SessionError> {
604        self.sender.on_message(&message).await
605    }
606
607    async fn await_subscription_ack(
608        rx: tokio::sync::oneshot::Receiver<
609            Result<(), crate::subscription_manager::SubscriptionAckError>,
610        >,
611    ) -> Result<(), SessionError> {
612        crate::subscription_manager::SubscriptionManager::await_ack(rx)
613            .await
614            .map_err(SessionError::SubscriptionAckFailed)
615    }
616
617    pub(crate) async fn add_route(&mut self, name: Name, conn: u64) -> Result<(), SessionError> {
618        if name == self.settings.source {
619            // We never add a route for ourselves
620            return Ok(());
621        }
622
623        let (subscription_id, rx) = self
624            .settings
625            .subscription_manager
626            .set_route(&self.settings.source, &name, conn)
627            .await
628            .map_err(SessionError::SubscriptionAckFailed)?;
629        Self::await_subscription_ack(rx).await?;
630
631        debug!(%name, %conn, %subscription_id, source = %self.settings.source, "route added");
632
633        self.subscription_ids
634            .insert((SubscriptionKind::Route, name, conn), subscription_id);
635
636        Ok(())
637    }
638
639    pub(crate) async fn delete_route(&mut self, name: Name, conn: u64) -> Result<(), SessionError> {
640        if name == self.settings.source {
641            // We never remove a route for ourselves
642            return Ok(());
643        }
644
645        let key = (SubscriptionKind::Route, name, conn);
646        let subscription_id = self.subscription_ids.remove(&key);
647        let (_, name, conn) = key;
648        match subscription_id {
649            Some(subscription_id) => {
650                let rx = self
651                    .settings
652                    .subscription_manager
653                    .remove_route(&self.settings.source, &name, subscription_id, conn)
654                    .await
655                    .map_err(SessionError::SubscriptionAckFailed)?;
656
657                Self::await_subscription_ack(rx).await?;
658                tracing::debug!(%name, %conn, %subscription_id, "route deleted");
659            }
660            None => {
661                tracing::warn!(
662                    %name, %conn, io = %self.settings.source,
663                    "no subscription_id found for route, skipping delete"
664                );
665            }
666        }
667
668        Ok(())
669    }
670
671    pub(crate) async fn add_subscription(
672        &mut self,
673        name: Name,
674        conn: u64,
675    ) -> Result<(), SessionError> {
676        let (subscription_id, rx) = self
677            .settings
678            .subscription_manager
679            .subscribe(&self.settings.source, &name, Some(conn))
680            .await
681            .map_err(SessionError::SubscriptionAckFailed)?;
682
683        Self::await_subscription_ack(rx).await?;
684
685        debug!(%name, %conn, %subscription_id, "subscription added");
686
687        self.subscription_ids.insert(
688            (SubscriptionKind::Subscription, name, conn),
689            subscription_id,
690        );
691
692        Ok(())
693    }
694
695    pub(crate) async fn delete_subscription(
696        &mut self,
697        name: Name,
698        conn: u64,
699    ) -> Result<(), SessionError> {
700        let key = (SubscriptionKind::Subscription, name, conn);
701        let subscription_id = self.subscription_ids.remove(&key);
702        let (_, name, conn) = key;
703        match subscription_id {
704            Some(subscription_id) => {
705                let rx = self
706                    .settings
707                    .subscription_manager
708                    .unsubscribe(&self.settings.source, &name, subscription_id, Some(conn))
709                    .await
710                    .map_err(SessionError::SubscriptionAckFailed)?;
711
712                Self::await_subscription_ack(rx).await?;
713                debug!(%name, %conn, %subscription_id, "subscription deleted");
714            }
715            None => {
716                tracing::warn!(
717                    %name, %conn,
718                    "no subscription_id found for subscription, skipping delete"
719                );
720            }
721        }
722
723        Ok(())
724    }
725
726    pub(crate) fn create_control_message(
727        &mut self,
728        dst: &Name,
729        message_type: ProtoSessionMessageType,
730        message_id: u32,
731        payload: Content,
732        broadcast: bool,
733    ) -> Result<Message, SessionError> {
734        let mut builder = Message::builder()
735            .source(self.settings.source.clone())
736            .destination(dst.clone())
737            .identity("")
738            .session_type(self.settings.config.session_type)
739            .session_message_type(message_type)
740            .session_id(self.settings.id)
741            .message_id(message_id)
742            .payload(payload);
743
744        if broadcast {
745            builder = builder.fanout(256);
746        }
747
748        let ret = builder.build_publish()?;
749
750        Ok(ret)
751    }
752
753    /// Send control message without creating ack channel (for internal use by moderator)
754    pub(crate) async fn send_control_message(
755        &mut self,
756        dst: &Name,
757        message_type: ProtoSessionMessageType,
758        message_id: u32,
759        payload: Content,
760        metadata: Option<HashMap<String, String>>,
761        broadcast: bool,
762    ) -> Result<(), SessionError> {
763        let mut msg =
764            self.create_control_message(dst, message_type, message_id, payload, broadcast)?;
765        if let Some(m) = metadata {
766            msg.set_metadata_map(m);
767        }
768        self.send_with_timer(msg).await
769    }
770}
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775
776    // Test: internal draining transition triggered by a leave request.
777    // This test sends a LeaveRequest into a multicast participant session and then
778    // verifies (indirectly) that subsequent messages are still accepted while the
779    // session is transitioning, indicating that graceful draining has begun.
780    // Removed broken test_internal_draining_via_leave_request (incompatible mock trait implementation)
781
782    use crate::subscription_manager::{SpySubscriptionManager, SubscriptionCall};
783    use crate::transmitter::SessionTransmitter;
784    use slim_auth::shared_secret::SharedSecret;
785
786    use std::sync::Arc;
787    use std::sync::atomic::AtomicBool;
788    use std::time::Duration;
789    use tokio::time::timeout;
790    use tracing_test::traced_test;
791
792    const SHARED_SECRET: &str = "kjandjansdiasb8udaijdniasdaindasndasndasndasndasndasndasndas";
793
794    /// Test helper to create a SessionController with common setup
795    struct SessionControllerTestBuilder {
796        session_id: u32,
797        source: Name,
798        destination: Name,
799        session_type: ProtoSessionType,
800        mls_enabled: bool,
801        initiator: bool,
802        max_retries: Option<u32>,
803        interval: Option<Duration>,
804        metadata: HashMap<String, String>,
805        graceful_shutdown_timeout: Option<Duration>,
806    }
807
808    impl SessionControllerTestBuilder {
809        #[allow(dead_code)]
810        fn new() -> Self {
811            Self {
812                session_id: 10,
813                source: Name::from_strings(["org", "ns", "source"]).with_id(1),
814                destination: Name::from_strings(["org", "ns", "dest"]).with_id(2),
815                session_type: ProtoSessionType::PointToPoint,
816                mls_enabled: false,
817                initiator: true,
818                max_retries: Some(5),
819                interval: Some(Duration::from_millis(200)),
820                metadata: HashMap::new(),
821                graceful_shutdown_timeout: Some(Duration::from_secs(10)),
822            }
823        }
824
825        fn with_session_id(mut self, id: u32) -> Self {
826            self.session_id = id;
827            self
828        }
829
830        #[allow(dead_code)]
831        fn with_source(mut self, source: Name) -> Self {
832            self.source = source;
833            self
834        }
835
836        #[allow(dead_code)]
837        fn with_destination(mut self, destination: Name) -> Self {
838            self.destination = destination;
839            self
840        }
841
842        fn with_session_type(mut self, session_type: ProtoSessionType) -> Self {
843            self.session_type = session_type;
844            self
845        }
846
847        fn with_mls_enabled(mut self, enabled: bool) -> Self {
848            self.mls_enabled = enabled;
849            self
850        }
851
852        fn with_initiator(mut self, initiator: bool) -> Self {
853            self.initiator = initiator;
854            self
855        }
856
857        fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
858            self.metadata = metadata;
859            self
860        }
861
862        fn with_graceful_shutdown_timeout(mut self, timeout: Duration) -> Self {
863            self.graceful_shutdown_timeout = Some(timeout);
864            self
865        }
866
867        fn build(
868            self,
869        ) -> (
870            SessionController,
871            tokio::sync::mpsc::Receiver<Result<Message, slim_datapath::Status>>,
872            tokio::sync::mpsc::UnboundedReceiver<Result<Message, SessionError>>,
873        ) {
874            let config = SessionConfig {
875                session_type: self.session_type,
876                max_retries: self.max_retries,
877                interval: self.interval,
878                mls_enabled: self.mls_enabled,
879                initiator: self.initiator,
880                metadata: self.metadata,
881            };
882
883            let (tx_slim, rx_slim) = tokio::sync::mpsc::channel(10);
884            let (tx_app, rx_app) = tokio::sync::mpsc::unbounded_channel();
885            let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10);
886
887            let tx = SessionTransmitter::new(tx_slim, tx_app);
888
889            let controller = SessionController::builder()
890                .with_id(self.session_id)
891                .with_source(self.source.clone())
892                .with_destination(self.destination.clone())
893                .with_config(config)
894                .with_identity_provider(SharedSecret::new("test", SHARED_SECRET).unwrap())
895                .with_identity_verifier(SharedSecret::new("test", SHARED_SECRET).unwrap())
896                .with_tx(tx)
897                .with_tx_to_session_layer(tx_session_layer)
898                .ready()
899                .expect("failed to validate builder")
900                .build()
901                .expect("failed to build controller");
902
903            (controller, rx_slim, rx_app)
904        }
905    }
906
907    #[tokio::test]
908    async fn test_session_controller_getters() {
909        let mut metadata = HashMap::new();
910        metadata.insert("key1".to_string(), "value1".to_string());
911
912        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
913            .with_session_id(42)
914            .with_session_type(ProtoSessionType::Multicast)
915            .with_mls_enabled(true)
916            .with_metadata(metadata)
917            .build();
918
919        assert_eq!(controller.id(), 42);
920        assert_eq!(
921            controller.source(),
922            &Name::from_strings(["org", "ns", "source"]).with_id(1)
923        );
924        assert_eq!(
925            controller.dst(),
926            &Name::from_strings(["org", "ns", "dest"]).with_id(2)
927        );
928        assert_eq!(controller.session_type(), ProtoSessionType::Multicast);
929        assert!(controller.is_initiator());
930        assert_eq!(
931            controller.metadata().get("key1"),
932            Some(&"value1".to_string())
933        );
934
935        let retrieved_config = controller.session_config();
936        assert_eq!(retrieved_config.session_type, ProtoSessionType::Multicast);
937        assert_eq!(retrieved_config.max_retries, Some(5));
938    }
939
940    #[tokio::test]
941    async fn test_publish_basic() {
942        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
943
944        let target_name = Name::from_strings(["org", "ns", "target"]);
945        let payload = b"Hello World".to_vec();
946
947        controller
948            .publish(
949                &target_name,
950                payload.clone(),
951                Some("test-type".to_string()),
952                None,
953            )
954            .await
955            .expect("publish should succeed");
956
957        tokio::time::sleep(Duration::from_millis(50)).await;
958    }
959
960    #[tokio::test]
961    async fn test_publish_to_specific_connection() {
962        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
963            .with_session_type(ProtoSessionType::Multicast)
964            .build();
965
966        let target_name = Name::from_strings(["org", "ns", "target"]);
967        let payload = b"Hello to specific connection".to_vec();
968        let connection_id = 123u64;
969
970        controller
971            .publish_to(
972                &target_name,
973                connection_id,
974                payload.clone(),
975                Some("test-type".to_string()),
976                None,
977            )
978            .await
979            .expect("publish_to should succeed");
980
981        tokio::time::sleep(Duration::from_millis(50)).await;
982    }
983
984    #[tokio::test]
985    async fn test_publish_with_metadata() {
986        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
987            .with_session_type(ProtoSessionType::Multicast)
988            .build();
989
990        let target_name = Name::from_strings(["org", "ns", "target"]);
991        let payload = b"Hello with metadata".to_vec();
992
993        let mut metadata = HashMap::new();
994        metadata.insert("custom_key".to_string(), "custom_value".to_string());
995
996        controller
997            .publish(
998                &target_name,
999                payload.clone(),
1000                Some("test-type".to_string()),
1001                Some(metadata),
1002            )
1003            .await
1004            .expect("publish with metadata should succeed");
1005
1006        tokio::time::sleep(Duration::from_millis(50)).await;
1007    }
1008
1009    #[tokio::test]
1010    async fn test_invite_participant_in_multicast() {
1011        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1012            .with_session_type(ProtoSessionType::Multicast)
1013            .build();
1014
1015        let participant = Name::from_strings(["org", "ns", "participant"]);
1016
1017        controller
1018            .invite_participant(&participant)
1019            .await
1020            .expect("invite should succeed");
1021
1022        tokio::time::sleep(Duration::from_millis(50)).await;
1023    }
1024
1025    #[tokio::test]
1026    async fn test_invite_participant_not_initiator_error() {
1027        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1028            .with_session_type(ProtoSessionType::Multicast)
1029            .with_initiator(false)
1030            .build();
1031
1032        let participant = Name::from_strings(["org", "ns", "new_participant"]);
1033
1034        let result = controller.invite_participant(&participant).await;
1035        assert!(result.is_err_and(|e| matches!(e, SessionError::NotInitiator)));
1036    }
1037
1038    #[tokio::test]
1039    async fn test_invite_participant_p2p_error() {
1040        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1041            .with_session_type(ProtoSessionType::PointToPoint)
1042            .build();
1043
1044        let participant = Name::from_strings(["org", "ns", "participant"]);
1045
1046        let result = controller.invite_participant(&participant).await;
1047        assert!(result.is_err_and(|e| matches!(e, SessionError::CannotInviteToP2P)));
1048    }
1049
1050    #[tokio::test]
1051    async fn test_remove_participant_in_multicast() {
1052        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1053            .with_session_type(ProtoSessionType::Multicast)
1054            .build();
1055
1056        let participant = Name::from_strings(["org", "ns", "participant"]);
1057
1058        controller
1059            .remove_participant(&participant)
1060            .await
1061            .expect("remove should succeed");
1062
1063        tokio::time::sleep(Duration::from_millis(50)).await;
1064    }
1065
1066    #[tokio::test]
1067    async fn test_remove_participant_not_initiator_error() {
1068        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1069            .with_session_type(ProtoSessionType::Multicast)
1070            .with_initiator(false)
1071            .build();
1072
1073        let participant = Name::from_strings(["org", "ns", "participant"]);
1074
1075        let result = controller.remove_participant(&participant).await;
1076        assert!(result.is_err_and(|e| matches!(e, SessionError::NotInitiator)));
1077    }
1078
1079    #[tokio::test]
1080    async fn test_remove_participant_p2p_error() {
1081        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1082            .with_session_type(ProtoSessionType::PointToPoint)
1083            .build();
1084
1085        let participant = Name::from_strings(["org", "ns", "participant"]);
1086
1087        let result = controller.remove_participant(&participant).await;
1088        assert!(result.is_err_and(|e| matches!(e, SessionError::CannotRemoveFromP2P)));
1089    }
1090
1091    #[test]
1092    fn test_handle_channel_discovery_message() {
1093        let app_name = Name::from_strings(["org", "ns", "app"]).with_id(100);
1094        let session_id = 42;
1095
1096        let discovery_request = Message::builder()
1097            .source(Name::from_strings(["org", "ns", "requester"]).with_id(1))
1098            .destination(Name::from_strings(["org", "ns", "service"]))
1099            .identity("")
1100            .incoming_conn(999)
1101            .session_type(ProtoSessionType::Multicast)
1102            .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
1103            .session_id(session_id)
1104            .message_id(123)
1105            .payload(
1106                CommandPayload::builder()
1107                    .discovery_request(None)
1108                    .as_content(),
1109            )
1110            .build_publish()
1111            .unwrap();
1112
1113        let response = handle_channel_discovery_message(
1114            &discovery_request,
1115            &app_name,
1116            session_id,
1117            ProtoSessionType::Multicast,
1118        )
1119        .expect("should create discovery response");
1120
1121        assert_eq!(
1122            response.get_session_message_type(),
1123            ProtoSessionMessageType::DiscoveryReply
1124        );
1125        assert_eq!(response.get_session_header().get_session_id(), session_id);
1126        assert_eq!(response.get_id(), 123);
1127        assert_eq!(
1128            response.get_dst(),
1129            Name::from_strings(["org", "ns", "requester"]).with_id(1)
1130        );
1131        assert_eq!(response.get_slim_header().get_forward_to(), Some(999));
1132    }
1133
1134    #[tokio::test]
1135    async fn test_controller_drop_cancels_processing() {
1136        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1137
1138        let token = controller.cancellation_token.clone();
1139        assert!(!token.is_cancelled());
1140
1141        drop(controller);
1142
1143        tokio::time::sleep(Duration::from_millis(100)).await;
1144        assert!(token.is_cancelled());
1145    }
1146
1147    #[tokio::test]
1148    async fn test_close_success() {
1149        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1150            .with_graceful_shutdown_timeout(std::time::Duration::from_secs(2))
1151            .build();
1152
1153        let token = controller.cancellation_token.clone();
1154        assert!(!token.is_cancelled());
1155
1156        let handle = controller.close();
1157        assert!(handle.is_ok(), "got error {}", handle.unwrap_err());
1158        assert!(token.is_cancelled());
1159
1160        // Wait for the handle to complete
1161        handle
1162            .unwrap()
1163            .await
1164            .expect("processing task should complete");
1165    }
1166
1167    #[tokio::test]
1168    async fn test_close_already_closed() {
1169        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1170
1171        // Close once - should succeed
1172        let handle = controller.close();
1173        assert!(handle.is_ok());
1174        handle
1175            .unwrap()
1176            .await
1177            .expect("processing task should complete");
1178
1179        // Close again - should fail with appropriate error
1180        let result = controller.close();
1181        assert!(result.is_err());
1182        match result {
1183            Err(SessionError::SessionAlreadyClosed) => {
1184                // expected
1185            }
1186            _ => panic!("Expected SessionError::SessionAlreadyClosed"),
1187        }
1188    }
1189
1190    #[tokio::test]
1191    async fn test_close_cancels_token_immediately() {
1192        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1193
1194        let token = controller.cancellation_token.clone();
1195
1196        // Verify token is not cancelled before close
1197        assert!(!token.is_cancelled());
1198
1199        // Close returns immediately after cancelling token
1200        let handle = controller.close();
1201        assert!(handle.is_ok());
1202
1203        // Token should be cancelled immediately
1204        assert!(token.is_cancelled());
1205
1206        // Wait for processing to complete
1207        handle.unwrap().await.expect("processing should complete");
1208    }
1209
1210    #[tokio::test]
1211    async fn test_on_message_direction_north() {
1212        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1213
1214        let test_message = Message::builder()
1215            .source(controller.dst().clone())
1216            .destination(controller.source().clone())
1217            .identity("")
1218            .session_type(ProtoSessionType::PointToPoint)
1219            .session_message_type(ProtoSessionMessageType::Msg)
1220            .session_id(controller.id())
1221            .message_id(1)
1222            .application_payload("test", b"test data".to_vec())
1223            .build_publish()
1224            .unwrap();
1225
1226        let result = controller.on_message_from_slim(test_message).await;
1227        assert!(result.is_ok());
1228    }
1229
1230    #[tokio::test]
1231    async fn test_create_discovery_request() {
1232        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1233            .with_session_type(ProtoSessionType::Multicast)
1234            .build();
1235
1236        let target = Name::from_strings(["org", "ns", "target"]);
1237        let discovery_msg = controller
1238            .create_discovery_request(&target)
1239            .expect("should create discovery request");
1240
1241        assert_eq!(discovery_msg.get_source(), *controller.source());
1242        assert_eq!(discovery_msg.get_dst(), target);
1243        assert_eq!(
1244            discovery_msg.get_session_message_type(),
1245            ProtoSessionMessageType::DiscoveryRequest
1246        );
1247        assert_eq!(
1248            discovery_msg.get_session_header().get_session_id(),
1249            controller.id()
1250        );
1251        assert_eq!(
1252            discovery_msg.get_session_type(),
1253            ProtoSessionType::Multicast
1254        );
1255    }
1256
1257    #[tokio::test]
1258    #[traced_test]
1259    async fn test_end_to_end_p2p() {
1260        let session_id = 10;
1261        let moderator_name = Name::from_strings(["org", "ns", "moderator"]).with_id(1);
1262        let participant_name = Name::from_strings(["org", "ns", "participant"]);
1263        let participant_name_id = Name::from_strings(["org", "ns", "participant"]).with_id(1);
1264        // create a SessionModerator
1265        let (tx_slim_moderator, mut rx_slim_moderator) = tokio::sync::mpsc::channel(10);
1266        let (tx_app_moderator, _rx_app_moderator) = tokio::sync::mpsc::unbounded_channel();
1267        let (tx_session_layer_moderator, _rx_session_layer_moderator) =
1268            tokio::sync::mpsc::channel(10);
1269
1270        let tx_moderator =
1271            SessionTransmitter::new(tx_slim_moderator.clone(), tx_app_moderator.clone());
1272
1273        let moderator_config = SessionConfig {
1274            session_type: slim_datapath::api::ProtoSessionType::PointToPoint,
1275            max_retries: Some(5),
1276            interval: Some(Duration::from_millis(1000)),
1277            mls_enabled: true,
1278            initiator: true,
1279            metadata: std::collections::HashMap::new(),
1280        };
1281
1282        let (spy_moderator_mgr, mut rx_spy_moderator) = SpySubscriptionManager::new();
1283        let moderator = SessionController::builder()
1284            .with_id(session_id)
1285            .with_source(moderator_name.clone())
1286            .with_destination(participant_name.clone())
1287            .with_config(moderator_config)
1288            .with_identity_provider(SharedSecret::new("moderator", SHARED_SECRET).unwrap())
1289            .with_identity_verifier(SharedSecret::new("moderator", SHARED_SECRET).unwrap())
1290            .with_tx(tx_moderator.clone())
1291            .with_tx_to_session_layer(tx_session_layer_moderator)
1292            .with_subscription_manager(spy_moderator_mgr)
1293            .ready()
1294            .expect("failed to validate builder")
1295            .build()
1296            .unwrap();
1297
1298        // create a SessionParticipant
1299        let (tx_slim_participant, mut rx_slim_participant) = tokio::sync::mpsc::channel(10);
1300        let (tx_app_participant, mut rx_app_participant) = tokio::sync::mpsc::unbounded_channel();
1301        let (tx_session_layer_participant, _rx_session_layer_participant) =
1302            tokio::sync::mpsc::channel(10);
1303
1304        let tx_participant =
1305            SessionTransmitter::new(tx_slim_participant.clone(), tx_app_participant.clone());
1306
1307        let participant_config = SessionConfig {
1308            session_type: slim_datapath::api::ProtoSessionType::PointToPoint,
1309            max_retries: Some(5),
1310            interval: Some(Duration::from_millis(200)),
1311            mls_enabled: true,
1312            initiator: false,
1313            metadata: std::collections::HashMap::new(),
1314        };
1315
1316        let (spy_participant_mgr, mut rx_spy_participant) = SpySubscriptionManager::new();
1317        let participant = SessionController::builder()
1318            .with_id(session_id)
1319            .with_source(participant_name_id.clone())
1320            .with_destination(moderator_name.clone())
1321            .with_config(participant_config)
1322            .with_identity_provider(SharedSecret::new("participant", SHARED_SECRET).unwrap())
1323            .with_identity_verifier(SharedSecret::new("participant", SHARED_SECRET).unwrap())
1324            .with_tx(tx_participant.clone())
1325            .with_tx_to_session_layer(tx_session_layer_participant)
1326            .with_subscription_manager(spy_participant_mgr)
1327            .ready()
1328            .expect("failed to validate builder")
1329            .build()
1330            .unwrap();
1331
1332        let completion_handle = moderator
1333            .invite_participant_internal(&participant_name)
1334            .await
1335            .expect("error inviting participant");
1336
1337        let received_discovery_request =
1338            timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1339                .await
1340                .expect("timeout waiting for discovery request on moderator slim channel")
1341                .expect("channel closed")
1342                .expect("error in discovery request");
1343
1344        assert_eq!(
1345            received_discovery_request.get_session_message_type(),
1346            slim_datapath::api::ProtoSessionMessageType::DiscoveryRequest
1347        );
1348
1349        let discovery_msg_id = received_discovery_request.get_id();
1350
1351        // create a discovery reply and call the on message on the moderator with the reply (direction north)
1352        let mut discovery_reply = Message::builder()
1353            .source(participant_name_id.clone())
1354            .destination(moderator_name.clone())
1355            .identity("")
1356            .forward_to(1)
1357            .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1358            .session_message_type(slim_datapath::api::ProtoSessionMessageType::DiscoveryReply)
1359            .session_id(session_id)
1360            .message_id(discovery_msg_id)
1361            .payload(CommandPayload::builder().discovery_reply().as_content())
1362            .build_publish()
1363            .unwrap();
1364        discovery_reply
1365            .get_slim_header_mut()
1366            .set_incoming_conn(Some(1));
1367
1368        moderator
1369            .on_message_from_slim(discovery_reply)
1370            .await
1371            .expect("error processing discovery reply on moderator");
1372
1373        // moderator sets route for participant after discovery reply
1374        assert_eq!(
1375            rx_spy_moderator.recv().await,
1376            Some(SubscriptionCall::SetRoute),
1377            "moderator should set route after discovery reply"
1378        );
1379
1380        // check that a join request is received by slim
1381        let join_request = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1382            .await
1383            .expect("timeout waiting for join request on moderator slim channel")
1384            .expect("channel closed")
1385            .expect("error in join request");
1386
1387        assert_eq!(
1388            join_request.get_session_message_type(),
1389            slim_datapath::api::ProtoSessionMessageType::JoinRequest
1390        );
1391        assert_eq!(join_request.get_dst(), participant_name_id);
1392
1393        // call the on message on the participant side with the join request (direction north)
1394        let mut join_request_to_participant = join_request.clone();
1395        join_request_to_participant
1396            .get_slim_header_mut()
1397            .set_incoming_conn(Some(1));
1398
1399        participant
1400            .on_message_from_slim(join_request_to_participant)
1401            .await
1402            .expect("error processing join request on participant");
1403
1404        // participant sets route for moderator after join request
1405        assert_eq!(
1406            rx_spy_participant.recv().await,
1407            Some(SubscriptionCall::SetRoute),
1408            "participant should set route after join request"
1409        );
1410
1411        // check that a join reply is received by slim on the participant
1412        let join_reply = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1413            .await
1414            .expect("timeout waiting for join reply on participant slim channel")
1415            .expect("channel closed")
1416            .expect("error in join reply");
1417
1418        assert_eq!(
1419            join_reply.get_session_message_type(),
1420            slim_datapath::api::ProtoSessionMessageType::JoinReply
1421        );
1422        assert_eq!(join_reply.get_dst(), moderator_name);
1423
1424        // call the on message on the moderator with the reply (direction north)
1425        let mut join_reply_to_moderator = join_reply.clone();
1426        join_reply_to_moderator
1427            .get_slim_header_mut()
1428            .set_incoming_conn(Some(1));
1429
1430        moderator
1431            .on_message_from_slim(join_reply_to_moderator)
1432            .await
1433            .expect("error processing join reply on moderator");
1434
1435        // check that a welcome message is received by slim on the moderator
1436        let welcome_message = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1437            .await
1438            .expect("timeout waiting for welcome message on moderator slim channel")
1439            .expect("channel closed")
1440            .expect("error in welcome message");
1441
1442        assert_eq!(
1443            welcome_message.get_session_message_type(),
1444            slim_datapath::api::ProtoSessionMessageType::GroupWelcome
1445        );
1446        assert_eq!(welcome_message.get_dst(), participant_name_id);
1447
1448        // call the on message on the participant side with the welcome message (direction north)
1449        let mut welcome_to_participant = welcome_message.clone();
1450        welcome_to_participant
1451            .get_slim_header_mut()
1452            .set_incoming_conn(Some(1));
1453
1454        participant
1455            .on_message_from_slim(welcome_to_participant)
1456            .await
1457            .expect("error processing welcome message on participant");
1458
1459        // check that an ack group is received by slim on the participant
1460        let ack_group = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1461            .await
1462            .expect("timeout waiting for ack group on participant slim channel")
1463            .expect("channel closed")
1464            .expect("error in ack group");
1465
1466        assert_eq!(
1467            ack_group.get_session_message_type(),
1468            slim_datapath::api::ProtoSessionMessageType::GroupAck
1469        );
1470        assert_eq!(ack_group.get_dst(), moderator_name);
1471
1472        // call the on message on the moderator with the ack (direction north)
1473        let mut ack_to_moderator = ack_group.clone();
1474        ack_to_moderator
1475            .get_slim_header_mut()
1476            .set_incoming_conn(Some(1));
1477
1478        moderator
1479            .on_message_from_slim(ack_to_moderator)
1480            .await
1481            .expect("error processing ack group on moderator");
1482
1483        // no other message should be sent
1484        let no_more_moderator = timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1485        assert!(
1486            no_more_moderator.is_err(),
1487            "Expected no more messages on moderator slim channel, received {:?}",
1488            no_more_moderator
1489                .ok()
1490                .and_then(|opt| opt)
1491                .and_then(|res| res.ok())
1492        );
1493
1494        let no_more_participant =
1495            timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1496        assert!(
1497            no_more_participant.is_err(),
1498            "Expected no more messages on participant slim channel"
1499        );
1500
1501        // the completion handler should now be complete
1502        completion_handle.await.expect("error in completion handle");
1503
1504        // create an application message using the participant name
1505        let app_data = b"Hello from moderator to participant".to_vec();
1506        let app_message = Message::builder()
1507            .source(moderator_name.clone())
1508            .destination(participant_name.clone())
1509            .identity("")
1510            .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1511            .session_message_type(slim_datapath::api::ProtoSessionMessageType::Msg)
1512            .session_id(session_id)
1513            .message_id(1)
1514            .application_payload("test-app-data", app_data.clone())
1515            .build_publish()
1516            .unwrap();
1517
1518        // call on message on the moderator (direction south)
1519        moderator
1520            .on_message_from_app(app_message)
1521            .await
1522            .expect("error sending application message from moderator");
1523
1524        // check that message is received from slim with destination equal to participant name id
1525        let app_msg_to_slim = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1526            .await
1527            .expect("timeout waiting for application message on moderator slim channel")
1528            .expect("channel closed")
1529            .expect("error in application message");
1530
1531        assert_eq!(app_msg_to_slim.get_dst(), participant_name_id);
1532        assert!(
1533            app_msg_to_slim.is_publish(),
1534            "message should be a publish message"
1535        );
1536
1537        let app_msg_id = app_msg_to_slim.get_id();
1538
1539        // call the on message on the participant (direction north)
1540        let mut app_msg_to_participant = app_msg_to_slim.clone();
1541        app_msg_to_participant
1542            .get_slim_header_mut()
1543            .set_incoming_conn(Some(1));
1544
1545        participant
1546            .on_message_from_slim(app_msg_to_participant)
1547            .await
1548            .expect("error processing application message on participant");
1549
1550        // check that the message is received by the application
1551        let app_msg_received = timeout(Duration::from_millis(100), rx_app_participant.recv())
1552            .await
1553            .expect("timeout waiting for application message on participant app channel")
1554            .expect("channel closed")
1555            .expect("error in application message to app");
1556
1557        assert_eq!(app_msg_received.get_source(), moderator_name);
1558        assert!(
1559            app_msg_received.is_publish(),
1560            "message should be a publish message"
1561        );
1562
1563        let content = app_msg_received
1564            .get_payload()
1565            .unwrap()
1566            .as_application_payload()
1567            .unwrap()
1568            .blob
1569            .clone();
1570        assert_eq!(content, app_data);
1571
1572        // check that an ack is sent to slim
1573        let ack_msg = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1574            .await
1575            .expect("timeout waiting for ack on participant slim channel")
1576            .expect("channel closed")
1577            .expect("error in ack");
1578
1579        assert_eq!(
1580            ack_msg.get_session_message_type(),
1581            slim_datapath::api::ProtoSessionMessageType::MsgAck,
1582            "message should be an ack"
1583        );
1584        assert_eq!(ack_msg.get_dst(), moderator_name);
1585        assert_eq!(ack_msg.get_id(), app_msg_id);
1586
1587        // call the on message with the ack on the moderator
1588        let mut ack_to_moderator = ack_msg.clone();
1589        ack_to_moderator
1590            .get_slim_header_mut()
1591            .set_incoming_conn(Some(1));
1592
1593        moderator
1594            .on_message_from_slim(ack_to_moderator)
1595            .await
1596            .expect("error processing ack on moderator");
1597
1598        // check that no other message is generated
1599        let no_more_moderator_after_ack =
1600            timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1601        assert!(
1602            no_more_moderator_after_ack.is_err(),
1603            "Expected no more messages on moderator slim channel after ack"
1604        );
1605
1606        let no_more_participant_after_ack =
1607            timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1608        assert!(
1609            no_more_participant_after_ack.is_err(),
1610            "Expected no more messages on participant slim channel after ack"
1611        );
1612
1613        // create a leave request and send to moderator on message (direction south)
1614        let leave_request = Message::builder()
1615            .source(moderator_name.clone())
1616            .destination(participant_name.clone())
1617            .identity("")
1618            .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1619            .session_message_type(slim_datapath::api::ProtoSessionMessageType::LeaveRequest)
1620            .session_id(session_id)
1621            .message_id(rand::random::<u32>())
1622            .payload(CommandPayload::builder().leave_request(None).as_content())
1623            .build_publish()
1624            .unwrap();
1625
1626        moderator
1627            .on_message_from_app(leave_request)
1628            .await
1629            .expect("error sending leave request");
1630
1631        // check that the request is received by slim on the moderator
1632        let received_leave_request = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1633            .await
1634            .expect("timeout waiting for leave request on moderator slim channel")
1635            .expect("channel closed")
1636            .expect("error in leave request");
1637
1638        assert_eq!(
1639            received_leave_request.get_session_message_type(),
1640            slim_datapath::api::ProtoSessionMessageType::LeaveRequest
1641        );
1642        assert_eq!(received_leave_request.get_dst(), participant_name_id);
1643
1644        // send the request to the participant (direction north)
1645        let mut leave_request_to_participant = received_leave_request.clone();
1646        leave_request_to_participant
1647            .get_slim_header_mut()
1648            .set_incoming_conn(Some(1));
1649
1650        participant
1651            .on_message_from_slim(leave_request_to_participant)
1652            .await
1653            .expect("error processing leave request on participant");
1654
1655        // get the leave reply on the participant slim
1656        let leave_reply = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1657            .await
1658            .expect("timeout waiting for leave reply on participant slim channel")
1659            .expect("channel closed")
1660            .expect("error in leave reply");
1661
1662        assert_eq!(
1663            leave_reply.get_session_message_type(),
1664            slim_datapath::api::ProtoSessionMessageType::LeaveReply
1665        );
1666        assert_eq!(leave_reply.get_dst(), moderator_name);
1667
1668        // participant removes route after processing leave request
1669        assert_eq!(
1670            rx_spy_participant.recv().await,
1671            Some(SubscriptionCall::RemoveRoute),
1672            "participant should remove route after leave request"
1673        );
1674
1675        // send the leave reply to the moderator on message (direction north)
1676        let mut leave_reply_to_moderator = leave_reply.clone();
1677        leave_reply_to_moderator
1678            .get_slim_header_mut()
1679            .set_incoming_conn(Some(1));
1680
1681        moderator
1682            .on_message_from_slim(leave_reply_to_moderator)
1683            .await
1684            .expect("error processing leave reply on moderator");
1685
1686        // moderator removes route after processing leave reply
1687        assert_eq!(
1688            rx_spy_moderator.recv().await,
1689            Some(SubscriptionCall::RemoveRoute),
1690            "moderator should remove route after leave reply"
1691        );
1692
1693        // check that no other messages are generated by the moderator
1694        let no_more_moderator_final =
1695            timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1696
1697        assert!(
1698            no_more_moderator_final.is_err(),
1699            "Expected no more messages on moderator slim channel after leave"
1700        );
1701
1702        let no_more_participant_final =
1703            timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1704        assert!(
1705            no_more_participant_final.is_err(),
1706            "Expected no more messages on participant slim channel after leave"
1707        );
1708    }
1709
1710    // ============================================================================
1711    // Draining Tests
1712    #[traced_test]
1713    #[tokio::test]
1714    async fn test_internal_draining_via_processing_state_switch() {
1715        use super::*;
1716        use tokio::sync::mpsc;
1717        use tracing::debug;
1718
1719        // Custom handler that flips processing_state to Draining after first normal message
1720        struct InternalDrainHandler {
1721            state: ProcessingState,
1722            messages: Vec<SessionMessage>,
1723            needs_drain: Arc<AtomicBool>,
1724        }
1725
1726        impl InternalDrainHandler {
1727            fn new(needs_drain: Arc<AtomicBool>) -> Self {
1728                Self {
1729                    state: ProcessingState::Active,
1730                    messages: vec![],
1731                    needs_drain,
1732                }
1733            }
1734        }
1735
1736        #[async_trait::async_trait]
1737        impl MessageHandler for InternalDrainHandler {
1738            async fn init(&mut self) -> Result<(), SessionError> {
1739                Ok(())
1740            }
1741
1742            async fn on_message(&mut self, message: SessionMessage) -> Result<(), SessionError> {
1743                debug!(?self.state, "internal-drain-handler received message");
1744                self.messages.push(message);
1745
1746                // when we receive 2 messages, transition to draining state
1747                if self.messages.len() == 2 {
1748                    debug!("internal-drain-handler transitioning to draining");
1749                    self.state = ProcessingState::Draining;
1750                }
1751
1752                Ok(())
1753            }
1754
1755            fn needs_drain(&self) -> bool {
1756                self.needs_drain.load(std::sync::atomic::Ordering::SeqCst)
1757            }
1758
1759            fn processing_state(&self) -> ProcessingState {
1760                self.state
1761            }
1762
1763            async fn on_shutdown(&mut self) -> Result<(), SessionError> {
1764                debug!("shutdown called on handler");
1765                Ok(())
1766            }
1767        }
1768
1769        // Build minimal SessionSettings
1770        let (tx_slim, _rx_slim) = mpsc::channel(8);
1771        let (tx_app, _rx_app) = mpsc::unbounded_channel();
1772        let (tx_session, rx_session) = mpsc::channel(32);
1773        let (tx_session_layer, _rx_session_layer) = mpsc::channel(8);
1774
1775        let subscription_manager =
1776            crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
1777        let settings = SessionSettings {
1778            id: 999,
1779            source: Name::from_strings(["org", "ns", "source"]).with_id(1),
1780            destination: Name::from_strings(["org", "ns", "dest"]).with_id(2),
1781            config: SessionConfig {
1782                session_type: ProtoSessionType::PointToPoint,
1783                max_retries: Some(3),
1784                interval: Some(Duration::from_millis(150)),
1785                mls_enabled: false,
1786                initiator: true,
1787                metadata: HashMap::new(),
1788            },
1789            tx: SessionTransmitter::new(tx_slim, tx_app),
1790            tx_session: tx_session.clone(),
1791            tx_to_session_layer: tx_session_layer,
1792            identity_provider: SharedSecret::new("src", SHARED_SECRET).unwrap(),
1793            identity_verifier: SharedSecret::new("src", SHARED_SECRET).unwrap(),
1794            graceful_shutdown_timeout: Some(Duration::from_secs(10)),
1795            subscription_manager,
1796            service_id: String::new(),
1797        };
1798
1799        let needs_drain = Arc::new(AtomicBool::new(true));
1800        let handler = InternalDrainHandler::new(needs_drain.clone());
1801        let cancellation_token = CancellationToken::new();
1802        let cancellation_token_clone = cancellation_token.clone();
1803
1804        // Spawn processing loop without unnecessary cloning
1805        let processing_handle = tokio::spawn(async move {
1806            SessionController::processing_loop(
1807                handler,
1808                rx_session,
1809                cancellation_token_clone,
1810                settings,
1811            )
1812            .await
1813        });
1814
1815        // Send first regular message
1816        tx_session
1817            .send(create_test_message(1, b"first".to_vec()))
1818            .await
1819            .expect("failed to send first message");
1820
1821        tokio::time::sleep(Duration::from_millis(100)).await;
1822
1823        assert!(logs_contain("internal-drain-handler received message"));
1824
1825        // Send second message; this causes internal handler move to draining (active -> draining)
1826        tx_session
1827            .send(create_test_message(2, b"second".to_vec()))
1828            .await
1829            .expect("failed to send second message");
1830
1831        tokio::time::sleep(Duration::from_millis(100)).await;
1832
1833        assert!(logs_contain("internal-drain-handler received message"));
1834
1835        assert!(logs_contain(
1836            "internal-drain-handler transitioning to draining"
1837        ));
1838
1839        // Send a third message that should not be processed, as draining is active
1840        tx_session
1841            .send(create_test_message(3, b"third".to_vec()))
1842            .await
1843            .expect("failed to send third message");
1844
1845        tokio::time::sleep(Duration::from_millis(100)).await;
1846
1847        assert!(logs_contain(
1848            "session is draining, rejecting new messages from application"
1849        ));
1850
1851        // set needs drain to false to allow shutdown to complete
1852        needs_drain.store(false, std::sync::atomic::Ordering::SeqCst);
1853
1854        // trigger cancellation to exit processing loop
1855        cancellation_token.cancel();
1856
1857        tokio::time::sleep(Duration::from_millis(100)).await;
1858
1859        // Send a session message to trigger the shutdown process
1860        tx_session
1861            .send(SessionMessage::StartDrain {
1862                grace_period: std::time::Duration::from_millis(100),
1863            })
1864            .await
1865            .expect("failed to send timeout message");
1866
1867        // Wait for processing loop to complete
1868        processing_handle.await.expect("processing loop panicked");
1869    }
1870    // ============================================================================
1871
1872    /// Mock handler that tracks draining behavior
1873    struct DrainableHandler {
1874        messages_received: Arc<tokio::sync::Mutex<Vec<SessionMessage>>>,
1875        needs_drain: Arc<AtomicBool>,
1876        shutdown_called: Arc<tokio::sync::Mutex<bool>>,
1877        drain_delay: Option<Duration>,
1878    }
1879
1880    impl DrainableHandler {
1881        fn new() -> Self {
1882            Self {
1883                messages_received: Arc::new(tokio::sync::Mutex::new(Vec::new())),
1884                needs_drain: Arc::new(AtomicBool::new(false)),
1885                shutdown_called: Arc::new(tokio::sync::Mutex::new(false)),
1886                drain_delay: None,
1887            }
1888        }
1889
1890        fn with_needs_drain(self, needs_drain: bool) -> Self {
1891            self.needs_drain
1892                .store(needs_drain, std::sync::atomic::Ordering::SeqCst);
1893            self
1894        }
1895
1896        #[allow(dead_code)]
1897        fn with_drain_delay(mut self, delay: Duration) -> Self {
1898            self.drain_delay = Some(delay);
1899            self
1900        }
1901
1902        #[allow(dead_code)]
1903        async fn get_messages_count(&self) -> usize {
1904            self.messages_received.lock().await.len()
1905        }
1906
1907        #[allow(dead_code)]
1908        async fn was_shutdown_called(&self) -> bool {
1909            *self.shutdown_called.lock().await
1910        }
1911    }
1912
1913    #[async_trait::async_trait]
1914    impl MessageHandler for DrainableHandler {
1915        async fn init(&mut self) -> Result<(), SessionError> {
1916            Ok(())
1917        }
1918
1919        async fn on_message(&mut self, message: SessionMessage) -> Result<(), SessionError> {
1920            self.messages_received.lock().await.push(message);
1921            Ok(())
1922        }
1923
1924        fn needs_drain(&self) -> bool {
1925            self.needs_drain.load(std::sync::atomic::Ordering::SeqCst)
1926        }
1927
1928        async fn on_shutdown(&mut self) -> Result<(), SessionError> {
1929            if let Some(delay) = self.drain_delay {
1930                tokio::time::sleep(delay).await;
1931            }
1932            *self.shutdown_called.lock().await = true;
1933            Ok(())
1934        }
1935    }
1936
1937    /// Helper to create test SessionSettings
1938    fn create_test_settings(
1939        graceful_shutdown_timeout: Option<Duration>,
1940    ) -> SessionSettings<SharedSecret, SharedSecret> {
1941        let (tx_slim, _rx_slim) = tokio::sync::mpsc::channel(10);
1942        let (tx_app, _rx_app) = tokio::sync::mpsc::unbounded_channel();
1943        let (tx_session, _rx_session) = tokio::sync::mpsc::channel(10);
1944        let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10);
1945
1946        let subscription_manager =
1947            crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
1948        SessionSettings {
1949            id: 1,
1950            source: Name::from_strings(["org", "ns", "test"]).with_id(1),
1951            destination: Name::from_strings(["org", "ns", "test"]).with_id(2),
1952            config: SessionConfig {
1953                session_type: ProtoSessionType::PointToPoint,
1954                max_retries: Some(5),
1955                interval: Some(Duration::from_millis(200)),
1956                mls_enabled: false,
1957                initiator: true,
1958                metadata: HashMap::new(),
1959            },
1960            tx: SessionTransmitter::new(tx_slim, tx_app),
1961            tx_session,
1962            tx_to_session_layer: tx_session_layer,
1963            identity_provider: SharedSecret::new("test", SHARED_SECRET).unwrap(),
1964            identity_verifier: SharedSecret::new("test", SHARED_SECRET).unwrap(),
1965            graceful_shutdown_timeout,
1966            subscription_manager,
1967            service_id: String::new(),
1968        }
1969    }
1970
1971    /// Helper to create a test message
1972    fn create_test_message(message_id: u32, payload: Vec<u8>) -> SessionMessage {
1973        SessionMessage::OnMessage {
1974            message: Message::builder()
1975                .source(Name::from_strings(["org", "ns", "test"]).with_id(1))
1976                .destination(Name::from_strings(["org", "ns", "test"]).with_id(2))
1977                .identity("")
1978                .forward_to(1)
1979                .session_type(ProtoSessionType::PointToPoint)
1980                .session_message_type(ProtoSessionMessageType::Msg)
1981                .session_id(1)
1982                .message_id(message_id)
1983                .application_payload("test", payload)
1984                .build_publish()
1985                .unwrap(),
1986            direction: MessageDirection::South,
1987            ack_tx: None,
1988        }
1989    }
1990
1991    async fn count_on_messages(messages: &Arc<tokio::sync::Mutex<Vec<SessionMessage>>>) -> usize {
1992        let messages = messages.lock().await;
1993        messages
1994            .iter()
1995            .filter(|msg| matches!(msg, SessionMessage::OnMessage { .. }))
1996            .count()
1997    }
1998
1999    /// Helper to spawn a processing loop and return the task handle
2000    fn spawn_processing_loop(
2001        handler: DrainableHandler,
2002        rx: tokio::sync::mpsc::Receiver<SessionMessage>,
2003        cancellation_token: CancellationToken,
2004        settings: SessionSettings<SharedSecret, SharedSecret>,
2005    ) -> tokio::task::JoinHandle<()> {
2006        tokio::spawn(async move {
2007            SessionController::processing_loop(handler, rx, cancellation_token, settings).await;
2008        })
2009    }
2010
2011    #[tokio::test]
2012    async fn test_draining_processes_queued_messages() {
2013        let handler = DrainableHandler::new();
2014        let messages_received = handler.messages_received.clone();
2015        let shutdown_called = handler.shutdown_called.clone();
2016
2017        let (tx, rx) = tokio::sync::mpsc::channel(10);
2018        let cancellation_token = CancellationToken::new();
2019        let token_clone = cancellation_token.clone();
2020
2021        let settings = create_test_settings(Some(Duration::from_secs(2)));
2022        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2023
2024        // Send multiple messages before cancellation
2025        tx.send(create_test_message(1, vec![1, 2, 3]))
2026            .await
2027            .unwrap();
2028        tx.send(create_test_message(2, vec![4, 5, 6]))
2029            .await
2030            .unwrap();
2031        tx.send(create_test_message(3, vec![7, 8, 9]))
2032            .await
2033            .unwrap();
2034
2035        // Give some time for messages to be queued
2036        tokio::time::sleep(Duration::from_millis(50)).await;
2037
2038        // Trigger cancellation
2039        token_clone.cancel();
2040
2041        // Close the channel to signal no more messages
2042        drop(tx);
2043
2044        // Wait for processing to complete
2045        timeout(Duration::from_secs(3), processing_task)
2046            .await
2047            .expect("timeout waiting for processing loop")
2048            .expect("processing loop panicked");
2049
2050        // Verify all messages were processed
2051        let processed_messages = count_on_messages(&messages_received).await;
2052        assert_eq!(
2053            processed_messages, 3,
2054            "All queued messages should be processed during draining"
2055        );
2056        assert!(
2057            *shutdown_called.lock().await,
2058            "Shutdown should have been called"
2059        );
2060    }
2061
2062    #[tokio::test]
2063    async fn test_draining_with_needs_drain_true() {
2064        let handler = DrainableHandler::new().with_needs_drain(true);
2065        let messages_received = handler.messages_received.clone();
2066        let shutdown_called = handler.shutdown_called.clone();
2067
2068        let (tx, rx) = tokio::sync::mpsc::channel(10);
2069        let cancellation_token = CancellationToken::new();
2070        let token_clone = cancellation_token.clone();
2071
2072        let settings = create_test_settings(Some(Duration::from_secs(2)));
2073        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2074
2075        // Send a message
2076        tx.send(create_test_message(1, vec![1, 2, 3]))
2077            .await
2078            .unwrap();
2079        tokio::time::sleep(Duration::from_millis(50)).await;
2080
2081        // Trigger cancellation and close channel
2082        token_clone.cancel();
2083        drop(tx);
2084
2085        // Wait for processing to complete (should wait for drain timeout)
2086        timeout(Duration::from_secs(3), processing_task)
2087            .await
2088            .expect("timeout waiting for processing loop")
2089            .expect("processing loop panicked");
2090
2091        // Verify message was processed and shutdown was called
2092        let processed_messages = count_on_messages(&messages_received).await;
2093        assert_eq!(processed_messages, 1, "Message should be processed");
2094        assert!(
2095            *shutdown_called.lock().await,
2096            "Shutdown should have been called after draining"
2097        );
2098    }
2099
2100    #[tokio::test]
2101    async fn test_draining_with_needs_drain_false() {
2102        let handler = DrainableHandler::new().with_needs_drain(false);
2103        let messages_received = handler.messages_received.clone();
2104        let shutdown_called = handler.shutdown_called.clone();
2105
2106        let (tx, rx) = tokio::sync::mpsc::channel(10);
2107        let cancellation_token = CancellationToken::new();
2108        let token_clone = cancellation_token.clone();
2109
2110        let settings = create_test_settings(Some(Duration::from_secs(2)));
2111
2112        let start_time = tokio::time::Instant::now();
2113        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2114
2115        // Send a message
2116        tx.send(create_test_message(1, vec![1, 2, 3]))
2117            .await
2118            .unwrap();
2119        tokio::time::sleep(Duration::from_millis(50)).await;
2120
2121        // Trigger cancellation and close channel
2122        token_clone.cancel();
2123        drop(tx);
2124
2125        // Wait for processing to complete (should exit quickly)
2126        timeout(Duration::from_secs(1), processing_task)
2127            .await
2128            .expect("timeout waiting for processing loop")
2129            .expect("processing loop panicked");
2130
2131        let elapsed = start_time.elapsed();
2132
2133        // Verify message was processed and shutdown was called quickly
2134        let processed_messages = count_on_messages(&messages_received).await;
2135        assert_eq!(processed_messages, 1, "Message should be processed");
2136        assert!(
2137            *shutdown_called.lock().await,
2138            "Shutdown should have been called"
2139        );
2140        assert!(
2141            elapsed < Duration::from_millis(500),
2142            "Should exit quickly when no draining needed"
2143        );
2144    }
2145
2146    #[tokio::test]
2147    async fn test_draining_timeout_enforced() {
2148        // Test that the timeout fires when draining takes too long with needs_drain=true
2149        let handler = DrainableHandler::new().with_needs_drain(true);
2150
2151        let (tx, rx) = tokio::sync::mpsc::channel(10);
2152        let cancellation_token = CancellationToken::new();
2153        let token_clone = cancellation_token.clone();
2154
2155        let settings = create_test_settings(Some(Duration::from_millis(500)));
2156        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2157
2158        // Give the processing loop a moment to start
2159        tokio::time::sleep(Duration::from_millis(50)).await;
2160
2161        let start_time = tokio::time::Instant::now();
2162
2163        // Trigger cancellation - this starts draining
2164        token_clone.cancel();
2165
2166        // Keep sending messages to prevent channel from closing
2167        // This simulates a scenario where messages keep arriving during drain period
2168        let send_task = tokio::spawn(async move {
2169            for i in 0..10 {
2170                tokio::time::sleep(Duration::from_millis(100)).await;
2171                if tx
2172                    .send(create_test_message(i, vec![i as u8]))
2173                    .await
2174                    .is_err()
2175                {
2176                    break;
2177                }
2178            }
2179        });
2180
2181        // Wait for processing to complete - should timeout after 500ms
2182        timeout(Duration::from_secs(2), processing_task)
2183            .await
2184            .expect("timeout waiting for processing loop")
2185            .expect("processing loop panicked");
2186
2187        let elapsed = start_time.elapsed();
2188
2189        // Verify timeout was enforced (should be around 500ms)
2190        assert!(
2191            elapsed >= Duration::from_millis(400),
2192            "Should wait at least close to the timeout period"
2193        );
2194        assert!(
2195            elapsed < Duration::from_secs(2),
2196            "Should respect the timeout and exit, not wait forever"
2197        );
2198
2199        // Clean up the send task
2200        send_task.abort();
2201    }
2202
2203    #[tokio::test]
2204    async fn test_draining_no_messages_in_queue() {
2205        let handler = DrainableHandler::new().with_needs_drain(true);
2206        let messages_received = handler.messages_received.clone();
2207        let shutdown_called = handler.shutdown_called.clone();
2208
2209        let (tx, rx) = tokio::sync::mpsc::channel(10);
2210        let cancellation_token = CancellationToken::new();
2211        let token_clone = cancellation_token.clone();
2212
2213        let settings = create_test_settings(Some(Duration::from_secs(1)));
2214        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2215
2216        // Trigger cancellation immediately without sending messages
2217        token_clone.cancel();
2218        drop(tx);
2219
2220        // Wait for processing to complete
2221        timeout(Duration::from_secs(2), processing_task)
2222            .await
2223            .expect("timeout waiting for processing loop")
2224            .expect("processing loop panicked");
2225
2226        // Verify no messages were processed but shutdown was called
2227        let processed_messages = count_on_messages(&messages_received).await;
2228        assert_eq!(processed_messages, 0, "No messages should be processed");
2229        assert!(
2230            *shutdown_called.lock().await,
2231            "Shutdown should still be called"
2232        );
2233    }
2234
2235    #[tokio::test]
2236    async fn test_draining_messages_after_cancellation_processed() {
2237        let handler = DrainableHandler::new();
2238        let messages_received = handler.messages_received.clone();
2239
2240        let (tx, rx) = tokio::sync::mpsc::channel(10);
2241        let cancellation_token = CancellationToken::new();
2242        let token_clone = cancellation_token.clone();
2243
2244        let settings = create_test_settings(Some(Duration::from_secs(2)));
2245
2246        // Send messages before cancellation
2247        tx.send(create_test_message(1, vec![1, 2, 3]))
2248            .await
2249            .unwrap();
2250        tx.send(create_test_message(2, vec![4, 5, 6]))
2251            .await
2252            .unwrap();
2253
2254        // Spawn the processing loop after messages are queued
2255        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2256
2257        // Give a moment for processing to start
2258        tokio::time::sleep(Duration::from_millis(50)).await;
2259
2260        // Trigger cancellation while messages are in queue
2261        token_clone.cancel();
2262
2263        // Close channel
2264        drop(tx);
2265
2266        // Wait for processing to complete
2267        timeout(Duration::from_secs(3), processing_task)
2268            .await
2269            .expect("timeout waiting for processing loop")
2270            .expect("processing loop panicked");
2271
2272        // Verify messages in queue when cancellation happened were still processed
2273        let processed_messages = count_on_messages(&messages_received).await;
2274        assert_eq!(
2275            processed_messages, 2,
2276            "Messages in queue during cancellation should be processed"
2277        );
2278    }
2279}