Skip to main content

slim_session/
session_controller.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4// Standard library imports
5use std::{collections::HashMap, time::Duration};
6
7use display_error_chain::ErrorChainExt;
8use parking_lot::Mutex;
9use tokio::sync::{self, oneshot};
10// Third-party crates
11use tokio_util::sync::CancellationToken;
12use tracing::{Instrument, debug};
13
14use slim_auth::traits::{TokenProvider, Verifier};
15use slim_datapath::{
16    api::{
17        CommandPayload, Content, NameId, ProtoMessage as Message, ProtoName,
18        ProtoSessionMessageType, ProtoSessionType, SlimHeader,
19    },
20    messages::utils::SlimHeaderFlags,
21};
22
23// Local crate
24use crate::{
25    MessageDirection, SessionError,
26    common::{OutboundMessage, SessionMessage, SessionOutput},
27    completion_handle::CompletionHandle,
28    controller_sender::{ControllerSender, PING_INTERVAL},
29    session_builder::{ForController, SessionBuilder},
30    session_config::SessionConfig,
31    session_settings::SessionSettings,
32    traits::{MessageHandler, ProcessingState},
33};
34
35pub(crate) async fn verify_identity<V>(msg: &Message, verifier: &V) -> Result<(), SessionError>
36where
37    V: Verifier + Send + Sync,
38{
39    let identity = msg.get_slim_header().get_identity();
40    if verifier.try_verify(&identity).is_err() {
41        verifier.verify(&identity).await?;
42    }
43    Ok(())
44}
45
46pub struct SessionController {
47    /// session id
48    pub(crate) id: u32,
49
50    /// local name
51    pub(crate) source: ProtoName,
52
53    /// group or remote endpoint name
54    pub(crate) destination: ProtoName,
55
56    /// session config
57    pub(crate) config: SessionConfig,
58
59    /// channel to send messages to the processing loop
60    tx_controller: sync::mpsc::Sender<SessionMessage>,
61
62    /// use in drop implementation to gracefully close the processing loop
63    pub(crate) cancellation_token: CancellationToken,
64
65    /// handle for the processing loop
66    handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
67}
68
69impl SessionController {
70    /// Returns a new SessionBuilder for constructing a SessionController
71    pub fn builder<P, V>() -> SessionBuilder<P, V, ForController>
72    where
73        P: TokenProvider + Send + Sync + Clone + 'static,
74        V: Verifier + Send + Sync + Clone + 'static,
75    {
76        SessionBuilder::for_controller()
77    }
78
79    /// Internal constructor for the builder to use
80    #[allow(clippy::too_many_arguments)]
81    pub(crate) fn from_parts<I, P, V, M>(
82        id: u32,
83        source: ProtoName,
84        destination: ProtoName,
85        config: SessionConfig,
86        settings: SessionSettings<P, V, M>,
87        tx: sync::mpsc::Sender<SessionMessage>,
88        rx: sync::mpsc::Receiver<SessionMessage>,
89        inner: I,
90    ) -> Self
91    where
92        I: MessageHandler + Send + Sync + 'static,
93        P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
94        V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
95        M: crate::subscription_manager::SubscriptionOps,
96    {
97        // Spawn the processing loop
98        let cancellation_token = CancellationToken::new();
99
100        // setup tracing context
101        let span = tracing::debug_span!(
102            parent: None,
103            "session_controller_processing_loop",
104            session_id = id,
105            service_id = %settings.service_id,
106            source = %source,
107            destination = %destination,
108            session_type = ?config.session_type
109        );
110
111        let handle = crate::runtime::spawn(
112            Self::processing_loop(inner, rx, cancellation_token.clone(), settings).instrument(span),
113        );
114
115        Self {
116            id,
117            source,
118            destination,
119            config,
120            tx_controller: tx,
121            cancellation_token,
122            handle: Mutex::new(Some(handle)),
123        }
124    }
125
126    /// Internal processing loop that handles messages with mutable access
127    fn enter_draining_state<P, V, M>(
128        shutdown_deadline: &mut std::pin::Pin<&mut tokio::time::Sleep>,
129        settings: &SessionSettings<P, V, M>,
130    ) where
131        P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
132        V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
133        M: crate::subscription_manager::SubscriptionOps,
134    {
135        let shutdown_timeout = settings
136            .graceful_shutdown_timeout
137            .unwrap_or(Duration::from_secs(60));
138        shutdown_deadline
139            .as_mut()
140            .reset(tokio::time::Instant::now() + shutdown_timeout);
141    }
142
143    /// Apply the identity token to all outbound ToSlim messages.
144    /// Must run before MLS encryption so header-integrity AAD matches the on-wire header.
145    pub(crate) fn apply_identity_to_slim_output<P>(
146        output: &mut SessionOutput,
147        identity_provider: &P,
148    ) -> Result<(), SessionError>
149    where
150        P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
151    {
152        let identity = identity_provider.get_token()?;
153        for msg in &mut output.messages {
154            if let OutboundMessage::ToSlim(m) = msg {
155                m.get_slim_header_mut().set_identity(identity.clone());
156            }
157        }
158        Ok(())
159    }
160
161    /// Dispatch outbound messages from SessionOutput to actual channels.
162    /// Sends ToApp messages directly to the application channel.
163    async fn dispatch_output<P, V, M>(output: SessionOutput, settings: &SessionSettings<P, V, M>)
164    where
165        P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
166        V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
167        M: crate::subscription_manager::SubscriptionOps,
168    {
169        for msg in output.messages {
170            match msg {
171                OutboundMessage::ToSlim(message) => {
172                    if let Err(e) = settings.slim_tx.send(Ok(message)).await {
173                        tracing::error!(error = %e, "failed to send message to SLIM");
174                    }
175                }
176                OutboundMessage::ToApp(result) => {
177                    if let Err(e) = settings.app_tx.send(result) {
178                        tracing::error!(error = %e, "failed to send message to application");
179                    }
180                }
181            }
182        }
183    }
184
185    async fn processing_loop<P, V, M>(
186        mut inner: impl MessageHandler + 'static,
187        mut rx: sync::mpsc::Receiver<SessionMessage>,
188        cancellation_token: CancellationToken,
189        settings: SessionSettings<P, V, M>,
190    ) where
191        P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
192        V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
193        M: crate::subscription_manager::SubscriptionOps,
194    {
195        // Start with an infinite timeout (will be updated on graceful shutdown)
196        let mut shutdown_deadline = std::pin::pin!(tokio::time::sleep(Duration::MAX));
197
198        // Init the inner components
199        if let Err(e) = inner.init().await {
200            tracing::error!(error = %e.chain(), "error during initialization of session");
201        }
202
203        loop {
204            tokio::select! {
205                _ = cancellation_token.cancelled(), if inner.processing_state() == ProcessingState::Active => {
206                    // Update the timeout to the configured grace period
207                    let shutdown_timeout = settings.graceful_shutdown_timeout
208                        .unwrap_or(Duration::from_secs(60)); // Default 60 seconds if not configured
209
210                    // Finish any ongoing processing before starting drain
211                    debug!("consuming pending messages before entering draining state");
212                    while let Ok(msg) = rx.try_recv() {
213                        if let SessionMessage::OnMessage { message, direction: MessageDirection::North, .. } = &msg
214                            && let Err(e) = crate::session_controller::verify_identity(message, &settings.identity_verifier).await {
215                            debug!(error = %e.chain(), "dropping inbound message during drain: identity verification failed");
216                            continue;
217                        }
218                        match inner.on_message(msg).await {
219                            Ok(output) => Self::dispatch_output(output, &settings).await,
220                            Err(e) => {
221                                tracing::error!(error = %e.chain(), "error processing message during draining - close immediately.");
222                                break;
223                            }
224                        }
225                    }
226
227                    // Send drain to message to the inner to notify the beginning of the drain
228                    match inner.on_message(SessionMessage::StartDrain {
229                        grace_period: shutdown_timeout
230                    }).await {
231                        Ok(output) => Self::dispatch_output(output, &settings).await,
232                        Err(e) => {
233                            tracing::error!(error = %e.chain(),  "error during start drain");
234                            break;
235                        }
236                    }
237
238                    Self::enter_draining_state(&mut shutdown_deadline, &settings);
239
240                    debug!("cancellation requested, entering draining state");
241                }
242                _ = &mut shutdown_deadline => {
243                    debug!("graceful shutdown timeout reached, forcing exit");
244                    break;
245                }
246                msg = rx.recv() => {
247                    match msg {
248                        Some(session_message) => {
249                            // Handle GetParticipantsList query immediately without going through the handler
250                            if let SessionMessage::GetParticipantsList { tx } = session_message {
251                                let participants_list = inner.participants_list();
252                                let _ = tx.send(participants_list);
253                                continue;
254                            }
255
256                            if let SessionMessage::OnMessage { message, direction: MessageDirection::North, .. } = &session_message
257                                && let Err(e) = crate::session_controller::verify_identity(message, &settings.identity_verifier).await {
258                                debug!(
259                                    error = %e.chain(),
260                                    msg_type = %message.get_session_message_type().as_str_name(),
261                                    msg_id = %message.get_id(),
262                                    "dropping inbound message: identity verification failed",
263                                );
264                                continue;
265                            }
266
267                            let draining = inner.processing_state() == ProcessingState::Draining;
268
269                            // if draining and message is sent by the application, reject it
270                            if draining && matches!(session_message, SessionMessage::OnMessage { direction: MessageDirection::South, .. }) {
271                                tracing::debug!("session is draining, rejecting new messages from application");
272                                if let SessionMessage::OnMessage { ack_tx: Some(ack_tx), .. } = session_message {
273                                    let _ = ack_tx.send(Err(SessionError::SessionDrainingDrop));
274                                }
275                                continue;
276                            }
277
278                            match inner.on_message(session_message).await {
279                                Ok(output) => {
280                                    Self::dispatch_output(output, &settings).await;
281                                    // If we were active before processing and the handler switched to draining,
282                                    // start (or reset) the graceful shutdown deadline just like on cancellation.
283                                    if !draining && inner.processing_state() == ProcessingState::Draining {
284                                        debug!("internal component requested draining, entering draining state");
285                                        Self::enter_draining_state(&mut shutdown_deadline, &settings);
286                                    }
287                                }
288                                Err(e) => {
289                                    debug!(
290                                        error=%e,
291                                        "Error processing message{}",
292                                        if draining { " during graceful shutdown" } else { "" }
293                                    );
294                                    if draining {
295                                        debug!("Exiting processing loop due to error while draining");
296                                        break;
297                                    }
298                                }
299                            }
300                        }
301                        None => {
302                            debug!("Session channel closed, no more messages can arrive - exiting processing loop");
303                            break;
304                        }
305                    }
306                }
307            }
308
309            // If we are in draining state and the inner component does not require drain, exit
310            if inner.processing_state() == ProcessingState::Draining && !inner.needs_drain() {
311                debug!("draining complete, exiting processing loop");
312                break;
313            }
314        }
315
316        // Perform final shutdown
317        if let Err(e) = inner.on_shutdown().await {
318            tracing::error!(error = %e.chain(), "error during shutdown of session");
319        }
320    }
321
322    /// getters
323    pub fn id(&self) -> u32 {
324        self.id
325    }
326
327    pub fn source(&self) -> &ProtoName {
328        &self.source
329    }
330
331    pub fn dst(&self) -> &ProtoName {
332        &self.destination
333    }
334
335    pub fn session_type(&self) -> ProtoSessionType {
336        self.config.session_type
337    }
338
339    pub fn metadata(&self) -> HashMap<String, String> {
340        self.config.metadata.clone()
341    }
342
343    pub fn session_config(&self) -> SessionConfig {
344        self.config.clone()
345    }
346
347    pub fn is_initiator(&self) -> bool {
348        self.config.initiator
349    }
350
351    pub async fn participants_list(&self) -> Result<Vec<ProtoName>, SessionError> {
352        let (tx, rx) = oneshot::channel();
353
354        // Send query to the processing loop
355        self.tx_controller
356            .send(SessionMessage::GetParticipantsList { tx })
357            .await
358            .map_err(|_| SessionError::ParticipantsListQueryFailed)?;
359
360        // Wait for response
361        rx.await
362            .map_err(|_| SessionError::ParticipantsListQueryFailed)
363    }
364
365    async fn on_message(
366        &self,
367        message: Message,
368        direction: MessageDirection,
369        ack_tx: Option<oneshot::Sender<Result<(), SessionError>>>,
370    ) -> Result<(), SessionError> {
371        self.tx_controller
372            .send(SessionMessage::OnMessage {
373                message,
374                direction,
375                ack_tx,
376            })
377            .await
378            .map_err(|_e| SessionError::SessionControllerSendFailed)
379    }
380
381    /// Send a message to the controller for processing
382    pub async fn on_message_from_app(
383        &self,
384        message: Message,
385    ) -> Result<CompletionHandle, SessionError> {
386        let (ack_tx, ack_rx) = oneshot::channel();
387        self.on_message(message, MessageDirection::South, Some(ack_tx))
388            .await?;
389
390        let ret = CompletionHandle::from_oneshot_receiver(ack_rx);
391
392        Ok(ret)
393    }
394
395    /// Send a message to the controller for processing
396    pub async fn on_message_from_slim(&self, message: Message) -> Result<(), SessionError> {
397        self.on_message(message, MessageDirection::North, None)
398            .await
399    }
400
401    /// Send an error message to the controller for processing
402    pub async fn on_error_message_from_slim(
403        &self,
404        error: SessionError,
405    ) -> Result<(), SessionError> {
406        self.tx_controller
407            .send(SessionMessage::MessageError { error })
408            .await
409            .map_err(|_e| SessionError::SessionControllerSendFailed)
410    }
411
412    pub fn close(&self) -> Result<tokio::task::JoinHandle<()>, SessionError> {
413        self.cancellation_token.cancel();
414
415        self.handle
416            .lock()
417            .take()
418            .ok_or(SessionError::SessionAlreadyClosed)
419    }
420
421    pub async fn publish_message(
422        &self,
423        message: Message,
424    ) -> Result<CompletionHandle, SessionError> {
425        self.on_message_from_app(message).await
426    }
427
428    /// Publish a message to a specific connection (forward_to)
429    pub async fn publish_to(
430        &self,
431        name: &ProtoName,
432        forward_to: u64,
433        blob: Vec<u8>,
434        payload_type: Option<String>,
435        metadata: Option<HashMap<String, String>>,
436    ) -> Result<CompletionHandle, SessionError> {
437        self.publish_with_flags(
438            name,
439            SlimHeaderFlags::default().with_forward_to(forward_to),
440            blob,
441            payload_type,
442            metadata,
443        )
444        .await
445    }
446
447    /// Publish a message to a specific app name
448    pub async fn publish(
449        &self,
450        name: &ProtoName,
451        blob: Vec<u8>,
452        payload_type: Option<String>,
453        metadata: Option<HashMap<String, String>>,
454    ) -> Result<CompletionHandle, SessionError> {
455        self.publish_with_flags(
456            name,
457            SlimHeaderFlags::default(),
458            blob,
459            payload_type,
460            metadata,
461        )
462        .await
463    }
464
465    /// Publish a message with specific flags
466    pub async fn publish_with_flags(
467        &self,
468        name: &ProtoName,
469        flags: SlimHeaderFlags,
470        blob: Vec<u8>,
471        payload_type: Option<String>,
472        metadata: Option<HashMap<String, String>>,
473    ) -> Result<CompletionHandle, SessionError> {
474        let ct = payload_type.unwrap_or_else(|| "msg".to_string());
475
476        let mut msg = Message::builder()
477            .source(self.source().clone())
478            .destination(name.clone())
479            .identity("")
480            .flags(flags)
481            .session_type(self.session_type())
482            .session_message_type(ProtoSessionMessageType::Msg)
483            .session_id(self.id())
484            .message_id(rand::random::<u32>()) // this will be changed by the session itself
485            .application_payload(&ct, blob)
486            .build_publish()?;
487        if let Some(map) = metadata
488            && !map.is_empty()
489        {
490            msg.set_metadata_map(map);
491        }
492
493        // southbound=true means towards slim
494        self.publish_message(msg).await
495    }
496
497    /// Creates a discovery request message with minimum required information
498    fn create_discovery_request(&self, destination: &ProtoName) -> Result<Message, SessionError> {
499        let payload = CommandPayload::builder().discovery_request().as_content();
500
501        let msg = Message::builder()
502            .source(self.source().clone())
503            .destination(destination.clone())
504            .identity("")
505            .session_type(self.session_type())
506            .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
507            .session_id(self.id())
508            .message_id(rand::random::<u32>())
509            .payload(payload)
510            .build_publish()?;
511
512        Ok(msg)
513    }
514
515    pub(crate) async fn invite_participant_internal(
516        &self,
517        destination: &ProtoName,
518    ) -> Result<CompletionHandle, SessionError> {
519        let msg = self.create_discovery_request(destination)?;
520        self.publish_message(msg).await
521    }
522
523    pub async fn invite_participant(
524        &self,
525        destination: &ProtoName,
526    ) -> Result<CompletionHandle, SessionError> {
527        match self.session_type() {
528            ProtoSessionType::PointToPoint => Err(SessionError::CannotInviteToP2P),
529            ProtoSessionType::Multicast => {
530                if !self.is_initiator() {
531                    return Err(SessionError::NotInitiator);
532                }
533                self.invite_participant_internal(destination).await
534            }
535            _ => Err(SessionError::SessionTypeUnknown(self.session_type())),
536        }
537    }
538
539    pub async fn remove_participant(
540        &self,
541        destination: &ProtoName,
542    ) -> Result<CompletionHandle, SessionError> {
543        match self.session_type() {
544            ProtoSessionType::PointToPoint => Err(SessionError::CannotRemoveFromP2P),
545            ProtoSessionType::Multicast => {
546                if !self.is_initiator() {
547                    return Err(SessionError::NotInitiator);
548                }
549                let msg = Message::builder()
550                    .source(self.source().clone())
551                    .destination(destination.clone().with_id(NameId::NULL_COMPONENT))
552                    .identity("")
553                    .session_type(ProtoSessionType::Multicast)
554                    .session_message_type(ProtoSessionMessageType::LeaveRequest)
555                    .session_id(self.id())
556                    .message_id(rand::random::<u32>())
557                    .payload(CommandPayload::builder().leave_request().as_content())
558                    .build_publish()?;
559                self.publish_message(msg).await
560            }
561            _ => Err(SessionError::SessionTypeUnknown(self.session_type())),
562        }
563    }
564}
565
566impl Drop for SessionController {
567    fn drop(&mut self) {
568        self.cancellation_token.cancel();
569    }
570}
571
572pub fn handle_channel_discovery_message(
573    message: &Message,
574    app_name: &ProtoName,
575    session_id: u32,
576    session_type: ProtoSessionType,
577) -> Result<Message, SessionError> {
578    let destination = message.get_slim_header().source.clone().unwrap();
579
580    // the destination of the discovery message may be different from the name of
581    // application itself. This can happen if the application subscribes to multiple
582    // service names. So we can reply using as a source the destination name of
583    // the discovery message but setting the application id
584    let mut source = message.get_slim_header().destination.clone().unwrap();
585    source.set_id(app_name.id());
586    let msg_id = message.get_id();
587
588    let slim_header = SlimHeader::new(
589        source,
590        destination,
591        "",
592        Some(SlimHeaderFlags::default().with_forward_to(message.get_incoming_conn())),
593    );
594
595    let msg = Message::builder()
596        .with_slim_header(slim_header)
597        .session_type(session_type)
598        .session_message_type(ProtoSessionMessageType::DiscoveryReply)
599        .session_id(session_id)
600        .message_id(msg_id)
601        .payload(CommandPayload::builder().discovery_reply().as_content())
602        .build_publish()?;
603
604    Ok(msg)
605}
606
607pub(crate) struct SessionControllerCommon<
608    P,
609    V,
610    M = crate::subscription_manager::SubscriptionManager,
611> where
612    P: TokenProvider + Send + Sync + Clone + 'static,
613    V: Verifier + Send + Sync + Clone + 'static,
614    M: crate::subscription_manager::SubscriptionOps,
615{
616    /// common session fields
617    pub(crate) settings: SessionSettings<P, V, M>,
618
619    /// sender for command messages
620    pub(crate) sender: ControllerSender,
621
622    /// processing state
623    pub(crate) processing_state: ProcessingState,
624
625    /// Maps (kind, name, conn) → subscription_id for route/subscription tracking.
626    subscription_ids: HashMap<(SubscriptionKind, ProtoName, u64), u64>,
627}
628
629/// Distinguishes route entries from subscription entries in the subscription_ids map.
630/// Both can share the same `(Name, conn)` pair, so this enum prevents key collisions.
631#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
632pub(crate) enum SubscriptionKind {
633    /// A recv_from route (`set_route` / `remove_route`).
634    Route,
635    /// A forward_to subscription (`subscribe` / `unsubscribe`).
636    Subscription,
637}
638
639impl<P, V, M> SessionControllerCommon<P, V, M>
640where
641    P: TokenProvider + Send + Sync + Clone + 'static,
642    V: Verifier + Send + Sync + Clone + 'static,
643    M: crate::subscription_manager::SubscriptionOps,
644{
645    pub(crate) fn new(settings: SessionSettings<P, V, M>) -> Self {
646        // Create the controller sender.
647        let controller_sender = ControllerSender::new(
648            settings.config.get_timer_settings(),
649            settings.source.clone(),
650            settings.config.session_type,
651            settings.id,
652            Some(PING_INTERVAL),
653            settings.config.initiator,
654            settings.tx_session.clone(),
655        );
656
657        SessionControllerCommon {
658            settings,
659            sender: controller_sender,
660            processing_state: ProcessingState::Active,
661            subscription_ids: HashMap::new(),
662        }
663    }
664
665    /// Send control message through ControllerSender, returning the output.
666    pub(crate) fn send_with_timer(
667        &mut self,
668        message: Message,
669    ) -> Result<SessionOutput, SessionError> {
670        self.sender.on_message(&message)
671    }
672
673    async fn await_subscription_ack(
674        rx: tokio::sync::oneshot::Receiver<
675            Result<(), crate::subscription_manager::SubscriptionAckError>,
676        >,
677    ) -> Result<(), SessionError> {
678        crate::subscription_manager::SubscriptionManager::await_ack(rx)
679            .await
680            .map_err(SessionError::SubscriptionAckFailed)
681    }
682
683    pub(crate) async fn add_route(
684        &mut self,
685        name: ProtoName,
686        conn: u64,
687    ) -> Result<(), SessionError> {
688        if name == self.settings.source.clone() {
689            // We never add a route for ourselves
690            return Ok(());
691        }
692
693        let source_proto = self.settings.source.clone();
694        let (subscription_id, rx) = self
695            .settings
696            .subscription_manager
697            .set_route(&source_proto, &name, conn)
698            .await
699            .map_err(SessionError::SubscriptionAckFailed)?;
700        Self::await_subscription_ack(rx).await?;
701
702        debug!(%name, %conn, %subscription_id, source = %self.settings.source, "route added");
703
704        self.subscription_ids
705            .insert((SubscriptionKind::Route, name, conn), subscription_id);
706
707        Ok(())
708    }
709
710    pub(crate) async fn delete_route(
711        &mut self,
712        name: ProtoName,
713        conn: u64,
714    ) -> Result<(), SessionError> {
715        if name == self.settings.source.clone() {
716            // We never remove a route for ourselves
717            return Ok(());
718        }
719
720        let key = (SubscriptionKind::Route, name, conn);
721        let subscription_id = self.subscription_ids.remove(&key);
722        let (_, name, conn) = key;
723        match subscription_id {
724            Some(subscription_id) => {
725                let source_proto = self.settings.source.clone();
726                let rx = self
727                    .settings
728                    .subscription_manager
729                    .remove_route(&source_proto, &name, subscription_id, conn)
730                    .await
731                    .map_err(SessionError::SubscriptionAckFailed)?;
732
733                Self::await_subscription_ack(rx).await?;
734                tracing::debug!(%name, %conn, %subscription_id, "route deleted");
735            }
736            None => {
737                tracing::warn!(
738                    %name, %conn, io = %self.settings.source,
739                    "no subscription_id found for route, skipping delete"
740                );
741            }
742        }
743
744        Ok(())
745    }
746
747    pub(crate) async fn add_subscription(
748        &mut self,
749        name: ProtoName,
750        conn: u64,
751    ) -> Result<(), SessionError> {
752        let source_proto = self.settings.source.clone();
753        let (subscription_id, rx) = self
754            .settings
755            .subscription_manager
756            .subscribe(&source_proto, &name, Some(conn))
757            .await
758            .map_err(SessionError::SubscriptionAckFailed)?;
759
760        Self::await_subscription_ack(rx).await?;
761
762        debug!(%name, %conn, %subscription_id, "subscription added");
763
764        self.subscription_ids.insert(
765            (SubscriptionKind::Subscription, name, conn),
766            subscription_id,
767        );
768
769        Ok(())
770    }
771
772    pub(crate) async fn delete_subscription(
773        &mut self,
774        name: ProtoName,
775        conn: u64,
776    ) -> Result<(), SessionError> {
777        let key = (SubscriptionKind::Subscription, name, conn);
778        let subscription_id = self.subscription_ids.remove(&key);
779        let (_, name, conn) = key;
780        match subscription_id {
781            Some(subscription_id) => {
782                let source_proto = self.settings.source.clone();
783                let rx = self
784                    .settings
785                    .subscription_manager
786                    .unsubscribe(&source_proto, &name, subscription_id, Some(conn))
787                    .await
788                    .map_err(SessionError::SubscriptionAckFailed)?;
789
790                Self::await_subscription_ack(rx).await?;
791                debug!(%name, %conn, %subscription_id, "subscription deleted");
792            }
793            None => {
794                tracing::debug!(
795                    %name, %conn,
796                    "no subscription_id found for subscription, skipping delete"
797                );
798            }
799        }
800
801        Ok(())
802    }
803
804    pub(crate) fn create_control_message(
805        &mut self,
806        dst: &ProtoName,
807        message_type: ProtoSessionMessageType,
808        message_id: u32,
809        payload: Content,
810        broadcast: bool,
811    ) -> Result<Message, SessionError> {
812        let mut builder = Message::builder()
813            .source(self.settings.source.clone())
814            .destination(dst.clone())
815            .identity("")
816            .session_type(self.settings.config.session_type)
817            .session_message_type(message_type)
818            .session_id(self.settings.id)
819            .message_id(message_id)
820            .payload(payload);
821
822        if broadcast {
823            builder = builder.fanout(256);
824        }
825
826        let ret = builder.build_publish()?;
827
828        Ok(ret)
829    }
830
831    /// Send control message without creating ack channel (for internal use by moderator)
832    pub(crate) fn send_control_message(
833        &mut self,
834        dst: &ProtoName,
835        message_type: ProtoSessionMessageType,
836        message_id: u32,
837        payload: Content,
838        metadata: Option<HashMap<String, String>>,
839        broadcast: bool,
840    ) -> Result<SessionOutput, SessionError> {
841        let mut msg =
842            self.create_control_message(dst, message_type, message_id, payload, broadcast)?;
843        if let Some(m) = metadata {
844            msg.set_metadata_map(m);
845        }
846        self.send_with_timer(msg)
847    }
848}
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853
854    // Test: internal draining transition triggered by a leave request.
855    // This test sends a LeaveRequest into a multicast participant session and then
856    // verifies (indirectly) that subsequent messages are still accepted while the
857    // session is transitioning, indicating that graceful draining has begun.
858    // Removed broken test_internal_draining_via_leave_request (incompatible mock trait implementation)
859
860    use crate::Direction;
861    use crate::session_config::MlsSettings;
862    use crate::subscription_manager::{SpySubscriptionManager, SubscriptionCall};
863    use slim_auth::shared_secret::SharedSecret;
864
865    use std::sync::Arc;
866    use std::sync::atomic::AtomicBool;
867    use std::time::Duration;
868    use tokio::time::timeout;
869    use tracing_test::traced_test;
870
871    const SHARED_SECRET: &str = "kjandjansdiasb8udaijdniasdaindasndasndasndasndasndasndasndas";
872
873    fn test_identity() -> String {
874        SharedSecret::new("test", SHARED_SECRET)
875            .unwrap()
876            .get_token()
877            .unwrap()
878    }
879
880    /// Test helper to create a SessionController with common setup
881    struct SessionControllerTestBuilder {
882        session_id: u32,
883        source: ProtoName,
884        destination: ProtoName,
885        session_type: ProtoSessionType,
886        mls_settings: Option<MlsSettings>,
887        initiator: bool,
888        max_retries: Option<u32>,
889        interval: Option<Duration>,
890        metadata: HashMap<String, String>,
891        graceful_shutdown_timeout: Option<Duration>,
892    }
893
894    impl SessionControllerTestBuilder {
895        #[allow(dead_code)]
896        fn new() -> Self {
897            Self {
898                session_id: 10,
899                source: ProtoName::from_strings(["org", "ns", "source"]).with_id(1),
900                destination: ProtoName::from_strings(["org", "ns", "dest"]).with_id(2),
901                session_type: ProtoSessionType::PointToPoint,
902                mls_settings: None,
903                initiator: true,
904                max_retries: Some(5),
905                interval: Some(Duration::from_millis(200)),
906                metadata: HashMap::new(),
907                graceful_shutdown_timeout: Some(Duration::from_secs(10)),
908            }
909        }
910
911        fn with_session_id(mut self, id: u32) -> Self {
912            self.session_id = id;
913            self
914        }
915
916        #[allow(dead_code)]
917        fn with_source(mut self, source: ProtoName) -> Self {
918            self.source = source;
919            self
920        }
921
922        #[allow(dead_code)]
923        fn with_destination(mut self, destination: ProtoName) -> Self {
924            self.destination = destination;
925            self
926        }
927
928        fn with_session_type(mut self, session_type: ProtoSessionType) -> Self {
929            self.session_type = session_type;
930            self
931        }
932
933        fn with_mls_enabled(mut self, enabled: bool) -> Self {
934            self.mls_settings = if enabled {
935                Some(MlsSettings::default())
936            } else {
937                None
938            };
939            self
940        }
941
942        fn with_initiator(mut self, initiator: bool) -> Self {
943            self.initiator = initiator;
944            self
945        }
946
947        fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
948            self.metadata = metadata;
949            self
950        }
951
952        fn with_graceful_shutdown_timeout(mut self, timeout: Duration) -> Self {
953            self.graceful_shutdown_timeout = Some(timeout);
954            self
955        }
956
957        fn build(
958            self,
959        ) -> (
960            SessionController,
961            tokio::sync::mpsc::Receiver<Result<Message, slim_datapath::Status>>,
962            tokio::sync::mpsc::UnboundedReceiver<Result<Message, SessionError>>,
963        ) {
964            let config = SessionConfig {
965                session_type: self.session_type,
966                max_retries: self.max_retries,
967                interval: self.interval,
968                mls_settings: self.mls_settings,
969                initiator: self.initiator,
970                metadata: self.metadata,
971            };
972
973            let (tx_slim, rx_slim) = tokio::sync::mpsc::channel(10);
974            let (tx_app, rx_app) = tokio::sync::mpsc::unbounded_channel();
975            let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10);
976
977            let controller = SessionController::builder()
978                .with_id(self.session_id)
979                .with_source(self.source.clone())
980                .with_destination(self.destination.clone())
981                .with_config(config)
982                .with_identity_provider(SharedSecret::new("test", SHARED_SECRET).unwrap())
983                .with_identity_verifier(SharedSecret::new("test", SHARED_SECRET).unwrap())
984                .with_slim_tx(tx_slim)
985                .with_app_tx(tx_app)
986                .with_tx_to_session_layer(tx_session_layer)
987                .ready()
988                .expect("failed to validate builder")
989                .build()
990                .expect("failed to build controller");
991
992            (controller, rx_slim, rx_app)
993        }
994    }
995
996    #[tokio::test]
997    async fn test_session_controller_getters() {
998        let mut metadata = HashMap::new();
999        metadata.insert("key1".to_string(), "value1".to_string());
1000
1001        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1002            .with_session_id(42)
1003            .with_session_type(ProtoSessionType::Multicast)
1004            .with_mls_enabled(true)
1005            .with_metadata(metadata)
1006            .build();
1007
1008        assert_eq!(controller.id(), 42);
1009        assert_eq!(
1010            controller.source(),
1011            &ProtoName::from_strings(["org", "ns", "source"]).with_id(1)
1012        );
1013        // For multicast sessions, destination uses DATA_CHANNEL_ID
1014        assert_eq!(
1015            controller.dst(),
1016            &ProtoName::from_strings(["org", "ns", "dest"]).with_id(NameId::DATA_CHANNEL_ID)
1017        );
1018        assert_eq!(controller.session_type(), ProtoSessionType::Multicast);
1019        assert!(controller.is_initiator());
1020        assert_eq!(
1021            controller.metadata().get("key1"),
1022            Some(&"value1".to_string())
1023        );
1024
1025        let retrieved_config = controller.session_config();
1026        assert_eq!(retrieved_config.session_type, ProtoSessionType::Multicast);
1027        assert_eq!(retrieved_config.max_retries, Some(5));
1028    }
1029
1030    #[tokio::test]
1031    async fn test_publish_basic() {
1032        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1033
1034        let target_name = ProtoName::from_strings(["org", "ns", "target"]);
1035        let payload = b"Hello World".to_vec();
1036
1037        controller
1038            .publish(
1039                &target_name,
1040                payload.clone(),
1041                Some("test-type".to_string()),
1042                None,
1043            )
1044            .await
1045            .expect("publish should succeed");
1046
1047        tokio::time::sleep(Duration::from_millis(50)).await;
1048    }
1049
1050    #[tokio::test]
1051    async fn test_publish_to_specific_connection() {
1052        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1053            .with_session_type(ProtoSessionType::Multicast)
1054            .build();
1055
1056        let target_name = ProtoName::from_strings(["org", "ns", "target"]);
1057        let payload = b"Hello to specific connection".to_vec();
1058        let connection_id = 123u64;
1059
1060        controller
1061            .publish_to(
1062                &target_name,
1063                connection_id,
1064                payload.clone(),
1065                Some("test-type".to_string()),
1066                None,
1067            )
1068            .await
1069            .expect("publish_to should succeed");
1070
1071        tokio::time::sleep(Duration::from_millis(50)).await;
1072    }
1073
1074    #[tokio::test]
1075    async fn test_publish_with_metadata() {
1076        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1077            .with_session_type(ProtoSessionType::Multicast)
1078            .build();
1079
1080        let target_name = ProtoName::from_strings(["org", "ns", "target"]);
1081        let payload = b"Hello with metadata".to_vec();
1082
1083        let mut metadata = HashMap::new();
1084        metadata.insert("custom_key".to_string(), "custom_value".to_string());
1085
1086        controller
1087            .publish(
1088                &target_name,
1089                payload.clone(),
1090                Some("test-type".to_string()),
1091                Some(metadata),
1092            )
1093            .await
1094            .expect("publish with metadata should succeed");
1095
1096        tokio::time::sleep(Duration::from_millis(50)).await;
1097    }
1098
1099    #[tokio::test]
1100    async fn test_invite_participant_in_multicast() {
1101        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1102            .with_session_type(ProtoSessionType::Multicast)
1103            .build();
1104
1105        let participant = ProtoName::from_strings(["org", "ns", "participant"]);
1106
1107        controller
1108            .invite_participant(&participant)
1109            .await
1110            .expect("invite should succeed");
1111
1112        tokio::time::sleep(Duration::from_millis(50)).await;
1113    }
1114
1115    #[tokio::test]
1116    async fn test_invite_participant_not_initiator_error() {
1117        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1118            .with_session_type(ProtoSessionType::Multicast)
1119            .with_initiator(false)
1120            .build();
1121
1122        let participant = ProtoName::from_strings(["org", "ns", "new_participant"]);
1123
1124        let result = controller.invite_participant(&participant).await;
1125        assert!(result.is_err_and(|e| matches!(e, SessionError::NotInitiator)));
1126    }
1127
1128    #[tokio::test]
1129    async fn test_invite_participant_p2p_error() {
1130        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1131            .with_session_type(ProtoSessionType::PointToPoint)
1132            .build();
1133
1134        let participant = ProtoName::from_strings(["org", "ns", "participant"]);
1135
1136        let result = controller.invite_participant(&participant).await;
1137        assert!(result.is_err_and(|e| matches!(e, SessionError::CannotInviteToP2P)));
1138    }
1139
1140    #[tokio::test]
1141    async fn test_remove_participant_in_multicast() {
1142        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1143            .with_session_type(ProtoSessionType::Multicast)
1144            .build();
1145
1146        let participant = ProtoName::from_strings(["org", "ns", "participant"]);
1147
1148        controller
1149            .remove_participant(&participant)
1150            .await
1151            .expect("remove should succeed");
1152
1153        tokio::time::sleep(Duration::from_millis(50)).await;
1154    }
1155
1156    #[tokio::test]
1157    async fn test_remove_participant_not_initiator_error() {
1158        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1159            .with_session_type(ProtoSessionType::Multicast)
1160            .with_initiator(false)
1161            .build();
1162
1163        let participant = ProtoName::from_strings(["org", "ns", "participant"]);
1164
1165        let result = controller.remove_participant(&participant).await;
1166        assert!(result.is_err_and(|e| matches!(e, SessionError::NotInitiator)));
1167    }
1168
1169    #[tokio::test]
1170    async fn test_remove_participant_p2p_error() {
1171        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1172            .with_session_type(ProtoSessionType::PointToPoint)
1173            .build();
1174
1175        let participant = ProtoName::from_strings(["org", "ns", "participant"]);
1176
1177        let result = controller.remove_participant(&participant).await;
1178        assert!(result.is_err_and(|e| matches!(e, SessionError::CannotRemoveFromP2P)));
1179    }
1180
1181    #[test]
1182    fn test_handle_channel_discovery_message() {
1183        let app_name = ProtoName::from_strings(["org", "ns", "app"]).with_id(100);
1184        let session_id = 42;
1185
1186        let discovery_request = Message::builder()
1187            .source(ProtoName::from_strings(["org", "ns", "requester"]).with_id(1))
1188            .destination(ProtoName::from_strings(["org", "ns", "service"]))
1189            .identity(test_identity())
1190            .incoming_conn(999)
1191            .session_type(ProtoSessionType::Multicast)
1192            .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
1193            .session_id(session_id)
1194            .message_id(123)
1195            .payload(CommandPayload::builder().discovery_request().as_content())
1196            .build_publish()
1197            .unwrap();
1198
1199        let response = handle_channel_discovery_message(
1200            &discovery_request,
1201            &app_name,
1202            session_id,
1203            ProtoSessionType::Multicast,
1204        )
1205        .expect("should create discovery response");
1206
1207        assert_eq!(
1208            response.get_session_message_type(),
1209            ProtoSessionMessageType::DiscoveryReply
1210        );
1211        assert_eq!(response.get_session_header().get_session_id(), session_id);
1212        assert_eq!(response.get_id(), 123);
1213        assert_eq!(
1214            response.get_dst(),
1215            ProtoName::from_strings(["org", "ns", "requester"]).with_id(1)
1216        );
1217        assert_eq!(response.get_slim_header().get_forward_to(), Some(999));
1218    }
1219
1220    #[tokio::test]
1221    async fn test_controller_drop_cancels_processing() {
1222        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1223
1224        let token = controller.cancellation_token.clone();
1225        assert!(!token.is_cancelled());
1226
1227        drop(controller);
1228
1229        tokio::time::sleep(Duration::from_millis(100)).await;
1230        assert!(token.is_cancelled());
1231    }
1232
1233    #[tokio::test]
1234    async fn test_close_success() {
1235        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1236            .with_graceful_shutdown_timeout(std::time::Duration::from_secs(2))
1237            .build();
1238
1239        let token = controller.cancellation_token.clone();
1240        assert!(!token.is_cancelled());
1241
1242        let handle = controller.close();
1243        assert!(handle.is_ok(), "got error {}", handle.unwrap_err());
1244        assert!(token.is_cancelled());
1245
1246        // Wait for the handle to complete
1247        handle
1248            .unwrap()
1249            .await
1250            .expect("processing task should complete");
1251    }
1252
1253    #[tokio::test]
1254    async fn test_close_already_closed() {
1255        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1256
1257        // Close once - should succeed
1258        let handle = controller.close();
1259        assert!(handle.is_ok());
1260        handle
1261            .unwrap()
1262            .await
1263            .expect("processing task should complete");
1264
1265        // Close again - should fail with appropriate error
1266        let result = controller.close();
1267        assert!(result.is_err());
1268        match result {
1269            Err(SessionError::SessionAlreadyClosed) => {
1270                // expected
1271            }
1272            _ => panic!("Expected SessionError::SessionAlreadyClosed"),
1273        }
1274    }
1275
1276    #[tokio::test]
1277    async fn test_close_cancels_token_immediately() {
1278        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1279
1280        let token = controller.cancellation_token.clone();
1281
1282        // Verify token is not cancelled before close
1283        assert!(!token.is_cancelled());
1284
1285        // Close returns immediately after cancelling token
1286        let handle = controller.close();
1287        assert!(handle.is_ok());
1288
1289        // Token should be cancelled immediately
1290        assert!(token.is_cancelled());
1291
1292        // Wait for processing to complete
1293        handle.unwrap().await.expect("processing should complete");
1294    }
1295
1296    #[tokio::test]
1297    async fn test_on_message_direction_north() {
1298        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
1299
1300        let test_message = Message::builder()
1301            .source(controller.dst().clone())
1302            .destination(controller.source().clone())
1303            .identity(test_identity())
1304            .session_type(ProtoSessionType::PointToPoint)
1305            .session_message_type(ProtoSessionMessageType::Msg)
1306            .session_id(controller.id())
1307            .message_id(1)
1308            .application_payload("test", b"test data".to_vec())
1309            .build_publish()
1310            .unwrap();
1311
1312        let result = controller.on_message_from_slim(test_message).await;
1313        assert!(result.is_ok());
1314    }
1315
1316    #[tokio::test]
1317    async fn test_create_discovery_request() {
1318        let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
1319            .with_session_type(ProtoSessionType::Multicast)
1320            .build();
1321
1322        let target = ProtoName::from_strings(["org", "ns", "target"]);
1323        let discovery_msg = controller
1324            .create_discovery_request(&target)
1325            .expect("should create discovery request");
1326
1327        assert_eq!(discovery_msg.get_source(), *controller.source());
1328        assert_eq!(discovery_msg.get_dst(), target);
1329        assert_eq!(
1330            discovery_msg.get_session_message_type(),
1331            ProtoSessionMessageType::DiscoveryRequest
1332        );
1333        assert_eq!(
1334            discovery_msg.get_session_header().get_session_id(),
1335            controller.id()
1336        );
1337        assert_eq!(
1338            discovery_msg.get_session_type(),
1339            ProtoSessionType::Multicast
1340        );
1341    }
1342
1343    #[tokio::test]
1344    #[traced_test]
1345    async fn test_end_to_end_p2p() {
1346        let session_id = 10;
1347        let moderator_name = ProtoName::from_strings(["org", "ns", "moderator"]).with_id(1);
1348        let participant_name = ProtoName::from_strings(["org", "ns", "participant"]);
1349        let participant_name_id = ProtoName::from_strings(["org", "ns", "participant"]).with_id(1);
1350        // create a SessionModerator
1351        let (tx_slim_moderator, mut rx_slim_moderator) = tokio::sync::mpsc::channel(10);
1352        let (tx_app_moderator, _rx_app_moderator) = tokio::sync::mpsc::unbounded_channel();
1353        let (tx_session_layer_moderator, _rx_session_layer_moderator) =
1354            tokio::sync::mpsc::channel(10);
1355
1356        let moderator_config = SessionConfig {
1357            session_type: slim_datapath::api::ProtoSessionType::PointToPoint,
1358            max_retries: Some(5),
1359            interval: Some(Duration::from_millis(1000)),
1360            mls_settings: Some(MlsSettings::default()),
1361            initiator: true,
1362            metadata: std::collections::HashMap::new(),
1363        };
1364
1365        let (spy_moderator_mgr, mut rx_spy_moderator) = SpySubscriptionManager::new();
1366        let moderator = SessionController::builder()
1367            .with_id(session_id)
1368            .with_source(moderator_name.clone())
1369            .with_destination(participant_name.clone())
1370            .with_config(moderator_config)
1371            .with_identity_provider(SharedSecret::new("moderator", SHARED_SECRET).unwrap())
1372            .with_identity_verifier(SharedSecret::new("moderator", SHARED_SECRET).unwrap())
1373            .with_slim_tx(tx_slim_moderator.clone())
1374            .with_app_tx(tx_app_moderator.clone())
1375            .with_tx_to_session_layer(tx_session_layer_moderator)
1376            .with_subscription_manager(spy_moderator_mgr)
1377            .ready()
1378            .expect("failed to validate builder")
1379            .build()
1380            .unwrap();
1381
1382        // create a SessionParticipant
1383        let (tx_slim_participant, mut rx_slim_participant) = tokio::sync::mpsc::channel(10);
1384        let (tx_app_participant, mut rx_app_participant) = tokio::sync::mpsc::unbounded_channel();
1385        let (tx_session_layer_participant, _rx_session_layer_participant) =
1386            tokio::sync::mpsc::channel(10);
1387
1388        let participant_config = SessionConfig {
1389            session_type: slim_datapath::api::ProtoSessionType::PointToPoint,
1390            max_retries: Some(5),
1391            interval: Some(Duration::from_millis(200)),
1392            mls_settings: Some(MlsSettings::default()),
1393            initiator: false,
1394            metadata: std::collections::HashMap::new(),
1395        };
1396
1397        let (spy_participant_mgr, mut rx_spy_participant) = SpySubscriptionManager::new();
1398        let participant = SessionController::builder()
1399            .with_id(session_id)
1400            .with_source(participant_name_id.clone())
1401            .with_destination(moderator_name.clone())
1402            .with_config(participant_config)
1403            .with_identity_provider(SharedSecret::new("participant", SHARED_SECRET).unwrap())
1404            .with_identity_verifier(SharedSecret::new("participant", SHARED_SECRET).unwrap())
1405            .with_slim_tx(tx_slim_participant.clone())
1406            .with_app_tx(tx_app_participant.clone())
1407            .with_tx_to_session_layer(tx_session_layer_participant)
1408            .with_subscription_manager(spy_participant_mgr)
1409            .ready()
1410            .expect("failed to validate builder")
1411            .build()
1412            .unwrap();
1413
1414        let completion_handle = moderator
1415            .invite_participant_internal(&participant_name)
1416            .await
1417            .expect("error inviting participant");
1418
1419        let received_discovery_request =
1420            timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1421                .await
1422                .expect("timeout waiting for discovery request on moderator slim channel")
1423                .expect("channel closed")
1424                .expect("error in discovery request");
1425
1426        assert_eq!(
1427            received_discovery_request.get_session_message_type(),
1428            slim_datapath::api::ProtoSessionMessageType::DiscoveryRequest
1429        );
1430
1431        let discovery_msg_id = received_discovery_request.get_id();
1432
1433        // create a discovery reply and call the on message on the moderator with the reply (direction north)
1434        let mut discovery_reply = Message::builder()
1435            .source(participant_name_id.clone())
1436            .destination(moderator_name.clone())
1437            .identity(test_identity())
1438            .forward_to(1)
1439            .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1440            .session_message_type(slim_datapath::api::ProtoSessionMessageType::DiscoveryReply)
1441            .session_id(session_id)
1442            .message_id(discovery_msg_id)
1443            .payload(CommandPayload::builder().discovery_reply().as_content())
1444            .build_publish()
1445            .unwrap();
1446        discovery_reply
1447            .get_slim_header_mut()
1448            .set_incoming_conn(Some(1));
1449
1450        moderator
1451            .on_message_from_slim(discovery_reply)
1452            .await
1453            .expect("error processing discovery reply on moderator");
1454
1455        // moderator sets route for participant after discovery reply
1456        assert_eq!(
1457            rx_spy_moderator.recv().await,
1458            Some(SubscriptionCall::SetRoute),
1459            "moderator should set route after discovery reply"
1460        );
1461
1462        // check that a join request is received by slim
1463        let join_request = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1464            .await
1465            .expect("timeout waiting for join request on moderator slim channel")
1466            .expect("channel closed")
1467            .expect("error in join request");
1468
1469        assert_eq!(
1470            join_request.get_session_message_type(),
1471            slim_datapath::api::ProtoSessionMessageType::JoinRequest
1472        );
1473        assert_eq!(join_request.get_dst(), participant_name_id);
1474
1475        // call the on message on the participant side with the join request (direction north)
1476        let mut join_request_to_participant = join_request.clone();
1477        join_request_to_participant
1478            .get_slim_header_mut()
1479            .set_incoming_conn(Some(1));
1480
1481        participant
1482            .on_message_from_slim(join_request_to_participant)
1483            .await
1484            .expect("error processing join request on participant");
1485
1486        // participant sets route for moderator after join request
1487        assert_eq!(
1488            rx_spy_participant.recv().await,
1489            Some(SubscriptionCall::SetRoute),
1490            "participant should set route after join request"
1491        );
1492
1493        // check that a join reply is received by slim on the participant
1494        let join_reply = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1495            .await
1496            .expect("timeout waiting for join reply on participant slim channel")
1497            .expect("channel closed")
1498            .expect("error in join reply");
1499
1500        assert_eq!(
1501            join_reply.get_session_message_type(),
1502            slim_datapath::api::ProtoSessionMessageType::JoinReply
1503        );
1504        assert_eq!(join_reply.get_dst(), moderator_name);
1505
1506        // call the on message on the moderator with the reply (direction north)
1507        let mut join_reply_to_moderator = join_reply.clone();
1508        join_reply_to_moderator
1509            .get_slim_header_mut()
1510            .set_incoming_conn(Some(1));
1511
1512        moderator
1513            .on_message_from_slim(join_reply_to_moderator)
1514            .await
1515            .expect("error processing join reply on moderator");
1516
1517        // check that a welcome message is received by slim on the moderator
1518        let welcome_message = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1519            .await
1520            .expect("timeout waiting for welcome message on moderator slim channel")
1521            .expect("channel closed")
1522            .expect("error in welcome message");
1523
1524        assert_eq!(
1525            welcome_message.get_session_message_type(),
1526            slim_datapath::api::ProtoSessionMessageType::GroupWelcome
1527        );
1528        assert_eq!(welcome_message.get_dst(), participant_name_id);
1529
1530        // call the on message on the participant side with the welcome message (direction north)
1531        let mut welcome_to_participant = welcome_message.clone();
1532        welcome_to_participant
1533            .get_slim_header_mut()
1534            .set_incoming_conn(Some(1));
1535
1536        participant
1537            .on_message_from_slim(welcome_to_participant)
1538            .await
1539            .expect("error processing welcome message on participant");
1540
1541        // check that an ack group is received by slim on the participant
1542        let ack_group = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1543            .await
1544            .expect("timeout waiting for ack group on participant slim channel")
1545            .expect("channel closed")
1546            .expect("error in ack group");
1547
1548        assert_eq!(
1549            ack_group.get_session_message_type(),
1550            slim_datapath::api::ProtoSessionMessageType::GroupAck
1551        );
1552        assert_eq!(ack_group.get_dst(), moderator_name);
1553
1554        // call the on message on the moderator with the ack (direction north)
1555        let mut ack_to_moderator = ack_group.clone();
1556        ack_to_moderator
1557            .get_slim_header_mut()
1558            .set_incoming_conn(Some(1));
1559
1560        moderator
1561            .on_message_from_slim(ack_to_moderator)
1562            .await
1563            .expect("error processing ack group on moderator");
1564
1565        // no other message should be sent
1566        let no_more_moderator = timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1567        assert!(
1568            no_more_moderator.is_err(),
1569            "Expected no more messages on moderator slim channel, received {:?}",
1570            no_more_moderator
1571                .ok()
1572                .and_then(|opt| opt)
1573                .and_then(|res| res.ok())
1574        );
1575
1576        let no_more_participant =
1577            timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1578        assert!(
1579            no_more_participant.is_err(),
1580            "Expected no more messages on participant slim channel"
1581        );
1582
1583        // the completion handler should now be complete
1584        completion_handle.await.expect("error in completion handle");
1585
1586        // create an application message using the participant name
1587        let app_data = b"Hello from moderator to participant".to_vec();
1588        let app_message = Message::builder()
1589            .source(moderator_name.clone())
1590            .destination(participant_name.clone())
1591            .identity(test_identity())
1592            .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1593            .session_message_type(slim_datapath::api::ProtoSessionMessageType::Msg)
1594            .session_id(session_id)
1595            .message_id(1)
1596            .application_payload("test-app-data", app_data.clone())
1597            .build_publish()
1598            .unwrap();
1599
1600        // call on message on the moderator (direction south)
1601        moderator
1602            .on_message_from_app(app_message)
1603            .await
1604            .expect("error sending application message from moderator");
1605
1606        // check that message is received from slim with destination equal to participant name id
1607        let app_msg_to_slim = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1608            .await
1609            .expect("timeout waiting for application message on moderator slim channel")
1610            .expect("channel closed")
1611            .expect("error in application message");
1612
1613        assert_eq!(app_msg_to_slim.get_dst(), participant_name_id);
1614        assert!(
1615            app_msg_to_slim.is_publish(),
1616            "message should be a publish message"
1617        );
1618
1619        let app_msg_id = app_msg_to_slim.get_id();
1620
1621        // call the on message on the participant (direction north)
1622        let mut app_msg_to_participant = app_msg_to_slim.clone();
1623        app_msg_to_participant
1624            .get_slim_header_mut()
1625            .set_incoming_conn(Some(1));
1626
1627        participant
1628            .on_message_from_slim(app_msg_to_participant)
1629            .await
1630            .expect("error processing application message on participant");
1631
1632        // check that the message is received by the application
1633        let app_msg_received = timeout(Duration::from_millis(100), rx_app_participant.recv())
1634            .await
1635            .expect("timeout waiting for application message on participant app channel")
1636            .expect("channel closed")
1637            .expect("error in application message to app");
1638
1639        assert_eq!(app_msg_received.get_source(), moderator_name);
1640        assert!(
1641            app_msg_received.is_publish(),
1642            "message should be a publish message"
1643        );
1644
1645        let content = app_msg_received
1646            .get_payload()
1647            .unwrap()
1648            .as_application_payload()
1649            .unwrap()
1650            .blob
1651            .clone();
1652        assert_eq!(content, app_data);
1653
1654        // check that an ack is sent to slim
1655        let ack_msg = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1656            .await
1657            .expect("timeout waiting for ack on participant slim channel")
1658            .expect("channel closed")
1659            .expect("error in ack");
1660
1661        assert_eq!(
1662            ack_msg.get_session_message_type(),
1663            slim_datapath::api::ProtoSessionMessageType::MsgAck,
1664            "message should be an ack"
1665        );
1666        assert_eq!(ack_msg.get_dst(), moderator_name);
1667        assert_eq!(ack_msg.get_id(), app_msg_id);
1668
1669        // call the on message with the ack on the moderator
1670        let mut ack_to_moderator = ack_msg.clone();
1671        ack_to_moderator
1672            .get_slim_header_mut()
1673            .set_incoming_conn(Some(1));
1674
1675        moderator
1676            .on_message_from_slim(ack_to_moderator)
1677            .await
1678            .expect("error processing ack on moderator");
1679
1680        // check that no other message is generated
1681        let no_more_moderator_after_ack =
1682            timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1683        assert!(
1684            no_more_moderator_after_ack.is_err(),
1685            "Expected no more messages on moderator slim channel after ack"
1686        );
1687
1688        let no_more_participant_after_ack =
1689            timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1690        assert!(
1691            no_more_participant_after_ack.is_err(),
1692            "Expected no more messages on participant slim channel after ack"
1693        );
1694
1695        // create a leave request and send to moderator on message (direction south)
1696        let leave_request = Message::builder()
1697            .source(moderator_name.clone())
1698            .destination(participant_name.clone())
1699            .identity(test_identity())
1700            .session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
1701            .session_message_type(slim_datapath::api::ProtoSessionMessageType::LeaveRequest)
1702            .session_id(session_id)
1703            .message_id(rand::random::<u32>())
1704            .payload(CommandPayload::builder().leave_request().as_content())
1705            .build_publish()
1706            .unwrap();
1707
1708        moderator
1709            .on_message_from_app(leave_request)
1710            .await
1711            .expect("error sending leave request");
1712
1713        // check that the request is received by slim on the moderator
1714        let received_leave_request = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
1715            .await
1716            .expect("timeout waiting for leave request on moderator slim channel")
1717            .expect("channel closed")
1718            .expect("error in leave request");
1719
1720        assert_eq!(
1721            received_leave_request.get_session_message_type(),
1722            slim_datapath::api::ProtoSessionMessageType::LeaveRequest
1723        );
1724        assert_eq!(received_leave_request.get_dst(), participant_name_id);
1725
1726        // send the request to the participant (direction north)
1727        let mut leave_request_to_participant = received_leave_request.clone();
1728        leave_request_to_participant
1729            .get_slim_header_mut()
1730            .set_incoming_conn(Some(1));
1731
1732        participant
1733            .on_message_from_slim(leave_request_to_participant)
1734            .await
1735            .expect("error processing leave request on participant");
1736
1737        // get the leave reply on the participant slim
1738        let leave_reply = timeout(Duration::from_millis(100), rx_slim_participant.recv())
1739            .await
1740            .expect("timeout waiting for leave reply on participant slim channel")
1741            .expect("channel closed")
1742            .expect("error in leave reply");
1743
1744        assert_eq!(
1745            leave_reply.get_session_message_type(),
1746            slim_datapath::api::ProtoSessionMessageType::LeaveReply
1747        );
1748        assert_eq!(leave_reply.get_dst(), moderator_name);
1749
1750        // participant removes route after processing leave request
1751        assert_eq!(
1752            rx_spy_participant.recv().await,
1753            Some(SubscriptionCall::RemoveRoute),
1754            "participant should remove route after leave request"
1755        );
1756
1757        // send the leave reply to the moderator on message (direction north)
1758        let mut leave_reply_to_moderator = leave_reply.clone();
1759        leave_reply_to_moderator
1760            .get_slim_header_mut()
1761            .set_incoming_conn(Some(1));
1762
1763        moderator
1764            .on_message_from_slim(leave_reply_to_moderator)
1765            .await
1766            .expect("error processing leave reply on moderator");
1767
1768        // moderator removes route after processing leave reply
1769        assert_eq!(
1770            rx_spy_moderator.recv().await,
1771            Some(SubscriptionCall::RemoveRoute),
1772            "moderator should remove route after leave reply"
1773        );
1774
1775        // check that no other messages are generated by the moderator
1776        let no_more_moderator_final =
1777            timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
1778
1779        assert!(
1780            no_more_moderator_final.is_err(),
1781            "Expected no more messages on moderator slim channel after leave"
1782        );
1783
1784        let no_more_participant_final =
1785            timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
1786        assert!(
1787            no_more_participant_final.is_err(),
1788            "Expected no more messages on participant slim channel after leave"
1789        );
1790    }
1791
1792    // ============================================================================
1793    // Draining Tests
1794    #[traced_test]
1795    #[tokio::test]
1796    async fn test_internal_draining_via_processing_state_switch() {
1797        use super::*;
1798        use tokio::sync::mpsc;
1799        use tracing::debug;
1800
1801        // Custom handler that flips processing_state to Draining after first normal message
1802        struct InternalDrainHandler {
1803            state: ProcessingState,
1804            messages: Vec<SessionMessage>,
1805            needs_drain: Arc<AtomicBool>,
1806        }
1807
1808        impl InternalDrainHandler {
1809            fn new(needs_drain: Arc<AtomicBool>) -> Self {
1810                Self {
1811                    state: ProcessingState::Active,
1812                    messages: vec![],
1813                    needs_drain,
1814                }
1815            }
1816        }
1817
1818        impl MessageHandler for InternalDrainHandler {
1819            async fn init(&mut self) -> Result<(), SessionError> {
1820                Ok(())
1821            }
1822
1823            async fn on_message(
1824                &mut self,
1825                message: SessionMessage,
1826            ) -> Result<SessionOutput, SessionError> {
1827                debug!(?self.state, "internal-drain-handler received message");
1828                self.messages.push(message);
1829
1830                // when we receive 2 messages, transition to draining state
1831                if self.messages.len() == 2 {
1832                    debug!("internal-drain-handler transitioning to draining");
1833                    self.state = ProcessingState::Draining;
1834                }
1835
1836                Ok(SessionOutput::new())
1837            }
1838
1839            fn needs_drain(&self) -> bool {
1840                self.needs_drain.load(std::sync::atomic::Ordering::SeqCst)
1841            }
1842
1843            fn processing_state(&self) -> ProcessingState {
1844                self.state
1845            }
1846
1847            async fn on_shutdown(&mut self) -> Result<(), SessionError> {
1848                debug!("shutdown called on handler");
1849                Ok(())
1850            }
1851        }
1852
1853        // Build minimal SessionSettings
1854        let (tx_slim, _rx_slim) = mpsc::channel(8);
1855        let (tx_app, _rx_app) = mpsc::unbounded_channel();
1856        let (tx_session, rx_session) = mpsc::channel(32);
1857        let (tx_session_layer, _rx_session_layer) = mpsc::channel(8);
1858
1859        let subscription_manager =
1860            crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
1861        let settings = SessionSettings {
1862            id: 999,
1863            source: ProtoName::from_strings(["org", "ns", "source"]).with_id(1),
1864            destination: ProtoName::from_strings(["org", "ns", "dest"]).with_id(2),
1865            control: ProtoName::from_strings(["org", "ns", "dest"]).with_id(2),
1866            config: SessionConfig {
1867                session_type: ProtoSessionType::PointToPoint,
1868                max_retries: Some(3),
1869                interval: Some(Duration::from_millis(150)),
1870                mls_settings: None,
1871                initiator: true,
1872                metadata: HashMap::new(),
1873            },
1874            direction: Direction::Bidirectional,
1875            slim_tx: tx_slim,
1876            app_tx: tx_app,
1877            tx_session: tx_session.clone(),
1878            tx_to_session_layer: tx_session_layer,
1879            identity_provider: SharedSecret::new("src", SHARED_SECRET).unwrap(),
1880            identity_verifier: SharedSecret::new("src", SHARED_SECRET).unwrap(),
1881            graceful_shutdown_timeout: Some(Duration::from_secs(10)),
1882            subscription_manager,
1883            service_id: String::new(),
1884        };
1885
1886        let needs_drain = Arc::new(AtomicBool::new(true));
1887        let handler = InternalDrainHandler::new(needs_drain.clone());
1888        let cancellation_token = CancellationToken::new();
1889        let cancellation_token_clone = cancellation_token.clone();
1890
1891        // Spawn processing loop without unnecessary cloning
1892        let processing_handle = tokio::spawn(async move {
1893            SessionController::processing_loop(
1894                handler,
1895                rx_session,
1896                cancellation_token_clone,
1897                settings,
1898            )
1899            .await
1900        });
1901
1902        // Send first regular message
1903        tx_session
1904            .send(create_test_message(1, b"first".to_vec()))
1905            .await
1906            .expect("failed to send first message");
1907
1908        tokio::time::sleep(Duration::from_millis(100)).await;
1909
1910        assert!(logs_contain("internal-drain-handler received message"));
1911
1912        // Send second message; this causes internal handler move to draining (active -> draining)
1913        tx_session
1914            .send(create_test_message(2, b"second".to_vec()))
1915            .await
1916            .expect("failed to send second message");
1917
1918        tokio::time::sleep(Duration::from_millis(100)).await;
1919
1920        assert!(logs_contain("internal-drain-handler received message"));
1921
1922        assert!(logs_contain(
1923            "internal-drain-handler transitioning to draining"
1924        ));
1925
1926        // Send a third message that should not be processed, as draining is active
1927        tx_session
1928            .send(create_test_message(3, b"third".to_vec()))
1929            .await
1930            .expect("failed to send third message");
1931
1932        tokio::time::sleep(Duration::from_millis(100)).await;
1933
1934        assert!(logs_contain(
1935            "session is draining, rejecting new messages from application"
1936        ));
1937
1938        // set needs drain to false to allow shutdown to complete
1939        needs_drain.store(false, std::sync::atomic::Ordering::SeqCst);
1940
1941        // trigger cancellation to exit processing loop
1942        cancellation_token.cancel();
1943
1944        tokio::time::sleep(Duration::from_millis(100)).await;
1945
1946        // Send a session message to trigger the shutdown process
1947        tx_session
1948            .send(SessionMessage::StartDrain {
1949                grace_period: std::time::Duration::from_millis(100),
1950            })
1951            .await
1952            .expect("failed to send timeout message");
1953
1954        // Wait for processing loop to complete
1955        processing_handle.await.expect("processing loop panicked");
1956    }
1957    // ============================================================================
1958
1959    /// Mock handler that tracks draining behavior
1960    struct DrainableHandler {
1961        messages_received: Arc<tokio::sync::Mutex<Vec<SessionMessage>>>,
1962        needs_drain: Arc<AtomicBool>,
1963        shutdown_called: Arc<tokio::sync::Mutex<bool>>,
1964        drain_delay: Option<Duration>,
1965    }
1966
1967    impl DrainableHandler {
1968        fn new() -> Self {
1969            Self {
1970                messages_received: Arc::new(tokio::sync::Mutex::new(Vec::new())),
1971                needs_drain: Arc::new(AtomicBool::new(false)),
1972                shutdown_called: Arc::new(tokio::sync::Mutex::new(false)),
1973                drain_delay: None,
1974            }
1975        }
1976
1977        fn with_needs_drain(self, needs_drain: bool) -> Self {
1978            self.needs_drain
1979                .store(needs_drain, std::sync::atomic::Ordering::SeqCst);
1980            self
1981        }
1982
1983        #[allow(dead_code)]
1984        fn with_drain_delay(mut self, delay: Duration) -> Self {
1985            self.drain_delay = Some(delay);
1986            self
1987        }
1988
1989        #[allow(dead_code)]
1990        async fn get_messages_count(&self) -> usize {
1991            self.messages_received.lock().await.len()
1992        }
1993
1994        #[allow(dead_code)]
1995        async fn was_shutdown_called(&self) -> bool {
1996            *self.shutdown_called.lock().await
1997        }
1998    }
1999
2000    impl MessageHandler for DrainableHandler {
2001        async fn init(&mut self) -> Result<(), SessionError> {
2002            Ok(())
2003        }
2004
2005        async fn on_message(
2006            &mut self,
2007            message: SessionMessage,
2008        ) -> Result<SessionOutput, SessionError> {
2009            self.messages_received.lock().await.push(message);
2010            Ok(SessionOutput::new())
2011        }
2012
2013        fn needs_drain(&self) -> bool {
2014            self.needs_drain.load(std::sync::atomic::Ordering::SeqCst)
2015        }
2016
2017        async fn on_shutdown(&mut self) -> Result<(), SessionError> {
2018            if let Some(delay) = self.drain_delay {
2019                tokio::time::sleep(delay).await;
2020            }
2021            *self.shutdown_called.lock().await = true;
2022            Ok(())
2023        }
2024    }
2025
2026    /// Helper to create test SessionSettings
2027    fn create_test_settings(
2028        graceful_shutdown_timeout: Option<Duration>,
2029    ) -> SessionSettings<SharedSecret, SharedSecret> {
2030        let (tx_slim, _rx_slim) = tokio::sync::mpsc::channel(10);
2031        let (tx_app, _rx_app) = tokio::sync::mpsc::unbounded_channel();
2032        let (tx_session, _rx_session) = tokio::sync::mpsc::channel(10);
2033        let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10);
2034
2035        let subscription_manager =
2036            crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
2037        SessionSettings {
2038            id: 1,
2039            source: ProtoName::from_strings(["org", "ns", "test"]).with_id(1),
2040            destination: ProtoName::from_strings(["org", "ns", "test"]).with_id(2),
2041            control: ProtoName::from_strings(["org", "ns", "test"]).with_id(2),
2042            config: SessionConfig {
2043                session_type: ProtoSessionType::PointToPoint,
2044                max_retries: Some(5),
2045                interval: Some(Duration::from_millis(200)),
2046                mls_settings: None,
2047                initiator: true,
2048                metadata: HashMap::new(),
2049            },
2050            direction: Direction::Bidirectional,
2051            slim_tx: tx_slim,
2052            app_tx: tx_app,
2053            tx_session,
2054            tx_to_session_layer: tx_session_layer,
2055            identity_provider: SharedSecret::new("test", SHARED_SECRET).unwrap(),
2056            identity_verifier: SharedSecret::new("test", SHARED_SECRET).unwrap(),
2057            graceful_shutdown_timeout,
2058            subscription_manager,
2059            service_id: String::new(),
2060        }
2061    }
2062
2063    /// Helper to create a test message
2064    fn create_test_message(message_id: u32, payload: Vec<u8>) -> SessionMessage {
2065        SessionMessage::OnMessage {
2066            message: Message::builder()
2067                .source(ProtoName::from_strings(["org", "ns", "test"]).with_id(1))
2068                .destination(ProtoName::from_strings(["org", "ns", "test"]).with_id(2))
2069                .identity(test_identity())
2070                .forward_to(1)
2071                .session_type(ProtoSessionType::PointToPoint)
2072                .session_message_type(ProtoSessionMessageType::Msg)
2073                .session_id(1)
2074                .message_id(message_id)
2075                .application_payload("test", payload)
2076                .build_publish()
2077                .unwrap(),
2078            direction: MessageDirection::South,
2079            ack_tx: None,
2080        }
2081    }
2082
2083    async fn count_on_messages(messages: &Arc<tokio::sync::Mutex<Vec<SessionMessage>>>) -> usize {
2084        let messages = messages.lock().await;
2085        messages
2086            .iter()
2087            .filter(|msg| matches!(msg, SessionMessage::OnMessage { .. }))
2088            .count()
2089    }
2090
2091    /// Helper to spawn a processing loop and return the task handle
2092    fn spawn_processing_loop(
2093        handler: DrainableHandler,
2094        rx: tokio::sync::mpsc::Receiver<SessionMessage>,
2095        cancellation_token: CancellationToken,
2096        settings: SessionSettings<SharedSecret, SharedSecret>,
2097    ) -> tokio::task::JoinHandle<()> {
2098        tokio::spawn(async move {
2099            SessionController::processing_loop(handler, rx, cancellation_token, settings).await;
2100        })
2101    }
2102
2103    #[tokio::test]
2104    async fn test_draining_processes_queued_messages() {
2105        let handler = DrainableHandler::new();
2106        let messages_received = handler.messages_received.clone();
2107        let shutdown_called = handler.shutdown_called.clone();
2108
2109        let (tx, rx) = tokio::sync::mpsc::channel(10);
2110        let cancellation_token = CancellationToken::new();
2111        let token_clone = cancellation_token.clone();
2112
2113        let settings = create_test_settings(Some(Duration::from_secs(2)));
2114        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2115
2116        // Send multiple messages before cancellation
2117        tx.send(create_test_message(1, vec![1, 2, 3]))
2118            .await
2119            .unwrap();
2120        tx.send(create_test_message(2, vec![4, 5, 6]))
2121            .await
2122            .unwrap();
2123        tx.send(create_test_message(3, vec![7, 8, 9]))
2124            .await
2125            .unwrap();
2126
2127        // Give some time for messages to be queued
2128        tokio::time::sleep(Duration::from_millis(50)).await;
2129
2130        // Trigger cancellation
2131        token_clone.cancel();
2132
2133        // Close the channel to signal no more messages
2134        drop(tx);
2135
2136        // Wait for processing to complete
2137        timeout(Duration::from_secs(3), processing_task)
2138            .await
2139            .expect("timeout waiting for processing loop")
2140            .expect("processing loop panicked");
2141
2142        // Verify all messages were processed
2143        let processed_messages = count_on_messages(&messages_received).await;
2144        assert_eq!(
2145            processed_messages, 3,
2146            "All queued messages should be processed during draining"
2147        );
2148        assert!(
2149            *shutdown_called.lock().await,
2150            "Shutdown should have been called"
2151        );
2152    }
2153
2154    #[tokio::test]
2155    async fn test_draining_with_needs_drain_true() {
2156        let handler = DrainableHandler::new().with_needs_drain(true);
2157        let messages_received = handler.messages_received.clone();
2158        let shutdown_called = handler.shutdown_called.clone();
2159
2160        let (tx, rx) = tokio::sync::mpsc::channel(10);
2161        let cancellation_token = CancellationToken::new();
2162        let token_clone = cancellation_token.clone();
2163
2164        let settings = create_test_settings(Some(Duration::from_secs(2)));
2165        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2166
2167        // Send a message
2168        tx.send(create_test_message(1, vec![1, 2, 3]))
2169            .await
2170            .unwrap();
2171        tokio::time::sleep(Duration::from_millis(50)).await;
2172
2173        // Trigger cancellation and close channel
2174        token_clone.cancel();
2175        drop(tx);
2176
2177        // Wait for processing to complete (should wait for drain timeout)
2178        timeout(Duration::from_secs(3), processing_task)
2179            .await
2180            .expect("timeout waiting for processing loop")
2181            .expect("processing loop panicked");
2182
2183        // Verify message was processed and shutdown was called
2184        let processed_messages = count_on_messages(&messages_received).await;
2185        assert_eq!(processed_messages, 1, "Message should be processed");
2186        assert!(
2187            *shutdown_called.lock().await,
2188            "Shutdown should have been called after draining"
2189        );
2190    }
2191
2192    #[tokio::test]
2193    async fn test_draining_with_needs_drain_false() {
2194        let handler = DrainableHandler::new().with_needs_drain(false);
2195        let messages_received = handler.messages_received.clone();
2196        let shutdown_called = handler.shutdown_called.clone();
2197
2198        let (tx, rx) = tokio::sync::mpsc::channel(10);
2199        let cancellation_token = CancellationToken::new();
2200        let token_clone = cancellation_token.clone();
2201
2202        let settings = create_test_settings(Some(Duration::from_secs(2)));
2203
2204        let start_time = tokio::time::Instant::now();
2205        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2206
2207        // Send a message
2208        tx.send(create_test_message(1, vec![1, 2, 3]))
2209            .await
2210            .unwrap();
2211        tokio::time::sleep(Duration::from_millis(50)).await;
2212
2213        // Trigger cancellation and close channel
2214        token_clone.cancel();
2215        drop(tx);
2216
2217        // Wait for processing to complete (should exit quickly)
2218        timeout(Duration::from_secs(1), processing_task)
2219            .await
2220            .expect("timeout waiting for processing loop")
2221            .expect("processing loop panicked");
2222
2223        let elapsed = start_time.elapsed();
2224
2225        // Verify message was processed and shutdown was called quickly
2226        let processed_messages = count_on_messages(&messages_received).await;
2227        assert_eq!(processed_messages, 1, "Message should be processed");
2228        assert!(
2229            *shutdown_called.lock().await,
2230            "Shutdown should have been called"
2231        );
2232        assert!(
2233            elapsed < Duration::from_millis(500),
2234            "Should exit quickly when no draining needed"
2235        );
2236    }
2237
2238    #[tokio::test]
2239    async fn test_draining_timeout_enforced() {
2240        // Test that the timeout fires when draining takes too long with needs_drain=true
2241        let handler = DrainableHandler::new().with_needs_drain(true);
2242
2243        let (tx, rx) = tokio::sync::mpsc::channel(10);
2244        let cancellation_token = CancellationToken::new();
2245        let token_clone = cancellation_token.clone();
2246
2247        let settings = create_test_settings(Some(Duration::from_millis(500)));
2248        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2249
2250        // Give the processing loop a moment to start
2251        tokio::time::sleep(Duration::from_millis(50)).await;
2252
2253        let start_time = tokio::time::Instant::now();
2254
2255        // Trigger cancellation - this starts draining
2256        token_clone.cancel();
2257
2258        // Keep sending messages to prevent channel from closing
2259        // This simulates a scenario where messages keep arriving during drain period
2260        let send_task = tokio::spawn(async move {
2261            for i in 0..10 {
2262                tokio::time::sleep(Duration::from_millis(100)).await;
2263                if tx
2264                    .send(create_test_message(i, vec![i as u8]))
2265                    .await
2266                    .is_err()
2267                {
2268                    break;
2269                }
2270            }
2271        });
2272
2273        // Wait for processing to complete - should timeout after 500ms
2274        timeout(Duration::from_secs(2), processing_task)
2275            .await
2276            .expect("timeout waiting for processing loop")
2277            .expect("processing loop panicked");
2278
2279        let elapsed = start_time.elapsed();
2280
2281        // Verify timeout was enforced (should be around 500ms)
2282        assert!(
2283            elapsed >= Duration::from_millis(400),
2284            "Should wait at least close to the timeout period"
2285        );
2286        assert!(
2287            elapsed < Duration::from_secs(2),
2288            "Should respect the timeout and exit, not wait forever"
2289        );
2290
2291        // Clean up the send task
2292        send_task.abort();
2293    }
2294
2295    #[tokio::test]
2296    async fn test_draining_no_messages_in_queue() {
2297        let handler = DrainableHandler::new().with_needs_drain(true);
2298        let messages_received = handler.messages_received.clone();
2299        let shutdown_called = handler.shutdown_called.clone();
2300
2301        let (tx, rx) = tokio::sync::mpsc::channel(10);
2302        let cancellation_token = CancellationToken::new();
2303        let token_clone = cancellation_token.clone();
2304
2305        let settings = create_test_settings(Some(Duration::from_secs(1)));
2306        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2307
2308        // Trigger cancellation immediately without sending messages
2309        token_clone.cancel();
2310        drop(tx);
2311
2312        // Wait for processing to complete
2313        timeout(Duration::from_secs(2), processing_task)
2314            .await
2315            .expect("timeout waiting for processing loop")
2316            .expect("processing loop panicked");
2317
2318        // Verify no messages were processed but shutdown was called
2319        let processed_messages = count_on_messages(&messages_received).await;
2320        assert_eq!(processed_messages, 0, "No messages should be processed");
2321        assert!(
2322            *shutdown_called.lock().await,
2323            "Shutdown should still be called"
2324        );
2325    }
2326
2327    #[tokio::test]
2328    async fn test_draining_messages_after_cancellation_processed() {
2329        let handler = DrainableHandler::new();
2330        let messages_received = handler.messages_received.clone();
2331
2332        let (tx, rx) = tokio::sync::mpsc::channel(10);
2333        let cancellation_token = CancellationToken::new();
2334        let token_clone = cancellation_token.clone();
2335
2336        let settings = create_test_settings(Some(Duration::from_secs(2)));
2337
2338        // Send messages before cancellation
2339        tx.send(create_test_message(1, vec![1, 2, 3]))
2340            .await
2341            .unwrap();
2342        tx.send(create_test_message(2, vec![4, 5, 6]))
2343            .await
2344            .unwrap();
2345
2346        // Spawn the processing loop after messages are queued
2347        let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
2348
2349        // Give a moment for processing to start
2350        tokio::time::sleep(Duration::from_millis(50)).await;
2351
2352        // Trigger cancellation while messages are in queue
2353        token_clone.cancel();
2354
2355        // Close channel
2356        drop(tx);
2357
2358        // Wait for processing to complete
2359        timeout(Duration::from_secs(3), processing_task)
2360            .await
2361            .expect("timeout waiting for processing loop")
2362            .expect("processing loop panicked");
2363
2364        // Verify messages in queue when cancellation happened were still processed
2365        let processed_messages = count_on_messages(&messages_received).await;
2366        assert_eq!(
2367            processed_messages, 2,
2368            "Messages in queue during cancellation should be processed"
2369        );
2370    }
2371}