a3s_code_core/orchestrator/
wrapper.rs1use crate::agent::AgentEvent;
5use crate::error::Result;
6use crate::orchestrator::{
7 ControlSignal, OrchestratorEvent, SubAgentActivity, SubAgentConfig, SubAgentState,
8};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11
12struct PendingToolCall {
13 id: String,
14 name: String,
15 args_buffer: String,
16 started_at: std::time::Instant,
17 emitted: bool,
18}
19
20fn parse_tool_args(raw: &str) -> serde_json::Value {
21 let trimmed = raw.trim();
22 if trimmed.is_empty() {
23 serde_json::Value::Null
24 } else {
25 serde_json::from_str(trimmed)
26 .unwrap_or_else(|_| serde_json::Value::String(trimmed.to_string()))
27 }
28}
29
30fn tool_duration_ms(started_at: std::time::Instant) -> u64 {
31 std::cmp::max(1, started_at.elapsed().as_millis() as u64)
32}
33
34pub struct SubAgentWrapper {
35 id: String,
36 config: SubAgentConfig,
37 agent: Option<Arc<crate::Agent>>,
39 event_tx: broadcast::Sender<OrchestratorEvent>,
40 subagent_event_tx: broadcast::Sender<OrchestratorEvent>,
41 event_history: Arc<RwLock<std::collections::VecDeque<OrchestratorEvent>>>,
42 control_rx: mpsc::Receiver<ControlSignal>,
43 state: Arc<RwLock<SubAgentState>>,
44 activity: Arc<RwLock<SubAgentActivity>>,
45 session_registry:
48 Arc<RwLock<std::collections::HashMap<String, Arc<crate::agent_api::AgentSession>>>>,
49}
50
51impl SubAgentWrapper {
52 #[allow(clippy::too_many_arguments)]
53 pub fn new(
54 id: String,
55 config: SubAgentConfig,
56 agent: Option<Arc<crate::Agent>>,
57 event_tx: broadcast::Sender<OrchestratorEvent>,
58 subagent_event_tx: broadcast::Sender<OrchestratorEvent>,
59 event_history: Arc<RwLock<std::collections::VecDeque<OrchestratorEvent>>>,
60 control_rx: mpsc::Receiver<ControlSignal>,
61 state: Arc<RwLock<SubAgentState>>,
62 activity: Arc<RwLock<SubAgentActivity>>,
63 session_registry: Arc<
64 RwLock<std::collections::HashMap<String, Arc<crate::agent_api::AgentSession>>>,
65 >,
66 ) -> Self {
67 Self {
68 id,
69 config,
70 agent,
71 event_tx,
72 subagent_event_tx,
73 event_history,
74 control_rx,
75 state,
76 activity,
77 session_registry,
78 }
79 }
80
81 async fn emit(&self, event: OrchestratorEvent) {
82 let _ = self.event_tx.send(event.clone());
83 let _ = self.subagent_event_tx.send(event.clone());
84
85 let mut history = self.event_history.write().await;
86 history.push_back(event);
87 while history.len() > 1024 {
88 history.pop_front();
89 }
90 }
91
92 async fn flush_tool_start(
93 &self,
94 pending_tool: &mut Option<PendingToolCall>,
95 ) -> std::time::Instant {
96 let pending = pending_tool
97 .as_mut()
98 .expect("flush_tool_start called without a pending tool");
99 if pending.emitted {
100 return pending.started_at;
101 }
102
103 let args = parse_tool_args(&pending.args_buffer);
104 self.emit(OrchestratorEvent::ToolExecutionStarted {
105 id: self.id.clone(),
106 tool_id: pending.id.clone(),
107 tool_name: pending.name.clone(),
108 args: args.clone(),
109 })
110 .await;
111
112 *self.activity.write().await = SubAgentActivity::CallingTool {
113 tool_name: pending.name.clone(),
114 args,
115 };
116 pending.emitted = true;
117 pending.started_at
118 }
119
120 pub async fn execute(mut self) -> Result<String> {
122 self.update_state(SubAgentState::Running).await;
123 let start = std::time::Instant::now();
124
125 let result = if let Some(agent) = self.agent.take() {
126 self.execute_with_agent(agent).await
127 } else {
128 self.execute_placeholder().await
129 };
130
131 let duration_ms = start.elapsed().as_millis() as u64;
132
133 match &result {
134 Ok(output) => {
135 self.update_state(SubAgentState::Completed {
136 success: true,
137 output: output.clone(),
138 })
139 .await;
140 self.emit(OrchestratorEvent::SubAgentCompleted {
141 id: self.id.clone(),
142 success: true,
143 output: output.clone(),
144 duration_ms,
145 token_usage: None,
146 })
147 .await;
148 }
149 Err(e) => {
150 let current = self.state.read().await.clone();
151 if !matches!(current, SubAgentState::Cancelled) {
152 self.update_state(SubAgentState::Error {
153 message: e.to_string(),
154 })
155 .await;
156 }
157 self.emit(OrchestratorEvent::SubAgentCompleted {
158 id: self.id.clone(),
159 success: false,
160 output: e.to_string(),
161 duration_ms,
162 token_usage: None,
163 })
164 .await;
165 }
166 }
167
168 result
169 }
170
171 async fn execute_with_agent(&mut self, agent: Arc<crate::Agent>) -> Result<String> {
176 let registry = crate::AgentRegistry::new();
178 for dir in &self.config.agent_dirs {
179 let agents = crate::load_agents_from_dir(std::path::Path::new(dir));
180 for def in agents {
181 registry.register(def);
182 }
183 }
184
185 let mut opts = crate::SessionOptions::new();
187
188 for dir in &self.config.agent_dirs {
190 opts = opts.with_agent_dir(dir.as_str());
191 }
192 if !self.config.skill_dirs.is_empty() {
193 opts = opts.with_skill_dirs(self.config.skill_dirs.iter().map(|s| s.as_str()));
194 }
195
196 if self.config.permissive {
198 let mut policy = crate::permissions::PermissionPolicy::permissive();
200
201 for rule in &self.config.permissive_deny {
203 policy = policy.deny(rule);
204 }
205
206 if let Some(def) = registry.get(&self.config.agent_type) {
208 for rule in &def.permissions.deny {
209 policy = policy.deny(&rule.rule);
210 }
211 }
212
213 opts = opts.with_permission_checker(Arc::new(policy));
214 }
215
216 if let Some(steps) = self.config.max_steps {
217 opts = opts.with_max_tool_rounds(steps);
218 }
219 if let Some(queue_cfg) = self.config.lane_config.clone() {
220 opts = opts.with_queue_config(queue_cfg);
221 }
222
223 let session = Arc::new(if let Some(def) = registry.get(&self.config.agent_type) {
226 agent.session_for_agent(&self.config.workspace, &def, Some(opts))?
227 } else {
228 agent.session(&self.config.workspace, Some(opts))?
229 });
230
231 self.session_registry
233 .write()
234 .await
235 .insert(self.id.clone(), Arc::clone(&session));
236
237 let (mut rx, _task) = session.stream(&self.config.prompt, None).await?;
239
240 let mut output = String::new();
241 let mut step: usize = 0;
242 let mut pending_tool: Option<PendingToolCall> = None;
243
244 loop {
245 while let Ok(signal) = self.control_rx.try_recv() {
247 self.handle_control_signal(signal).await?;
248 }
249
250 if matches!(*self.state.read().await, SubAgentState::Cancelled) {
252 drop(rx);
254 return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
255 }
256
257 while matches!(*self.state.read().await, SubAgentState::Paused) {
259 *self.activity.write().await = SubAgentActivity::WaitingForControl {
260 reason: "Paused by orchestrator".to_string(),
261 };
262 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
263 while let Ok(signal) = self.control_rx.try_recv() {
264 self.handle_control_signal(signal).await?;
265 }
266 if matches!(*self.state.read().await, SubAgentState::Cancelled) {
267 drop(rx);
268 return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
269 }
270 }
271
272 match rx.recv().await {
274 Some(AgentEvent::TurnStart { turn }) => {
275 *self.activity.write().await =
276 SubAgentActivity::RequestingLlm { message_count: 0 };
277 self.emit(OrchestratorEvent::SubAgentInternalEvent {
278 id: self.id.clone(),
279 event: AgentEvent::TurnStart { turn },
280 })
281 .await;
282 }
283 Some(AgentEvent::ToolStart { id, name }) => {
284 pending_tool = Some(PendingToolCall {
285 id,
286 name,
287 args_buffer: String::new(),
288 started_at: std::time::Instant::now(),
289 emitted: false,
290 });
291 }
292 Some(AgentEvent::ToolInputDelta { delta }) => {
293 if let Some(pending) = pending_tool.as_mut() {
294 pending.args_buffer.push_str(&delta);
295 }
296 self.emit(OrchestratorEvent::SubAgentInternalEvent {
297 id: self.id.clone(),
298 event: AgentEvent::ToolInputDelta { delta },
299 })
300 .await;
301 }
302 Some(AgentEvent::ToolEnd {
303 id,
304 name,
305 output: tool_out,
306 exit_code,
307 ..
308 }) => {
309 step += 1;
310 let started_at =
311 if pending_tool.as_ref().map(|p| p.id.as_str()) == Some(id.as_str()) {
312 self.flush_tool_start(&mut pending_tool).await
313 } else {
314 std::time::Instant::now()
315 };
316 *self.activity.write().await = SubAgentActivity::Idle;
317 self.emit(OrchestratorEvent::ToolExecutionCompleted {
318 id: self.id.clone(),
319 tool_id: id,
320 tool_name: name,
321 result: tool_out,
322 exit_code,
323 duration_ms: tool_duration_ms(started_at),
324 })
325 .await;
326 pending_tool = None;
327 self.emit(OrchestratorEvent::SubAgentProgress {
328 id: self.id.clone(),
329 step,
330 total_steps: self.config.max_steps.unwrap_or(0),
331 message: format!("Completed tool call {step}"),
332 })
333 .await;
334 }
335 Some(AgentEvent::TextDelta { text }) => {
336 if pending_tool.is_some() {
337 self.flush_tool_start(&mut pending_tool).await;
338 }
339 output.push_str(&text);
340 self.emit(OrchestratorEvent::SubAgentInternalEvent {
341 id: self.id.clone(),
342 event: AgentEvent::TextDelta { text },
343 })
344 .await;
345 }
346 Some(AgentEvent::ExternalTaskPending {
347 task_id,
348 session_id,
349 lane,
350 command_type,
351 payload,
352 timeout_ms,
353 }) => {
354 if pending_tool.is_some() {
355 self.flush_tool_start(&mut pending_tool).await;
356 }
357 self.emit(OrchestratorEvent::ExternalTaskPending {
358 id: self.id.clone(),
359 task_id,
360 lane,
361 command_type,
362 payload,
363 timeout_ms,
364 })
365 .await;
366 let _ = session_id;
368 }
369 Some(AgentEvent::ExternalTaskCompleted {
370 task_id,
371 session_id,
372 success,
373 }) => {
374 if pending_tool.is_some() {
375 self.flush_tool_start(&mut pending_tool).await;
376 }
377 self.emit(OrchestratorEvent::ExternalTaskCompleted {
378 id: self.id.clone(),
379 task_id,
380 success,
381 })
382 .await;
383 let _ = session_id;
384 }
385 Some(AgentEvent::End { text, .. }) => {
386 if pending_tool.is_some() {
387 self.flush_tool_start(&mut pending_tool).await;
388 }
389 output = text;
390 break;
391 }
392 Some(AgentEvent::Error { message }) => {
393 return Err(anyhow::anyhow!("Agent error: {message}").into());
394 }
395 Some(event) => {
397 if pending_tool.is_some() {
398 self.flush_tool_start(&mut pending_tool).await;
399 }
400 self.emit(OrchestratorEvent::SubAgentInternalEvent {
401 id: self.id.clone(),
402 event,
403 })
404 .await;
405 }
406 None => break, }
408 }
409
410 self.session_registry.write().await.remove(&self.id);
412
413 Ok(output)
414 }
415
416 async fn execute_placeholder(&mut self) -> Result<String> {
421 for step in 1..=5 {
422 while let Ok(signal) = self.control_rx.try_recv() {
423 self.handle_control_signal(signal).await?;
424 }
425
426 if matches!(*self.state.read().await, SubAgentState::Cancelled) {
427 return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
428 }
429
430 while matches!(*self.state.read().await, SubAgentState::Paused) {
431 *self.activity.write().await = SubAgentActivity::WaitingForControl {
432 reason: "Paused by orchestrator".to_string(),
433 };
434 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
435 while let Ok(signal) = self.control_rx.try_recv() {
436 self.handle_control_signal(signal).await?;
437 }
438 }
439
440 *self.activity.write().await = SubAgentActivity::CallingTool {
441 tool_name: "read".to_string(),
442 args: serde_json::json!({"path": "/tmp/file.txt"}),
443 };
444
445 self.emit(OrchestratorEvent::ToolExecutionStarted {
446 id: self.id.clone(),
447 tool_id: format!("tool-{step}"),
448 tool_name: "read".to_string(),
449 args: serde_json::json!({"path": "/tmp/file.txt"}),
450 })
451 .await;
452
453 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
454
455 *self.activity.write().await = SubAgentActivity::RequestingLlm { message_count: 3 };
456
457 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
458
459 *self.activity.write().await = SubAgentActivity::Idle;
460
461 self.emit(OrchestratorEvent::SubAgentProgress {
462 id: self.id.clone(),
463 step,
464 total_steps: 5,
465 message: format!("Step {step}/5 completed"),
466 })
467 .await;
468 }
469
470 Ok(format!(
471 "Placeholder result for SubAgent {} ({})",
472 self.id, self.config.agent_type
473 ))
474 }
475
476 async fn handle_control_signal(&mut self, signal: ControlSignal) -> Result<()> {
481 self.emit(OrchestratorEvent::ControlSignalReceived {
482 id: self.id.clone(),
483 signal: signal.clone(),
484 })
485 .await;
486
487 let result = match signal {
488 ControlSignal::Pause => {
489 self.update_state(SubAgentState::Paused).await;
490 Ok(())
491 }
492 ControlSignal::Resume => {
493 self.update_state(SubAgentState::Running).await;
494 Ok(())
495 }
496 ControlSignal::Cancel => {
497 self.update_state(SubAgentState::Cancelled).await;
498 Err(anyhow::anyhow!("Cancelled by orchestrator").into())
499 }
500 ControlSignal::AdjustParams { max_steps, .. } => {
501 if let Some(steps) = max_steps {
502 self.config.max_steps = Some(steps);
503 }
504 Ok(())
505 }
506 ControlSignal::InjectPrompt { ref prompt } => {
507 self.config.prompt.push('\n');
509 self.config.prompt.push_str(prompt);
510 Ok(())
511 }
512 };
513
514 self.emit(OrchestratorEvent::ControlSignalApplied {
515 id: self.id.clone(),
516 signal,
517 success: result.is_ok(),
518 error: result.as_ref().err().map(|e| format!("{e}")),
519 })
520 .await;
521
522 result
523 }
524
525 async fn update_state(&self, new_state: SubAgentState) {
526 let old_state = {
527 let mut state = self.state.write().await;
528 let old = state.clone();
529 *state = new_state.clone();
530 old
531 };
532
533 self.emit(OrchestratorEvent::SubAgentStateChanged {
534 id: self.id.clone(),
535 old_state,
536 new_state,
537 })
538 .await;
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::{parse_tool_args, tool_duration_ms};
545 use serde_json::json;
546 use std::time::{Duration, Instant};
547
548 #[test]
549 fn parse_tool_args_parses_json_object() {
550 assert_eq!(
551 parse_tool_args(r#"{"path":"README.md"}"#),
552 json!({"path": "README.md"})
553 );
554 }
555
556 #[test]
557 fn parse_tool_args_returns_null_for_empty_input() {
558 assert_eq!(parse_tool_args(" "), serde_json::Value::Null);
559 }
560
561 #[test]
562 fn parse_tool_args_preserves_non_json_input_as_string() {
563 assert_eq!(
564 parse_tool_args(r#"{"path":"README.md""#),
565 serde_json::Value::String(r#"{"path":"README.md""#.to_string())
566 );
567 }
568
569 #[test]
570 fn tool_duration_ms_has_one_millisecond_floor() {
571 let started_at = Instant::now();
572 assert_eq!(tool_duration_ms(started_at), 1);
573 }
574
575 #[test]
576 fn tool_duration_ms_preserves_elapsed_milliseconds() {
577 let started_at = Instant::now() - Duration::from_millis(12);
578 assert!(tool_duration_ms(started_at) >= 12);
579 }
580}