1use async_trait::async_trait;
60use serde::{Deserialize, Serialize};
61use serde_json::json;
62use std::collections::HashMap;
63use std::sync::Arc;
64use tracing::{debug, info};
65
66use crate::error::McpError;
67use crate::protocol::{CallToolResult, McpTool, ToolContent};
68use crate::server::ToolHandler;
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct AgentMcpInput {
73 pub query: String,
75 #[serde(default)]
77 pub context: HashMap<String, String>,
78 #[serde(default)]
80 pub history: Vec<String>,
81 #[serde(default)]
83 pub max_tokens: Option<usize>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct AgentMcpOutput {
89 pub content: String,
91 pub success: bool,
93 pub confidence: f32,
95 #[serde(default)]
97 pub metadata: HashMap<String, String>,
98 pub duration_ms: u64,
100 #[serde(default)]
102 pub tools_used: Vec<String>,
103}
104
105#[derive(Debug, Clone)]
107pub struct AgentMcpConfig {
108 pub name_prefix: String,
110 pub include_metadata: bool,
112 pub include_tools_used: bool,
114}
115
116impl Default for AgentMcpConfig {
117 fn default() -> Self {
118 Self {
119 name_prefix: "agent_".to_string(),
120 include_metadata: true,
121 include_tools_used: true,
122 }
123 }
124}
125
126pub type AgentHandlerFn = Arc<
128 dyn Fn(
129 AgentMcpInput,
130 ) -> std::pin::Pin<
131 Box<dyn std::future::Future<Output = Result<AgentMcpOutput, String>> + Send>,
132 > + Send
133 + Sync,
134>;
135
136pub struct AgentMcpHandler {
138 name: String,
140 description: String,
142 capabilities: Vec<String>,
144 handler: AgentHandlerFn,
146 config: AgentMcpConfig,
148}
149
150impl AgentMcpHandler {
151 pub fn new<F, Fut>(name: impl Into<String>, description: impl Into<String>, handler: F) -> Self
153 where
154 F: Fn(AgentMcpInput) -> Fut + Send + Sync + 'static,
155 Fut: std::future::Future<Output = Result<AgentMcpOutput, String>> + Send + 'static,
156 {
157 let config = AgentMcpConfig::default();
158 let name_str = name.into();
159 let tool_name = format!("{}{}", config.name_prefix, name_str);
160
161 Self {
162 name: tool_name,
163 description: description.into(),
164 capabilities: Vec::new(),
165 handler: Arc::new(move |input| Box::pin(handler(input))),
166 config,
167 }
168 }
169
170 pub fn with_config(mut self, config: AgentMcpConfig) -> Self {
172 let base_name = self
174 .name
175 .strip_prefix(&self.config.name_prefix)
176 .unwrap_or(&self.name)
177 .to_string();
178 self.name = format!("{}{}", config.name_prefix, base_name);
179 self.config = config;
180 self
181 }
182
183 pub fn with_capability(mut self, capability: impl Into<String>) -> Self {
185 self.capabilities.push(capability.into());
186 self
187 }
188
189 pub fn with_capabilities(mut self, capabilities: Vec<String>) -> Self {
191 self.capabilities.extend(capabilities);
192 self
193 }
194
195 pub fn builder(name: impl Into<String>) -> AgentMcpHandlerBuilder {
197 AgentMcpHandlerBuilder::new(name)
198 }
199
200 pub fn name(&self) -> &str {
202 &self.name
203 }
204
205 pub fn capabilities(&self) -> &[String] {
207 &self.capabilities
208 }
209}
210
211#[async_trait]
212impl ToolHandler for AgentMcpHandler {
213 fn definition(&self) -> McpTool {
214 let schema = json!({
215 "type": "object",
216 "properties": {
217 "query": {
218 "type": "string",
219 "description": "The query or task for the agent"
220 },
221 "context": {
222 "type": "object",
223 "description": "Additional context as key-value pairs",
224 "additionalProperties": { "type": "string" }
225 },
226 "history": {
227 "type": "array",
228 "description": "Conversation history (optional)",
229 "items": { "type": "string" }
230 },
231 "max_tokens": {
232 "type": "integer",
233 "description": "Maximum tokens for response (optional hint)"
234 }
235 },
236 "required": ["query"]
237 });
238
239 let description = if self.capabilities.is_empty() {
241 self.description.clone()
242 } else {
243 format!(
244 "{}\n\nCapabilities: {}",
245 self.description,
246 self.capabilities.join(", ")
247 )
248 };
249
250 McpTool {
251 name: self.name.clone(),
252 description: Some(description),
253 input_schema: schema,
254 }
255 }
256
257 async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError> {
258 debug!(tool = %self.name, "Executing agent MCP handler");
259
260 let input: AgentMcpInput = serde_json::from_value(arguments.clone())
262 .map_err(|e| McpError::InvalidParams(format!("Invalid input: {}", e)))?;
263
264 info!(
265 tool = %self.name,
266 query = %input.query,
267 context_keys = ?input.context.keys().collect::<Vec<_>>(),
268 "Agent executing query"
269 );
270
271 let result = (self.handler)(input).await;
273
274 match result {
275 Ok(output) => {
276 let mut response_parts = vec![output.content.clone()];
277
278 if self.config.include_metadata && !output.metadata.is_empty() {
280 let metadata_str = output
281 .metadata
282 .iter()
283 .map(|(k, v)| format!(" {}: {}", k, v))
284 .collect::<Vec<_>>()
285 .join("\n");
286 response_parts.push(format!("\n\nMetadata:\n{}", metadata_str));
287 }
288
289 if self.config.include_tools_used && !output.tools_used.is_empty() {
291 response_parts
292 .push(format!("\n\nTools used: {}", output.tools_used.join(", ")));
293 }
294
295 let response_text = response_parts.join("");
296
297 let structured_output = json!({
299 "success": output.success,
300 "confidence": output.confidence,
301 "duration_ms": output.duration_ms,
302 "metadata": output.metadata,
303 "tools_used": output.tools_used
304 });
305
306 Ok(CallToolResult {
307 content: vec![
308 ToolContent::text(response_text),
309 ToolContent::text(format!(
310 "\n---\nStructured output: {}",
311 serde_json::to_string_pretty(&structured_output).unwrap_or_default()
312 )),
313 ],
314 is_error: !output.success,
315 })
316 }
317 Err(e) => Ok(CallToolResult {
318 content: vec![ToolContent::text(format!("Agent error: {}", e))],
319 is_error: true,
320 }),
321 }
322 }
323}
324
325pub struct AgentMcpHandlerBuilder {
327 name: String,
328 description: String,
329 capabilities: Vec<String>,
330 config: AgentMcpConfig,
331}
332
333impl AgentMcpHandlerBuilder {
334 pub fn new(name: impl Into<String>) -> Self {
335 Self {
336 name: name.into(),
337 description: String::new(),
338 capabilities: Vec::new(),
339 config: AgentMcpConfig::default(),
340 }
341 }
342
343 pub fn description(mut self, description: impl Into<String>) -> Self {
344 self.description = description.into();
345 self
346 }
347
348 pub fn capability(mut self, capability: impl Into<String>) -> Self {
349 self.capabilities.push(capability.into());
350 self
351 }
352
353 pub fn capabilities(mut self, capabilities: Vec<String>) -> Self {
354 self.capabilities.extend(capabilities);
355 self
356 }
357
358 pub fn config(mut self, config: AgentMcpConfig) -> Self {
359 self.config = config;
360 self
361 }
362
363 pub fn name_prefix(mut self, prefix: impl Into<String>) -> Self {
364 self.config.name_prefix = prefix.into();
365 self
366 }
367
368 pub fn include_metadata(mut self, include: bool) -> Self {
369 self.config.include_metadata = include;
370 self
371 }
372
373 pub fn include_tools_used(mut self, include: bool) -> Self {
374 self.config.include_tools_used = include;
375 self
376 }
377
378 pub fn handler<F, Fut>(self, handler: F) -> AgentMcpHandler
380 where
381 F: Fn(AgentMcpInput) -> Fut + Send + Sync + 'static,
382 Fut: std::future::Future<Output = Result<AgentMcpOutput, String>> + Send + 'static,
383 {
384 let tool_name = format!("{}{}", self.config.name_prefix, self.name);
385
386 AgentMcpHandler {
387 name: tool_name,
388 description: self.description,
389 capabilities: self.capabilities,
390 handler: Arc::new(move |input| Box::pin(handler(input))),
391 config: self.config,
392 }
393 }
394}
395
396pub fn simple_agent<F, Fut>(
398 name: impl Into<String>,
399 description: impl Into<String>,
400 handler: F,
401) -> AgentMcpHandler
402where
403 F: Fn(String) -> Fut + Send + Sync + 'static,
404 Fut: std::future::Future<Output = Result<String, String>> + Send + 'static,
405{
406 let handler = Arc::new(handler);
407 AgentMcpHandler::builder(name)
408 .description(description)
409 .handler(move |input: AgentMcpInput| {
410 let h = handler.clone();
411 async move {
412 let start = std::time::Instant::now();
413 match h(input.query).await {
414 Ok(content) => Ok(AgentMcpOutput {
415 content,
416 success: true,
417 confidence: 1.0,
418 metadata: HashMap::new(),
419 duration_ms: start.elapsed().as_millis() as u64,
420 tools_used: Vec::new(),
421 }),
422 Err(e) => Err(e),
423 }
424 }
425 })
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[tokio::test]
433 async fn test_agent_mcp_handler_basic() {
434 let handler = AgentMcpHandler::builder("test_agent")
435 .description("A test agent")
436 .handler(|input: AgentMcpInput| async move {
437 Ok(AgentMcpOutput {
438 content: format!("Processed: {}", input.query),
439 success: true,
440 confidence: 0.95,
441 metadata: HashMap::new(),
442 duration_ms: 100,
443 tools_used: vec!["tool1".to_string()],
444 })
445 });
446
447 let def = handler.definition();
448 assert_eq!(def.name, "agent_test_agent");
449 assert!(def.description.unwrap().contains("test agent"));
450
451 let result = handler
452 .execute(json!({"query": "Hello world"}))
453 .await
454 .unwrap();
455
456 assert!(!result.is_error);
457 assert!(result.content[0]
458 .as_text()
459 .unwrap()
460 .contains("Processed: Hello world"));
461 }
462
463 #[tokio::test]
464 async fn test_agent_mcp_handler_with_context() {
465 let handler = AgentMcpHandler::builder("context_agent")
466 .description("Agent that uses context")
467 .handler(|input: AgentMcpInput| async move {
468 let name = input.context.get("name").cloned().unwrap_or_default();
469 Ok(AgentMcpOutput {
470 content: format!("Hello, {}!", name),
471 success: true,
472 confidence: 1.0,
473 metadata: HashMap::new(),
474 duration_ms: 50,
475 tools_used: Vec::new(),
476 })
477 });
478
479 let result = handler
480 .execute(json!({
481 "query": "greet",
482 "context": {"name": "World"}
483 }))
484 .await
485 .unwrap();
486
487 assert!(result.content[0]
488 .as_text()
489 .unwrap()
490 .contains("Hello, World!"));
491 }
492
493 #[tokio::test]
494 async fn test_agent_mcp_handler_error() {
495 let handler = AgentMcpHandler::builder("failing_agent")
496 .description("Agent that fails")
497 .handler(|_: AgentMcpInput| async move { Err("Intentional failure".to_string()) });
498
499 let result = handler.execute(json!({"query": "test"})).await.unwrap();
500
501 assert!(result.is_error);
502 assert!(result.content[0].as_text().unwrap().contains("Agent error"));
503 }
504
505 #[tokio::test]
506 async fn test_agent_mcp_handler_capabilities() {
507 let handler = AgentMcpHandler::builder("capable_agent")
508 .description("Agent with capabilities")
509 .capability("math")
510 .capability("science")
511 .handler(|_: AgentMcpInput| async move {
512 Ok(AgentMcpOutput {
513 content: "OK".to_string(),
514 success: true,
515 confidence: 1.0,
516 metadata: HashMap::new(),
517 duration_ms: 10,
518 tools_used: Vec::new(),
519 })
520 });
521
522 let def = handler.definition();
523 let desc = def.description.unwrap();
524 assert!(desc.contains("math"));
525 assert!(desc.contains("science"));
526 }
527
528 #[tokio::test]
529 async fn test_simple_agent_helper() {
530 let handler = simple_agent("simple", "A simple agent", |query: String| async move {
531 Ok(format!("Echo: {}", query))
532 });
533
534 let result = handler
535 .execute(json!({"query": "test message"}))
536 .await
537 .unwrap();
538
539 assert!(!result.is_error);
540 assert!(result.content[0]
541 .as_text()
542 .unwrap()
543 .contains("Echo: test message"));
544 }
545
546 #[tokio::test]
547 async fn test_agent_mcp_handler_custom_prefix() {
548 let handler = AgentMcpHandler::builder("custom")
549 .description("Custom prefix agent")
550 .name_prefix("ai_")
551 .handler(|_: AgentMcpInput| async move {
552 Ok(AgentMcpOutput {
553 content: "OK".to_string(),
554 success: true,
555 confidence: 1.0,
556 metadata: HashMap::new(),
557 duration_ms: 10,
558 tools_used: Vec::new(),
559 })
560 });
561
562 let def = handler.definition();
563 assert_eq!(def.name, "ai_custom");
564 }
565
566 #[tokio::test]
567 async fn test_agent_mcp_handler_metadata_output() {
568 let handler = AgentMcpHandler::builder("metadata_agent")
569 .description("Agent with metadata")
570 .include_metadata(true)
571 .handler(|_: AgentMcpInput| async move {
572 let mut metadata = HashMap::new();
573 metadata.insert("source".to_string(), "database".to_string());
574 metadata.insert("version".to_string(), "1.0".to_string());
575
576 Ok(AgentMcpOutput {
577 content: "Result with metadata".to_string(),
578 success: true,
579 confidence: 0.9,
580 metadata,
581 duration_ms: 200,
582 tools_used: vec!["db_query".to_string()],
583 })
584 });
585
586 let result = handler.execute(json!({"query": "test"})).await.unwrap();
587
588 let text = result.content[0].as_text().unwrap();
589 assert!(text.contains("Result with metadata"));
590 assert!(text.contains("source: database"));
591 }
592
593 #[test]
594 fn test_agent_mcp_input_deserialization() {
595 let json = json!({
596 "query": "What is 2+2?",
597 "context": {"mode": "math"},
598 "history": ["previous query"],
599 "max_tokens": 100
600 });
601
602 let input: AgentMcpInput = serde_json::from_value(json).unwrap();
603 assert_eq!(input.query, "What is 2+2?");
604 assert_eq!(input.context.get("mode").unwrap(), "math");
605 assert_eq!(input.history.len(), 1);
606 assert_eq!(input.max_tokens, Some(100));
607 }
608
609 #[test]
610 fn test_agent_mcp_input_minimal() {
611 let json = json!({"query": "simple query"});
612 let input: AgentMcpInput = serde_json::from_value(json).unwrap();
613
614 assert_eq!(input.query, "simple query");
615 assert!(input.context.is_empty());
616 assert!(input.history.is_empty());
617 assert!(input.max_tokens.is_none());
618 }
619}