Skip to main content

mofa_kernel/bus/
mod.rs

1use crate::agent::AgentMetadata;
2use crate::message::AgentMessage;
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5use std::sync::Arc;
6use tokio::sync::{RwLock, broadcast};
7
8/// 通信模式枚举
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
10pub enum CommunicationMode {
11    /// 点对点通信(单发送方 -> 单接收方)
12    PointToPoint(String),
13    /// 广播通信(单发送方 -> 所有智能体)
14    Broadcast,
15    /// 订阅-发布通信(基于主题)
16    PubSub(String),
17}
18
19/// 通信总线核心结构体
20#[derive(Clone)]
21pub struct AgentBus {
22    /// 智能体-通信通道映射
23    agent_channels:
24        Arc<RwLock<HashMap<String, HashMap<CommunicationMode, broadcast::Sender<Vec<u8>>>>>>,
25    /// 主题-订阅者映射(PubSub 模式专用)
26    topic_subscribers: Arc<RwLock<HashMap<String, HashSet<String>>>>,
27    /// 广播通道
28    broadcast_channel: broadcast::Sender<Vec<u8>>,
29}
30
31impl AgentBus {
32    /// 创建通信总线实例
33    pub async fn new() -> anyhow::Result<Self> {
34        let (broadcast_sender, _) = broadcast::channel(100);
35        Ok(Self {
36            agent_channels: Arc::new(RwLock::new(HashMap::new())),
37            topic_subscribers: Arc::new(RwLock::new(HashMap::new())),
38            broadcast_channel: broadcast_sender,
39        })
40    }
41
42    /// 为智能体注册通信通道
43    pub async fn register_channel(
44        &self,
45        agent_metadata: &AgentMetadata,
46        mode: CommunicationMode,
47    ) -> anyhow::Result<()> {
48        let id = &agent_metadata.id;
49        let mut agent_channels = self.agent_channels.write().await;
50        let entry = agent_channels.entry(id.clone()).or_default();
51
52        // 如果是广播模式,不需要单独注册,使用全局广播通道
53        if matches!(mode, CommunicationMode::Broadcast) {
54            return Ok(());
55        }
56
57        // 如果通道已存在,直接返回
58        if entry.contains_key(&mode) {
59            return Ok(());
60        }
61
62        // 创建新的广播通道
63        let (sender, _) = broadcast::channel(100);
64        entry.insert(mode.clone(), sender);
65
66        // PubSub 模式需注册订阅者映射
67        if let CommunicationMode::PubSub(topic) = &mode {
68            let mut topic_subs = self.topic_subscribers.write().await;
69            topic_subs
70                .entry(topic.clone())
71                .or_default()
72                .insert(id.clone());
73        }
74
75        Ok(())
76    }
77
78    // 核心:完善点对点消息发送逻辑
79    pub async fn send_message(
80        &self,
81        sender_id: &str,
82        mode: CommunicationMode,
83        message: &AgentMessage,
84    ) -> anyhow::Result<()> {
85        let message_bytes = bincode::serialize(message)?;
86
87        match mode {
88            // 点对点模式:根据接收方 ID 查找通道并发送
89            CommunicationMode::PointToPoint(receiver_id) => {
90                let agent_channels = self.agent_channels.read().await;
91                // 1. 校验接收方是否存在并注册了对应通道
92                let Some(receiver_channels) = agent_channels.get(&receiver_id) else {
93                    return Err(anyhow::anyhow!("Receiver agent {} not found", receiver_id));
94                };
95                let Some(channel) =
96                    receiver_channels.get(&CommunicationMode::PointToPoint(sender_id.to_string()))
97                else {
98                    return Err(anyhow::anyhow!(
99                        "Receiver {} has no point-to-point channel with sender {}",
100                        receiver_id,
101                        sender_id
102                    ));
103                };
104                // 2. 发送消息
105                channel.send(message_bytes)?;
106            }
107            CommunicationMode::Broadcast => {
108                // 使用全局广播通道
109                self.broadcast_channel.send(message_bytes)?;
110            }
111            CommunicationMode::PubSub(ref topic) => {
112                let topic_subs = self.topic_subscribers.read().await;
113                let subscribers = topic_subs
114                    .get(topic)
115                    .ok_or_else(|| anyhow::anyhow!("No subscribers for topic: {}", topic))?;
116                let agent_channels = self.agent_channels.read().await;
117
118                for sub_id in subscribers {
119                    let Some(channels) = agent_channels.get(sub_id) else {
120                        continue;
121                    };
122                    let Some(channel) = channels.get(&mode) else {
123                        continue;
124                    };
125                    channel.send(message_bytes.clone())?;
126                }
127            }
128        }
129
130        Ok(())
131    }
132
133    pub async fn receive_message(
134        &self,
135        id: &str,
136        mode: CommunicationMode,
137    ) -> anyhow::Result<Option<AgentMessage>> {
138        let agent_channels = self.agent_channels.read().await;
139
140        // 处理广播模式
141        if matches!(mode, CommunicationMode::Broadcast) {
142            let mut receiver = self.broadcast_channel.subscribe();
143            match receiver.recv().await {
144                Ok(data) => {
145                    let message = bincode::deserialize(&data)?;
146                    Ok(Some(message))
147                }
148                Err(_) => Ok(None),
149            }
150        } else {
151            // 处理其他模式
152            let Some(channels) = agent_channels.get(id) else {
153                return Ok(None);
154            };
155            let Some(channel) = channels.get(&mode) else {
156                return Ok(None);
157            };
158
159            let mut receiver = channel.subscribe();
160            match receiver.recv().await {
161                Ok(data) => {
162                    let message = bincode::deserialize(&data)?;
163                    Ok(Some(message))
164                }
165                Err(_) => Ok(None),
166            }
167        }
168    }
169
170    pub async fn unsubscribe_topic(&self, id: &str, topic: &str) -> anyhow::Result<()> {
171        let mut topic_subs = self.topic_subscribers.write().await;
172        if let Some(subscribers) = topic_subs.get_mut(topic) {
173            subscribers.remove(id);
174            if subscribers.is_empty() {
175                topic_subs.remove(topic);
176            }
177        }
178        Ok(())
179    }
180}