1use super::{Tool, ToolResult};
7use crate::provider::{ProviderRegistry, parse_model_string};
8use crate::swarm::executor::run_agent_loop;
9use crate::swarm::orchestrator::{choose_default_provider, default_model_for_provider};
10use crate::tool::ToolRegistry;
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use serde_json::{Value, json};
14use std::sync::Arc;
15
16pub struct SwarmExecuteTool;
17
18impl SwarmExecuteTool {
19 pub fn new() -> Self {
20 Self
21 }
22}
23
24impl Default for SwarmExecuteTool {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30#[derive(Clone)]
31struct TaskInput {
32 id: Option<String>,
33 name: String,
34 instruction: String,
35 #[allow(dead_code)]
36 specialty: Option<String>,
37}
38
39#[derive(serde::Serialize)]
40struct TaskResult {
41 task_id: String,
42 task_name: String,
43 success: bool,
44 output: String,
45 error: Option<String>,
46 steps: usize,
47 tool_calls: usize,
48}
49
50#[async_trait]
51impl Tool for SwarmExecuteTool {
52 fn id(&self) -> &str {
53 "swarm_execute"
54 }
55
56 fn name(&self) -> &str {
57 "Swarm Execute"
58 }
59
60 fn description(&self) -> &str {
61 "Execute multiple tasks in parallel across multiple sub-agents. \
62 Each task runs independently in its own agent context. \
63 Returns aggregated results from all swarm participants. \
64 Handles partial failures gracefully based on aggregation strategy."
65 }
66
67 fn parameters(&self) -> Value {
68 json!({
69 "type": "object",
70 "properties": {
71 "tasks": {
72 "type": "array",
73 "items": {
74 "type": "object",
75 "properties": {
76 "id": {
77 "type": "string",
78 "description": "Unique identifier for this task (auto-generated if not provided)"
79 },
80 "name": {
81 "type": "string",
82 "description": "Human-readable name for this task"
83 },
84 "instruction": {
85 "type": "string",
86 "description": "The instruction for the sub-agent to execute"
87 },
88 "specialty": {
89 "type": "string",
90 "description": "Optional specialty for the sub-agent (e.g., 'Code Writer', 'Researcher', 'Tester')"
91 }
92 },
93 "required": ["name", "instruction"]
94 },
95 "description": "Array of tasks to execute in parallel"
96 },
97 "concurrency_limit": {
98 "type": "integer",
99 "description": "Maximum number of concurrent agents (default: 5)",
100 "default": 5
101 },
102 "aggregation_strategy": {
103 "type": "string",
104 "enum": ["all", "first_error", "best_effort"],
105 "description": "How to aggregate results: 'all' (all must succeed), 'first_error' (stop on first error), 'best_effort' (collect all, report failures)",
106 "default": "best_effort"
107 },
108 "model": {
109 "type": "string",
110 "description": "Model to use for sub-agents (provider/model format, e.g., 'anthropic/claude-sonnet-4-20250514'). Defaults to configured default."
111 },
112 "max_steps": {
113 "type": "integer",
114 "description": "Maximum steps per sub-agent (default: 50)",
115 "default": 50
116 },
117 "timeout_secs": {
118 "type": "integer",
119 "description": "Timeout per sub-agent in seconds (default: 300)",
120 "default": 300
121 }
122 },
123 "required": ["tasks"]
124 })
125 }
126
127 async fn execute(&self, params: Value) -> Result<ToolResult> {
128 let example = json!({
129 "tasks": [{"name": "Task 1", "instruction": "Do something"}],
130 "concurrency_limit": 5,
131 "aggregation_strategy": "best_effort"
132 });
133
134 let tasks_val = match params.get("tasks").and_then(|v| v.as_array()) {
136 Some(arr) if !arr.is_empty() => arr,
137 Some(_) => {
138 return Ok(ToolResult::structured_error(
139 "INVALID_FIELD",
140 "swarm_execute",
141 "tasks array must contain at least one task",
142 Some(vec!["tasks"]),
143 Some(example),
144 ));
145 }
146 None => {
147 return Ok(ToolResult::structured_error(
148 "MISSING_FIELD",
149 "swarm_execute",
150 "tasks is required and must be an array of task objects with 'name' and 'instruction' fields",
151 Some(vec!["tasks"]),
152 Some(example),
153 ));
154 }
155 };
156
157 let mut tasks = Vec::new();
158 for (i, task_val) in tasks_val.iter().enumerate() {
159 let name = match task_val.get("name").and_then(|v| v.as_str()) {
160 Some(s) => s.to_string(),
161 None => {
162 return Ok(ToolResult::structured_error(
163 "INVALID_FIELD",
164 "swarm_execute",
165 &format!("tasks[{i}].name is required and must be a string"),
166 Some(vec!["name"]),
167 Some(json!({"name": "Task Name", "instruction": "Do something"})),
168 ));
169 }
170 };
171 let instruction = match task_val.get("instruction").and_then(|v| v.as_str()) {
172 Some(s) => s.to_string(),
173 None => {
174 return Ok(ToolResult::structured_error(
175 "INVALID_FIELD",
176 "swarm_execute",
177 &format!("tasks[{i}].instruction is required and must be a string"),
178 Some(vec!["instruction"]),
179 Some(json!({"name": name, "instruction": "What the sub-agent should do"})),
180 ));
181 }
182 };
183 tasks.push(TaskInput {
184 id: task_val
185 .get("id")
186 .and_then(|v| v.as_str())
187 .map(String::from),
188 name,
189 instruction,
190 specialty: task_val
191 .get("specialty")
192 .and_then(|v| v.as_str())
193 .map(String::from),
194 });
195 }
196
197 let concurrency_limit = params
198 .get("concurrency_limit")
199 .and_then(|v| v.as_u64())
200 .map(|v| v as usize)
201 .unwrap_or(5);
202 let aggregation_strategy = params
203 .get("aggregation_strategy")
204 .and_then(|v| v.as_str())
205 .unwrap_or("best_effort")
206 .to_string();
207 let model = params
208 .get("model")
209 .and_then(|v| v.as_str())
210 .map(String::from);
211 let max_steps = params
212 .get("max_steps")
213 .and_then(|v| v.as_u64())
214 .map(|v| v as usize)
215 .unwrap_or(50);
216 let timeout_secs = params
217 .get("timeout_secs")
218 .and_then(|v| v.as_u64())
219 .unwrap_or(300);
220
221 let concurrency = concurrency_limit.min(20).max(1);
222
223 tracing::info!(
224 task_count = tasks.len(),
225 concurrency = concurrency,
226 strategy = %aggregation_strategy,
227 "Starting swarm execution"
228 );
229
230 let providers = ProviderRegistry::from_vault()
232 .await
233 .context("Failed to load providers")?;
234 let provider_list = providers.list();
235
236 if provider_list.is_empty() {
237 return Ok(ToolResult::error(
238 "No providers available for swarm execution",
239 ));
240 }
241
242 let (provider_name, model_name) = if let Some(ref model_str) = model {
244 let (prov, mod_id) = parse_model_string(model_str);
245 let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
246 let provider_name = if let Some(explicit_provider) = prov {
247 if provider_list.contains(&explicit_provider) {
248 explicit_provider.to_string()
249 } else {
250 return Ok(ToolResult::error(format!(
251 "Provider '{}' selected explicitly but is unavailable. Available providers: {}",
252 explicit_provider,
253 provider_list.join(", ")
254 )));
255 }
256 } else {
257 choose_default_provider(provider_list.as_slice())
258 .ok_or_else(|| anyhow::anyhow!("No providers available for swarm execution"))?
259 .to_string()
260 };
261 let model_name = if mod_id.trim().is_empty() {
262 default_model_for_provider(&provider_name)
263 } else {
264 mod_id.to_string()
265 };
266 (provider_name, model_name)
267 } else {
268 let provider_name = choose_default_provider(provider_list.as_slice())
269 .ok_or_else(|| anyhow::anyhow!("No providers available for swarm execution"))?
270 .to_string();
271 let model_name = default_model_for_provider(&provider_name);
272 (provider_name, model_name)
273 };
274
275 let provider = providers
276 .get(&provider_name)
277 .context("Failed to get provider")?;
278
279 tracing::info!(provider = %provider_name, model = %model_name, "Using provider for swarm");
280
281 let tools = Self::get_subagent_tools();
283
284 let system_prompt = r#"You are a sub-agent in a swarm execution context.
286Your role is to execute the given task independently and report your results.
287Focus on completing your specific task efficiently.
288Use available tools to accomplish your goal.
289When done, provide a clear summary of what you accomplished.
290Share any intermediate results using the swarm_share tool so other agents can benefit."#;
291
292 let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency));
294 let mut join_handles = Vec::new();
295
296 for task_input in tasks.clone() {
297 let semaphore = semaphore.clone();
298 let provider = provider.clone();
299 let tools = tools.clone();
300 let system_prompt = system_prompt.to_string();
301 let task_id = task_input
302 .id
303 .clone()
304 .unwrap_or_else(|| format!("task_{}", uuid::Uuid::new_v4()));
305 let model_name = model_name.clone();
306 let max_steps = max_steps;
307 let timeout_secs = timeout_secs;
308
309 let handle = tokio::spawn(async move {
310 let _permit = semaphore
311 .acquire()
312 .await
313 .expect("swarm semaphore closed unexpectedly");
314
315 let user_prompt = format!(
316 "Task: {}\nSpecialty: {}\n\nInstruction: {}",
317 task_input.name,
318 task_input
319 .specialty
320 .as_deref()
321 .unwrap_or("Generalist execution"),
322 task_input.instruction
323 );
324
325 let (output, steps, tool_calls, exit) = run_agent_loop(
326 provider,
327 &model_name,
328 &system_prompt,
329 &user_prompt,
330 tools,
331 Arc::new(ToolRegistry::new()),
332 max_steps,
333 timeout_secs,
334 None,
335 task_id.clone(),
336 None,
337 None,
338 )
339 .await?;
340
341 let success = matches!(exit, crate::swarm::executor::AgentLoopExit::Completed)
342 || matches!(exit, crate::swarm::executor::AgentLoopExit::MaxStepsReached);
343
344 Ok::<TaskResult, anyhow::Error>(TaskResult {
345 task_id,
346 task_name: task_input.name,
347 success,
348 output,
349 error: if success {
350 None
351 } else {
352 Some(format!("{:?}", exit))
353 },
354 steps,
355 tool_calls,
356 })
357 });
358
359 join_handles.push(handle);
360 }
361
362 let mut results: Vec<TaskResult> = Vec::new();
364 let mut failures = 0;
365
366 for handle in join_handles {
367 match handle.await {
368 Ok(Ok(result)) => {
369 if !result.success {
370 failures += 1;
371
372 match aggregation_strategy.as_str() {
374 "all" => {
375 return Ok(ToolResult::success(
377 json!({
378 "status": "failed",
379 "failed_task": result.task_name,
380 "error": result.error,
381 "results": [result],
382 "summary": {
383 "total": 1,
384 "success": 0,
385 "failures": 1
386 }
387 })
388 .to_string(),
389 ));
390 }
391 "first_error" => {
392 return Ok(ToolResult::success(
393 json!({
394 "status": "error",
395 "error": result.error,
396 "failed_task": result.task_name,
397 "completed_tasks": results.len(),
398 "results": results,
399 })
400 .to_string(),
401 ));
402 }
403 _ => {} }
405 }
406 results.push(result);
407 }
408 Ok(Err(e)) => {
409 failures += 1;
410 tracing::error!(error = %e, "Task execution failed");
411 }
412 Err(e) => {
413 failures += 1;
414 tracing::error!(error = %e, "Task join failed");
415 }
416 }
417 }
418
419 let total = results.len();
421 let successes = results.iter().filter(|r| r.success).count();
422
423 let response = if failures == 0 {
424 json!({
425 "status": "success",
426 "results": results,
427 "summary": {
428 "total": total,
429 "success": successes,
430 "failures": failures
431 }
432 })
433 } else {
434 match aggregation_strategy.as_str() {
435 "all" => json!({
436 "status": "partial_failure",
437 "results": results,
438 "summary": {
439 "total": total,
440 "success": successes,
441 "failures": failures
442 }
443 }),
444 "first_error" => json!({
445 "status": "error",
446 "results": results,
447 "summary": {
448 "total": total,
449 "success": successes,
450 "failures": failures
451 }
452 }),
453 _ => json!({
454 "status": "partial_success",
455 "results": results,
456 "summary": {
457 "total": total,
458 "success": successes,
459 "failures": failures
460 }
461 }),
462 }
463 };
464
465 Ok(ToolResult::success(response.to_string()))
466 }
467}
468
469impl SwarmExecuteTool {
470 fn get_subagent_tools() -> Vec<crate::provider::ToolDefinition> {
472 let registry = ToolRegistry::new();
474 registry
475 .definitions()
476 .into_iter()
477 .filter(|t| {
478 !matches!(
479 t.name.as_str(),
480 "question"
481 | "confirm_edit"
482 | "confirm_multiedit"
483 | "plan_enter"
484 | "plan_exit"
485 | "swarm_execute"
486 | "agent"
487 )
488 })
489 .collect()
490 }
491}