use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::sync::Arc;
use tracing::{debug, info};
use crate::error::McpError;
use crate::protocol::{CallToolResult, McpTool, ToolContent};
use crate::server::ToolHandler;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphMcpInput {
pub input: serde_json::Value,
#[serde(default)]
pub max_iterations: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeExecution {
pub node_id: String,
pub iteration: u32,
pub duration_ms: u64,
pub output_summary: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphMcpOutput {
pub result: serde_json::Value,
pub status: String,
pub nodes_executed: Vec<NodeExecution>,
pub iterations: u32,
pub duration_ms: u64,
}
#[derive(Debug, Clone)]
pub struct GraphMcpConfig {
pub name_prefix: String,
pub include_node_details: bool,
}
impl Default for GraphMcpConfig {
fn default() -> Self {
Self {
name_prefix: "graph_".to_string(),
include_node_details: true,
}
}
}
pub type GraphHandlerFn = Arc<
dyn Fn(
GraphMcpInput,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<GraphMcpOutput, String>> + Send>,
> + Send
+ Sync,
>;
pub struct GraphMcpHandler {
name: String,
description: String,
capabilities: Vec<String>,
handler: GraphHandlerFn,
config: GraphMcpConfig,
}
impl GraphMcpHandler {
pub fn builder(name: impl Into<String>) -> GraphMcpHandlerBuilder {
GraphMcpHandlerBuilder::new(name)
}
pub fn name(&self) -> &str {
&self.name
}
pub fn capabilities(&self) -> &[String] {
&self.capabilities
}
}
#[async_trait]
impl ToolHandler for GraphMcpHandler {
fn definition(&self) -> McpTool {
let schema = json!({
"type": "object",
"properties": {
"input": {
"type": "object",
"description": "Initial graph state data"
},
"max_iterations": {
"type": "integer",
"description": "Iteration limit for cyclic graphs"
}
},
"required": ["input"]
});
let description = if self.capabilities.is_empty() {
self.description.clone()
} else {
format!(
"{}\n\nCapabilities: {}",
self.description,
self.capabilities.join(", ")
)
};
McpTool {
name: self.name.clone(),
description: Some(description),
input_schema: schema,
}
}
async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError> {
debug!(tool = %self.name, "Executing graph MCP handler");
let input: GraphMcpInput = serde_json::from_value(arguments)
.map_err(|e| McpError::InvalidParams(format!("Invalid input: {}", e)))?;
info!(
tool = %self.name,
max_iterations = ?input.max_iterations,
"Graph executing"
);
let result = (self.handler)(input).await;
match result {
Ok(output) => {
let response_text = build_success_response(&output, &self.config);
let structured = json!({
"status": output.status,
"iterations": output.iterations,
"duration_ms": output.duration_ms,
"nodes_executed_count": output.nodes_executed.len(),
"result": output.result,
});
Ok(CallToolResult {
content: vec![
ToolContent::text(response_text),
ToolContent::text(format!(
"\n---\nStructured output: {}",
serde_json::to_string_pretty(&structured).unwrap_or_default()
)),
],
is_error: false,
})
}
Err(e) => Ok(CallToolResult {
content: vec![ToolContent::text(format!("Graph error: {}", e))],
is_error: true,
}),
}
}
}
fn build_success_response(output: &GraphMcpOutput, config: &GraphMcpConfig) -> String {
let mut parts = vec![format!(
"Status: {} | Iterations: {} | Duration: {}ms",
output.status, output.iterations, output.duration_ms
)];
if config.include_node_details && !output.nodes_executed.is_empty() {
let nodes_str = output
.nodes_executed
.iter()
.map(|n| {
let summary = n
.output_summary
.as_deref()
.unwrap_or("(no summary)");
format!(
" - {} [iter {}] ({}ms): {}",
n.node_id, n.iteration, n.duration_ms, summary
)
})
.collect::<Vec<_>>()
.join("\n");
parts.push(format!("\n\nNodes executed:\n{}", nodes_str));
}
parts.join("")
}
pub struct GraphMcpHandlerBuilder {
name: String,
description: String,
capabilities: Vec<String>,
config: GraphMcpConfig,
}
impl GraphMcpHandlerBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
description: String::new(),
capabilities: Vec::new(),
config: GraphMcpConfig::default(),
}
}
pub fn description(self, description: impl Into<String>) -> Self {
Self {
description: description.into(),
..self
}
}
pub fn capability(self, capability: impl Into<String>) -> Self {
let mut capabilities = self.capabilities;
capabilities.push(capability.into());
Self {
capabilities,
..self
}
}
pub fn capabilities(self, new_capabilities: Vec<String>) -> Self {
let mut capabilities = self.capabilities;
capabilities.extend(new_capabilities);
Self {
capabilities,
..self
}
}
pub fn name_prefix(self, prefix: impl Into<String>) -> Self {
Self {
config: GraphMcpConfig {
name_prefix: prefix.into(),
..self.config
},
..self
}
}
pub fn include_node_details(self, include: bool) -> Self {
Self {
config: GraphMcpConfig {
include_node_details: include,
..self.config
},
..self
}
}
pub fn config(self, config: GraphMcpConfig) -> Self {
Self { config, ..self }
}
pub fn handler<F, Fut>(self, handler: F) -> GraphMcpHandler
where
F: Fn(GraphMcpInput) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<GraphMcpOutput, String>> + Send + 'static,
{
let tool_name = format!("{}{}", self.config.name_prefix, self.name);
GraphMcpHandler {
name: tool_name,
description: self.description,
capabilities: self.capabilities,
handler: Arc::new(move |input| Box::pin(handler(input))),
config: self.config,
}
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
#[test]
fn test_graph_mcp_input_full_deserialization() {
use super::GraphMcpInput;
let json_val = json!({
"input": {"query": "test", "depth": 3},
"max_iterations": 10
});
let input: GraphMcpInput = serde_json::from_value(json_val).unwrap();
assert_eq!(input.input["query"], "test");
assert_eq!(input.input["depth"], 3);
assert_eq!(input.max_iterations, Some(10));
}
#[test]
fn test_graph_handler_definition_and_schema() {
use super::*;
let handler = GraphMcpHandler::builder("pipeline")
.description("Data processing pipeline")
.capability("data_transform")
.capability("validation")
.handler(|_input: GraphMcpInput| async move {
Ok(GraphMcpOutput {
result: serde_json::json!({}),
status: "completed".to_string(),
nodes_executed: Vec::new(),
iterations: 0,
duration_ms: 0,
})
});
let def = handler.definition();
assert_eq!(def.name, "graph_pipeline");
let desc = def.description.unwrap();
assert!(desc.contains("Data processing pipeline"));
assert!(desc.contains("data_transform"));
assert!(desc.contains("validation"));
let schema = &def.input_schema;
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["input"].is_object());
assert!(schema["properties"]["max_iterations"].is_object());
assert_eq!(schema["required"][0], "input");
}
#[test]
fn test_graph_handler_custom_prefix() {
use super::*;
let handler = GraphMcpHandler::builder("workflow")
.description("A workflow")
.name_prefix("wf_")
.handler(|_input: GraphMcpInput| async move {
Ok(GraphMcpOutput {
result: serde_json::json!({}),
status: "completed".to_string(),
nodes_executed: Vec::new(),
iterations: 0,
duration_ms: 0,
})
});
let def = handler.definition();
assert_eq!(def.name, "wf_workflow");
}
#[tokio::test]
async fn test_graph_handler_execution_with_mock() {
use super::*;
let handler = GraphMcpHandler::builder("data_pipeline")
.description("Three-node pipeline with a cycle")
.handler(|input: GraphMcpInput| async move {
let query = input.input["query"].as_str().unwrap_or("unknown");
let max_iter = input.max_iterations.unwrap_or(5);
Ok(GraphMcpOutput {
result: json!({
"query": query,
"answer": format!("Processed: {}", query),
"max_iterations_used": max_iter,
}),
status: "completed".to_string(),
nodes_executed: vec![
NodeExecution {
node_id: "node_a".to_string(),
iteration: 1,
duration_ms: 100,
output_summary: Some("Fetched data".to_string()),
},
NodeExecution {
node_id: "node_b".to_string(),
iteration: 1,
duration_ms: 200,
output_summary: Some("Transformed data".to_string()),
},
NodeExecution {
node_id: "node_c".to_string(),
iteration: 1,
duration_ms: 150,
output_summary: Some("Validated — needs retry".to_string()),
},
NodeExecution {
node_id: "node_b".to_string(),
iteration: 2,
duration_ms: 180,
output_summary: Some("Re-transformed".to_string()),
},
NodeExecution {
node_id: "node_c".to_string(),
iteration: 2,
duration_ms: 120,
output_summary: Some("Validated — passed".to_string()),
},
],
iterations: 2,
duration_ms: 750,
})
});
let result = handler
.execute(json!({
"input": {"query": "AI trends"},
"max_iterations": 5
}))
.await
.unwrap();
assert!(!result.is_error);
let text = result.content[0].as_text().unwrap();
assert!(text.contains("Status: completed"));
assert!(text.contains("Iterations: 2"));
assert!(text.contains("750ms"));
assert!(text.contains("node_a"));
assert!(text.contains("node_b"));
assert!(text.contains("node_c"));
assert!(text.contains("Fetched data"));
assert!(text.contains("Validated — passed"));
let structured_text = result.content[1].as_text().unwrap();
assert!(structured_text.contains("\"status\": \"completed\""));
assert!(structured_text.contains("\"iterations\": 2"));
assert!(structured_text.contains("750"));
assert!(structured_text.contains("\"nodes_executed_count\": 5"));
}
#[tokio::test]
async fn test_graph_handler_error_returns_is_error() {
use super::*;
let handler = GraphMcpHandler::builder("failing_graph")
.description("A graph that fails")
.handler(|_: GraphMcpInput| async move {
Err("Node 'validate' failed: timeout after 30s".to_string())
});
let result = handler
.execute(json!({"input": {"data": "test"}}))
.await
.unwrap();
assert!(result.is_error);
let text = result.content[0].as_text().unwrap();
assert!(text.contains("Graph error"));
assert!(text.contains("timeout after 30s"));
}
#[tokio::test]
async fn test_graph_handler_invalid_input_returns_error() {
use super::*;
let handler = GraphMcpHandler::builder("strict_graph")
.description("Graph with strict input")
.handler(|_: GraphMcpInput| async move {
Ok(GraphMcpOutput {
result: json!({}),
status: "completed".to_string(),
nodes_executed: Vec::new(),
iterations: 0,
duration_ms: 0,
})
});
let result = handler.execute(json!({"max_iterations": 5})).await;
assert!(result.is_err());
}
#[test]
fn test_graph_mcp_input_minimal_deserialization() {
use super::GraphMcpInput;
let json_val = json!({"input": {"key": "value"}});
let input: GraphMcpInput = serde_json::from_value(json_val).unwrap();
assert_eq!(input.input["key"], "value");
assert!(input.max_iterations.is_none());
}
}