Skip to main content

codetether_agent/tool/
swarm_execute.rs

1//! Swarm Execute Tool - Parallel task execution across multiple agents
2//!
3//! This tool enables LLM-driven parallel execution of tasks across multiple
4//! sub-agents in a swarm pattern, with configurable concurrency and aggregation.
5
6use super::{Tool, ToolResult};
7use crate::provider::{ProviderRegistry, parse_model_string};
8use crate::swarm::executor::run_agent_loop;
9use crate::swarm::orchestrator::{choose_default_provider, default_model_for_provider};
10use crate::tool::ToolRegistry;
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use serde_json::{Value, json};
14use std::sync::Arc;
15
16pub struct SwarmExecuteTool;
17
18impl SwarmExecuteTool {
19    pub fn new() -> Self {
20        Self
21    }
22}
23
24impl Default for SwarmExecuteTool {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30#[derive(Clone)]
31struct TaskInput {
32    id: Option<String>,
33    name: String,
34    instruction: String,
35    #[allow(dead_code)]
36    specialty: Option<String>,
37}
38
39#[derive(serde::Serialize)]
40struct TaskResult {
41    task_id: String,
42    task_name: String,
43    success: bool,
44    output: String,
45    error: Option<String>,
46    steps: usize,
47    tool_calls: usize,
48}
49
50#[async_trait]
51impl Tool for SwarmExecuteTool {
52    fn id(&self) -> &str {
53        "swarm_execute"
54    }
55
56    fn name(&self) -> &str {
57        "Swarm Execute"
58    }
59
60    fn description(&self) -> &str {
61        "Execute multiple tasks in parallel across multiple sub-agents. \
62         Each task runs independently in its own agent context. \
63         Returns aggregated results from all swarm participants. \
64         Handles partial failures gracefully based on aggregation strategy."
65    }
66
67    fn parameters(&self) -> Value {
68        json!({
69            "type": "object",
70            "properties": {
71                "tasks": {
72                    "type": "array",
73                    "items": {
74                        "type": "object",
75                        "properties": {
76                            "id": {
77                                "type": "string",
78                                "description": "Unique identifier for this task (auto-generated if not provided)"
79                            },
80                            "name": {
81                                "type": "string",
82                                "description": "Human-readable name for this task"
83                            },
84                            "instruction": {
85                                "type": "string",
86                                "description": "The instruction for the sub-agent to execute"
87                            },
88                            "specialty": {
89                                "type": "string",
90                                "description": "Optional specialty for the sub-agent (e.g., 'Code Writer', 'Researcher', 'Tester')"
91                            }
92                        },
93                        "required": ["name", "instruction"]
94                    },
95                    "description": "Array of tasks to execute in parallel"
96                },
97                "concurrency_limit": {
98                    "type": "integer",
99                    "description": "Maximum number of concurrent agents (default: 5)",
100                    "default": 5
101                },
102                "aggregation_strategy": {
103                    "type": "string",
104                    "enum": ["all", "first_error", "best_effort"],
105                    "description": "How to aggregate results: 'all' (all must succeed), 'first_error' (stop on first error), 'best_effort' (collect all, report failures)",
106                    "default": "best_effort"
107                },
108                "model": {
109                    "type": "string",
110                    "description": "Model to use for sub-agents (provider/model format, e.g., 'anthropic/claude-sonnet-4-20250514'). Defaults to configured default."
111                },
112                "max_steps": {
113                    "type": "integer",
114                    "description": "Maximum steps per sub-agent (default: 50)",
115                    "default": 50
116                },
117                "timeout_secs": {
118                    "type": "integer",
119                    "description": "Timeout per sub-agent in seconds (default: 300)",
120                    "default": 300
121                }
122            },
123            "required": ["tasks"]
124        })
125    }
126
127    async fn execute(&self, params: Value) -> Result<ToolResult> {
128        let example = json!({
129            "tasks": [{"name": "Task 1", "instruction": "Do something"}],
130            "concurrency_limit": 5,
131            "aggregation_strategy": "best_effort"
132        });
133
134        // Parse tasks array
135        let tasks_val = match params.get("tasks").and_then(|v| v.as_array()) {
136            Some(arr) if !arr.is_empty() => arr,
137            Some(_) => {
138                return Ok(ToolResult::structured_error(
139                    "INVALID_FIELD",
140                    "swarm_execute",
141                    "tasks array must contain at least one task",
142                    Some(vec!["tasks"]),
143                    Some(example),
144                ));
145            }
146            None => {
147                return Ok(ToolResult::structured_error(
148                    "MISSING_FIELD",
149                    "swarm_execute",
150                    "tasks is required and must be an array of task objects with 'name' and 'instruction' fields",
151                    Some(vec!["tasks"]),
152                    Some(example),
153                ));
154            }
155        };
156
157        let mut tasks = Vec::new();
158        for (i, task_val) in tasks_val.iter().enumerate() {
159            let name = match task_val.get("name").and_then(|v| v.as_str()) {
160                Some(s) => s.to_string(),
161                None => {
162                    return Ok(ToolResult::structured_error(
163                        "INVALID_FIELD",
164                        "swarm_execute",
165                        &format!("tasks[{i}].name is required and must be a string"),
166                        Some(vec!["name"]),
167                        Some(json!({"name": "Task Name", "instruction": "Do something"})),
168                    ));
169                }
170            };
171            let instruction = match task_val.get("instruction").and_then(|v| v.as_str()) {
172                Some(s) => s.to_string(),
173                None => {
174                    return Ok(ToolResult::structured_error(
175                        "INVALID_FIELD",
176                        "swarm_execute",
177                        &format!("tasks[{i}].instruction is required and must be a string"),
178                        Some(vec!["instruction"]),
179                        Some(json!({"name": name, "instruction": "What the sub-agent should do"})),
180                    ));
181                }
182            };
183            tasks.push(TaskInput {
184                id: task_val
185                    .get("id")
186                    .and_then(|v| v.as_str())
187                    .map(String::from),
188                name,
189                instruction,
190                specialty: task_val
191                    .get("specialty")
192                    .and_then(|v| v.as_str())
193                    .map(String::from),
194            });
195        }
196
197        let concurrency_limit = params
198            .get("concurrency_limit")
199            .and_then(|v| v.as_u64())
200            .map(|v| v as usize)
201            .unwrap_or(5);
202        let aggregation_strategy = params
203            .get("aggregation_strategy")
204            .and_then(|v| v.as_str())
205            .unwrap_or("best_effort")
206            .to_string();
207        let model = params
208            .get("model")
209            .and_then(|v| v.as_str())
210            .map(String::from);
211        let max_steps = params
212            .get("max_steps")
213            .and_then(|v| v.as_u64())
214            .map(|v| v as usize)
215            .unwrap_or(50);
216        let timeout_secs = params
217            .get("timeout_secs")
218            .and_then(|v| v.as_u64())
219            .unwrap_or(300);
220
221        let concurrency = concurrency_limit.min(20).max(1);
222
223        tracing::info!(
224            task_count = tasks.len(),
225            concurrency = concurrency,
226            strategy = %aggregation_strategy,
227            "Starting swarm execution"
228        );
229
230        // Get provider registry from Vault
231        let providers = ProviderRegistry::from_vault()
232            .await
233            .context("Failed to load providers")?;
234        let provider_list = providers.list();
235
236        if provider_list.is_empty() {
237            return Ok(ToolResult::error(
238                "No providers available for swarm execution",
239            ));
240        }
241
242        // Determine provider/model to use
243        let (provider_name, model_name) = if let Some(ref model_str) = model {
244            let (prov, mod_id) = parse_model_string(model_str);
245            let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
246            let provider_name = if let Some(explicit_provider) = prov {
247                if provider_list.contains(&explicit_provider) {
248                    explicit_provider.to_string()
249                } else {
250                    return Ok(ToolResult::error(format!(
251                        "Provider '{}' selected explicitly but is unavailable. Available providers: {}",
252                        explicit_provider,
253                        provider_list.join(", ")
254                    )));
255                }
256            } else {
257                choose_default_provider(provider_list.as_slice())
258                    .ok_or_else(|| anyhow::anyhow!("No providers available for swarm execution"))?
259                    .to_string()
260            };
261            let model_name = if mod_id.trim().is_empty() {
262                default_model_for_provider(&provider_name)
263            } else {
264                mod_id.to_string()
265            };
266            (provider_name, model_name)
267        } else {
268            let provider_name = choose_default_provider(provider_list.as_slice())
269                .ok_or_else(|| anyhow::anyhow!("No providers available for swarm execution"))?
270                .to_string();
271            let model_name = default_model_for_provider(&provider_name);
272            (provider_name, model_name)
273        };
274
275        let provider = providers
276            .get(&provider_name)
277            .context("Failed to get provider")?;
278
279        tracing::info!(provider = %provider_name, model = %model_name, "Using provider for swarm");
280
281        // Get tool definitions (filtered for sub-agents)
282        let tools = Self::get_subagent_tools();
283
284        // System prompt for sub-agents
285        let system_prompt = r#"You are a sub-agent in a swarm execution context.
286Your role is to execute the given task independently and report your results.
287Focus on completing your specific task efficiently.
288Use available tools to accomplish your goal.
289When done, provide a clear summary of what you accomplished.
290Share any intermediate results using the swarm_share tool so other agents can benefit."#;
291
292        // Execute tasks concurrently using semaphore for rate limiting
293        let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency));
294        let mut join_handles = Vec::new();
295
296        for task_input in tasks.clone() {
297            let semaphore = semaphore.clone();
298            let provider = provider.clone();
299            let tools = tools.clone();
300            let system_prompt = system_prompt.to_string();
301            let task_id = task_input
302                .id
303                .clone()
304                .unwrap_or_else(|| format!("task_{}", uuid::Uuid::new_v4()));
305            let model_name = model_name.clone();
306            let max_steps = max_steps;
307            let timeout_secs = timeout_secs;
308
309            let handle = tokio::spawn(async move {
310                let _permit = semaphore
311                    .acquire()
312                    .await
313                    .expect("swarm semaphore closed unexpectedly");
314
315                let user_prompt = format!(
316                    "Task: {}\nSpecialty: {}\n\nInstruction: {}",
317                    task_input.name,
318                    task_input
319                        .specialty
320                        .as_deref()
321                        .unwrap_or("Generalist execution"),
322                    task_input.instruction
323                );
324
325                let (output, steps, tool_calls, exit) = run_agent_loop(
326                    provider,
327                    &model_name,
328                    &system_prompt,
329                    &user_prompt,
330                    tools,
331                    Arc::new(ToolRegistry::new()),
332                    max_steps,
333                    timeout_secs,
334                    None,
335                    task_id.clone(),
336                    None,
337                    None,
338                )
339                .await?;
340
341                let success = matches!(exit, crate::swarm::executor::AgentLoopExit::Completed)
342                    || matches!(exit, crate::swarm::executor::AgentLoopExit::MaxStepsReached);
343
344                Ok::<TaskResult, anyhow::Error>(TaskResult {
345                    task_id,
346                    task_name: task_input.name,
347                    success,
348                    output,
349                    error: if success {
350                        None
351                    } else {
352                        Some(format!("{:?}", exit))
353                    },
354                    steps,
355                    tool_calls,
356                })
357            });
358
359            join_handles.push(handle);
360        }
361
362        // Wait for all tasks to complete
363        let mut results: Vec<TaskResult> = Vec::new();
364        let mut failures = 0;
365
366        for handle in join_handles {
367            match handle.await {
368                Ok(Ok(result)) => {
369                    if !result.success {
370                        failures += 1;
371
372                        // Handle aggregation strategies
373                        match aggregation_strategy.as_str() {
374                            "all" => {
375                                // Return immediately on first failure
376                                return Ok(ToolResult::success(
377                                    json!({
378                                        "status": "failed",
379                                        "failed_task": result.task_name,
380                                        "error": result.error,
381                                        "results": [result],
382                                        "summary": {
383                                            "total": 1,
384                                            "success": 0,
385                                            "failures": 1
386                                        }
387                                    })
388                                    .to_string(),
389                                ));
390                            }
391                            "first_error" => {
392                                return Ok(ToolResult::success(
393                                    json!({
394                                        "status": "error",
395                                        "error": result.error,
396                                        "failed_task": result.task_name,
397                                        "completed_tasks": results.len(),
398                                        "results": results,
399                                    })
400                                    .to_string(),
401                                ));
402                            }
403                            _ => {} // "best_effort" - continue collecting
404                        }
405                    }
406                    results.push(result);
407                }
408                Ok(Err(e)) => {
409                    failures += 1;
410                    tracing::error!(error = %e, "Task execution failed");
411                }
412                Err(e) => {
413                    failures += 1;
414                    tracing::error!(error = %e, "Task join failed");
415                }
416            }
417        }
418
419        // Build aggregation response
420        let total = results.len();
421        let successes = results.iter().filter(|r| r.success).count();
422
423        let response = if failures == 0 {
424            json!({
425                "status": "success",
426                "results": results,
427                "summary": {
428                    "total": total,
429                    "success": successes,
430                    "failures": failures
431                }
432            })
433        } else {
434            match aggregation_strategy.as_str() {
435                "all" => json!({
436                    "status": "partial_failure",
437                    "results": results,
438                    "summary": {
439                        "total": total,
440                        "success": successes,
441                        "failures": failures
442                    }
443                }),
444                "first_error" => json!({
445                    "status": "error",
446                    "results": results,
447                    "summary": {
448                        "total": total,
449                        "success": successes,
450                        "failures": failures
451                    }
452                }),
453                _ => json!({
454                    "status": "partial_success",
455                    "results": results,
456                    "summary": {
457                        "total": total,
458                        "success": successes,
459                        "failures": failures
460                    }
461                }),
462            }
463        };
464
465        Ok(ToolResult::success(response.to_string()))
466    }
467}
468
469impl SwarmExecuteTool {
470    /// Get tool definitions suitable for sub-agents
471    fn get_subagent_tools() -> Vec<crate::provider::ToolDefinition> {
472        // Filter out interactive/blocking tools that don't work well for sub-agents
473        let registry = ToolRegistry::new();
474        registry
475            .definitions()
476            .into_iter()
477            .filter(|t| {
478                !matches!(
479                    t.name.as_str(),
480                    "question"
481                        | "confirm_edit"
482                        | "confirm_multiedit"
483                        | "plan_enter"
484                        | "plan_exit"
485                        | "swarm_execute"
486                        | "agent"
487                )
488            })
489            .collect()
490    }
491}