use std::future::Future;
use std::sync::{Arc, Weak};
use crate::common::AppChannelReceiver;
use crate::session_controller::SessionController;
#[derive(Debug)]
pub struct SessionContext {
pub session: Weak<SessionController>,
pub rx: AppChannelReceiver,
}
impl SessionContext {
pub fn new(session: Arc<SessionController>, rx: AppChannelReceiver) -> Self {
SessionContext {
session: Arc::downgrade(&session),
rx,
}
}
pub fn session(&self) -> &Weak<SessionController> {
&self.session
}
pub fn session_arc(&self) -> Option<Arc<SessionController>> {
self.session().upgrade()
}
pub fn into_parts(self) -> (Weak<SessionController>, AppChannelReceiver) {
(self.session, self.rx)
}
pub fn spawn_receiver<F, Fut>(self, f: F) -> Weak<SessionController>
where
F: FnOnce(AppChannelReceiver, Weak<SessionController>) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let (session, rx) = self.into_parts();
let session_clone = session.clone();
tokio::spawn(async move {
f(rx, session_clone).await;
});
session
}
pub fn session_id(&self) -> u32 {
self.session_arc().map(|s| s.id()).unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::AppChannelSender;
use crate::session_config::SessionConfig;
use crate::session_controller::SessionController;
use crate::test_utils::{MockTokenProvider, MockVerifier};
use crate::{SessionError, SessionMessage};
use slim_datapath::api::ProtoName as Name;
use slim_datapath::api::ProtoSessionType;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
fn make_name(parts: [&str; 3]) -> Name {
Name::from_strings(parts).with_id(0)
}
async fn build_session_controller_with_app_tx(
id: u32,
app_tx: AppChannelSender,
) -> Arc<SessionController> {
use crate::SlimChannelSender;
let source = make_name(["a", "b", "c"]);
let destination = make_name(["x", "y", "z"]);
let cfg = SessionConfig {
session_type: ProtoSessionType::PointToPoint,
max_retries: Some(3),
interval: Some(std::time::Duration::from_secs(1)),
mls_settings: None,
initiator: false,
metadata: Default::default(),
};
let (slim_tx, _slim_rx): (SlimChannelSender, _) = mpsc::channel(32);
let (tx_session, _rx_session): (mpsc::Sender<Result<SessionMessage, SessionError>>, _) =
mpsc::channel(32);
Arc::new(
SessionController::builder()
.with_id(id)
.with_source(source)
.with_destination(destination.clone())
.with_config(cfg)
.with_identity_provider(MockTokenProvider)
.with_identity_verifier(MockVerifier)
.with_slim_tx(slim_tx)
.with_app_tx(app_tx.clone())
.with_tx_to_session_layer(tx_session)
.ready()
.expect("Failed to prepare SessionController builder")
.build()
.expect("Failed to create SessionController"),
)
}
#[tokio::test]
async fn context_new_and_upgrade() {
let (tx_app, rx_app) = mpsc::unbounded_channel();
let session_controller = build_session_controller_with_app_tx(1, tx_app).await;
let ctx = SessionContext::new(session_controller.clone(), rx_app);
assert!(ctx.session_arc().is_some());
}
#[tokio::test]
async fn context_spawn_receiver_runs_closure() {
let (tx_app, rx_app) = mpsc::unbounded_channel();
let session_controller = build_session_controller_with_app_tx(3, tx_app).await;
let ctx = SessionContext::new(session_controller.clone(), rx_app);
let flag = Arc::new(tokio::sync::Mutex::new(false));
let flag_clone = flag.clone();
let weak = ctx.spawn_receiver(move |_rx, s| async move {
assert!(s.upgrade().is_some());
*flag_clone.lock().await = true;
});
assert!(weak.upgrade().is_some());
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
assert!(*flag.lock().await, "closure not executed");
}
#[tokio::test]
async fn context_spawn_receiver_drops_session() {
let (tx_app, rx_app) = mpsc::unbounded_channel();
let session_controller = build_session_controller_with_app_tx(4, tx_app).await;
let ctx = SessionContext::new(session_controller.clone(), rx_app);
let weak = ctx.spawn_receiver(|_rx, s| async move {
let _ = s;
});
drop(session_controller);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert!(
weak.upgrade().is_none(),
"session should be dropped when last strong ref gone"
);
}
#[tokio::test]
async fn context_spawn_receiver_task_finishes_on_session_drop() {
let (tx_app, rx_app) = mpsc::unbounded_channel();
let session_controller = build_session_controller_with_app_tx(5, tx_app.clone()).await;
let ctx = SessionContext::new(session_controller.clone(), rx_app);
let (done_tx, done_rx) = oneshot::channel();
let weak = ctx.spawn_receiver(move |mut rx, _s| async move {
while rx.recv().await.is_some() {}
let _ = done_tx.send(());
});
drop(tx_app);
drop(session_controller);
tokio::time::timeout(std::time::Duration::from_millis(200), done_rx)
.await
.expect("receiver task did not finish after channel close")
.ok();
assert!(weak.upgrade().is_none(), "session Arc should be gone");
}
#[tokio::test]
async fn dummy_verifier_trait_methods_coverage() {
use slim_auth::traits::Verifier;
let verifier = MockVerifier;
verifier.verify("some-token").await.unwrap();
verifier.try_verify("some-token").unwrap();
let _: Result<String, _> = verifier.get_claims("some-token").await;
let _: Result<String, _> = verifier.try_get_claims("some-token");
}
}