use async_trait::async_trait;
use super::command_message::{CommandMessage, MessageType};
use super::error::IpcError;
#[async_trait]
pub trait ModuleHandler: Send + Sync {
async fn handle_message(&mut self, msg: CommandMessage) -> CommandMessage;
async fn on_initialize(&mut self) -> Result<(), anyhow::Error>;
async fn on_finalize(&mut self) -> Result<(), anyhow::Error>;
async fn on_subscribe(&mut self, _topic: &str, _subscriber_id: &str) -> Result<(), anyhow::Error> {
Ok(())
}
async fn on_unsubscribe(&mut self, _topic: &str, _subscriber_id: &str) -> Result<(), anyhow::Error> {
Ok(())
}
async fn on_heartbeat(&mut self) -> Result<(), anyhow::Error> {
Ok(())
}
fn domain(&self) -> &str;
fn version(&self) -> &str {
"1.0.0"
}
fn capabilities(&self) -> Vec<String> {
Vec::new()
}
fn get_catalog(&self) -> Vec<String> {
Vec::new()
}
fn shm_variable_names(&self) -> Vec<String> {
Vec::new()
}
async fn on_shm_configured(&mut self, _shm_map: crate::shm::ShmMap) -> Result<(), anyhow::Error> {
Ok(())
}
}
#[async_trait]
pub trait ModuleHandlerExt: ModuleHandler {
async fn process_message(&mut self, msg: CommandMessage) -> Result<(CommandMessage, bool), IpcError>;
}
#[async_trait]
impl<T: ModuleHandler + ?Sized> ModuleHandlerExt for T {
async fn process_message(&mut self, msg: CommandMessage) -> Result<(CommandMessage, bool), IpcError> {
match msg.message_type {
MessageType::NoOp => {
Ok((msg.into_response(serde_json::Value::Null), false))
}
MessageType::Heartbeat => {
self.on_heartbeat().await.map_err(|e| IpcError::Handler(e.to_string()))?;
Ok((CommandMessage::heartbeat(), false))
}
MessageType::Control => {
let subtopic = msg.subtopic();
let control_type = msg.data.get("action")
.and_then(|a| a.as_str())
.unwrap_or(&subtopic);
match control_type {
"initialize" => {
self.on_initialize().await.map_err(|e| IpcError::Handler(e.to_string()))?;
Ok((msg.into_response(serde_json::Value::Null), false))
}
"finalize" => {
log::info!("Received finalize command, shutting down...");
self.on_finalize().await.map_err(|e| IpcError::Handler(e.to_string()))?;
Ok((msg.into_response(serde_json::json!({"finalized": true})), true))
}
_ => {
let response = self.handle_message(msg).await;
Ok((response, false))
}
}
}
MessageType::Subscribe => {
let topic = &msg.topic;
let subscriber = msg.data.get("subscriber")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
self.on_subscribe(topic, subscriber).await
.map_err(|e| IpcError::Handler(e.to_string()))?;
Ok((msg.into_response(serde_json::Value::Null), false))
}
MessageType::Unsubscribe => {
let topic = &msg.topic;
let subscriber = msg.data.get("subscriber")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
self.on_unsubscribe(topic, subscriber).await
.map_err(|e| IpcError::Handler(e.to_string()))?;
Ok((msg.into_response(serde_json::Value::Null), false))
}
MessageType::Request | MessageType::Read | MessageType::Write => {
let response = self.handle_message(msg).await;
Ok((response, false))
}
MessageType::Response | MessageType::Broadcast => {
Ok((msg, false))
}
}
}
}
pub struct BaseModuleHandler {
domain: String,
version: String,
capabilities: Vec<String>,
pub catalog: Vec<String>,
}
impl BaseModuleHandler {
pub fn new(domain: &str) -> Self {
Self {
domain: domain.to_string(),
version: "1.0.0".to_string(),
capabilities: Vec::new(),
catalog: Vec::new(),
}
}
pub fn with_version(mut self, version: &str) -> Self {
self.version = version.to_string();
self
}
pub fn with_capabilities(mut self, caps: Vec<String>) -> Self {
self.capabilities = caps;
self
}
pub fn register_fqdn(&mut self, fqdn: String) {
self.catalog.push(fqdn);
}
}
#[async_trait]
impl ModuleHandler for BaseModuleHandler {
async fn handle_message(&mut self, msg: CommandMessage) -> CommandMessage {
if msg.subtopic() == "get_catalog" {
return msg.into_response(serde_json::to_value(&self.catalog).unwrap_or(serde_json::Value::Null));
}
let subtopic = msg.subtopic().to_string();
let error_msg = format!("Command '{}' not implemented", subtopic);
msg.into_error_response(&error_msg)
}
async fn on_initialize(&mut self) -> Result<(), anyhow::Error> {
log::info!("Module {} initialized", self.domain);
Ok(())
}
async fn on_finalize(&mut self) -> Result<(), anyhow::Error> {
log::info!("Module {} finalized", self.domain);
Ok(())
}
fn domain(&self) -> &str {
&self.domain
}
fn version(&self) -> &str {
&self.version
}
fn capabilities(&self) -> Vec<String> {
self.capabilities.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestModule {
domain: String,
initialized: bool,
}
impl TestModule {
fn new(domain: &str) -> Self {
Self {
domain: domain.to_string(),
initialized: false,
}
}
}
#[async_trait]
impl ModuleHandler for TestModule {
async fn handle_message(&mut self, msg: CommandMessage) -> CommandMessage {
let subtopic = msg.subtopic().to_string();
msg.into_response(serde_json::json!({"echo": subtopic}))
}
async fn on_initialize(&mut self) -> Result<(), anyhow::Error> {
self.initialized = true;
Ok(())
}
async fn on_finalize(&mut self) -> Result<(), anyhow::Error> {
self.initialized = false;
Ok(())
}
fn domain(&self) -> &str {
&self.domain
}
}
#[tokio::test]
async fn test_module_handler() {
let mut module = TestModule::new("TEST");
module.on_initialize().await.unwrap();
assert!(module.initialized);
let msg = CommandMessage::read("TEST.ping");
let response = module.handle_message(msg).await;
assert!(response.success);
assert_eq!(response.data["echo"], "ping");
module.on_finalize().await.unwrap();
assert!(!module.initialized);
}
}