use crate::errors::MessageBusError;
use tokio::sync::broadcast;
pub struct FsmMessageBus {
stage_commands: broadcast::Sender<StageCommand>,
}
impl Default for FsmMessageBus {
fn default() -> Self {
Self::new()
}
}
impl FsmMessageBus {
pub fn new() -> Self {
let (stage_commands, _) = broadcast::channel(16);
Self { stage_commands }
}
pub fn subscribe_to_stage_commands(&self) -> broadcast::Receiver<StageCommand> {
self.stage_commands.subscribe()
}
pub async fn send_stage_command(&self, command: StageCommand) -> Result<(), MessageBusError> {
self.stage_commands.send(command).map_err(|_| {
tracing::error!(
location = %std::panic::Location::caller(),
"No stages listening for commands - receiver count: {}",
self.stage_commands.receiver_count()
);
MessageBusError::NoStageReceivers
})?;
Ok(())
}
}
#[derive(Clone, Debug)]
pub enum StageCommand {
Initialize,
Start,
BeginDrain,
ForceShutdown { reason: String },
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_message_bus_creation() {
let bus = FsmMessageBus::new();
let _sub1 = bus.subscribe_to_stage_commands();
let _sub2 = bus.subscribe_to_stage_commands();
}
#[tokio::test]
async fn test_stage_command_broadcast() {
let bus = FsmMessageBus::new();
let mut sub1 = bus.subscribe_to_stage_commands();
let mut sub2 = bus.subscribe_to_stage_commands();
bus.send_stage_command(StageCommand::Start).await.unwrap();
let cmd1 = sub1.recv().await.unwrap();
let cmd2 = sub2.recv().await.unwrap();
assert!(matches!(cmd1, StageCommand::Start));
assert!(matches!(cmd2, StageCommand::Start));
}
}