a3s_code_core/orchestrator/
agent.rs1use crate::error::Result;
4use crate::orchestrator::{
5 ControlSignal, OrchestratorConfig, OrchestratorEvent, SubAgentActivity, SubAgentConfig,
6 SubAgentHandle, SubAgentInfo, SubAgentState,
7};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{broadcast, RwLock};
11
12pub struct AgentOrchestrator {
17 config: OrchestratorConfig,
19
20 event_tx: broadcast::Sender<OrchestratorEvent>,
22
23 subagents: Arc<RwLock<HashMap<String, SubAgentHandle>>>,
25
26 next_id: Arc<RwLock<u64>>,
28}
29
30impl AgentOrchestrator {
31 pub fn new_memory() -> Self {
35 Self::new(OrchestratorConfig::default())
36 }
37
38 pub fn new(config: OrchestratorConfig) -> Self {
40 let (event_tx, _) = broadcast::channel(config.event_buffer_size);
41
42 Self {
43 config,
44 event_tx,
45 subagents: Arc::new(RwLock::new(HashMap::new())),
46 next_id: Arc::new(RwLock::new(1)),
47 }
48 }
49
50 pub fn subscribe_all(&self) -> broadcast::Receiver<OrchestratorEvent> {
54 self.event_tx.subscribe()
55 }
56
57 pub fn subscribe_subagent(&self, id: &str) -> SubAgentEventStream {
61 let rx = self.event_tx.subscribe();
62 SubAgentEventStream {
63 rx,
64 filter_id: id.to_string(),
65 }
66 }
67
68 pub async fn spawn_subagent(&self, config: SubAgentConfig) -> Result<SubAgentHandle> {
72 {
74 let subagents = self.subagents.read().await;
75 let active_count = subagents
76 .values()
77 .filter(|h| !h.state().is_terminal())
78 .count();
79
80 if active_count >= self.config.max_concurrent_subagents {
81 return Err(anyhow::anyhow!(
82 "Maximum concurrent subagents ({}) reached",
83 self.config.max_concurrent_subagents
84 )
85 .into());
86 }
87 }
88
89 let id = {
91 let mut next_id = self.next_id.write().await;
92 let id = format!("subagent-{}", *next_id);
93 *next_id += 1;
94 id
95 };
96
97 let (control_tx, control_rx) = tokio::sync::mpsc::channel(self.config.control_buffer_size);
99
100 let state = Arc::new(RwLock::new(SubAgentState::Initializing));
102
103 let activity = Arc::new(RwLock::new(SubAgentActivity::Idle));
105
106 let _ = self.event_tx.send(OrchestratorEvent::SubAgentStarted {
108 id: id.clone(),
109 agent_type: config.agent_type.clone(),
110 description: config.description.clone(),
111 parent_id: config.parent_id.clone(),
112 config: config.clone(),
113 });
114
115 let wrapper = crate::orchestrator::wrapper::SubAgentWrapper::new(
117 id.clone(),
118 config.clone(),
119 self.event_tx.clone(),
120 control_rx,
121 state.clone(),
122 activity.clone(),
123 );
124
125 let task_handle = tokio::spawn(async move { wrapper.execute().await });
126
127 let handle = SubAgentHandle::new(
129 id.clone(),
130 config,
131 control_tx,
132 state.clone(),
133 activity.clone(),
134 task_handle,
135 );
136
137 self.subagents
139 .write()
140 .await
141 .insert(id.clone(), handle.clone());
142
143 Ok(handle)
144 }
145
146 pub async fn send_control(&self, id: &str, signal: ControlSignal) -> Result<()> {
148 let subagents = self.subagents.read().await;
149 let handle = subagents
150 .get(id)
151 .ok_or_else(|| anyhow::anyhow!("SubAgent '{}' not found", id))?;
152
153 handle.send_control(signal.clone()).await?;
154
155 let _ = self
157 .event_tx
158 .send(OrchestratorEvent::ControlSignalReceived {
159 id: id.to_string(),
160 signal,
161 });
162
163 Ok(())
164 }
165
166 pub async fn pause_subagent(&self, id: &str) -> Result<()> {
168 self.send_control(id, ControlSignal::Pause).await
169 }
170
171 pub async fn resume_subagent(&self, id: &str) -> Result<()> {
173 self.send_control(id, ControlSignal::Resume).await
174 }
175
176 pub async fn cancel_subagent(&self, id: &str) -> Result<()> {
178 self.send_control(id, ControlSignal::Cancel).await
179 }
180
181 pub async fn adjust_subagent_params(
183 &self,
184 id: &str,
185 max_steps: Option<usize>,
186 timeout_ms: Option<u64>,
187 ) -> Result<()> {
188 self.send_control(
189 id,
190 ControlSignal::AdjustParams {
191 max_steps,
192 timeout_ms,
193 },
194 )
195 .await
196 }
197
198 pub async fn get_subagent_state(&self, id: &str) -> Option<SubAgentState> {
200 let subagents = self.subagents.read().await;
201 subagents.get(id).map(|h| h.state())
202 }
203
204 pub async fn get_all_states(&self) -> HashMap<String, SubAgentState> {
206 let subagents = self.subagents.read().await;
207 subagents
208 .iter()
209 .map(|(id, handle)| (id.clone(), handle.state()))
210 .collect()
211 }
212
213 pub async fn active_count(&self) -> usize {
215 let subagents = self.subagents.read().await;
216 subagents
217 .values()
218 .filter(|h| !h.state().is_terminal())
219 .count()
220 }
221
222 pub async fn wait_all(&self) -> Result<()> {
224 loop {
225 let active = self.active_count().await;
226 if active == 0 {
227 break;
228 }
229 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
230 }
231 Ok(())
232 }
233
234 pub async fn list_subagents(&self) -> Vec<SubAgentInfo> {
236 let subagents = self.subagents.read().await;
237 let mut infos = Vec::new();
238
239 for (id, handle) in subagents.iter() {
240 let state = handle.state_async().await;
241 let activity = handle.activity().await;
242 let config = handle.config();
243
244 infos.push(SubAgentInfo {
245 id: id.clone(),
246 agent_type: config.agent_type.clone(),
247 description: config.description.clone(),
248 state: format!("{:?}", state),
249 parent_id: config.parent_id.clone(),
250 created_at: handle.created_at(),
251 updated_at: std::time::SystemTime::now()
252 .duration_since(std::time::UNIX_EPOCH)
253 .unwrap()
254 .as_millis() as u64,
255 current_activity: Some(activity),
256 });
257 }
258
259 infos
260 }
261
262 pub async fn get_subagent_info(&self, id: &str) -> Option<SubAgentInfo> {
264 let subagents = self.subagents.read().await;
265 let handle = subagents.get(id)?;
266
267 let state = handle.state_async().await;
268 let activity = handle.activity().await;
269 let config = handle.config();
270
271 Some(SubAgentInfo {
272 id: id.to_string(),
273 agent_type: config.agent_type.clone(),
274 description: config.description.clone(),
275 state: format!("{:?}", state),
276 parent_id: config.parent_id.clone(),
277 created_at: handle.created_at(),
278 updated_at: std::time::SystemTime::now()
279 .duration_since(std::time::UNIX_EPOCH)
280 .unwrap()
281 .as_millis() as u64,
282 current_activity: Some(activity),
283 })
284 }
285
286 pub async fn get_active_activities(&self) -> HashMap<String, SubAgentActivity> {
288 let subagents = self.subagents.read().await;
289 let mut activities = HashMap::new();
290
291 for (id, handle) in subagents.iter() {
292 if !handle.state().is_terminal() {
293 let activity = handle.activity().await;
294 activities.insert(id.clone(), activity);
295 }
296 }
297
298 activities
299 }
300
301 pub async fn get_handle(&self, id: &str) -> Option<SubAgentHandle> {
303 let subagents = self.subagents.read().await;
304 subagents.get(id).cloned()
305 }
306}
307
308impl std::fmt::Debug for AgentOrchestrator {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 f.debug_struct("AgentOrchestrator")
311 .field("event_buffer_size", &self.config.event_buffer_size)
312 .field(
313 "max_concurrent_subagents",
314 &self.config.max_concurrent_subagents,
315 )
316 .finish()
317 }
318}
319
320pub struct SubAgentEventStream {
322 rx: broadcast::Receiver<OrchestratorEvent>,
323 filter_id: String,
324}
325
326impl SubAgentEventStream {
327 pub async fn recv(&mut self) -> Option<OrchestratorEvent> {
329 loop {
330 match self.rx.recv().await {
331 Ok(event) => {
332 if let Some(id) = event.subagent_id() {
333 if id == self.filter_id {
334 return Some(event);
335 }
336 }
337 }
338 Err(_) => return None,
339 }
340 }
341 }
342}