1use crate::delegate::{
11 run_batch, DelegateConfig, DelegateTask, ProviderFactory, ToolsetFactory,
12 DEFAULT_MAX_CONCURRENT,
13};
14use async_trait::async_trait;
15use cersei_tools::{PermissionLevel, Tool, ToolContext, ToolResult};
16use serde::Deserialize;
17use serde_json::{json, Value};
18
19pub struct DelegateTool {
20 provider_factory: ProviderFactory,
21 toolset_factory: ToolsetFactory,
22 model: Option<String>,
23 max_turns: u32,
24 max_concurrent: usize,
25}
26
27impl DelegateTool {
28 pub fn new(provider_factory: ProviderFactory, toolset_factory: ToolsetFactory) -> Self {
29 Self {
30 provider_factory,
31 toolset_factory,
32 model: None,
33 max_turns: 30,
34 max_concurrent: DEFAULT_MAX_CONCURRENT,
35 }
36 }
37
38 pub fn with_model(mut self, m: impl Into<String>) -> Self {
39 self.model = Some(m.into());
40 self
41 }
42
43 pub fn with_max_turns(mut self, n: u32) -> Self {
44 self.max_turns = n;
45 self
46 }
47
48 pub fn with_max_concurrent(mut self, n: usize) -> Self {
49 self.max_concurrent = n.max(1);
50 self
51 }
52}
53
54#[derive(Debug, Deserialize)]
55struct TaskInput {
56 goal: String,
57 #[serde(default)]
58 context: Option<String>,
59 #[serde(default)]
60 workspace: Option<String>,
61}
62
63#[derive(Debug, Deserialize)]
64struct Input {
65 #[serde(default)]
66 goal: Option<String>,
67 #[serde(default)]
68 context: Option<String>,
69 #[serde(default)]
70 tasks: Option<Vec<TaskInput>>,
71}
72
73#[async_trait]
74impl Tool for DelegateTool {
75 fn name(&self) -> &str {
76 "delegate"
77 }
78
79 fn description(&self) -> &str {
80 "Delegate one or more focused sub-tasks to isolated sub-agents running in \
81 parallel. Each child starts with a fresh conversation, a restricted toolset, \
82 and cannot spawn further sub-agents. Use `tasks` for a batch; otherwise pass \
83 a single `goal`. Returns a combined summary with one block per task."
84 }
85
86 fn permission_level(&self) -> PermissionLevel {
87 PermissionLevel::None
88 }
89
90 fn input_schema(&self) -> Value {
91 json!({
92 "type": "object",
93 "properties": {
94 "goal": {
95 "type": "string",
96 "description": "Single-task mode: what the sub-agent should accomplish."
97 },
98 "context": {
99 "type": "string",
100 "description": "Optional background context shared with the sub-agent(s)."
101 },
102 "tasks": {
103 "type": "array",
104 "description": "Batch mode: multiple tasks to run in parallel.",
105 "items": {
106 "type": "object",
107 "properties": {
108 "goal": { "type": "string" },
109 "context": { "type": "string" },
110 "workspace": { "type": "string" }
111 },
112 "required": ["goal"]
113 }
114 }
115 }
116 })
117 }
118
119 async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
120 let parsed: Input = match serde_json::from_value(input) {
121 Ok(i) => i,
122 Err(e) => return ToolResult::error(format!("Invalid input: {e}")),
123 };
124
125 let tasks: Vec<DelegateTask> = if let Some(batch) = parsed.tasks {
126 if batch.is_empty() {
127 return ToolResult::error("tasks array is empty");
128 }
129 batch
130 .into_iter()
131 .map(|t| {
132 let mut task = DelegateTask::new(t.goal);
133 if let Some(c) = t.context {
134 task = task.with_context(c);
135 }
136 if let Some(w) = t.workspace {
137 task = task.with_workspace(std::path::PathBuf::from(w));
138 } else {
139 task = task.with_workspace(ctx.working_dir.clone());
140 }
141 task
142 })
143 .collect()
144 } else if let Some(goal) = parsed.goal {
145 let mut task = DelegateTask::new(goal).with_workspace(ctx.working_dir.clone());
146 if let Some(c) = parsed.context {
147 task = task.with_context(c);
148 }
149 vec![task]
150 } else {
151 return ToolResult::error("must provide either `goal` or `tasks`");
152 };
153
154 let cfg = DelegateConfig {
155 tasks,
156 provider_factory: self.provider_factory.clone(),
157 toolset_factory: self.toolset_factory.clone(),
158 model: self.model.clone(),
159 max_turns: self.max_turns,
160 max_concurrent: self.max_concurrent,
161 depth: 1,
162 extra_blocked: Vec::new(),
163 };
164
165 match run_batch(cfg).await {
166 Ok(results) => {
167 let total = results.len();
168 let failures = results.iter().filter(|r| !r.is_ok()).count();
169 let mut out = String::new();
170 for (i, r) in results.iter().enumerate() {
171 out.push_str(&format!(
172 "── Task {}/{}: {}\n",
173 i + 1,
174 total,
175 truncate(&r.goal, 120)
176 ));
177 if let Some(err) = &r.error {
178 out.push_str(&format!(" ERROR: {err}\n\n"));
179 } else {
180 out.push_str(&format!("{}\n\n", r.summary.trim()));
181 }
182 }
183 let meta = json!({
184 "tasks": total,
185 "failures": failures,
186 });
187 if failures == total && total > 0 {
188 ToolResult::error(out).with_metadata(meta)
189 } else {
190 ToolResult::success(out).with_metadata(meta)
191 }
192 }
193 Err(e) => ToolResult::error(format!("delegate batch failed: {e}")),
194 }
195 }
196}
197
198fn truncate(s: &str, n: usize) -> String {
199 if s.len() <= n {
200 s.to_string()
201 } else {
202 let mut end = n;
203 while end > 0 && !s.is_char_boundary(end) {
204 end -= 1;
205 }
206 format!("{}…", &s[..end])
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::delegate::{ProviderFactory, ToolsetFactory};
214 use cersei_provider::{CompletionRequest, CompletionStream, Provider, ProviderCapabilities};
215 use cersei_tools::permissions::AllowAll;
216 use cersei_tools::{CostTracker, Extensions};
217 use cersei_types::*;
218 use std::sync::Arc;
219 use tokio::sync::mpsc;
220
221 struct EchoProvider;
222
223 #[async_trait]
224 impl Provider for EchoProvider {
225 fn name(&self) -> &str { "echo" }
226 fn context_window(&self, _: &str) -> u64 { 4096 }
227 fn capabilities(&self, _: &str) -> ProviderCapabilities {
228 ProviderCapabilities { streaming: true, tool_use: false, ..Default::default() }
229 }
230 async fn complete(&self, req: CompletionRequest) -> cersei_types::Result<CompletionStream> {
231 let prompt = req.messages.last().and_then(|m| m.get_text()).unwrap_or("").to_string();
232 let (tx, rx) = mpsc::channel(16);
233 tokio::spawn(async move {
234 let _ = tx.send(StreamEvent::MessageStart { id: "1".into(), model: "echo".into() }).await;
235 let _ = tx.send(StreamEvent::ContentBlockStart { index: 0, block_type: "text".into(), id: None, name: None }).await;
236 let _ = tx.send(StreamEvent::TextDelta { index: 0, text: format!("done: {prompt}") }).await;
237 let _ = tx.send(StreamEvent::ContentBlockStop { index: 0 }).await;
238 let _ = tx.send(StreamEvent::MessageDelta {
239 stop_reason: Some(StopReason::EndTurn),
240 usage: Some(Usage { input_tokens: 10, output_tokens: 5, ..Default::default() }),
241 }).await;
242 let _ = tx.send(StreamEvent::MessageStop).await;
243 });
244 Ok(CompletionStream::new(rx))
245 }
246 }
247
248 fn ctx() -> ToolContext {
249 ToolContext {
250 working_dir: std::env::temp_dir(),
251 session_id: "t".into(),
252 permissions: Arc::new(AllowAll),
253 cost_tracker: Arc::new(CostTracker::new()),
254 mcp_manager: None,
255 extensions: Extensions::default(),
256 }
257 }
258
259 fn factories() -> (ProviderFactory, ToolsetFactory) {
260 let pf: ProviderFactory = Arc::new(|| Box::new(EchoProvider));
261 let tf: ToolsetFactory = Arc::new(|| Vec::new());
262 (pf, tf)
263 }
264
265 #[tokio::test]
266 async fn single_goal_runs_one_child() {
267 let (pf, tf) = factories();
268 let tool = DelegateTool::new(pf, tf).with_max_turns(2);
269 let r = tool.execute(json!({ "goal": "ping" }), &ctx()).await;
270 assert!(!r.is_error, "{}", r.content);
271 assert!(r.content.contains("Task 1/1"));
272 assert!(r.content.contains("done:"));
273 }
274
275 #[tokio::test]
276 async fn batch_mode_runs_all_tasks() {
277 let (pf, tf) = factories();
278 let tool = DelegateTool::new(pf, tf).with_max_turns(2).with_max_concurrent(2);
279 let r = tool.execute(
280 json!({ "tasks": [{"goal": "a"}, {"goal": "b"}, {"goal": "c"}] }),
281 &ctx(),
282 ).await;
283 assert!(!r.is_error, "{}", r.content);
284 assert!(r.content.contains("Task 1/3"));
285 assert!(r.content.contains("Task 3/3"));
286 }
287
288 #[tokio::test]
289 async fn rejects_missing_goal_and_tasks() {
290 let (pf, tf) = factories();
291 let tool = DelegateTool::new(pf, tf);
292 let r = tool.execute(json!({}), &ctx()).await;
293 assert!(r.is_error);
294 }
295
296 #[tokio::test]
297 async fn rejects_empty_tasks_array() {
298 let (pf, tf) = factories();
299 let tool = DelegateTool::new(pf, tf);
300 let r = tool.execute(json!({ "tasks": [] }), &ctx()).await;
301 assert!(r.is_error);
302 }
303}