use std::{
collections::{HashMap, VecDeque},
time::Duration,
};
use display_error_chain::ErrorChainExt;
use slim_auth::traits::{TokenProvider, Verifier};
use slim_datapath::{
api::{
CommandPayload, MlsPayload, NameId, Participant, ProtoMessage as Message, ProtoMlsSettings,
ProtoName, ProtoSessionMessageType, ProtoSessionType,
},
messages::utils::{DELETE_GROUP, DISCONNECTION_DETECTED, LEAVING_SESSION, TRUE_VAL},
};
use slim_mls::mls::Mls;
use tokio::sync::oneshot;
use tracing::debug;
use crate::{
common::{MessageDirection, SessionMessage, SessionOutput},
errors::SessionError,
mls_state::{MlsModeratorState, MlsState},
moderator_task::{
AddParticipant, ModeratorTask, NotifyParticipants, RemoveParticipant, TaskUpdate,
},
runtime::maybe_await,
session_controller::SessionControllerCommon,
session_settings::SessionSettings,
subscription_manager::{SubscriptionManager, SubscriptionOps},
traits::{MessageHandler, ProcessingState},
};
pub struct SessionModerator<P, V, I, M = SubscriptionManager>
where
P: TokenProvider + Send + Sync + Clone + 'static,
V: Verifier + Send + Sync + Clone + 'static,
I: MessageHandler + Send + Sync + 'static,
M: SubscriptionOps,
{
tasks_todo: VecDeque<(Message, Option<oneshot::Sender<Result<(), SessionError>>>)>,
current_task: Option<ModeratorTask>,
mls_state: Option<MlsModeratorState<P, V>>,
group_list: HashMap<ProtoName, Participant>,
common: SessionControllerCommon<P, V, M>,
postponed_message: Option<Message>,
subscribed: bool,
conn_id: Option<u64>,
inner: I,
}
impl<P, V, I, M> SessionModerator<P, V, I, M>
where
P: TokenProvider + Send + Sync + Clone + 'static,
V: Verifier + Send + Sync + Clone + 'static,
I: MessageHandler + Send + Sync + 'static,
M: SubscriptionOps,
{
pub(crate) fn new(inner: I, settings: SessionSettings<P, V, M>) -> Self {
let common = SessionControllerCommon::new(settings);
SessionModerator {
tasks_todo: vec![].into(),
current_task: None,
mls_state: None,
group_list: HashMap::new(),
common,
postponed_message: None,
subscribed: false,
conn_id: None,
inner,
}
}
}
impl<P, V, I, M> MessageHandler for SessionModerator<P, V, I, M>
where
P: TokenProvider + Send + Sync + Clone + 'static,
V: Verifier + Send + Sync + Clone + 'static,
I: MessageHandler + Send + Sync + 'static,
M: SubscriptionOps,
{
async fn init(&mut self) -> Result<(), SessionError> {
self.mls_state = if let Some(mls_settings) = &self.common.settings.config.mls_settings {
let mls_state = MlsState::new(
Mls::new(
self.common.settings.identity_provider.clone(),
self.common.settings.identity_verifier.clone(),
),
mls_settings.header_integrity_validation_percent,
)
.await
.expect("failed to create MLS state");
Some(MlsModeratorState::new(mls_state))
} else {
None
};
Ok(())
}
async fn on_message(&mut self, message: SessionMessage) -> Result<SessionOutput, SessionError> {
let mut output = SessionOutput::new();
match message {
SessionMessage::OnMessage {
mut message,
direction,
ack_tx,
} => {
if message.get_session_message_type().is_command_message() {
debug!(
message = ?message.get_session_message_type(),
source = %message.get_source(),
"received message",
);
output.extend(self.process_control_message(message, ack_tx).await?);
} else {
if direction == MessageDirection::South
&& self.common.settings.config.session_type
== ProtoSessionType::PointToPoint
{
message
.get_slim_header_mut()
.set_destination(self.common.settings.destination.clone());
}
if direction == MessageDirection::North
&& let Some(mls_state) = &mut self.mls_state
{
maybe_await!(mls_state.common.process_message(&mut message, direction))?;
}
let inner_output = self
.inner
.on_message(SessionMessage::OnMessage {
message,
direction,
ack_tx,
})
.await?;
output.extend(inner_output);
}
}
SessionMessage::MessageError { error } => {
output.extend(self.handle_message_error(error).await?);
}
SessionMessage::TimerTimeout {
message_id,
message_type,
name,
timeouts,
} => {
if message_type.is_command_message() {
output.extend(
self.common
.sender
.on_timer_timeout(message_id, message_type)?,
);
} else {
let inner_output = self
.inner
.on_message(SessionMessage::TimerTimeout {
message_id,
message_type,
name,
timeouts,
})
.await?;
output.extend(inner_output);
}
}
SessionMessage::TimerFailure {
message_id,
message_type,
name,
timeouts,
} => {
if message_type.is_command_message() {
self.handle_failure(
message_id,
message_type,
SessionError::MessageSendRetryFailed { id: message_id },
)
.await?;
} else {
output.extend(
self.inner
.on_message(SessionMessage::TimerFailure {
message_id,
message_type,
name,
timeouts,
})
.await?,
);
}
}
SessionMessage::StartDrain { grace_period: _ } => {
debug!("start draining by calling delete_all");
self.common.processing_state = ProcessingState::Draining;
let p = CommandPayload::builder().leave_request().as_content();
let destination = self.common.settings.control.clone();
let mut leave_msg = self.common.create_control_message(
&destination,
ProtoSessionMessageType::LeaveRequest,
rand::random::<u32>(),
p,
false,
)?;
leave_msg.insert_metadata(DELETE_GROUP.to_string(), TRUE_VAL.to_string());
output.extend(self.delete_all(None).await?);
}
SessionMessage::ParticipantDisconnected {
name: opt_participant,
} => {
let participant =
opt_participant.ok_or(SessionError::MissingParticipantNameOnDisconnection)?;
debug!(
%participant,
"Participant not anymore connected to the current session",
);
let mut msg = self.common.create_control_message(
&participant.clone(),
ProtoSessionMessageType::LeaveRequest,
rand::random::<u32>(),
CommandPayload::builder().leave_request().as_content(),
false,
)?;
msg.insert_metadata(DISCONNECTION_DETECTED.to_string(), TRUE_VAL.to_string());
output.extend(self.on_disconnection_detected(msg, None).await?);
}
_ => {
return Err(SessionError::SessionMessageInternalUnexpected(Box::new(
message,
)));
}
}
maybe_await!(self.encrypt_output(&mut output))?;
Ok(output)
}
async fn add_endpoint(
&mut self,
endpoint: &Participant,
) -> Result<SessionOutput, SessionError> {
self.inner.add_endpoint(endpoint).await
}
fn remove_endpoint(&mut self, endpoint: &ProtoName) {
self.inner.remove_endpoint(endpoint);
}
fn needs_drain(&self) -> bool {
!self.common.sender.drain_completed()
|| self.inner.needs_drain()
|| !self.tasks_todo.is_empty()
}
fn processing_state(&self) -> ProcessingState {
self.common.processing_state
}
fn participants_list(&self) -> Vec<ProtoName> {
self.group_list
.iter()
.map(|(name, p)| {
let id = p
.name
.as_ref()
.map(|n| n.id())
.unwrap_or(NameId::NULL_COMPONENT); name.clone().with_id(id)
})
.collect()
}
async fn on_shutdown(&mut self) -> Result<(), SessionError> {
self.subscribed = false;
self.common.sender.close();
if self.common.settings.config.session_type == ProtoSessionType::Multicast
&& let Some(conn) = self.conn_id
{
self.common
.delete_route(self.common.settings.destination.clone(), conn)
.await?;
self.common
.delete_subscription(self.common.settings.destination.clone(), conn)
.await?;
self.common
.delete_route(self.common.settings.control.clone(), conn)
.await?;
}
MessageHandler::on_shutdown(&mut self.inner).await?;
self.send_close_signal().await;
Ok(())
}
}
impl<P, V, I, M> SessionModerator<P, V, I, M>
where
P: TokenProvider + Send + Sync + Clone + 'static,
V: Verifier + Send + Sync + Clone + 'static,
I: MessageHandler + Send + Sync + 'static,
M: SubscriptionOps,
{
#[maybe_async::maybe_async]
async fn encrypt_output(&mut self, output: &mut SessionOutput) -> Result<(), SessionError> {
crate::session_controller::SessionController::apply_identity_to_slim_output(
output,
&self.common.settings.identity_provider,
)?;
if let Some(mls_state) = &mut self.mls_state {
mls_state.common.encrypt_output(output).await?;
}
Ok(())
}
async fn handle_message_error(
&mut self,
error: SessionError,
) -> Result<SessionOutput, SessionError> {
let Some(session_ctx) = error.session_context() else {
tracing::warn!("Received MessageError without session context");
return self
.inner
.on_message(SessionMessage::MessageError { error })
.await;
};
if error.is_command_message_error() {
self.handle_failure(
session_ctx.message_id,
session_ctx.get_session_message_type(),
error,
)
.await?;
Ok(SessionOutput::new())
} else {
self.inner
.on_message(SessionMessage::MessageError { error })
.await
}
}
async fn handle_failure(
&mut self,
message_id: u32,
message_type: ProtoSessionMessageType,
error: SessionError,
) -> Result<(), SessionError> {
self.common.sender.on_failure(message_id, message_type);
if let Some(task) = self.current_task.as_mut()
&& let Some(ack_tx) = task.ack_tx_take()
{
let _ = ack_tx.send(Err(task.failure_message(error)));
}
self.current_task = None;
self.pop_task().await
}
fn handle_task_error(&mut self, error: SessionError) -> SessionError {
if let Some(task) = self.current_task.take() {
let ack_tx = match task {
ModeratorTask::Add(t) => t.ack_tx,
ModeratorTask::Remove(t) => t.ack_tx,
ModeratorTask::Update(t) => t.ack_tx,
ModeratorTask::CloseOrDisconnect(t) => t.ack_tx,
};
if let Some(tx) = ack_tx {
let _ = tx.send(Err(SessionError::cleanup_failed(&error)));
}
}
self.current_task = None;
error
}
async fn prepare_shutdown(&mut self) -> Result<(), SessionError> {
debug!("Preparing for shutdown: cleaning up state");
self.common.processing_state = ProcessingState::Draining;
self.mls_state = None;
self.tasks_todo.clear();
self.common.sender.clear_timers();
self.inner
.on_message(SessionMessage::StartDrain {
grace_period: Duration::from_secs(60), })
.await?;
self.common.sender.start_drain();
Ok(())
}
async fn remove_participant_and_compute_mls(
&mut self,
participant: &ProtoName,
msg: &Message,
) -> Result<(Vec<ProtoName>, Option<MlsPayload>), SessionError> {
let participants_vec: Vec<ProtoName> = self
.group_list
.iter()
.map(|(n, p)| p.get_name().map(|name| n.clone().with_id(name.id())))
.collect::<Result<Vec<ProtoName>, _>>()?;
let mut participant_no_id = participant.clone();
participant_no_id.reset_id();
self.group_list.remove(&participant_no_id);
self.remove_endpoint(participant);
let mls_payload = match self.mls_state.as_mut() {
Some(state) => {
let mls_content = maybe_await!(state.remove_participant(msg))
.map_err(|e| self.handle_task_error(e))?;
let commit_id = self.mls_state.as_mut().unwrap().get_next_mls_mgs_id();
Some(MlsPayload {
commit_id,
mls_content,
})
}
None => None,
};
Ok((participants_vec, mls_payload))
}
async fn send_group_remove(
&mut self,
removed_participant: ProtoName,
participants: Vec<ProtoName>,
mls_payload: Option<MlsPayload>,
) -> Result<(u32, SessionOutput), SessionError> {
let update_payload = CommandPayload::builder()
.group_remove(removed_participant, participants, mls_payload)
.as_content();
let msg_id = rand::random::<u32>();
let output = self.common.send_control_message(
&self.common.settings.control.clone(),
ProtoSessionMessageType::GroupRemove,
msg_id,
update_payload,
None,
true,
)?;
Ok((msg_id, output))
}
async fn process_control_message(
&mut self,
message: Message,
ack_tx: Option<oneshot::Sender<Result<(), SessionError>>>,
) -> Result<SessionOutput, SessionError> {
match message.get_session_message_type() {
ProtoSessionMessageType::DiscoveryRequest => {
self.on_discovery_request(message, ack_tx).await
}
ProtoSessionMessageType::DiscoveryReply => self.on_discovery_reply(message).await,
ProtoSessionMessageType::JoinReply => self.on_join_reply(message).await,
ProtoSessionMessageType::LeaveRequest => {
if message.contains_metadata(DISCONNECTION_DETECTED)
|| message.contains_metadata(LEAVING_SESSION)
{
return self.on_disconnection_detected(message, ack_tx).await;
}
self.on_leave_request(message, ack_tx).await
}
ProtoSessionMessageType::LeaveReply => self.on_leave_reply(message).await,
ProtoSessionMessageType::GroupAck => self.on_group_ack(message).await,
ProtoSessionMessageType::Ping => self.common.sender.on_message(&message),
ProtoSessionMessageType::GroupProposal => todo!(),
ProtoSessionMessageType::JoinRequest
| ProtoSessionMessageType::GroupAdd
| ProtoSessionMessageType::GroupRemove
| ProtoSessionMessageType::GroupWelcome
| ProtoSessionMessageType::GroupClose
| ProtoSessionMessageType::GroupNack => Err(
SessionError::SessionMessageTypeUnexpected(message.get_session_message_type()),
),
_ => Err(SessionError::SessionMessageTypeUnknown(
message.get_session_message_type(),
)),
}
}
async fn on_discovery_request(
&mut self,
mut msg: Message,
ack_tx: Option<oneshot::Sender<Result<(), SessionError>>>,
) -> Result<SessionOutput, SessionError> {
debug!(%self.common.settings.id, "received discovery request");
if self.current_task.is_some() {
debug!(
"Moderator is busy. Add invite participant task to the list and process it later"
);
self.tasks_todo.push_back((msg, ack_tx));
return Ok(SessionOutput::new());
}
debug!("Create AddParticipant task with ack_tx");
self.current_task = Some(ModeratorTask::Add(AddParticipant::new(ack_tx)));
let new_participant_name = msg.get_dst();
if self.group_list.contains_key(&new_participant_name) {
let err = SessionError::ParticipantAlreadyInGroup(new_participant_name);
return Err(self.handle_task_error(err));
}
let id = rand::random::<u32>();
msg.get_session_header_mut().set_message_id(id);
self.current_task
.as_mut()
.unwrap()
.discovery_start(id)
.map_err(|e| self.handle_task_error(e))?;
debug!(
dst = %msg.get_dst(),
id = msg.get_id(),
"send discovery request",
);
self.common
.send_with_timer(msg)
.map_err(|e| self.handle_task_error(e))
}
async fn on_discovery_reply(&mut self, msg: Message) -> Result<SessionOutput, SessionError> {
debug!(
source = %msg.get_source(),
id = msg.get_id(),
"discovery reply",
);
let mut output = self.common.sender.on_message(&msg)?;
self.current_task
.as_mut()
.unwrap()
.discovery_complete(msg.get_id())?;
self.join(msg.get_source(), msg.get_incoming_conn()).await?;
self.common
.add_route(msg.get_source(), msg.get_incoming_conn())
.await?;
if self.common.settings.config.session_type == ProtoSessionType::Multicast {
self.common
.add_route(
self.common.settings.destination.clone(),
msg.get_incoming_conn(),
)
.await?;
self.common
.add_route(
self.common.settings.control.clone(),
msg.get_incoming_conn(),
)
.await?;
}
let msg_id = rand::random::<u32>();
let channel = if self.common.settings.config.session_type == ProtoSessionType::Multicast {
Some(self.common.settings.destination.clone())
} else {
None
};
let mls_settings =
self.common
.settings
.config
.mls_settings
.as_ref()
.map(|s| ProtoMlsSettings {
header_integrity_validation_percent: s.header_integrity_validation_percent,
});
let payload = CommandPayload::builder()
.join_request(
self.common.settings.config.max_retries,
self.common.settings.config.interval,
channel,
mls_settings,
)
.as_content();
debug!(
dst = %msg.get_slim_header().get_source(),
id = msg_id,
"send join request",
);
output.extend(self.common.send_control_message(
&msg.get_slim_header().get_source(),
ProtoSessionMessageType::JoinRequest,
msg_id,
payload,
Some(self.common.settings.config.metadata.clone()),
false,
)?);
self.current_task.as_mut().unwrap().join_start(msg_id)?;
Ok(output)
}
async fn on_join_reply(&mut self, msg: Message) -> Result<SessionOutput, SessionError> {
debug!(
source = %msg.get_source(),
id = msg.get_id(),
"join reply",
);
let mut output = self.common.sender.on_message(&msg)?;
self.current_task
.as_mut()
.unwrap()
.join_complete(msg.get_id())?;
let new_participant = msg
.extract_join_reply()?
.participant
.clone()
.ok_or(SessionError::MissingParticipantSettings)?;
let mut new_name = new_participant.get_name()?;
debug!(session_name = %new_name, "add endpoint");
self.add_endpoint(&new_participant).await?;
new_name.reset_id();
self.group_list.insert(new_name, new_participant.clone());
let (commit, welcome) = if let Some(mls_state) = &mut self.mls_state {
let (commit_payload, welcome_payload) = maybe_await!(mls_state.add_participant(&msg))?;
let commit_id = self.mls_state.as_mut().unwrap().get_next_mls_mgs_id();
let commit = MlsPayload {
commit_id,
mls_content: commit_payload,
};
let welcome = MlsPayload {
commit_id,
mls_content: welcome_payload,
};
(Some(commit), Some(welcome))
} else {
(None, None)
};
let participants_vec = self.group_list.values().cloned().collect::<Vec<_>>();
if participants_vec.len() > 2 {
debug!("participant len is > 2, send a group update");
let update_payload = CommandPayload::builder()
.group_add(new_participant, participants_vec.clone(), commit)
.as_content();
let add_msg_id = rand::random::<u32>();
debug!(id = %add_msg_id, "send add update to channel");
output.extend(self.common.send_control_message(
&self.common.settings.control.clone(),
ProtoSessionMessageType::GroupAdd,
add_msg_id,
update_payload,
None,
true,
)?);
self.current_task
.as_mut()
.unwrap()
.commit_start(add_msg_id)?;
} else {
debug!("cancel the a group update task");
self.current_task.as_mut().unwrap().commit_start(12345)?;
self.current_task
.as_mut()
.unwrap()
.update_phase_completed(12345)?;
}
let welcome_msg_id = rand::random::<u32>();
let welcome_payload = CommandPayload::builder()
.group_welcome(participants_vec, welcome)
.as_content();
debug!(
dst = %msg.get_slim_header().get_source(),
id = %welcome_msg_id,
"send welcome message",
);
output.extend(self.common.send_control_message(
&msg.get_slim_header().get_source(),
ProtoSessionMessageType::GroupWelcome,
welcome_msg_id,
welcome_payload,
None,
false,
)?);
self.current_task
.as_mut()
.unwrap()
.welcome_start(welcome_msg_id)?;
Ok(output)
}
async fn on_leave_request(
&mut self,
mut msg: Message,
ack_tx: Option<oneshot::Sender<Result<(), SessionError>>>,
) -> Result<SessionOutput, SessionError> {
if self.current_task.is_some() {
debug!("Moderator is busy. Add leave request task to the list and process it later");
self.tasks_todo.push_back((msg, ack_tx));
return Ok(SessionOutput::new());
}
debug!("Create RemoveParticipant task");
self.current_task = Some(ModeratorTask::Remove(RemoveParticipant::new(ack_tx)));
let dst_without_id = msg.get_dst();
let id = match self.group_list.get(&dst_without_id) {
Some(p) => p.get_name()?.id(),
None => {
let err = SessionError::ParticipantNotFound(dst_without_id);
return Err(self.handle_task_error(err));
}
};
let dst_with_id = dst_without_id.clone().with_id(id);
msg.get_slim_header_mut().set_destination(dst_with_id);
msg.set_message_id(rand::random::<u32>());
let leave_message = msg;
debug!(
session_name = %leave_message.get_dst(),
"remove endpoint from the session",
);
let (participants_vec, mls_payload) = self
.remove_participant_and_compute_mls(&leave_message.get_dst(), &leave_message)
.await?;
if participants_vec.len() > 2 {
let (msg_id, output) = self
.send_group_remove(leave_message.get_dst(), participants_vec, mls_payload)
.await?;
self.current_task.as_mut().unwrap().commit_start(msg_id)?;
self.postponed_message = Some(leave_message);
Ok(output)
} else {
self.current_task.as_mut().unwrap().commit_start(12345)?;
self.current_task
.as_mut()
.unwrap()
.update_phase_completed(12345)?;
let output = self.common.sender.on_message(&leave_message)?;
self.current_task
.as_mut()
.unwrap()
.leave_start(leave_message.get_id())?;
Ok(output)
}
}
async fn on_disconnection_detected(
&mut self,
mut msg: Message,
ack_tx: Option<oneshot::Sender<Result<(), SessionError>>>,
) -> Result<SessionOutput, SessionError> {
let disconnected = if msg.contains_metadata(LEAVING_SESSION) {
msg.get_source()
} else {
msg.get_dst()
};
let mut disconnected_no_id = disconnected.clone();
disconnected_no_id.reset_id();
if !self.group_list.contains_key(&disconnected_no_id) {
debug!(
"detected disconnection of participant {} that is not part of the group, ignore the message",
disconnected
);
return Ok(SessionOutput::new());
}
debug!(%disconnected, "disconnection detected");
let error = SessionError::ParticipantDisconnected(disconnected.clone());
let mut output = SessionOutput::to_app(Err(error));
if msg.contains_metadata(LEAVING_SESSION) {
let reply = self.common.create_control_message(
&disconnected,
ProtoSessionMessageType::LeaveReply,
msg.get_id(),
CommandPayload::builder().leave_reply().as_content(),
false,
)?;
self.common.sender.remove_participant(&disconnected);
output.extend(SessionOutput::to_slim(reply));
msg.remove_metadata(LEAVING_SESSION);
msg.insert_metadata(DISCONNECTION_DETECTED.to_string(), TRUE_VAL.to_string());
let header = msg.get_slim_header_mut();
header.set_destination(disconnected.clone());
header.set_source(self.common.settings.source.clone());
}
if self.common.settings.config.session_type == ProtoSessionType::PointToPoint
|| self.group_list.len() == 2
{
debug!("no one is left connected connected to the session, close it");
self.prepare_shutdown().await?;
self.remove_endpoint(&msg.get_dst());
return Ok(output);
}
if self.current_task.is_some() {
debug!(
"Moderator is busy. Add disconnection handling task to the list and process it later"
);
self.tasks_todo.push_back((msg, ack_tx));
return Ok(output);
}
debug!("Create disconnected task for the disconnection handling");
self.current_task = Some(ModeratorTask::CloseOrDisconnect(NotifyParticipants::new(
ack_tx,
)));
debug!(
endpoint = %disconnected,
"remove disconnected endpoint from the session",
);
let (participants_vec, mls_payload) = self
.remove_participant_and_compute_mls(&disconnected, &msg)
.await?;
let (msg_id, remove_output) = self
.send_group_remove(disconnected, participants_vec, mls_payload)
.await?;
output.extend(remove_output);
self.current_task.as_mut().unwrap().commit_start(msg_id)?;
Ok(output)
}
async fn delete_all(
&mut self,
ack_tx: Option<oneshot::Sender<Result<(), SessionError>>>,
) -> Result<SessionOutput, SessionError> {
debug!("receive a close channel message, send signals to all participants");
self.prepare_shutdown().await?;
let participants: Vec<ProtoName> = self
.group_list
.iter()
.map(|(n, p)| p.get_name().map(|name| n.clone().with_id(name.id())))
.collect::<Result<Vec<ProtoName>, _>>()?;
if participants.len() == 1 {
return Ok(SessionOutput::new());
}
let destination = self.common.settings.control.clone();
let close_id = rand::random::<u32>();
let close = self.common.create_control_message(
&destination,
ProtoSessionMessageType::GroupClose,
close_id,
CommandPayload::builder()
.group_close(participants)
.as_content(),
true,
)?;
self.current_task = Some(ModeratorTask::CloseOrDisconnect(NotifyParticipants::new(
ack_tx,
)));
self.current_task.as_mut().unwrap().commit_start(close_id)?;
self.common.sender.on_message(&close)
}
async fn on_leave_reply(&mut self, msg: Message) -> Result<SessionOutput, SessionError> {
debug!(
from = %msg.get_source(),
id = %msg.get_id(),
"received leave reply",
);
let msg_id = msg.get_id();
self.common
.delete_route(msg.get_source(), msg.get_incoming_conn())
.await?;
let output = self.common.sender.on_message(&msg)?;
if !self.common.sender.is_still_pending(msg_id) {
self.current_task.as_mut().unwrap().leave_complete(msg_id)?;
}
self.task_done().await?;
Ok(output)
}
async fn on_group_ack(&mut self, msg: Message) -> Result<SessionOutput, SessionError> {
debug!(
from = %msg.get_source(),
id = %msg.get_id(),
"received group ack",
);
let mut output = self.common.sender.on_message(&msg)?;
let msg_id = msg.get_id();
if !self.common.sender.is_still_pending(msg_id) {
debug!(
id = %msg_id,
"process group ack. try to close task",
);
let Some(task) = self.current_task.as_mut() else {
debug!(
id = %msg_id,
"received group ack for completed/unknown task, ignoring",
);
return Ok(output);
};
task.update_phase_completed(msg_id)?;
if !self.current_task.as_mut().unwrap().task_complete() {
if let Some(leave_message) = &self.postponed_message
&& matches!(self.current_task, Some(ModeratorTask::Remove(_)))
{
output.extend(self.common.sender.on_message(leave_message)?);
self.current_task
.as_mut()
.unwrap()
.leave_start(leave_message.get_id())?;
self.postponed_message = None;
}
}
self.task_done().await?;
} else {
debug!(
id = %msg_id,
"timer for message still pending, do not close the task",
);
}
Ok(output)
}
async fn task_done(&mut self) -> Result<(), SessionError> {
if !self.current_task.as_ref().unwrap().task_complete() {
debug!("Current task is NOT completed");
return Ok(());
}
self.current_task = None;
self.pop_task().await
}
async fn pop_task(&mut self) -> Result<(), SessionError> {
if self.current_task.is_some() {
return Ok(());
}
let (msg, ack_tx) = match self.tasks_todo.pop_front() {
Some(task) => task,
None => {
debug!("No tasks left to perform");
return Ok(());
}
};
debug!("Re-enqueue a task from the todo list onto the processing loop");
self.common
.settings
.tx_session
.send(SessionMessage::OnMessage {
message: msg,
direction: MessageDirection::South,
ack_tx,
})
.await
.map_err(|_| SessionError::SlimMessageSendFailed)?;
Ok(())
}
async fn join(&mut self, remote: ProtoName, conn: u64) -> Result<(), SessionError> {
if self.subscribed {
return Ok(());
}
self.subscribed = true;
self.conn_id = Some(conn);
if self.common.settings.config.session_type == ProtoSessionType::PointToPoint {
self.common.settings.destination = remote;
} else {
let destination = self.common.settings.destination.clone();
self.common.add_subscription(destination, conn).await?;
}
if let Some(mls) = self.mls_state.as_mut() {
mls.init_moderator().await?;
}
let mut local_name = self.common.settings.source.clone();
let settings = self.common.settings.direction.to_participant_settings();
let participant = Participant::new(local_name.clone(), settings);
local_name.reset_id();
self.group_list.insert(local_name, participant);
Ok(())
}
#[allow(dead_code)]
async fn ack_msl_proposal(&mut self, _msg: &Message) -> Result<(), SessionError> {
todo!()
}
#[allow(dead_code)]
async fn on_mls_proposal(&mut self, _msg: Message) -> Result<(), SessionError> {
todo!()
}
async fn send_close_signal(&mut self) {
debug!("Signal session layer to close the session, all tasks are done");
let res = self
.common
.settings
.tx_to_session_layer
.send(Ok(SessionMessage::DeleteSession {
session_id: self.common.settings.id,
}))
.await;
if let Err(e) = res {
tracing::error!(error = %e.chain(), "an error occurred while signaling session close");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Direction;
use crate::common::OutboundMessage;
use crate::session_config::SessionConfig;
use crate::session_settings::SessionSettings;
use crate::test_utils::{MockInnerHandler, MockTokenProvider, MockVerifier};
use slim_datapath::Status;
use slim_datapath::api::{CommandPayload, ParticipantSettings, ProtoSessionType};
use tokio::sync::mpsc;
async fn run_with_acks<F, T>(
fut: F,
rx_slim: &mut mpsc::Receiver<Result<Message, Status>>,
sub_mgr: &crate::subscription_manager::SubscriptionManager,
) -> T
where
F: std::future::Future<Output = T>,
{
let mut pinned = Box::pin(fut);
loop {
tokio::select! {
res = &mut pinned => return res,
msg = rx_slim.recv() => {
if let Some(Ok(msg)) = msg && let Some(ack_id) = msg.get_subscription_id() {
let ack = Message::builder().build_subscription_ack(ack_id, true, "");
sub_mgr.resolve_ack(ack.get_subscription_ack());
}
}
}
}
}
fn make_name(parts: &[&str; 3]) -> ProtoName {
ProtoName::from_strings([parts[0], parts[1], parts[2]]).with_id(0)
}
fn setup_moderator() -> (
SessionModerator<MockTokenProvider, MockVerifier, MockInnerHandler>,
mpsc::Receiver<Result<Message, Status>>,
mpsc::Receiver<Result<SessionMessage, SessionError>>,
) {
let source = make_name(&["local", "moderator", "v1"]).with_id(100);
let destination = make_name(&["channel", "name", "v1"]).with_id(NameId::DATA_CHANNEL_ID);
let control = make_name(&["channel", "name", "v1"]).with_id(NameId::CONTROL_CHANNEL_ID);
let identity_provider = MockTokenProvider;
let identity_verifier = MockVerifier;
let (tx_slim, rx_slim) = mpsc::channel(16);
let (tx_app, _rx_app) = mpsc::unbounded_channel();
let (tx_session, _rx_session) = mpsc::channel(16);
let (tx_session_layer, rx_session_layer) = mpsc::channel(16);
let subscription_manager =
crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
let config = SessionConfig {
session_type: ProtoSessionType::Multicast,
max_retries: Some(3),
interval: Some(std::time::Duration::from_secs(1)),
mls_settings: None,
initiator: true,
metadata: Default::default(),
};
let settings = SessionSettings {
id: 1,
source,
destination,
control,
config,
direction: Direction::Bidirectional,
slim_tx: tx_slim,
app_tx: tx_app,
tx_session,
tx_to_session_layer: tx_session_layer,
identity_provider,
identity_verifier,
graceful_shutdown_timeout: None,
subscription_manager,
service_id: String::new(),
};
let inner = MockInnerHandler::new();
let moderator = SessionModerator::new(inner, settings);
(moderator, rx_slim, rx_session_layer)
}
#[tokio::test]
async fn test_moderator_new() {
let (moderator, _rx_slim, _rx_session_layer) = setup_moderator();
assert!(moderator.tasks_todo.is_empty());
assert!(moderator.current_task.is_none());
assert!(moderator.mls_state.is_none());
assert!(moderator.group_list.is_empty());
assert!(moderator.postponed_message.is_none());
assert!(!moderator.subscribed);
}
#[tokio::test]
async fn test_moderator_init() {
let (mut moderator, _rx_slim, _rx_session_layer) = setup_moderator();
let result = moderator.init().await;
assert!(result.is_ok());
assert!(moderator.mls_state.is_none()); }
#[tokio::test]
async fn test_moderator_discovery_request_starts_task() {
let (mut moderator, _rx_slim, _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
let source = make_name(&["requester", "app", "v1"]).with_id(300);
let destination = moderator.common.settings.source.clone();
let discovery_msg = Message::builder()
.source(source.clone())
.destination(destination)
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::DiscoveryRequest)
.session_id(1)
.message_id(100)
.payload(CommandPayload::builder().discovery_request().as_content())
.build_publish()
.unwrap();
let result = moderator.on_discovery_request(discovery_msg, None).await;
assert!(result.is_ok());
assert!(moderator.current_task.is_some());
assert!(matches!(
moderator.current_task,
Some(ModeratorTask::Add(_))
));
let output = result.unwrap();
assert!(!output.is_empty());
let msg = match &output.messages[0] {
OutboundMessage::ToSlim(m) => m,
_ => panic!("Expected ToSlim message"),
};
assert_eq!(
msg.get_session_header().session_message_type(),
ProtoSessionMessageType::DiscoveryRequest
);
}
#[tokio::test]
async fn test_moderator_discovery_request_when_busy() {
let (mut moderator, _rx_slim, _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
moderator.current_task = Some(ModeratorTask::Add(AddParticipant::new(None)));
let source = make_name(&["requester", "app", "v1"]).with_id(300);
let destination = moderator.common.settings.source.clone();
let discovery_msg = Message::builder()
.source(source.clone())
.destination(destination)
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::DiscoveryRequest)
.session_id(1)
.message_id(100)
.payload(CommandPayload::builder().discovery_request().as_content())
.build_publish()
.unwrap();
let result = moderator.on_discovery_request(discovery_msg, None).await;
assert!(result.is_ok());
assert_eq!(moderator.tasks_todo.len(), 1);
}
#[tokio::test]
async fn test_moderator_join_request_passthrough() {
let (mut moderator, _rx_slim, _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
let source = make_name(&["requester", "app", "v1"]).with_id(300);
let destination = moderator.common.settings.source.clone();
let join_msg = Message::builder()
.source(source.clone())
.destination(destination.clone())
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::JoinRequest)
.session_id(1)
.message_id(100)
.payload(
CommandPayload::builder()
.join_request(Some(3), Some(std::time::Duration::from_secs(1)), None, None)
.as_content(),
)
.build_publish()
.unwrap();
let result = moderator.process_control_message(join_msg, None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_moderator_application_message_forwarding() {
let (mut moderator, _rx_slim, _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
let source = moderator.common.settings.source.clone();
let destination = moderator.common.settings.destination.clone();
let app_msg = Message::builder()
.source(source)
.destination(destination)
.identity("")
.forward_to(0)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::Msg)
.session_id(1)
.message_id(100)
.application_payload("application/octet-stream", vec![1, 2, 3, 4])
.build_publish()
.unwrap();
let result = moderator
.on_message(SessionMessage::OnMessage {
message: app_msg,
direction: MessageDirection::South,
ack_tx: None,
})
.await;
assert!(result.is_ok());
assert_eq!(moderator.inner.get_messages_count().await, 1);
}
#[tokio::test]
async fn test_moderator_add_and_remove_endpoint() {
let (mut moderator, _rx_slim, _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
let endpoint_name = make_name(&["participant", "app", "v1"]).with_id(400);
let endpoint =
Participant::new(endpoint_name.clone(), ParticipantSettings::bidirectional());
let result = moderator.add_endpoint(&endpoint).await;
assert!(result.is_ok());
assert_eq!(moderator.inner.get_endpoints_added_count().await, 1);
moderator.remove_endpoint(&endpoint_name);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert_eq!(moderator.inner.get_endpoints_removed_count().await, 1);
}
#[tokio::test]
async fn test_moderator_join_sets_subscribed() {
let (mut moderator, mut rx_slim, _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
assert!(!moderator.subscribed);
let sub_mgr = moderator.common.settings.subscription_manager.clone();
let remote = make_name(&["remote", "app", "v1"]).with_id(200);
let result = run_with_acks(moderator.join(remote, 12345), &mut rx_slim, &sub_mgr).await;
assert!(result.is_ok());
assert!(moderator.subscribed);
assert!(!moderator.group_list.is_empty());
}
#[tokio::test]
async fn test_moderator_join_only_once() {
let (mut moderator, mut rx_slim, _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
let sub_mgr = moderator.common.settings.subscription_manager.clone();
let remote = make_name(&["remote", "app", "v1"]).with_id(200);
run_with_acks(
moderator.join(remote.clone(), 12345),
&mut rx_slim,
&sub_mgr,
)
.await
.unwrap();
moderator.join(remote, 12345).await.unwrap();
let second_subscribe = rx_slim.try_recv();
assert!(second_subscribe.is_err()); }
#[tokio::test]
async fn test_moderator_on_shutdown() {
let (mut moderator, _rx_slim, mut _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
let result = moderator.on_shutdown().await;
assert!(result.is_ok());
assert!(!moderator.subscribed);
}
#[tokio::test]
async fn test_moderator_delete_all_creates_leave_tasks() {
let (mut moderator, _rx_slim, _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
moderator.group_list.insert(
make_name(&["participant1", "app", "v1"]),
Participant::new(
make_name(&["participant1", "app", "v1"]).with_id(401),
ParticipantSettings::bidirectional(),
),
);
moderator.group_list.insert(
make_name(&["participant2", "app", "v1"]),
Participant::new(
make_name(&["participant2", "app", "v1"]).with_id(402),
ParticipantSettings::bidirectional(),
),
);
moderator.group_list.insert(
make_name(&["participant3", "app", "v1"]),
Participant::new(
make_name(&["participant3", "app", "v1"]).with_id(403),
ParticipantSettings::bidirectional(),
),
);
let result = moderator.delete_all(None).await;
assert!(result.is_ok() || result.is_err());
assert!(moderator.mls_state.is_none());
}
#[tokio::test]
async fn test_moderator_timer_timeout_for_control_message() {
let (mut moderator, _rx_slim, _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
let result = moderator
.on_message(SessionMessage::TimerTimeout {
message_id: 100,
message_type: ProtoSessionMessageType::DiscoveryRequest,
name: None,
timeouts: 1,
})
.await;
assert!(result.is_ok() || result.is_err());
}
#[tokio::test]
async fn test_moderator_timer_timeout_for_app_message() {
let (mut moderator, _rx_slim, _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
let result = moderator
.on_message(SessionMessage::TimerTimeout {
message_id: 100,
message_type: ProtoSessionMessageType::Msg,
name: None,
timeouts: 1,
})
.await;
assert!(result.is_ok());
assert_eq!(moderator.inner.get_messages_count().await, 1);
}
#[tokio::test]
async fn test_moderator_point_to_point_destination_update() {
let source = make_name(&["local", "app", "v1"]).with_id(100);
let destination = make_name(&["remote", "app", "v1"]).with_id(200);
let identity_provider = MockTokenProvider;
let identity_verifier = MockVerifier;
let (tx_slim, _rx_slim) = mpsc::channel(16);
let (tx_app, _rx_app) = mpsc::unbounded_channel();
let (tx_session, _rx_session) = mpsc::channel(16);
let (tx_session_layer, _rx_session_layer) = mpsc::channel(16);
let subscription_manager =
crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
let config = SessionConfig {
session_type: ProtoSessionType::PointToPoint,
max_retries: Some(3),
interval: Some(std::time::Duration::from_secs(1)),
mls_settings: None,
initiator: true,
metadata: Default::default(),
};
let settings = SessionSettings {
id: 1,
source: source.clone(),
destination: destination.clone(),
control: destination.clone(),
config,
direction: Direction::Bidirectional,
slim_tx: tx_slim,
app_tx: tx_app,
tx_session,
tx_to_session_layer: tx_session_layer,
identity_provider,
identity_verifier,
graceful_shutdown_timeout: None,
subscription_manager,
service_id: String::new(),
};
let inner = MockInnerHandler::new();
let mut moderator = SessionModerator::new(inner, settings);
moderator.init().await.unwrap();
let app_msg = Message::builder()
.source(source)
.destination(destination)
.identity("")
.forward_to(0)
.session_type(ProtoSessionType::PointToPoint)
.session_message_type(ProtoSessionMessageType::Msg)
.session_id(1)
.message_id(100)
.application_payload("application/octet-stream", vec![1, 2, 3])
.build_publish()
.unwrap();
let _original_dest = app_msg.get_dst();
let result = moderator
.on_message(SessionMessage::OnMessage {
message: app_msg,
direction: MessageDirection::South,
ack_tx: None,
})
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_moderator_graceful_leave_with_two_participants() {
let source = ProtoName::from_strings(["agntcy", "ns", "moderator"]).with_id(100);
let destination = ProtoName::from_strings(["agntcy", "ns", "chat"]);
let control =
ProtoName::from_strings(["agntcy", "ns", "chat"]).with_id(NameId::CONTROL_CHANNEL_ID);
let identity_provider = MockTokenProvider;
let identity_verifier = MockVerifier;
let (tx_slim, mut rx_slim) = mpsc::channel(16);
let (tx_app, _rx_app) = mpsc::unbounded_channel();
let (tx_session, _rx_session) = mpsc::channel(16);
let (tx_session_layer, _rx_session_layer) = mpsc::channel(16);
let subscription_manager =
crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
let config = SessionConfig {
session_type: ProtoSessionType::Multicast,
max_retries: Some(3),
interval: Some(std::time::Duration::from_secs(1)),
mls_settings: None,
initiator: true,
metadata: Default::default(),
};
let settings = SessionSettings {
id: 1,
source: source.clone(),
destination: destination.clone(),
control,
config,
direction: Direction::Bidirectional,
slim_tx: tx_slim,
app_tx: tx_app,
tx_session,
tx_to_session_layer: tx_session_layer,
identity_provider,
identity_verifier,
graceful_shutdown_timeout: None,
subscription_manager,
service_id: String::new(),
};
let inner = MockInnerHandler::new();
let mut moderator = SessionModerator::new(inner, settings);
moderator.init().await.unwrap();
let remote = ProtoName::from_strings(["agntcy", "ns", "participant"]).with_id(200);
let sub_mgr = moderator.common.settings.subscription_manager.clone();
run_with_acks(
moderator.join(remote.clone(), 12345),
&mut rx_slim,
&sub_mgr,
)
.await
.unwrap();
let mut participant_name = ProtoName::from_strings(["agntcy", "ns", "participant"]);
let participant = Participant::new(
participant_name.clone(),
ParticipantSettings::bidirectional(),
);
let participant_id = 401u128;
participant_name.reset_id(); moderator
.group_list
.insert(participant_name.clone(), participant);
assert_eq!(
moderator.group_list.len(),
2,
"Should have exactly 2 participants"
);
assert_eq!(moderator.processing_state(), ProcessingState::Active);
let participant_with_id = participant_name.clone().with_id(participant_id);
let mut leave_msg = Message::builder()
.source(participant_with_id.clone())
.destination(source.clone())
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::LeaveRequest)
.session_id(1)
.message_id(100)
.payload(
CommandPayload::builder()
.leave_request() .as_content(),
)
.build_publish()
.unwrap();
leave_msg.insert_metadata(LEAVING_SESSION.to_string(), TRUE_VAL.to_string());
let result = moderator.on_disconnection_detected(leave_msg, None).await;
assert!(result.is_ok(), "Should succeed with open app channel");
let output = result.unwrap();
let app_error = output
.messages
.iter()
.find(|m| matches!(m, OutboundMessage::ToApp(_)));
assert!(
app_error.is_some(),
"Expected error to be sent to app output"
);
if let Some(OutboundMessage::ToApp(Err(SessionError::ParticipantDisconnected(name)))) =
app_error
{
let name_str = name.to_string();
assert!(
name_str.contains("agntcy/ns/participant"),
"Error message should mention the participant, got: {}",
name_str
);
} else {
panic!("Expected ParticipantDisconnected error");
}
assert_eq!(
moderator.processing_state(),
ProcessingState::Draining,
"Session should be in draining state"
);
}
#[tokio::test]
async fn test_moderator_concurrent_leave_requests() {
let source = ProtoName::from_strings(["agntcy", "ns", "moderator"]).with_id(100);
let destination =
ProtoName::from_strings(["agntcy", "ns", "chat"]).with_id(NameId::DATA_CHANNEL_ID);
let control =
ProtoName::from_strings(["agntcy", "ns", "chat"]).with_id(NameId::CONTROL_CHANNEL_ID);
let identity_provider = MockTokenProvider;
let identity_verifier = MockVerifier;
let (tx_slim, mut rx_slim) = mpsc::channel(16);
let (tx_app, _rx_app) = mpsc::unbounded_channel();
let (tx_session, _rx_session) = mpsc::channel(16);
let (tx_session_layer, _rx_session_layer) = mpsc::channel(16);
let subscription_manager =
crate::subscription_manager::SubscriptionManager::new(tx_slim.clone());
let config = SessionConfig {
session_type: ProtoSessionType::Multicast,
max_retries: Some(3),
interval: Some(std::time::Duration::from_secs(1)),
mls_settings: None,
initiator: true,
metadata: Default::default(),
};
let settings = SessionSettings {
id: 1,
source: source.clone(),
destination: destination.clone(),
control: control.clone(),
config,
direction: Direction::Bidirectional,
slim_tx: tx_slim,
app_tx: tx_app,
tx_session,
tx_to_session_layer: tx_session_layer,
identity_provider,
identity_verifier,
graceful_shutdown_timeout: None,
subscription_manager,
service_id: String::new(),
};
let inner = MockInnerHandler::new();
let mut moderator = SessionModerator::new(inner, settings);
moderator.init().await.unwrap();
let remote = ProtoName::from_strings(["agntcy", "ns", "participant1"]).with_id(200);
let sub_mgr = moderator.common.settings.subscription_manager.clone();
run_with_acks(
moderator.join(remote.clone(), 12345),
&mut rx_slim,
&sub_mgr,
)
.await
.unwrap();
let mut participant1_name = ProtoName::from_strings(["agntcy", "ns", "participant1"]);
let mut participant2_name = ProtoName::from_strings(["agntcy", "ns", "participant2"]);
let mut participant3_name = ProtoName::from_strings(["agntcy", "ns", "participant3"]);
let participant1 = Participant::new(
participant1_name.clone(),
ParticipantSettings::bidirectional(),
);
let participant2 = Participant::new(
participant2_name.clone(),
ParticipantSettings::bidirectional(),
);
let participant3 = Participant::new(
participant3_name.clone(),
ParticipantSettings::bidirectional(),
);
participant1_name.reset_id(); participant2_name.reset_id();
participant3_name.reset_id();
moderator
.group_list
.insert(participant1_name.clone(), participant1);
moderator
.group_list
.insert(participant2_name.clone(), participant2);
moderator
.group_list
.insert(participant3_name.clone(), participant3);
let participant1_with_id = participant1_name.clone().with_id(401);
let mut leave_msg1 = Message::builder()
.source(participant1_with_id.clone())
.destination(source.clone()) .identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::LeaveRequest)
.session_id(1)
.message_id(101)
.payload(
CommandPayload::builder()
.leave_request() .as_content(),
)
.build_publish()
.unwrap();
leave_msg1.insert_metadata(LEAVING_SESSION.to_string(), TRUE_VAL.to_string());
let participant2_with_id = participant2_name.clone().with_id(402);
let mut leave_msg2 = Message::builder()
.source(participant2_with_id.clone())
.destination(source.clone()) .identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::LeaveRequest)
.session_id(1)
.message_id(102)
.payload(
CommandPayload::builder()
.leave_request() .as_content(),
)
.build_publish()
.unwrap();
leave_msg2.insert_metadata(LEAVING_SESSION.to_string(), TRUE_VAL.to_string());
let result1 = moderator.on_disconnection_detected(leave_msg1, None).await;
assert!(result1.is_ok() || result1.is_err());
assert!(
moderator.current_task.is_some(),
"First leave should create a task"
);
let result2 = moderator.on_disconnection_detected(leave_msg2, None).await;
assert!(result2.is_ok());
assert_eq!(
moderator.tasks_todo.len(),
1,
"Second leave request should be queued while first is processing"
);
if let Some((queued_msg, _)) = moderator.tasks_todo.front() {
assert!(
queued_msg.contains_metadata(DISCONNECTION_DETECTED),
"Queued message should have DISCONNECTION_DETECTED metadata"
);
} else {
panic!("Expected queued task for participant2");
}
while rx_slim.try_recv().is_ok() {}
assert!(
!moderator.group_list.contains_key(&participant1_name),
"Participant1 should be removed after first leave request"
);
assert!(
moderator.group_list.contains_key(&participant2_name),
"Participant2 should still be in group (task queued, not processed)"
);
}
#[tokio::test]
async fn test_group_ack_ignored_when_no_current_task() {
let (mut moderator, _rx_slim, _rx_session_layer) = setup_moderator();
moderator.init().await.unwrap();
assert!(moderator.current_task.is_none());
let source = make_name(&["participant", "app", "v1"]).with_id(300);
let destination = moderator.common.settings.source.clone();
let group_ack = Message::builder()
.source(source)
.destination(destination)
.identity("")
.forward_to(0)
.incoming_conn(12345)
.session_type(ProtoSessionType::Multicast)
.session_message_type(ProtoSessionMessageType::GroupAck)
.session_id(1)
.message_id(999)
.payload(CommandPayload::builder().group_ack().as_content())
.build_publish()
.unwrap();
let result = moderator.process_control_message(group_ack, None).await;
assert!(result.is_ok());
assert!(moderator.current_task.is_none());
}
}