Skip to main content

ai_session/
unified_bus.rs

1//! Unified message bus - Consolidates multiple communication channels
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::{RwLock, broadcast, mpsc};
8use uuid::Uuid;
9
10/// Unified message type that encompasses all communication
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub enum UnifiedMessage {
13    /// Session management message
14    Session(SessionMessage),
15    /// Agent coordination message
16    Coordination(CoordinationMessage),
17    /// Task management message
18    Task(TaskMessage),
19    /// System event message
20    Event(EventMessage),
21    /// Direct agent-to-agent message
22    Direct(DirectMessage),
23    /// IPC message
24    Ipc(IpcMessage),
25}
26
27/// Session management messages
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct SessionMessage {
30    pub id: String,
31    pub session_id: String,
32    pub msg_type: SessionMessageType,
33    pub payload: serde_json::Value,
34    pub timestamp: chrono::DateTime<chrono::Utc>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub enum SessionMessageType {
39    Created,
40    Started,
41    Stopped,
42    Output,
43    Input,
44    StatusChange,
45}
46
47/// Agent coordination messages
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct CoordinationMessage {
50    pub id: String,
51    pub from_agent: String,
52    pub to_agent: Option<String>,
53    pub msg_type: CoordinationMessageType,
54    pub payload: serde_json::Value,
55    pub timestamp: chrono::DateTime<chrono::Utc>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum CoordinationMessageType {
60    TaskAssignment,
61    TaskAccepted,
62    TaskRejected,
63    TaskCompleted,
64    StatusUpdate,
65    HelpRequest,
66    KnowledgeShare,
67}
68
69/// Task management messages
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct TaskMessage {
72    pub id: String,
73    pub task_id: String,
74    pub msg_type: TaskMessageType,
75    pub payload: serde_json::Value,
76    pub timestamp: chrono::DateTime<chrono::Utc>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum TaskMessageType {
81    Created,
82    Assigned,
83    Started,
84    Progress,
85    Completed,
86    Failed,
87    Cancelled,
88}
89
90/// System event messages
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct EventMessage {
93    pub id: String,
94    pub source: String,
95    pub event_type: EventType,
96    pub payload: serde_json::Value,
97    pub timestamp: chrono::DateTime<chrono::Utc>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub enum EventType {
102    SystemStartup,
103    SystemShutdown,
104    AgentConnected,
105    AgentDisconnected,
106    Error,
107    Warning,
108    Info,
109}
110
111/// Direct agent-to-agent messages
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct DirectMessage {
114    pub id: String,
115    pub from_agent: String,
116    pub to_agent: String,
117    pub content: String,
118    pub metadata: HashMap<String, serde_json::Value>,
119    pub timestamp: chrono::DateTime<chrono::Utc>,
120}
121
122/// IPC messages
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct IpcMessage {
125    pub id: String,
126    pub msg_type: String,
127    pub payload: serde_json::Value,
128    pub timestamp: chrono::DateTime<chrono::Utc>,
129}
130
131/// Unified message bus that handles all communication
132pub struct UnifiedBus {
133    /// Broadcast channel for all messages
134    broadcast_tx: broadcast::Sender<UnifiedMessage>,
135    /// Point-to-point channels for direct messages
136    direct_channels: Arc<RwLock<HashMap<String, mpsc::Sender<UnifiedMessage>>>>,
137    /// Topic subscribers
138    topic_subscribers: Arc<RwLock<HashMap<String, Vec<mpsc::Sender<UnifiedMessage>>>>>,
139    /// Message history
140    message_history: Arc<RwLock<Vec<UnifiedMessage>>>,
141    /// History size limit
142    history_limit: usize,
143}
144
145impl UnifiedBus {
146    /// Create new unified bus
147    pub fn new(history_limit: usize) -> Self {
148        let (broadcast_tx, _) = broadcast::channel(1024);
149
150        Self {
151            broadcast_tx,
152            direct_channels: Arc::new(RwLock::new(HashMap::new())),
153            topic_subscribers: Arc::new(RwLock::new(HashMap::new())),
154            message_history: Arc::new(RwLock::new(Vec::new())),
155            history_limit,
156        }
157    }
158
159    /// Subscribe to all messages
160    pub fn subscribe_all(&self) -> broadcast::Receiver<UnifiedMessage> {
161        self.broadcast_tx.subscribe()
162    }
163
164    /// Subscribe to specific topic
165    pub async fn subscribe_topic(&self, topic: &str) -> mpsc::Receiver<UnifiedMessage> {
166        let (tx, rx) = mpsc::channel(256);
167        let mut subscribers = self.topic_subscribers.write().await;
168        subscribers.entry(topic.to_string()).or_default().push(tx);
169        rx
170    }
171
172    /// Register direct channel for agent
173    pub async fn register_agent(&self, agent_id: &str) -> mpsc::Receiver<UnifiedMessage> {
174        let (tx, rx) = mpsc::channel(256);
175        let mut channels = self.direct_channels.write().await;
176        channels.insert(agent_id.to_string(), tx);
177        rx
178    }
179
180    /// Unregister agent
181    pub async fn unregister_agent(&self, agent_id: &str) {
182        let mut channels = self.direct_channels.write().await;
183        channels.remove(agent_id);
184    }
185
186    /// Send message to bus
187    pub async fn send(&self, message: UnifiedMessage) -> Result<()> {
188        // Add to history
189        {
190            let mut history = self.message_history.write().await;
191            history.push(message.clone());
192
193            // Trim history if needed
194            if history.len() > self.history_limit {
195                let drain_end = history.len() - self.history_limit;
196                history.drain(0..drain_end);
197            }
198        }
199
200        // Broadcast to all subscribers
201        let _ = self.broadcast_tx.send(message.clone());
202
203        // Send to topic subscribers
204        if let Some(topic) = self.get_message_topic(&message) {
205            let subscribers = self.topic_subscribers.read().await;
206            if let Some(subs) = subscribers.get(&topic) {
207                for sub in subs {
208                    let _ = sub.send(message.clone()).await;
209                }
210            }
211        }
212
213        // Send direct messages
214        if let UnifiedMessage::Direct(ref msg) = message {
215            let channels = self.direct_channels.read().await;
216            if let Some(tx) = channels.get(&msg.to_agent) {
217                let _ = tx.send(message).await;
218            }
219        }
220
221        Ok(())
222    }
223
224    /// Get message topic for routing
225    fn get_message_topic(&self, message: &UnifiedMessage) -> Option<String> {
226        match message {
227            UnifiedMessage::Session(_) => Some("session".to_string()),
228            UnifiedMessage::Coordination(_) => Some("coordination".to_string()),
229            UnifiedMessage::Task(_) => Some("task".to_string()),
230            UnifiedMessage::Event(_) => Some("event".to_string()),
231            UnifiedMessage::Direct(_) => None, // Direct messages don't use topics
232            UnifiedMessage::Ipc(_) => Some("ipc".to_string()),
233        }
234    }
235
236    /// Get message history
237    pub async fn get_history(&self, limit: Option<usize>) -> Vec<UnifiedMessage> {
238        let history = self.message_history.read().await;
239        match limit {
240            Some(n) => history.iter().rev().take(n).cloned().collect(),
241            None => history.clone(),
242        }
243    }
244
245    /// Create session message
246    pub fn create_session_message(
247        session_id: &str,
248        msg_type: SessionMessageType,
249        payload: serde_json::Value,
250    ) -> UnifiedMessage {
251        UnifiedMessage::Session(SessionMessage {
252            id: Uuid::new_v4().to_string(),
253            session_id: session_id.to_string(),
254            msg_type,
255            payload,
256            timestamp: chrono::Utc::now(),
257        })
258    }
259
260    /// Create coordination message
261    pub fn create_coordination_message(
262        from_agent: &str,
263        to_agent: Option<&str>,
264        msg_type: CoordinationMessageType,
265        payload: serde_json::Value,
266    ) -> UnifiedMessage {
267        UnifiedMessage::Coordination(CoordinationMessage {
268            id: Uuid::new_v4().to_string(),
269            from_agent: from_agent.to_string(),
270            to_agent: to_agent.map(|s| s.to_string()),
271            msg_type,
272            payload,
273            timestamp: chrono::Utc::now(),
274        })
275    }
276
277    /// Create task message
278    pub fn create_task_message(
279        task_id: &str,
280        msg_type: TaskMessageType,
281        payload: serde_json::Value,
282    ) -> UnifiedMessage {
283        UnifiedMessage::Task(TaskMessage {
284            id: Uuid::new_v4().to_string(),
285            task_id: task_id.to_string(),
286            msg_type,
287            payload,
288            timestamp: chrono::Utc::now(),
289        })
290    }
291
292    /// Create event message
293    pub fn create_event_message(
294        source: &str,
295        event_type: EventType,
296        payload: serde_json::Value,
297    ) -> UnifiedMessage {
298        UnifiedMessage::Event(EventMessage {
299            id: Uuid::new_v4().to_string(),
300            source: source.to_string(),
301            event_type,
302            payload,
303            timestamp: chrono::Utc::now(),
304        })
305    }
306
307    /// Create direct message
308    pub fn create_direct_message(
309        from_agent: &str,
310        to_agent: &str,
311        content: &str,
312    ) -> UnifiedMessage {
313        UnifiedMessage::Direct(DirectMessage {
314            id: Uuid::new_v4().to_string(),
315            from_agent: from_agent.to_string(),
316            to_agent: to_agent.to_string(),
317            content: content.to_string(),
318            metadata: HashMap::new(),
319            timestamp: chrono::Utc::now(),
320        })
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[tokio::test]
329    async fn test_unified_bus() {
330        let bus = UnifiedBus::new(100);
331
332        // Subscribe to all messages
333        let mut all_rx = bus.subscribe_all();
334
335        // Subscribe to session topic
336        let mut session_rx = bus.subscribe_topic("session").await;
337
338        // Register an agent
339        let mut agent_rx = bus.register_agent("agent1").await;
340
341        // Send a session message
342        let msg = UnifiedBus::create_session_message(
343            "session1",
344            SessionMessageType::Created,
345            serde_json::json!({"status": "ok"}),
346        );
347        bus.send(msg.clone()).await.unwrap();
348
349        // Check all subscriber received it
350        let received = all_rx.recv().await.unwrap();
351        match received {
352            UnifiedMessage::Session(s) => assert_eq!(s.session_id, "session1"),
353            _ => panic!("Wrong message type"),
354        }
355
356        // Check topic subscriber received it
357        let received = session_rx.recv().await.unwrap();
358        match received {
359            UnifiedMessage::Session(s) => assert_eq!(s.session_id, "session1"),
360            _ => panic!("Wrong message type"),
361        }
362
363        // Send direct message
364        let direct_msg = UnifiedBus::create_direct_message("agent2", "agent1", "Hello agent1");
365        bus.send(direct_msg).await.unwrap();
366
367        // Check agent received it
368        let received = agent_rx.recv().await.unwrap();
369        match received {
370            UnifiedMessage::Direct(d) => {
371                assert_eq!(d.to_agent, "agent1");
372                assert_eq!(d.content, "Hello agent1");
373            }
374            _ => panic!("Wrong message type"),
375        }
376    }
377}