use std::{collections::HashMap, time::Duration};
use slim_auth::traits::{TokenProvider, Verifier};
use slim_datapath::{
api::{
CommandPayload, Participant, ParticipantSettings, ProtoMessage as Message, ProtoName,
ProtoSessionMessageType, ProtoSessionType,
},
messages::utils::{LEAVING_SESSION, TRUE_VAL},
};
use slim_mls::mls::Mls;
use tracing::debug;
use crate::{
common::{MessageDirection, SessionMessage, SessionOutput},
errors::SessionError,
mls_state::MlsState,
runtime::maybe_await,
session_controller::SessionControllerCommon,
session_settings::SessionSettings,
subscription_manager::{SubscriptionManager, SubscriptionOps},
traits::{MessageHandler, ProcessingState},
};
pub struct SessionParticipant<P, V, I, M = SubscriptionManager>
where
P: TokenProvider + Send + Sync + Clone + 'static,
V: Verifier + Send + Sync + Clone + 'static,
I: MessageHandler + Send + Sync + 'static,
M: SubscriptionOps,
{
moderator_name: Option<ProtoName>,
group_list: HashMap<ProtoName, ParticipantSettings>,
mls_state: Option<MlsState<P, V>>,
common: SessionControllerCommon<P, V, M>,
conn_id: Option<u64>,
subscribed: bool,
pending_leave_cleanup: bool,
inner: I,
}
impl<P, V, I, M> SessionParticipant<P, V, I, M>
where
P: TokenProvider + Send + Sync + Clone + 'static,
V: Verifier + Send + Sync + Clone + 'static,
I: MessageHandler + Send + Sync + 'static,
M: SubscriptionOps,
{
pub(crate) fn new(inner: I, settings: SessionSettings<P, V, M>) -> Self {
let common = SessionControllerCommon::new(settings);
SessionParticipant {
moderator_name: None,
group_list: HashMap::new(),
mls_state: None,
common,
conn_id: None,
subscribed: false,
pending_leave_cleanup: false,
inner,
}
}
}
impl<P, V, I, M> MessageHandler for SessionParticipant<P, V, I, M>
where
P: TokenProvider + Send + Sync + Clone + 'static,
V: Verifier + Send + Sync + Clone + 'static,
I: MessageHandler + Send + Sync + 'static,
M: SubscriptionOps,
{
async fn init(&mut self) -> Result<(), SessionError> {
self.mls_state = if let Some(mls_settings) = &self.common.settings.config.mls_settings {
let mls_state = MlsState::new(
Mls::new(
self.common.settings.identity_provider.clone(),
self.common.settings.identity_verifier.clone(),
),
mls_settings.header_integrity_validation_percent,
)
.await
.expect("failed to create MLS state");
Some(mls_state)
} else {
None
};
Ok(())
}
async fn on_message(&mut self, message: SessionMessage) -> Result<SessionOutput, SessionError> {
let mut output = SessionOutput::new();
match message {
SessionMessage::OnMessage {
mut message,
direction,
ack_tx,
} => {
if message.get_session_message_type().is_command_message() {
debug!(
message = ?message.get_session_message_type(),
source = %message.get_source(),
"received message",
);
output.extend(self.process_control_message(message).await?);
} else {
if direction == MessageDirection::North
&& let Some(mls_state) = &mut self.mls_state
{
maybe_await!(mls_state.process_message(&mut message, direction))?;
}
let inner_output = self
.inner
.on_message(SessionMessage::OnMessage {
message,
direction,
ack_tx,
})
.await?;
output.extend(inner_output);
}
}
SessionMessage::MessageError { error } => {
output.extend(self.handle_message_error(error).await?);
}
SessionMessage::TimerTimeout {
message_id,
message_type,
name,
timeouts,
} => {
if message_type.is_command_message() {
output.extend(
self.common
.sender
.on_timer_timeout(message_id, message_type)?,
);
} else {
let inner_output = self
.inner
.on_message(SessionMessage::TimerTimeout {
message_id,
message_type,
name,
timeouts,
})
.await?;
output.extend(inner_output);
}
}
SessionMessage::TimerFailure {
message_id,
message_type,
name,
timeouts,
} => {
if message_type.is_command_message() {
self.common.sender.on_failure(message_id, message_type);
} else {
output.extend(
self.inner
.on_message(SessionMessage::TimerFailure {
message_id,
message_type,
name,
timeouts,
})
.await?,
);
}
}
SessionMessage::StartDrain {
grace_period: duration,
} => {
debug!("received drain signal");
let p = CommandPayload::builder().leave_request().as_content();
if let Some(moderator) = &self.moderator_name {
let mut msg = self.common.create_control_message(
moderator,
ProtoSessionMessageType::LeaveRequest,
rand::random::<u32>(),
p,
false,
)?;
debug!("start drain and notify the moderator");
msg.insert_metadata(LEAVING_SESSION.to_string(), TRUE_VAL.to_string());
self.disconnect_from_group().await?;
output.extend(self.common.sender.on_message(&msg)?);
}
self.common.processing_state = ProcessingState::Draining;
output.extend(
self.inner
.on_message(SessionMessage::StartDrain {
grace_period: duration,
})
.await?,
);
self.common.sender.start_drain();
}
SessionMessage::ParticipantDisconnected { name: _ } => {
debug!("The moderator is not anymore connected to the current session, close it",);
self.common.processing_state = ProcessingState::Draining;
output.extend(
self.inner
.on_message(SessionMessage::StartDrain {
grace_period: Duration::from_secs(1),
})
.await?,
);
self.common.sender.start_drain();
}
SessionMessage::LeaveCleanup => {
self.disconnect_from_group().await?;
self.disconnect_from_moderator().await?;
self.pending_leave_cleanup = false;
}
_ => {
return Err(SessionError::SessionMessageInternalUnexpected(Box::new(
message,
)));
}
}
maybe_await!(self.encrypt_output(&mut output))?;
Ok(output)
}
async fn add_endpoint(
&mut self,
endpoint: &Participant,
) -> Result<SessionOutput, SessionError> {
self.inner.add_endpoint(endpoint).await
}
fn remove_endpoint(&mut self, endpoint: &ProtoName) {
self.inner.remove_endpoint(endpoint);
}
fn needs_drain(&self) -> bool {
self.pending_leave_cleanup
|| !self.common.sender.drain_completed()
|| self.inner.needs_drain()
}
fn processing_state(&self) -> ProcessingState {
self.common.processing_state
}
fn participants_list(&self) -> Vec<ProtoName> {
self.group_list.keys().cloned().collect()
}
async fn on_shutdown(&mut self) -> Result<(), SessionError> {
self.subscribed = false;
self.common.sender.close();
MessageHandler::on_shutdown(&mut self.inner).await
}
}
impl<P, V, I, M> SessionParticipant<P, V, I, M>
where
P: TokenProvider + Send + Sync + Clone + 'static,
V: Verifier + Send + Sync + Clone + 'static,
I: MessageHandler + Send + Sync + 'static,
M: SubscriptionOps,
{
#[maybe_async::maybe_async]
async fn encrypt_output(&mut self, output: &mut SessionOutput) -> Result<(), SessionError> {
crate::session_controller::SessionController::apply_identity_to_slim_output(
output,
&self.common.settings.identity_provider,
)?;
if let Some(mls_state) = &mut self.mls_state {
mls_state.encrypt_output(output).await?;
}
Ok(())
}
async fn handle_message_error(
&mut self,
error: SessionError,
) -> Result<SessionOutput, SessionError> {
let Some(session_ctx) = error.session_context() else {
tracing::warn!("Received MessageError without session context");
return self
.inner
.on_message(SessionMessage::MessageError { error })
.await;
};
if error.is_command_message_error() {
self.common.sender.on_failure(
session_ctx.message_id,
session_ctx.get_session_message_type(),
);
Ok(SessionOutput::new())
} else {
self.inner
.on_message(SessionMessage::MessageError { error })
.await
}
}
async fn process_control_message(
&mut self,
message: Message,
) -> Result<SessionOutput, SessionError> {
match message.get_session_message_type() {
ProtoSessionMessageType::JoinRequest => self.on_join_request(message).await,
ProtoSessionMessageType::GroupWelcome => self.on_welcome(message).await,
ProtoSessionMessageType::GroupAdd => self.on_group_update_message(message, true).await,
ProtoSessionMessageType::GroupRemove => {
self.on_group_update_message(message, false).await
}
ProtoSessionMessageType::LeaveRequest | ProtoSessionMessageType::GroupClose => {
self.on_leave_request(message).await
}
ProtoSessionMessageType::Ping => self.on_ping(message).await,
ProtoSessionMessageType::LeaveReply => {
if self.common.processing_state == ProcessingState::Draining {
return self.common.sender.on_message(&message);
}
Ok(SessionOutput::new())
}
ProtoSessionMessageType::GroupProposal
| ProtoSessionMessageType::GroupAck
| ProtoSessionMessageType::GroupNack => todo!(),
ProtoSessionMessageType::DiscoveryRequest
| ProtoSessionMessageType::DiscoveryReply
| ProtoSessionMessageType::JoinReply => {
debug!(
control_message_type = ?message.get_session_message_type(),
"Unexpected control message type",
);
Ok(SessionOutput::new())
}
_ => {
debug!(
message_type = ?message.get_session_message_type(),
"Unexpected message type",
);
Ok(SessionOutput::new())
}
}
}
async fn on_join_request(&mut self, msg: Message) -> Result<SessionOutput, SessionError> {
debug!(
name = %self.common.settings.source,
id = %msg.get_id(),
"received join request",
);
let source = msg.get_source();
self.moderator_name = Some(source.clone());
self.common
.add_route(source.clone(), msg.get_incoming_conn())
.await?;
let key_package = if let Some(mls_state) = &mut self.mls_state {
debug!("mls enabled, create the package key");
let key = maybe_await!(mls_state.generate_key_package())?;
Some(key)
} else {
None
};
let participant = Participant::new(
self.common.settings.source.clone(),
self.common.settings.direction.to_participant_settings(),
);
let content = CommandPayload::builder()
.join_reply(key_package, participant)
.as_content();
debug!("send join reply message");
let reply = self.common.create_control_message(
&source,
ProtoSessionMessageType::JoinReply,
msg.get_id(),
content,
false,
)?;
Ok(SessionOutput::to_slim(reply))
}
async fn on_welcome(&mut self, msg: Message) -> Result<SessionOutput, SessionError> {
debug!(
name = %self.common.settings.source,
id = %msg.get_id(),
"received welcome message",
);
if let Some(mls_state) = &mut self.mls_state {
maybe_await!(mls_state.process_welcome_message(&msg))?;
}
self.join(&msg).await?;
let list = &msg
.get_payload()
.unwrap()
.as_command_payload()?
.as_welcome_payload()?
.participants;
for p in list {
let name = p.get_name()?;
self.group_list.insert(name.clone(), *p.get_settings()?);
if name != self.common.settings.source.clone() {
debug!(name = %msg.get_source(), "add endpoint to the session");
if self.moderator_name.as_ref() != Some(&name) {
self.common
.add_route(name.clone(), msg.get_incoming_conn())
.await?;
}
self.add_endpoint(p).await?;
}
}
let ack = self.common.create_control_message(
&msg.get_source(),
ProtoSessionMessageType::GroupAck,
msg.get_id(),
CommandPayload::builder().group_ack().as_content(),
false,
)?;
Ok(SessionOutput::to_slim(ack))
}
async fn on_group_update_message(
&mut self,
msg: Message,
add: bool,
) -> Result<SessionOutput, SessionError> {
debug!(
name = %self.common.settings.source,
id = %msg.get_id(),
"received update",
);
if let Some(mls_state) = &mut self.mls_state {
debug!("process mls control update");
let source_proto = self.common.settings.source.clone();
let ret = maybe_await!(mls_state.process_control_message(msg.clone(), &source_proto))?;
if !ret {
debug!(
id = %msg.get_id(),
"Message already processed, drop it",
);
return Ok(SessionOutput::new());
}
}
if add {
let p = msg
.get_payload()
.unwrap()
.as_command_payload()?
.as_group_add_payload()?;
if let Some(ref new_participant) = p.new_participant {
let name = new_participant.get_name()?;
self.group_list
.insert(name.clone(), *new_participant.get_settings()?);
debug!(name = %msg.get_source(), "add endpoint to session");
self.common.add_route(name, msg.get_incoming_conn()).await?;
self.add_endpoint(new_participant).await?;
}
} else {
let p = msg
.get_payload()
.unwrap()
.as_command_payload()?
.as_group_remove_payload()?;
if let Some(ref removed_participant) = p.removed_participant {
let name = removed_participant.clone();
self.group_list.remove(&name);
debug!(name = %msg.get_source(), "remove endpoint from session");
if name != self.common.settings.source.clone() {
self.common
.delete_route(name.clone(), msg.get_incoming_conn())
.await?;
}
self.inner.remove_endpoint(&name);
}
}
let msg = self.common.create_control_message(
&msg.get_source(),
ProtoSessionMessageType::GroupAck,
msg.get_id(),
CommandPayload::builder().group_ack().as_content(),
false,
)?;
Ok(SessionOutput::to_slim(msg))
}
async fn on_leave_request(&mut self, msg: Message) -> Result<SessionOutput, SessionError> {
debug!("close session");
self.common.processing_state = ProcessingState::Draining;
let (reply_type, reply_content) = match msg.get_session_message_type() {
ProtoSessionMessageType::GroupClose => {
self.on_shutdown().await?;
self.common.sender.close();
(
ProtoSessionMessageType::GroupAck,
CommandPayload::builder().group_ack().as_content(),
)
}
_ => {
self.inner
.on_message(SessionMessage::StartDrain {
grace_period: Duration::from_secs(60), })
.await?;
self.common.sender.start_drain();
(
ProtoSessionMessageType::LeaveReply,
CommandPayload::builder().leave_reply().as_content(),
)
}
};
let reply = self.common.create_control_message(
&msg.get_source(),
reply_type,
msg.get_id(),
reply_content,
false,
)?;
let output = SessionOutput::to_slim(reply);
self.common
.settings
.tx_to_session_layer
.send(Ok(SessionMessage::DeleteSession {
session_id: self.common.settings.id,
}))
.await
.map_err(|_e| SessionError::SessionDeleteMessageSendFailed)?;
self.pending_leave_cleanup = true;
self.common
.settings
.tx_session
.send(SessionMessage::LeaveCleanup)
.await
.map_err(|_| SessionError::SlimMessageSendFailed)?;
Ok(output)
}
async fn on_ping(&mut self, mut msg: Message) -> Result<SessionOutput, SessionError> {
debug!("received ping message, reply");
let mut output = self.common.sender.on_message(&msg)?;
let header = msg.get_slim_header_mut();
let src = header.get_source();
header.set_source(self.common.settings.source.clone());
header.set_destination(src);
output.extend(SessionOutput::to_slim(msg));
Ok(output)
}
async fn join(&mut self, msg: &Message) -> Result<(), SessionError> {
if self.subscribed {
return Ok(());
}
self.subscribed = true;
self.conn_id = Some(msg.get_incoming_conn());
if self.common.settings.config.session_type == ProtoSessionType::PointToPoint {
return Ok(());
}
let destination = self.common.settings.destination.clone();
let control = self.common.settings.control.clone();
self.common
.add_route(destination.clone(), msg.get_incoming_conn())
.await?;
self.common
.add_subscription(destination, msg.get_incoming_conn())
.await?;
self.common
.add_route(control.clone(), msg.get_incoming_conn())
.await?;
self.common
.add_subscription(control, msg.get_incoming_conn())
.await
}
async fn disconnect_from_group(&mut self) -> Result<(), SessionError> {
if self.common.settings.config.session_type == ProtoSessionType::PointToPoint {
return Ok(());
}
if let Some(conn_id) = self.conn_id {
self.common
.delete_route(self.common.settings.destination.clone(), conn_id)
.await?;
self.common
.delete_subscription(self.common.settings.destination.clone(), conn_id)
.await?;
self.common
.delete_route(self.common.settings.control.clone(), conn_id)
.await?;
self.common
.delete_subscription(self.common.settings.control.clone(), conn_id)
.await?;
}
for (n, _) in self.group_list.iter() {
if self.moderator_name.as_ref() != Some(n)
&& let Err(e) = self
.common
.delete_route(n.clone(), self.conn_id.unwrap())
.await
{
tracing::warn!(error = %e, name = %n, "error deleting route");
}
}
Ok(())
}
async fn disconnect_from_moderator(&mut self) -> Result<(), SessionError> {
if let Some(conn_id) = self.conn_id
&& let Err(e) = self
.common
.delete_route(self.moderator_name.as_ref().unwrap().clone(), conn_id)
.await
{
tracing::warn!(error = %e, name = ?self.moderator_name, "error disconnecting from moderator");
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Direction;
use crate::common::OutboundMessage;
use crate::session_config::SessionConfig;
use crate::session_settings::SessionSettings;
use crate::test_utils::{MockInnerHandler, MockTokenProvider, MockVerifier};
use slim_datapath::Status;
use slim_datapath::api::{CommandPayload, NameId, ProtoSessionType};
use tokio::sync::mpsc;
async fn run_with_acks<F, T>(
fut: F,
rx_slim: &mut mpsc::Receiver<Result<Message, Status>>,
sub_mgr: &crate::subscription_manager::SubscriptionManager,
) -> T
where
F: std::future::Future<Output = T>,
{
let mut pinned = Box::pin(fut);
loop {
tokio::select! {
res = &mut pinned => return res,
msg = rx_slim.recv() => {
if let Some(Ok(msg)) = msg && let Some(ack_id) = msg.get_subscription_id() {
let ack = Message::builder().build_subscription_ack(ack_id, true, "");
sub_mgr.resolve_ack(ack.get_subscription_ack());
}
}
}
}
}
fn make_name(parts: &[&str; 3]) -> ProtoName {
ProtoName::from_strings([parts[0], parts[1], parts[2]]).with_id(0)
}
fn make_proto_name(parts: &[&str; 3]) -> ProtoName {
ProtoName::from_strings([parts[0], parts[1], parts[2]]).with_id(0)
}
fn setup_participant(
session_type: ProtoSessionType,
) -> (
SessionParticipant<MockTokenProvider, MockVerifier, MockInnerHandler>,
mpsc::Receiver<Result<Message, Status>>,
mpsc::Receiver<Result<SessionMessage, SessionError>>,
mpsc::Receiver<SessionMessage>,
) {
let source = make_name(&["local", "participant", "v1"]);
let (destination, control) = match session_type {
ProtoSessionType::Multicast => (
make_name(&["channel", "name", "v1"]).with_id(NameId::DATA_CHANNEL_ID),
make_name(&["channel", "name", "v1"]).with_id(NameId::CONTROL_CHANNEL_ID),
),
ProtoSessionType::PointToPoint => (
make_name(&["remote", "participant", "v1"]).with_id(100),
make_name(&["remote", "participant", "v1"]).with_id(100),
),
_ => panic!("Unsupported session type for test setup"),
};
let identity_provider = MockTokenProvider;
let identity_verifier = MockVerifier;
let (tx_slim, rx_slim) = mpsc::channel(16);
let (tx_app, _rx_app) = mpsc::unbounded_channel();
let (tx_session, rx_session) = mpsc::channel(16);
let subscription_manager =
crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
let (tx_session_layer, rx_session_layer) = mpsc::channel(16);
let config = SessionConfig {
session_type,
max_retries: Some(3),
interval: Some(std::time::Duration::from_secs(1)),
mls_settings: None,
initiator: false,
metadata: Default::default(),
};
let settings = SessionSettings {
id: 1,
source,
destination,
control,
config,
direction: Direction::Bidirectional,
slim_tx: tx_slim,
app_tx: tx_app,
tx_session,
tx_to_session_layer: tx_session_layer,
identity_provider,
identity_verifier,
graceful_shutdown_timeout: None,
subscription_manager,
service_id: String::new(),
};
let inner = MockInnerHandler::new();
let participant = SessionParticipant::new(inner, settings);
(participant, rx_slim, rx_session_layer, rx_session)
}
#[tokio::test]
async fn test_participant_new() {
let (participant, _rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
assert!(participant.moderator_name.is_none());
assert!(participant.group_list.is_empty());
assert!(participant.mls_state.is_none());
assert!(!participant.subscribed);
}
#[tokio::test]
async fn test_participant_init() {
let (mut participant, _rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
let result = participant.init().await;
assert!(result.is_ok());
assert!(participant.mls_state.is_none()); }
#[tokio::test]
async fn test_participant_on_join_request() {
let (mut participant, mut rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
let moderator = make_name(&["moderator", "app", "v1"]).with_id(300);
let join_msg = Message::builder()
.source(moderator.clone())
.destination(participant.common.settings.source.clone())
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::JoinRequest)
.session_id(1)
.message_id(100)
.payload(
CommandPayload::builder()
.join_request(Some(3), Some(std::time::Duration::from_secs(1)), None, None)
.as_content(),
)
.build_publish()
.unwrap();
let sub_mgr = participant.common.settings.subscription_manager.clone();
let result = run_with_acks(
participant.on_join_request(join_msg),
&mut rx_slim,
&sub_mgr,
)
.await;
assert!(result.is_ok());
assert_eq!(participant.moderator_name, Some(moderator));
let output = result.unwrap();
assert!(
!output.is_empty(),
"Should have sent messages including join reply"
);
}
#[tokio::test]
async fn test_participant_on_welcome_multicast() {
let (mut participant, mut rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
let moderator = make_name(&["moderator", "app", "v1"]).with_id(300);
participant.moderator_name = Some(moderator.clone());
let participant1_name = make_name(&["participant1", "app", "v1"]).with_id(401);
let participant2_name = make_name(&["participant2", "app", "v1"]).with_id(402);
let participant1 = Participant::new(
participant1_name.clone(),
ParticipantSettings::bidirectional(),
);
let participant2 = Participant::new(
participant2_name.clone(),
ParticipantSettings::bidirectional(),
);
let welcome_msg = Message::builder()
.source(moderator.clone())
.destination(participant.common.settings.source.clone())
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::GroupWelcome)
.session_id(1)
.message_id(200)
.payload(
CommandPayload::builder()
.group_welcome(vec![participant1.clone(), participant2.clone()], None)
.as_content(),
)
.build_publish()
.unwrap();
let sub_mgr = participant.common.settings.subscription_manager.clone();
let result =
run_with_acks(participant.on_welcome(welcome_msg), &mut rx_slim, &sub_mgr).await;
assert!(result.is_ok());
assert!(participant.subscribed);
assert_eq!(participant.group_list.len(), 2);
assert_eq!(participant.inner.get_endpoints_added_count().await, 2);
}
#[tokio::test]
async fn test_participant_on_group_add_message() {
let (mut participant, mut rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
participant.subscribed = true;
let moderator = make_name(&["moderator", "app", "v1"]).with_id(300);
participant.moderator_name = Some(moderator.clone());
let new_participant_name = make_name(&["new_participant", "app", "v1"]).with_id(500);
let new_participant = Participant::new(
new_participant_name.clone(),
ParticipantSettings::bidirectional(),
);
let add_msg = Message::builder()
.source(moderator.clone())
.destination(participant.common.settings.destination.clone())
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::GroupAdd)
.session_id(1)
.message_id(300)
.payload(
CommandPayload::builder()
.group_add(new_participant.clone(), vec![], None)
.as_content(),
)
.build_publish()
.unwrap();
let sub_mgr = participant.common.settings.subscription_manager.clone();
let result = run_with_acks(
participant.on_group_update_message(add_msg, true),
&mut rx_slim,
&sub_mgr,
)
.await;
assert!(result.is_ok());
assert!(participant.group_list.contains_key(&new_participant_name));
assert_eq!(participant.inner.get_endpoints_added_count().await, 1);
let output = result.unwrap();
assert!(!output.is_empty(), "Should have sent group ack");
}
#[tokio::test]
async fn test_participant_on_group_remove_message() {
let (mut participant, mut rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
participant.subscribed = true;
let moderator = make_name(&["moderator", "app", "v1"]).with_id(300);
participant.moderator_name = Some(moderator.clone());
let removed_participant_name = make_name(&["removed", "app", "v1"]).with_id(500);
participant.group_list.insert(
removed_participant_name.clone(),
ParticipantSettings::bidirectional(),
);
let remove_msg = Message::builder()
.source(moderator.clone())
.destination(participant.common.settings.destination.clone())
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::GroupRemove)
.session_id(1)
.message_id(400)
.payload(
CommandPayload::builder()
.group_remove(removed_participant_name.clone(), vec![], None)
.as_content(),
)
.build_publish()
.unwrap();
let sub_mgr = participant.common.settings.subscription_manager.clone();
let result = run_with_acks(
participant.on_group_update_message(remove_msg, false),
&mut rx_slim,
&sub_mgr,
)
.await;
assert!(result.is_ok());
assert!(
!participant
.group_list
.contains_key(&removed_participant_name)
);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert_eq!(participant.inner.get_endpoints_removed_count().await, 1);
let output = result.unwrap();
assert!(!output.is_empty(), "Should have sent group ack");
}
#[tokio::test]
async fn test_participant_on_leave_request() {
let (mut participant, _rx_slim, mut rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
participant.subscribed = true;
let moderator = make_name(&["moderator", "app", "v1"]).with_id(300);
participant.moderator_name = Some(moderator.clone());
let leave_msg = Message::builder()
.source(moderator.clone())
.destination(participant.common.settings.source.clone())
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::LeaveRequest)
.session_id(1)
.message_id(500)
.payload(CommandPayload::builder().leave_request().as_content())
.build_publish()
.unwrap();
let result = participant.on_leave_request(leave_msg).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(!output.is_empty());
let msg = match &output.messages[0] {
OutboundMessage::ToSlim(m) => m,
_ => panic!("Expected ToSlim message"),
};
assert_eq!(
msg.get_session_header().session_message_type(),
ProtoSessionMessageType::LeaveReply
);
let delete_msg = rx_session_layer.try_recv();
assert!(delete_msg.is_ok());
if let Ok(Ok(SessionMessage::DeleteSession { session_id })) = delete_msg {
assert_eq!(session_id, 1);
} else {
panic!("Expected DeleteSession message");
}
}
#[tokio::test]
async fn test_participant_join_multicast() {
let (mut participant, mut rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
let moderator = make_name(&["moderator", "app", "v1"]).with_id(300);
let welcome_msg = Message::builder()
.source(moderator.clone())
.destination(participant.common.settings.source.clone())
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::GroupWelcome)
.session_id(1)
.message_id(100)
.payload(
CommandPayload::builder()
.group_welcome(vec![], None)
.as_content(),
)
.build_publish()
.unwrap();
let sub_mgr = participant.common.settings.subscription_manager.clone();
let result = run_with_acks(participant.join(&welcome_msg), &mut rx_slim, &sub_mgr).await;
assert!(result.is_ok());
assert!(participant.subscribed);
}
#[tokio::test]
async fn test_participant_join_point_to_point() {
let (mut participant, _rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::PointToPoint);
participant.init().await.unwrap();
let moderator = make_name(&["moderator", "app", "v1"]).with_id(300);
let msg = Message::builder()
.source(moderator.clone())
.destination(participant.common.settings.source.clone())
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::PointToPoint)
.session_message_type(ProtoSessionMessageType::JoinRequest)
.session_id(1)
.message_id(100)
.payload(
CommandPayload::builder()
.join_request(Some(3), Some(std::time::Duration::from_secs(1)), None, None)
.as_content(),
)
.build_publish()
.unwrap();
let result = participant.join(&msg).await;
assert!(result.is_ok());
assert!(participant.subscribed);
}
#[tokio::test]
async fn test_participant_join_idempotent() {
let (mut participant, mut rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
let moderator = make_name(&["moderator", "app", "v1"]).with_id(300);
let msg = Message::builder()
.source(moderator.clone())
.destination(participant.common.settings.source.clone())
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::GroupWelcome)
.session_id(1)
.message_id(100)
.payload(
CommandPayload::builder()
.group_welcome(vec![], None)
.as_content(),
)
.build_publish()
.unwrap();
let sub_mgr = participant.common.settings.subscription_manager.clone();
run_with_acks(participant.join(&msg), &mut rx_slim, &sub_mgr)
.await
.unwrap();
participant.join(&msg).await.unwrap();
let second_sub = rx_slim.try_recv();
assert!(
second_sub.is_err(),
"Second join should not send any messages"
);
}
#[tokio::test]
async fn test_participant_application_message_forwarding() {
let (mut participant, _rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
let source = participant.common.settings.source.clone();
let destination = participant.common.settings.destination.clone();
let app_msg = Message::builder()
.source(source)
.destination(destination)
.identity("")
.forward_to(0)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::Msg)
.session_id(1)
.message_id(100)
.application_payload("application/octet-stream", vec![1, 2, 3, 4])
.build_publish()
.unwrap();
let result = participant
.on_message(SessionMessage::OnMessage {
message: app_msg,
direction: crate::MessageDirection::South,
ack_tx: None,
})
.await;
assert!(result.is_ok());
assert_eq!(participant.inner.get_messages_count().await, 1);
}
#[tokio::test]
async fn test_participant_timer_timeout_control_message() {
let (mut participant, _rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
let result = participant
.on_message(SessionMessage::TimerTimeout {
message_id: 100,
message_type: ProtoSessionMessageType::JoinRequest,
name: None,
timeouts: 1,
})
.await;
assert!(result.is_ok() || result.is_err());
}
#[tokio::test]
async fn test_participant_timer_timeout_app_message() {
let (mut participant, _rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
let result = participant
.on_message(SessionMessage::TimerTimeout {
message_id: 100,
message_type: ProtoSessionMessageType::Msg,
name: None,
timeouts: 1,
})
.await;
assert!(result.is_ok());
assert_eq!(participant.inner.get_messages_count().await, 1);
}
#[tokio::test]
async fn test_participant_timer_failure_control_message() {
let (mut participant, _rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
let result = participant
.on_message(SessionMessage::TimerFailure {
message_id: 100,
message_type: ProtoSessionMessageType::JoinRequest,
name: None,
timeouts: 3,
})
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_participant_timer_failure_app_message() {
let (mut participant, _rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
let result = participant
.on_message(SessionMessage::TimerFailure {
message_id: 100,
message_type: ProtoSessionMessageType::Msg,
name: None,
timeouts: 3,
})
.await;
assert!(result.is_ok());
assert_eq!(participant.inner.get_messages_count().await, 1);
}
#[tokio::test]
async fn test_participant_add_and_remove_endpoint() {
let (mut participant, _rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
let endpoint_name = make_name(&["endpoint", "app", "v1"]);
let endpoint =
Participant::new(endpoint_name.clone(), ParticipantSettings::bidirectional());
let result = participant.add_endpoint(&endpoint).await;
assert!(result.is_ok());
assert_eq!(participant.inner.get_endpoints_added_count().await, 1);
participant.remove_endpoint(&endpoint_name);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert_eq!(participant.inner.get_endpoints_removed_count().await, 1);
}
#[tokio::test]
async fn test_participant_on_shutdown() {
let (mut participant, _rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
participant.subscribed = true;
let result = participant.on_shutdown().await;
assert!(result.is_ok());
assert!(!participant.subscribed);
}
#[tokio::test]
async fn test_participant_unexpected_control_messages() {
let (mut participant, _rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
let discovery_msg = Message::builder()
.source(make_proto_name(&["someone", "app", "v1"]).with_id(300))
.destination(participant.common.settings.source.clone())
.identity("")
.forward_to(0)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::DiscoveryRequest)
.session_id(1)
.message_id(100)
.payload(CommandPayload::builder().discovery_request().as_content())
.build_publish()
.unwrap();
let result = participant.process_control_message(discovery_msg).await;
assert!(result.is_ok()); }
#[tokio::test]
async fn test_participant_leave_multicast_unsubscribes() {
let (mut participant, mut rx_slim, _rx_session_layer, _rx_session) =
setup_participant(ProtoSessionType::Multicast);
participant.init().await.unwrap();
participant.subscribed = true;
participant.conn_id = Some(12345);
let moderator = make_name(&["moderator", "app", "v1"]).with_id(300);
participant.moderator_name = Some(moderator.clone());
let sub_mgr = participant.common.settings.subscription_manager.clone();
let result =
run_with_acks(participant.disconnect_from_group(), &mut rx_slim, &sub_mgr).await;
assert!(result.is_ok());
let result = run_with_acks(
participant.disconnect_from_moderator(),
&mut rx_slim,
&sub_mgr,
)
.await;
assert!(result.is_ok());
}
}