Skip to main content

slim_session/
session_layer.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4// Standard library imports
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use display_error_chain::ErrorChainExt;
9// Third-party crates
10use parking_lot::RwLock as SyncRwLock;
11use rand::Rng;
12
13use tokio::sync::Semaphore;
14use tokio::sync::mpsc::Sender;
15use tracing::{Instrument, debug, error, warn};
16
17use slim_auth::traits::{TokenProvider, Verifier};
18use slim_datapath::api::{
19    EncodedName, NameId, ParticipantSettings, ProtoMessage as Message, ProtoName,
20    ProtoSessionMessageType, ProtoSessionType,
21};
22
23use crate::common::SessionMessage;
24use crate::completion_handle::CompletionHandle;
25use crate::notification::Notification;
26use crate::session_config::SessionConfig;
27use crate::session_controller::SessionController;
28use crate::subscription_manager::SubscriptionManager;
29
30// Local crate
31use super::context::SessionContext;
32
33use super::{SESSION_RANGE, SlimChannelSender};
34use super::{SessionError, session_controller::handle_channel_discovery_message};
35/// Direction enum for session creation
36/// Indicates whether the session can send, receive, both, or neither data messages.
37#[derive(Clone, Copy, Debug)]
38pub enum Direction {
39    Send,          // Can only send data messages (shutdown_send: false, shutdown_receive: true)
40    Recv,          // Can only receive data messages (shutdown_send: true, shutdown_receive: false)
41    Bidirectional, // Can send and receive data messages (shutdown_send: false, shutdown_receive: false)
42    None, // Neither send nor receive data messages (shutdown_send: true, shutdown_receive: true)
43}
44
45impl Direction {
46    pub fn to_flags(self) -> (bool, bool) {
47        match self {
48            Direction::Send => (false, true),
49            Direction::Recv => (true, false),
50            Direction::Bidirectional => (false, false),
51            Direction::None => (true, true),
52        }
53    }
54
55    pub fn to_participant_settings(self) -> ParticipantSettings {
56        match self {
57            // None (absent) means true, so only set fields explicitly when false
58            Direction::Send => ParticipantSettings {
59                sends_data: true,
60                receives_data: false,
61            },
62            Direction::Recv => ParticipantSettings {
63                sends_data: false,
64                receives_data: true,
65            },
66            Direction::Bidirectional => ParticipantSettings {
67                sends_data: true,
68                receives_data: true,
69            },
70            Direction::None => ParticipantSettings {
71                sends_data: false,
72                receives_data: false,
73            },
74        }
75    }
76}
77
78/// SessionLayer manages sessions and their lifecycle
79pub struct SessionLayer<P, V>
80where
81    P: TokenProvider + Send + Sync + Clone + 'static,
82    V: Verifier + Send + Sync + Clone + 'static,
83{
84    /// Session pool
85    pool: Arc<SyncRwLock<HashMap<u32, Arc<SessionController>>>>,
86
87    /// Default name of the local app
88    app_id: u128,
89
90    /// Names registered by local app, keyed by encoded name (null id) → subscription_id
91    app_names: SyncRwLock<HashMap<EncodedName, u64>>,
92
93    /// Identity provider for the local app
94    identity_provider: P,
95
96    /// Identity verifier
97    identity_verifier: V,
98
99    /// ID of the local connection
100    conn_id: u64,
101
102    /// Tx channels
103    tx_slim: SlimChannelSender,
104    tx_app: Sender<Result<Notification, SessionError>>,
105
106    /// Channel to clone on session creation
107    tx_session: tokio::sync::mpsc::Sender<Result<SessionMessage, SessionError>>,
108
109    /// map from session id to session context
110    /// once a new session is created on reception of a join request, store it temporarily
111    /// in this map waiting for the welcome message before notifying it to the application
112    to_notify: SyncRwLock<HashMap<u32, SessionContext>>,
113
114    /// direction to use for the new sessions
115    direction: Direction,
116
117    /// Shared subscription manager — used by both this layer and all sessions it creates
118    subscription_manager: SubscriptionManager,
119
120    /// Service ID propagated into every session span
121    service_id: String,
122
123    /// Bounds concurrent identity verifications for messages without a session.
124    /// Caps the blast radius of an unknown-session flood with slow verifications.
125    pre_session_verify_slots: Arc<Semaphore>,
126}
127
128impl<P, V> SessionLayer<P, V>
129where
130    P: TokenProvider + Send + Sync + Clone + 'static,
131    V: Verifier + Send + Sync + Clone + 'static,
132{
133    const PRE_SESSION_VERIFY_SLOTS: usize = 128;
134
135    /// Create a new SessionLayer
136    #[allow(clippy::too_many_arguments)]
137    pub fn new(
138        app_name: ProtoName,
139        identity_provider: P,
140        identity_verifier: V,
141        conn_id: u64,
142        tx_slim: SlimChannelSender,
143        tx_app: Sender<Result<Notification, SessionError>>,
144        direction: Direction,
145        service_id: String,
146    ) -> Self {
147        let (tx_session, rx_session) = tokio::sync::mpsc::channel(16);
148
149        let subscription_manager = SubscriptionManager::new(tx_slim.clone());
150
151        let initial_key = Self::name_to_key(&app_name);
152        let sl = SessionLayer {
153            pool: Arc::new(SyncRwLock::new(HashMap::new())),
154            app_id: app_name.id(),
155            app_names: SyncRwLock::new(HashMap::from([(initial_key, 0)])),
156            identity_provider,
157            identity_verifier,
158            conn_id,
159            tx_slim,
160            tx_app,
161            tx_session,
162            to_notify: SyncRwLock::new(HashMap::new()),
163            direction,
164            subscription_manager,
165            service_id,
166            pre_session_verify_slots: Arc::new(Semaphore::new(Self::PRE_SESSION_VERIFY_SLOTS)),
167        };
168
169        sl.listen_from_sessions(rx_session);
170
171        sl
172    }
173
174    pub fn tx_slim(&self) -> SlimChannelSender {
175        self.tx_slim.clone()
176    }
177
178    pub fn subscription_manager(&self) -> SubscriptionManager {
179        self.subscription_manager.clone()
180    }
181
182    pub fn tx_app(&self) -> Sender<Result<Notification, SessionError>> {
183        self.tx_app.clone()
184    }
185
186    #[allow(dead_code)]
187    pub fn conn_id(&self) -> u64 {
188        self.conn_id
189    }
190
191    pub fn app_id(&self) -> u128 {
192        self.app_id
193    }
194
195    /// Build the HashMap key (EncodedName with null component_3) from a ProtoName.
196    fn name_to_key(name: &ProtoName) -> EncodedName {
197        let enc = name.name.as_ref().unwrap();
198        EncodedName {
199            component_0: enc.component_0,
200            component_1: enc.component_1,
201            component_2: enc.component_2,
202            name_id: Some(NameId::from(NameId::NULL_COMPONENT)),
203        }
204    }
205
206    pub fn add_app_name(&self, name: ProtoName, subscription_id: u64) {
207        let key = Self::name_to_key(&name);
208        self.app_names.write().insert(key, subscription_id);
209    }
210
211    pub fn remove_app_name(&self, name: &ProtoName) -> Option<u64> {
212        let key = Self::name_to_key(name);
213        let removed = self.app_names.write().remove(&key);
214        if removed.is_none() {
215            warn!(%name, "tried to remove unknown app name");
216        }
217        removed
218    }
219
220    fn get_local_name_for_session(&self, dst: ProtoName) -> Result<ProtoName, SessionError> {
221        let key = Self::name_to_key(&dst);
222        if self.app_names.read().contains_key(&key) {
223            Ok(dst.with_id(self.app_id))
224        } else {
225            Err(SessionError::SubscriptionNotFound(dst))
226        }
227    }
228
229    /// Get identity token from the identity provider
230    pub fn get_identity_token(&self) -> Result<String, SessionError> {
231        let token = self.identity_provider.get_token()?;
232        Ok(token)
233    }
234
235    /// Public interface to create a new session
236    #[tracing::instrument(skip_all, fields(service_id = %self.service_id))]
237    pub async fn create_session(
238        &self,
239        mut session_config: SessionConfig,
240        local_name: ProtoName,
241        destination: ProtoName,
242        id: Option<u32>,
243    ) -> Result<(SessionContext, CompletionHandle), SessionError> {
244        // Sanity check
245        session_config.initiator = true;
246
247        // Store values before they are moved
248        let is_p2p = session_config.session_type == ProtoSessionType::PointToPoint;
249        let destination_proto = destination.clone();
250
251        let session = self.create_session_internal(session_config, local_name, destination, id)?;
252
253        // If session is p2p, initiate the discovery request now and return the ack
254        // Otherwise, return an immediately resolved future
255        let init_ack = if is_p2p {
256            session
257                .session()
258                .upgrade()
259                .ok_or(SessionError::SessionNotFound(u32::MAX))?
260                .invite_participant_internal(&destination_proto)
261                .await
262                .inspect_err(|_| {
263                    // If invite_participant_internal fails, remove the session from the pool
264                    let _ = self.remove_session(session.session_id());
265                })?
266        } else {
267            // For non-P2P sessions, return an immediately resolved future
268            let (tx, rx) = tokio::sync::oneshot::channel();
269            let _ = tx.send(Ok(()));
270            CompletionHandle::from_oneshot_receiver(rx)
271        };
272
273        // return the session info and initialization ack
274        Ok((session, init_ack))
275    }
276
277    /// Create a new session and add it to the pool
278    fn create_session_internal(
279        &self,
280        session_config: SessionConfig,
281        local_name: ProtoName,
282        destination: ProtoName,
283        id: Option<u32>,
284    ) -> Result<SessionContext, SessionError> {
285        // Retry loop to handle race conditions when generating random IDs
286        loop {
287            // get a lock on the session pool
288            let session_id = {
289                let pool = self.pool.read();
290
291                // generate a new session ID in the SESSION_RANGE if not provided
292                match id {
293                    Some(id) => {
294                        // make sure provided id is in range
295                        if !SESSION_RANGE.contains(&id) {
296                            return Err(SessionError::InvalidSessionId(id));
297                        }
298
299                        // check if the session ID is already used
300                        if pool.contains_key(&id) {
301                            return Err(SessionError::SessionIdAlreadyUsed(id));
302                        }
303
304                        id
305                    }
306                    None => {
307                        // generate a new session ID
308                        loop {
309                            let session_id = rand::rng().random_range(SESSION_RANGE);
310                            if !pool.contains_key(&session_id) {
311                                break session_id;
312                            }
313                        }
314                    }
315                }
316            }; // lock is dropped here
317
318            // Create app channel for this session
319            let (app_tx, app_rx) = tokio::sync::mpsc::unbounded_channel();
320
321            // Build the session controller (this is async, so no locks are held)
322            // The builder will automatically force DATA_CHANNEL_ID for multicast destinations
323            let builder = SessionController::builder()
324                .with_id(session_id)
325                .with_source(local_name.clone())
326                .with_destination(destination.clone())
327                .with_config(session_config.clone())
328                .with_identity_provider(self.identity_provider.clone())
329                .with_identity_verifier(self.identity_verifier.clone())
330                .with_slim_tx(self.tx_slim.clone())
331                .with_app_tx(app_tx)
332                .with_tx_to_session_layer(self.tx_session.clone())
333                .with_direction(self.direction)
334                .with_subscription_manager(self.subscription_manager.clone())
335                .with_service_id(self.service_id.clone())
336                .ready()?;
337
338            // Perform the async build operation without holding any lock
339            let session_controller = Arc::new(builder.build()?);
340
341            // Reacquire lock to insert the session
342            let mut pool = self.pool.write();
343
344            // Double-check that the ID wasn't taken while we didn't hold the lock
345            if pool.contains_key(&session_id) {
346                // If a specific ID was provided, return an error
347                if id.is_some() {
348                    return Err(SessionError::SessionIdAlreadyUsed(session_id));
349                }
350                // If ID was randomly generated, retry with a new ID
351                continue;
352            }
353
354            let ret = pool.insert(session_id, session_controller.clone());
355
356            // This should never happen, but just in case
357            if ret.is_some() {
358                error!(
359                    %session_id,
360                    "session ID was taken during insertion: this should not happen",
361                );
362                return Err(SessionError::SessionIdAlreadyUsed(session_id));
363            }
364
365            return Ok(SessionContext::new(session_controller, app_rx));
366        }
367    }
368
369    pub fn listen_from_sessions(
370        &self,
371        mut rx_session: tokio::sync::mpsc::Receiver<Result<SessionMessage, SessionError>>,
372    ) {
373        let pool_clone = self.pool.clone();
374        let sessions_span = tracing::info_span!(parent: None, "listen_from_sessions", service_id = %self.service_id);
375
376        tokio::spawn(async move {
377            loop {
378                tokio::select! {
379                    next = rx_session.recv() => {
380                        match next {
381                            Some(Ok(SessionMessage::DeleteSession { session_id })) => {
382                                debug!(%session_id, "received closing signal, cancel session from the pool");
383                                if pool_clone.write().remove(&session_id).is_none() {
384                                    warn!(%session_id, "requested to delete unknown session");
385                                }
386                            }
387                            Some(Ok(m)) => {
388                                error!(?m, "received unexpected message");
389                            }
390                            Some(Err(e)) => {
391                                warn!(error = %e.chain(), "error from session");
392                            }
393                            None => {
394                                // All senders dropped; exit loop.
395                                break;
396                            }
397                        }
398                    }
399                }
400            }
401        }.instrument(sessions_span));
402    }
403
404    /// Remove a session from the pool and return a handle to optionally wait on
405    #[tracing::instrument(skip_all, fields(service_id = %self.service_id, session_id = id))]
406    pub fn remove_session(&self, id: u32) -> Result<CompletionHandle, SessionError> {
407        debug!(%id, "try to remove session");
408        // get the read lock
409        let binding = self.pool.read();
410        let session = binding.get(&id).ok_or(SessionError::SessionNotFound(id))?;
411
412        // close the session and get the join handle
413        let join_handle = session.close()?;
414
415        // Return a CompletionHandle wrapping the oneshot receiver
416        Ok(CompletionHandle::from_join_handle(join_handle))
417    }
418
419    /// Clear all sessions and return completion handles to await on
420    pub fn clear_all_sessions(&self) -> HashMap<u32, Result<CompletionHandle, SessionError>> {
421        let pool = {
422            let mut pool = self.pool.write();
423            let copy = pool.clone();
424            pool.clear();
425            copy
426        };
427
428        // Close all sessions and return completion handles
429        pool.iter()
430            .map(|(id, session)| {
431                let result = session.close().map(CompletionHandle::from_join_handle);
432                (*id, result)
433            })
434            .collect()
435    }
436
437    /// Handle an error coming from SLIM. Forward it to the corresponding session.
438    #[tracing::instrument(skip_all, fields(service_id = %self.service_id))]
439    pub async fn handle_error_from_slim(&self, error: SessionError) -> Result<(), SessionError> {
440        // Extract context and session ID from the error
441        let Some(session_ctx) = error.session_context() else {
442            debug!(
443                error = %error.chain(),
444                "received error without session context in handle_error_from_slim",
445            );
446            return Ok(());
447        };
448
449        let session_id = session_ctx.session_id;
450        let session_controller = self.pool.read().get(&session_id).cloned();
451
452        if let Some(controller) = session_controller {
453            debug!(
454                error = %error.chain(),
455                session_id = %session_id,
456                "received error from SLIM for session id",
457            );
458
459            // pass the error to the session
460            return controller.on_error_message_from_slim(error).await;
461        }
462
463        debug!(
464            error = %error.chain(),
465            "received error from SLIM for unknown session id",
466        );
467
468        Ok(())
469    }
470
471    /// Handle a message from the message processor, and pass it to the
472    /// corresponding session
473    #[tracing::instrument(skip_all, fields(service_id = %self.service_id))]
474    pub async fn handle_message_from_slim(
475        self: &Arc<Self>,
476        message: Message,
477    ) -> Result<(), SessionError> {
478        tracing::trace!(
479            msg_type = %message.get_session_message_type().as_str_name(),
480            session_id = %message.get_id(),
481            "received message from SLIM",
482        );
483
484        let (id, session_type, session_message_type) = {
485            let header = message.get_session_header();
486            (
487                header.session_id,
488                header.session_type(),
489                header.session_message_type(),
490            )
491        };
492
493        // Fast path: known session — route to its controller. The controller's
494        // processing loop verifies identity in its own task, so sessions don't
495        // serialize behind each other.
496        let session_controller = self.pool.read().get(&id).cloned();
497        if let Some(controller) = session_controller {
498            controller.on_message_from_slim(message).await?;
499
500            if session_message_type == ProtoSessionMessageType::GroupWelcome {
501                let new_session = self
502                    .to_notify
503                    .write()
504                    .remove(&id)
505                    .ok_or(SessionError::NewSessionSendFailed)?;
506                return self
507                    .tx_app
508                    .send(Ok(Notification::NewSession(new_session)))
509                    .await
510                    .map_err(|_e| SessionError::NewSessionSendFailed);
511            }
512
513            return Ok(());
514        }
515
516        // Slow path: no session yet. JoinRequest is processed inline so that
517        // the session is registered before the next message (e.g. GroupWelcome)
518        // arrives on this same receive loop. Its identity will be verified by
519        // the new controller's processing loop (single verify, no replay
520        // collision). Stateless DiscoveryRequest is verified off-task before
521        // replying. Everything else is dropped.
522        match session_message_type {
523            ProtoSessionMessageType::JoinRequest => {
524                self.handle_join_request(message, id, session_type).await
525            }
526            ProtoSessionMessageType::DiscoveryRequest => {
527                self.handle_discovery_request(message, id, session_type, session_message_type)
528            }
529            _ => {
530                tracing::debug!(?message, "received channel message with unknown session id");
531                Ok(())
532            }
533        }
534    }
535
536    fn handle_discovery_request(
537        self: &Arc<Self>,
538        message: Message,
539        id: u32,
540        session_type: ProtoSessionType,
541        session_message_type: ProtoSessionMessageType,
542    ) -> Result<(), SessionError> {
543        let layer = self.clone();
544        tokio::spawn(async move {
545            let _permit = match layer.pre_session_verify_slots.clone().acquire_owned().await {
546                Ok(p) => p,
547                Err(_) => return,
548            };
549
550            if let Err(e) =
551                crate::session_controller::verify_identity(&message, &layer.identity_verifier).await
552            {
553                debug!(
554                    error = %e.chain(),
555                    msg_type = %session_message_type.as_str_name(),
556                    "dropping pre-session message: identity verification failed",
557                );
558                return;
559            }
560
561            let local_name =
562                match layer.get_local_name_for_session(message.get_slim_header().get_dst()) {
563                    Ok(n) => n,
564                    Err(e) => {
565                        debug!(error = %e.chain(), "error handling discovery request");
566                        return;
567                    }
568                };
569
570            let mut reply =
571                match handle_channel_discovery_message(&message, &local_name, id, session_type) {
572                    Ok(r) => r,
573                    Err(e) => {
574                        debug!(error = %e.chain(), "error building discovery reply");
575                        return;
576                    }
577                };
578
579            let identity = match layer.identity_provider.get_token() {
580                Ok(t) => t,
581                Err(e) => {
582                    debug!(error = %e.chain(), "error getting identity token for discovery reply");
583                    return;
584                }
585            };
586            reply.get_slim_header_mut().set_identity(identity);
587            if let Err(e) = layer.tx_slim.send(Ok(reply)).await {
588                debug!(error = %e.chain(), "error sending discovery reply");
589            }
590        });
591
592        Ok(())
593    }
594
595    async fn handle_join_request(
596        &self,
597        message: Message,
598        id: u32,
599        session_type: ProtoSessionType,
600    ) -> Result<(), SessionError> {
601        let local_name = self.get_local_name_for_session(message.get_slim_header().get_dst())?;
602
603        let new_session = match session_type {
604            ProtoSessionType::PointToPoint => {
605                let conf = crate::SessionConfig::from_join_request(
606                    ProtoSessionType::PointToPoint,
607                    message.extract_command_payload()?,
608                    message.get_metadata_map(),
609                    false,
610                )?;
611                self.create_session_internal(conf, local_name, message.get_source(), Some(id))?
612            }
613            ProtoSessionType::Multicast => {
614                let payload = message.extract_join_request()?;
615                if payload.timer_settings.is_none() {
616                    return Err(SessionError::MissingPayload {
617                        context: "timer options",
618                    });
619                }
620                let channel = payload
621                    .channel
622                    .clone()
623                    .ok_or(SessionError::MissingChannelName)?;
624                let conf = crate::SessionConfig::from_join_request(
625                    ProtoSessionType::Multicast,
626                    message.extract_command_payload()?,
627                    message.get_metadata_map(),
628                    false,
629                )?;
630                self.create_session_internal(conf, local_name, channel, Some(id))?
631            }
632            _ => {
633                warn!(
634                    session_type = %session_type.as_str_name(),
635                    "received channel join request with unknown session type",
636                );
637                return Err(SessionError::SessionTypeUnknown(session_type));
638            }
639        };
640
641        let session_controller = new_session
642            .session()
643            .upgrade()
644            .ok_or(SessionError::SessionClosed)?;
645
646        session_controller.on_message_from_slim(message).await?;
647
648        self.to_notify
649            .write()
650            .insert(new_session.session_id(), new_session);
651
652        Ok(())
653    }
654
655    /// Check if the session pool is empty (for testing purposes)
656    pub fn is_pool_empty(&self) -> bool {
657        self.pool.read().is_empty()
658    }
659
660    /// Get the number of sessions in the pool (for testing purposes)
661    pub fn pool_size(&self) -> usize {
662        self.pool.read().len()
663    }
664
665    /// Get a session from the pool (for testing purposes)
666    pub fn get_session(&self, id: u32) -> Option<Arc<SessionController>> {
667        self.pool.read().get(&id).cloned()
668    }
669}
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674    use crate::test_utils::{MockTokenProvider, MockVerifier};
675    use slim_datapath::Status;
676    use slim_datapath::api::{NameId, ProtoName, ProtoSessionType};
677    use tokio::sync::mpsc;
678
679    // --- Test Mocks -----------------------------------------------------------------------
680
681    fn make_name(parts: &[&str; 3]) -> ProtoName {
682        ProtoName::from_strings([parts[0], parts[1], parts[2]]).with_id(0)
683    }
684
685    type TestSessionLayer = Arc<SessionLayer<MockTokenProvider, MockVerifier>>;
686    type SlimReceiver = mpsc::Receiver<Result<Message, Status>>;
687    type AppReceiver = mpsc::Receiver<Result<Notification, SessionError>>;
688
689    fn setup_session_layer() -> (TestSessionLayer, SlimReceiver, AppReceiver) {
690        let app_name = make_name(&["test", "app", "v1"]);
691        let identity_provider = MockTokenProvider;
692        let identity_verifier = MockVerifier;
693        let conn_id = 12345u64;
694
695        let (tx_slim, rx_slim) = mpsc::channel(16);
696        let (tx_app, rx_app) = mpsc::channel(16);
697
698        let session_layer = Arc::new(SessionLayer::new(
699            app_name,
700            identity_provider,
701            identity_verifier,
702            conn_id,
703            tx_slim,
704            tx_app,
705            Direction::Bidirectional,
706            "test-service".to_string(),
707        ));
708
709        (session_layer, rx_slim, rx_app)
710    }
711
712    #[tokio::test]
713    async fn test_new_session_layer() {
714        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
715
716        assert_eq!(session_layer.app_id(), 0);
717        assert_eq!(session_layer.conn_id(), 12345);
718        assert!(session_layer.is_pool_empty());
719    }
720
721    #[tokio::test]
722    async fn test_add_and_remove_app_name() {
723        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
724
725        let name1 = make_name(&["service", "v1", "api"]);
726        let name2 = make_name(&["service", "v2", "api"]);
727
728        session_layer.add_app_name(name1.clone(), 0);
729        session_layer.add_app_name(name2.clone(), 0);
730
731        // Verify names are added
732        assert_eq!(session_layer.app_names.read().len(), 3); // initial + 2 added
733
734        session_layer.remove_app_name(&name1);
735        assert_eq!(session_layer.app_names.read().len(), 2);
736
737        session_layer.remove_app_name(&name2);
738        assert_eq!(session_layer.app_names.read().len(), 1);
739    }
740
741    #[tokio::test]
742    async fn test_get_identity_token() {
743        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
744
745        let token = session_layer.get_identity_token();
746        assert!(token.is_ok());
747        assert_eq!(token.unwrap(), "");
748    }
749
750    #[tokio::test]
751    async fn test_create_session_with_auto_id() {
752        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
753
754        let local_name = make_name(&["local", "app", "v1"]);
755        let destination = make_name(&["remote", "app", "v1"]);
756        let config = SessionConfig {
757            session_type: ProtoSessionType::PointToPoint,
758            max_retries: Some(3),
759            interval: Some(std::time::Duration::from_secs(1)),
760            mls_settings: None,
761            initiator: true,
762            metadata: Default::default(),
763        };
764
765        let result = session_layer.create_session_internal(config, local_name, destination, None);
766
767        assert!(result.is_ok());
768        assert_eq!(session_layer.pool_size(), 1);
769    }
770
771    #[tokio::test]
772    async fn test_create_session_with_specific_id() {
773        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
774
775        let local_name = make_name(&["local", "app", "v1"]);
776        let destination = make_name(&["remote", "app", "v1"]);
777        let config = SessionConfig {
778            session_type: ProtoSessionType::PointToPoint,
779            max_retries: Some(3),
780            interval: Some(std::time::Duration::from_secs(1)),
781            mls_settings: None,
782            initiator: true,
783            metadata: Default::default(),
784        };
785
786        let session_id = 100u32;
787        let result = session_layer.create_session_internal(
788            config,
789            local_name,
790            destination,
791            Some(session_id),
792        );
793
794        assert!(result.is_ok());
795        assert_eq!(session_layer.pool_size(), 1);
796
797        let session = session_layer.get_session(session_id);
798        assert!(session.is_some());
799    }
800
801    #[tokio::test]
802    async fn test_create_session_with_invalid_id() {
803        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
804
805        let local_name = make_name(&["local", "app", "v1"]);
806        let destination = make_name(&["remote", "app", "v1"]);
807        let config = SessionConfig {
808            session_type: ProtoSessionType::PointToPoint,
809            max_retries: Some(3),
810            interval: Some(std::time::Duration::from_secs(1)),
811            mls_settings: None,
812            initiator: true,
813            metadata: Default::default(),
814        };
815
816        // Use an ID outside the SESSION_RANGE (SESSION_RANGE is 0..u32::MAX-1000)
817        let invalid_id = u32::MAX - 500; // This is outside SESSION_RANGE
818        let result = session_layer.create_session_internal(
819            config,
820            local_name,
821            destination,
822            Some(invalid_id),
823        );
824
825        assert!(result.is_err());
826        match result {
827            Err(SessionError::InvalidSessionId(_)) => {}
828            _ => panic!("Expected InvalidSessionId error"),
829        }
830    }
831
832    #[tokio::test]
833    async fn test_create_session_with_duplicate_id() {
834        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
835
836        let local_name = make_name(&["local", "app", "v1"]);
837        let destination = make_name(&["remote", "app", "v1"]);
838        let config = SessionConfig {
839            session_type: ProtoSessionType::PointToPoint,
840            max_retries: Some(3),
841            interval: Some(std::time::Duration::from_secs(1)),
842            mls_settings: None,
843            initiator: true,
844            metadata: Default::default(),
845        };
846
847        let session_id = 100u32;
848
849        // Create first session
850        let result1 = session_layer.create_session_internal(
851            config.clone(),
852            local_name.clone(),
853            destination.clone(),
854            Some(session_id),
855        );
856        assert!(result1.is_ok());
857
858        // Try to create second session with same ID
859        let result2 = session_layer.create_session_internal(
860            config,
861            local_name,
862            destination,
863            Some(session_id),
864        );
865
866        assert!(result2.is_err());
867        match result2 {
868            Err(SessionError::SessionIdAlreadyUsed(_)) => {}
869            _ => panic!("Expected SessionIdAlreadyUsed error"),
870        }
871    }
872
873    #[tokio::test]
874    async fn test_remove_session() {
875        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
876
877        let local_name = make_name(&["local", "app", "v1"]);
878        let destination = make_name(&["remote", "app", "v1"]);
879        let config = SessionConfig {
880            session_type: ProtoSessionType::PointToPoint,
881            max_retries: Some(3),
882            interval: Some(std::time::Duration::from_secs(1)),
883            mls_settings: None,
884            initiator: true,
885            metadata: Default::default(),
886        };
887
888        let session_id = 100u32;
889        let _context = session_layer
890            .create_session_internal(config, local_name, destination, Some(session_id))
891            .unwrap();
892
893        assert_eq!(session_layer.pool_size(), 1);
894
895        let removed = session_layer
896            .remove_session(session_id)
897            .expect("error removing connection");
898        // await for the handler
899        removed.await.expect("error awaiting the handler");
900        assert!(session_layer.is_pool_empty());
901
902        // Try to remove non-existent session
903        let removed_again = session_layer.remove_session(session_id);
904        assert!(removed_again.is_err());
905    }
906
907    #[tokio::test]
908    async fn test_get_local_name_for_session() {
909        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
910
911        let name = make_name(&["service", "api", "v1"]);
912        session_layer.add_app_name(name.clone(), 0);
913
914        let dst = name.with_id(123);
915        let result = session_layer.get_local_name_for_session(dst);
916
917        assert!(result.is_ok());
918        let local_name = result.unwrap();
919        assert_eq!(local_name.id(), session_layer.app_id());
920    }
921
922    #[tokio::test]
923    async fn test_get_local_name_for_session_not_found() {
924        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
925
926        let unknown_name = make_name(&["unknown", "service", "v1"]);
927        let result = session_layer.get_local_name_for_session(unknown_name);
928
929        assert!(result.is_err());
930        match result {
931            Err(SessionError::SubscriptionNotFound(_)) => {}
932            _ => panic!("Expected SubscriptionNotFound error"),
933        }
934    }
935
936    #[tokio::test]
937    async fn test_tx_slim_and_tx_app_cloning() {
938        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
939
940        let tx_slim = session_layer.tx_slim();
941        let tx_app = session_layer.tx_app();
942
943        // Just verify that we can clone these channels
944        let _tx_slim2 = tx_slim.clone();
945        let _tx_app2 = tx_app.clone();
946    }
947
948    #[tokio::test]
949    async fn test_handle_discovery_request_without_session() {
950        let (session_layer, mut rx_slim, _rx_app) = setup_session_layer();
951
952        let local_name = make_name(&["local", "app", "v1"]);
953        session_layer.add_app_name(local_name.clone(), 0);
954
955        let source = make_name(&["remote", "app", "v1"]);
956        let message = Message::builder()
957            .source(source.clone())
958            .destination(local_name.clone().with_id(session_layer.app_id()))
959            .identity("")
960            .forward_to(0)
961            .incoming_conn(12345)
962            .session_type(ProtoSessionType::PointToPoint)
963            .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
964            .session_id(100)
965            .message_id(0)
966            .application_payload("", vec![])
967            .build_publish()
968            .unwrap();
969
970        session_layer
971            .handle_message_from_slim(message)
972            .await
973            .unwrap();
974
975        let sent = tokio::time::timeout(std::time::Duration::from_secs(1), rx_slim.recv())
976            .await
977            .expect("expected a discovery reply")
978            .expect("slim channel closed")
979            .expect("slim delivered an error");
980
981        assert_eq!(
982            sent.get_session_header().session_message_type(),
983            ProtoSessionMessageType::DiscoveryReply
984        );
985    }
986
987    #[tokio::test]
988    async fn test_pre_session_unknown_message_is_dropped() {
989        let (session_layer, mut rx_slim, _rx_app) = setup_session_layer();
990
991        let local_name = make_name(&["local", "app", "v1"]);
992        session_layer.add_app_name(local_name.clone(), 0);
993
994        let source = make_name(&["remote", "app", "v1"]);
995        let mut message = Message::builder()
996            .source(source.clone())
997            .destination(local_name.clone().with_id(session_layer.app_id()))
998            .application_payload("application/octet-stream", vec![])
999            .build_publish()
1000            .unwrap();
1001        let header = message.get_session_header_mut();
1002        header.set_session_type(ProtoSessionType::PointToPoint);
1003        header.set_session_message_type(ProtoSessionMessageType::Msg);
1004        header.session_id = 100;
1005
1006        session_layer
1007            .handle_message_from_slim(message)
1008            .await
1009            .unwrap();
1010
1011        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1012        assert!(rx_slim.try_recv().is_err());
1013    }
1014
1015    #[tokio::test]
1016    async fn test_multiple_sessions_in_pool() {
1017        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
1018
1019        let local_name = make_name(&["local", "app", "v1"]);
1020        let config = SessionConfig {
1021            session_type: ProtoSessionType::PointToPoint,
1022            max_retries: Some(3),
1023            interval: Some(std::time::Duration::from_secs(1)),
1024            mls_settings: None,
1025            initiator: true,
1026            metadata: Default::default(),
1027        };
1028
1029        // Create multiple sessions
1030        for i in 0..5 {
1031            let destination = make_name(&["remote", &format!("app{}", i), "v1"]);
1032            let result = session_layer.create_session_internal(
1033                config.clone(),
1034                local_name.clone(),
1035                destination,
1036                None,
1037            );
1038            assert!(result.is_ok());
1039        }
1040
1041        assert_eq!(session_layer.pool_size(), 5);
1042    }
1043
1044    #[test]
1045    fn test_direction_to_participant_settings() {
1046        let s = Direction::Send.to_participant_settings();
1047        assert!(s.sends_data);
1048        assert!(!s.receives_data);
1049
1050        let s = Direction::Recv.to_participant_settings();
1051        assert!(!s.sends_data);
1052        assert!(s.receives_data);
1053
1054        let s = Direction::Bidirectional.to_participant_settings();
1055        assert!(s.sends_data);
1056        assert!(s.receives_data);
1057
1058        let s = Direction::None.to_participant_settings();
1059        assert!(!s.sends_data);
1060        assert!(!s.receives_data);
1061    }
1062
1063    #[tokio::test]
1064    async fn test_remove_app_name_with_null_component() {
1065        let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
1066
1067        let name = make_name(&["service", "v1", "api"]).with_id(123);
1068        session_layer.add_app_name(name.clone(), 0);
1069
1070        // Remove with specific ID (should normalize to NULL_COMPONENT)
1071        session_layer.remove_app_name(&name);
1072
1073        // The name with NULL_COMPONENT should be removed
1074        let name_null = name.with_id(NameId::NULL_COMPONENT);
1075        assert!(
1076            !session_layer
1077                .app_names
1078                .read()
1079                .contains_key(
1080                    &SessionLayer::<MockTokenProvider, MockVerifier>::name_to_key(&name_null)
1081                )
1082        );
1083    }
1084}