use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, RwLock};
use uuid::Uuid;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum UnifiedMessage {
Session(SessionMessage),
Coordination(CoordinationMessage),
Task(TaskMessage),
Event(EventMessage),
Direct(DirectMessage),
Ipc(IpcMessage),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionMessage {
pub id: String,
pub session_id: String,
pub msg_type: SessionMessageType,
pub payload: serde_json::Value,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SessionMessageType {
Created,
Started,
Stopped,
Output,
Input,
StatusChange,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoordinationMessage {
pub id: String,
pub from_agent: String,
pub to_agent: Option<String>,
pub msg_type: CoordinationMessageType,
pub payload: serde_json::Value,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CoordinationMessageType {
TaskAssignment,
TaskAccepted,
TaskRejected,
TaskCompleted,
StatusUpdate,
HelpRequest,
KnowledgeShare,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskMessage {
pub id: String,
pub task_id: String,
pub msg_type: TaskMessageType,
pub payload: serde_json::Value,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TaskMessageType {
Created,
Assigned,
Started,
Progress,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EventMessage {
pub id: String,
pub source: String,
pub event_type: EventType,
pub payload: serde_json::Value,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EventType {
SystemStartup,
SystemShutdown,
AgentConnected,
AgentDisconnected,
Error,
Warning,
Info,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectMessage {
pub id: String,
pub from_agent: String,
pub to_agent: String,
pub content: String,
pub metadata: HashMap<String, serde_json::Value>,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IpcMessage {
pub id: String,
pub msg_type: String,
pub payload: serde_json::Value,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
pub struct UnifiedBus {
broadcast_tx: broadcast::Sender<UnifiedMessage>,
direct_channels: Arc<RwLock<HashMap<String, mpsc::Sender<UnifiedMessage>>>>,
topic_subscribers: Arc<RwLock<HashMap<String, Vec<mpsc::Sender<UnifiedMessage>>>>>,
message_history: Arc<RwLock<Vec<UnifiedMessage>>>,
history_limit: usize,
}
impl UnifiedBus {
pub fn new(history_limit: usize) -> Self {
let (broadcast_tx, _) = broadcast::channel(1024);
Self {
broadcast_tx,
direct_channels: Arc::new(RwLock::new(HashMap::new())),
topic_subscribers: Arc::new(RwLock::new(HashMap::new())),
message_history: Arc::new(RwLock::new(Vec::new())),
history_limit,
}
}
pub fn subscribe_all(&self) -> broadcast::Receiver<UnifiedMessage> {
self.broadcast_tx.subscribe()
}
pub async fn subscribe_topic(&self, topic: &str) -> mpsc::Receiver<UnifiedMessage> {
let (tx, rx) = mpsc::channel(256);
let mut subscribers = self.topic_subscribers.write().await;
subscribers.entry(topic.to_string()).or_default().push(tx);
rx
}
pub async fn register_agent(&self, agent_id: &str) -> mpsc::Receiver<UnifiedMessage> {
let (tx, rx) = mpsc::channel(256);
let mut channels = self.direct_channels.write().await;
channels.insert(agent_id.to_string(), tx);
rx
}
pub async fn unregister_agent(&self, agent_id: &str) {
let mut channels = self.direct_channels.write().await;
channels.remove(agent_id);
}
pub async fn send(&self, message: UnifiedMessage) -> Result<()> {
{
let mut history = self.message_history.write().await;
history.push(message.clone());
if history.len() > self.history_limit {
let drain_end = history.len() - self.history_limit;
history.drain(0..drain_end);
}
}
let _ = self.broadcast_tx.send(message.clone());
if let Some(topic) = self.get_message_topic(&message) {
let subscribers = self.topic_subscribers.read().await;
if let Some(subs) = subscribers.get(&topic) {
for sub in subs {
let _ = sub.send(message.clone()).await;
}
}
}
if let UnifiedMessage::Direct(ref msg) = message {
let channels = self.direct_channels.read().await;
if let Some(tx) = channels.get(&msg.to_agent) {
let _ = tx.send(message).await;
}
}
Ok(())
}
fn get_message_topic(&self, message: &UnifiedMessage) -> Option<String> {
match message {
UnifiedMessage::Session(_) => Some("session".to_string()),
UnifiedMessage::Coordination(_) => Some("coordination".to_string()),
UnifiedMessage::Task(_) => Some("task".to_string()),
UnifiedMessage::Event(_) => Some("event".to_string()),
UnifiedMessage::Direct(_) => None, UnifiedMessage::Ipc(_) => Some("ipc".to_string()),
}
}
pub async fn get_history(&self, limit: Option<usize>) -> Vec<UnifiedMessage> {
let history = self.message_history.read().await;
match limit {
Some(n) => history.iter().rev().take(n).cloned().collect(),
None => history.clone(),
}
}
pub fn create_session_message(
session_id: &str,
msg_type: SessionMessageType,
payload: serde_json::Value,
) -> UnifiedMessage {
UnifiedMessage::Session(SessionMessage {
id: Uuid::new_v4().to_string(),
session_id: session_id.to_string(),
msg_type,
payload,
timestamp: chrono::Utc::now(),
})
}
pub fn create_coordination_message(
from_agent: &str,
to_agent: Option<&str>,
msg_type: CoordinationMessageType,
payload: serde_json::Value,
) -> UnifiedMessage {
UnifiedMessage::Coordination(CoordinationMessage {
id: Uuid::new_v4().to_string(),
from_agent: from_agent.to_string(),
to_agent: to_agent.map(|s| s.to_string()),
msg_type,
payload,
timestamp: chrono::Utc::now(),
})
}
pub fn create_task_message(
task_id: &str,
msg_type: TaskMessageType,
payload: serde_json::Value,
) -> UnifiedMessage {
UnifiedMessage::Task(TaskMessage {
id: Uuid::new_v4().to_string(),
task_id: task_id.to_string(),
msg_type,
payload,
timestamp: chrono::Utc::now(),
})
}
pub fn create_event_message(
source: &str,
event_type: EventType,
payload: serde_json::Value,
) -> UnifiedMessage {
UnifiedMessage::Event(EventMessage {
id: Uuid::new_v4().to_string(),
source: source.to_string(),
event_type,
payload,
timestamp: chrono::Utc::now(),
})
}
pub fn create_direct_message(
from_agent: &str,
to_agent: &str,
content: &str,
) -> UnifiedMessage {
UnifiedMessage::Direct(DirectMessage {
id: Uuid::new_v4().to_string(),
from_agent: from_agent.to_string(),
to_agent: to_agent.to_string(),
content: content.to_string(),
metadata: HashMap::new(),
timestamp: chrono::Utc::now(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_unified_bus() {
let bus = UnifiedBus::new(100);
let mut all_rx = bus.subscribe_all();
let mut session_rx = bus.subscribe_topic("session").await;
let mut agent_rx = bus.register_agent("agent1").await;
let msg = UnifiedBus::create_session_message(
"session1",
SessionMessageType::Created,
serde_json::json!({"status": "ok"}),
);
bus.send(msg.clone()).await.unwrap();
let received = all_rx.recv().await.unwrap();
match received {
UnifiedMessage::Session(s) => assert_eq!(s.session_id, "session1"),
_ => panic!("Wrong message type"),
}
let received = session_rx.recv().await.unwrap();
match received {
UnifiedMessage::Session(s) => assert_eq!(s.session_id, "session1"),
_ => panic!("Wrong message type"),
}
let direct_msg = UnifiedBus::create_direct_message(
"agent2",
"agent1",
"Hello agent1",
);
bus.send(direct_msg).await.unwrap();
let received = agent_rx.recv().await.unwrap();
match received {
UnifiedMessage::Direct(d) => {
assert_eq!(d.to_agent, "agent1");
assert_eq!(d.content, "Hello agent1");
}
_ => panic!("Wrong message type"),
}
}
}