Skip to main content

cersei_agent/
agent_tool.rs

1//! AgentTool: spawn a sub-agent to handle complex sub-tasks.
2//!
3//! Each sub-agent runs its own agentic loop with independent message history,
4//! cost tracking, and tool access. The `Agent` tool is filtered out of
5//! sub-agents to prevent infinite recursion.
6
7use crate::Agent;
8use async_trait::async_trait;
9use cersei_provider::Provider;
10use cersei_tools::permissions::AllowAll;
11use cersei_tools::{PermissionLevel, Tool, ToolContext, ToolResult};
12use serde::Deserialize;
13use serde_json::{json, Value};
14use std::sync::Arc;
15
16/// The AgentTool — spawns independent sub-agents.
17///
18/// This must be constructed with a reference to the parent's provider
19/// so sub-agents can make their own API calls.
20pub struct AgentTool {
21    provider_factory: Arc<dyn Fn() -> Box<dyn Provider> + Send + Sync>,
22    available_tools: Vec<Box<dyn Tool>>,
23}
24
25impl AgentTool {
26    /// Create an AgentTool with a provider factory and available tools.
27    ///
28    /// The provider factory creates a new provider instance for each sub-agent.
29    /// The `available_tools` list will have "Agent" filtered out automatically.
30    pub fn new(
31        provider_factory: impl Fn() -> Box<dyn Provider> + Send + Sync + 'static,
32        tools: Vec<Box<dyn Tool>>,
33    ) -> Self {
34        Self {
35            provider_factory: Arc::new(provider_factory),
36            available_tools: tools,
37        }
38    }
39}
40
41#[derive(Debug, Deserialize)]
42struct AgentInput {
43    description: String,
44    prompt: String,
45    #[serde(default)]
46    system_prompt: Option<String>,
47    #[serde(default)]
48    max_turns: Option<u32>,
49    #[serde(default)]
50    model: Option<String>,
51}
52
53#[async_trait]
54impl Tool for AgentTool {
55    fn name(&self) -> &str {
56        "Agent"
57    }
58
59    fn description(&self) -> &str {
60        "Launch a new agent to handle complex, multi-step tasks autonomously. \
61         The agent runs its own agentic loop with access to tools and returns \
62         its final result. Use this to delegate sub-tasks, run parallel \
63         workstreams, or handle tasks that require many tool calls."
64    }
65
66    fn permission_level(&self) -> PermissionLevel {
67        PermissionLevel::None
68    }
69
70    fn input_schema(&self) -> Value {
71        json!({
72            "type": "object",
73            "properties": {
74                "description": {
75                    "type": "string",
76                    "description": "Short description of the agent's task (3-5 words)"
77                },
78                "prompt": {
79                    "type": "string",
80                    "description": "The complete task for the agent to perform"
81                },
82                "system_prompt": {
83                    "type": "string",
84                    "description": "Optional system prompt override for the sub-agent"
85                },
86                "max_turns": {
87                    "type": "integer",
88                    "description": "Max turns for the sub-agent (default 10)"
89                },
90                "model": {
91                    "type": "string",
92                    "description": "Optional model override"
93                }
94            },
95            "required": ["description", "prompt"]
96        })
97    }
98
99    async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
100        let input: AgentInput = match serde_json::from_value(input) {
101            Ok(i) => i,
102            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
103        };
104
105        tracing::info!(description = %input.description, "Spawning sub-agent");
106
107        // Create a fresh provider for the sub-agent
108        let provider = (self.provider_factory)();
109
110        // Filter out "Agent" tool to prevent recursion
111        let sub_tools: Vec<Box<dyn Tool>> = self
112            .available_tools
113            .iter()
114            .filter(|t| t.name() != "Agent")
115            .map(|t| {
116                // We can't clone Box<dyn Tool>, so we rebuild tool sets
117                // This is a limitation — in practice, sub-agents get the
118                // standard tool sets minus Agent
119                cersei_tools::all()
120                    .into_iter()
121                    .find(|st| st.name() == t.name())
122            })
123            .flatten()
124            .collect();
125
126        // Use standard tools if filtering resulted in empty set
127        let sub_tools = if sub_tools.is_empty() {
128            cersei_tools::all()
129                .into_iter()
130                .filter(|t| t.name() != "Agent")
131                .collect()
132        } else {
133            sub_tools
134        };
135
136        let mut builder = Agent::builder()
137            .provider(provider)
138            .tools(sub_tools)
139            .max_turns(input.max_turns.unwrap_or(10))
140            .permission_policy(AllowAll)
141            .working_dir(&ctx.working_dir);
142
143        if let Some(sys) = input.system_prompt {
144            builder = builder.system_prompt(sys);
145        } else {
146            builder = builder.system_prompt(
147                "You are a specialized sub-agent. Complete the given task thoroughly and return your findings.",
148            );
149        }
150
151        if let Some(model) = input.model {
152            builder = builder.model(model);
153        }
154
155        let agent = match builder.build() {
156            Ok(a) => a,
157            Err(e) => return ToolResult::error(format!("Failed to build sub-agent: {}", e)),
158        };
159
160        match agent.run(&input.prompt).await {
161            Ok(output) => {
162                let text = output.text().to_string();
163                let meta = json!({
164                    "turns": output.turns,
165                    "tool_calls": output.tool_calls.len(),
166                    "input_tokens": output.usage.input_tokens,
167                    "output_tokens": output.usage.output_tokens,
168                });
169                ToolResult::success(text).with_metadata(meta)
170            }
171            Err(e) => ToolResult::error(format!("Sub-agent failed: {}", e)),
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use cersei_provider::{CompletionRequest, CompletionStream, ProviderCapabilities};
180    use cersei_tools::permissions::AllowAll;
181    use cersei_tools::{CostTracker, Extensions};
182    use cersei_types::*;
183    use tokio::sync::mpsc;
184
185    /// Mock provider that returns EndTurn immediately with a text response.
186    struct EchoProvider;
187
188    #[async_trait]
189    impl Provider for EchoProvider {
190        fn name(&self) -> &str {
191            "echo"
192        }
193        fn context_window(&self, _: &str) -> u64 {
194            4096
195        }
196        fn capabilities(&self, _: &str) -> ProviderCapabilities {
197            ProviderCapabilities {
198                streaming: true,
199                tool_use: false,
200                ..Default::default()
201            }
202        }
203        async fn complete(&self, req: CompletionRequest) -> cersei_types::Result<CompletionStream> {
204            let prompt = req
205                .messages
206                .last()
207                .and_then(|m| m.get_text())
208                .unwrap_or("")
209                .to_string();
210            let (tx, rx) = mpsc::channel(16);
211            tokio::spawn(async move {
212                let _ = tx
213                    .send(StreamEvent::MessageStart {
214                        id: "1".into(),
215                        model: "echo".into(),
216                    })
217                    .await;
218                let _ = tx
219                    .send(StreamEvent::ContentBlockStart {
220                        index: 0,
221                        block_type: "text".into(),
222                        id: None,
223                        name: None,
224                    })
225                    .await;
226                let _ = tx
227                    .send(StreamEvent::TextDelta {
228                        index: 0,
229                        text: format!("Echo: {}", prompt),
230                    })
231                    .await;
232                let _ = tx.send(StreamEvent::ContentBlockStop { index: 0 }).await;
233                let _ = tx
234                    .send(StreamEvent::MessageDelta {
235                        stop_reason: Some(StopReason::EndTurn),
236                        usage: Some(Usage {
237                            input_tokens: 10,
238                            output_tokens: 5,
239                            ..Default::default()
240                        }),
241                    })
242                    .await;
243                let _ = tx.send(StreamEvent::MessageStop).await;
244            });
245            Ok(CompletionStream::new(rx))
246        }
247    }
248
249    #[tokio::test]
250    async fn test_agent_tool_spawns_sub_agent() {
251        let agent_tool = AgentTool::new(|| Box::new(EchoProvider), cersei_tools::filesystem());
252
253        let ctx = ToolContext {
254            working_dir: std::env::temp_dir(),
255            session_id: "parent".into(),
256            permissions: Arc::new(AllowAll),
257            cost_tracker: Arc::new(CostTracker::new()),
258            mcp_manager: None,
259            extensions: Extensions::default(),
260        };
261
262        let result = agent_tool
263            .execute(
264                json!({
265                    "description": "test sub-agent",
266                    "prompt": "Hello from parent"
267                }),
268                &ctx,
269            )
270            .await;
271
272        assert!(
273            !result.is_error,
274            "Sub-agent should succeed: {}",
275            result.content
276        );
277        assert!(
278            result.content.contains("Echo:"),
279            "Should contain echo response"
280        );
281        assert!(result.metadata.is_some(), "Should have metadata");
282    }
283
284    #[tokio::test]
285    async fn test_agent_tool_filters_self() {
286        // Verify Agent tool is not available to sub-agents (no recursion)
287        let agent_tool = AgentTool::new(|| Box::new(EchoProvider), cersei_tools::all());
288
289        let ctx = ToolContext {
290            working_dir: std::env::temp_dir(),
291            session_id: "parent".into(),
292            permissions: Arc::new(AllowAll),
293            cost_tracker: Arc::new(CostTracker::new()),
294            mcp_manager: None,
295            extensions: Extensions::default(),
296        };
297
298        // This should work — sub-agent gets tools minus "Agent"
299        let result = agent_tool
300            .execute(
301                json!({
302                    "description": "test no recursion",
303                    "prompt": "Do something"
304                }),
305                &ctx,
306            )
307            .await;
308
309        assert!(!result.is_error);
310    }
311}