1use std::collections::HashMap;
6use std::sync::Arc;
7
8use display_error_chain::ErrorChainExt;
9use parking_lot::RwLock as SyncRwLock;
11use rand::Rng;
12
13use tokio::sync::Semaphore;
14use tokio::sync::mpsc::Sender;
15use tracing::{Instrument, debug, error, warn};
16
17use slim_auth::traits::{TokenProvider, Verifier};
18use slim_datapath::api::{
19 EncodedName, NameId, ParticipantSettings, ProtoMessage as Message, ProtoName,
20 ProtoSessionMessageType, ProtoSessionType,
21};
22
23use crate::common::SessionMessage;
24use crate::completion_handle::CompletionHandle;
25use crate::notification::Notification;
26use crate::session_config::SessionConfig;
27use crate::session_controller::SessionController;
28use crate::subscription_manager::SubscriptionManager;
29
30use super::context::SessionContext;
32
33use super::{SESSION_RANGE, SlimChannelSender};
34use super::{SessionError, session_controller::handle_channel_discovery_message};
35#[derive(Clone, Copy, Debug)]
38pub enum Direction {
39 Send, Recv, Bidirectional, None, }
44
45impl Direction {
46 pub fn to_flags(self) -> (bool, bool) {
47 match self {
48 Direction::Send => (false, true),
49 Direction::Recv => (true, false),
50 Direction::Bidirectional => (false, false),
51 Direction::None => (true, true),
52 }
53 }
54
55 pub fn to_participant_settings(self) -> ParticipantSettings {
56 match self {
57 Direction::Send => ParticipantSettings {
59 sends_data: true,
60 receives_data: false,
61 },
62 Direction::Recv => ParticipantSettings {
63 sends_data: false,
64 receives_data: true,
65 },
66 Direction::Bidirectional => ParticipantSettings {
67 sends_data: true,
68 receives_data: true,
69 },
70 Direction::None => ParticipantSettings {
71 sends_data: false,
72 receives_data: false,
73 },
74 }
75 }
76}
77
78pub struct SessionLayer<P, V>
80where
81 P: TokenProvider + Send + Sync + Clone + 'static,
82 V: Verifier + Send + Sync + Clone + 'static,
83{
84 pool: Arc<SyncRwLock<HashMap<u32, Arc<SessionController>>>>,
86
87 app_id: u128,
89
90 app_names: SyncRwLock<HashMap<EncodedName, u64>>,
92
93 identity_provider: P,
95
96 identity_verifier: V,
98
99 conn_id: u64,
101
102 tx_slim: SlimChannelSender,
104 tx_app: Sender<Result<Notification, SessionError>>,
105
106 tx_session: tokio::sync::mpsc::Sender<Result<SessionMessage, SessionError>>,
108
109 to_notify: SyncRwLock<HashMap<u32, SessionContext>>,
113
114 direction: Direction,
116
117 subscription_manager: SubscriptionManager,
119
120 service_id: String,
122
123 pre_session_verify_slots: Arc<Semaphore>,
126}
127
128impl<P, V> SessionLayer<P, V>
129where
130 P: TokenProvider + Send + Sync + Clone + 'static,
131 V: Verifier + Send + Sync + Clone + 'static,
132{
133 const PRE_SESSION_VERIFY_SLOTS: usize = 128;
134
135 #[allow(clippy::too_many_arguments)]
137 pub fn new(
138 app_name: ProtoName,
139 identity_provider: P,
140 identity_verifier: V,
141 conn_id: u64,
142 tx_slim: SlimChannelSender,
143 tx_app: Sender<Result<Notification, SessionError>>,
144 direction: Direction,
145 service_id: String,
146 ) -> Self {
147 let (tx_session, rx_session) = tokio::sync::mpsc::channel(16);
148
149 let subscription_manager = SubscriptionManager::new(tx_slim.clone());
150
151 let initial_key = Self::name_to_key(&app_name);
152 let sl = SessionLayer {
153 pool: Arc::new(SyncRwLock::new(HashMap::new())),
154 app_id: app_name.id(),
155 app_names: SyncRwLock::new(HashMap::from([(initial_key, 0)])),
156 identity_provider,
157 identity_verifier,
158 conn_id,
159 tx_slim,
160 tx_app,
161 tx_session,
162 to_notify: SyncRwLock::new(HashMap::new()),
163 direction,
164 subscription_manager,
165 service_id,
166 pre_session_verify_slots: Arc::new(Semaphore::new(Self::PRE_SESSION_VERIFY_SLOTS)),
167 };
168
169 sl.listen_from_sessions(rx_session);
170
171 sl
172 }
173
174 pub fn tx_slim(&self) -> SlimChannelSender {
175 self.tx_slim.clone()
176 }
177
178 pub fn subscription_manager(&self) -> SubscriptionManager {
179 self.subscription_manager.clone()
180 }
181
182 pub fn tx_app(&self) -> Sender<Result<Notification, SessionError>> {
183 self.tx_app.clone()
184 }
185
186 #[allow(dead_code)]
187 pub fn conn_id(&self) -> u64 {
188 self.conn_id
189 }
190
191 pub fn app_id(&self) -> u128 {
192 self.app_id
193 }
194
195 fn name_to_key(name: &ProtoName) -> EncodedName {
197 let enc = name.name.as_ref().unwrap();
198 EncodedName {
199 component_0: enc.component_0,
200 component_1: enc.component_1,
201 component_2: enc.component_2,
202 name_id: Some(NameId::from(NameId::NULL_COMPONENT)),
203 }
204 }
205
206 pub fn add_app_name(&self, name: ProtoName, subscription_id: u64) {
207 let key = Self::name_to_key(&name);
208 self.app_names.write().insert(key, subscription_id);
209 }
210
211 pub fn remove_app_name(&self, name: &ProtoName) -> Option<u64> {
212 let key = Self::name_to_key(name);
213 let removed = self.app_names.write().remove(&key);
214 if removed.is_none() {
215 warn!(%name, "tried to remove unknown app name");
216 }
217 removed
218 }
219
220 fn get_local_name_for_session(&self, dst: ProtoName) -> Result<ProtoName, SessionError> {
221 let key = Self::name_to_key(&dst);
222 if self.app_names.read().contains_key(&key) {
223 Ok(dst.with_id(self.app_id))
224 } else {
225 Err(SessionError::SubscriptionNotFound(dst))
226 }
227 }
228
229 pub fn get_identity_token(&self) -> Result<String, SessionError> {
231 let token = self.identity_provider.get_token()?;
232 Ok(token)
233 }
234
235 #[tracing::instrument(skip_all, fields(service_id = %self.service_id))]
237 pub async fn create_session(
238 &self,
239 mut session_config: SessionConfig,
240 local_name: ProtoName,
241 destination: ProtoName,
242 id: Option<u32>,
243 ) -> Result<(SessionContext, CompletionHandle), SessionError> {
244 session_config.initiator = true;
246
247 let is_p2p = session_config.session_type == ProtoSessionType::PointToPoint;
249 let destination_proto = destination.clone();
250
251 let session = self.create_session_internal(session_config, local_name, destination, id)?;
252
253 let init_ack = if is_p2p {
256 session
257 .session()
258 .upgrade()
259 .ok_or(SessionError::SessionNotFound(u32::MAX))?
260 .invite_participant_internal(&destination_proto)
261 .await
262 .inspect_err(|_| {
263 let _ = self.remove_session(session.session_id());
265 })?
266 } else {
267 let (tx, rx) = tokio::sync::oneshot::channel();
269 let _ = tx.send(Ok(()));
270 CompletionHandle::from_oneshot_receiver(rx)
271 };
272
273 Ok((session, init_ack))
275 }
276
277 fn create_session_internal(
279 &self,
280 session_config: SessionConfig,
281 local_name: ProtoName,
282 destination: ProtoName,
283 id: Option<u32>,
284 ) -> Result<SessionContext, SessionError> {
285 loop {
287 let session_id = {
289 let pool = self.pool.read();
290
291 match id {
293 Some(id) => {
294 if !SESSION_RANGE.contains(&id) {
296 return Err(SessionError::InvalidSessionId(id));
297 }
298
299 if pool.contains_key(&id) {
301 return Err(SessionError::SessionIdAlreadyUsed(id));
302 }
303
304 id
305 }
306 None => {
307 loop {
309 let session_id = rand::rng().random_range(SESSION_RANGE);
310 if !pool.contains_key(&session_id) {
311 break session_id;
312 }
313 }
314 }
315 }
316 }; let (app_tx, app_rx) = tokio::sync::mpsc::unbounded_channel();
320
321 let builder = SessionController::builder()
324 .with_id(session_id)
325 .with_source(local_name.clone())
326 .with_destination(destination.clone())
327 .with_config(session_config.clone())
328 .with_identity_provider(self.identity_provider.clone())
329 .with_identity_verifier(self.identity_verifier.clone())
330 .with_slim_tx(self.tx_slim.clone())
331 .with_app_tx(app_tx)
332 .with_tx_to_session_layer(self.tx_session.clone())
333 .with_direction(self.direction)
334 .with_subscription_manager(self.subscription_manager.clone())
335 .with_service_id(self.service_id.clone())
336 .ready()?;
337
338 let session_controller = Arc::new(builder.build()?);
340
341 let mut pool = self.pool.write();
343
344 if pool.contains_key(&session_id) {
346 if id.is_some() {
348 return Err(SessionError::SessionIdAlreadyUsed(session_id));
349 }
350 continue;
352 }
353
354 let ret = pool.insert(session_id, session_controller.clone());
355
356 if ret.is_some() {
358 error!(
359 %session_id,
360 "session ID was taken during insertion: this should not happen",
361 );
362 return Err(SessionError::SessionIdAlreadyUsed(session_id));
363 }
364
365 return Ok(SessionContext::new(session_controller, app_rx));
366 }
367 }
368
369 pub fn listen_from_sessions(
370 &self,
371 mut rx_session: tokio::sync::mpsc::Receiver<Result<SessionMessage, SessionError>>,
372 ) {
373 let pool_clone = self.pool.clone();
374 let sessions_span = tracing::info_span!(parent: None, "listen_from_sessions", service_id = %self.service_id);
375
376 tokio::spawn(async move {
377 loop {
378 tokio::select! {
379 next = rx_session.recv() => {
380 match next {
381 Some(Ok(SessionMessage::DeleteSession { session_id })) => {
382 debug!(%session_id, "received closing signal, cancel session from the pool");
383 if pool_clone.write().remove(&session_id).is_none() {
384 warn!(%session_id, "requested to delete unknown session");
385 }
386 }
387 Some(Ok(m)) => {
388 error!(?m, "received unexpected message");
389 }
390 Some(Err(e)) => {
391 warn!(error = %e.chain(), "error from session");
392 }
393 None => {
394 break;
396 }
397 }
398 }
399 }
400 }
401 }.instrument(sessions_span));
402 }
403
404 #[tracing::instrument(skip_all, fields(service_id = %self.service_id, session_id = id))]
406 pub fn remove_session(&self, id: u32) -> Result<CompletionHandle, SessionError> {
407 debug!(%id, "try to remove session");
408 let binding = self.pool.read();
410 let session = binding.get(&id).ok_or(SessionError::SessionNotFound(id))?;
411
412 let join_handle = session.close()?;
414
415 Ok(CompletionHandle::from_join_handle(join_handle))
417 }
418
419 pub fn clear_all_sessions(&self) -> HashMap<u32, Result<CompletionHandle, SessionError>> {
421 let pool = {
422 let mut pool = self.pool.write();
423 let copy = pool.clone();
424 pool.clear();
425 copy
426 };
427
428 pool.iter()
430 .map(|(id, session)| {
431 let result = session.close().map(CompletionHandle::from_join_handle);
432 (*id, result)
433 })
434 .collect()
435 }
436
437 #[tracing::instrument(skip_all, fields(service_id = %self.service_id))]
439 pub async fn handle_error_from_slim(&self, error: SessionError) -> Result<(), SessionError> {
440 let Some(session_ctx) = error.session_context() else {
442 debug!(
443 error = %error.chain(),
444 "received error without session context in handle_error_from_slim",
445 );
446 return Ok(());
447 };
448
449 let session_id = session_ctx.session_id;
450 let session_controller = self.pool.read().get(&session_id).cloned();
451
452 if let Some(controller) = session_controller {
453 debug!(
454 error = %error.chain(),
455 session_id = %session_id,
456 "received error from SLIM for session id",
457 );
458
459 return controller.on_error_message_from_slim(error).await;
461 }
462
463 debug!(
464 error = %error.chain(),
465 "received error from SLIM for unknown session id",
466 );
467
468 Ok(())
469 }
470
471 #[tracing::instrument(skip_all, fields(service_id = %self.service_id))]
474 pub async fn handle_message_from_slim(
475 self: &Arc<Self>,
476 message: Message,
477 ) -> Result<(), SessionError> {
478 tracing::trace!(
479 msg_type = %message.get_session_message_type().as_str_name(),
480 session_id = %message.get_id(),
481 "received message from SLIM",
482 );
483
484 let (id, session_type, session_message_type) = {
485 let header = message.get_session_header();
486 (
487 header.session_id,
488 header.session_type(),
489 header.session_message_type(),
490 )
491 };
492
493 let session_controller = self.pool.read().get(&id).cloned();
497 if let Some(controller) = session_controller {
498 controller.on_message_from_slim(message).await?;
499
500 if session_message_type == ProtoSessionMessageType::GroupWelcome {
501 let new_session = self
502 .to_notify
503 .write()
504 .remove(&id)
505 .ok_or(SessionError::NewSessionSendFailed)?;
506 return self
507 .tx_app
508 .send(Ok(Notification::NewSession(new_session)))
509 .await
510 .map_err(|_e| SessionError::NewSessionSendFailed);
511 }
512
513 return Ok(());
514 }
515
516 match session_message_type {
523 ProtoSessionMessageType::JoinRequest => {
524 self.handle_join_request(message, id, session_type).await
525 }
526 ProtoSessionMessageType::DiscoveryRequest => {
527 self.handle_discovery_request(message, id, session_type, session_message_type)
528 }
529 _ => {
530 tracing::debug!(?message, "received channel message with unknown session id");
531 Ok(())
532 }
533 }
534 }
535
536 fn handle_discovery_request(
537 self: &Arc<Self>,
538 message: Message,
539 id: u32,
540 session_type: ProtoSessionType,
541 session_message_type: ProtoSessionMessageType,
542 ) -> Result<(), SessionError> {
543 let layer = self.clone();
544 tokio::spawn(async move {
545 let _permit = match layer.pre_session_verify_slots.clone().acquire_owned().await {
546 Ok(p) => p,
547 Err(_) => return,
548 };
549
550 if let Err(e) =
551 crate::session_controller::verify_identity(&message, &layer.identity_verifier).await
552 {
553 debug!(
554 error = %e.chain(),
555 msg_type = %session_message_type.as_str_name(),
556 "dropping pre-session message: identity verification failed",
557 );
558 return;
559 }
560
561 let local_name =
562 match layer.get_local_name_for_session(message.get_slim_header().get_dst()) {
563 Ok(n) => n,
564 Err(e) => {
565 debug!(error = %e.chain(), "error handling discovery request");
566 return;
567 }
568 };
569
570 let mut reply =
571 match handle_channel_discovery_message(&message, &local_name, id, session_type) {
572 Ok(r) => r,
573 Err(e) => {
574 debug!(error = %e.chain(), "error building discovery reply");
575 return;
576 }
577 };
578
579 let identity = match layer.identity_provider.get_token() {
580 Ok(t) => t,
581 Err(e) => {
582 debug!(error = %e.chain(), "error getting identity token for discovery reply");
583 return;
584 }
585 };
586 reply.get_slim_header_mut().set_identity(identity);
587 if let Err(e) = layer.tx_slim.send(Ok(reply)).await {
588 debug!(error = %e.chain(), "error sending discovery reply");
589 }
590 });
591
592 Ok(())
593 }
594
595 async fn handle_join_request(
596 &self,
597 message: Message,
598 id: u32,
599 session_type: ProtoSessionType,
600 ) -> Result<(), SessionError> {
601 let local_name = self.get_local_name_for_session(message.get_slim_header().get_dst())?;
602
603 let new_session = match session_type {
604 ProtoSessionType::PointToPoint => {
605 let conf = crate::SessionConfig::from_join_request(
606 ProtoSessionType::PointToPoint,
607 message.extract_command_payload()?,
608 message.get_metadata_map(),
609 false,
610 )?;
611 self.create_session_internal(conf, local_name, message.get_source(), Some(id))?
612 }
613 ProtoSessionType::Multicast => {
614 let payload = message.extract_join_request()?;
615 if payload.timer_settings.is_none() {
616 return Err(SessionError::MissingPayload {
617 context: "timer options",
618 });
619 }
620 let channel = payload
621 .channel
622 .clone()
623 .ok_or(SessionError::MissingChannelName)?;
624 let conf = crate::SessionConfig::from_join_request(
625 ProtoSessionType::Multicast,
626 message.extract_command_payload()?,
627 message.get_metadata_map(),
628 false,
629 )?;
630 self.create_session_internal(conf, local_name, channel, Some(id))?
631 }
632 _ => {
633 warn!(
634 session_type = %session_type.as_str_name(),
635 "received channel join request with unknown session type",
636 );
637 return Err(SessionError::SessionTypeUnknown(session_type));
638 }
639 };
640
641 let session_controller = new_session
642 .session()
643 .upgrade()
644 .ok_or(SessionError::SessionClosed)?;
645
646 session_controller.on_message_from_slim(message).await?;
647
648 self.to_notify
649 .write()
650 .insert(new_session.session_id(), new_session);
651
652 Ok(())
653 }
654
655 pub fn is_pool_empty(&self) -> bool {
657 self.pool.read().is_empty()
658 }
659
660 pub fn pool_size(&self) -> usize {
662 self.pool.read().len()
663 }
664
665 pub fn get_session(&self, id: u32) -> Option<Arc<SessionController>> {
667 self.pool.read().get(&id).cloned()
668 }
669}
670
671#[cfg(test)]
672mod tests {
673 use super::*;
674 use crate::test_utils::{MockTokenProvider, MockVerifier};
675 use slim_datapath::Status;
676 use slim_datapath::api::{NameId, ProtoName, ProtoSessionType};
677 use tokio::sync::mpsc;
678
679 fn make_name(parts: &[&str; 3]) -> ProtoName {
682 ProtoName::from_strings([parts[0], parts[1], parts[2]]).with_id(0)
683 }
684
685 type TestSessionLayer = Arc<SessionLayer<MockTokenProvider, MockVerifier>>;
686 type SlimReceiver = mpsc::Receiver<Result<Message, Status>>;
687 type AppReceiver = mpsc::Receiver<Result<Notification, SessionError>>;
688
689 fn setup_session_layer() -> (TestSessionLayer, SlimReceiver, AppReceiver) {
690 let app_name = make_name(&["test", "app", "v1"]);
691 let identity_provider = MockTokenProvider;
692 let identity_verifier = MockVerifier;
693 let conn_id = 12345u64;
694
695 let (tx_slim, rx_slim) = mpsc::channel(16);
696 let (tx_app, rx_app) = mpsc::channel(16);
697
698 let session_layer = Arc::new(SessionLayer::new(
699 app_name,
700 identity_provider,
701 identity_verifier,
702 conn_id,
703 tx_slim,
704 tx_app,
705 Direction::Bidirectional,
706 "test-service".to_string(),
707 ));
708
709 (session_layer, rx_slim, rx_app)
710 }
711
712 #[tokio::test]
713 async fn test_new_session_layer() {
714 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
715
716 assert_eq!(session_layer.app_id(), 0);
717 assert_eq!(session_layer.conn_id(), 12345);
718 assert!(session_layer.is_pool_empty());
719 }
720
721 #[tokio::test]
722 async fn test_add_and_remove_app_name() {
723 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
724
725 let name1 = make_name(&["service", "v1", "api"]);
726 let name2 = make_name(&["service", "v2", "api"]);
727
728 session_layer.add_app_name(name1.clone(), 0);
729 session_layer.add_app_name(name2.clone(), 0);
730
731 assert_eq!(session_layer.app_names.read().len(), 3); session_layer.remove_app_name(&name1);
735 assert_eq!(session_layer.app_names.read().len(), 2);
736
737 session_layer.remove_app_name(&name2);
738 assert_eq!(session_layer.app_names.read().len(), 1);
739 }
740
741 #[tokio::test]
742 async fn test_get_identity_token() {
743 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
744
745 let token = session_layer.get_identity_token();
746 assert!(token.is_ok());
747 assert_eq!(token.unwrap(), "");
748 }
749
750 #[tokio::test]
751 async fn test_create_session_with_auto_id() {
752 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
753
754 let local_name = make_name(&["local", "app", "v1"]);
755 let destination = make_name(&["remote", "app", "v1"]);
756 let config = SessionConfig {
757 session_type: ProtoSessionType::PointToPoint,
758 max_retries: Some(3),
759 interval: Some(std::time::Duration::from_secs(1)),
760 mls_settings: None,
761 initiator: true,
762 metadata: Default::default(),
763 };
764
765 let result = session_layer.create_session_internal(config, local_name, destination, None);
766
767 assert!(result.is_ok());
768 assert_eq!(session_layer.pool_size(), 1);
769 }
770
771 #[tokio::test]
772 async fn test_create_session_with_specific_id() {
773 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
774
775 let local_name = make_name(&["local", "app", "v1"]);
776 let destination = make_name(&["remote", "app", "v1"]);
777 let config = SessionConfig {
778 session_type: ProtoSessionType::PointToPoint,
779 max_retries: Some(3),
780 interval: Some(std::time::Duration::from_secs(1)),
781 mls_settings: None,
782 initiator: true,
783 metadata: Default::default(),
784 };
785
786 let session_id = 100u32;
787 let result = session_layer.create_session_internal(
788 config,
789 local_name,
790 destination,
791 Some(session_id),
792 );
793
794 assert!(result.is_ok());
795 assert_eq!(session_layer.pool_size(), 1);
796
797 let session = session_layer.get_session(session_id);
798 assert!(session.is_some());
799 }
800
801 #[tokio::test]
802 async fn test_create_session_with_invalid_id() {
803 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
804
805 let local_name = make_name(&["local", "app", "v1"]);
806 let destination = make_name(&["remote", "app", "v1"]);
807 let config = SessionConfig {
808 session_type: ProtoSessionType::PointToPoint,
809 max_retries: Some(3),
810 interval: Some(std::time::Duration::from_secs(1)),
811 mls_settings: None,
812 initiator: true,
813 metadata: Default::default(),
814 };
815
816 let invalid_id = u32::MAX - 500; let result = session_layer.create_session_internal(
819 config,
820 local_name,
821 destination,
822 Some(invalid_id),
823 );
824
825 assert!(result.is_err());
826 match result {
827 Err(SessionError::InvalidSessionId(_)) => {}
828 _ => panic!("Expected InvalidSessionId error"),
829 }
830 }
831
832 #[tokio::test]
833 async fn test_create_session_with_duplicate_id() {
834 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
835
836 let local_name = make_name(&["local", "app", "v1"]);
837 let destination = make_name(&["remote", "app", "v1"]);
838 let config = SessionConfig {
839 session_type: ProtoSessionType::PointToPoint,
840 max_retries: Some(3),
841 interval: Some(std::time::Duration::from_secs(1)),
842 mls_settings: None,
843 initiator: true,
844 metadata: Default::default(),
845 };
846
847 let session_id = 100u32;
848
849 let result1 = session_layer.create_session_internal(
851 config.clone(),
852 local_name.clone(),
853 destination.clone(),
854 Some(session_id),
855 );
856 assert!(result1.is_ok());
857
858 let result2 = session_layer.create_session_internal(
860 config,
861 local_name,
862 destination,
863 Some(session_id),
864 );
865
866 assert!(result2.is_err());
867 match result2 {
868 Err(SessionError::SessionIdAlreadyUsed(_)) => {}
869 _ => panic!("Expected SessionIdAlreadyUsed error"),
870 }
871 }
872
873 #[tokio::test]
874 async fn test_remove_session() {
875 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
876
877 let local_name = make_name(&["local", "app", "v1"]);
878 let destination = make_name(&["remote", "app", "v1"]);
879 let config = SessionConfig {
880 session_type: ProtoSessionType::PointToPoint,
881 max_retries: Some(3),
882 interval: Some(std::time::Duration::from_secs(1)),
883 mls_settings: None,
884 initiator: true,
885 metadata: Default::default(),
886 };
887
888 let session_id = 100u32;
889 let _context = session_layer
890 .create_session_internal(config, local_name, destination, Some(session_id))
891 .unwrap();
892
893 assert_eq!(session_layer.pool_size(), 1);
894
895 let removed = session_layer
896 .remove_session(session_id)
897 .expect("error removing connection");
898 removed.await.expect("error awaiting the handler");
900 assert!(session_layer.is_pool_empty());
901
902 let removed_again = session_layer.remove_session(session_id);
904 assert!(removed_again.is_err());
905 }
906
907 #[tokio::test]
908 async fn test_get_local_name_for_session() {
909 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
910
911 let name = make_name(&["service", "api", "v1"]);
912 session_layer.add_app_name(name.clone(), 0);
913
914 let dst = name.with_id(123);
915 let result = session_layer.get_local_name_for_session(dst);
916
917 assert!(result.is_ok());
918 let local_name = result.unwrap();
919 assert_eq!(local_name.id(), session_layer.app_id());
920 }
921
922 #[tokio::test]
923 async fn test_get_local_name_for_session_not_found() {
924 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
925
926 let unknown_name = make_name(&["unknown", "service", "v1"]);
927 let result = session_layer.get_local_name_for_session(unknown_name);
928
929 assert!(result.is_err());
930 match result {
931 Err(SessionError::SubscriptionNotFound(_)) => {}
932 _ => panic!("Expected SubscriptionNotFound error"),
933 }
934 }
935
936 #[tokio::test]
937 async fn test_tx_slim_and_tx_app_cloning() {
938 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
939
940 let tx_slim = session_layer.tx_slim();
941 let tx_app = session_layer.tx_app();
942
943 let _tx_slim2 = tx_slim.clone();
945 let _tx_app2 = tx_app.clone();
946 }
947
948 #[tokio::test]
949 async fn test_handle_discovery_request_without_session() {
950 let (session_layer, mut rx_slim, _rx_app) = setup_session_layer();
951
952 let local_name = make_name(&["local", "app", "v1"]);
953 session_layer.add_app_name(local_name.clone(), 0);
954
955 let source = make_name(&["remote", "app", "v1"]);
956 let message = Message::builder()
957 .source(source.clone())
958 .destination(local_name.clone().with_id(session_layer.app_id()))
959 .identity("")
960 .forward_to(0)
961 .incoming_conn(12345)
962 .session_type(ProtoSessionType::PointToPoint)
963 .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
964 .session_id(100)
965 .message_id(0)
966 .application_payload("", vec![])
967 .build_publish()
968 .unwrap();
969
970 session_layer
971 .handle_message_from_slim(message)
972 .await
973 .unwrap();
974
975 let sent = tokio::time::timeout(std::time::Duration::from_secs(1), rx_slim.recv())
976 .await
977 .expect("expected a discovery reply")
978 .expect("slim channel closed")
979 .expect("slim delivered an error");
980
981 assert_eq!(
982 sent.get_session_header().session_message_type(),
983 ProtoSessionMessageType::DiscoveryReply
984 );
985 }
986
987 #[tokio::test]
988 async fn test_pre_session_unknown_message_is_dropped() {
989 let (session_layer, mut rx_slim, _rx_app) = setup_session_layer();
990
991 let local_name = make_name(&["local", "app", "v1"]);
992 session_layer.add_app_name(local_name.clone(), 0);
993
994 let source = make_name(&["remote", "app", "v1"]);
995 let mut message = Message::builder()
996 .source(source.clone())
997 .destination(local_name.clone().with_id(session_layer.app_id()))
998 .application_payload("application/octet-stream", vec![])
999 .build_publish()
1000 .unwrap();
1001 let header = message.get_session_header_mut();
1002 header.set_session_type(ProtoSessionType::PointToPoint);
1003 header.set_session_message_type(ProtoSessionMessageType::Msg);
1004 header.session_id = 100;
1005
1006 session_layer
1007 .handle_message_from_slim(message)
1008 .await
1009 .unwrap();
1010
1011 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1012 assert!(rx_slim.try_recv().is_err());
1013 }
1014
1015 #[tokio::test]
1016 async fn test_multiple_sessions_in_pool() {
1017 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
1018
1019 let local_name = make_name(&["local", "app", "v1"]);
1020 let config = SessionConfig {
1021 session_type: ProtoSessionType::PointToPoint,
1022 max_retries: Some(3),
1023 interval: Some(std::time::Duration::from_secs(1)),
1024 mls_settings: None,
1025 initiator: true,
1026 metadata: Default::default(),
1027 };
1028
1029 for i in 0..5 {
1031 let destination = make_name(&["remote", &format!("app{}", i), "v1"]);
1032 let result = session_layer.create_session_internal(
1033 config.clone(),
1034 local_name.clone(),
1035 destination,
1036 None,
1037 );
1038 assert!(result.is_ok());
1039 }
1040
1041 assert_eq!(session_layer.pool_size(), 5);
1042 }
1043
1044 #[test]
1045 fn test_direction_to_participant_settings() {
1046 let s = Direction::Send.to_participant_settings();
1047 assert!(s.sends_data);
1048 assert!(!s.receives_data);
1049
1050 let s = Direction::Recv.to_participant_settings();
1051 assert!(!s.sends_data);
1052 assert!(s.receives_data);
1053
1054 let s = Direction::Bidirectional.to_participant_settings();
1055 assert!(s.sends_data);
1056 assert!(s.receives_data);
1057
1058 let s = Direction::None.to_participant_settings();
1059 assert!(!s.sends_data);
1060 assert!(!s.receives_data);
1061 }
1062
1063 #[tokio::test]
1064 async fn test_remove_app_name_with_null_component() {
1065 let (session_layer, _rx_slim, _rx_app) = setup_session_layer();
1066
1067 let name = make_name(&["service", "v1", "api"]).with_id(123);
1068 session_layer.add_app_name(name.clone(), 0);
1069
1070 session_layer.remove_app_name(&name);
1072
1073 let name_null = name.with_id(NameId::NULL_COMPONENT);
1075 assert!(
1076 !session_layer
1077 .app_names
1078 .read()
1079 .contains_key(
1080 &SessionLayer::<MockTokenProvider, MockVerifier>::name_to_key(&name_null)
1081 )
1082 );
1083 }
1084}