agent_diva_core/bus/
queue.rs1use super::events::{AgentBusEvent, AgentEvent, InboundMessage, OutboundMessage};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{broadcast, mpsc, RwLock};
7use tracing::debug;
8
9pub type OutboundSender = mpsc::UnboundedSender<OutboundMessage>;
11pub type OutboundReceiver = mpsc::UnboundedReceiver<OutboundMessage>;
12
13type OutboundCallback = Arc<
14 dyn Fn(OutboundMessage) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
15 + Send
16 + Sync,
17>;
18
19#[derive(Clone)]
24pub struct MessageBus {
25 inbound_tx: mpsc::UnboundedSender<InboundMessage>,
27 inbound_rx: Arc<RwLock<Option<mpsc::UnboundedReceiver<InboundMessage>>>>,
28 outbound_tx: mpsc::UnboundedSender<OutboundMessage>,
30 outbound_rx: Arc<RwLock<Option<mpsc::UnboundedReceiver<OutboundMessage>>>>,
31 subscribers: Arc<RwLock<HashMap<String, Vec<OutboundCallback>>>>,
33 event_tx: broadcast::Sender<AgentBusEvent>,
35 running: Arc<RwLock<bool>>,
37}
38
39impl MessageBus {
40 pub fn new() -> Self {
42 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
43 let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
44 let (event_tx, _) = broadcast::channel(1024);
45
46 Self {
47 inbound_tx,
48 inbound_rx: Arc::new(RwLock::new(Some(inbound_rx))),
49 outbound_tx,
50 outbound_rx: Arc::new(RwLock::new(Some(outbound_rx))),
51 subscribers: Arc::new(RwLock::new(HashMap::new())),
52 event_tx,
53 running: Arc::new(RwLock::new(false)),
54 }
55 }
56
57 pub fn publish_event(
59 &self,
60 channel: impl Into<String>,
61 chat_id: impl Into<String>,
62 event: AgentEvent,
63 ) -> crate::Result<()> {
64 let bus_event = AgentBusEvent {
65 channel: channel.into(),
66 chat_id: chat_id.into(),
67 event,
68 };
69 let _ = self.event_tx.send(bus_event);
71 Ok(())
72 }
73
74 pub fn subscribe_events(&self) -> broadcast::Receiver<AgentBusEvent> {
76 self.event_tx.subscribe()
77 }
78
79 pub async fn take_inbound_receiver(&self) -> Option<mpsc::UnboundedReceiver<InboundMessage>> {
81 self.inbound_rx.write().await.take()
82 }
83
84 pub async fn take_outbound_receiver(&self) -> Option<mpsc::UnboundedReceiver<OutboundMessage>> {
86 self.outbound_rx.write().await.take()
87 }
88
89 pub fn publish_inbound(&self, msg: InboundMessage) -> crate::Result<()> {
91 self.inbound_tx
92 .send(msg)
93 .map_err(|_| crate::Error::Channel("Inbound channel closed".to_string()))
94 }
95
96 pub fn publish_outbound(&self, msg: OutboundMessage) -> crate::Result<()> {
98 self.outbound_tx
99 .send(msg)
100 .map_err(|_| crate::Error::Channel("Outbound channel closed".to_string()))
101 }
102
103 pub async fn subscribe_outbound<F, Fut>(&self, channel: impl Into<String>, callback: F)
105 where
106 F: Fn(OutboundMessage) -> Fut + Send + Sync + 'static,
107 Fut: std::future::Future<Output = ()> + Send + 'static,
108 {
109 let channel = channel.into();
110 let wrapped: OutboundCallback = Arc::new(move |msg| Box::pin(callback(msg)));
111
112 let mut subscribers = self.subscribers.write().await;
113 subscribers.entry(channel).or_default().push(wrapped);
114 }
115
116 pub async fn dispatch_outbound_loop(&self) {
119 let mut outbound_rx = match self.take_outbound_receiver().await {
120 Some(rx) => rx,
121 None => {
122 debug!("Outbound receiver already taken");
123 return;
124 }
125 };
126
127 *self.running.write().await = true;
128 debug!("Starting outbound dispatcher");
129
130 while *self.running.read().await {
131 tokio::select! {
132 Some(msg) = outbound_rx.recv() => {
133 let channel = msg.channel.clone();
134 let subscribers = self.subscribers.read().await;
135
136 if let Some(callbacks) = subscribers.get(&channel) {
137 for callback in callbacks {
138 let msg_clone = msg.clone();
139 let future = callback(msg_clone);
140 tokio::spawn(async move {
142 future.await;
143 });
144 }
145 } else {
146 debug!("No subscribers for channel: {}", channel);
147 }
148 }
149 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
150 continue;
152 }
153 }
154 }
155
156 debug!("Outbound dispatcher stopped");
157 }
158
159 pub async fn stop(&self) {
161 *self.running.write().await = false;
162 }
163
164 pub async fn is_running(&self) -> bool {
166 *self.running.read().await
167 }
168}
169
170impl Default for MessageBus {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[tokio::test]
181 async fn test_message_bus_creation() {
182 let bus = MessageBus::new();
183 assert!(!bus.is_running().await);
184 }
185
186 #[tokio::test]
187 async fn test_publish_inbound() {
188 let bus = MessageBus::new();
189 let mut inbound_rx = bus.take_inbound_receiver().await.unwrap();
190
191 let msg = InboundMessage::new("test", "user1", "chat1", "Hello");
192 assert!(bus.publish_inbound(msg.clone()).is_ok());
193
194 let received = inbound_rx.try_recv();
196 assert!(received.is_ok());
197 }
198
199 #[tokio::test]
200 async fn test_subscribe_outbound() {
201 let bus = MessageBus::new();
202
203 bus.subscribe_outbound("telegram", |_msg| async move {
204 })
206 .await;
207
208 assert!(!bus.is_running().await);
210 }
211}