Skip to main content

cortexai_mcp/
agent_handler.rs

1//! Agent-as-MCP-Tool
2//!
3//! Exposes cortex agents as MCP tools, allowing external MCP clients
4//! (like Claude Desktop) to invoke agents directly.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────┐
10//! │                     MCP Client (Claude)                     │
11//! └─────────────────────────────────────────────────────────────┘
12//!                              │
13//!                    tools/call "agent_xxx"
14//!                              │
15//!                              ▼
16//! ┌─────────────────────────────────────────────────────────────┐
17//! │                     McpServer                               │
18//! │  ┌─────────────────────────────────────────────────────┐   │
19//! │  │              AgentMcpHandler                         │   │
20//! │  │                                                      │   │
21//! │  │  - Wraps AgentTool as MCP ToolHandler               │   │
22//! │  │  - Converts MCP params to AgentToolInput            │   │
23//! │  │  - Converts AgentToolOutput to MCP CallToolResult   │   │
24//! │  └─────────────────────────────────────────────────────┘   │
25//! └─────────────────────────────────────────────────────────────┘
26//!                              │
27//!                              ▼
28//! ┌─────────────────────────────────────────────────────────────┐
29//! │                     AgentTool / AgentEngine                 │
30//! │  - Executes agent logic                                     │
31//! │  - Returns structured response                              │
32//! └─────────────────────────────────────────────────────────────┘
33//! ```
34//!
35//! # Example
36//!
37//! ```rust,ignore
38//! use cortexai_mcp::{McpServer, AgentMcpHandler};
39//! use cortexai_agents::agent_tool::AgentTool;
40//!
41//! // Create an agent tool
42//! let research_agent = AgentTool::builder("research_agent")
43//!     .description("Researches topics and provides summaries")
44//!     .handler(|input| async move {
45//!         // Agent logic here
46//!         Ok(AgentToolOutput::success("Research results..."))
47//!     });
48//!
49//! // Wrap as MCP handler
50//! let mcp_handler = AgentMcpHandler::from_agent_tool(research_agent);
51//!
52//! // Add to MCP server
53//! let server = McpServer::builder()
54//!     .name("agent-server")
55//!     .add_tool(mcp_handler)
56//!     .build();
57//! ```
58
59use async_trait::async_trait;
60use serde::{Deserialize, Serialize};
61use serde_json::json;
62use std::collections::HashMap;
63use std::sync::Arc;
64use tracing::{debug, info};
65
66use crate::error::McpError;
67use crate::protocol::{CallToolResult, McpTool, ToolContent};
68use crate::server::ToolHandler;
69
70/// Input schema for agent MCP tools
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct AgentMcpInput {
73    /// The query or task for the agent
74    pub query: String,
75    /// Additional context as key-value pairs
76    #[serde(default)]
77    pub context: HashMap<String, String>,
78    /// Conversation history (optional)
79    #[serde(default)]
80    pub history: Vec<String>,
81    /// Maximum tokens for response (optional hint)
82    #[serde(default)]
83    pub max_tokens: Option<usize>,
84}
85
86/// Output structure for agent responses
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct AgentMcpOutput {
89    /// The response content
90    pub content: String,
91    /// Whether the agent successfully completed the task
92    pub success: bool,
93    /// Confidence score (0.0 to 1.0)
94    pub confidence: f32,
95    /// Additional metadata
96    #[serde(default)]
97    pub metadata: HashMap<String, String>,
98    /// Time taken in milliseconds
99    pub duration_ms: u64,
100    /// Tools used by the agent
101    #[serde(default)]
102    pub tools_used: Vec<String>,
103}
104
105/// Configuration for an agent MCP handler
106#[derive(Debug, Clone)]
107pub struct AgentMcpConfig {
108    /// Name prefix for the MCP tool (e.g., "agent_")
109    pub name_prefix: String,
110    /// Whether to include metadata in response
111    pub include_metadata: bool,
112    /// Whether to include tools used in response
113    pub include_tools_used: bool,
114}
115
116impl Default for AgentMcpConfig {
117    fn default() -> Self {
118        Self {
119            name_prefix: "agent_".to_string(),
120            include_metadata: true,
121            include_tools_used: true,
122        }
123    }
124}
125
126/// Handler type for agent execution
127pub type AgentHandlerFn = Arc<
128    dyn Fn(
129            AgentMcpInput,
130        ) -> std::pin::Pin<
131            Box<dyn std::future::Future<Output = Result<AgentMcpOutput, String>> + Send>,
132        > + Send
133        + Sync,
134>;
135
136/// MCP ToolHandler that wraps an agent
137pub struct AgentMcpHandler {
138    /// Tool name (as exposed in MCP)
139    name: String,
140    /// Tool description
141    description: String,
142    /// Agent capabilities/tags
143    capabilities: Vec<String>,
144    /// Handler function
145    handler: AgentHandlerFn,
146    /// Configuration
147    config: AgentMcpConfig,
148}
149
150impl AgentMcpHandler {
151    /// Create a new handler with a custom async function
152    pub fn new<F, Fut>(name: impl Into<String>, description: impl Into<String>, handler: F) -> Self
153    where
154        F: Fn(AgentMcpInput) -> Fut + Send + Sync + 'static,
155        Fut: std::future::Future<Output = Result<AgentMcpOutput, String>> + Send + 'static,
156    {
157        let config = AgentMcpConfig::default();
158        let name_str = name.into();
159        let tool_name = format!("{}{}", config.name_prefix, name_str);
160
161        Self {
162            name: tool_name,
163            description: description.into(),
164            capabilities: Vec::new(),
165            handler: Arc::new(move |input| Box::pin(handler(input))),
166            config,
167        }
168    }
169
170    /// Create with custom configuration
171    pub fn with_config(mut self, config: AgentMcpConfig) -> Self {
172        // Update name with new prefix
173        let base_name = self
174            .name
175            .strip_prefix(&self.config.name_prefix)
176            .unwrap_or(&self.name)
177            .to_string();
178        self.name = format!("{}{}", config.name_prefix, base_name);
179        self.config = config;
180        self
181    }
182
183    /// Add a capability tag
184    pub fn with_capability(mut self, capability: impl Into<String>) -> Self {
185        self.capabilities.push(capability.into());
186        self
187    }
188
189    /// Add multiple capabilities
190    pub fn with_capabilities(mut self, capabilities: Vec<String>) -> Self {
191        self.capabilities.extend(capabilities);
192        self
193    }
194
195    /// Create a builder for fluent construction
196    pub fn builder(name: impl Into<String>) -> AgentMcpHandlerBuilder {
197        AgentMcpHandlerBuilder::new(name)
198    }
199
200    /// Get the tool name
201    pub fn name(&self) -> &str {
202        &self.name
203    }
204
205    /// Get the capabilities
206    pub fn capabilities(&self) -> &[String] {
207        &self.capabilities
208    }
209}
210
211#[async_trait]
212impl ToolHandler for AgentMcpHandler {
213    fn definition(&self) -> McpTool {
214        let schema = json!({
215            "type": "object",
216            "properties": {
217                "query": {
218                    "type": "string",
219                    "description": "The query or task for the agent"
220                },
221                "context": {
222                    "type": "object",
223                    "description": "Additional context as key-value pairs",
224                    "additionalProperties": { "type": "string" }
225                },
226                "history": {
227                    "type": "array",
228                    "description": "Conversation history (optional)",
229                    "items": { "type": "string" }
230                },
231                "max_tokens": {
232                    "type": "integer",
233                    "description": "Maximum tokens for response (optional hint)"
234                }
235            },
236            "required": ["query"]
237        });
238
239        // Add capabilities to description if present
240        let description = if self.capabilities.is_empty() {
241            self.description.clone()
242        } else {
243            format!(
244                "{}\n\nCapabilities: {}",
245                self.description,
246                self.capabilities.join(", ")
247            )
248        };
249
250        McpTool {
251            name: self.name.clone(),
252            description: Some(description),
253            input_schema: schema,
254        }
255    }
256
257    async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError> {
258        debug!(tool = %self.name, "Executing agent MCP handler");
259
260        // Parse input
261        let input: AgentMcpInput = serde_json::from_value(arguments.clone())
262            .map_err(|e| McpError::InvalidParams(format!("Invalid input: {}", e)))?;
263
264        info!(
265            tool = %self.name,
266            query = %input.query,
267            context_keys = ?input.context.keys().collect::<Vec<_>>(),
268            "Agent executing query"
269        );
270
271        // Execute agent
272        let result = (self.handler)(input).await;
273
274        match result {
275            Ok(output) => {
276                let mut response_parts = vec![output.content.clone()];
277
278                // Add metadata if configured
279                if self.config.include_metadata && !output.metadata.is_empty() {
280                    let metadata_str = output
281                        .metadata
282                        .iter()
283                        .map(|(k, v)| format!("  {}: {}", k, v))
284                        .collect::<Vec<_>>()
285                        .join("\n");
286                    response_parts.push(format!("\n\nMetadata:\n{}", metadata_str));
287                }
288
289                // Add tools used if configured
290                if self.config.include_tools_used && !output.tools_used.is_empty() {
291                    response_parts
292                        .push(format!("\n\nTools used: {}", output.tools_used.join(", ")));
293                }
294
295                let response_text = response_parts.join("");
296
297                // Add structured data as additional content
298                let structured_output = json!({
299                    "success": output.success,
300                    "confidence": output.confidence,
301                    "duration_ms": output.duration_ms,
302                    "metadata": output.metadata,
303                    "tools_used": output.tools_used
304                });
305
306                Ok(CallToolResult {
307                    content: vec![
308                        ToolContent::text(response_text),
309                        ToolContent::text(format!(
310                            "\n---\nStructured output: {}",
311                            serde_json::to_string_pretty(&structured_output).unwrap_or_default()
312                        )),
313                    ],
314                    is_error: !output.success,
315                })
316            }
317            Err(e) => Ok(CallToolResult {
318                content: vec![ToolContent::text(format!("Agent error: {}", e))],
319                is_error: true,
320            }),
321        }
322    }
323}
324
325/// Builder for AgentMcpHandler
326pub struct AgentMcpHandlerBuilder {
327    name: String,
328    description: String,
329    capabilities: Vec<String>,
330    config: AgentMcpConfig,
331}
332
333impl AgentMcpHandlerBuilder {
334    pub fn new(name: impl Into<String>) -> Self {
335        Self {
336            name: name.into(),
337            description: String::new(),
338            capabilities: Vec::new(),
339            config: AgentMcpConfig::default(),
340        }
341    }
342
343    pub fn description(mut self, description: impl Into<String>) -> Self {
344        self.description = description.into();
345        self
346    }
347
348    pub fn capability(mut self, capability: impl Into<String>) -> Self {
349        self.capabilities.push(capability.into());
350        self
351    }
352
353    pub fn capabilities(mut self, capabilities: Vec<String>) -> Self {
354        self.capabilities.extend(capabilities);
355        self
356    }
357
358    pub fn config(mut self, config: AgentMcpConfig) -> Self {
359        self.config = config;
360        self
361    }
362
363    pub fn name_prefix(mut self, prefix: impl Into<String>) -> Self {
364        self.config.name_prefix = prefix.into();
365        self
366    }
367
368    pub fn include_metadata(mut self, include: bool) -> Self {
369        self.config.include_metadata = include;
370        self
371    }
372
373    pub fn include_tools_used(mut self, include: bool) -> Self {
374        self.config.include_tools_used = include;
375        self
376    }
377
378    /// Build with a handler function
379    pub fn handler<F, Fut>(self, handler: F) -> AgentMcpHandler
380    where
381        F: Fn(AgentMcpInput) -> Fut + Send + Sync + 'static,
382        Fut: std::future::Future<Output = Result<AgentMcpOutput, String>> + Send + 'static,
383    {
384        let tool_name = format!("{}{}", self.config.name_prefix, self.name);
385
386        AgentMcpHandler {
387            name: tool_name,
388            description: self.description,
389            capabilities: self.capabilities,
390            handler: Arc::new(move |input| Box::pin(handler(input))),
391            config: self.config,
392        }
393    }
394}
395
396/// Helper to create a simple agent handler from a closure
397pub fn simple_agent<F, Fut>(
398    name: impl Into<String>,
399    description: impl Into<String>,
400    handler: F,
401) -> AgentMcpHandler
402where
403    F: Fn(String) -> Fut + Send + Sync + 'static,
404    Fut: std::future::Future<Output = Result<String, String>> + Send + 'static,
405{
406    let handler = Arc::new(handler);
407    AgentMcpHandler::builder(name)
408        .description(description)
409        .handler(move |input: AgentMcpInput| {
410            let h = handler.clone();
411            async move {
412                let start = std::time::Instant::now();
413                match h(input.query).await {
414                    Ok(content) => Ok(AgentMcpOutput {
415                        content,
416                        success: true,
417                        confidence: 1.0,
418                        metadata: HashMap::new(),
419                        duration_ms: start.elapsed().as_millis() as u64,
420                        tools_used: Vec::new(),
421                    }),
422                    Err(e) => Err(e),
423                }
424            }
425        })
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[tokio::test]
433    async fn test_agent_mcp_handler_basic() {
434        let handler = AgentMcpHandler::builder("test_agent")
435            .description("A test agent")
436            .handler(|input: AgentMcpInput| async move {
437                Ok(AgentMcpOutput {
438                    content: format!("Processed: {}", input.query),
439                    success: true,
440                    confidence: 0.95,
441                    metadata: HashMap::new(),
442                    duration_ms: 100,
443                    tools_used: vec!["tool1".to_string()],
444                })
445            });
446
447        let def = handler.definition();
448        assert_eq!(def.name, "agent_test_agent");
449        assert!(def.description.unwrap().contains("test agent"));
450
451        let result = handler
452            .execute(json!({"query": "Hello world"}))
453            .await
454            .unwrap();
455
456        assert!(!result.is_error);
457        assert!(result.content[0]
458            .as_text()
459            .unwrap()
460            .contains("Processed: Hello world"));
461    }
462
463    #[tokio::test]
464    async fn test_agent_mcp_handler_with_context() {
465        let handler = AgentMcpHandler::builder("context_agent")
466            .description("Agent that uses context")
467            .handler(|input: AgentMcpInput| async move {
468                let name = input.context.get("name").cloned().unwrap_or_default();
469                Ok(AgentMcpOutput {
470                    content: format!("Hello, {}!", name),
471                    success: true,
472                    confidence: 1.0,
473                    metadata: HashMap::new(),
474                    duration_ms: 50,
475                    tools_used: Vec::new(),
476                })
477            });
478
479        let result = handler
480            .execute(json!({
481                "query": "greet",
482                "context": {"name": "World"}
483            }))
484            .await
485            .unwrap();
486
487        assert!(result.content[0]
488            .as_text()
489            .unwrap()
490            .contains("Hello, World!"));
491    }
492
493    #[tokio::test]
494    async fn test_agent_mcp_handler_error() {
495        let handler = AgentMcpHandler::builder("failing_agent")
496            .description("Agent that fails")
497            .handler(|_: AgentMcpInput| async move { Err("Intentional failure".to_string()) });
498
499        let result = handler.execute(json!({"query": "test"})).await.unwrap();
500
501        assert!(result.is_error);
502        assert!(result.content[0].as_text().unwrap().contains("Agent error"));
503    }
504
505    #[tokio::test]
506    async fn test_agent_mcp_handler_capabilities() {
507        let handler = AgentMcpHandler::builder("capable_agent")
508            .description("Agent with capabilities")
509            .capability("math")
510            .capability("science")
511            .handler(|_: AgentMcpInput| async move {
512                Ok(AgentMcpOutput {
513                    content: "OK".to_string(),
514                    success: true,
515                    confidence: 1.0,
516                    metadata: HashMap::new(),
517                    duration_ms: 10,
518                    tools_used: Vec::new(),
519                })
520            });
521
522        let def = handler.definition();
523        let desc = def.description.unwrap();
524        assert!(desc.contains("math"));
525        assert!(desc.contains("science"));
526    }
527
528    #[tokio::test]
529    async fn test_simple_agent_helper() {
530        let handler = simple_agent("simple", "A simple agent", |query: String| async move {
531            Ok(format!("Echo: {}", query))
532        });
533
534        let result = handler
535            .execute(json!({"query": "test message"}))
536            .await
537            .unwrap();
538
539        assert!(!result.is_error);
540        assert!(result.content[0]
541            .as_text()
542            .unwrap()
543            .contains("Echo: test message"));
544    }
545
546    #[tokio::test]
547    async fn test_agent_mcp_handler_custom_prefix() {
548        let handler = AgentMcpHandler::builder("custom")
549            .description("Custom prefix agent")
550            .name_prefix("ai_")
551            .handler(|_: AgentMcpInput| async move {
552                Ok(AgentMcpOutput {
553                    content: "OK".to_string(),
554                    success: true,
555                    confidence: 1.0,
556                    metadata: HashMap::new(),
557                    duration_ms: 10,
558                    tools_used: Vec::new(),
559                })
560            });
561
562        let def = handler.definition();
563        assert_eq!(def.name, "ai_custom");
564    }
565
566    #[tokio::test]
567    async fn test_agent_mcp_handler_metadata_output() {
568        let handler = AgentMcpHandler::builder("metadata_agent")
569            .description("Agent with metadata")
570            .include_metadata(true)
571            .handler(|_: AgentMcpInput| async move {
572                let mut metadata = HashMap::new();
573                metadata.insert("source".to_string(), "database".to_string());
574                metadata.insert("version".to_string(), "1.0".to_string());
575
576                Ok(AgentMcpOutput {
577                    content: "Result with metadata".to_string(),
578                    success: true,
579                    confidence: 0.9,
580                    metadata,
581                    duration_ms: 200,
582                    tools_used: vec!["db_query".to_string()],
583                })
584            });
585
586        let result = handler.execute(json!({"query": "test"})).await.unwrap();
587
588        let text = result.content[0].as_text().unwrap();
589        assert!(text.contains("Result with metadata"));
590        assert!(text.contains("source: database"));
591    }
592
593    #[test]
594    fn test_agent_mcp_input_deserialization() {
595        let json = json!({
596            "query": "What is 2+2?",
597            "context": {"mode": "math"},
598            "history": ["previous query"],
599            "max_tokens": 100
600        });
601
602        let input: AgentMcpInput = serde_json::from_value(json).unwrap();
603        assert_eq!(input.query, "What is 2+2?");
604        assert_eq!(input.context.get("mode").unwrap(), "math");
605        assert_eq!(input.history.len(), 1);
606        assert_eq!(input.max_tokens, Some(100));
607    }
608
609    #[test]
610    fn test_agent_mcp_input_minimal() {
611        let json = json!({"query": "simple query"});
612        let input: AgentMcpInput = serde_json::from_value(json).unwrap();
613
614        assert_eq!(input.query, "simple query");
615        assert!(input.context.is_empty());
616        assert!(input.history.is_empty());
617        assert!(input.max_tokens.is_none());
618    }
619}