Skip to main content

agent_diva_core/bus/
queue.rs

1//! Async message queue implementation
2
3use super::events::{AgentBusEvent, AgentEvent, InboundMessage, OutboundMessage};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{broadcast, mpsc, RwLock};
7use tracing::debug;
8
9/// Type alias for message channel senders
10pub type OutboundSender = mpsc::UnboundedSender<OutboundMessage>;
11pub type OutboundReceiver = mpsc::UnboundedReceiver<OutboundMessage>;
12
13type OutboundCallback = Arc<
14    dyn Fn(OutboundMessage) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
15        + Send
16        + Sync,
17>;
18
19/// Async message bus that decouples chat channels from the agent core
20///
21/// Channels push messages to the inbound queue, and the agent processes
22/// them and pushes responses to the outbound queue.
23#[derive(Clone)]
24pub struct MessageBus {
25    /// Inbound messages from channels
26    inbound_tx: mpsc::UnboundedSender<InboundMessage>,
27    inbound_rx: Arc<RwLock<Option<mpsc::UnboundedReceiver<InboundMessage>>>>,
28    /// Outbound messages to channels
29    outbound_tx: mpsc::UnboundedSender<OutboundMessage>,
30    outbound_rx: Arc<RwLock<Option<mpsc::UnboundedReceiver<OutboundMessage>>>>,
31    /// Outbound subscribers by channel
32    subscribers: Arc<RwLock<HashMap<String, Vec<OutboundCallback>>>>,
33    /// Event broadcast channel
34    event_tx: broadcast::Sender<AgentBusEvent>,
35    /// Running state
36    running: Arc<RwLock<bool>>,
37}
38
39impl MessageBus {
40    /// Create a new message bus
41    pub fn new() -> Self {
42        let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
43        let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
44        let (event_tx, _) = broadcast::channel(1024);
45
46        Self {
47            inbound_tx,
48            inbound_rx: Arc::new(RwLock::new(Some(inbound_rx))),
49            outbound_tx,
50            outbound_rx: Arc::new(RwLock::new(Some(outbound_rx))),
51            subscribers: Arc::new(RwLock::new(HashMap::new())),
52            event_tx,
53            running: Arc::new(RwLock::new(false)),
54        }
55    }
56
57    /// Publish an event to the broadcast channel
58    pub fn publish_event(
59        &self,
60        channel: impl Into<String>,
61        chat_id: impl Into<String>,
62        event: AgentEvent,
63    ) -> crate::Result<()> {
64        let bus_event = AgentBusEvent {
65            channel: channel.into(),
66            chat_id: chat_id.into(),
67            event,
68        };
69        // We ignore the error if there are no receivers
70        let _ = self.event_tx.send(bus_event);
71        Ok(())
72    }
73
74    /// Subscribe to the event broadcast channel
75    pub fn subscribe_events(&self) -> broadcast::Receiver<AgentBusEvent> {
76        self.event_tx.subscribe()
77    }
78
79    /// Take the inbound receiver (can only be called once)
80    pub async fn take_inbound_receiver(&self) -> Option<mpsc::UnboundedReceiver<InboundMessage>> {
81        self.inbound_rx.write().await.take()
82    }
83
84    /// Take the outbound receiver (can only be called once)
85    pub async fn take_outbound_receiver(&self) -> Option<mpsc::UnboundedReceiver<OutboundMessage>> {
86        self.outbound_rx.write().await.take()
87    }
88
89    /// Publish a message from a channel to the agent
90    pub fn publish_inbound(&self, msg: InboundMessage) -> crate::Result<()> {
91        self.inbound_tx
92            .send(msg)
93            .map_err(|_| crate::Error::Channel("Inbound channel closed".to_string()))
94    }
95
96    /// Publish a response from the agent to channels
97    pub fn publish_outbound(&self, msg: OutboundMessage) -> crate::Result<()> {
98        self.outbound_tx
99            .send(msg)
100            .map_err(|_| crate::Error::Channel("Outbound channel closed".to_string()))
101    }
102
103    /// Subscribe to outbound messages for a specific channel with a callback
104    pub async fn subscribe_outbound<F, Fut>(&self, channel: impl Into<String>, callback: F)
105    where
106        F: Fn(OutboundMessage) -> Fut + Send + Sync + 'static,
107        Fut: std::future::Future<Output = ()> + Send + 'static,
108    {
109        let channel = channel.into();
110        let wrapped: OutboundCallback = Arc::new(move |msg| Box::pin(callback(msg)));
111
112        let mut subscribers = self.subscribers.write().await;
113        subscribers.entry(channel).or_default().push(wrapped);
114    }
115
116    /// Dispatch outbound messages to subscribed channels
117    /// Run this as a background task
118    pub async fn dispatch_outbound_loop(&self) {
119        let mut outbound_rx = match self.take_outbound_receiver().await {
120            Some(rx) => rx,
121            None => {
122                debug!("Outbound receiver already taken");
123                return;
124            }
125        };
126
127        *self.running.write().await = true;
128        debug!("Starting outbound dispatcher");
129
130        while *self.running.read().await {
131            tokio::select! {
132                Some(msg) = outbound_rx.recv() => {
133                    let channel = msg.channel.clone();
134                    let subscribers = self.subscribers.read().await;
135
136                    if let Some(callbacks) = subscribers.get(&channel) {
137                        for callback in callbacks {
138                            let msg_clone = msg.clone();
139                            let future = callback(msg_clone);
140                            // Spawn to avoid blocking
141                            tokio::spawn(async move {
142                                future.await;
143                            });
144                        }
145                    } else {
146                        debug!("No subscribers for channel: {}", channel);
147                    }
148                }
149                _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
150                    // Check running state periodically
151                    continue;
152                }
153            }
154        }
155
156        debug!("Outbound dispatcher stopped");
157    }
158
159    /// Stop the dispatcher loop
160    pub async fn stop(&self) {
161        *self.running.write().await = false;
162    }
163
164    /// Check if the bus is running
165    pub async fn is_running(&self) -> bool {
166        *self.running.read().await
167    }
168}
169
170impl Default for MessageBus {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[tokio::test]
181    async fn test_message_bus_creation() {
182        let bus = MessageBus::new();
183        assert!(!bus.is_running().await);
184    }
185
186    #[tokio::test]
187    async fn test_publish_inbound() {
188        let bus = MessageBus::new();
189        let mut inbound_rx = bus.take_inbound_receiver().await.unwrap();
190
191        let msg = InboundMessage::new("test", "user1", "chat1", "Hello");
192        assert!(bus.publish_inbound(msg.clone()).is_ok());
193
194        // Verify message was received
195        let received = inbound_rx.try_recv();
196        assert!(received.is_ok());
197    }
198
199    #[tokio::test]
200    async fn test_subscribe_outbound() {
201        let bus = MessageBus::new();
202
203        bus.subscribe_outbound("telegram", |_msg| async move {
204            // Callback function
205        })
206        .await;
207
208        // Check bus is not running yet
209        assert!(!bus.is_running().await);
210    }
211}