Skip to main content

a3s_code_core/orchestrator/
wrapper.rs

1//! SubAgent wrapper — executes a real AgentSession and forwards events to the
2//! Orchestrator event bus, with pause/resume/cancel control signal support.
3
4use crate::agent::AgentEvent;
5use crate::error::Result;
6use crate::orchestrator::{
7    ControlSignal, OrchestratorEvent, SubAgentActivity, SubAgentConfig, SubAgentState,
8};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11
12struct PendingToolCall {
13    id: String,
14    name: String,
15    args_buffer: String,
16    started_at: std::time::Instant,
17    emitted: bool,
18}
19
20fn parse_tool_args(raw: &str) -> serde_json::Value {
21    let trimmed = raw.trim();
22    if trimmed.is_empty() {
23        serde_json::Value::Null
24    } else {
25        serde_json::from_str(trimmed)
26            .unwrap_or_else(|_| serde_json::Value::String(trimmed.to_string()))
27    }
28}
29
30fn tool_duration_ms(started_at: std::time::Instant) -> u64 {
31    std::cmp::max(1, started_at.elapsed().as_millis() as u64)
32}
33
34pub struct SubAgentWrapper {
35    id: String,
36    config: SubAgentConfig,
37    /// Real agent for LLM execution; `None` → placeholder mode.
38    agent: Option<Arc<crate::Agent>>,
39    event_tx: broadcast::Sender<OrchestratorEvent>,
40    subagent_event_tx: broadcast::Sender<OrchestratorEvent>,
41    event_history: Arc<RwLock<std::collections::VecDeque<OrchestratorEvent>>>,
42    control_rx: mpsc::Receiver<ControlSignal>,
43    state: Arc<RwLock<SubAgentState>>,
44    activity: Arc<RwLock<SubAgentActivity>>,
45    /// Shared map of live sessions; wrapper registers its session here so
46    /// `AgentOrchestrator::complete_external_task()` can reach it.
47    session_registry:
48        Arc<RwLock<std::collections::HashMap<String, Arc<crate::agent_api::AgentSession>>>>,
49}
50
51impl SubAgentWrapper {
52    #[allow(clippy::too_many_arguments)]
53    pub fn new(
54        id: String,
55        config: SubAgentConfig,
56        agent: Option<Arc<crate::Agent>>,
57        event_tx: broadcast::Sender<OrchestratorEvent>,
58        subagent_event_tx: broadcast::Sender<OrchestratorEvent>,
59        event_history: Arc<RwLock<std::collections::VecDeque<OrchestratorEvent>>>,
60        control_rx: mpsc::Receiver<ControlSignal>,
61        state: Arc<RwLock<SubAgentState>>,
62        activity: Arc<RwLock<SubAgentActivity>>,
63        session_registry: Arc<
64            RwLock<std::collections::HashMap<String, Arc<crate::agent_api::AgentSession>>>,
65        >,
66    ) -> Self {
67        Self {
68            id,
69            config,
70            agent,
71            event_tx,
72            subagent_event_tx,
73            event_history,
74            control_rx,
75            state,
76            activity,
77            session_registry,
78        }
79    }
80
81    async fn emit(&self, event: OrchestratorEvent) {
82        let _ = self.event_tx.send(event.clone());
83        let _ = self.subagent_event_tx.send(event.clone());
84
85        let mut history = self.event_history.write().await;
86        history.push_back(event);
87        while history.len() > 1024 {
88            history.pop_front();
89        }
90    }
91
92    async fn flush_tool_start(
93        &self,
94        pending_tool: &mut Option<PendingToolCall>,
95    ) -> std::time::Instant {
96        let pending = pending_tool
97            .as_mut()
98            .expect("flush_tool_start called without a pending tool");
99        if pending.emitted {
100            return pending.started_at;
101        }
102
103        let args = parse_tool_args(&pending.args_buffer);
104        self.emit(OrchestratorEvent::ToolExecutionStarted {
105            id: self.id.clone(),
106            tool_id: pending.id.clone(),
107            tool_name: pending.name.clone(),
108            args: args.clone(),
109        })
110        .await;
111
112        *self.activity.write().await = SubAgentActivity::CallingTool {
113            tool_name: pending.name.clone(),
114            args,
115        };
116        pending.emitted = true;
117        pending.started_at
118    }
119
120    /// Run the SubAgent.  Dispatches to real or placeholder execution.
121    pub async fn execute(mut self) -> Result<String> {
122        self.update_state(SubAgentState::Running).await;
123        let start = std::time::Instant::now();
124
125        let result = if let Some(agent) = self.agent.take() {
126            self.execute_with_agent(agent).await
127        } else {
128            self.execute_placeholder().await
129        };
130
131        let duration_ms = start.elapsed().as_millis() as u64;
132
133        match &result {
134            Ok(output) => {
135                self.update_state(SubAgentState::Completed {
136                    success: true,
137                    output: output.clone(),
138                })
139                .await;
140                self.emit(OrchestratorEvent::SubAgentCompleted {
141                    id: self.id.clone(),
142                    success: true,
143                    output: output.clone(),
144                    duration_ms,
145                    token_usage: None,
146                })
147                .await;
148            }
149            Err(e) => {
150                let current = self.state.read().await.clone();
151                if !matches!(current, SubAgentState::Cancelled) {
152                    self.update_state(SubAgentState::Error {
153                        message: e.to_string(),
154                    })
155                    .await;
156                }
157                self.emit(OrchestratorEvent::SubAgentCompleted {
158                    id: self.id.clone(),
159                    success: false,
160                    output: e.to_string(),
161                    duration_ms,
162                    token_usage: None,
163                })
164                .await;
165            }
166        }
167
168        result
169    }
170
171    // -------------------------------------------------------------------------
172    // Real execution via AgentSession
173    // -------------------------------------------------------------------------
174
175    async fn execute_with_agent(&mut self, agent: Arc<crate::Agent>) -> Result<String> {
176        // Build an AgentRegistry from built-ins + extra agent_dirs.
177        let registry = crate::AgentRegistry::new();
178        for dir in &self.config.agent_dirs {
179            let agents = crate::load_agents_from_dir(std::path::Path::new(dir));
180            for def in agents {
181                registry.register(def);
182            }
183        }
184
185        // Build session options from SubAgentConfig fields.
186        let mut opts = crate::SessionOptions::new();
187
188        // Pass agent_dirs and skill_dirs to session options
189        for dir in &self.config.agent_dirs {
190            opts = opts.with_agent_dir(dir.as_str());
191        }
192        if !self.config.skill_dirs.is_empty() {
193            opts = opts.with_skill_dirs(self.config.skill_dirs.iter().map(|s| s.as_str()));
194        }
195
196        // Handle permissive mode with fine-grained deny control
197        if self.config.permissive {
198            // Build a permissive policy that still respects deny rules
199            let mut policy = crate::permissions::PermissionPolicy::permissive();
200
201            // Add deny rules from permissive_deny config
202            for rule in &self.config.permissive_deny {
203                policy = policy.deny(rule);
204            }
205
206            // If we have an agent definition, also add its deny rules
207            if let Some(def) = registry.get(&self.config.agent_type) {
208                for rule in &def.permissions.deny {
209                    policy = policy.deny(&rule.rule);
210                }
211            }
212
213            opts = opts.with_permission_checker(Arc::new(policy));
214        }
215
216        if let Some(steps) = self.config.max_steps {
217            opts = opts.with_max_tool_rounds(steps);
218        }
219        if let Some(queue_cfg) = self.config.lane_config.clone() {
220            opts = opts.with_queue_config(queue_cfg);
221        }
222
223        // Create session: use the named agent definition if found, otherwise
224        // fall back to a plain session so unknown agent_types still work.
225        let session = Arc::new(if let Some(def) = registry.get(&self.config.agent_type) {
226            agent.session_for_agent(&self.config.workspace, &def, Some(opts))?
227        } else {
228            agent.session(&self.config.workspace, Some(opts))?
229        });
230
231        // Register session so complete_external_task() can reach it.
232        self.session_registry
233            .write()
234            .await
235            .insert(self.id.clone(), Arc::clone(&session));
236
237        // Stream execution.
238        let (mut rx, _task) = session.stream(&self.config.prompt, None).await?;
239
240        let mut output = String::new();
241        let mut step: usize = 0;
242        let mut pending_tool: Option<PendingToolCall> = None;
243
244        loop {
245            // Drain pending control signals before each event.
246            while let Ok(signal) = self.control_rx.try_recv() {
247                self.handle_control_signal(signal).await?;
248            }
249
250            // Abort if cancelled.
251            if matches!(*self.state.read().await, SubAgentState::Cancelled) {
252                // Drop rx to signal the background streaming task to stop.
253                drop(rx);
254                return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
255            }
256
257            // Wait while paused (backpressure on rx naturally slows the agent).
258            while matches!(*self.state.read().await, SubAgentState::Paused) {
259                *self.activity.write().await = SubAgentActivity::WaitingForControl {
260                    reason: "Paused by orchestrator".to_string(),
261                };
262                tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
263                while let Ok(signal) = self.control_rx.try_recv() {
264                    self.handle_control_signal(signal).await?;
265                }
266                if matches!(*self.state.read().await, SubAgentState::Cancelled) {
267                    drop(rx);
268                    return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
269                }
270            }
271
272            // Consume the next agent event.
273            match rx.recv().await {
274                Some(AgentEvent::TurnStart { turn }) => {
275                    *self.activity.write().await =
276                        SubAgentActivity::RequestingLlm { message_count: 0 };
277                    self.emit(OrchestratorEvent::SubAgentInternalEvent {
278                        id: self.id.clone(),
279                        event: AgentEvent::TurnStart { turn },
280                    })
281                    .await;
282                }
283                Some(AgentEvent::ToolStart { id, name }) => {
284                    pending_tool = Some(PendingToolCall {
285                        id,
286                        name,
287                        args_buffer: String::new(),
288                        started_at: std::time::Instant::now(),
289                        emitted: false,
290                    });
291                }
292                Some(AgentEvent::ToolInputDelta { delta }) => {
293                    if let Some(pending) = pending_tool.as_mut() {
294                        pending.args_buffer.push_str(&delta);
295                    }
296                    self.emit(OrchestratorEvent::SubAgentInternalEvent {
297                        id: self.id.clone(),
298                        event: AgentEvent::ToolInputDelta { delta },
299                    })
300                    .await;
301                }
302                Some(AgentEvent::ToolEnd {
303                    id,
304                    name,
305                    output: tool_out,
306                    exit_code,
307                    ..
308                }) => {
309                    step += 1;
310                    let started_at =
311                        if pending_tool.as_ref().map(|p| p.id.as_str()) == Some(id.as_str()) {
312                            self.flush_tool_start(&mut pending_tool).await
313                        } else {
314                            std::time::Instant::now()
315                        };
316                    *self.activity.write().await = SubAgentActivity::Idle;
317                    self.emit(OrchestratorEvent::ToolExecutionCompleted {
318                        id: self.id.clone(),
319                        tool_id: id,
320                        tool_name: name,
321                        result: tool_out,
322                        exit_code,
323                        duration_ms: tool_duration_ms(started_at),
324                    })
325                    .await;
326                    pending_tool = None;
327                    self.emit(OrchestratorEvent::SubAgentProgress {
328                        id: self.id.clone(),
329                        step,
330                        total_steps: self.config.max_steps.unwrap_or(0),
331                        message: format!("Completed tool call {step}"),
332                    })
333                    .await;
334                }
335                Some(AgentEvent::TextDelta { text }) => {
336                    if pending_tool.is_some() {
337                        self.flush_tool_start(&mut pending_tool).await;
338                    }
339                    output.push_str(&text);
340                    self.emit(OrchestratorEvent::SubAgentInternalEvent {
341                        id: self.id.clone(),
342                        event: AgentEvent::TextDelta { text },
343                    })
344                    .await;
345                }
346                Some(AgentEvent::ExternalTaskPending {
347                    task_id,
348                    session_id,
349                    lane,
350                    command_type,
351                    payload,
352                    timeout_ms,
353                }) => {
354                    if pending_tool.is_some() {
355                        self.flush_tool_start(&mut pending_tool).await;
356                    }
357                    self.emit(OrchestratorEvent::ExternalTaskPending {
358                        id: self.id.clone(),
359                        task_id,
360                        lane,
361                        command_type,
362                        payload,
363                        timeout_ms,
364                    })
365                    .await;
366                    // session_id is informational; the orchestrator routes by subagent ID.
367                    let _ = session_id;
368                }
369                Some(AgentEvent::ExternalTaskCompleted {
370                    task_id,
371                    session_id,
372                    success,
373                }) => {
374                    if pending_tool.is_some() {
375                        self.flush_tool_start(&mut pending_tool).await;
376                    }
377                    self.emit(OrchestratorEvent::ExternalTaskCompleted {
378                        id: self.id.clone(),
379                        task_id,
380                        success,
381                    })
382                    .await;
383                    let _ = session_id;
384                }
385                Some(AgentEvent::End { text, .. }) => {
386                    if pending_tool.is_some() {
387                        self.flush_tool_start(&mut pending_tool).await;
388                    }
389                    output = text;
390                    break;
391                }
392                Some(AgentEvent::Error { message }) => {
393                    return Err(anyhow::anyhow!("Agent error: {message}").into());
394                }
395                // Forward all other events as internal events for observability.
396                Some(event) => {
397                    if pending_tool.is_some() {
398                        self.flush_tool_start(&mut pending_tool).await;
399                    }
400                    self.emit(OrchestratorEvent::SubAgentInternalEvent {
401                        id: self.id.clone(),
402                        event,
403                    })
404                    .await;
405                }
406                None => break, // stream closed
407            }
408        }
409
410        // Deregister so the Arc is dropped and the session can be freed.
411        self.session_registry.write().await.remove(&self.id);
412
413        Ok(output)
414    }
415
416    // -------------------------------------------------------------------------
417    // Placeholder execution (backward compatibility when no agent is configured)
418    // -------------------------------------------------------------------------
419
420    async fn execute_placeholder(&mut self) -> Result<String> {
421        for step in 1..=5 {
422            while let Ok(signal) = self.control_rx.try_recv() {
423                self.handle_control_signal(signal).await?;
424            }
425
426            if matches!(*self.state.read().await, SubAgentState::Cancelled) {
427                return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
428            }
429
430            while matches!(*self.state.read().await, SubAgentState::Paused) {
431                *self.activity.write().await = SubAgentActivity::WaitingForControl {
432                    reason: "Paused by orchestrator".to_string(),
433                };
434                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
435                while let Ok(signal) = self.control_rx.try_recv() {
436                    self.handle_control_signal(signal).await?;
437                }
438            }
439
440            *self.activity.write().await = SubAgentActivity::CallingTool {
441                tool_name: "read".to_string(),
442                args: serde_json::json!({"path": "/tmp/file.txt"}),
443            };
444
445            self.emit(OrchestratorEvent::ToolExecutionStarted {
446                id: self.id.clone(),
447                tool_id: format!("tool-{step}"),
448                tool_name: "read".to_string(),
449                args: serde_json::json!({"path": "/tmp/file.txt"}),
450            })
451            .await;
452
453            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
454
455            *self.activity.write().await = SubAgentActivity::RequestingLlm { message_count: 3 };
456
457            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
458
459            *self.activity.write().await = SubAgentActivity::Idle;
460
461            self.emit(OrchestratorEvent::SubAgentProgress {
462                id: self.id.clone(),
463                step,
464                total_steps: 5,
465                message: format!("Step {step}/5 completed"),
466            })
467            .await;
468        }
469
470        Ok(format!(
471            "Placeholder result for SubAgent {} ({})",
472            self.id, self.config.agent_type
473        ))
474    }
475
476    // -------------------------------------------------------------------------
477    // Control signal handling
478    // -------------------------------------------------------------------------
479
480    async fn handle_control_signal(&mut self, signal: ControlSignal) -> Result<()> {
481        self.emit(OrchestratorEvent::ControlSignalReceived {
482            id: self.id.clone(),
483            signal: signal.clone(),
484        })
485        .await;
486
487        let result = match signal {
488            ControlSignal::Pause => {
489                self.update_state(SubAgentState::Paused).await;
490                Ok(())
491            }
492            ControlSignal::Resume => {
493                self.update_state(SubAgentState::Running).await;
494                Ok(())
495            }
496            ControlSignal::Cancel => {
497                self.update_state(SubAgentState::Cancelled).await;
498                Err(anyhow::anyhow!("Cancelled by orchestrator").into())
499            }
500            ControlSignal::AdjustParams { max_steps, .. } => {
501                if let Some(steps) = max_steps {
502                    self.config.max_steps = Some(steps);
503                }
504                Ok(())
505            }
506            ControlSignal::InjectPrompt { ref prompt } => {
507                // Append the injected prompt so the next LLM turn sees it.
508                self.config.prompt.push('\n');
509                self.config.prompt.push_str(prompt);
510                Ok(())
511            }
512        };
513
514        self.emit(OrchestratorEvent::ControlSignalApplied {
515            id: self.id.clone(),
516            signal,
517            success: result.is_ok(),
518            error: result.as_ref().err().map(|e| format!("{e}")),
519        })
520        .await;
521
522        result
523    }
524
525    async fn update_state(&self, new_state: SubAgentState) {
526        let old_state = {
527            let mut state = self.state.write().await;
528            let old = state.clone();
529            *state = new_state.clone();
530            old
531        };
532
533        self.emit(OrchestratorEvent::SubAgentStateChanged {
534            id: self.id.clone(),
535            old_state,
536            new_state,
537        })
538        .await;
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::{parse_tool_args, tool_duration_ms};
545    use serde_json::json;
546    use std::time::{Duration, Instant};
547
548    #[test]
549    fn parse_tool_args_parses_json_object() {
550        assert_eq!(
551            parse_tool_args(r#"{"path":"README.md"}"#),
552            json!({"path": "README.md"})
553        );
554    }
555
556    #[test]
557    fn parse_tool_args_returns_null_for_empty_input() {
558        assert_eq!(parse_tool_args("   "), serde_json::Value::Null);
559    }
560
561    #[test]
562    fn parse_tool_args_preserves_non_json_input_as_string() {
563        assert_eq!(
564            parse_tool_args(r#"{"path":"README.md""#),
565            serde_json::Value::String(r#"{"path":"README.md""#.to_string())
566        );
567    }
568
569    #[test]
570    fn tool_duration_ms_has_one_millisecond_floor() {
571        let started_at = Instant::now();
572        assert_eq!(tool_duration_ms(started_at), 1);
573    }
574
575    #[test]
576    fn tool_duration_ms_preserves_elapsed_milliseconds() {
577        let started_at = Instant::now() - Duration::from_millis(12);
578        assert!(tool_duration_ms(started_at) >= 12);
579    }
580}