1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use std::sync::Arc;
10use tracing::{debug, info};
11
12use crate::error::McpError;
13use crate::protocol::{CallToolResult, McpTool, ToolContent};
14use crate::server::ToolHandler;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct GraphMcpInput {
23 pub input: serde_json::Value,
25 #[serde(default)]
27 pub max_iterations: Option<u32>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct NodeExecution {
33 pub node_id: String,
35 pub iteration: u32,
37 pub duration_ms: u64,
39 pub output_summary: Option<String>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct GraphMcpOutput {
46 pub result: serde_json::Value,
48 pub status: String,
50 pub nodes_executed: Vec<NodeExecution>,
52 pub iterations: u32,
54 pub duration_ms: u64,
56}
57
58#[derive(Debug, Clone)]
64pub struct GraphMcpConfig {
65 pub name_prefix: String,
67 pub include_node_details: bool,
69}
70
71impl Default for GraphMcpConfig {
72 fn default() -> Self {
73 Self {
74 name_prefix: "graph_".to_string(),
75 include_node_details: true,
76 }
77 }
78}
79
80pub type GraphHandlerFn = Arc<
86 dyn Fn(
87 GraphMcpInput,
88 ) -> std::pin::Pin<
89 Box<dyn std::future::Future<Output = Result<GraphMcpOutput, String>> + Send>,
90 > + Send
91 + Sync,
92>;
93
94pub struct GraphMcpHandler {
96 name: String,
97 description: String,
98 capabilities: Vec<String>,
99 handler: GraphHandlerFn,
100 config: GraphMcpConfig,
101}
102
103impl GraphMcpHandler {
104 pub fn builder(name: impl Into<String>) -> GraphMcpHandlerBuilder {
106 GraphMcpHandlerBuilder::new(name)
107 }
108
109 pub fn name(&self) -> &str {
111 &self.name
112 }
113
114 pub fn capabilities(&self) -> &[String] {
116 &self.capabilities
117 }
118}
119
120#[async_trait]
121impl ToolHandler for GraphMcpHandler {
122 fn definition(&self) -> McpTool {
123 let schema = json!({
124 "type": "object",
125 "properties": {
126 "input": {
127 "type": "object",
128 "description": "Initial graph state data"
129 },
130 "max_iterations": {
131 "type": "integer",
132 "description": "Iteration limit for cyclic graphs"
133 }
134 },
135 "required": ["input"]
136 });
137
138 let description = if self.capabilities.is_empty() {
139 self.description.clone()
140 } else {
141 format!(
142 "{}\n\nCapabilities: {}",
143 self.description,
144 self.capabilities.join(", ")
145 )
146 };
147
148 McpTool {
149 name: self.name.clone(),
150 description: Some(description),
151 input_schema: schema,
152 }
153 }
154
155 async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError> {
156 debug!(tool = %self.name, "Executing graph MCP handler");
157
158 let input: GraphMcpInput = serde_json::from_value(arguments)
159 .map_err(|e| McpError::InvalidParams(format!("Invalid input: {}", e)))?;
160
161 info!(
162 tool = %self.name,
163 max_iterations = ?input.max_iterations,
164 "Graph executing"
165 );
166
167 let result = (self.handler)(input).await;
168
169 match result {
170 Ok(output) => {
171 let response_text = build_success_response(&output, &self.config);
172
173 let structured = json!({
174 "status": output.status,
175 "iterations": output.iterations,
176 "duration_ms": output.duration_ms,
177 "nodes_executed_count": output.nodes_executed.len(),
178 "result": output.result,
179 });
180
181 Ok(CallToolResult {
182 content: vec![
183 ToolContent::text(response_text),
184 ToolContent::text(format!(
185 "\n---\nStructured output: {}",
186 serde_json::to_string_pretty(&structured).unwrap_or_default()
187 )),
188 ],
189 is_error: false,
190 })
191 }
192 Err(e) => Ok(CallToolResult {
193 content: vec![ToolContent::text(format!("Graph error: {}", e))],
194 is_error: true,
195 }),
196 }
197 }
198}
199
200fn build_success_response(output: &GraphMcpOutput, config: &GraphMcpConfig) -> String {
202 let mut parts = vec![format!(
203 "Status: {} | Iterations: {} | Duration: {}ms",
204 output.status, output.iterations, output.duration_ms
205 )];
206
207 if config.include_node_details && !output.nodes_executed.is_empty() {
208 let nodes_str = output
209 .nodes_executed
210 .iter()
211 .map(|n| {
212 let summary = n
213 .output_summary
214 .as_deref()
215 .unwrap_or("(no summary)");
216 format!(
217 " - {} [iter {}] ({}ms): {}",
218 n.node_id, n.iteration, n.duration_ms, summary
219 )
220 })
221 .collect::<Vec<_>>()
222 .join("\n");
223 parts.push(format!("\n\nNodes executed:\n{}", nodes_str));
224 }
225
226 parts.join("")
227}
228
229pub struct GraphMcpHandlerBuilder {
235 name: String,
236 description: String,
237 capabilities: Vec<String>,
238 config: GraphMcpConfig,
239}
240
241impl GraphMcpHandlerBuilder {
242 pub fn new(name: impl Into<String>) -> Self {
243 Self {
244 name: name.into(),
245 description: String::new(),
246 capabilities: Vec::new(),
247 config: GraphMcpConfig::default(),
248 }
249 }
250
251 pub fn description(self, description: impl Into<String>) -> Self {
252 Self {
253 description: description.into(),
254 ..self
255 }
256 }
257
258 pub fn capability(self, capability: impl Into<String>) -> Self {
259 let mut capabilities = self.capabilities;
260 capabilities.push(capability.into());
261 Self {
262 capabilities,
263 ..self
264 }
265 }
266
267 pub fn capabilities(self, new_capabilities: Vec<String>) -> Self {
268 let mut capabilities = self.capabilities;
269 capabilities.extend(new_capabilities);
270 Self {
271 capabilities,
272 ..self
273 }
274 }
275
276 pub fn name_prefix(self, prefix: impl Into<String>) -> Self {
277 Self {
278 config: GraphMcpConfig {
279 name_prefix: prefix.into(),
280 ..self.config
281 },
282 ..self
283 }
284 }
285
286 pub fn include_node_details(self, include: bool) -> Self {
287 Self {
288 config: GraphMcpConfig {
289 include_node_details: include,
290 ..self.config
291 },
292 ..self
293 }
294 }
295
296 pub fn config(self, config: GraphMcpConfig) -> Self {
297 Self { config, ..self }
298 }
299
300 pub fn handler<F, Fut>(self, handler: F) -> GraphMcpHandler
302 where
303 F: Fn(GraphMcpInput) -> Fut + Send + Sync + 'static,
304 Fut: std::future::Future<Output = Result<GraphMcpOutput, String>> + Send + 'static,
305 {
306 let tool_name = format!("{}{}", self.config.name_prefix, self.name);
307
308 GraphMcpHandler {
309 name: tool_name,
310 description: self.description,
311 capabilities: self.capabilities,
312 handler: Arc::new(move |input| Box::pin(handler(input))),
313 config: self.config,
314 }
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use serde_json::json;
321
322 #[test]
323 fn test_graph_mcp_input_full_deserialization() {
324 use super::GraphMcpInput;
325
326 let json_val = json!({
327 "input": {"query": "test", "depth": 3},
328 "max_iterations": 10
329 });
330
331 let input: GraphMcpInput = serde_json::from_value(json_val).unwrap();
332 assert_eq!(input.input["query"], "test");
333 assert_eq!(input.input["depth"], 3);
334 assert_eq!(input.max_iterations, Some(10));
335 }
336
337 #[test]
338 fn test_graph_handler_definition_and_schema() {
339 use super::*;
340
341 let handler = GraphMcpHandler::builder("pipeline")
342 .description("Data processing pipeline")
343 .capability("data_transform")
344 .capability("validation")
345 .handler(|_input: GraphMcpInput| async move {
346 Ok(GraphMcpOutput {
347 result: serde_json::json!({}),
348 status: "completed".to_string(),
349 nodes_executed: Vec::new(),
350 iterations: 0,
351 duration_ms: 0,
352 })
353 });
354
355 let def = handler.definition();
356 assert_eq!(def.name, "graph_pipeline");
357 let desc = def.description.unwrap();
358 assert!(desc.contains("Data processing pipeline"));
359 assert!(desc.contains("data_transform"));
360 assert!(desc.contains("validation"));
361
362 let schema = &def.input_schema;
363 assert_eq!(schema["type"], "object");
364 assert!(schema["properties"]["input"].is_object());
365 assert!(schema["properties"]["max_iterations"].is_object());
366 assert_eq!(schema["required"][0], "input");
367 }
368
369 #[test]
370 fn test_graph_handler_custom_prefix() {
371 use super::*;
372
373 let handler = GraphMcpHandler::builder("workflow")
374 .description("A workflow")
375 .name_prefix("wf_")
376 .handler(|_input: GraphMcpInput| async move {
377 Ok(GraphMcpOutput {
378 result: serde_json::json!({}),
379 status: "completed".to_string(),
380 nodes_executed: Vec::new(),
381 iterations: 0,
382 duration_ms: 0,
383 })
384 });
385
386 let def = handler.definition();
387 assert_eq!(def.name, "wf_workflow");
388 }
389
390 #[tokio::test]
391 async fn test_graph_handler_execution_with_mock() {
392 use super::*;
393
394 let handler = GraphMcpHandler::builder("data_pipeline")
396 .description("Three-node pipeline with a cycle")
397 .handler(|input: GraphMcpInput| async move {
398 let query = input.input["query"].as_str().unwrap_or("unknown");
399 let max_iter = input.max_iterations.unwrap_or(5);
400
401 Ok(GraphMcpOutput {
402 result: json!({
403 "query": query,
404 "answer": format!("Processed: {}", query),
405 "max_iterations_used": max_iter,
406 }),
407 status: "completed".to_string(),
408 nodes_executed: vec![
409 NodeExecution {
410 node_id: "node_a".to_string(),
411 iteration: 1,
412 duration_ms: 100,
413 output_summary: Some("Fetched data".to_string()),
414 },
415 NodeExecution {
416 node_id: "node_b".to_string(),
417 iteration: 1,
418 duration_ms: 200,
419 output_summary: Some("Transformed data".to_string()),
420 },
421 NodeExecution {
422 node_id: "node_c".to_string(),
423 iteration: 1,
424 duration_ms: 150,
425 output_summary: Some("Validated — needs retry".to_string()),
426 },
427 NodeExecution {
428 node_id: "node_b".to_string(),
429 iteration: 2,
430 duration_ms: 180,
431 output_summary: Some("Re-transformed".to_string()),
432 },
433 NodeExecution {
434 node_id: "node_c".to_string(),
435 iteration: 2,
436 duration_ms: 120,
437 output_summary: Some("Validated — passed".to_string()),
438 },
439 ],
440 iterations: 2,
441 duration_ms: 750,
442 })
443 });
444
445 let result = handler
446 .execute(json!({
447 "input": {"query": "AI trends"},
448 "max_iterations": 5
449 }))
450 .await
451 .unwrap();
452
453 assert!(!result.is_error);
454
455 let text = result.content[0].as_text().unwrap();
456 assert!(text.contains("Status: completed"));
457 assert!(text.contains("Iterations: 2"));
458 assert!(text.contains("750ms"));
459 assert!(text.contains("node_a"));
460 assert!(text.contains("node_b"));
461 assert!(text.contains("node_c"));
462 assert!(text.contains("Fetched data"));
463 assert!(text.contains("Validated — passed"));
464
465 let structured_text = result.content[1].as_text().unwrap();
467 assert!(structured_text.contains("\"status\": \"completed\""));
468 assert!(structured_text.contains("\"iterations\": 2"));
469 assert!(structured_text.contains("750"));
470 assert!(structured_text.contains("\"nodes_executed_count\": 5"));
471 }
472
473 #[tokio::test]
474 async fn test_graph_handler_error_returns_is_error() {
475 use super::*;
476
477 let handler = GraphMcpHandler::builder("failing_graph")
478 .description("A graph that fails")
479 .handler(|_: GraphMcpInput| async move {
480 Err("Node 'validate' failed: timeout after 30s".to_string())
481 });
482
483 let result = handler
484 .execute(json!({"input": {"data": "test"}}))
485 .await
486 .unwrap();
487
488 assert!(result.is_error);
489 let text = result.content[0].as_text().unwrap();
490 assert!(text.contains("Graph error"));
491 assert!(text.contains("timeout after 30s"));
492 }
493
494 #[tokio::test]
495 async fn test_graph_handler_invalid_input_returns_error() {
496 use super::*;
497
498 let handler = GraphMcpHandler::builder("strict_graph")
499 .description("Graph with strict input")
500 .handler(|_: GraphMcpInput| async move {
501 Ok(GraphMcpOutput {
502 result: json!({}),
503 status: "completed".to_string(),
504 nodes_executed: Vec::new(),
505 iterations: 0,
506 duration_ms: 0,
507 })
508 });
509
510 let result = handler.execute(json!({"max_iterations": 5})).await;
512 assert!(result.is_err());
513 }
514
515 #[test]
516 fn test_graph_mcp_input_minimal_deserialization() {
517 use super::GraphMcpInput;
518
519 let json_val = json!({"input": {"key": "value"}});
520 let input: GraphMcpInput = serde_json::from_value(json_val).unwrap();
521
522 assert_eq!(input.input["key"], "value");
523 assert!(input.max_iterations.is_none());
524 }
525}