Skip to main content

a3s_code_core/orchestrator/
wrapper.rs

1//! SubAgent wrapper — executes a real AgentSession and forwards events to the
2//! Orchestrator event bus, with pause/resume/cancel control signal support.
3
4use crate::agent::AgentEvent;
5use crate::error::Result;
6use crate::orchestrator::{
7    ControlSignal, OrchestratorEvent, SubAgentActivity, SubAgentConfig, SubAgentState,
8};
9use std::sync::Arc;
10use tokio::sync::{broadcast, mpsc, RwLock};
11
12pub struct SubAgentWrapper {
13    id: String,
14    config: SubAgentConfig,
15    /// Real agent for LLM execution; `None` → placeholder mode.
16    agent: Option<Arc<crate::Agent>>,
17    event_tx: broadcast::Sender<OrchestratorEvent>,
18    control_rx: mpsc::Receiver<ControlSignal>,
19    state: Arc<RwLock<SubAgentState>>,
20    activity: Arc<RwLock<SubAgentActivity>>,
21    /// Shared map of live sessions; wrapper registers its session here so
22    /// `AgentOrchestrator::complete_external_task()` can reach it.
23    session_registry:
24        Arc<RwLock<std::collections::HashMap<String, Arc<crate::agent_api::AgentSession>>>>,
25}
26
27impl SubAgentWrapper {
28    #[allow(clippy::too_many_arguments)]
29    pub fn new(
30        id: String,
31        config: SubAgentConfig,
32        agent: Option<Arc<crate::Agent>>,
33        event_tx: broadcast::Sender<OrchestratorEvent>,
34        control_rx: mpsc::Receiver<ControlSignal>,
35        state: Arc<RwLock<SubAgentState>>,
36        activity: Arc<RwLock<SubAgentActivity>>,
37        session_registry: Arc<
38            RwLock<std::collections::HashMap<String, Arc<crate::agent_api::AgentSession>>>,
39        >,
40    ) -> Self {
41        Self {
42            id,
43            config,
44            agent,
45            event_tx,
46            control_rx,
47            state,
48            activity,
49            session_registry,
50        }
51    }
52
53    /// Run the SubAgent.  Dispatches to real or placeholder execution.
54    pub async fn execute(mut self) -> Result<String> {
55        self.update_state(SubAgentState::Running).await;
56        let start = std::time::Instant::now();
57
58        let result = if let Some(agent) = self.agent.take() {
59            self.execute_with_agent(agent).await
60        } else {
61            self.execute_placeholder().await
62        };
63
64        let duration_ms = start.elapsed().as_millis() as u64;
65
66        match &result {
67            Ok(output) => {
68                self.update_state(SubAgentState::Completed {
69                    success: true,
70                    output: output.clone(),
71                })
72                .await;
73                let _ = self.event_tx.send(OrchestratorEvent::SubAgentCompleted {
74                    id: self.id.clone(),
75                    success: true,
76                    output: output.clone(),
77                    duration_ms,
78                    token_usage: None,
79                });
80            }
81            Err(e) => {
82                let current = self.state.read().await.clone();
83                if !matches!(current, SubAgentState::Cancelled) {
84                    self.update_state(SubAgentState::Error {
85                        message: e.to_string(),
86                    })
87                    .await;
88                }
89                let _ = self.event_tx.send(OrchestratorEvent::SubAgentCompleted {
90                    id: self.id.clone(),
91                    success: false,
92                    output: e.to_string(),
93                    duration_ms,
94                    token_usage: None,
95                });
96            }
97        }
98
99        result
100    }
101
102    // -------------------------------------------------------------------------
103    // Real execution via AgentSession
104    // -------------------------------------------------------------------------
105
106    async fn execute_with_agent(&mut self, agent: Arc<crate::Agent>) -> Result<String> {
107        // Build an AgentRegistry from built-ins + extra agent_dirs.
108        let registry = crate::AgentRegistry::new();
109        for dir in &self.config.agent_dirs {
110            let agents = crate::load_agents_from_dir(std::path::Path::new(dir));
111            for def in agents {
112                registry.register(def);
113            }
114        }
115
116        // Build session options from SubAgentConfig fields.
117        let mut opts = crate::SessionOptions::new();
118
119        // Pass agent_dirs and skill_dirs to session options
120        for dir in &self.config.agent_dirs {
121            opts = opts.with_agent_dir(dir.as_str());
122        }
123        if !self.config.skill_dirs.is_empty() {
124            opts = opts.with_skill_dirs(self.config.skill_dirs.iter().map(|s| s.as_str()));
125        }
126
127        // Handle permissive mode with fine-grained deny control
128        if self.config.permissive {
129            // Build a permissive policy that still respects deny rules
130            let mut policy = crate::permissions::PermissionPolicy::permissive();
131
132            // Add deny rules from permissive_deny config
133            for rule in &self.config.permissive_deny {
134                policy = policy.deny(rule);
135            }
136
137            // If we have an agent definition, also add its deny rules
138            if let Some(def) = registry.get(&self.config.agent_type) {
139                for rule in &def.permissions.deny {
140                    policy = policy.deny(&rule.rule);
141                }
142            }
143
144            opts = opts.with_permission_checker(Arc::new(policy));
145        }
146
147        if let Some(steps) = self.config.max_steps {
148            opts = opts.with_max_tool_rounds(steps);
149        }
150        if let Some(queue_cfg) = self.config.lane_config.clone() {
151            opts = opts.with_queue_config(queue_cfg);
152        }
153
154        // Create session: use the named agent definition if found, otherwise
155        // fall back to a plain session so unknown agent_types still work.
156        let session = Arc::new(if let Some(def) = registry.get(&self.config.agent_type) {
157            agent.session_for_agent(&self.config.workspace, &def, Some(opts))?
158        } else {
159            agent.session(&self.config.workspace, Some(opts))?
160        });
161
162        // Register session so complete_external_task() can reach it.
163        self.session_registry
164            .write()
165            .await
166            .insert(self.id.clone(), Arc::clone(&session));
167
168        // Stream execution.
169        let (mut rx, _task) = session.stream(&self.config.prompt, None).await?;
170
171        let mut output = String::new();
172        let mut step: usize = 0;
173
174        loop {
175            // Drain pending control signals before each event.
176            while let Ok(signal) = self.control_rx.try_recv() {
177                self.handle_control_signal(signal).await?;
178            }
179
180            // Abort if cancelled.
181            if matches!(*self.state.read().await, SubAgentState::Cancelled) {
182                // Drop rx to signal the background streaming task to stop.
183                drop(rx);
184                return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
185            }
186
187            // Wait while paused (backpressure on rx naturally slows the agent).
188            while matches!(*self.state.read().await, SubAgentState::Paused) {
189                *self.activity.write().await = SubAgentActivity::WaitingForControl {
190                    reason: "Paused by orchestrator".to_string(),
191                };
192                tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
193                while let Ok(signal) = self.control_rx.try_recv() {
194                    self.handle_control_signal(signal).await?;
195                }
196                if matches!(*self.state.read().await, SubAgentState::Cancelled) {
197                    drop(rx);
198                    return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
199                }
200            }
201
202            // Consume the next agent event.
203            match rx.recv().await {
204                Some(AgentEvent::TurnStart { turn }) => {
205                    *self.activity.write().await =
206                        SubAgentActivity::RequestingLlm { message_count: 0 };
207                    // Forward as internal event for observability
208                    let _ = self
209                        .event_tx
210                        .send(OrchestratorEvent::SubAgentInternalEvent {
211                            id: self.id.clone(),
212                            event: AgentEvent::TurnStart { turn },
213                        });
214                }
215                Some(AgentEvent::ToolStart { id, name }) => {
216                    *self.activity.write().await = SubAgentActivity::CallingTool {
217                        tool_name: name.clone(),
218                        args: serde_json::Value::Null,
219                    };
220                    let _ = self.event_tx.send(OrchestratorEvent::ToolExecutionStarted {
221                        id: self.id.clone(),
222                        tool_id: id,
223                        tool_name: name,
224                        args: serde_json::Value::Null,
225                    });
226                }
227                Some(AgentEvent::ToolEnd {
228                    id,
229                    name,
230                    output: tool_out,
231                    exit_code,
232                    ..
233                }) => {
234                    step += 1;
235                    *self.activity.write().await = SubAgentActivity::Idle;
236                    let tool_start = std::time::Instant::now();
237                    let _ = self
238                        .event_tx
239                        .send(OrchestratorEvent::ToolExecutionCompleted {
240                            id: self.id.clone(),
241                            tool_id: id,
242                            tool_name: name,
243                            result: tool_out,
244                            exit_code,
245                            duration_ms: tool_start.elapsed().as_millis() as u64,
246                        });
247                    let _ = self.event_tx.send(OrchestratorEvent::SubAgentProgress {
248                        id: self.id.clone(),
249                        step,
250                        total_steps: self.config.max_steps.unwrap_or(0),
251                        message: format!("Completed tool call {step}"),
252                    });
253                }
254                Some(AgentEvent::TextDelta { text }) => {
255                    output.push_str(&text);
256                    // Forward as internal event for streaming observability
257                    let _ = self
258                        .event_tx
259                        .send(OrchestratorEvent::SubAgentInternalEvent {
260                            id: self.id.clone(),
261                            event: AgentEvent::TextDelta { text },
262                        });
263                }
264                Some(AgentEvent::ExternalTaskPending {
265                    task_id,
266                    session_id,
267                    lane,
268                    command_type,
269                    payload,
270                    timeout_ms,
271                }) => {
272                    let _ = self.event_tx.send(OrchestratorEvent::ExternalTaskPending {
273                        id: self.id.clone(),
274                        task_id,
275                        lane,
276                        command_type,
277                        payload,
278                        timeout_ms,
279                    });
280                    // session_id is informational; the orchestrator routes by subagent ID.
281                    let _ = session_id;
282                }
283                Some(AgentEvent::ExternalTaskCompleted {
284                    task_id,
285                    session_id,
286                    success,
287                }) => {
288                    let _ = self
289                        .event_tx
290                        .send(OrchestratorEvent::ExternalTaskCompleted {
291                            id: self.id.clone(),
292                            task_id,
293                            success,
294                        });
295                    let _ = session_id;
296                }
297                Some(AgentEvent::End { text, .. }) => {
298                    output = text;
299                    break;
300                }
301                Some(AgentEvent::Error { message }) => {
302                    return Err(anyhow::anyhow!("Agent error: {message}").into());
303                }
304                // Forward all other events as internal events for observability.
305                Some(event) => {
306                    let _ = self
307                        .event_tx
308                        .send(OrchestratorEvent::SubAgentInternalEvent {
309                            id: self.id.clone(),
310                            event,
311                        });
312                }
313                None => break, // stream closed
314            }
315        }
316
317        // Deregister so the Arc is dropped and the session can be freed.
318        self.session_registry.write().await.remove(&self.id);
319
320        Ok(output)
321    }
322
323    // -------------------------------------------------------------------------
324    // Placeholder execution (backward compatibility when no agent is configured)
325    // -------------------------------------------------------------------------
326
327    async fn execute_placeholder(&mut self) -> Result<String> {
328        for step in 1..=5 {
329            while let Ok(signal) = self.control_rx.try_recv() {
330                self.handle_control_signal(signal).await?;
331            }
332
333            if matches!(*self.state.read().await, SubAgentState::Cancelled) {
334                return Err(anyhow::anyhow!("Cancelled by orchestrator").into());
335            }
336
337            while matches!(*self.state.read().await, SubAgentState::Paused) {
338                *self.activity.write().await = SubAgentActivity::WaitingForControl {
339                    reason: "Paused by orchestrator".to_string(),
340                };
341                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
342                while let Ok(signal) = self.control_rx.try_recv() {
343                    self.handle_control_signal(signal).await?;
344                }
345            }
346
347            *self.activity.write().await = SubAgentActivity::CallingTool {
348                tool_name: "read".to_string(),
349                args: serde_json::json!({"path": "/tmp/file.txt"}),
350            };
351
352            let _ = self.event_tx.send(OrchestratorEvent::ToolExecutionStarted {
353                id: self.id.clone(),
354                tool_id: format!("tool-{step}"),
355                tool_name: "read".to_string(),
356                args: serde_json::json!({"path": "/tmp/file.txt"}),
357            });
358
359            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
360
361            *self.activity.write().await = SubAgentActivity::RequestingLlm { message_count: 3 };
362
363            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
364
365            *self.activity.write().await = SubAgentActivity::Idle;
366
367            let _ = self.event_tx.send(OrchestratorEvent::SubAgentProgress {
368                id: self.id.clone(),
369                step,
370                total_steps: 5,
371                message: format!("Step {step}/5 completed"),
372            });
373        }
374
375        Ok(format!(
376            "Placeholder result for SubAgent {} ({})",
377            self.id, self.config.agent_type
378        ))
379    }
380
381    // -------------------------------------------------------------------------
382    // Control signal handling
383    // -------------------------------------------------------------------------
384
385    async fn handle_control_signal(&mut self, signal: ControlSignal) -> Result<()> {
386        let _ = self
387            .event_tx
388            .send(OrchestratorEvent::ControlSignalReceived {
389                id: self.id.clone(),
390                signal: signal.clone(),
391            });
392
393        let result = match signal {
394            ControlSignal::Pause => {
395                self.update_state(SubAgentState::Paused).await;
396                Ok(())
397            }
398            ControlSignal::Resume => {
399                self.update_state(SubAgentState::Running).await;
400                Ok(())
401            }
402            ControlSignal::Cancel => {
403                self.update_state(SubAgentState::Cancelled).await;
404                Err(anyhow::anyhow!("Cancelled by orchestrator").into())
405            }
406            ControlSignal::AdjustParams { max_steps, .. } => {
407                if let Some(steps) = max_steps {
408                    self.config.max_steps = Some(steps);
409                }
410                Ok(())
411            }
412            ControlSignal::InjectPrompt { ref prompt } => {
413                // Append the injected prompt so the next LLM turn sees it.
414                self.config.prompt.push('\n');
415                self.config.prompt.push_str(prompt);
416                Ok(())
417            }
418        };
419
420        let _ = self.event_tx.send(OrchestratorEvent::ControlSignalApplied {
421            id: self.id.clone(),
422            signal,
423            success: result.is_ok(),
424            error: result.as_ref().err().map(|e| format!("{e}")),
425        });
426
427        result
428    }
429
430    async fn update_state(&self, new_state: SubAgentState) {
431        let old_state = {
432            let mut state = self.state.write().await;
433            let old = state.clone();
434            *state = new_state.clone();
435            old
436        };
437
438        let _ = self.event_tx.send(OrchestratorEvent::SubAgentStateChanged {
439            id: self.id.clone(),
440            old_state,
441            new_state,
442        });
443    }
444}