Skip to main content

cortexai_mcp/
graph_handler.rs

1//! Graph-as-MCP-Tool
2//!
3//! Exposes cortex graph workflows (with cycles, conditionals, state) as MCP tools,
4//! allowing external MCP clients to invoke graph executions as a single tool call.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use std::sync::Arc;
10use tracing::{debug, info};
11
12use crate::error::McpError;
13use crate::protocol::{CallToolResult, McpTool, ToolContent};
14use crate::server::ToolHandler;
15
16// =============================================================================
17// Input / Output types
18// =============================================================================
19
20/// Input schema for graph MCP tools
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct GraphMcpInput {
23    /// Initial graph state data
24    pub input: serde_json::Value,
25    /// Iteration limit for cyclic graphs
26    #[serde(default)]
27    pub max_iterations: Option<u32>,
28}
29
30/// Per-node execution record
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct NodeExecution {
33    /// Identifier of the executed node
34    pub node_id: String,
35    /// Which iteration this execution occurred in
36    pub iteration: u32,
37    /// Duration of this node's execution in milliseconds
38    pub duration_ms: u64,
39    /// Optional summary of the node's output
40    pub output_summary: Option<String>,
41}
42
43/// Output structure for graph MCP responses
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct GraphMcpOutput {
46    /// Final graph state
47    pub result: serde_json::Value,
48    /// Execution status: completed, failed, max_iterations_reached
49    pub status: String,
50    /// Per-node execution log
51    pub nodes_executed: Vec<NodeExecution>,
52    /// How many iterations were needed
53    pub iterations: u32,
54    /// Total execution duration in milliseconds
55    pub duration_ms: u64,
56}
57
58// =============================================================================
59// Configuration
60// =============================================================================
61
62/// Configuration for a graph MCP handler
63#[derive(Debug, Clone)]
64pub struct GraphMcpConfig {
65    /// Name prefix for the MCP tool (e.g., "graph_")
66    pub name_prefix: String,
67    /// Whether to include per-node execution details in the response
68    pub include_node_details: bool,
69}
70
71impl Default for GraphMcpConfig {
72    fn default() -> Self {
73        Self {
74            name_prefix: "graph_".to_string(),
75            include_node_details: true,
76        }
77    }
78}
79
80// =============================================================================
81// Handler
82// =============================================================================
83
84/// Handler type for graph execution
85pub type GraphHandlerFn = Arc<
86    dyn Fn(
87            GraphMcpInput,
88        ) -> std::pin::Pin<
89            Box<dyn std::future::Future<Output = Result<GraphMcpOutput, String>> + Send>,
90        > + Send
91        + Sync,
92>;
93
94/// MCP ToolHandler that wraps a graph workflow
95pub struct GraphMcpHandler {
96    name: String,
97    description: String,
98    capabilities: Vec<String>,
99    handler: GraphHandlerFn,
100    config: GraphMcpConfig,
101}
102
103impl GraphMcpHandler {
104    /// Create a builder for fluent construction
105    pub fn builder(name: impl Into<String>) -> GraphMcpHandlerBuilder {
106        GraphMcpHandlerBuilder::new(name)
107    }
108
109    /// Get the tool name
110    pub fn name(&self) -> &str {
111        &self.name
112    }
113
114    /// Get the capabilities
115    pub fn capabilities(&self) -> &[String] {
116        &self.capabilities
117    }
118}
119
120#[async_trait]
121impl ToolHandler for GraphMcpHandler {
122    fn definition(&self) -> McpTool {
123        let schema = json!({
124            "type": "object",
125            "properties": {
126                "input": {
127                    "type": "object",
128                    "description": "Initial graph state data"
129                },
130                "max_iterations": {
131                    "type": "integer",
132                    "description": "Iteration limit for cyclic graphs"
133                }
134            },
135            "required": ["input"]
136        });
137
138        let description = if self.capabilities.is_empty() {
139            self.description.clone()
140        } else {
141            format!(
142                "{}\n\nCapabilities: {}",
143                self.description,
144                self.capabilities.join(", ")
145            )
146        };
147
148        McpTool {
149            name: self.name.clone(),
150            description: Some(description),
151            input_schema: schema,
152        }
153    }
154
155    async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError> {
156        debug!(tool = %self.name, "Executing graph MCP handler");
157
158        let input: GraphMcpInput = serde_json::from_value(arguments)
159            .map_err(|e| McpError::InvalidParams(format!("Invalid input: {}", e)))?;
160
161        info!(
162            tool = %self.name,
163            max_iterations = ?input.max_iterations,
164            "Graph executing"
165        );
166
167        let result = (self.handler)(input).await;
168
169        match result {
170            Ok(output) => {
171                let response_text = build_success_response(&output, &self.config);
172
173                let structured = json!({
174                    "status": output.status,
175                    "iterations": output.iterations,
176                    "duration_ms": output.duration_ms,
177                    "nodes_executed_count": output.nodes_executed.len(),
178                    "result": output.result,
179                });
180
181                Ok(CallToolResult {
182                    content: vec![
183                        ToolContent::text(response_text),
184                        ToolContent::text(format!(
185                            "\n---\nStructured output: {}",
186                            serde_json::to_string_pretty(&structured).unwrap_or_default()
187                        )),
188                    ],
189                    is_error: false,
190                })
191            }
192            Err(e) => Ok(CallToolResult {
193                content: vec![ToolContent::text(format!("Graph error: {}", e))],
194                is_error: true,
195            }),
196        }
197    }
198}
199
200/// Build the human-readable success response text
201fn build_success_response(output: &GraphMcpOutput, config: &GraphMcpConfig) -> String {
202    let mut parts = vec![format!(
203        "Status: {} | Iterations: {} | Duration: {}ms",
204        output.status, output.iterations, output.duration_ms
205    )];
206
207    if config.include_node_details && !output.nodes_executed.is_empty() {
208        let nodes_str = output
209            .nodes_executed
210            .iter()
211            .map(|n| {
212                let summary = n
213                    .output_summary
214                    .as_deref()
215                    .unwrap_or("(no summary)");
216                format!(
217                    "  - {} [iter {}] ({}ms): {}",
218                    n.node_id, n.iteration, n.duration_ms, summary
219                )
220            })
221            .collect::<Vec<_>>()
222            .join("\n");
223        parts.push(format!("\n\nNodes executed:\n{}", nodes_str));
224    }
225
226    parts.join("")
227}
228
229// =============================================================================
230// Builder
231// =============================================================================
232
233/// Builder for GraphMcpHandler
234pub struct GraphMcpHandlerBuilder {
235    name: String,
236    description: String,
237    capabilities: Vec<String>,
238    config: GraphMcpConfig,
239}
240
241impl GraphMcpHandlerBuilder {
242    pub fn new(name: impl Into<String>) -> Self {
243        Self {
244            name: name.into(),
245            description: String::new(),
246            capabilities: Vec::new(),
247            config: GraphMcpConfig::default(),
248        }
249    }
250
251    pub fn description(self, description: impl Into<String>) -> Self {
252        Self {
253            description: description.into(),
254            ..self
255        }
256    }
257
258    pub fn capability(self, capability: impl Into<String>) -> Self {
259        let mut capabilities = self.capabilities;
260        capabilities.push(capability.into());
261        Self {
262            capabilities,
263            ..self
264        }
265    }
266
267    pub fn capabilities(self, new_capabilities: Vec<String>) -> Self {
268        let mut capabilities = self.capabilities;
269        capabilities.extend(new_capabilities);
270        Self {
271            capabilities,
272            ..self
273        }
274    }
275
276    pub fn name_prefix(self, prefix: impl Into<String>) -> Self {
277        Self {
278            config: GraphMcpConfig {
279                name_prefix: prefix.into(),
280                ..self.config
281            },
282            ..self
283        }
284    }
285
286    pub fn include_node_details(self, include: bool) -> Self {
287        Self {
288            config: GraphMcpConfig {
289                include_node_details: include,
290                ..self.config
291            },
292            ..self
293        }
294    }
295
296    pub fn config(self, config: GraphMcpConfig) -> Self {
297        Self { config, ..self }
298    }
299
300    /// Build with a handler function
301    pub fn handler<F, Fut>(self, handler: F) -> GraphMcpHandler
302    where
303        F: Fn(GraphMcpInput) -> Fut + Send + Sync + 'static,
304        Fut: std::future::Future<Output = Result<GraphMcpOutput, String>> + Send + 'static,
305    {
306        let tool_name = format!("{}{}", self.config.name_prefix, self.name);
307
308        GraphMcpHandler {
309            name: tool_name,
310            description: self.description,
311            capabilities: self.capabilities,
312            handler: Arc::new(move |input| Box::pin(handler(input))),
313            config: self.config,
314        }
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use serde_json::json;
321
322    #[test]
323    fn test_graph_mcp_input_full_deserialization() {
324        use super::GraphMcpInput;
325
326        let json_val = json!({
327            "input": {"query": "test", "depth": 3},
328            "max_iterations": 10
329        });
330
331        let input: GraphMcpInput = serde_json::from_value(json_val).unwrap();
332        assert_eq!(input.input["query"], "test");
333        assert_eq!(input.input["depth"], 3);
334        assert_eq!(input.max_iterations, Some(10));
335    }
336
337    #[test]
338    fn test_graph_handler_definition_and_schema() {
339        use super::*;
340
341        let handler = GraphMcpHandler::builder("pipeline")
342            .description("Data processing pipeline")
343            .capability("data_transform")
344            .capability("validation")
345            .handler(|_input: GraphMcpInput| async move {
346                Ok(GraphMcpOutput {
347                    result: serde_json::json!({}),
348                    status: "completed".to_string(),
349                    nodes_executed: Vec::new(),
350                    iterations: 0,
351                    duration_ms: 0,
352                })
353            });
354
355        let def = handler.definition();
356        assert_eq!(def.name, "graph_pipeline");
357        let desc = def.description.unwrap();
358        assert!(desc.contains("Data processing pipeline"));
359        assert!(desc.contains("data_transform"));
360        assert!(desc.contains("validation"));
361
362        let schema = &def.input_schema;
363        assert_eq!(schema["type"], "object");
364        assert!(schema["properties"]["input"].is_object());
365        assert!(schema["properties"]["max_iterations"].is_object());
366        assert_eq!(schema["required"][0], "input");
367    }
368
369    #[test]
370    fn test_graph_handler_custom_prefix() {
371        use super::*;
372
373        let handler = GraphMcpHandler::builder("workflow")
374            .description("A workflow")
375            .name_prefix("wf_")
376            .handler(|_input: GraphMcpInput| async move {
377                Ok(GraphMcpOutput {
378                    result: serde_json::json!({}),
379                    status: "completed".to_string(),
380                    nodes_executed: Vec::new(),
381                    iterations: 0,
382                    duration_ms: 0,
383                })
384            });
385
386        let def = handler.definition();
387        assert_eq!(def.name, "wf_workflow");
388    }
389
390    #[tokio::test]
391    async fn test_graph_handler_execution_with_mock() {
392        use super::*;
393
394        // Simulate a 3-node linear graph + 1 cycle (node_c loops back to node_b once)
395        let handler = GraphMcpHandler::builder("data_pipeline")
396            .description("Three-node pipeline with a cycle")
397            .handler(|input: GraphMcpInput| async move {
398                let query = input.input["query"].as_str().unwrap_or("unknown");
399                let max_iter = input.max_iterations.unwrap_or(5);
400
401                Ok(GraphMcpOutput {
402                    result: json!({
403                        "query": query,
404                        "answer": format!("Processed: {}", query),
405                        "max_iterations_used": max_iter,
406                    }),
407                    status: "completed".to_string(),
408                    nodes_executed: vec![
409                        NodeExecution {
410                            node_id: "node_a".to_string(),
411                            iteration: 1,
412                            duration_ms: 100,
413                            output_summary: Some("Fetched data".to_string()),
414                        },
415                        NodeExecution {
416                            node_id: "node_b".to_string(),
417                            iteration: 1,
418                            duration_ms: 200,
419                            output_summary: Some("Transformed data".to_string()),
420                        },
421                        NodeExecution {
422                            node_id: "node_c".to_string(),
423                            iteration: 1,
424                            duration_ms: 150,
425                            output_summary: Some("Validated — needs retry".to_string()),
426                        },
427                        NodeExecution {
428                            node_id: "node_b".to_string(),
429                            iteration: 2,
430                            duration_ms: 180,
431                            output_summary: Some("Re-transformed".to_string()),
432                        },
433                        NodeExecution {
434                            node_id: "node_c".to_string(),
435                            iteration: 2,
436                            duration_ms: 120,
437                            output_summary: Some("Validated — passed".to_string()),
438                        },
439                    ],
440                    iterations: 2,
441                    duration_ms: 750,
442                })
443            });
444
445        let result = handler
446            .execute(json!({
447                "input": {"query": "AI trends"},
448                "max_iterations": 5
449            }))
450            .await
451            .unwrap();
452
453        assert!(!result.is_error);
454
455        let text = result.content[0].as_text().unwrap();
456        assert!(text.contains("Status: completed"));
457        assert!(text.contains("Iterations: 2"));
458        assert!(text.contains("750ms"));
459        assert!(text.contains("node_a"));
460        assert!(text.contains("node_b"));
461        assert!(text.contains("node_c"));
462        assert!(text.contains("Fetched data"));
463        assert!(text.contains("Validated — passed"));
464
465        // Verify structured output
466        let structured_text = result.content[1].as_text().unwrap();
467        assert!(structured_text.contains("\"status\": \"completed\""));
468        assert!(structured_text.contains("\"iterations\": 2"));
469        assert!(structured_text.contains("750"));
470        assert!(structured_text.contains("\"nodes_executed_count\": 5"));
471    }
472
473    #[tokio::test]
474    async fn test_graph_handler_error_returns_is_error() {
475        use super::*;
476
477        let handler = GraphMcpHandler::builder("failing_graph")
478            .description("A graph that fails")
479            .handler(|_: GraphMcpInput| async move {
480                Err("Node 'validate' failed: timeout after 30s".to_string())
481            });
482
483        let result = handler
484            .execute(json!({"input": {"data": "test"}}))
485            .await
486            .unwrap();
487
488        assert!(result.is_error);
489        let text = result.content[0].as_text().unwrap();
490        assert!(text.contains("Graph error"));
491        assert!(text.contains("timeout after 30s"));
492    }
493
494    #[tokio::test]
495    async fn test_graph_handler_invalid_input_returns_error() {
496        use super::*;
497
498        let handler = GraphMcpHandler::builder("strict_graph")
499            .description("Graph with strict input")
500            .handler(|_: GraphMcpInput| async move {
501                Ok(GraphMcpOutput {
502                    result: json!({}),
503                    status: "completed".to_string(),
504                    nodes_executed: Vec::new(),
505                    iterations: 0,
506                    duration_ms: 0,
507                })
508            });
509
510        // Missing required "input" field
511        let result = handler.execute(json!({"max_iterations": 5})).await;
512        assert!(result.is_err());
513    }
514
515    #[test]
516    fn test_graph_mcp_input_minimal_deserialization() {
517        use super::GraphMcpInput;
518
519        let json_val = json!({"input": {"key": "value"}});
520        let input: GraphMcpInput = serde_json::from_value(json_val).unwrap();
521
522        assert_eq!(input.input["key"], "value");
523        assert!(input.max_iterations.is_none());
524    }
525}