Skip to main content

agentzero_tools/
delegate.rs

1use agentzero_core::delegation::{
2    filter_tools, validate_delegation, DelegateConfig, DelegateRequest,
3};
4use agentzero_core::{
5    Agent, AgentConfig, ChatResult, MemoryEntry, MemoryStore, Provider, Tool, ToolContext,
6    ToolResult,
7};
8use async_trait::async_trait;
9use serde::Deserialize;
10use std::collections::HashMap;
11use std::sync::Arc;
12
13/// Function that builds a tool set for sub-agents. The delegate tool calls this
14/// when running in agentic mode, then filters the result based on each agent's
15/// `allowed_tools` configuration.
16pub type ToolBuilder = Arc<dyn Fn() -> anyhow::Result<Vec<Box<dyn Tool>>> + Send + Sync>;
17
18#[derive(Debug, Deserialize)]
19struct Input {
20    agent: String,
21    prompt: String,
22}
23
24pub struct DelegateTool {
25    agents: HashMap<String, DelegateConfig>,
26    current_depth: usize,
27    tool_builder: ToolBuilder,
28}
29
30impl DelegateTool {
31    pub fn new(
32        agents: HashMap<String, DelegateConfig>,
33        current_depth: usize,
34        tool_builder: ToolBuilder,
35    ) -> Self {
36        Self {
37            agents,
38            current_depth,
39            tool_builder,
40        }
41    }
42}
43
44#[async_trait]
45impl Tool for DelegateTool {
46    fn name(&self) -> &'static str {
47        "delegate"
48    }
49
50    fn description(&self) -> &'static str {
51        "Delegate a subtask to a named sub-agent with its own provider, model, and tool set."
52    }
53
54    fn input_schema(&self) -> Option<serde_json::Value> {
55        Some(serde_json::json!({
56            "type": "object",
57            "properties": {
58                "agent": { "type": "string", "description": "Name of the sub-agent to delegate to" },
59                "prompt": { "type": "string", "description": "The prompt/task to send to the sub-agent" }
60            },
61            "required": ["agent", "prompt"],
62            "additionalProperties": false
63        }))
64    }
65
66    async fn execute(&self, input: &str, ctx: &ToolContext) -> anyhow::Result<ToolResult> {
67        let parsed: Input =
68            serde_json::from_str(input).map_err(|e| anyhow::anyhow!("invalid input: {e}"))?;
69
70        let config = self
71            .agents
72            .get(&parsed.agent)
73            .ok_or_else(|| anyhow::anyhow!("unknown agent: {}", parsed.agent))?;
74
75        let request = DelegateRequest {
76            agent_name: parsed.agent.clone(),
77            prompt: parsed.prompt.clone(),
78            current_depth: self.current_depth,
79        };
80        validate_delegation(&request, config)?;
81
82        let api_key = resolve_delegate_api_key(config);
83
84        let provider = agentzero_providers::build_provider(
85            &config.provider_kind,
86            config.provider.clone(),
87            api_key,
88            config.model.clone(),
89        );
90
91        let output = if config.agentic {
92            run_agentic(provider, config, &parsed.prompt, ctx, &self.tool_builder).await?
93        } else {
94            // For single-shot (non-agentic) delegates, use <system> tags so the
95            // provider's `extract_system_prompt()` can parse them.
96            let effective_prompt = match &config.system_prompt {
97                Some(sp) => format!("<system>{sp}</system>\n{}", parsed.prompt),
98                None => parsed.prompt.clone(),
99            };
100            run_single_shot(provider.as_ref(), &effective_prompt).await?
101        };
102
103        Ok(ToolResult { output })
104    }
105}
106
107/// Resolve an API key for a delegate agent. Checks (in order):
108/// 1. Explicit `api_key` in the delegate config
109/// 2. Provider-specific environment variable
110/// 3. Generic `OPENAI_API_KEY` fallback
111fn resolve_delegate_api_key(config: &DelegateConfig) -> String {
112    if let Some(ref key) = config.api_key {
113        if !key.is_empty() {
114            return key.clone();
115        }
116    }
117
118    let provider_env_keys: &[&str] = match config.provider_kind.as_str() {
119        "anthropic" => &["ANTHROPIC_API_KEY"],
120        "openrouter" => &["OPENROUTER_API_KEY"],
121        "openai" => &["OPENAI_API_KEY"],
122        "google" | "gemini" => &["GOOGLE_API_KEY", "GEMINI_API_KEY"],
123        "groq" => &["GROQ_API_KEY"],
124        "together" | "together-ai" => &["TOGETHER_API_KEY"],
125        "deepseek" => &["DEEPSEEK_API_KEY"],
126        "mistral" => &["MISTRAL_API_KEY"],
127        "xai" | "grok" => &["XAI_API_KEY"],
128        _ => &[],
129    };
130
131    for key in provider_env_keys {
132        if let Ok(val) = std::env::var(key) {
133            if !val.is_empty() {
134                return val;
135            }
136        }
137    }
138
139    std::env::var("OPENAI_API_KEY").unwrap_or_default()
140}
141
142async fn run_single_shot(provider: &dyn Provider, prompt: &str) -> anyhow::Result<String> {
143    let result: ChatResult = provider.complete(prompt).await?;
144    Ok(result.output_text)
145}
146
147async fn run_agentic(
148    provider: Box<dyn Provider>,
149    config: &DelegateConfig,
150    prompt: &str,
151    ctx: &ToolContext,
152    tool_builder: &ToolBuilder,
153) -> anyhow::Result<String> {
154    let agent_config = AgentConfig {
155        max_tool_iterations: config.max_iterations,
156        system_prompt: config.system_prompt.clone(),
157        ..Default::default()
158    };
159    let memory = EphemeralMemory::default();
160
161    // Build tools for the sub-agent. The builder creates the full set; we
162    // filter to only those in the agent's allowed_tools (filter_tools also
163    // excludes "delegate" to prevent infinite chains).
164    let all_tools = tool_builder().unwrap_or_else(|_| vec![]);
165    let all_tool_names: Vec<String> = all_tools.iter().map(|t| t.name().to_string()).collect();
166    let allowed_names = filter_tools(&all_tool_names, &config.allowed_tools);
167    let tools: Vec<Box<dyn Tool>> = all_tools
168        .into_iter()
169        .filter(|t| allowed_names.contains(&t.name().to_string()))
170        .collect();
171
172    let agent = Agent::new(agent_config, provider, Box::new(memory), tools);
173
174    let response = agent
175        .respond(
176            agentzero_core::UserMessage {
177                text: prompt.to_string(),
178            },
179            ctx,
180        )
181        .await?;
182    Ok(response.text)
183}
184
185#[derive(Default)]
186struct EphemeralMemory {
187    entries: std::sync::Mutex<Vec<MemoryEntry>>,
188}
189
190#[async_trait]
191impl MemoryStore for EphemeralMemory {
192    async fn append(&self, entry: MemoryEntry) -> anyhow::Result<()> {
193        self.entries
194            .lock()
195            .expect("ephemeral memory lock poisoned")
196            .push(entry);
197        Ok(())
198    }
199
200    async fn recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
201        let entries = self.entries.lock().expect("ephemeral memory lock poisoned");
202        Ok(entries.iter().rev().take(limit).cloned().collect())
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use std::collections::HashSet;
210
211    fn noop_builder() -> ToolBuilder {
212        Arc::new(|| Ok(vec![]))
213    }
214
215    fn test_agents() -> HashMap<String, DelegateConfig> {
216        let mut map = HashMap::new();
217        map.insert(
218            "researcher".to_string(),
219            DelegateConfig {
220                name: "researcher".into(),
221                provider_kind: "openai".into(),
222                provider: "https://api.example.invalid/v1".into(),
223                model: "gpt-4o-mini".into(),
224                max_depth: 3,
225                agentic: false,
226                max_iterations: 10,
227                ..Default::default()
228            },
229        );
230        map.insert(
231            "coder".to_string(),
232            DelegateConfig {
233                name: "coder".into(),
234                provider_kind: "openai".into(),
235                provider: "https://api.example.invalid/v1".into(),
236                model: "gpt-4o".into(),
237                max_depth: 2,
238                agentic: true,
239                max_iterations: 5,
240                ..Default::default()
241            },
242        );
243        map
244    }
245
246    fn test_ctx() -> ToolContext {
247        ToolContext::new("/tmp".to_string())
248    }
249
250    #[tokio::test]
251    async fn delegate_unknown_agent_returns_error() {
252        let tool = DelegateTool::new(test_agents(), 0, noop_builder());
253        let result = tool
254            .execute(r#"{"agent":"nonexistent","prompt":"hello"}"#, &test_ctx())
255            .await;
256        assert!(result.is_err());
257        assert!(result.unwrap_err().to_string().contains("unknown agent"));
258    }
259
260    #[tokio::test]
261    async fn delegate_depth_exceeded_returns_error() {
262        let tool = DelegateTool::new(test_agents(), 3, noop_builder());
263        let result = tool
264            .execute(r#"{"agent":"researcher","prompt":"hello"}"#, &test_ctx())
265            .await;
266        assert!(result.is_err());
267        assert!(result.unwrap_err().to_string().contains("depth limit"));
268    }
269
270    #[tokio::test]
271    async fn delegate_invalid_input_returns_error() {
272        let tool = DelegateTool::new(test_agents(), 0, noop_builder());
273        let result = tool.execute(r#"not json"#, &test_ctx()).await;
274        assert!(result.is_err());
275        assert!(result.unwrap_err().to_string().contains("invalid input"));
276    }
277
278    #[tokio::test]
279    async fn delegate_rejects_agent_with_delegate_in_allowlist() {
280        let mut agents = HashMap::new();
281        let mut allowed = HashSet::new();
282        allowed.insert("delegate".to_string());
283        agents.insert(
284            "bad".to_string(),
285            DelegateConfig {
286                name: "bad".into(),
287                provider_kind: "openai".into(),
288                provider: "https://api.example.invalid/v1".into(),
289                model: "gpt-4o".into(),
290                max_depth: 3,
291                agentic: true,
292                allowed_tools: allowed,
293                ..Default::default()
294            },
295        );
296        let tool = DelegateTool::new(agents, 0, noop_builder());
297        let result = tool
298            .execute(r#"{"agent":"bad","prompt":"hello"}"#, &test_ctx())
299            .await;
300        assert!(result.is_err());
301        assert!(result.unwrap_err().to_string().contains("delegate"));
302    }
303
304    #[test]
305    fn resolve_api_key_prefers_explicit_config() {
306        let config = DelegateConfig {
307            api_key: Some("explicit-key".into()),
308            provider_kind: "openai".into(),
309            ..Default::default()
310        };
311        assert_eq!(resolve_delegate_api_key(&config), "explicit-key");
312    }
313
314    #[test]
315    fn resolve_api_key_uses_provider_specific_env_var() {
316        let config = DelegateConfig {
317            provider_kind: "anthropic".into(),
318            ..Default::default()
319        };
320        temp_env::with_vars(
321            [
322                ("ANTHROPIC_API_KEY", Some("ant-key")),
323                ("OPENAI_API_KEY", Some("oai-key")),
324            ],
325            || {
326                assert_eq!(resolve_delegate_api_key(&config), "ant-key");
327            },
328        );
329    }
330
331    #[test]
332    fn resolve_api_key_falls_back_to_openai_env() {
333        let config = DelegateConfig {
334            provider_kind: "custom".into(),
335            ..Default::default()
336        };
337        temp_env::with_vars(
338            [
339                ("OPENAI_API_KEY", Some("oai-fallback")),
340                ("ANTHROPIC_API_KEY", None),
341            ],
342            || {
343                assert_eq!(resolve_delegate_api_key(&config), "oai-fallback");
344            },
345        );
346    }
347
348    #[test]
349    fn system_prompt_uses_xml_tags_for_single_shot() {
350        let config = DelegateConfig {
351            system_prompt: Some("You are a research assistant.".into()),
352            ..Default::default()
353        };
354        let user_prompt = "Find docs about X";
355        let effective = match &config.system_prompt {
356            Some(sp) => format!("<system>{sp}</system>\n{user_prompt}"),
357            None => user_prompt.to_string(),
358        };
359        assert!(effective.starts_with("<system>You are a research assistant.</system>"));
360        assert!(effective.ends_with("Find docs about X"));
361    }
362
363    #[test]
364    fn no_system_prompt_passes_user_prompt_unchanged() {
365        let config = DelegateConfig::default();
366        let user_prompt = "Find docs about X";
367        let effective = match &config.system_prompt {
368            Some(sp) => format!("<system>{sp}</system>\n{user_prompt}"),
369            None => user_prompt.to_string(),
370        };
371        assert_eq!(effective, "Find docs about X");
372    }
373
374    #[test]
375    fn agentic_delegate_passes_system_prompt_via_config() {
376        // Verify that agentic mode sets system_prompt on AgentConfig.
377        let config = DelegateConfig {
378            system_prompt: Some("Be concise.".into()),
379            max_iterations: 5,
380            ..Default::default()
381        };
382        let agent_config = AgentConfig {
383            max_tool_iterations: config.max_iterations,
384            system_prompt: config.system_prompt.clone(),
385            ..Default::default()
386        };
387        assert_eq!(agent_config.system_prompt.as_deref(), Some("Be concise."));
388        assert_eq!(agent_config.max_tool_iterations, 5);
389    }
390
391    #[test]
392    fn tool_builder_filters_by_allowed_tools() {
393        use agentzero_core::{ToolContext, ToolResult};
394
395        // A simple test tool.
396        struct FakeTool(&'static str);
397        #[async_trait]
398        impl Tool for FakeTool {
399            fn name(&self) -> &'static str {
400                self.0
401            }
402            async fn execute(
403                &self,
404                _input: &str,
405                _ctx: &ToolContext,
406            ) -> anyhow::Result<ToolResult> {
407                Ok(ToolResult {
408                    output: "ok".into(),
409                })
410            }
411        }
412
413        let builder: ToolBuilder = Arc::new(|| {
414            Ok(vec![
415                Box::new(FakeTool("read_file")) as Box<dyn Tool>,
416                Box::new(FakeTool("shell")),
417                Box::new(FakeTool("delegate")),
418                Box::new(FakeTool("web_search")),
419            ])
420        });
421
422        // Build with an allowlist of just "read_file".
423        let all_tools = builder().unwrap();
424        let all_names: Vec<String> = all_tools.iter().map(|t| t.name().to_string()).collect();
425        let mut allowed = HashSet::new();
426        allowed.insert("read_file".to_string());
427        let filtered = filter_tools(&all_names, &allowed);
428        assert_eq!(filtered, vec!["read_file".to_string()]);
429
430        // Build with an empty allowlist (all except delegate).
431        let filtered_all = filter_tools(&all_names, &HashSet::new());
432        assert!(filtered_all.contains(&"read_file".to_string()));
433        assert!(filtered_all.contains(&"shell".to_string()));
434        assert!(filtered_all.contains(&"web_search".to_string()));
435        assert!(!filtered_all.contains(&"delegate".to_string()));
436    }
437}