a3s_code_core/orchestration/
executor.rs1use crate::agent::{AgentEvent, DEFAULT_MAX_PARALLEL_TASKS};
4use crate::ordered_parallel::run_ordered_parallel_with_limit;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tokio::sync::broadcast;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19pub struct AgentStepSpec {
20 pub task_id: String,
23 pub agent: String,
25 pub description: String,
27 pub prompt: String,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
31 pub max_steps: Option<usize>,
32 #[serde(default, skip_serializing_if = "Option::is_none")]
34 pub parent_session_id: Option<String>,
35 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub output_schema: Option<serde_json::Value>,
41}
42
43impl AgentStepSpec {
44 pub fn new(
46 task_id: impl Into<String>,
47 agent: impl Into<String>,
48 description: impl Into<String>,
49 prompt: impl Into<String>,
50 ) -> Self {
51 Self {
52 task_id: task_id.into(),
53 agent: agent.into(),
54 description: description.into(),
55 prompt: prompt.into(),
56 max_steps: None,
57 parent_session_id: None,
58 output_schema: None,
59 }
60 }
61
62 pub fn with_max_steps(mut self, max_steps: usize) -> Self {
63 self.max_steps = Some(max_steps);
64 self
65 }
66
67 pub fn with_parent_session_id(mut self, parent_session_id: impl Into<String>) -> Self {
68 self.parent_session_id = Some(parent_session_id.into());
69 self
70 }
71
72 pub fn with_output_schema(mut self, schema: serde_json::Value) -> Self {
74 self.output_schema = Some(schema);
75 self
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85pub struct StepOutcome {
86 pub task_id: String,
87 pub session_id: String,
88 pub agent: String,
89 pub output: String,
90 pub success: bool,
91 #[serde(default, skip_serializing_if = "Option::is_none")]
93 pub structured: Option<serde_json::Value>,
94}
95
96impl StepOutcome {
97 pub fn failed(
101 task_id: impl Into<String>,
102 agent: impl Into<String>,
103 message: impl Into<String>,
104 ) -> Self {
105 let task_id = task_id.into();
106 let session_id = format!("task-run-{task_id}");
107 Self {
108 task_id,
109 session_id,
110 agent: agent.into(),
111 output: message.into(),
112 success: false,
113 structured: None,
114 }
115 }
116}
117
118#[async_trait]
128pub trait AgentExecutor: Send + Sync {
129 async fn execute_step(
136 &self,
137 spec: AgentStepSpec,
138 event_tx: Option<broadcast::Sender<AgentEvent>>,
139 ) -> StepOutcome;
140
141 fn concurrency_hint(&self) -> usize {
147 DEFAULT_MAX_PARALLEL_TASKS
148 }
149}
150
151pub async fn execute_steps_parallel(
159 executor: Arc<dyn AgentExecutor>,
160 specs: Vec<AgentStepSpec>,
161 event_tx: Option<broadcast::Sender<AgentEvent>>,
162) -> Vec<StepOutcome> {
163 let limit = executor.concurrency_hint();
164 let labels: Vec<(String, String)> = specs
167 .iter()
168 .map(|s| (s.task_id.clone(), s.agent.clone()))
169 .collect();
170
171 let results = run_ordered_parallel_with_limit(specs, limit, move |_idx, spec| {
172 let executor = Arc::clone(&executor);
173 let event_tx = event_tx.clone();
174 async move { executor.execute_step(spec, event_tx).await }
175 })
176 .await;
177
178 results
179 .into_iter()
180 .map(|result| match result.output {
181 Ok(outcome) => outcome,
182 Err(error) => {
183 let (task_id, agent) = labels
184 .get(result.index)
185 .cloned()
186 .unwrap_or_else(|| ("unknown".to_string(), "unknown".to_string()));
187 StepOutcome::failed(task_id, agent, error.to_string())
188 }
189 })
190 .collect()
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use std::sync::atomic::{AtomicUsize, Ordering};
197 use std::time::Duration;
198
199 struct MockExecutor {
202 hint: usize,
203 active: Arc<AtomicUsize>,
204 max_active: Arc<AtomicUsize>,
205 }
206
207 impl MockExecutor {
208 fn new(hint: usize) -> Self {
209 Self {
210 hint,
211 active: Arc::new(AtomicUsize::new(0)),
212 max_active: Arc::new(AtomicUsize::new(0)),
213 }
214 }
215 }
216
217 #[async_trait]
218 impl AgentExecutor for MockExecutor {
219 async fn execute_step(
220 &self,
221 spec: AgentStepSpec,
222 _event_tx: Option<broadcast::Sender<AgentEvent>>,
223 ) -> StepOutcome {
224 let now = self.active.fetch_add(1, Ordering::SeqCst) + 1;
225 self.max_active.fetch_max(now, Ordering::SeqCst);
226 tokio::time::sleep(Duration::from_millis(20)).await;
227 self.active.fetch_sub(1, Ordering::SeqCst);
228
229 assert!(spec.agent != "boom", "boom");
232 StepOutcome {
233 task_id: spec.task_id.clone(),
234 session_id: format!("task-run-{}", spec.task_id),
235 agent: spec.agent.clone(),
236 output: format!("ran: {}", spec.prompt),
237 success: spec.agent != "fail",
238 structured: None,
239 }
240 }
241 fn concurrency_hint(&self) -> usize {
242 self.hint
243 }
244 }
245
246 fn spec(id: &str, agent: &str) -> AgentStepSpec {
247 AgentStepSpec::new(id, agent, "d", format!("prompt-{id}"))
248 }
249
250 #[tokio::test]
251 async fn fans_out_in_input_order() {
252 let exec: Arc<dyn AgentExecutor> = Arc::new(MockExecutor::new(8));
253 let specs = vec![spec("a", "explore"), spec("b", "review"), spec("c", "plan")];
254 let out = execute_steps_parallel(exec, specs, None).await;
255 assert_eq!(
256 out.iter().map(|o| o.task_id.as_str()).collect::<Vec<_>>(),
257 vec!["a", "b", "c"],
258 "results preserve input order"
259 );
260 assert!(out.iter().all(|o| o.success));
261 assert_eq!(out[0].output, "ran: prompt-a");
262 }
263
264 #[tokio::test]
265 async fn respects_concurrency_hint() {
266 let mock = MockExecutor::new(2);
267 let max_active = Arc::clone(&mock.max_active);
268 let exec: Arc<dyn AgentExecutor> = Arc::new(mock);
269 let specs = (0..6).map(|i| spec(&i.to_string(), "explore")).collect();
270 let _ = execute_steps_parallel(exec, specs, None).await;
271 assert_eq!(
272 max_active.load(Ordering::SeqCst),
273 2,
274 "never more than concurrency_hint steps run at once"
275 );
276 }
277
278 #[tokio::test]
279 async fn isolates_failed_and_panicked_steps() {
280 let exec: Arc<dyn AgentExecutor> = Arc::new(MockExecutor::new(8));
281 let specs = vec![
282 spec("ok", "explore"),
283 spec("bad", "fail"),
284 spec("crash", "boom"),
285 spec("ok2", "review"),
286 ];
287 let out = execute_steps_parallel(exec, specs, None).await;
288 assert_eq!(out.len(), 4, "every step yields a result");
289 assert!(out[0].success);
290 assert!(
291 !out[1].success,
292 "explicit failure surfaces as success=false"
293 );
294 assert!(
295 !out[2].success && out[2].agent == "boom",
296 "a panicked branch becomes a labelled failed outcome, not a drop"
297 );
298 assert!(out[3].success, "later steps unaffected by an earlier panic");
299 }
300
301 #[tokio::test]
302 async fn default_concurrency_hint_is_the_framework_default() {
303 struct Bare;
304 #[async_trait]
305 impl AgentExecutor for Bare {
306 async fn execute_step(
307 &self,
308 spec: AgentStepSpec,
309 _tx: Option<broadcast::Sender<AgentEvent>>,
310 ) -> StepOutcome {
311 StepOutcome::failed(spec.task_id, spec.agent, "unused")
312 }
313 }
314 assert_eq!(Bare.concurrency_hint(), DEFAULT_MAX_PARALLEL_TASKS);
315 }
316
317 #[test]
318 fn spec_and_outcome_round_trip_including_new_optional_fields() {
319 let schema = serde_json::json!({
320 "type": "object",
321 "properties": { "v": { "type": "string" } },
322 "required": ["v"]
323 });
324 let spec = AgentStepSpec::new("t1", "explore", "d", "p")
325 .with_max_steps(3)
326 .with_parent_session_id("parent")
327 .with_output_schema(schema.clone());
328 let back: AgentStepSpec =
329 serde_json::from_str(&serde_json::to_string(&spec).unwrap()).unwrap();
330 assert_eq!(back, spec);
331 assert_eq!(back.output_schema, Some(schema));
332
333 let outcome = StepOutcome {
334 task_id: "t1".into(),
335 session_id: "task-run-t1".into(),
336 agent: "explore".into(),
337 output: "ok".into(),
338 success: true,
339 structured: Some(serde_json::json!({ "v": "x" })),
340 };
341 let back: StepOutcome =
342 serde_json::from_str(&serde_json::to_string(&outcome).unwrap()).unwrap();
343 assert_eq!(back, outcome);
344
345 let old_spec: AgentStepSpec =
348 serde_json::from_str(r#"{"task_id":"t","agent":"a","description":"d","prompt":"p"}"#)
349 .unwrap();
350 assert_eq!(old_spec.output_schema, None);
351 let old_outcome: StepOutcome = serde_json::from_str(
352 r#"{"task_id":"t","session_id":"s","agent":"a","output":"o","success":true}"#,
353 )
354 .unwrap();
355 assert_eq!(old_outcome.structured, None);
356 }
357}