Skip to main content

codetether_agent/tool/
swarm_execute.rs

1//! Swarm Execute Tool - Parallel task execution across multiple agents
2//!
3//! This tool enables LLM-driven parallel execution of tasks across multiple
4//! sub-agents in a swarm pattern, with configurable concurrency and aggregation.
5
6use super::{Tool, ToolResult};
7use crate::provider::{ProviderRegistry, parse_model_string};
8use crate::swarm::executor::run_agent_loop;
9use crate::tool::ToolRegistry;
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use serde_json::{Value, json};
13use std::sync::Arc;
14
15pub struct SwarmExecuteTool;
16
17impl SwarmExecuteTool {
18    pub fn new() -> Self {
19        Self
20    }
21}
22
23impl Default for SwarmExecuteTool {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29#[derive(Clone)]
30struct TaskInput {
31    id: Option<String>,
32    name: String,
33    instruction: String,
34    #[allow(dead_code)]
35    specialty: Option<String>,
36}
37
38#[derive(serde::Serialize)]
39struct TaskResult {
40    task_id: String,
41    task_name: String,
42    success: bool,
43    output: String,
44    error: Option<String>,
45    steps: usize,
46    tool_calls: usize,
47}
48
49#[async_trait]
50impl Tool for SwarmExecuteTool {
51    fn id(&self) -> &str {
52        "swarm_execute"
53    }
54
55    fn name(&self) -> &str {
56        "Swarm Execute"
57    }
58
59    fn description(&self) -> &str {
60        "Execute multiple tasks in parallel across multiple sub-agents. \
61         Each task runs independently in its own agent context. \
62         Returns aggregated results from all swarm participants. \
63         Handles partial failures gracefully based on aggregation strategy."
64    }
65
66    fn parameters(&self) -> Value {
67        json!({
68            "type": "object",
69            "properties": {
70                "tasks": {
71                    "type": "array",
72                    "items": {
73                        "type": "object",
74                        "properties": {
75                            "id": {
76                                "type": "string",
77                                "description": "Unique identifier for this task (auto-generated if not provided)"
78                            },
79                            "name": {
80                                "type": "string",
81                                "description": "Human-readable name for this task"
82                            },
83                            "instruction": {
84                                "type": "string",
85                                "description": "The instruction for the sub-agent to execute"
86                            },
87                            "specialty": {
88                                "type": "string",
89                                "description": "Optional specialty for the sub-agent (e.g., 'Code Writer', 'Researcher', 'Tester')"
90                            }
91                        },
92                        "required": ["name", "instruction"]
93                    },
94                    "description": "Array of tasks to execute in parallel"
95                },
96                "concurrency_limit": {
97                    "type": "integer",
98                    "description": "Maximum number of concurrent agents (default: 5)",
99                    "default": 5
100                },
101                "aggregation_strategy": {
102                    "type": "string",
103                    "enum": ["all", "first_error", "best_effort"],
104                    "description": "How to aggregate results: 'all' (all must succeed), 'first_error' (stop on first error), 'best_effort' (collect all, report failures)",
105                    "default": "best_effort"
106                },
107                "model": {
108                    "type": "string",
109                    "description": "Model to use for sub-agents (provider/model format, e.g., 'anthropic/claude-sonnet-4-20250514'). Defaults to configured default."
110                },
111                "max_steps": {
112                    "type": "integer",
113                    "description": "Maximum steps per sub-agent (default: 50)",
114                    "default": 50
115                },
116                "timeout_secs": {
117                    "type": "integer",
118                    "description": "Timeout per sub-agent in seconds (default: 300)",
119                    "default": 300
120                }
121            },
122            "required": ["tasks"]
123        })
124    }
125
126    async fn execute(&self, params: Value) -> Result<ToolResult> {
127        let example = json!({
128            "tasks": [{"name": "Task 1", "instruction": "Do something"}],
129            "concurrency_limit": 5,
130            "aggregation_strategy": "best_effort"
131        });
132
133        // Parse tasks array
134        let tasks_val = match params.get("tasks").and_then(|v| v.as_array()) {
135            Some(arr) if !arr.is_empty() => arr,
136            Some(_) => {
137                return Ok(ToolResult::structured_error(
138                    "INVALID_FIELD",
139                    "swarm_execute",
140                    "tasks array must contain at least one task",
141                    Some(vec!["tasks"]),
142                    Some(example),
143                ));
144            }
145            None => {
146                return Ok(ToolResult::structured_error(
147                    "MISSING_FIELD",
148                    "swarm_execute",
149                    "tasks is required and must be an array of task objects with 'name' and 'instruction' fields",
150                    Some(vec!["tasks"]),
151                    Some(example),
152                ));
153            }
154        };
155
156        let mut tasks = Vec::new();
157        for (i, task_val) in tasks_val.iter().enumerate() {
158            let name = match task_val.get("name").and_then(|v| v.as_str()) {
159                Some(s) => s.to_string(),
160                None => {
161                    return Ok(ToolResult::structured_error(
162                        "INVALID_FIELD",
163                        "swarm_execute",
164                        &format!("tasks[{i}].name is required and must be a string"),
165                        Some(vec!["name"]),
166                        Some(json!({"name": "Task Name", "instruction": "Do something"})),
167                    ));
168                }
169            };
170            let instruction = match task_val.get("instruction").and_then(|v| v.as_str()) {
171                Some(s) => s.to_string(),
172                None => {
173                    return Ok(ToolResult::structured_error(
174                        "INVALID_FIELD",
175                        "swarm_execute",
176                        &format!("tasks[{i}].instruction is required and must be a string"),
177                        Some(vec!["instruction"]),
178                        Some(json!({"name": name, "instruction": "What the sub-agent should do"})),
179                    ));
180                }
181            };
182            tasks.push(TaskInput {
183                id: task_val
184                    .get("id")
185                    .and_then(|v| v.as_str())
186                    .map(String::from),
187                name,
188                instruction,
189                specialty: task_val
190                    .get("specialty")
191                    .and_then(|v| v.as_str())
192                    .map(String::from),
193            });
194        }
195
196        let concurrency_limit = params
197            .get("concurrency_limit")
198            .and_then(|v| v.as_u64())
199            .map(|v| v as usize)
200            .unwrap_or(5);
201        let aggregation_strategy = params
202            .get("aggregation_strategy")
203            .and_then(|v| v.as_str())
204            .unwrap_or("best_effort")
205            .to_string();
206        let model = params
207            .get("model")
208            .and_then(|v| v.as_str())
209            .map(String::from);
210        let max_steps = params
211            .get("max_steps")
212            .and_then(|v| v.as_u64())
213            .map(|v| v as usize)
214            .unwrap_or(50);
215        let timeout_secs = params
216            .get("timeout_secs")
217            .and_then(|v| v.as_u64())
218            .unwrap_or(300);
219
220        let concurrency = concurrency_limit.min(20).max(1);
221
222        tracing::info!(
223            task_count = tasks.len(),
224            concurrency = concurrency,
225            strategy = %aggregation_strategy,
226            "Starting swarm execution"
227        );
228
229        // Get provider registry from Vault
230        let providers = ProviderRegistry::from_vault()
231            .await
232            .context("Failed to load providers")?;
233        let provider_list = providers.list();
234
235        if provider_list.is_empty() {
236            return Ok(ToolResult::error(
237                "No providers available for swarm execution",
238            ));
239        }
240
241        // Determine model to use
242        let (provider_name, model_name) = if let Some(ref model_str) = model {
243            let (prov, mod_id) = parse_model_string(model_str);
244            let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
245            (
246                prov.filter(|p| provider_list.contains(p))
247                    .unwrap_or(provider_list[0])
248                    .to_string(),
249                mod_id.to_string(),
250            )
251        } else {
252            // Default to GLM-5 via Z.AI for swarm
253            let provider = if provider_list.contains(&"zai") {
254                "zai".to_string()
255            } else if provider_list.contains(&"openrouter") {
256                "openrouter".to_string()
257            } else {
258                provider_list[0].to_string()
259            };
260            let model = "glm-5".to_string();
261            (provider, model)
262        };
263
264        let provider = providers
265            .get(&provider_name)
266            .context("Failed to get provider")?;
267
268        tracing::info!(provider = %provider_name, model = %model_name, "Using provider for swarm");
269
270        // Get tool definitions (filtered for sub-agents)
271        let tools = Self::get_subagent_tools();
272
273        // System prompt for sub-agents
274        let system_prompt = r#"You are a sub-agent in a swarm execution context.
275Your role is to execute the given task independently and report your results.
276Focus on completing your specific task efficiently.
277Use available tools to accomplish your goal.
278When done, provide a clear summary of what you accomplished.
279Share any intermediate results using the swarm_share tool so other agents can benefit."#;
280
281        // Execute tasks concurrently using semaphore for rate limiting
282        let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency));
283        let mut join_handles = Vec::new();
284
285        for task_input in tasks.clone() {
286            let semaphore = semaphore.clone();
287            let provider = provider.clone();
288            let tools = tools.clone();
289            let system_prompt = system_prompt.to_string();
290            let task_id = task_input
291                .id
292                .clone()
293                .unwrap_or_else(|| format!("task_{}", uuid::Uuid::new_v4()));
294            let model_name = model_name.clone();
295            let max_steps = max_steps;
296            let timeout_secs = timeout_secs;
297
298            let handle = tokio::spawn(async move {
299                let _permit = semaphore.acquire().await.unwrap();
300
301                let user_prompt = format!(
302                    "Task: {}\nSpecialty: {}\n\nInstruction: {}",
303                    task_input.name,
304                    task_input
305                        .specialty
306                        .as_deref()
307                        .unwrap_or("Generalist execution"),
308                    task_input.instruction
309                );
310
311                let (output, steps, tool_calls, exit) = run_agent_loop(
312                    provider,
313                    &model_name,
314                    &system_prompt,
315                    &user_prompt,
316                    tools,
317                    Arc::new(ToolRegistry::new()),
318                    max_steps,
319                    timeout_secs,
320                    None,
321                    task_id.clone(),
322                    None,
323                    None,
324                )
325                .await?;
326
327                let success = matches!(exit, crate::swarm::executor::AgentLoopExit::Completed)
328                    || matches!(exit, crate::swarm::executor::AgentLoopExit::MaxStepsReached);
329
330                Ok::<TaskResult, anyhow::Error>(TaskResult {
331                    task_id,
332                    task_name: task_input.name,
333                    success,
334                    output,
335                    error: if success {
336                        None
337                    } else {
338                        Some(format!("{:?}", exit))
339                    },
340                    steps,
341                    tool_calls,
342                })
343            });
344
345            join_handles.push(handle);
346        }
347
348        // Wait for all tasks to complete
349        let mut results: Vec<TaskResult> = Vec::new();
350        let mut failures = 0;
351
352        for handle in join_handles {
353            match handle.await {
354                Ok(Ok(result)) => {
355                    if !result.success {
356                        failures += 1;
357
358                        // Handle aggregation strategies
359                        match aggregation_strategy.as_str() {
360                            "all" => {
361                                // Return immediately on first failure
362                                return Ok(ToolResult::success(
363                                    json!({
364                                        "status": "failed",
365                                        "failed_task": result.task_name,
366                                        "error": result.error,
367                                        "results": [result],
368                                        "summary": {
369                                            "total": 1,
370                                            "success": 0,
371                                            "failures": 1
372                                        }
373                                    })
374                                    .to_string(),
375                                ));
376                            }
377                            "first_error" => {
378                                return Ok(ToolResult::success(
379                                    json!({
380                                        "status": "error",
381                                        "error": result.error,
382                                        "failed_task": result.task_name,
383                                        "completed_tasks": results.len(),
384                                        "results": results,
385                                    })
386                                    .to_string(),
387                                ));
388                            }
389                            _ => {} // "best_effort" - continue collecting
390                        }
391                    }
392                    results.push(result);
393                }
394                Ok(Err(e)) => {
395                    failures += 1;
396                    tracing::error!(error = %e, "Task execution failed");
397                }
398                Err(e) => {
399                    failures += 1;
400                    tracing::error!(error = %e, "Task join failed");
401                }
402            }
403        }
404
405        // Build aggregation response
406        let total = results.len();
407        let successes = results.iter().filter(|r| r.success).count();
408
409        let response = if failures == 0 {
410            json!({
411                "status": "success",
412                "results": results,
413                "summary": {
414                    "total": total,
415                    "success": successes,
416                    "failures": failures
417                }
418            })
419        } else {
420            match aggregation_strategy.as_str() {
421                "all" => json!({
422                    "status": "partial_failure",
423                    "results": results,
424                    "summary": {
425                        "total": total,
426                        "success": successes,
427                        "failures": failures
428                    }
429                }),
430                "first_error" => json!({
431                    "status": "error",
432                    "results": results,
433                    "summary": {
434                        "total": total,
435                        "success": successes,
436                        "failures": failures
437                    }
438                }),
439                _ => json!({
440                    "status": "partial_success",
441                    "results": results,
442                    "summary": {
443                        "total": total,
444                        "success": successes,
445                        "failures": failures
446                    }
447                }),
448            }
449        };
450
451        Ok(ToolResult::success(response.to_string()))
452    }
453}
454
455impl SwarmExecuteTool {
456    /// Get tool definitions suitable for sub-agents
457    fn get_subagent_tools() -> Vec<crate::provider::ToolDefinition> {
458        // Filter out interactive/blocking tools that don't work well for sub-agents
459        let registry = ToolRegistry::new();
460        registry
461            .definitions()
462            .into_iter()
463            .filter(|t| {
464                !matches!(
465                    t.name.as_str(),
466                    "question"
467                        | "confirm_edit"
468                        | "confirm_multiedit"
469                        | "plan_enter"
470                        | "plan_exit"
471                        | "swarm_execute"
472                        | "agent"
473                )
474            })
475            .collect()
476    }
477}