Skip to main content

cortexai_tools/
delegation.rs

1//! Agent-to-agent delegation tools
2//!
3//! Allows agents to delegate tasks to other specialized agents,
4//! enabling hierarchical agent architectures and supervisor/worker patterns.
5//!
6//! # Example
7//! ```ignore
8//! use cortexai_tools::delegation::*;
9//!
10//! // Create agent registry
11//! let registry = AgentRegistry::new(engine.clone());
12//!
13//! // Register agents with descriptions
14//! registry.register_agent(
15//!     researcher_id,
16//!     "Research Agent",
17//!     "Specialized in web research and gathering information",
18//! );
19//!
20//! registry.register_agent(
21//!     writer_id,
22//!     "Writer Agent",
23//!     "Specialized in writing clear, concise content",
24//! );
25//!
26//! // Create delegation tool for supervisor
27//! let delegate_tool = DelegateAgentTool::new(registry);
28//! supervisor_tools.register(Arc::new(delegate_tool));
29//! ```
30
31use async_trait::async_trait;
32use parking_lot::RwLock;
33use cortexai_core::{
34    errors::ToolError,
35    tool::{ExecutionContext, Tool, ToolSchema},
36    AgentId, Content, Message,
37};
38use serde::{Deserialize, Serialize};
39use std::collections::HashMap;
40use std::sync::Arc;
41use std::time::Duration;
42use tokio::sync::oneshot;
43use tokio::time::timeout;
44use tracing::{debug, info, warn};
45
46/// Information about a registered agent available for delegation
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct AgentInfo {
49    /// Agent ID
50    pub id: AgentId,
51    /// Human-readable name
52    pub name: String,
53    /// Description of what the agent specializes in
54    pub description: String,
55    /// Optional tags for categorization
56    pub tags: Vec<String>,
57    /// Whether agent is currently available
58    pub available: bool,
59}
60
61impl AgentInfo {
62    pub fn new(id: AgentId, name: impl Into<String>, description: impl Into<String>) -> Self {
63        Self {
64            id,
65            name: name.into(),
66            description: description.into(),
67            tags: Vec::new(),
68            available: true,
69        }
70    }
71
72    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
73        self.tags = tags;
74        self
75    }
76}
77
78/// Callback type for sending messages to agents
79pub type MessageSender = Arc<dyn Fn(Message) -> Result<(), String> + Send + Sync>;
80
81/// Callback type for waiting for agent response
82pub type ResponseWaiter =
83    Arc<dyn Fn(AgentId, Duration) -> Option<oneshot::Receiver<Message>> + Send + Sync>;
84
85/// Registry of agents available for delegation
86pub struct AgentRegistry {
87    agents: RwLock<HashMap<AgentId, AgentInfo>>,
88    message_sender: MessageSender,
89    response_channels: RwLock<HashMap<AgentId, Vec<oneshot::Sender<Message>>>>,
90}
91
92impl AgentRegistry {
93    /// Create a new agent registry with a message sender callback
94    pub fn new(message_sender: MessageSender) -> Self {
95        Self {
96            agents: RwLock::new(HashMap::new()),
97            message_sender,
98            response_channels: RwLock::new(HashMap::new()),
99        }
100    }
101
102    /// Register an agent for delegation
103    pub fn register_agent(&self, info: AgentInfo) {
104        let id = info.id.clone();
105        self.agents.write().insert(id, info);
106    }
107
108    /// Register an agent with basic info
109    pub fn register(&self, id: AgentId, name: impl Into<String>, description: impl Into<String>) {
110        self.register_agent(AgentInfo::new(id, name, description));
111    }
112
113    /// Unregister an agent
114    pub fn unregister(&self, id: &AgentId) {
115        self.agents.write().remove(id);
116    }
117
118    /// Get agent info
119    pub fn get_agent(&self, id: &AgentId) -> Option<AgentInfo> {
120        self.agents.read().get(id).cloned()
121    }
122
123    /// List all available agents
124    pub fn list_agents(&self) -> Vec<AgentInfo> {
125        self.agents
126            .read()
127            .values()
128            .filter(|a| a.available)
129            .cloned()
130            .collect()
131    }
132
133    /// Find agents by tag
134    pub fn find_by_tag(&self, tag: &str) -> Vec<AgentInfo> {
135        self.agents
136            .read()
137            .values()
138            .filter(|a| a.available && a.tags.iter().any(|t| t == tag))
139            .cloned()
140            .collect()
141    }
142
143    /// Set agent availability
144    pub fn set_available(&self, id: &AgentId, available: bool) {
145        if let Some(agent) = self.agents.write().get_mut(id) {
146            agent.available = available;
147        }
148    }
149
150    /// Send a message to an agent
151    pub fn send_message(&self, message: Message) -> Result<(), String> {
152        (self.message_sender)(message)
153    }
154
155    /// Register a response channel for an agent
156    pub fn register_response_channel(&self, agent_id: AgentId) -> oneshot::Receiver<Message> {
157        let (tx, rx) = oneshot::channel();
158        self.response_channels
159            .write()
160            .entry(agent_id)
161            .or_default()
162            .push(tx);
163        rx
164    }
165
166    /// Deliver a response to waiting channels
167    pub fn deliver_response(&self, from_agent: &AgentId, message: Message) {
168        if let Some(channels) = self.response_channels.write().remove(from_agent) {
169            for tx in channels {
170                let _ = tx.send(message.clone());
171            }
172        }
173    }
174
175    /// Get count of registered agents
176    pub fn agent_count(&self) -> usize {
177        self.agents.read().len()
178    }
179}
180
181/// Tool for delegating tasks to other agents
182///
183/// This tool allows a supervisor agent to delegate specific tasks
184/// to specialized worker agents and receive their responses.
185pub struct DelegateAgentTool {
186    registry: Arc<AgentRegistry>,
187    timeout: Duration,
188}
189
190impl DelegateAgentTool {
191    /// Create a new delegation tool
192    pub fn new(registry: Arc<AgentRegistry>) -> Self {
193        Self {
194            registry,
195            timeout: Duration::from_secs(60),
196        }
197    }
198
199    /// Set timeout for waiting for agent response
200    pub fn with_timeout(mut self, timeout: Duration) -> Self {
201        self.timeout = timeout;
202        self
203    }
204}
205
206#[async_trait]
207impl Tool for DelegateAgentTool {
208    fn schema(&self) -> ToolSchema {
209        // Build available agents description
210        let agents = self.registry.list_agents();
211        let agents_desc = if agents.is_empty() {
212            "No agents currently available for delegation.".to_string()
213        } else {
214            agents
215                .iter()
216                .map(|a| format!("- {} ({}): {}", a.name, a.id, a.description))
217                .collect::<Vec<_>>()
218                .join("\n")
219        };
220
221        ToolSchema::new(
222            "delegate_to_agent",
223            format!(
224                "Delegate a task to another specialized agent. Use this when you need help \
225                from an agent with specific expertise. Available agents:\n{}",
226                agents_desc
227            ),
228        )
229        .with_parameters(serde_json::json!({
230            "type": "object",
231            "properties": {
232                "agent_id": {
233                    "type": "string",
234                    "description": "The ID of the agent to delegate to"
235                },
236                "task": {
237                    "type": "string",
238                    "description": "The task or question to send to the agent"
239                },
240                "context": {
241                    "type": "string",
242                    "description": "Optional additional context for the task"
243                }
244            },
245            "required": ["agent_id", "task"]
246        }))
247    }
248
249    async fn execute(
250        &self,
251        ctx: &ExecutionContext,
252        arguments: serde_json::Value,
253    ) -> Result<serde_json::Value, ToolError> {
254        let agent_id_str = arguments["agent_id"]
255            .as_str()
256            .ok_or_else(|| ToolError::InvalidArguments("agent_id is required".to_string()))?;
257
258        let task = arguments["task"]
259            .as_str()
260            .ok_or_else(|| ToolError::InvalidArguments("task is required".to_string()))?;
261
262        let context = arguments["context"].as_str();
263
264        let target_agent_id = AgentId::new(agent_id_str);
265
266        // Verify agent exists and is available
267        let agent_info = self.registry.get_agent(&target_agent_id).ok_or_else(|| {
268            ToolError::ExecutionFailed(format!("Agent '{}' not found", agent_id_str))
269        })?;
270
271        if !agent_info.available {
272            return Err(ToolError::ExecutionFailed(format!(
273                "Agent '{}' is not currently available",
274                agent_id_str
275            )));
276        }
277
278        info!(
279            from = %ctx.agent_id,
280            to = %target_agent_id,
281            task = %task,
282            "Delegating task to agent"
283        );
284
285        // Build message content
286        let content = if let Some(ctx_str) = context {
287            format!("{}\n\nContext: {}", task, ctx_str)
288        } else {
289            task.to_string()
290        };
291
292        // Register response channel before sending
293        let response_rx = self
294            .registry
295            .register_response_channel(target_agent_id.clone());
296
297        // Send message to target agent
298        let message = Message::new(
299            ctx.agent_id.clone(),
300            target_agent_id.clone(),
301            Content::Text(content),
302        );
303
304        self.registry.send_message(message).map_err(|e| {
305            ToolError::ExecutionFailed(format!("Failed to send message to agent: {}", e))
306        })?;
307
308        // Wait for response with timeout
309        debug!(
310            target = %target_agent_id,
311            timeout_secs = self.timeout.as_secs(),
312            "Waiting for agent response"
313        );
314
315        match timeout(self.timeout, response_rx).await {
316            Ok(Ok(response)) => {
317                info!(
318                    from = %target_agent_id,
319                    "Received response from delegated agent"
320                );
321
322                match response.content {
323                    Content::Text(text) => Ok(serde_json::json!({
324                        "agent": agent_info.name,
325                        "agent_id": agent_id_str,
326                        "response": text,
327                        "success": true
328                    })),
329                    _ => Ok(serde_json::json!({
330                        "agent": agent_info.name,
331                        "agent_id": agent_id_str,
332                        "response": "Agent returned non-text response",
333                        "success": true
334                    })),
335                }
336            }
337            Ok(Err(_)) => {
338                warn!(target = %target_agent_id, "Response channel closed");
339                Err(ToolError::ExecutionFailed(
340                    "Agent response channel closed unexpectedly".to_string(),
341                ))
342            }
343            Err(_) => {
344                warn!(
345                    target = %target_agent_id,
346                    timeout_secs = self.timeout.as_secs(),
347                    "Timeout waiting for agent response"
348                );
349                Err(ToolError::Timeout(format!(
350                    "Timeout waiting for response from agent '{}' after {} seconds",
351                    target_agent_id,
352                    self.timeout.as_secs()
353                )))
354            }
355        }
356    }
357}
358
359/// Tool for listing available agents
360///
361/// Useful for supervisor agents to discover what specialists are available.
362pub struct ListAgentsTool {
363    registry: Arc<AgentRegistry>,
364}
365
366impl ListAgentsTool {
367    pub fn new(registry: Arc<AgentRegistry>) -> Self {
368        Self { registry }
369    }
370}
371
372#[async_trait]
373impl Tool for ListAgentsTool {
374    fn schema(&self) -> ToolSchema {
375        ToolSchema::new(
376            "list_available_agents",
377            "List all agents available for delegation, including their specializations",
378        )
379        .with_parameters(serde_json::json!({
380            "type": "object",
381            "properties": {
382                "tag": {
383                    "type": "string",
384                    "description": "Optional tag to filter agents by specialization"
385                }
386            },
387            "required": []
388        }))
389    }
390
391    async fn execute(
392        &self,
393        _ctx: &ExecutionContext,
394        arguments: serde_json::Value,
395    ) -> Result<serde_json::Value, ToolError> {
396        let agents = if let Some(tag) = arguments["tag"].as_str() {
397            self.registry.find_by_tag(tag)
398        } else {
399            self.registry.list_agents()
400        };
401
402        let agent_list: Vec<serde_json::Value> = agents
403            .into_iter()
404            .map(|a| {
405                serde_json::json!({
406                    "id": a.id.to_string(),
407                    "name": a.name,
408                    "description": a.description,
409                    "tags": a.tags
410                })
411            })
412            .collect();
413
414        Ok(serde_json::json!({
415            "agents": agent_list,
416            "count": agent_list.len()
417        }))
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    fn create_test_registry() -> Arc<AgentRegistry> {
426        let sender: MessageSender = Arc::new(|_msg| Ok(()));
427        Arc::new(AgentRegistry::new(sender))
428    }
429
430    #[test]
431    fn test_agent_registry_register() {
432        let registry = create_test_registry();
433
434        registry.register(
435            AgentId::new("agent-1"),
436            "Research Agent",
437            "Specializes in research",
438        );
439
440        assert_eq!(registry.agent_count(), 1);
441
442        let agent = registry.get_agent(&AgentId::new("agent-1")).unwrap();
443        assert_eq!(agent.name, "Research Agent");
444    }
445
446    #[test]
447    fn test_agent_registry_list() {
448        let registry = create_test_registry();
449
450        registry.register(AgentId::new("agent-1"), "Agent 1", "Description 1");
451        registry.register(AgentId::new("agent-2"), "Agent 2", "Description 2");
452
453        let agents = registry.list_agents();
454        assert_eq!(agents.len(), 2);
455    }
456
457    #[test]
458    fn test_agent_availability() {
459        let registry = create_test_registry();
460
461        registry.register(AgentId::new("agent-1"), "Agent 1", "Description 1");
462
463        // Initially available
464        let agents = registry.list_agents();
465        assert_eq!(agents.len(), 1);
466
467        // Set unavailable
468        registry.set_available(&AgentId::new("agent-1"), false);
469        let agents = registry.list_agents();
470        assert_eq!(agents.len(), 0);
471
472        // Set available again
473        registry.set_available(&AgentId::new("agent-1"), true);
474        let agents = registry.list_agents();
475        assert_eq!(agents.len(), 1);
476    }
477
478    #[test]
479    fn test_find_by_tag() {
480        let registry = create_test_registry();
481
482        registry.register_agent(
483            AgentInfo::new(AgentId::new("research-1"), "Researcher", "Does research")
484                .with_tags(vec!["research".to_string(), "analysis".to_string()]),
485        );
486
487        registry.register_agent(
488            AgentInfo::new(AgentId::new("writer-1"), "Writer", "Writes content")
489                .with_tags(vec!["writing".to_string(), "content".to_string()]),
490        );
491
492        let researchers = registry.find_by_tag("research");
493        assert_eq!(researchers.len(), 1);
494        assert_eq!(researchers[0].id, AgentId::new("research-1"));
495
496        let writers = registry.find_by_tag("writing");
497        assert_eq!(writers.len(), 1);
498        assert_eq!(writers[0].id, AgentId::new("writer-1"));
499    }
500
501    #[tokio::test]
502    async fn test_delegate_tool_schema() {
503        let registry = create_test_registry();
504        registry.register(AgentId::new("helper"), "Helper Agent", "Helps with tasks");
505
506        let tool = DelegateAgentTool::new(registry);
507        let schema = tool.schema();
508
509        assert_eq!(schema.name, "delegate_to_agent");
510        assert!(schema.description.contains("Helper Agent"));
511    }
512
513    #[tokio::test]
514    async fn test_list_agents_tool() {
515        let registry = create_test_registry();
516        registry.register(AgentId::new("agent-1"), "Agent 1", "Description 1");
517        registry.register(AgentId::new("agent-2"), "Agent 2", "Description 2");
518
519        let tool = ListAgentsTool::new(registry);
520        let ctx = ExecutionContext::new(AgentId::new("supervisor"));
521
522        let result = tool.execute(&ctx, serde_json::json!({})).await.unwrap();
523
524        assert_eq!(result["count"], 2);
525        assert_eq!(result["agents"].as_array().unwrap().len(), 2);
526    }
527}