use std::{collections::HashMap, time::Duration};
use display_error_chain::ErrorChainExt;
use parking_lot::Mutex;
use tokio::sync::{self, oneshot};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug};
use slim_auth::traits::{TokenProvider, Verifier};
use slim_datapath::{
api::{
CommandPayload, Content, NameId, ProtoMessage as Message, ProtoName,
ProtoSessionMessageType, ProtoSessionType, SlimHeader,
},
messages::utils::SlimHeaderFlags,
};
use crate::{
MessageDirection, SessionError,
common::{OutboundMessage, SessionMessage, SessionOutput},
completion_handle::CompletionHandle,
controller_sender::{ControllerSender, PING_INTERVAL},
session_builder::{ForController, SessionBuilder},
session_config::SessionConfig,
session_settings::SessionSettings,
traits::{MessageHandler, ProcessingState},
};
pub(crate) async fn verify_identity<V>(msg: &Message, verifier: &V) -> Result<(), SessionError>
where
V: Verifier + Send + Sync,
{
let identity = msg.get_slim_header().get_identity();
if verifier.try_verify(&identity).is_err() {
verifier.verify(&identity).await?;
}
Ok(())
}
pub struct SessionController {
pub(crate) id: u32,
pub(crate) source: ProtoName,
pub(crate) destination: ProtoName,
pub(crate) config: SessionConfig,
tx_controller: sync::mpsc::Sender<SessionMessage>,
pub(crate) cancellation_token: CancellationToken,
handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
}
impl SessionController {
pub fn builder<P, V>() -> SessionBuilder<P, V, ForController>
where
P: TokenProvider + Send + Sync + Clone + 'static,
V: Verifier + Send + Sync + Clone + 'static,
{
SessionBuilder::for_controller()
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn from_parts<I, P, V, M>(
id: u32,
source: ProtoName,
destination: ProtoName,
config: SessionConfig,
settings: SessionSettings<P, V, M>,
tx: sync::mpsc::Sender<SessionMessage>,
rx: sync::mpsc::Receiver<SessionMessage>,
inner: I,
) -> Self
where
I: MessageHandler + Send + Sync + 'static,
P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
M: crate::subscription_manager::SubscriptionOps,
{
let cancellation_token = CancellationToken::new();
let span = tracing::debug_span!(
parent: None,
"session_controller_processing_loop",
session_id = id,
service_id = %settings.service_id,
source = %source,
destination = %destination,
session_type = ?config.session_type
);
let handle = crate::runtime::spawn(
Self::processing_loop(inner, rx, cancellation_token.clone(), settings).instrument(span),
);
Self {
id,
source,
destination,
config,
tx_controller: tx,
cancellation_token,
handle: Mutex::new(Some(handle)),
}
}
fn enter_draining_state<P, V, M>(
shutdown_deadline: &mut std::pin::Pin<&mut tokio::time::Sleep>,
settings: &SessionSettings<P, V, M>,
) where
P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
M: crate::subscription_manager::SubscriptionOps,
{
let shutdown_timeout = settings
.graceful_shutdown_timeout
.unwrap_or(Duration::from_secs(60));
shutdown_deadline
.as_mut()
.reset(tokio::time::Instant::now() + shutdown_timeout);
}
pub(crate) fn apply_identity_to_slim_output<P>(
output: &mut SessionOutput,
identity_provider: &P,
) -> Result<(), SessionError>
where
P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
{
let identity = identity_provider.get_token()?;
for msg in &mut output.messages {
if let OutboundMessage::ToSlim(m) = msg {
m.get_slim_header_mut().set_identity(identity.clone());
}
}
Ok(())
}
async fn dispatch_output<P, V, M>(output: SessionOutput, settings: &SessionSettings<P, V, M>)
where
P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
M: crate::subscription_manager::SubscriptionOps,
{
for msg in output.messages {
match msg {
OutboundMessage::ToSlim(message) => {
if let Err(e) = settings.slim_tx.send(Ok(message)).await {
tracing::error!(error = %e, "failed to send message to SLIM");
}
}
OutboundMessage::ToApp(result) => {
if let Err(e) = settings.app_tx.send(result) {
tracing::error!(error = %e, "failed to send message to application");
}
}
}
}
}
async fn processing_loop<P, V, M>(
mut inner: impl MessageHandler + 'static,
mut rx: sync::mpsc::Receiver<SessionMessage>,
cancellation_token: CancellationToken,
settings: SessionSettings<P, V, M>,
) where
P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static,
V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static,
M: crate::subscription_manager::SubscriptionOps,
{
let mut shutdown_deadline = std::pin::pin!(tokio::time::sleep(Duration::MAX));
if let Err(e) = inner.init().await {
tracing::error!(error = %e.chain(), "error during initialization of session");
}
loop {
tokio::select! {
_ = cancellation_token.cancelled(), if inner.processing_state() == ProcessingState::Active => {
let shutdown_timeout = settings.graceful_shutdown_timeout
.unwrap_or(Duration::from_secs(60));
debug!("consuming pending messages before entering draining state");
while let Ok(msg) = rx.try_recv() {
if let SessionMessage::OnMessage { message, direction: MessageDirection::North, .. } = &msg
&& let Err(e) = crate::session_controller::verify_identity(message, &settings.identity_verifier).await {
debug!(error = %e.chain(), "dropping inbound message during drain: identity verification failed");
continue;
}
match inner.on_message(msg).await {
Ok(output) => Self::dispatch_output(output, &settings).await,
Err(e) => {
tracing::error!(error = %e.chain(), "error processing message during draining - close immediately.");
break;
}
}
}
match inner.on_message(SessionMessage::StartDrain {
grace_period: shutdown_timeout
}).await {
Ok(output) => Self::dispatch_output(output, &settings).await,
Err(e) => {
tracing::error!(error = %e.chain(), "error during start drain");
break;
}
}
Self::enter_draining_state(&mut shutdown_deadline, &settings);
debug!("cancellation requested, entering draining state");
}
_ = &mut shutdown_deadline => {
debug!("graceful shutdown timeout reached, forcing exit");
break;
}
msg = rx.recv() => {
match msg {
Some(session_message) => {
if let SessionMessage::GetParticipantsList { tx } = session_message {
let participants_list = inner.participants_list();
let _ = tx.send(participants_list);
continue;
}
if let SessionMessage::OnMessage { message, direction: MessageDirection::North, .. } = &session_message
&& let Err(e) = crate::session_controller::verify_identity(message, &settings.identity_verifier).await {
debug!(
error = %e.chain(),
msg_type = %message.get_session_message_type().as_str_name(),
msg_id = %message.get_id(),
"dropping inbound message: identity verification failed",
);
continue;
}
let draining = inner.processing_state() == ProcessingState::Draining;
if draining && matches!(session_message, SessionMessage::OnMessage { direction: MessageDirection::South, .. }) {
tracing::debug!("session is draining, rejecting new messages from application");
if let SessionMessage::OnMessage { ack_tx: Some(ack_tx), .. } = session_message {
let _ = ack_tx.send(Err(SessionError::SessionDrainingDrop));
}
continue;
}
match inner.on_message(session_message).await {
Ok(output) => {
Self::dispatch_output(output, &settings).await;
if !draining && inner.processing_state() == ProcessingState::Draining {
debug!("internal component requested draining, entering draining state");
Self::enter_draining_state(&mut shutdown_deadline, &settings);
}
}
Err(e) => {
debug!(
error=%e,
"Error processing message{}",
if draining { " during graceful shutdown" } else { "" }
);
if draining {
debug!("Exiting processing loop due to error while draining");
break;
}
}
}
}
None => {
debug!("Session channel closed, no more messages can arrive - exiting processing loop");
break;
}
}
}
}
if inner.processing_state() == ProcessingState::Draining && !inner.needs_drain() {
debug!("draining complete, exiting processing loop");
break;
}
}
if let Err(e) = inner.on_shutdown().await {
tracing::error!(error = %e.chain(), "error during shutdown of session");
}
}
pub fn id(&self) -> u32 {
self.id
}
pub fn source(&self) -> &ProtoName {
&self.source
}
pub fn dst(&self) -> &ProtoName {
&self.destination
}
pub fn session_type(&self) -> ProtoSessionType {
self.config.session_type
}
pub fn metadata(&self) -> HashMap<String, String> {
self.config.metadata.clone()
}
pub fn session_config(&self) -> SessionConfig {
self.config.clone()
}
pub fn is_initiator(&self) -> bool {
self.config.initiator
}
pub async fn participants_list(&self) -> Result<Vec<ProtoName>, SessionError> {
let (tx, rx) = oneshot::channel();
self.tx_controller
.send(SessionMessage::GetParticipantsList { tx })
.await
.map_err(|_| SessionError::ParticipantsListQueryFailed)?;
rx.await
.map_err(|_| SessionError::ParticipantsListQueryFailed)
}
async fn on_message(
&self,
message: Message,
direction: MessageDirection,
ack_tx: Option<oneshot::Sender<Result<(), SessionError>>>,
) -> Result<(), SessionError> {
self.tx_controller
.send(SessionMessage::OnMessage {
message,
direction,
ack_tx,
})
.await
.map_err(|_e| SessionError::SessionControllerSendFailed)
}
pub async fn on_message_from_app(
&self,
message: Message,
) -> Result<CompletionHandle, SessionError> {
let (ack_tx, ack_rx) = oneshot::channel();
self.on_message(message, MessageDirection::South, Some(ack_tx))
.await?;
let ret = CompletionHandle::from_oneshot_receiver(ack_rx);
Ok(ret)
}
pub async fn on_message_from_slim(&self, message: Message) -> Result<(), SessionError> {
self.on_message(message, MessageDirection::North, None)
.await
}
pub async fn on_error_message_from_slim(
&self,
error: SessionError,
) -> Result<(), SessionError> {
self.tx_controller
.send(SessionMessage::MessageError { error })
.await
.map_err(|_e| SessionError::SessionControllerSendFailed)
}
pub fn close(&self) -> Result<tokio::task::JoinHandle<()>, SessionError> {
self.cancellation_token.cancel();
self.handle
.lock()
.take()
.ok_or(SessionError::SessionAlreadyClosed)
}
pub async fn publish_message(
&self,
message: Message,
) -> Result<CompletionHandle, SessionError> {
self.on_message_from_app(message).await
}
pub async fn publish_to(
&self,
name: &ProtoName,
forward_to: u64,
blob: Vec<u8>,
payload_type: Option<String>,
metadata: Option<HashMap<String, String>>,
) -> Result<CompletionHandle, SessionError> {
self.publish_with_flags(
name,
SlimHeaderFlags::default().with_forward_to(forward_to),
blob,
payload_type,
metadata,
)
.await
}
pub async fn publish(
&self,
name: &ProtoName,
blob: Vec<u8>,
payload_type: Option<String>,
metadata: Option<HashMap<String, String>>,
) -> Result<CompletionHandle, SessionError> {
self.publish_with_flags(
name,
SlimHeaderFlags::default(),
blob,
payload_type,
metadata,
)
.await
}
pub async fn publish_with_flags(
&self,
name: &ProtoName,
flags: SlimHeaderFlags,
blob: Vec<u8>,
payload_type: Option<String>,
metadata: Option<HashMap<String, String>>,
) -> Result<CompletionHandle, SessionError> {
let ct = payload_type.unwrap_or_else(|| "msg".to_string());
let mut msg = Message::builder()
.source(self.source().clone())
.destination(name.clone())
.identity("")
.flags(flags)
.session_type(self.session_type())
.session_message_type(ProtoSessionMessageType::Msg)
.session_id(self.id())
.message_id(rand::random::<u32>()) .application_payload(&ct, blob)
.build_publish()?;
if let Some(map) = metadata
&& !map.is_empty()
{
msg.set_metadata_map(map);
}
self.publish_message(msg).await
}
fn create_discovery_request(&self, destination: &ProtoName) -> Result<Message, SessionError> {
let payload = CommandPayload::builder().discovery_request().as_content();
let msg = Message::builder()
.source(self.source().clone())
.destination(destination.clone())
.identity("")
.session_type(self.session_type())
.session_message_type(ProtoSessionMessageType::DiscoveryRequest)
.session_id(self.id())
.message_id(rand::random::<u32>())
.payload(payload)
.build_publish()?;
Ok(msg)
}
pub(crate) async fn invite_participant_internal(
&self,
destination: &ProtoName,
) -> Result<CompletionHandle, SessionError> {
let msg = self.create_discovery_request(destination)?;
self.publish_message(msg).await
}
pub async fn invite_participant(
&self,
destination: &ProtoName,
) -> Result<CompletionHandle, SessionError> {
match self.session_type() {
ProtoSessionType::PointToPoint => Err(SessionError::CannotInviteToP2P),
ProtoSessionType::Multicast => {
if !self.is_initiator() {
return Err(SessionError::NotInitiator);
}
self.invite_participant_internal(destination).await
}
_ => Err(SessionError::SessionTypeUnknown(self.session_type())),
}
}
pub async fn remove_participant(
&self,
destination: &ProtoName,
) -> Result<CompletionHandle, SessionError> {
match self.session_type() {
ProtoSessionType::PointToPoint => Err(SessionError::CannotRemoveFromP2P),
ProtoSessionType::Multicast => {
if !self.is_initiator() {
return Err(SessionError::NotInitiator);
}
let msg = Message::builder()
.source(self.source().clone())
.destination(destination.clone().with_id(NameId::NULL_COMPONENT))
.identity("")
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::LeaveRequest)
.session_id(self.id())
.message_id(rand::random::<u32>())
.payload(CommandPayload::builder().leave_request().as_content())
.build_publish()?;
self.publish_message(msg).await
}
_ => Err(SessionError::SessionTypeUnknown(self.session_type())),
}
}
}
impl Drop for SessionController {
fn drop(&mut self) {
self.cancellation_token.cancel();
}
}
pub fn handle_channel_discovery_message(
message: &Message,
app_name: &ProtoName,
session_id: u32,
session_type: ProtoSessionType,
) -> Result<Message, SessionError> {
let destination = message.get_slim_header().source.clone().unwrap();
let mut source = message.get_slim_header().destination.clone().unwrap();
source.set_id(app_name.id());
let msg_id = message.get_id();
let slim_header = SlimHeader::new(
source,
destination,
"",
Some(SlimHeaderFlags::default().with_forward_to(message.get_incoming_conn())),
);
let msg = Message::builder()
.with_slim_header(slim_header)
.session_type(session_type)
.session_message_type(ProtoSessionMessageType::DiscoveryReply)
.session_id(session_id)
.message_id(msg_id)
.payload(CommandPayload::builder().discovery_reply().as_content())
.build_publish()?;
Ok(msg)
}
pub(crate) struct SessionControllerCommon<
P,
V,
M = crate::subscription_manager::SubscriptionManager,
> where
P: TokenProvider + Send + Sync + Clone + 'static,
V: Verifier + Send + Sync + Clone + 'static,
M: crate::subscription_manager::SubscriptionOps,
{
pub(crate) settings: SessionSettings<P, V, M>,
pub(crate) sender: ControllerSender,
pub(crate) processing_state: ProcessingState,
subscription_ids: HashMap<(SubscriptionKind, ProtoName, u64), u64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum SubscriptionKind {
Route,
Subscription,
}
impl<P, V, M> SessionControllerCommon<P, V, M>
where
P: TokenProvider + Send + Sync + Clone + 'static,
V: Verifier + Send + Sync + Clone + 'static,
M: crate::subscription_manager::SubscriptionOps,
{
pub(crate) fn new(settings: SessionSettings<P, V, M>) -> Self {
let controller_sender = ControllerSender::new(
settings.config.get_timer_settings(),
settings.source.clone(),
settings.config.session_type,
settings.id,
Some(PING_INTERVAL),
settings.config.initiator,
settings.tx_session.clone(),
);
SessionControllerCommon {
settings,
sender: controller_sender,
processing_state: ProcessingState::Active,
subscription_ids: HashMap::new(),
}
}
pub(crate) fn send_with_timer(
&mut self,
message: Message,
) -> Result<SessionOutput, SessionError> {
self.sender.on_message(&message)
}
async fn await_subscription_ack(
rx: tokio::sync::oneshot::Receiver<
Result<(), crate::subscription_manager::SubscriptionAckError>,
>,
) -> Result<(), SessionError> {
crate::subscription_manager::SubscriptionManager::await_ack(rx)
.await
.map_err(SessionError::SubscriptionAckFailed)
}
pub(crate) async fn add_route(
&mut self,
name: ProtoName,
conn: u64,
) -> Result<(), SessionError> {
if name == self.settings.source.clone() {
return Ok(());
}
let source_proto = self.settings.source.clone();
let (subscription_id, rx) = self
.settings
.subscription_manager
.set_route(&source_proto, &name, conn)
.await
.map_err(SessionError::SubscriptionAckFailed)?;
Self::await_subscription_ack(rx).await?;
debug!(%name, %conn, %subscription_id, source = %self.settings.source, "route added");
self.subscription_ids
.insert((SubscriptionKind::Route, name, conn), subscription_id);
Ok(())
}
pub(crate) async fn delete_route(
&mut self,
name: ProtoName,
conn: u64,
) -> Result<(), SessionError> {
if name == self.settings.source.clone() {
return Ok(());
}
let key = (SubscriptionKind::Route, name, conn);
let subscription_id = self.subscription_ids.remove(&key);
let (_, name, conn) = key;
match subscription_id {
Some(subscription_id) => {
let source_proto = self.settings.source.clone();
let rx = self
.settings
.subscription_manager
.remove_route(&source_proto, &name, subscription_id, conn)
.await
.map_err(SessionError::SubscriptionAckFailed)?;
Self::await_subscription_ack(rx).await?;
tracing::debug!(%name, %conn, %subscription_id, "route deleted");
}
None => {
tracing::warn!(
%name, %conn, io = %self.settings.source,
"no subscription_id found for route, skipping delete"
);
}
}
Ok(())
}
pub(crate) async fn add_subscription(
&mut self,
name: ProtoName,
conn: u64,
) -> Result<(), SessionError> {
let source_proto = self.settings.source.clone();
let (subscription_id, rx) = self
.settings
.subscription_manager
.subscribe(&source_proto, &name, Some(conn))
.await
.map_err(SessionError::SubscriptionAckFailed)?;
Self::await_subscription_ack(rx).await?;
debug!(%name, %conn, %subscription_id, "subscription added");
self.subscription_ids.insert(
(SubscriptionKind::Subscription, name, conn),
subscription_id,
);
Ok(())
}
pub(crate) async fn delete_subscription(
&mut self,
name: ProtoName,
conn: u64,
) -> Result<(), SessionError> {
let key = (SubscriptionKind::Subscription, name, conn);
let subscription_id = self.subscription_ids.remove(&key);
let (_, name, conn) = key;
match subscription_id {
Some(subscription_id) => {
let source_proto = self.settings.source.clone();
let rx = self
.settings
.subscription_manager
.unsubscribe(&source_proto, &name, subscription_id, Some(conn))
.await
.map_err(SessionError::SubscriptionAckFailed)?;
Self::await_subscription_ack(rx).await?;
debug!(%name, %conn, %subscription_id, "subscription deleted");
}
None => {
tracing::debug!(
%name, %conn,
"no subscription_id found for subscription, skipping delete"
);
}
}
Ok(())
}
pub(crate) fn create_control_message(
&mut self,
dst: &ProtoName,
message_type: ProtoSessionMessageType,
message_id: u32,
payload: Content,
broadcast: bool,
) -> Result<Message, SessionError> {
let mut builder = Message::builder()
.source(self.settings.source.clone())
.destination(dst.clone())
.identity("")
.session_type(self.settings.config.session_type)
.session_message_type(message_type)
.session_id(self.settings.id)
.message_id(message_id)
.payload(payload);
if broadcast {
builder = builder.fanout(256);
}
let ret = builder.build_publish()?;
Ok(ret)
}
pub(crate) fn send_control_message(
&mut self,
dst: &ProtoName,
message_type: ProtoSessionMessageType,
message_id: u32,
payload: Content,
metadata: Option<HashMap<String, String>>,
broadcast: bool,
) -> Result<SessionOutput, SessionError> {
let mut msg =
self.create_control_message(dst, message_type, message_id, payload, broadcast)?;
if let Some(m) = metadata {
msg.set_metadata_map(m);
}
self.send_with_timer(msg)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Direction;
use crate::session_config::MlsSettings;
use crate::subscription_manager::{SpySubscriptionManager, SubscriptionCall};
use slim_auth::shared_secret::SharedSecret;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::time::Duration;
use tokio::time::timeout;
use tracing_test::traced_test;
const SHARED_SECRET: &str = "kjandjansdiasb8udaijdniasdaindasndasndasndasndasndasndasndas";
fn test_identity() -> String {
SharedSecret::new("test", SHARED_SECRET)
.unwrap()
.get_token()
.unwrap()
}
struct SessionControllerTestBuilder {
session_id: u32,
source: ProtoName,
destination: ProtoName,
session_type: ProtoSessionType,
mls_settings: Option<MlsSettings>,
initiator: bool,
max_retries: Option<u32>,
interval: Option<Duration>,
metadata: HashMap<String, String>,
graceful_shutdown_timeout: Option<Duration>,
}
impl SessionControllerTestBuilder {
#[allow(dead_code)]
fn new() -> Self {
Self {
session_id: 10,
source: ProtoName::from_strings(["org", "ns", "source"]).with_id(1),
destination: ProtoName::from_strings(["org", "ns", "dest"]).with_id(2),
session_type: ProtoSessionType::PointToPoint,
mls_settings: None,
initiator: true,
max_retries: Some(5),
interval: Some(Duration::from_millis(200)),
metadata: HashMap::new(),
graceful_shutdown_timeout: Some(Duration::from_secs(10)),
}
}
fn with_session_id(mut self, id: u32) -> Self {
self.session_id = id;
self
}
#[allow(dead_code)]
fn with_source(mut self, source: ProtoName) -> Self {
self.source = source;
self
}
#[allow(dead_code)]
fn with_destination(mut self, destination: ProtoName) -> Self {
self.destination = destination;
self
}
fn with_session_type(mut self, session_type: ProtoSessionType) -> Self {
self.session_type = session_type;
self
}
fn with_mls_enabled(mut self, enabled: bool) -> Self {
self.mls_settings = if enabled {
Some(MlsSettings::default())
} else {
None
};
self
}
fn with_initiator(mut self, initiator: bool) -> Self {
self.initiator = initiator;
self
}
fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
self.metadata = metadata;
self
}
fn with_graceful_shutdown_timeout(mut self, timeout: Duration) -> Self {
self.graceful_shutdown_timeout = Some(timeout);
self
}
fn build(
self,
) -> (
SessionController,
tokio::sync::mpsc::Receiver<Result<Message, slim_datapath::Status>>,
tokio::sync::mpsc::UnboundedReceiver<Result<Message, SessionError>>,
) {
let config = SessionConfig {
session_type: self.session_type,
max_retries: self.max_retries,
interval: self.interval,
mls_settings: self.mls_settings,
initiator: self.initiator,
metadata: self.metadata,
};
let (tx_slim, rx_slim) = tokio::sync::mpsc::channel(10);
let (tx_app, rx_app) = tokio::sync::mpsc::unbounded_channel();
let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10);
let controller = SessionController::builder()
.with_id(self.session_id)
.with_source(self.source.clone())
.with_destination(self.destination.clone())
.with_config(config)
.with_identity_provider(SharedSecret::new("test", SHARED_SECRET).unwrap())
.with_identity_verifier(SharedSecret::new("test", SHARED_SECRET).unwrap())
.with_slim_tx(tx_slim)
.with_app_tx(tx_app)
.with_tx_to_session_layer(tx_session_layer)
.ready()
.expect("failed to validate builder")
.build()
.expect("failed to build controller");
(controller, rx_slim, rx_app)
}
}
#[tokio::test]
async fn test_session_controller_getters() {
let mut metadata = HashMap::new();
metadata.insert("key1".to_string(), "value1".to_string());
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
.with_session_id(42)
.with_session_type(ProtoSessionType::Multicast)
.with_mls_enabled(true)
.with_metadata(metadata)
.build();
assert_eq!(controller.id(), 42);
assert_eq!(
controller.source(),
&ProtoName::from_strings(["org", "ns", "source"]).with_id(1)
);
assert_eq!(
controller.dst(),
&ProtoName::from_strings(["org", "ns", "dest"]).with_id(NameId::DATA_CHANNEL_ID)
);
assert_eq!(controller.session_type(), ProtoSessionType::Multicast);
assert!(controller.is_initiator());
assert_eq!(
controller.metadata().get("key1"),
Some(&"value1".to_string())
);
let retrieved_config = controller.session_config();
assert_eq!(retrieved_config.session_type, ProtoSessionType::Multicast);
assert_eq!(retrieved_config.max_retries, Some(5));
}
#[tokio::test]
async fn test_publish_basic() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
let target_name = ProtoName::from_strings(["org", "ns", "target"]);
let payload = b"Hello World".to_vec();
controller
.publish(
&target_name,
payload.clone(),
Some("test-type".to_string()),
None,
)
.await
.expect("publish should succeed");
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[tokio::test]
async fn test_publish_to_specific_connection() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
.with_session_type(ProtoSessionType::Multicast)
.build();
let target_name = ProtoName::from_strings(["org", "ns", "target"]);
let payload = b"Hello to specific connection".to_vec();
let connection_id = 123u64;
controller
.publish_to(
&target_name,
connection_id,
payload.clone(),
Some("test-type".to_string()),
None,
)
.await
.expect("publish_to should succeed");
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[tokio::test]
async fn test_publish_with_metadata() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
.with_session_type(ProtoSessionType::Multicast)
.build();
let target_name = ProtoName::from_strings(["org", "ns", "target"]);
let payload = b"Hello with metadata".to_vec();
let mut metadata = HashMap::new();
metadata.insert("custom_key".to_string(), "custom_value".to_string());
controller
.publish(
&target_name,
payload.clone(),
Some("test-type".to_string()),
Some(metadata),
)
.await
.expect("publish with metadata should succeed");
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[tokio::test]
async fn test_invite_participant_in_multicast() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
.with_session_type(ProtoSessionType::Multicast)
.build();
let participant = ProtoName::from_strings(["org", "ns", "participant"]);
controller
.invite_participant(&participant)
.await
.expect("invite should succeed");
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[tokio::test]
async fn test_invite_participant_not_initiator_error() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
.with_session_type(ProtoSessionType::Multicast)
.with_initiator(false)
.build();
let participant = ProtoName::from_strings(["org", "ns", "new_participant"]);
let result = controller.invite_participant(&participant).await;
assert!(result.is_err_and(|e| matches!(e, SessionError::NotInitiator)));
}
#[tokio::test]
async fn test_invite_participant_p2p_error() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
.with_session_type(ProtoSessionType::PointToPoint)
.build();
let participant = ProtoName::from_strings(["org", "ns", "participant"]);
let result = controller.invite_participant(&participant).await;
assert!(result.is_err_and(|e| matches!(e, SessionError::CannotInviteToP2P)));
}
#[tokio::test]
async fn test_remove_participant_in_multicast() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
.with_session_type(ProtoSessionType::Multicast)
.build();
let participant = ProtoName::from_strings(["org", "ns", "participant"]);
controller
.remove_participant(&participant)
.await
.expect("remove should succeed");
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[tokio::test]
async fn test_remove_participant_not_initiator_error() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
.with_session_type(ProtoSessionType::Multicast)
.with_initiator(false)
.build();
let participant = ProtoName::from_strings(["org", "ns", "participant"]);
let result = controller.remove_participant(&participant).await;
assert!(result.is_err_and(|e| matches!(e, SessionError::NotInitiator)));
}
#[tokio::test]
async fn test_remove_participant_p2p_error() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
.with_session_type(ProtoSessionType::PointToPoint)
.build();
let participant = ProtoName::from_strings(["org", "ns", "participant"]);
let result = controller.remove_participant(&participant).await;
assert!(result.is_err_and(|e| matches!(e, SessionError::CannotRemoveFromP2P)));
}
#[test]
fn test_handle_channel_discovery_message() {
let app_name = ProtoName::from_strings(["org", "ns", "app"]).with_id(100);
let session_id = 42;
let discovery_request = Message::builder()
.source(ProtoName::from_strings(["org", "ns", "requester"]).with_id(1))
.destination(ProtoName::from_strings(["org", "ns", "service"]))
.identity(test_identity())
.incoming_conn(999)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::DiscoveryRequest)
.session_id(session_id)
.message_id(123)
.payload(CommandPayload::builder().discovery_request().as_content())
.build_publish()
.unwrap();
let response = handle_channel_discovery_message(
&discovery_request,
&app_name,
session_id,
ProtoSessionType::Multicast,
)
.expect("should create discovery response");
assert_eq!(
response.get_session_message_type(),
ProtoSessionMessageType::DiscoveryReply
);
assert_eq!(response.get_session_header().get_session_id(), session_id);
assert_eq!(response.get_id(), 123);
assert_eq!(
response.get_dst(),
ProtoName::from_strings(["org", "ns", "requester"]).with_id(1)
);
assert_eq!(response.get_slim_header().get_forward_to(), Some(999));
}
#[tokio::test]
async fn test_controller_drop_cancels_processing() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
let token = controller.cancellation_token.clone();
assert!(!token.is_cancelled());
drop(controller);
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(token.is_cancelled());
}
#[tokio::test]
async fn test_close_success() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
.with_graceful_shutdown_timeout(std::time::Duration::from_secs(2))
.build();
let token = controller.cancellation_token.clone();
assert!(!token.is_cancelled());
let handle = controller.close();
assert!(handle.is_ok(), "got error {}", handle.unwrap_err());
assert!(token.is_cancelled());
handle
.unwrap()
.await
.expect("processing task should complete");
}
#[tokio::test]
async fn test_close_already_closed() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
let handle = controller.close();
assert!(handle.is_ok());
handle
.unwrap()
.await
.expect("processing task should complete");
let result = controller.close();
assert!(result.is_err());
match result {
Err(SessionError::SessionAlreadyClosed) => {
}
_ => panic!("Expected SessionError::SessionAlreadyClosed"),
}
}
#[tokio::test]
async fn test_close_cancels_token_immediately() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
let token = controller.cancellation_token.clone();
assert!(!token.is_cancelled());
let handle = controller.close();
assert!(handle.is_ok());
assert!(token.is_cancelled());
handle.unwrap().await.expect("processing should complete");
}
#[tokio::test]
async fn test_on_message_direction_north() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new().build();
let test_message = Message::builder()
.source(controller.dst().clone())
.destination(controller.source().clone())
.identity(test_identity())
.session_type(ProtoSessionType::PointToPoint)
.session_message_type(ProtoSessionMessageType::Msg)
.session_id(controller.id())
.message_id(1)
.application_payload("test", b"test data".to_vec())
.build_publish()
.unwrap();
let result = controller.on_message_from_slim(test_message).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_create_discovery_request() {
let (controller, _rx_slim, _rx_app) = SessionControllerTestBuilder::new()
.with_session_type(ProtoSessionType::Multicast)
.build();
let target = ProtoName::from_strings(["org", "ns", "target"]);
let discovery_msg = controller
.create_discovery_request(&target)
.expect("should create discovery request");
assert_eq!(discovery_msg.get_source(), *controller.source());
assert_eq!(discovery_msg.get_dst(), target);
assert_eq!(
discovery_msg.get_session_message_type(),
ProtoSessionMessageType::DiscoveryRequest
);
assert_eq!(
discovery_msg.get_session_header().get_session_id(),
controller.id()
);
assert_eq!(
discovery_msg.get_session_type(),
ProtoSessionType::Multicast
);
}
#[tokio::test]
#[traced_test]
async fn test_end_to_end_p2p() {
let session_id = 10;
let moderator_name = ProtoName::from_strings(["org", "ns", "moderator"]).with_id(1);
let participant_name = ProtoName::from_strings(["org", "ns", "participant"]);
let participant_name_id = ProtoName::from_strings(["org", "ns", "participant"]).with_id(1);
let (tx_slim_moderator, mut rx_slim_moderator) = tokio::sync::mpsc::channel(10);
let (tx_app_moderator, _rx_app_moderator) = tokio::sync::mpsc::unbounded_channel();
let (tx_session_layer_moderator, _rx_session_layer_moderator) =
tokio::sync::mpsc::channel(10);
let moderator_config = SessionConfig {
session_type: slim_datapath::api::ProtoSessionType::PointToPoint,
max_retries: Some(5),
interval: Some(Duration::from_millis(1000)),
mls_settings: Some(MlsSettings::default()),
initiator: true,
metadata: std::collections::HashMap::new(),
};
let (spy_moderator_mgr, mut rx_spy_moderator) = SpySubscriptionManager::new();
let moderator = SessionController::builder()
.with_id(session_id)
.with_source(moderator_name.clone())
.with_destination(participant_name.clone())
.with_config(moderator_config)
.with_identity_provider(SharedSecret::new("moderator", SHARED_SECRET).unwrap())
.with_identity_verifier(SharedSecret::new("moderator", SHARED_SECRET).unwrap())
.with_slim_tx(tx_slim_moderator.clone())
.with_app_tx(tx_app_moderator.clone())
.with_tx_to_session_layer(tx_session_layer_moderator)
.with_subscription_manager(spy_moderator_mgr)
.ready()
.expect("failed to validate builder")
.build()
.unwrap();
let (tx_slim_participant, mut rx_slim_participant) = tokio::sync::mpsc::channel(10);
let (tx_app_participant, mut rx_app_participant) = tokio::sync::mpsc::unbounded_channel();
let (tx_session_layer_participant, _rx_session_layer_participant) =
tokio::sync::mpsc::channel(10);
let participant_config = SessionConfig {
session_type: slim_datapath::api::ProtoSessionType::PointToPoint,
max_retries: Some(5),
interval: Some(Duration::from_millis(200)),
mls_settings: Some(MlsSettings::default()),
initiator: false,
metadata: std::collections::HashMap::new(),
};
let (spy_participant_mgr, mut rx_spy_participant) = SpySubscriptionManager::new();
let participant = SessionController::builder()
.with_id(session_id)
.with_source(participant_name_id.clone())
.with_destination(moderator_name.clone())
.with_config(participant_config)
.with_identity_provider(SharedSecret::new("participant", SHARED_SECRET).unwrap())
.with_identity_verifier(SharedSecret::new("participant", SHARED_SECRET).unwrap())
.with_slim_tx(tx_slim_participant.clone())
.with_app_tx(tx_app_participant.clone())
.with_tx_to_session_layer(tx_session_layer_participant)
.with_subscription_manager(spy_participant_mgr)
.ready()
.expect("failed to validate builder")
.build()
.unwrap();
let completion_handle = moderator
.invite_participant_internal(&participant_name)
.await
.expect("error inviting participant");
let received_discovery_request =
timeout(Duration::from_millis(100), rx_slim_moderator.recv())
.await
.expect("timeout waiting for discovery request on moderator slim channel")
.expect("channel closed")
.expect("error in discovery request");
assert_eq!(
received_discovery_request.get_session_message_type(),
slim_datapath::api::ProtoSessionMessageType::DiscoveryRequest
);
let discovery_msg_id = received_discovery_request.get_id();
let mut discovery_reply = Message::builder()
.source(participant_name_id.clone())
.destination(moderator_name.clone())
.identity(test_identity())
.forward_to(1)
.session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
.session_message_type(slim_datapath::api::ProtoSessionMessageType::DiscoveryReply)
.session_id(session_id)
.message_id(discovery_msg_id)
.payload(CommandPayload::builder().discovery_reply().as_content())
.build_publish()
.unwrap();
discovery_reply
.get_slim_header_mut()
.set_incoming_conn(Some(1));
moderator
.on_message_from_slim(discovery_reply)
.await
.expect("error processing discovery reply on moderator");
assert_eq!(
rx_spy_moderator.recv().await,
Some(SubscriptionCall::SetRoute),
"moderator should set route after discovery reply"
);
let join_request = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
.await
.expect("timeout waiting for join request on moderator slim channel")
.expect("channel closed")
.expect("error in join request");
assert_eq!(
join_request.get_session_message_type(),
slim_datapath::api::ProtoSessionMessageType::JoinRequest
);
assert_eq!(join_request.get_dst(), participant_name_id);
let mut join_request_to_participant = join_request.clone();
join_request_to_participant
.get_slim_header_mut()
.set_incoming_conn(Some(1));
participant
.on_message_from_slim(join_request_to_participant)
.await
.expect("error processing join request on participant");
assert_eq!(
rx_spy_participant.recv().await,
Some(SubscriptionCall::SetRoute),
"participant should set route after join request"
);
let join_reply = timeout(Duration::from_millis(100), rx_slim_participant.recv())
.await
.expect("timeout waiting for join reply on participant slim channel")
.expect("channel closed")
.expect("error in join reply");
assert_eq!(
join_reply.get_session_message_type(),
slim_datapath::api::ProtoSessionMessageType::JoinReply
);
assert_eq!(join_reply.get_dst(), moderator_name);
let mut join_reply_to_moderator = join_reply.clone();
join_reply_to_moderator
.get_slim_header_mut()
.set_incoming_conn(Some(1));
moderator
.on_message_from_slim(join_reply_to_moderator)
.await
.expect("error processing join reply on moderator");
let welcome_message = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
.await
.expect("timeout waiting for welcome message on moderator slim channel")
.expect("channel closed")
.expect("error in welcome message");
assert_eq!(
welcome_message.get_session_message_type(),
slim_datapath::api::ProtoSessionMessageType::GroupWelcome
);
assert_eq!(welcome_message.get_dst(), participant_name_id);
let mut welcome_to_participant = welcome_message.clone();
welcome_to_participant
.get_slim_header_mut()
.set_incoming_conn(Some(1));
participant
.on_message_from_slim(welcome_to_participant)
.await
.expect("error processing welcome message on participant");
let ack_group = timeout(Duration::from_millis(100), rx_slim_participant.recv())
.await
.expect("timeout waiting for ack group on participant slim channel")
.expect("channel closed")
.expect("error in ack group");
assert_eq!(
ack_group.get_session_message_type(),
slim_datapath::api::ProtoSessionMessageType::GroupAck
);
assert_eq!(ack_group.get_dst(), moderator_name);
let mut ack_to_moderator = ack_group.clone();
ack_to_moderator
.get_slim_header_mut()
.set_incoming_conn(Some(1));
moderator
.on_message_from_slim(ack_to_moderator)
.await
.expect("error processing ack group on moderator");
let no_more_moderator = timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
assert!(
no_more_moderator.is_err(),
"Expected no more messages on moderator slim channel, received {:?}",
no_more_moderator
.ok()
.and_then(|opt| opt)
.and_then(|res| res.ok())
);
let no_more_participant =
timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
assert!(
no_more_participant.is_err(),
"Expected no more messages on participant slim channel"
);
completion_handle.await.expect("error in completion handle");
let app_data = b"Hello from moderator to participant".to_vec();
let app_message = Message::builder()
.source(moderator_name.clone())
.destination(participant_name.clone())
.identity(test_identity())
.session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
.session_message_type(slim_datapath::api::ProtoSessionMessageType::Msg)
.session_id(session_id)
.message_id(1)
.application_payload("test-app-data", app_data.clone())
.build_publish()
.unwrap();
moderator
.on_message_from_app(app_message)
.await
.expect("error sending application message from moderator");
let app_msg_to_slim = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
.await
.expect("timeout waiting for application message on moderator slim channel")
.expect("channel closed")
.expect("error in application message");
assert_eq!(app_msg_to_slim.get_dst(), participant_name_id);
assert!(
app_msg_to_slim.is_publish(),
"message should be a publish message"
);
let app_msg_id = app_msg_to_slim.get_id();
let mut app_msg_to_participant = app_msg_to_slim.clone();
app_msg_to_participant
.get_slim_header_mut()
.set_incoming_conn(Some(1));
participant
.on_message_from_slim(app_msg_to_participant)
.await
.expect("error processing application message on participant");
let app_msg_received = timeout(Duration::from_millis(100), rx_app_participant.recv())
.await
.expect("timeout waiting for application message on participant app channel")
.expect("channel closed")
.expect("error in application message to app");
assert_eq!(app_msg_received.get_source(), moderator_name);
assert!(
app_msg_received.is_publish(),
"message should be a publish message"
);
let content = app_msg_received
.get_payload()
.unwrap()
.as_application_payload()
.unwrap()
.blob
.clone();
assert_eq!(content, app_data);
let ack_msg = timeout(Duration::from_millis(100), rx_slim_participant.recv())
.await
.expect("timeout waiting for ack on participant slim channel")
.expect("channel closed")
.expect("error in ack");
assert_eq!(
ack_msg.get_session_message_type(),
slim_datapath::api::ProtoSessionMessageType::MsgAck,
"message should be an ack"
);
assert_eq!(ack_msg.get_dst(), moderator_name);
assert_eq!(ack_msg.get_id(), app_msg_id);
let mut ack_to_moderator = ack_msg.clone();
ack_to_moderator
.get_slim_header_mut()
.set_incoming_conn(Some(1));
moderator
.on_message_from_slim(ack_to_moderator)
.await
.expect("error processing ack on moderator");
let no_more_moderator_after_ack =
timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
assert!(
no_more_moderator_after_ack.is_err(),
"Expected no more messages on moderator slim channel after ack"
);
let no_more_participant_after_ack =
timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
assert!(
no_more_participant_after_ack.is_err(),
"Expected no more messages on participant slim channel after ack"
);
let leave_request = Message::builder()
.source(moderator_name.clone())
.destination(participant_name.clone())
.identity(test_identity())
.session_type(slim_datapath::api::ProtoSessionType::PointToPoint)
.session_message_type(slim_datapath::api::ProtoSessionMessageType::LeaveRequest)
.session_id(session_id)
.message_id(rand::random::<u32>())
.payload(CommandPayload::builder().leave_request().as_content())
.build_publish()
.unwrap();
moderator
.on_message_from_app(leave_request)
.await
.expect("error sending leave request");
let received_leave_request = timeout(Duration::from_millis(100), rx_slim_moderator.recv())
.await
.expect("timeout waiting for leave request on moderator slim channel")
.expect("channel closed")
.expect("error in leave request");
assert_eq!(
received_leave_request.get_session_message_type(),
slim_datapath::api::ProtoSessionMessageType::LeaveRequest
);
assert_eq!(received_leave_request.get_dst(), participant_name_id);
let mut leave_request_to_participant = received_leave_request.clone();
leave_request_to_participant
.get_slim_header_mut()
.set_incoming_conn(Some(1));
participant
.on_message_from_slim(leave_request_to_participant)
.await
.expect("error processing leave request on participant");
let leave_reply = timeout(Duration::from_millis(100), rx_slim_participant.recv())
.await
.expect("timeout waiting for leave reply on participant slim channel")
.expect("channel closed")
.expect("error in leave reply");
assert_eq!(
leave_reply.get_session_message_type(),
slim_datapath::api::ProtoSessionMessageType::LeaveReply
);
assert_eq!(leave_reply.get_dst(), moderator_name);
assert_eq!(
rx_spy_participant.recv().await,
Some(SubscriptionCall::RemoveRoute),
"participant should remove route after leave request"
);
let mut leave_reply_to_moderator = leave_reply.clone();
leave_reply_to_moderator
.get_slim_header_mut()
.set_incoming_conn(Some(1));
moderator
.on_message_from_slim(leave_reply_to_moderator)
.await
.expect("error processing leave reply on moderator");
assert_eq!(
rx_spy_moderator.recv().await,
Some(SubscriptionCall::RemoveRoute),
"moderator should remove route after leave reply"
);
let no_more_moderator_final =
timeout(Duration::from_millis(100), rx_slim_moderator.recv()).await;
assert!(
no_more_moderator_final.is_err(),
"Expected no more messages on moderator slim channel after leave"
);
let no_more_participant_final =
timeout(Duration::from_millis(100), rx_slim_participant.recv()).await;
assert!(
no_more_participant_final.is_err(),
"Expected no more messages on participant slim channel after leave"
);
}
#[traced_test]
#[tokio::test]
async fn test_internal_draining_via_processing_state_switch() {
use super::*;
use tokio::sync::mpsc;
use tracing::debug;
struct InternalDrainHandler {
state: ProcessingState,
messages: Vec<SessionMessage>,
needs_drain: Arc<AtomicBool>,
}
impl InternalDrainHandler {
fn new(needs_drain: Arc<AtomicBool>) -> Self {
Self {
state: ProcessingState::Active,
messages: vec![],
needs_drain,
}
}
}
impl MessageHandler for InternalDrainHandler {
async fn init(&mut self) -> Result<(), SessionError> {
Ok(())
}
async fn on_message(
&mut self,
message: SessionMessage,
) -> Result<SessionOutput, SessionError> {
debug!(?self.state, "internal-drain-handler received message");
self.messages.push(message);
if self.messages.len() == 2 {
debug!("internal-drain-handler transitioning to draining");
self.state = ProcessingState::Draining;
}
Ok(SessionOutput::new())
}
fn needs_drain(&self) -> bool {
self.needs_drain.load(std::sync::atomic::Ordering::SeqCst)
}
fn processing_state(&self) -> ProcessingState {
self.state
}
async fn on_shutdown(&mut self) -> Result<(), SessionError> {
debug!("shutdown called on handler");
Ok(())
}
}
let (tx_slim, _rx_slim) = mpsc::channel(8);
let (tx_app, _rx_app) = mpsc::unbounded_channel();
let (tx_session, rx_session) = mpsc::channel(32);
let (tx_session_layer, _rx_session_layer) = mpsc::channel(8);
let subscription_manager =
crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
let settings = SessionSettings {
id: 999,
source: ProtoName::from_strings(["org", "ns", "source"]).with_id(1),
destination: ProtoName::from_strings(["org", "ns", "dest"]).with_id(2),
control: ProtoName::from_strings(["org", "ns", "dest"]).with_id(2),
config: SessionConfig {
session_type: ProtoSessionType::PointToPoint,
max_retries: Some(3),
interval: Some(Duration::from_millis(150)),
mls_settings: None,
initiator: true,
metadata: HashMap::new(),
},
direction: Direction::Bidirectional,
slim_tx: tx_slim,
app_tx: tx_app,
tx_session: tx_session.clone(),
tx_to_session_layer: tx_session_layer,
identity_provider: SharedSecret::new("src", SHARED_SECRET).unwrap(),
identity_verifier: SharedSecret::new("src", SHARED_SECRET).unwrap(),
graceful_shutdown_timeout: Some(Duration::from_secs(10)),
subscription_manager,
service_id: String::new(),
};
let needs_drain = Arc::new(AtomicBool::new(true));
let handler = InternalDrainHandler::new(needs_drain.clone());
let cancellation_token = CancellationToken::new();
let cancellation_token_clone = cancellation_token.clone();
let processing_handle = tokio::spawn(async move {
SessionController::processing_loop(
handler,
rx_session,
cancellation_token_clone,
settings,
)
.await
});
tx_session
.send(create_test_message(1, b"first".to_vec()))
.await
.expect("failed to send first message");
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(logs_contain("internal-drain-handler received message"));
tx_session
.send(create_test_message(2, b"second".to_vec()))
.await
.expect("failed to send second message");
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(logs_contain("internal-drain-handler received message"));
assert!(logs_contain(
"internal-drain-handler transitioning to draining"
));
tx_session
.send(create_test_message(3, b"third".to_vec()))
.await
.expect("failed to send third message");
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(logs_contain(
"session is draining, rejecting new messages from application"
));
needs_drain.store(false, std::sync::atomic::Ordering::SeqCst);
cancellation_token.cancel();
tokio::time::sleep(Duration::from_millis(100)).await;
tx_session
.send(SessionMessage::StartDrain {
grace_period: std::time::Duration::from_millis(100),
})
.await
.expect("failed to send timeout message");
processing_handle.await.expect("processing loop panicked");
}
struct DrainableHandler {
messages_received: Arc<tokio::sync::Mutex<Vec<SessionMessage>>>,
needs_drain: Arc<AtomicBool>,
shutdown_called: Arc<tokio::sync::Mutex<bool>>,
drain_delay: Option<Duration>,
}
impl DrainableHandler {
fn new() -> Self {
Self {
messages_received: Arc::new(tokio::sync::Mutex::new(Vec::new())),
needs_drain: Arc::new(AtomicBool::new(false)),
shutdown_called: Arc::new(tokio::sync::Mutex::new(false)),
drain_delay: None,
}
}
fn with_needs_drain(self, needs_drain: bool) -> Self {
self.needs_drain
.store(needs_drain, std::sync::atomic::Ordering::SeqCst);
self
}
#[allow(dead_code)]
fn with_drain_delay(mut self, delay: Duration) -> Self {
self.drain_delay = Some(delay);
self
}
#[allow(dead_code)]
async fn get_messages_count(&self) -> usize {
self.messages_received.lock().await.len()
}
#[allow(dead_code)]
async fn was_shutdown_called(&self) -> bool {
*self.shutdown_called.lock().await
}
}
impl MessageHandler for DrainableHandler {
async fn init(&mut self) -> Result<(), SessionError> {
Ok(())
}
async fn on_message(
&mut self,
message: SessionMessage,
) -> Result<SessionOutput, SessionError> {
self.messages_received.lock().await.push(message);
Ok(SessionOutput::new())
}
fn needs_drain(&self) -> bool {
self.needs_drain.load(std::sync::atomic::Ordering::SeqCst)
}
async fn on_shutdown(&mut self) -> Result<(), SessionError> {
if let Some(delay) = self.drain_delay {
tokio::time::sleep(delay).await;
}
*self.shutdown_called.lock().await = true;
Ok(())
}
}
fn create_test_settings(
graceful_shutdown_timeout: Option<Duration>,
) -> SessionSettings<SharedSecret, SharedSecret> {
let (tx_slim, _rx_slim) = tokio::sync::mpsc::channel(10);
let (tx_app, _rx_app) = tokio::sync::mpsc::unbounded_channel();
let (tx_session, _rx_session) = tokio::sync::mpsc::channel(10);
let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10);
let subscription_manager =
crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
SessionSettings {
id: 1,
source: ProtoName::from_strings(["org", "ns", "test"]).with_id(1),
destination: ProtoName::from_strings(["org", "ns", "test"]).with_id(2),
control: ProtoName::from_strings(["org", "ns", "test"]).with_id(2),
config: SessionConfig {
session_type: ProtoSessionType::PointToPoint,
max_retries: Some(5),
interval: Some(Duration::from_millis(200)),
mls_settings: None,
initiator: true,
metadata: HashMap::new(),
},
direction: Direction::Bidirectional,
slim_tx: tx_slim,
app_tx: tx_app,
tx_session,
tx_to_session_layer: tx_session_layer,
identity_provider: SharedSecret::new("test", SHARED_SECRET).unwrap(),
identity_verifier: SharedSecret::new("test", SHARED_SECRET).unwrap(),
graceful_shutdown_timeout,
subscription_manager,
service_id: String::new(),
}
}
fn create_test_message(message_id: u32, payload: Vec<u8>) -> SessionMessage {
SessionMessage::OnMessage {
message: Message::builder()
.source(ProtoName::from_strings(["org", "ns", "test"]).with_id(1))
.destination(ProtoName::from_strings(["org", "ns", "test"]).with_id(2))
.identity(test_identity())
.forward_to(1)
.session_type(ProtoSessionType::PointToPoint)
.session_message_type(ProtoSessionMessageType::Msg)
.session_id(1)
.message_id(message_id)
.application_payload("test", payload)
.build_publish()
.unwrap(),
direction: MessageDirection::South,
ack_tx: None,
}
}
async fn count_on_messages(messages: &Arc<tokio::sync::Mutex<Vec<SessionMessage>>>) -> usize {
let messages = messages.lock().await;
messages
.iter()
.filter(|msg| matches!(msg, SessionMessage::OnMessage { .. }))
.count()
}
fn spawn_processing_loop(
handler: DrainableHandler,
rx: tokio::sync::mpsc::Receiver<SessionMessage>,
cancellation_token: CancellationToken,
settings: SessionSettings<SharedSecret, SharedSecret>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
SessionController::processing_loop(handler, rx, cancellation_token, settings).await;
})
}
#[tokio::test]
async fn test_draining_processes_queued_messages() {
let handler = DrainableHandler::new();
let messages_received = handler.messages_received.clone();
let shutdown_called = handler.shutdown_called.clone();
let (tx, rx) = tokio::sync::mpsc::channel(10);
let cancellation_token = CancellationToken::new();
let token_clone = cancellation_token.clone();
let settings = create_test_settings(Some(Duration::from_secs(2)));
let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
tx.send(create_test_message(1, vec![1, 2, 3]))
.await
.unwrap();
tx.send(create_test_message(2, vec![4, 5, 6]))
.await
.unwrap();
tx.send(create_test_message(3, vec![7, 8, 9]))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
token_clone.cancel();
drop(tx);
timeout(Duration::from_secs(3), processing_task)
.await
.expect("timeout waiting for processing loop")
.expect("processing loop panicked");
let processed_messages = count_on_messages(&messages_received).await;
assert_eq!(
processed_messages, 3,
"All queued messages should be processed during draining"
);
assert!(
*shutdown_called.lock().await,
"Shutdown should have been called"
);
}
#[tokio::test]
async fn test_draining_with_needs_drain_true() {
let handler = DrainableHandler::new().with_needs_drain(true);
let messages_received = handler.messages_received.clone();
let shutdown_called = handler.shutdown_called.clone();
let (tx, rx) = tokio::sync::mpsc::channel(10);
let cancellation_token = CancellationToken::new();
let token_clone = cancellation_token.clone();
let settings = create_test_settings(Some(Duration::from_secs(2)));
let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
tx.send(create_test_message(1, vec![1, 2, 3]))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
token_clone.cancel();
drop(tx);
timeout(Duration::from_secs(3), processing_task)
.await
.expect("timeout waiting for processing loop")
.expect("processing loop panicked");
let processed_messages = count_on_messages(&messages_received).await;
assert_eq!(processed_messages, 1, "Message should be processed");
assert!(
*shutdown_called.lock().await,
"Shutdown should have been called after draining"
);
}
#[tokio::test]
async fn test_draining_with_needs_drain_false() {
let handler = DrainableHandler::new().with_needs_drain(false);
let messages_received = handler.messages_received.clone();
let shutdown_called = handler.shutdown_called.clone();
let (tx, rx) = tokio::sync::mpsc::channel(10);
let cancellation_token = CancellationToken::new();
let token_clone = cancellation_token.clone();
let settings = create_test_settings(Some(Duration::from_secs(2)));
let start_time = tokio::time::Instant::now();
let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
tx.send(create_test_message(1, vec![1, 2, 3]))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
token_clone.cancel();
drop(tx);
timeout(Duration::from_secs(1), processing_task)
.await
.expect("timeout waiting for processing loop")
.expect("processing loop panicked");
let elapsed = start_time.elapsed();
let processed_messages = count_on_messages(&messages_received).await;
assert_eq!(processed_messages, 1, "Message should be processed");
assert!(
*shutdown_called.lock().await,
"Shutdown should have been called"
);
assert!(
elapsed < Duration::from_millis(500),
"Should exit quickly when no draining needed"
);
}
#[tokio::test]
async fn test_draining_timeout_enforced() {
let handler = DrainableHandler::new().with_needs_drain(true);
let (tx, rx) = tokio::sync::mpsc::channel(10);
let cancellation_token = CancellationToken::new();
let token_clone = cancellation_token.clone();
let settings = create_test_settings(Some(Duration::from_millis(500)));
let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
tokio::time::sleep(Duration::from_millis(50)).await;
let start_time = tokio::time::Instant::now();
token_clone.cancel();
let send_task = tokio::spawn(async move {
for i in 0..10 {
tokio::time::sleep(Duration::from_millis(100)).await;
if tx
.send(create_test_message(i, vec![i as u8]))
.await
.is_err()
{
break;
}
}
});
timeout(Duration::from_secs(2), processing_task)
.await
.expect("timeout waiting for processing loop")
.expect("processing loop panicked");
let elapsed = start_time.elapsed();
assert!(
elapsed >= Duration::from_millis(400),
"Should wait at least close to the timeout period"
);
assert!(
elapsed < Duration::from_secs(2),
"Should respect the timeout and exit, not wait forever"
);
send_task.abort();
}
#[tokio::test]
async fn test_draining_no_messages_in_queue() {
let handler = DrainableHandler::new().with_needs_drain(true);
let messages_received = handler.messages_received.clone();
let shutdown_called = handler.shutdown_called.clone();
let (tx, rx) = tokio::sync::mpsc::channel(10);
let cancellation_token = CancellationToken::new();
let token_clone = cancellation_token.clone();
let settings = create_test_settings(Some(Duration::from_secs(1)));
let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
token_clone.cancel();
drop(tx);
timeout(Duration::from_secs(2), processing_task)
.await
.expect("timeout waiting for processing loop")
.expect("processing loop panicked");
let processed_messages = count_on_messages(&messages_received).await;
assert_eq!(processed_messages, 0, "No messages should be processed");
assert!(
*shutdown_called.lock().await,
"Shutdown should still be called"
);
}
#[tokio::test]
async fn test_draining_messages_after_cancellation_processed() {
let handler = DrainableHandler::new();
let messages_received = handler.messages_received.clone();
let (tx, rx) = tokio::sync::mpsc::channel(10);
let cancellation_token = CancellationToken::new();
let token_clone = cancellation_token.clone();
let settings = create_test_settings(Some(Duration::from_secs(2)));
tx.send(create_test_message(1, vec![1, 2, 3]))
.await
.unwrap();
tx.send(create_test_message(2, vec![4, 5, 6]))
.await
.unwrap();
let processing_task = spawn_processing_loop(handler, rx, cancellation_token, settings);
tokio::time::sleep(Duration::from_millis(50)).await;
token_clone.cancel();
drop(tx);
timeout(Duration::from_secs(3), processing_task)
.await
.expect("timeout waiting for processing loop")
.expect("processing loop panicked");
let processed_messages = count_on_messages(&messages_received).await;
assert_eq!(
processed_messages, 2,
"Messages in queue during cancellation should be processed"
);
}
}