1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use tokio::sync::mpsc;
9use uuid::Uuid;
10
11use crate::{AgentId, Error, Result};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct MessageId(Uuid);
16
17impl MessageId {
18 pub fn new() -> Self {
19 Self(Uuid::new_v4())
20 }
21}
22
23impl Default for MessageId {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub enum MessagePayload {
32 Start,
34 Data(serde_json::Value),
36 Complete,
38 Error(String),
40 StatusRequest,
42 StatusResponse(String),
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct Message {
49 pub id: MessageId,
50 pub from: AgentId,
51 pub to: AgentId,
52 pub payload: MessagePayload,
53 pub timestamp: i64,
54}
55
56impl Message {
57 pub fn new(from: AgentId, to: AgentId, payload: MessagePayload) -> Self {
59 Self {
60 id: MessageId::new(),
61 from,
62 to,
63 payload,
64 timestamp: chrono::Utc::now().timestamp(),
65 }
66 }
67
68 pub fn data(from: AgentId, to: AgentId, data: serde_json::Value) -> Self {
70 Self::new(from, to, MessagePayload::Data(data))
71 }
72
73 pub fn start(from: AgentId, to: AgentId) -> Self {
75 Self::new(from, to, MessagePayload::Start)
76 }
77
78 pub fn complete(from: AgentId, to: AgentId) -> Self {
80 Self::new(from, to, MessagePayload::Complete)
81 }
82
83 pub fn error(from: AgentId, to: AgentId, error: impl Into<String>) -> Self {
85 Self::new(from, to, MessagePayload::Error(error.into()))
86 }
87}
88
89pub struct MessageBus {
96 channels: HashMap<AgentId, mpsc::UnboundedSender<Message>>,
97}
98
99impl MessageBus {
100 pub fn new() -> Self {
102 Self {
103 channels: HashMap::new(),
104 }
105 }
106
107 pub fn register(&mut self, agent_id: AgentId) -> mpsc::UnboundedReceiver<Message> {
109 let (tx, rx) = mpsc::unbounded_channel();
110 self.channels.insert(agent_id, tx);
111 rx
112 }
113
114 pub fn unregister(&mut self, agent_id: &AgentId) {
116 self.channels.remove(agent_id);
117 }
118
119 pub async fn send(&self, message: Message) -> Result<()> {
121 let channel = self.channels
122 .get(&message.to)
123 .ok_or_else(|| Error::MessageBus(format!("Agent {} not registered", message.to)))?;
124
125 channel
126 .send(message)
127 .map_err(|e| Error::MessageBus(format!("Failed to send message: {}", e)))?;
128
129 Ok(())
130 }
131
132 pub async fn broadcast(&self, message: Message) -> Result<()> {
134 for channel in self.channels.values() {
135 channel
136 .send(message.clone())
137 .map_err(|e| Error::MessageBus(format!("Failed to broadcast: {}", e)))?;
138 }
139 Ok(())
140 }
141
142 pub fn agent_count(&self) -> usize {
144 self.channels.len()
145 }
146}
147
148impl Default for MessageBus {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn test_message_creation() {
160 let from = AgentId::new();
161 let to = AgentId::new();
162 let msg = Message::start(from, to);
163
164 assert_eq!(msg.from, from);
165 assert_eq!(msg.to, to);
166 assert!(matches!(msg.payload, MessagePayload::Start));
167 }
168
169 #[test]
170 fn test_message_bus_registration() {
171 let mut bus = MessageBus::new();
172 let agent_id = AgentId::new();
173
174 let _rx = bus.register(agent_id);
175 assert_eq!(bus.agent_count(), 1);
176
177 bus.unregister(&agent_id);
178 assert_eq!(bus.agent_count(), 0);
179 }
180
181 #[tokio::test]
182 async fn test_message_bus_send() {
183 let mut bus = MessageBus::new();
184 let agent1 = AgentId::new();
185 let agent2 = AgentId::new();
186
187 let mut rx = bus.register(agent2);
188 let _tx = bus.register(agent1);
189
190 let msg = Message::start(agent1, agent2);
191 bus.send(msg.clone()).await.unwrap();
192
193 let received = rx.recv().await.unwrap();
194 assert_eq!(received.id, msg.id);
195 assert_eq!(received.from, agent1);
196 }
197}