1use crate::Agent;
8use async_trait::async_trait;
9use cersei_provider::Provider;
10use cersei_tools::permissions::AllowAll;
11use cersei_tools::{PermissionLevel, Tool, ToolContext, ToolResult};
12use serde::Deserialize;
13use serde_json::{json, Value};
14use std::sync::Arc;
15
16pub struct AgentTool {
21 provider_factory: Arc<dyn Fn() -> Box<dyn Provider> + Send + Sync>,
22 available_tools: Vec<Box<dyn Tool>>,
23}
24
25impl AgentTool {
26 pub fn new(
31 provider_factory: impl Fn() -> Box<dyn Provider> + Send + Sync + 'static,
32 tools: Vec<Box<dyn Tool>>,
33 ) -> Self {
34 Self {
35 provider_factory: Arc::new(provider_factory),
36 available_tools: tools,
37 }
38 }
39}
40
41#[derive(Debug, Deserialize)]
42struct AgentInput {
43 description: String,
44 prompt: String,
45 #[serde(default)]
46 system_prompt: Option<String>,
47 #[serde(default)]
48 max_turns: Option<u32>,
49 #[serde(default)]
50 model: Option<String>,
51}
52
53#[async_trait]
54impl Tool for AgentTool {
55 fn name(&self) -> &str {
56 "Agent"
57 }
58
59 fn description(&self) -> &str {
60 "Launch a new agent to handle complex, multi-step tasks autonomously. \
61 The agent runs its own agentic loop with access to tools and returns \
62 its final result. Use this to delegate sub-tasks, run parallel \
63 workstreams, or handle tasks that require many tool calls."
64 }
65
66 fn permission_level(&self) -> PermissionLevel {
67 PermissionLevel::None
68 }
69
70 fn input_schema(&self) -> Value {
71 json!({
72 "type": "object",
73 "properties": {
74 "description": {
75 "type": "string",
76 "description": "Short description of the agent's task (3-5 words)"
77 },
78 "prompt": {
79 "type": "string",
80 "description": "The complete task for the agent to perform"
81 },
82 "system_prompt": {
83 "type": "string",
84 "description": "Optional system prompt override for the sub-agent"
85 },
86 "max_turns": {
87 "type": "integer",
88 "description": "Max turns for the sub-agent (default 10)"
89 },
90 "model": {
91 "type": "string",
92 "description": "Optional model override"
93 }
94 },
95 "required": ["description", "prompt"]
96 })
97 }
98
99 async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
100 let input: AgentInput = match serde_json::from_value(input) {
101 Ok(i) => i,
102 Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
103 };
104
105 tracing::info!(description = %input.description, "Spawning sub-agent");
106
107 let provider = (self.provider_factory)();
109
110 let sub_tools: Vec<Box<dyn Tool>> = self
112 .available_tools
113 .iter()
114 .filter(|t| t.name() != "Agent")
115 .map(|t| {
116 cersei_tools::all()
120 .into_iter()
121 .find(|st| st.name() == t.name())
122 })
123 .flatten()
124 .collect();
125
126 let sub_tools = if sub_tools.is_empty() {
128 cersei_tools::all()
129 .into_iter()
130 .filter(|t| t.name() != "Agent")
131 .collect()
132 } else {
133 sub_tools
134 };
135
136 let mut builder = Agent::builder()
137 .provider(provider)
138 .tools(sub_tools)
139 .max_turns(input.max_turns.unwrap_or(10))
140 .permission_policy(AllowAll)
141 .working_dir(&ctx.working_dir);
142
143 if let Some(sys) = input.system_prompt {
144 builder = builder.system_prompt(sys);
145 } else {
146 builder = builder.system_prompt(
147 "You are a specialized sub-agent. Complete the given task thoroughly and return your findings.",
148 );
149 }
150
151 if let Some(model) = input.model {
152 builder = builder.model(model);
153 }
154
155 let agent = match builder.build() {
156 Ok(a) => a,
157 Err(e) => return ToolResult::error(format!("Failed to build sub-agent: {}", e)),
158 };
159
160 match agent.run(&input.prompt).await {
161 Ok(output) => {
162 let text = output.text().to_string();
163 let meta = json!({
164 "turns": output.turns,
165 "tool_calls": output.tool_calls.len(),
166 "input_tokens": output.usage.input_tokens,
167 "output_tokens": output.usage.output_tokens,
168 });
169 ToolResult::success(text).with_metadata(meta)
170 }
171 Err(e) => ToolResult::error(format!("Sub-agent failed: {}", e)),
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use cersei_provider::{CompletionRequest, CompletionStream, ProviderCapabilities};
180 use cersei_tools::permissions::AllowAll;
181 use cersei_tools::{CostTracker, Extensions};
182 use cersei_types::*;
183 use tokio::sync::mpsc;
184
185 struct EchoProvider;
187
188 #[async_trait]
189 impl Provider for EchoProvider {
190 fn name(&self) -> &str {
191 "echo"
192 }
193 fn context_window(&self, _: &str) -> u64 {
194 4096
195 }
196 fn capabilities(&self, _: &str) -> ProviderCapabilities {
197 ProviderCapabilities {
198 streaming: true,
199 tool_use: false,
200 ..Default::default()
201 }
202 }
203 async fn complete(&self, req: CompletionRequest) -> cersei_types::Result<CompletionStream> {
204 let prompt = req
205 .messages
206 .last()
207 .and_then(|m| m.get_text())
208 .unwrap_or("")
209 .to_string();
210 let (tx, rx) = mpsc::channel(16);
211 tokio::spawn(async move {
212 let _ = tx
213 .send(StreamEvent::MessageStart {
214 id: "1".into(),
215 model: "echo".into(),
216 })
217 .await;
218 let _ = tx
219 .send(StreamEvent::ContentBlockStart {
220 index: 0,
221 block_type: "text".into(),
222 id: None,
223 name: None,
224 })
225 .await;
226 let _ = tx
227 .send(StreamEvent::TextDelta {
228 index: 0,
229 text: format!("Echo: {}", prompt),
230 })
231 .await;
232 let _ = tx.send(StreamEvent::ContentBlockStop { index: 0 }).await;
233 let _ = tx
234 .send(StreamEvent::MessageDelta {
235 stop_reason: Some(StopReason::EndTurn),
236 usage: Some(Usage {
237 input_tokens: 10,
238 output_tokens: 5,
239 ..Default::default()
240 }),
241 })
242 .await;
243 let _ = tx.send(StreamEvent::MessageStop).await;
244 });
245 Ok(CompletionStream::new(rx))
246 }
247 }
248
249 #[tokio::test]
250 async fn test_agent_tool_spawns_sub_agent() {
251 let agent_tool = AgentTool::new(|| Box::new(EchoProvider), cersei_tools::filesystem());
252
253 let ctx = ToolContext {
254 working_dir: std::env::temp_dir(),
255 session_id: "parent".into(),
256 permissions: Arc::new(AllowAll),
257 cost_tracker: Arc::new(CostTracker::new()),
258 mcp_manager: None,
259 extensions: Extensions::default(),
260 };
261
262 let result = agent_tool
263 .execute(
264 json!({
265 "description": "test sub-agent",
266 "prompt": "Hello from parent"
267 }),
268 &ctx,
269 )
270 .await;
271
272 assert!(
273 !result.is_error,
274 "Sub-agent should succeed: {}",
275 result.content
276 );
277 assert!(
278 result.content.contains("Echo:"),
279 "Should contain echo response"
280 );
281 assert!(result.metadata.is_some(), "Should have metadata");
282 }
283
284 #[tokio::test]
285 async fn test_agent_tool_filters_self() {
286 let agent_tool = AgentTool::new(|| Box::new(EchoProvider), cersei_tools::all());
288
289 let ctx = ToolContext {
290 working_dir: std::env::temp_dir(),
291 session_id: "parent".into(),
292 permissions: Arc::new(AllowAll),
293 cost_tracker: Arc::new(CostTracker::new()),
294 mcp_manager: None,
295 extensions: Extensions::default(),
296 };
297
298 let result = agent_tool
300 .execute(
301 json!({
302 "description": "test no recursion",
303 "prompt": "Do something"
304 }),
305 &ctx,
306 )
307 .await;
308
309 assert!(!result.is_error);
310 }
311}