Skip to main content

a3s_code_core/orchestrator/
wrapper.rs

1//! SubAgent wrapper — executes a real AgentSession and forwards events to the
2//! Orchestrator event bus, with pause/resume/cancel control signal support.
3
4use crate::agent::AgentEvent;
5use crate::error::Result;
6use crate::orchestrator::{
7    ControlSignal, OrchestratorEvent, SubAgentActivity, SubAgentConfig, SubAgentState,
8};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11
12pub struct SubAgentWrapper {
13    id: String,
14    config: SubAgentConfig,
15    /// Real agent for LLM execution; `None` → placeholder mode.
16    agent: Option<Arc<crate::Agent>>,
17    event_tx: broadcast::Sender<OrchestratorEvent>,
18    control_rx: mpsc::Receiver<ControlSignal>,
19    state: Arc<RwLock<SubAgentState>>,
20    activity: Arc<RwLock<SubAgentActivity>>,
21    /// Shared map of live sessions; wrapper registers its session here so
22    /// `AgentOrchestrator::complete_external_task()` can reach it.
23    session_registry:
24        Arc<RwLock<std::collections::HashMap<String, Arc<crate::agent_api::AgentSession>>>>,
25}
26
27impl SubAgentWrapper {
28    #[allow(clippy::too_many_arguments)]
29    pub fn new(
30        id: String,
31        config: SubAgentConfig,
32        agent: Option<Arc<crate::Agent>>,
33        event_tx: broadcast::Sender<OrchestratorEvent>,
34        control_rx: mpsc::Receiver<ControlSignal>,
35        state: Arc<RwLock<SubAgentState>>,
36        activity: Arc<RwLock<SubAgentActivity>>,
37        session_registry: Arc<
38            RwLock<std::collections::HashMap<String, Arc<crate::agent_api::AgentSession>>>,
39        >,
40    ) -> Self {
41        Self {
42            id,
43            config,
44            agent,
45            event_tx,
46            control_rx,
47            state,
48            activity,
49            session_registry,
50        }
51    }
52
53    /// Run the SubAgent.  Dispatches to real or placeholder execution.
54    pub async fn execute(mut self) -> Result<String> {
55        self.update_state(SubAgentState::Running).await;
56        let start = std::time::Instant::now();
57
58        let result = if let Some(agent) = self.agent.take() {
59            self.execute_with_agent(agent).await
60        } else {
61            self.execute_placeholder().await
62        };
63
64        let duration_ms = start.elapsed().as_millis() as u64;
65
66        match &result {
67            Ok(output) => {
68                self.update_state(SubAgentState::Completed {
69                    success: true,
70                    output: output.clone(),
71                })
72                .await;
73                let _ = self.event_tx.send(OrchestratorEvent::SubAgentCompleted {
74                    id: self.id.clone(),
75                    success: true,
76                    output: output.clone(),
77                    duration_ms,
78                    token_usage: None,
79                });
80            }
81            Err(e) => {
82                let current = self.state.read().await.clone();
83                if !matches!(current, SubAgentState::Cancelled) {
84                    self.update_state(SubAgentState::Error {
85                        message: e.to_string(),
86                    })
87                    .await;
88                }
89                let _ = self.event_tx.send(OrchestratorEvent::SubAgentCompleted {
90                    id: self.id.clone(),
91                    success: false,
92                    output: e.to_string(),
93                    duration_ms,
94                    token_usage: None,
95                });
96            }
97        }
98
99        result
100    }
101
102    // -------------------------------------------------------------------------
103    // Real execution via AgentSession
104    // -------------------------------------------------------------------------
105
106    async fn execute_with_agent(&mut self, agent: Arc<crate::Agent>) -> Result<String> {
107        // Build an AgentRegistry from built-ins + extra agent_dirs.
108        let registry = crate::AgentRegistry::new();
109        for dir in &self.config.agent_dirs {
110            let agents = crate::load_agents_from_dir(std::path::Path::new(dir));
111            for def in agents {
112                registry.register(def);
113            }
114        }
115
116        // Build session options from SubAgentConfig fields.
117        let mut opts = crate::SessionOptions::new();
118        if self.config.permissive {
119            opts = opts.with_permissive_policy();
120        }
121        if let Some(steps) = self.config.max_steps {
122            opts = opts.with_max_tool_rounds(steps);
123        }
124        if let Some(queue_cfg) = self.config.lane_config.clone() {
125            opts = opts.with_queue_config(queue_cfg);
126        }
127
128        // Create session: use the named agent definition if found, otherwise
129        // fall back to a plain session so unknown agent_types still work.
130        let session = Arc::new(if let Some(def) = registry.get(&self.config.agent_type) {
131            agent.session_for_agent(&self.config.workspace, &def, Some(opts))?
132        } else {
133            agent.session(&self.config.workspace, Some(opts))?
134        });
135
136        // Register session so complete_external_task() can reach it.
137        self.session_registry
138            .write()
139            .await
140            .insert(self.id.clone(), Arc::clone(&session));
141
142        // Stream execution.
143        let (mut rx, _task) = session.stream(&self.config.prompt, None).await?;
144
145        let mut output = String::new();
146        let mut step: usize = 0;
147
148        loop {
149            // Drain pending control signals before each event.
150            while let Ok(signal) = self.control_rx.try_recv() {
151                self.handle_control_signal(signal).await?;
152            }
153
154            // Abort if cancelled.
155            if matches!(*self.state.read().await, SubAgentState::Cancelled) {
156                // Drop rx to signal the background streaming task to stop.
157                drop(rx);
158                return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
159            }
160
161            // Wait while paused (backpressure on rx naturally slows the agent).
162            while matches!(*self.state.read().await, SubAgentState::Paused) {
163                *self.activity.write().await = SubAgentActivity::WaitingForControl {
164                    reason: "Paused by orchestrator".to_string(),
165                };
166                tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
167                while let Ok(signal) = self.control_rx.try_recv() {
168                    self.handle_control_signal(signal).await?;
169                }
170                if matches!(*self.state.read().await, SubAgentState::Cancelled) {
171                    drop(rx);
172                    return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
173                }
174            }
175
176            // Consume the next agent event.
177            match rx.recv().await {
178                Some(AgentEvent::TurnStart { .. }) => {
179                    *self.activity.write().await =
180                        SubAgentActivity::RequestingLlm { message_count: 0 };
181                }
182                Some(AgentEvent::ToolStart { id, name }) => {
183                    *self.activity.write().await = SubAgentActivity::CallingTool {
184                        tool_name: name.clone(),
185                        args: serde_json::Value::Null,
186                    };
187                    let _ = self.event_tx.send(OrchestratorEvent::ToolExecutionStarted {
188                        id: self.id.clone(),
189                        tool_id: id,
190                        tool_name: name,
191                        args: serde_json::Value::Null,
192                    });
193                }
194                Some(AgentEvent::ToolEnd {
195                    id,
196                    name,
197                    output: tool_out,
198                    exit_code,
199                    ..
200                }) => {
201                    step += 1;
202                    *self.activity.write().await = SubAgentActivity::Idle;
203                    let tool_start = std::time::Instant::now();
204                    let _ = self
205                        .event_tx
206                        .send(OrchestratorEvent::ToolExecutionCompleted {
207                            id: self.id.clone(),
208                            tool_id: id,
209                            tool_name: name,
210                            result: tool_out,
211                            exit_code,
212                            duration_ms: tool_start.elapsed().as_millis() as u64,
213                        });
214                    let _ = self.event_tx.send(OrchestratorEvent::SubAgentProgress {
215                        id: self.id.clone(),
216                        step,
217                        total_steps: self.config.max_steps.unwrap_or(0),
218                        message: format!("Completed tool call {step}"),
219                    });
220                }
221                Some(AgentEvent::TextDelta { text }) => {
222                    output.push_str(&text);
223                }
224                Some(AgentEvent::ExternalTaskPending {
225                    task_id,
226                    session_id,
227                    lane,
228                    command_type,
229                    payload,
230                    timeout_ms,
231                }) => {
232                    let _ = self.event_tx.send(OrchestratorEvent::ExternalTaskPending {
233                        id: self.id.clone(),
234                        task_id,
235                        lane,
236                        command_type,
237                        payload,
238                        timeout_ms,
239                    });
240                    // session_id is informational; the orchestrator routes by subagent ID.
241                    let _ = session_id;
242                }
243                Some(AgentEvent::ExternalTaskCompleted {
244                    task_id,
245                    session_id,
246                    success,
247                }) => {
248                    let _ = self
249                        .event_tx
250                        .send(OrchestratorEvent::ExternalTaskCompleted {
251                            id: self.id.clone(),
252                            task_id,
253                            success,
254                        });
255                    let _ = session_id;
256                }
257                Some(AgentEvent::End { text, .. }) => {
258                    output = text;
259                    break;
260                }
261                Some(AgentEvent::Error { message }) => {
262                    return Err(anyhow::anyhow!("Agent error: {message}").into());
263                }
264                // Forward all other events as internal events for observability.
265                Some(event) => {
266                    let _ = self
267                        .event_tx
268                        .send(OrchestratorEvent::SubAgentInternalEvent {
269                            id: self.id.clone(),
270                            event,
271                        });
272                }
273                None => break, // stream closed
274            }
275        }
276
277        // Deregister so the Arc is dropped and the session can be freed.
278        self.session_registry.write().await.remove(&self.id);
279
280        Ok(output)
281    }
282
283    // -------------------------------------------------------------------------
284    // Placeholder execution (backward compatibility when no agent is configured)
285    // -------------------------------------------------------------------------
286
287    async fn execute_placeholder(&mut self) -> Result<String> {
288        for step in 1..=5 {
289            while let Ok(signal) = self.control_rx.try_recv() {
290                self.handle_control_signal(signal).await?;
291            }
292
293            if matches!(*self.state.read().await, SubAgentState::Cancelled) {
294                return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
295            }
296
297            while matches!(*self.state.read().await, SubAgentState::Paused) {
298                *self.activity.write().await = SubAgentActivity::WaitingForControl {
299                    reason: "Paused by orchestrator".to_string(),
300                };
301                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
302                while let Ok(signal) = self.control_rx.try_recv() {
303                    self.handle_control_signal(signal).await?;
304                }
305            }
306
307            *self.activity.write().await = SubAgentActivity::CallingTool {
308                tool_name: "read".to_string(),
309                args: serde_json::json!({"path": "/tmp/file.txt"}),
310            };
311
312            let _ = self.event_tx.send(OrchestratorEvent::ToolExecutionStarted {
313                id: self.id.clone(),
314                tool_id: format!("tool-{step}"),
315                tool_name: "read".to_string(),
316                args: serde_json::json!({"path": "/tmp/file.txt"}),
317            });
318
319            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
320
321            *self.activity.write().await = SubAgentActivity::RequestingLlm { message_count: 3 };
322
323            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
324
325            *self.activity.write().await = SubAgentActivity::Idle;
326
327            let _ = self.event_tx.send(OrchestratorEvent::SubAgentProgress {
328                id: self.id.clone(),
329                step,
330                total_steps: 5,
331                message: format!("Step {step}/5 completed"),
332            });
333        }
334
335        Ok(format!(
336            "Placeholder result for SubAgent {} ({})",
337            self.id, self.config.agent_type
338        ))
339    }
340
341    // -------------------------------------------------------------------------
342    // Control signal handling
343    // -------------------------------------------------------------------------
344
345    async fn handle_control_signal(&mut self, signal: ControlSignal) -> Result<()> {
346        let _ = self
347            .event_tx
348            .send(OrchestratorEvent::ControlSignalReceived {
349                id: self.id.clone(),
350                signal: signal.clone(),
351            });
352
353        let result = match signal {
354            ControlSignal::Pause => {
355                self.update_state(SubAgentState::Paused).await;
356                Ok(())
357            }
358            ControlSignal::Resume => {
359                self.update_state(SubAgentState::Running).await;
360                Ok(())
361            }
362            ControlSignal::Cancel => {
363                self.update_state(SubAgentState::Cancelled).await;
364                Err(anyhow::anyhow!("Cancelled by orchestrator").into())
365            }
366            ControlSignal::AdjustParams { max_steps, .. } => {
367                if let Some(steps) = max_steps {
368                    self.config.max_steps = Some(steps);
369                }
370                Ok(())
371            }
372            ControlSignal::InjectPrompt { ref prompt } => {
373                // Append the injected prompt so the next LLM turn sees it.
374                self.config.prompt.push('\n');
375                self.config.prompt.push_str(prompt);
376                Ok(())
377            }
378        };
379
380        let _ = self.event_tx.send(OrchestratorEvent::ControlSignalApplied {
381            id: self.id.clone(),
382            signal,
383            success: result.is_ok(),
384            error: result.as_ref().err().map(|e| format!("{e}")),
385        });
386
387        result
388    }
389
390    async fn update_state(&self, new_state: SubAgentState) {
391        let old_state = {
392            let mut state = self.state.write().await;
393            let old = state.clone();
394            *state = new_state.clone();
395            old
396        };
397
398        let _ = self.event_tx.send(OrchestratorEvent::SubAgentStateChanged {
399            id: self.id.clone(),
400            old_state,
401            new_state,
402        });
403    }
404}