Skip to main content

cersei_agent/
delegate_tool.rs

1//! `DelegateTool` — a batch/parallel variant of `AgentTool`.
2//!
3//! Where `AgentTool` spawns one sub-agent, `DelegateTool` spawns N in parallel
4//! with the isolation / blocklist / depth rules from `crate::delegate`.
5//!
6//! The tool input accepts either a single `goal` or a `tasks` array. When both
7//! are present, `goal` is ignored in favor of `tasks` (matches the hermes-agent
8//! contract).
9
10use crate::delegate::{
11    run_batch, DelegateConfig, DelegateTask, ProviderFactory, ToolsetFactory,
12    DEFAULT_MAX_CONCURRENT,
13};
14use async_trait::async_trait;
15use cersei_tools::{PermissionLevel, Tool, ToolContext, ToolResult};
16use serde::Deserialize;
17use serde_json::{json, Value};
18
19pub struct DelegateTool {
20    provider_factory: ProviderFactory,
21    toolset_factory: ToolsetFactory,
22    model: Option<String>,
23    max_turns: u32,
24    max_concurrent: usize,
25}
26
27impl DelegateTool {
28    pub fn new(provider_factory: ProviderFactory, toolset_factory: ToolsetFactory) -> Self {
29        Self {
30            provider_factory,
31            toolset_factory,
32            model: None,
33            max_turns: 30,
34            max_concurrent: DEFAULT_MAX_CONCURRENT,
35        }
36    }
37
38    pub fn with_model(mut self, m: impl Into<String>) -> Self {
39        self.model = Some(m.into());
40        self
41    }
42
43    pub fn with_max_turns(mut self, n: u32) -> Self {
44        self.max_turns = n;
45        self
46    }
47
48    pub fn with_max_concurrent(mut self, n: usize) -> Self {
49        self.max_concurrent = n.max(1);
50        self
51    }
52}
53
54#[derive(Debug, Deserialize)]
55struct TaskInput {
56    goal: String,
57    #[serde(default)]
58    context: Option<String>,
59    #[serde(default)]
60    workspace: Option<String>,
61}
62
63#[derive(Debug, Deserialize)]
64struct Input {
65    #[serde(default)]
66    goal: Option<String>,
67    #[serde(default)]
68    context: Option<String>,
69    #[serde(default)]
70    tasks: Option<Vec<TaskInput>>,
71}
72
73#[async_trait]
74impl Tool for DelegateTool {
75    fn name(&self) -> &str {
76        "delegate"
77    }
78
79    fn description(&self) -> &str {
80        "Delegate one or more focused sub-tasks to isolated sub-agents running in \
81         parallel. Each child starts with a fresh conversation, a restricted toolset, \
82         and cannot spawn further sub-agents. Use `tasks` for a batch; otherwise pass \
83         a single `goal`. Returns a combined summary with one block per task."
84    }
85
86    fn permission_level(&self) -> PermissionLevel {
87        PermissionLevel::None
88    }
89
90    fn input_schema(&self) -> Value {
91        json!({
92            "type": "object",
93            "properties": {
94                "goal": {
95                    "type": "string",
96                    "description": "Single-task mode: what the sub-agent should accomplish."
97                },
98                "context": {
99                    "type": "string",
100                    "description": "Optional background context shared with the sub-agent(s)."
101                },
102                "tasks": {
103                    "type": "array",
104                    "description": "Batch mode: multiple tasks to run in parallel.",
105                    "items": {
106                        "type": "object",
107                        "properties": {
108                            "goal":      { "type": "string" },
109                            "context":   { "type": "string" },
110                            "workspace": { "type": "string" }
111                        },
112                        "required": ["goal"]
113                    }
114                }
115            }
116        })
117    }
118
119    async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
120        let parsed: Input = match serde_json::from_value(input) {
121            Ok(i) => i,
122            Err(e) => return ToolResult::error(format!("Invalid input: {e}")),
123        };
124
125        let tasks: Vec<DelegateTask> = if let Some(batch) = parsed.tasks {
126            if batch.is_empty() {
127                return ToolResult::error("tasks array is empty");
128            }
129            batch
130                .into_iter()
131                .map(|t| {
132                    let mut task = DelegateTask::new(t.goal);
133                    if let Some(c) = t.context {
134                        task = task.with_context(c);
135                    }
136                    if let Some(w) = t.workspace {
137                        task = task.with_workspace(std::path::PathBuf::from(w));
138                    } else {
139                        task = task.with_workspace(ctx.working_dir.clone());
140                    }
141                    task
142                })
143                .collect()
144        } else if let Some(goal) = parsed.goal {
145            let mut task = DelegateTask::new(goal).with_workspace(ctx.working_dir.clone());
146            if let Some(c) = parsed.context {
147                task = task.with_context(c);
148            }
149            vec![task]
150        } else {
151            return ToolResult::error("must provide either `goal` or `tasks`");
152        };
153
154        let cfg = DelegateConfig {
155            tasks,
156            provider_factory: self.provider_factory.clone(),
157            toolset_factory: self.toolset_factory.clone(),
158            model: self.model.clone(),
159            max_turns: self.max_turns,
160            max_concurrent: self.max_concurrent,
161            depth: 1,
162            extra_blocked: Vec::new(),
163        };
164
165        match run_batch(cfg).await {
166            Ok(results) => {
167                let total = results.len();
168                let failures = results.iter().filter(|r| !r.is_ok()).count();
169                let mut out = String::new();
170                for (i, r) in results.iter().enumerate() {
171                    out.push_str(&format!(
172                        "── Task {}/{}: {}\n",
173                        i + 1,
174                        total,
175                        truncate(&r.goal, 120)
176                    ));
177                    if let Some(err) = &r.error {
178                        out.push_str(&format!("   ERROR: {err}\n\n"));
179                    } else {
180                        out.push_str(&format!("{}\n\n", r.summary.trim()));
181                    }
182                }
183                let meta = json!({
184                    "tasks": total,
185                    "failures": failures,
186                });
187                if failures == total && total > 0 {
188                    ToolResult::error(out).with_metadata(meta)
189                } else {
190                    ToolResult::success(out).with_metadata(meta)
191                }
192            }
193            Err(e) => ToolResult::error(format!("delegate batch failed: {e}")),
194        }
195    }
196}
197
198fn truncate(s: &str, n: usize) -> String {
199    if s.len() <= n {
200        s.to_string()
201    } else {
202        let mut end = n;
203        while end > 0 && !s.is_char_boundary(end) {
204            end -= 1;
205        }
206        format!("{}…", &s[..end])
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::delegate::{ProviderFactory, ToolsetFactory};
214    use cersei_provider::{CompletionRequest, CompletionStream, Provider, ProviderCapabilities};
215    use cersei_tools::permissions::AllowAll;
216    use cersei_tools::{CostTracker, Extensions};
217    use cersei_types::*;
218    use std::sync::Arc;
219    use tokio::sync::mpsc;
220
221    struct EchoProvider;
222
223    #[async_trait]
224    impl Provider for EchoProvider {
225        fn name(&self) -> &str { "echo" }
226        fn context_window(&self, _: &str) -> u64 { 4096 }
227        fn capabilities(&self, _: &str) -> ProviderCapabilities {
228            ProviderCapabilities { streaming: true, tool_use: false, ..Default::default() }
229        }
230        async fn complete(&self, req: CompletionRequest) -> cersei_types::Result<CompletionStream> {
231            let prompt = req.messages.last().and_then(|m| m.get_text()).unwrap_or("").to_string();
232            let (tx, rx) = mpsc::channel(16);
233            tokio::spawn(async move {
234                let _ = tx.send(StreamEvent::MessageStart { id: "1".into(), model: "echo".into() }).await;
235                let _ = tx.send(StreamEvent::ContentBlockStart { index: 0, block_type: "text".into(), id: None, name: None }).await;
236                let _ = tx.send(StreamEvent::TextDelta { index: 0, text: format!("done: {prompt}") }).await;
237                let _ = tx.send(StreamEvent::ContentBlockStop { index: 0 }).await;
238                let _ = tx.send(StreamEvent::MessageDelta {
239                    stop_reason: Some(StopReason::EndTurn),
240                    usage: Some(Usage { input_tokens: 10, output_tokens: 5, ..Default::default() }),
241                }).await;
242                let _ = tx.send(StreamEvent::MessageStop).await;
243            });
244            Ok(CompletionStream::new(rx))
245        }
246    }
247
248    fn ctx() -> ToolContext {
249        ToolContext {
250            working_dir: std::env::temp_dir(),
251            session_id: "t".into(),
252            permissions: Arc::new(AllowAll),
253            cost_tracker: Arc::new(CostTracker::new()),
254            mcp_manager: None,
255            extensions: Extensions::default(),
256        }
257    }
258
259    fn factories() -> (ProviderFactory, ToolsetFactory) {
260        let pf: ProviderFactory = Arc::new(|| Box::new(EchoProvider));
261        let tf: ToolsetFactory = Arc::new(|| Vec::new());
262        (pf, tf)
263    }
264
265    #[tokio::test]
266    async fn single_goal_runs_one_child() {
267        let (pf, tf) = factories();
268        let tool = DelegateTool::new(pf, tf).with_max_turns(2);
269        let r = tool.execute(json!({ "goal": "ping" }), &ctx()).await;
270        assert!(!r.is_error, "{}", r.content);
271        assert!(r.content.contains("Task 1/1"));
272        assert!(r.content.contains("done:"));
273    }
274
275    #[tokio::test]
276    async fn batch_mode_runs_all_tasks() {
277        let (pf, tf) = factories();
278        let tool = DelegateTool::new(pf, tf).with_max_turns(2).with_max_concurrent(2);
279        let r = tool.execute(
280            json!({ "tasks": [{"goal": "a"}, {"goal": "b"}, {"goal": "c"}] }),
281            &ctx(),
282        ).await;
283        assert!(!r.is_error, "{}", r.content);
284        assert!(r.content.contains("Task 1/3"));
285        assert!(r.content.contains("Task 3/3"));
286    }
287
288    #[tokio::test]
289    async fn rejects_missing_goal_and_tasks() {
290        let (pf, tf) = factories();
291        let tool = DelegateTool::new(pf, tf);
292        let r = tool.execute(json!({}), &ctx()).await;
293        assert!(r.is_error);
294    }
295
296    #[tokio::test]
297    async fn rejects_empty_tasks_array() {
298        let (pf, tf) = factories();
299        let tool = DelegateTool::new(pf, tf);
300        let r = tool.execute(json!({ "tasks": [] }), &ctx()).await;
301        assert!(r.is_error);
302    }
303}