Skip to main content

a3s_code_core/orchestrator/
agent.rs

1//! Agent Orchestrator 核心实现
2
3use crate::error::Result;
4use crate::orchestrator::{
5    ControlSignal, OrchestratorConfig, OrchestratorEvent, SubAgentActivity, SubAgentConfig,
6    SubAgentHandle, SubAgentInfo, SubAgentState,
7};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{broadcast, RwLock};
11
12/// Agent Orchestrator - 主子智能体协调器
13///
14/// 基于事件总线实现统一的监控和控制机制。
15/// 默认使用内存事件通讯,支持用户自定义 NATS provider。
16pub struct AgentOrchestrator {
17    /// 配置
18    config: OrchestratorConfig,
19
20    /// 事件广播通道
21    event_tx: broadcast::Sender<OrchestratorEvent>,
22
23    /// SubAgent 注册表
24    subagents: Arc<RwLock<HashMap<String, SubAgentHandle>>>,
25
26    /// 下一个 SubAgent ID
27    next_id: Arc<RwLock<u64>>,
28}
29
30impl AgentOrchestrator {
31    /// 创建新的 orchestrator(使用内存事件通讯)
32    ///
33    /// 这是默认的创建方式,适用于单进程场景。
34    pub fn new_memory() -> Self {
35        Self::new(OrchestratorConfig::default())
36    }
37
38    /// 使用自定义配置创建 orchestrator
39    pub fn new(config: OrchestratorConfig) -> Self {
40        let (event_tx, _) = broadcast::channel(config.event_buffer_size);
41
42        Self {
43            config,
44            event_tx,
45            subagents: Arc::new(RwLock::new(HashMap::new())),
46            next_id: Arc::new(RwLock::new(1)),
47        }
48    }
49
50    /// 订阅所有 SubAgent 事件
51    ///
52    /// 返回一个接收器,可以接收所有 SubAgent 的事件。
53    pub fn subscribe_all(&self) -> broadcast::Receiver<OrchestratorEvent> {
54        self.event_tx.subscribe()
55    }
56
57    /// 订阅特定 SubAgent 的事件
58    ///
59    /// 返回一个过滤后的接收器,只接收指定 SubAgent 的事件。
60    pub fn subscribe_subagent(&self, id: &str) -> SubAgentEventStream {
61        let rx = self.event_tx.subscribe();
62        SubAgentEventStream {
63            rx,
64            filter_id: id.to_string(),
65        }
66    }
67
68    /// 启动新的 SubAgent
69    ///
70    /// 返回 SubAgent 句柄,可用于控制和查询状态。
71    pub async fn spawn_subagent(&self, config: SubAgentConfig) -> Result<SubAgentHandle> {
72        // 检查并发限制
73        {
74            let subagents = self.subagents.read().await;
75            let active_count = subagents
76                .values()
77                .filter(|h| !h.state().is_terminal())
78                .count();
79
80            if active_count >= self.config.max_concurrent_subagents {
81                return Err(anyhow::anyhow!(
82                    "Maximum concurrent subagents ({}) reached",
83                    self.config.max_concurrent_subagents
84                )
85                .into());
86            }
87        }
88
89        // 生成 SubAgent ID
90        let id = {
91            let mut next_id = self.next_id.write().await;
92            let id = format!("subagent-{}", *next_id);
93            *next_id += 1;
94            id
95        };
96
97        // 创建控制通道
98        let (control_tx, control_rx) = tokio::sync::mpsc::channel(self.config.control_buffer_size);
99
100        // 创建状态
101        let state = Arc::new(RwLock::new(SubAgentState::Initializing));
102
103        // 创建活动跟踪
104        let activity = Arc::new(RwLock::new(SubAgentActivity::Idle));
105
106        // 发布启动事件
107        let _ = self.event_tx.send(OrchestratorEvent::SubAgentStarted {
108            id: id.clone(),
109            agent_type: config.agent_type.clone(),
110            description: config.description.clone(),
111            parent_id: config.parent_id.clone(),
112            config: config.clone(),
113        });
114
115        // 创建 SubAgentWrapper 并启动执行
116        let wrapper = crate::orchestrator::wrapper::SubAgentWrapper::new(
117            id.clone(),
118            config.clone(),
119            self.event_tx.clone(),
120            control_rx,
121            state.clone(),
122            activity.clone(),
123        );
124
125        let task_handle = tokio::spawn(async move { wrapper.execute().await });
126
127        // 创建句柄
128        let handle = SubAgentHandle::new(
129            id.clone(),
130            config,
131            control_tx,
132            state.clone(),
133            activity.clone(),
134            task_handle,
135        );
136
137        // 注册到 orchestrator
138        self.subagents
139            .write()
140            .await
141            .insert(id.clone(), handle.clone());
142
143        Ok(handle)
144    }
145
146    /// 发送控制信号到 SubAgent
147    pub async fn send_control(&self, id: &str, signal: ControlSignal) -> Result<()> {
148        let subagents = self.subagents.read().await;
149        let handle = subagents
150            .get(id)
151            .ok_or_else(|| anyhow::anyhow!("SubAgent '{}' not found", id))?;
152
153        handle.send_control(signal.clone()).await?;
154
155        // 发布控制信号接收事件
156        let _ = self
157            .event_tx
158            .send(OrchestratorEvent::ControlSignalReceived {
159                id: id.to_string(),
160                signal,
161            });
162
163        Ok(())
164    }
165
166    /// 暂停 SubAgent
167    pub async fn pause_subagent(&self, id: &str) -> Result<()> {
168        self.send_control(id, ControlSignal::Pause).await
169    }
170
171    /// 恢复 SubAgent
172    pub async fn resume_subagent(&self, id: &str) -> Result<()> {
173        self.send_control(id, ControlSignal::Resume).await
174    }
175
176    /// 取消 SubAgent
177    pub async fn cancel_subagent(&self, id: &str) -> Result<()> {
178        self.send_control(id, ControlSignal::Cancel).await
179    }
180
181    /// 调整 SubAgent 参数
182    pub async fn adjust_subagent_params(
183        &self,
184        id: &str,
185        max_steps: Option<usize>,
186        timeout_ms: Option<u64>,
187    ) -> Result<()> {
188        self.send_control(
189            id,
190            ControlSignal::AdjustParams {
191                max_steps,
192                timeout_ms,
193            },
194        )
195        .await
196    }
197
198    /// 获取 SubAgent 状态
199    pub async fn get_subagent_state(&self, id: &str) -> Option<SubAgentState> {
200        let subagents = self.subagents.read().await;
201        subagents.get(id).map(|h| h.state())
202    }
203
204    /// 获取所有 SubAgent 的状态
205    pub async fn get_all_states(&self) -> HashMap<String, SubAgentState> {
206        let subagents = self.subagents.read().await;
207        subagents
208            .iter()
209            .map(|(id, handle)| (id.clone(), handle.state()))
210            .collect()
211    }
212
213    /// 获取活跃的 SubAgent 数量
214    pub async fn active_count(&self) -> usize {
215        let subagents = self.subagents.read().await;
216        subagents
217            .values()
218            .filter(|h| !h.state().is_terminal())
219            .count()
220    }
221
222    /// 等待所有 SubAgent 完成
223    pub async fn wait_all(&self) -> Result<()> {
224        loop {
225            let active = self.active_count().await;
226            if active == 0 {
227                break;
228            }
229            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
230        }
231        Ok(())
232    }
233
234    /// 获取所有 SubAgent 的信息列表
235    pub async fn list_subagents(&self) -> Vec<SubAgentInfo> {
236        let subagents = self.subagents.read().await;
237        let mut infos = Vec::new();
238
239        for (id, handle) in subagents.iter() {
240            let state = handle.state_async().await;
241            let activity = handle.activity().await;
242            let config = handle.config();
243
244            infos.push(SubAgentInfo {
245                id: id.clone(),
246                agent_type: config.agent_type.clone(),
247                description: config.description.clone(),
248                state: format!("{:?}", state),
249                parent_id: config.parent_id.clone(),
250                created_at: handle.created_at(),
251                updated_at: std::time::SystemTime::now()
252                    .duration_since(std::time::UNIX_EPOCH)
253                    .unwrap()
254                    .as_millis() as u64,
255                current_activity: Some(activity),
256            });
257        }
258
259        infos
260    }
261
262    /// 获取特定 SubAgent 的详细信息
263    pub async fn get_subagent_info(&self, id: &str) -> Option<SubAgentInfo> {
264        let subagents = self.subagents.read().await;
265        let handle = subagents.get(id)?;
266
267        let state = handle.state_async().await;
268        let activity = handle.activity().await;
269        let config = handle.config();
270
271        Some(SubAgentInfo {
272            id: id.to_string(),
273            agent_type: config.agent_type.clone(),
274            description: config.description.clone(),
275            state: format!("{:?}", state),
276            parent_id: config.parent_id.clone(),
277            created_at: handle.created_at(),
278            updated_at: std::time::SystemTime::now()
279                .duration_since(std::time::UNIX_EPOCH)
280                .unwrap()
281                .as_millis() as u64,
282            current_activity: Some(activity),
283        })
284    }
285
286    /// 获取所有活跃 SubAgent 的当前活动
287    pub async fn get_active_activities(&self) -> HashMap<String, SubAgentActivity> {
288        let subagents = self.subagents.read().await;
289        let mut activities = HashMap::new();
290
291        for (id, handle) in subagents.iter() {
292            if !handle.state().is_terminal() {
293                let activity = handle.activity().await;
294                activities.insert(id.clone(), activity);
295            }
296        }
297
298        activities
299    }
300
301    /// 获取 SubAgent 句柄(用于直接控制)
302    pub async fn get_handle(&self, id: &str) -> Option<SubAgentHandle> {
303        let subagents = self.subagents.read().await;
304        subagents.get(id).cloned()
305    }
306}
307
308impl std::fmt::Debug for AgentOrchestrator {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        f.debug_struct("AgentOrchestrator")
311            .field("event_buffer_size", &self.config.event_buffer_size)
312            .field(
313                "max_concurrent_subagents",
314                &self.config.max_concurrent_subagents,
315            )
316            .finish()
317    }
318}
319
320/// SubAgent 事件流(过滤特定 SubAgent 的事件)
321pub struct SubAgentEventStream {
322    rx: broadcast::Receiver<OrchestratorEvent>,
323    filter_id: String,
324}
325
326impl SubAgentEventStream {
327    /// 接收下一个事件
328    pub async fn recv(&mut self) -> Option<OrchestratorEvent> {
329        loop {
330            match self.rx.recv().await {
331                Ok(event) => {
332                    if let Some(id) = event.subagent_id() {
333                        if id == self.filter_id {
334                            return Some(event);
335                        }
336                    }
337                }
338                Err(_) => return None,
339            }
340        }
341    }
342}