1use super::{Tool, ToolResult};
7use crate::provider::{ProviderRegistry, parse_model_string};
8use crate::swarm::executor::run_agent_loop;
9use crate::tool::ToolRegistry;
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use serde_json::{Value, json};
13use std::sync::Arc;
14
15pub struct SwarmExecuteTool;
16
17impl SwarmExecuteTool {
18 pub fn new() -> Self {
19 Self
20 }
21}
22
23impl Default for SwarmExecuteTool {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29#[derive(Clone)]
30struct TaskInput {
31 id: Option<String>,
32 name: String,
33 instruction: String,
34 #[allow(dead_code)]
35 specialty: Option<String>,
36}
37
38#[derive(serde::Serialize)]
39struct TaskResult {
40 task_id: String,
41 task_name: String,
42 success: bool,
43 output: String,
44 error: Option<String>,
45 steps: usize,
46 tool_calls: usize,
47}
48
49#[async_trait]
50impl Tool for SwarmExecuteTool {
51 fn id(&self) -> &str {
52 "swarm_execute"
53 }
54
55 fn name(&self) -> &str {
56 "Swarm Execute"
57 }
58
59 fn description(&self) -> &str {
60 "Execute multiple tasks in parallel across multiple sub-agents. \
61 Each task runs independently in its own agent context. \
62 Returns aggregated results from all swarm participants. \
63 Handles partial failures gracefully based on aggregation strategy."
64 }
65
66 fn parameters(&self) -> Value {
67 json!({
68 "type": "object",
69 "properties": {
70 "tasks": {
71 "type": "array",
72 "items": {
73 "type": "object",
74 "properties": {
75 "id": {
76 "type": "string",
77 "description": "Unique identifier for this task (auto-generated if not provided)"
78 },
79 "name": {
80 "type": "string",
81 "description": "Human-readable name for this task"
82 },
83 "instruction": {
84 "type": "string",
85 "description": "The instruction for the sub-agent to execute"
86 },
87 "specialty": {
88 "type": "string",
89 "description": "Optional specialty for the sub-agent (e.g., 'Code Writer', 'Researcher', 'Tester')"
90 }
91 },
92 "required": ["name", "instruction"]
93 },
94 "description": "Array of tasks to execute in parallel"
95 },
96 "concurrency_limit": {
97 "type": "integer",
98 "description": "Maximum number of concurrent agents (default: 5)",
99 "default": 5
100 },
101 "aggregation_strategy": {
102 "type": "string",
103 "enum": ["all", "first_error", "best_effort"],
104 "description": "How to aggregate results: 'all' (all must succeed), 'first_error' (stop on first error), 'best_effort' (collect all, report failures)",
105 "default": "best_effort"
106 },
107 "model": {
108 "type": "string",
109 "description": "Model to use for sub-agents (provider/model format, e.g., 'anthropic/claude-sonnet-4-20250514'). Defaults to configured default."
110 },
111 "max_steps": {
112 "type": "integer",
113 "description": "Maximum steps per sub-agent (default: 50)",
114 "default": 50
115 },
116 "timeout_secs": {
117 "type": "integer",
118 "description": "Timeout per sub-agent in seconds (default: 300)",
119 "default": 300
120 }
121 },
122 "required": ["tasks"]
123 })
124 }
125
126 async fn execute(&self, params: Value) -> Result<ToolResult> {
127 let example = json!({
128 "tasks": [{"name": "Task 1", "instruction": "Do something"}],
129 "concurrency_limit": 5,
130 "aggregation_strategy": "best_effort"
131 });
132
133 let tasks_val = match params.get("tasks").and_then(|v| v.as_array()) {
135 Some(arr) if !arr.is_empty() => arr,
136 Some(_) => {
137 return Ok(ToolResult::structured_error(
138 "INVALID_FIELD",
139 "swarm_execute",
140 "tasks array must contain at least one task",
141 Some(vec!["tasks"]),
142 Some(example),
143 ));
144 }
145 None => {
146 return Ok(ToolResult::structured_error(
147 "MISSING_FIELD",
148 "swarm_execute",
149 "tasks is required and must be an array of task objects with 'name' and 'instruction' fields",
150 Some(vec!["tasks"]),
151 Some(example),
152 ));
153 }
154 };
155
156 let mut tasks = Vec::new();
157 for (i, task_val) in tasks_val.iter().enumerate() {
158 let name = match task_val.get("name").and_then(|v| v.as_str()) {
159 Some(s) => s.to_string(),
160 None => {
161 return Ok(ToolResult::structured_error(
162 "INVALID_FIELD",
163 "swarm_execute",
164 &format!("tasks[{i}].name is required and must be a string"),
165 Some(vec!["name"]),
166 Some(json!({"name": "Task Name", "instruction": "Do something"})),
167 ));
168 }
169 };
170 let instruction = match task_val.get("instruction").and_then(|v| v.as_str()) {
171 Some(s) => s.to_string(),
172 None => {
173 return Ok(ToolResult::structured_error(
174 "INVALID_FIELD",
175 "swarm_execute",
176 &format!("tasks[{i}].instruction is required and must be a string"),
177 Some(vec!["instruction"]),
178 Some(json!({"name": name, "instruction": "What the sub-agent should do"})),
179 ));
180 }
181 };
182 tasks.push(TaskInput {
183 id: task_val
184 .get("id")
185 .and_then(|v| v.as_str())
186 .map(String::from),
187 name,
188 instruction,
189 specialty: task_val
190 .get("specialty")
191 .and_then(|v| v.as_str())
192 .map(String::from),
193 });
194 }
195
196 let concurrency_limit = params
197 .get("concurrency_limit")
198 .and_then(|v| v.as_u64())
199 .map(|v| v as usize)
200 .unwrap_or(5);
201 let aggregation_strategy = params
202 .get("aggregation_strategy")
203 .and_then(|v| v.as_str())
204 .unwrap_or("best_effort")
205 .to_string();
206 let model = params
207 .get("model")
208 .and_then(|v| v.as_str())
209 .map(String::from);
210 let max_steps = params
211 .get("max_steps")
212 .and_then(|v| v.as_u64())
213 .map(|v| v as usize)
214 .unwrap_or(50);
215 let timeout_secs = params
216 .get("timeout_secs")
217 .and_then(|v| v.as_u64())
218 .unwrap_or(300);
219
220 let concurrency = concurrency_limit.min(20).max(1);
221
222 tracing::info!(
223 task_count = tasks.len(),
224 concurrency = concurrency,
225 strategy = %aggregation_strategy,
226 "Starting swarm execution"
227 );
228
229 let providers = ProviderRegistry::from_vault()
231 .await
232 .context("Failed to load providers")?;
233 let provider_list = providers.list();
234
235 if provider_list.is_empty() {
236 return Ok(ToolResult::error(
237 "No providers available for swarm execution",
238 ));
239 }
240
241 let (provider_name, model_name) = if let Some(ref model_str) = model {
243 let (prov, mod_id) = parse_model_string(model_str);
244 let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
245 (
246 prov.filter(|p| provider_list.contains(p))
247 .unwrap_or(provider_list[0])
248 .to_string(),
249 mod_id.to_string(),
250 )
251 } else {
252 let provider = if provider_list.contains(&"zai") {
254 "zai".to_string()
255 } else if provider_list.contains(&"openrouter") {
256 "openrouter".to_string()
257 } else {
258 provider_list[0].to_string()
259 };
260 let model = "glm-5".to_string();
261 (provider, model)
262 };
263
264 let provider = providers
265 .get(&provider_name)
266 .context("Failed to get provider")?;
267
268 tracing::info!(provider = %provider_name, model = %model_name, "Using provider for swarm");
269
270 let tools = Self::get_subagent_tools();
272
273 let system_prompt = r#"You are a sub-agent in a swarm execution context.
275Your role is to execute the given task independently and report your results.
276Focus on completing your specific task efficiently.
277Use available tools to accomplish your goal.
278When done, provide a clear summary of what you accomplished.
279Share any intermediate results using the swarm_share tool so other agents can benefit."#;
280
281 let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency));
283 let mut join_handles = Vec::new();
284
285 for task_input in tasks.clone() {
286 let semaphore = semaphore.clone();
287 let provider = provider.clone();
288 let tools = tools.clone();
289 let system_prompt = system_prompt.to_string();
290 let task_id = task_input
291 .id
292 .clone()
293 .unwrap_or_else(|| format!("task_{}", uuid::Uuid::new_v4()));
294 let model_name = model_name.clone();
295 let max_steps = max_steps;
296 let timeout_secs = timeout_secs;
297
298 let handle = tokio::spawn(async move {
299 let _permit = semaphore.acquire().await.unwrap();
300
301 let user_prompt = format!(
302 "Task: {}\nSpecialty: {}\n\nInstruction: {}",
303 task_input.name,
304 task_input
305 .specialty
306 .as_deref()
307 .unwrap_or("Generalist execution"),
308 task_input.instruction
309 );
310
311 let (output, steps, tool_calls, exit) = run_agent_loop(
312 provider,
313 &model_name,
314 &system_prompt,
315 &user_prompt,
316 tools,
317 Arc::new(ToolRegistry::new()),
318 max_steps,
319 timeout_secs,
320 None,
321 task_id.clone(),
322 None,
323 None,
324 )
325 .await?;
326
327 let success = matches!(exit, crate::swarm::executor::AgentLoopExit::Completed)
328 || matches!(exit, crate::swarm::executor::AgentLoopExit::MaxStepsReached);
329
330 Ok::<TaskResult, anyhow::Error>(TaskResult {
331 task_id,
332 task_name: task_input.name,
333 success,
334 output,
335 error: if success {
336 None
337 } else {
338 Some(format!("{:?}", exit))
339 },
340 steps,
341 tool_calls,
342 })
343 });
344
345 join_handles.push(handle);
346 }
347
348 let mut results: Vec<TaskResult> = Vec::new();
350 let mut failures = 0;
351
352 for handle in join_handles {
353 match handle.await {
354 Ok(Ok(result)) => {
355 if !result.success {
356 failures += 1;
357
358 match aggregation_strategy.as_str() {
360 "all" => {
361 return Ok(ToolResult::success(
363 json!({
364 "status": "failed",
365 "failed_task": result.task_name,
366 "error": result.error,
367 "results": [result],
368 "summary": {
369 "total": 1,
370 "success": 0,
371 "failures": 1
372 }
373 })
374 .to_string(),
375 ));
376 }
377 "first_error" => {
378 return Ok(ToolResult::success(
379 json!({
380 "status": "error",
381 "error": result.error,
382 "failed_task": result.task_name,
383 "completed_tasks": results.len(),
384 "results": results,
385 })
386 .to_string(),
387 ));
388 }
389 _ => {} }
391 }
392 results.push(result);
393 }
394 Ok(Err(e)) => {
395 failures += 1;
396 tracing::error!(error = %e, "Task execution failed");
397 }
398 Err(e) => {
399 failures += 1;
400 tracing::error!(error = %e, "Task join failed");
401 }
402 }
403 }
404
405 let total = results.len();
407 let successes = results.iter().filter(|r| r.success).count();
408
409 let response = if failures == 0 {
410 json!({
411 "status": "success",
412 "results": results,
413 "summary": {
414 "total": total,
415 "success": successes,
416 "failures": failures
417 }
418 })
419 } else {
420 match aggregation_strategy.as_str() {
421 "all" => json!({
422 "status": "partial_failure",
423 "results": results,
424 "summary": {
425 "total": total,
426 "success": successes,
427 "failures": failures
428 }
429 }),
430 "first_error" => json!({
431 "status": "error",
432 "results": results,
433 "summary": {
434 "total": total,
435 "success": successes,
436 "failures": failures
437 }
438 }),
439 _ => json!({
440 "status": "partial_success",
441 "results": results,
442 "summary": {
443 "total": total,
444 "success": successes,
445 "failures": failures
446 }
447 }),
448 }
449 };
450
451 Ok(ToolResult::success(response.to_string()))
452 }
453}
454
455impl SwarmExecuteTool {
456 fn get_subagent_tools() -> Vec<crate::provider::ToolDefinition> {
458 let registry = ToolRegistry::new();
460 registry
461 .definitions()
462 .into_iter()
463 .filter(|t| {
464 !matches!(
465 t.name.as_str(),
466 "question"
467 | "confirm_edit"
468 | "confirm_multiedit"
469 | "plan_enter"
470 | "plan_exit"
471 | "swarm_execute"
472 | "agent"
473 )
474 })
475 .collect()
476 }
477}