mofa_foundation/messaging/
mod.rs1use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::{RwLock, broadcast};
13
14#[derive(Clone)]
16pub struct MessageBus<T, U>
17where
18 T: Clone + Send + 'static,
19 U: Clone + Send + 'static,
20{
21 inbound: broadcast::Sender<T>,
23 outbound: broadcast::Sender<U>,
25 outbound_subscribers: Arc<RwLock<HashMap<String, Vec<OutboundCallback<U>>>>>,
27}
28
29type OutboundCallback<U> =
30 Arc<dyn Fn(U) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
31
32impl<T, U> MessageBus<T, U>
33where
34 T: Clone + Send + 'static,
35 U: Clone + Send + 'static,
36{
37 pub fn new(capacity: usize) -> Self {
39 let (inbound_tx, _) = broadcast::channel(capacity);
40 let (outbound_tx, _) = broadcast::channel(capacity);
41
42 Self {
43 inbound: inbound_tx,
44 outbound: outbound_tx,
45 outbound_subscribers: Arc::new(RwLock::new(HashMap::new())),
46 }
47 }
48
49 pub fn default_capacity() -> Self {
51 Self::new(100)
52 }
53
54 pub async fn publish_inbound(&self, msg: T) -> Result<(), broadcast::error::SendError<T>> {
56 self.inbound.send(msg)?;
57 Ok(())
58 }
59
60 pub fn subscribe_inbound(&self) -> broadcast::Receiver<T> {
62 self.inbound.subscribe()
63 }
64
65 pub async fn publish_outbound(&self, msg: U) -> Result<(), broadcast::error::SendError<U>> {
67 self.outbound.send(msg)?;
68 Ok(())
69 }
70
71 pub fn subscribe_outbound(&self) -> broadcast::Receiver<U> {
73 self.outbound.subscribe()
74 }
75
76 pub async fn subscribe_outbound_key<F, Fut>(&self, key: String, callback: F)
78 where
79 F: Fn(U) -> Fut + Send + Sync + 'static,
80 Fut: std::future::Future<Output = ()> + Send + 'static,
81 {
82 let mut subscribers = self.outbound_subscribers.write().await;
83 subscribers
84 .entry(key)
85 .or_insert_with(Vec::new)
86 .push(Arc::new(move |msg| Box::pin(callback(msg))));
87 }
88
89 pub fn inbound_subscriber_count(&self) -> usize {
91 self.inbound.receiver_count()
92 }
93
94 pub fn outbound_subscriber_count(&self) -> usize {
96 self.outbound.receiver_count()
97 }
98}
99
100impl<T, U> Default for MessageBus<T, U>
101where
102 T: Clone + Send + 'static,
103 U: Clone + Send + 'static,
104{
105 fn default() -> Self {
106 Self::default_capacity()
107 }
108}
109
110pub trait InboundMessage: Clone + Send {
112 fn session_key(&self) -> String;
114
115 fn content(&self) -> &str;
117
118 fn media(&self) -> &[String] {
120 &[]
121 }
122
123 fn metadata(&self) -> &HashMap<String, serde_json::Value> {
125 use std::sync::OnceLock;
126 static EMPTY: OnceLock<HashMap<String, serde_json::Value>> = OnceLock::new();
127 EMPTY.get_or_init(HashMap::new)
128 }
129}
130
131pub trait OutboundMessage: Clone + Send {
133 fn channel(&self) -> &str;
135
136 fn chat_id(&self) -> &str;
138
139 fn content(&self) -> &str;
141
142 fn routing_key(&self) -> String {
144 format!("{}:{}", self.channel(), self.chat_id())
145 }
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct SimpleInboundMessage {
151 pub channel: String,
153 pub sender_id: String,
155 pub chat_id: String,
157 pub content: String,
159 #[serde(default)]
161 pub media: Vec<String>,
162 #[serde(flatten)]
164 pub metadata: HashMap<String, serde_json::Value>,
165}
166
167impl InboundMessage for SimpleInboundMessage {
168 fn session_key(&self) -> String {
169 format!("{}:{}", self.channel, self.chat_id)
170 }
171
172 fn content(&self) -> &str {
173 &self.content
174 }
175
176 fn media(&self) -> &[String] {
177 &self.media
178 }
179
180 fn metadata(&self) -> &HashMap<String, serde_json::Value> {
181 &self.metadata
182 }
183}
184
185impl SimpleInboundMessage {
186 pub fn new(
188 channel: impl Into<String>,
189 sender_id: impl Into<String>,
190 chat_id: impl Into<String>,
191 content: impl Into<String>,
192 ) -> Self {
193 Self {
194 channel: channel.into(),
195 sender_id: sender_id.into(),
196 chat_id: chat_id.into(),
197 content: content.into(),
198 media: Vec::new(),
199 metadata: HashMap::new(),
200 }
201 }
202
203 pub fn with_media(mut self, media: Vec<String>) -> Self {
205 self.media = media;
206 self
207 }
208
209 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
211 self.metadata.insert(key.into(), value);
212 self
213 }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct SimpleOutboundMessage {
219 pub channel: String,
221 pub chat_id: String,
223 pub content: String,
225 #[serde(skip_serializing_if = "Option::is_none")]
227 pub reply_to: Option<String>,
228}
229
230impl OutboundMessage for SimpleOutboundMessage {
231 fn channel(&self) -> &str {
232 &self.channel
233 }
234
235 fn chat_id(&self) -> &str {
236 &self.chat_id
237 }
238
239 fn content(&self) -> &str {
240 &self.content
241 }
242}
243
244impl SimpleOutboundMessage {
245 pub fn new(
247 channel: impl Into<String>,
248 chat_id: impl Into<String>,
249 content: impl Into<String>,
250 ) -> Self {
251 Self {
252 channel: channel.into(),
253 chat_id: chat_id.into(),
254 content: content.into(),
255 reply_to: None,
256 }
257 }
258
259 pub fn with_reply_to(mut self, reply_to: impl Into<String>) -> Self {
261 self.reply_to = Some(reply_to.into());
262 self
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[tokio::test]
271 async fn test_message_bus_publish() {
272 let bus = MessageBus::<SimpleInboundMessage, SimpleOutboundMessage>::new(10);
273
274 let mut rx = bus.subscribe_inbound();
275 let msg = SimpleInboundMessage::new("test", "user", "chat", "Hello");
276
277 tokio::spawn(async move {
278 bus.publish_inbound(msg).await.unwrap();
279 });
280
281 let received = rx.recv().await.unwrap();
282 assert_eq!(received.content, "Hello");
283 }
284
285 #[tokio::test]
286 async fn test_multiple_subscribers() {
287 let bus = MessageBus::<String, String>::new(10);
288
289 let mut rx1 = bus.subscribe_inbound();
290 let mut rx2 = bus.subscribe_inbound();
291
292 bus.publish_inbound("test".to_string()).await.unwrap();
293
294 let received1 = rx1.recv().await.unwrap();
295 let received2 = rx2.recv().await.unwrap();
296
297 assert_eq!(received1, "test");
298 assert_eq!(received2, "test");
299 }
300
301 #[tokio::test]
302 async fn test_outbound_subscribe() {
303 let bus = MessageBus::<String, SimpleOutboundMessage>::new(10);
304
305 let mut rx = bus.subscribe_outbound();
306 let msg = SimpleOutboundMessage::new("telegram", "123", "Response");
307
308 bus.publish_outbound(msg).await.unwrap();
309
310 let received = rx.recv().await.unwrap();
311 assert_eq!(received.content, "Response");
312 }
313
314 #[test]
315 fn test_simple_inbound_message() {
316 let msg = SimpleInboundMessage::new("telegram", "user123", "chat456", "Hello");
317 assert_eq!(msg.session_key(), "telegram:chat456");
318 assert_eq!(msg.content(), "Hello");
319 }
320
321 #[test]
322 fn test_simple_inbound_message_with_media() {
323 let msg = SimpleInboundMessage::new("telegram", "user123", "chat456", "Hello")
324 .with_media(vec!["image.jpg".to_string()]);
325 assert_eq!(msg.media().len(), 1);
326 }
327
328 #[test]
329 fn test_simple_outbound_message() {
330 let msg = SimpleOutboundMessage::new("telegram", "123", "Response");
331 assert_eq!(msg.channel(), "telegram");
332 assert_eq!(msg.chat_id(), "123");
333 assert_eq!(msg.content(), "Response");
334 assert_eq!(msg.routing_key(), "telegram:123");
335 }
336
337 #[test]
338 fn test_simple_outbound_message_with_reply() {
339 let msg = SimpleOutboundMessage::new("telegram", "123", "Response").with_reply_to("msg456");
340 assert_eq!(msg.reply_to, Some("msg456".to_string()));
341 }
342}