1use async_trait::async_trait;
32use parking_lot::RwLock;
33use cortexai_core::{
34 errors::ToolError,
35 tool::{ExecutionContext, Tool, ToolSchema},
36 AgentId, Content, Message,
37};
38use serde::{Deserialize, Serialize};
39use std::collections::HashMap;
40use std::sync::Arc;
41use std::time::Duration;
42use tokio::sync::oneshot;
43use tokio::time::timeout;
44use tracing::{debug, info, warn};
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct AgentInfo {
49 pub id: AgentId,
51 pub name: String,
53 pub description: String,
55 pub tags: Vec<String>,
57 pub available: bool,
59}
60
61impl AgentInfo {
62 pub fn new(id: AgentId, name: impl Into<String>, description: impl Into<String>) -> Self {
63 Self {
64 id,
65 name: name.into(),
66 description: description.into(),
67 tags: Vec::new(),
68 available: true,
69 }
70 }
71
72 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
73 self.tags = tags;
74 self
75 }
76}
77
78pub type MessageSender = Arc<dyn Fn(Message) -> Result<(), String> + Send + Sync>;
80
81pub type ResponseWaiter =
83 Arc<dyn Fn(AgentId, Duration) -> Option<oneshot::Receiver<Message>> + Send + Sync>;
84
85pub struct AgentRegistry {
87 agents: RwLock<HashMap<AgentId, AgentInfo>>,
88 message_sender: MessageSender,
89 response_channels: RwLock<HashMap<AgentId, Vec<oneshot::Sender<Message>>>>,
90}
91
92impl AgentRegistry {
93 pub fn new(message_sender: MessageSender) -> Self {
95 Self {
96 agents: RwLock::new(HashMap::new()),
97 message_sender,
98 response_channels: RwLock::new(HashMap::new()),
99 }
100 }
101
102 pub fn register_agent(&self, info: AgentInfo) {
104 let id = info.id.clone();
105 self.agents.write().insert(id, info);
106 }
107
108 pub fn register(&self, id: AgentId, name: impl Into<String>, description: impl Into<String>) {
110 self.register_agent(AgentInfo::new(id, name, description));
111 }
112
113 pub fn unregister(&self, id: &AgentId) {
115 self.agents.write().remove(id);
116 }
117
118 pub fn get_agent(&self, id: &AgentId) -> Option<AgentInfo> {
120 self.agents.read().get(id).cloned()
121 }
122
123 pub fn list_agents(&self) -> Vec<AgentInfo> {
125 self.agents
126 .read()
127 .values()
128 .filter(|a| a.available)
129 .cloned()
130 .collect()
131 }
132
133 pub fn find_by_tag(&self, tag: &str) -> Vec<AgentInfo> {
135 self.agents
136 .read()
137 .values()
138 .filter(|a| a.available && a.tags.iter().any(|t| t == tag))
139 .cloned()
140 .collect()
141 }
142
143 pub fn set_available(&self, id: &AgentId, available: bool) {
145 if let Some(agent) = self.agents.write().get_mut(id) {
146 agent.available = available;
147 }
148 }
149
150 pub fn send_message(&self, message: Message) -> Result<(), String> {
152 (self.message_sender)(message)
153 }
154
155 pub fn register_response_channel(&self, agent_id: AgentId) -> oneshot::Receiver<Message> {
157 let (tx, rx) = oneshot::channel();
158 self.response_channels
159 .write()
160 .entry(agent_id)
161 .or_default()
162 .push(tx);
163 rx
164 }
165
166 pub fn deliver_response(&self, from_agent: &AgentId, message: Message) {
168 if let Some(channels) = self.response_channels.write().remove(from_agent) {
169 for tx in channels {
170 let _ = tx.send(message.clone());
171 }
172 }
173 }
174
175 pub fn agent_count(&self) -> usize {
177 self.agents.read().len()
178 }
179}
180
181pub struct DelegateAgentTool {
186 registry: Arc<AgentRegistry>,
187 timeout: Duration,
188}
189
190impl DelegateAgentTool {
191 pub fn new(registry: Arc<AgentRegistry>) -> Self {
193 Self {
194 registry,
195 timeout: Duration::from_secs(60),
196 }
197 }
198
199 pub fn with_timeout(mut self, timeout: Duration) -> Self {
201 self.timeout = timeout;
202 self
203 }
204}
205
206#[async_trait]
207impl Tool for DelegateAgentTool {
208 fn schema(&self) -> ToolSchema {
209 let agents = self.registry.list_agents();
211 let agents_desc = if agents.is_empty() {
212 "No agents currently available for delegation.".to_string()
213 } else {
214 agents
215 .iter()
216 .map(|a| format!("- {} ({}): {}", a.name, a.id, a.description))
217 .collect::<Vec<_>>()
218 .join("\n")
219 };
220
221 ToolSchema::new(
222 "delegate_to_agent",
223 format!(
224 "Delegate a task to another specialized agent. Use this when you need help \
225 from an agent with specific expertise. Available agents:\n{}",
226 agents_desc
227 ),
228 )
229 .with_parameters(serde_json::json!({
230 "type": "object",
231 "properties": {
232 "agent_id": {
233 "type": "string",
234 "description": "The ID of the agent to delegate to"
235 },
236 "task": {
237 "type": "string",
238 "description": "The task or question to send to the agent"
239 },
240 "context": {
241 "type": "string",
242 "description": "Optional additional context for the task"
243 }
244 },
245 "required": ["agent_id", "task"]
246 }))
247 }
248
249 async fn execute(
250 &self,
251 ctx: &ExecutionContext,
252 arguments: serde_json::Value,
253 ) -> Result<serde_json::Value, ToolError> {
254 let agent_id_str = arguments["agent_id"]
255 .as_str()
256 .ok_or_else(|| ToolError::InvalidArguments("agent_id is required".to_string()))?;
257
258 let task = arguments["task"]
259 .as_str()
260 .ok_or_else(|| ToolError::InvalidArguments("task is required".to_string()))?;
261
262 let context = arguments["context"].as_str();
263
264 let target_agent_id = AgentId::new(agent_id_str);
265
266 let agent_info = self.registry.get_agent(&target_agent_id).ok_or_else(|| {
268 ToolError::ExecutionFailed(format!("Agent '{}' not found", agent_id_str))
269 })?;
270
271 if !agent_info.available {
272 return Err(ToolError::ExecutionFailed(format!(
273 "Agent '{}' is not currently available",
274 agent_id_str
275 )));
276 }
277
278 info!(
279 from = %ctx.agent_id,
280 to = %target_agent_id,
281 task = %task,
282 "Delegating task to agent"
283 );
284
285 let content = if let Some(ctx_str) = context {
287 format!("{}\n\nContext: {}", task, ctx_str)
288 } else {
289 task.to_string()
290 };
291
292 let response_rx = self
294 .registry
295 .register_response_channel(target_agent_id.clone());
296
297 let message = Message::new(
299 ctx.agent_id.clone(),
300 target_agent_id.clone(),
301 Content::Text(content),
302 );
303
304 self.registry.send_message(message).map_err(|e| {
305 ToolError::ExecutionFailed(format!("Failed to send message to agent: {}", e))
306 })?;
307
308 debug!(
310 target = %target_agent_id,
311 timeout_secs = self.timeout.as_secs(),
312 "Waiting for agent response"
313 );
314
315 match timeout(self.timeout, response_rx).await {
316 Ok(Ok(response)) => {
317 info!(
318 from = %target_agent_id,
319 "Received response from delegated agent"
320 );
321
322 match response.content {
323 Content::Text(text) => Ok(serde_json::json!({
324 "agent": agent_info.name,
325 "agent_id": agent_id_str,
326 "response": text,
327 "success": true
328 })),
329 _ => Ok(serde_json::json!({
330 "agent": agent_info.name,
331 "agent_id": agent_id_str,
332 "response": "Agent returned non-text response",
333 "success": true
334 })),
335 }
336 }
337 Ok(Err(_)) => {
338 warn!(target = %target_agent_id, "Response channel closed");
339 Err(ToolError::ExecutionFailed(
340 "Agent response channel closed unexpectedly".to_string(),
341 ))
342 }
343 Err(_) => {
344 warn!(
345 target = %target_agent_id,
346 timeout_secs = self.timeout.as_secs(),
347 "Timeout waiting for agent response"
348 );
349 Err(ToolError::Timeout(format!(
350 "Timeout waiting for response from agent '{}' after {} seconds",
351 target_agent_id,
352 self.timeout.as_secs()
353 )))
354 }
355 }
356 }
357}
358
359pub struct ListAgentsTool {
363 registry: Arc<AgentRegistry>,
364}
365
366impl ListAgentsTool {
367 pub fn new(registry: Arc<AgentRegistry>) -> Self {
368 Self { registry }
369 }
370}
371
372#[async_trait]
373impl Tool for ListAgentsTool {
374 fn schema(&self) -> ToolSchema {
375 ToolSchema::new(
376 "list_available_agents",
377 "List all agents available for delegation, including their specializations",
378 )
379 .with_parameters(serde_json::json!({
380 "type": "object",
381 "properties": {
382 "tag": {
383 "type": "string",
384 "description": "Optional tag to filter agents by specialization"
385 }
386 },
387 "required": []
388 }))
389 }
390
391 async fn execute(
392 &self,
393 _ctx: &ExecutionContext,
394 arguments: serde_json::Value,
395 ) -> Result<serde_json::Value, ToolError> {
396 let agents = if let Some(tag) = arguments["tag"].as_str() {
397 self.registry.find_by_tag(tag)
398 } else {
399 self.registry.list_agents()
400 };
401
402 let agent_list: Vec<serde_json::Value> = agents
403 .into_iter()
404 .map(|a| {
405 serde_json::json!({
406 "id": a.id.to_string(),
407 "name": a.name,
408 "description": a.description,
409 "tags": a.tags
410 })
411 })
412 .collect();
413
414 Ok(serde_json::json!({
415 "agents": agent_list,
416 "count": agent_list.len()
417 }))
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 fn create_test_registry() -> Arc<AgentRegistry> {
426 let sender: MessageSender = Arc::new(|_msg| Ok(()));
427 Arc::new(AgentRegistry::new(sender))
428 }
429
430 #[test]
431 fn test_agent_registry_register() {
432 let registry = create_test_registry();
433
434 registry.register(
435 AgentId::new("agent-1"),
436 "Research Agent",
437 "Specializes in research",
438 );
439
440 assert_eq!(registry.agent_count(), 1);
441
442 let agent = registry.get_agent(&AgentId::new("agent-1")).unwrap();
443 assert_eq!(agent.name, "Research Agent");
444 }
445
446 #[test]
447 fn test_agent_registry_list() {
448 let registry = create_test_registry();
449
450 registry.register(AgentId::new("agent-1"), "Agent 1", "Description 1");
451 registry.register(AgentId::new("agent-2"), "Agent 2", "Description 2");
452
453 let agents = registry.list_agents();
454 assert_eq!(agents.len(), 2);
455 }
456
457 #[test]
458 fn test_agent_availability() {
459 let registry = create_test_registry();
460
461 registry.register(AgentId::new("agent-1"), "Agent 1", "Description 1");
462
463 let agents = registry.list_agents();
465 assert_eq!(agents.len(), 1);
466
467 registry.set_available(&AgentId::new("agent-1"), false);
469 let agents = registry.list_agents();
470 assert_eq!(agents.len(), 0);
471
472 registry.set_available(&AgentId::new("agent-1"), true);
474 let agents = registry.list_agents();
475 assert_eq!(agents.len(), 1);
476 }
477
478 #[test]
479 fn test_find_by_tag() {
480 let registry = create_test_registry();
481
482 registry.register_agent(
483 AgentInfo::new(AgentId::new("research-1"), "Researcher", "Does research")
484 .with_tags(vec!["research".to_string(), "analysis".to_string()]),
485 );
486
487 registry.register_agent(
488 AgentInfo::new(AgentId::new("writer-1"), "Writer", "Writes content")
489 .with_tags(vec!["writing".to_string(), "content".to_string()]),
490 );
491
492 let researchers = registry.find_by_tag("research");
493 assert_eq!(researchers.len(), 1);
494 assert_eq!(researchers[0].id, AgentId::new("research-1"));
495
496 let writers = registry.find_by_tag("writing");
497 assert_eq!(writers.len(), 1);
498 assert_eq!(writers[0].id, AgentId::new("writer-1"));
499 }
500
501 #[tokio::test]
502 async fn test_delegate_tool_schema() {
503 let registry = create_test_registry();
504 registry.register(AgentId::new("helper"), "Helper Agent", "Helps with tasks");
505
506 let tool = DelegateAgentTool::new(registry);
507 let schema = tool.schema();
508
509 assert_eq!(schema.name, "delegate_to_agent");
510 assert!(schema.description.contains("Helper Agent"));
511 }
512
513 #[tokio::test]
514 async fn test_list_agents_tool() {
515 let registry = create_test_registry();
516 registry.register(AgentId::new("agent-1"), "Agent 1", "Description 1");
517 registry.register(AgentId::new("agent-2"), "Agent 2", "Description 2");
518
519 let tool = ListAgentsTool::new(registry);
520 let ctx = ExecutionContext::new(AgentId::new("supervisor"));
521
522 let result = tool.execute(&ctx, serde_json::json!({})).await.unwrap();
523
524 assert_eq!(result["count"], 2);
525 assert_eq!(result["agents"].as_array().unwrap().len(), 2);
526 }
527}