Skip to main content

fabryk_mcp_graph/
tools.rs

1//! MCP tools for graph queries.
2//!
3//! Provides `GraphTools` that implements `ToolRegistry` by delegating
4//! queries to `fabryk_graph` algorithms.
5
6use fabryk_mcp_core::error::McpErrorExt;
7use fabryk_mcp_core::model::{CallToolResult, Content, ErrorData, Tool};
8use fabryk_mcp_core::registry::{ToolRegistry, ToolResult};
9
10use fabryk_graph::{
11    EdgeInfo, GraphData, NeighborInfo, NodeSummary, PathStep, Relationship, calculate_centrality,
12    compute_stats, find_bridges, neighborhood, prerequisites_sorted, shortest_path, validate_graph,
13};
14use serde::Deserialize;
15use serde_json::{Value, json};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19
20// ---------------------------------------------------------------------------
21// Helpers
22// ---------------------------------------------------------------------------
23
24fn json_schema(value: Value) -> Arc<serde_json::Map<String, Value>> {
25    match value {
26        Value::Object(map) => Arc::new(map),
27        _ => Arc::new(serde_json::Map::new()),
28    }
29}
30
31fn make_tool(name: &str, description: &str, schema: Value) -> Tool {
32    Tool::new(
33        name.to_string(),
34        description.to_string(),
35        json_schema(schema),
36    )
37}
38
39fn serialize_response<T: serde::Serialize>(value: &T) -> Result<CallToolResult, ErrorData> {
40    let json = serde_json::to_string_pretty(value)
41        .map_err(|e| ErrorData::internal_error(e.to_string(), None))?;
42    Ok(CallToolResult::success(vec![Content::text(json)]))
43}
44
45fn parse_relationship(s: &str) -> Relationship {
46    match s.to_lowercase().as_str() {
47        "prerequisite" => Relationship::Prerequisite,
48        "leads_to" | "leadsto" => Relationship::LeadsTo,
49        "relates_to" | "relatesto" => Relationship::RelatesTo,
50        "extends" => Relationship::Extends,
51        "introduces" => Relationship::Introduces,
52        "covers" => Relationship::Covers,
53        "variant_of" | "variantof" => Relationship::VariantOf,
54        other => Relationship::Custom(other.to_string()),
55    }
56}
57
58// ---------------------------------------------------------------------------
59// Argument types
60// ---------------------------------------------------------------------------
61
62/// Arguments for graph_related tool.
63#[derive(Debug, Deserialize)]
64pub struct RelatedArgs {
65    /// Node ID to find relations for.
66    pub id: String,
67    /// Optional relationship type filter.
68    pub relationship: Option<String>,
69    /// Maximum results.
70    pub limit: Option<usize>,
71}
72
73/// Arguments for graph_path tool.
74#[derive(Debug, Deserialize)]
75pub struct PathArgs {
76    /// Starting node ID.
77    pub from: String,
78    /// Target node ID.
79    pub to: String,
80}
81
82/// Arguments for graph_prerequisites tool.
83#[derive(Debug, Deserialize)]
84pub struct PrerequisitesArgs {
85    /// Target node ID.
86    pub id: String,
87}
88
89/// Arguments for graph_neighborhood tool.
90#[derive(Debug, Deserialize)]
91pub struct NeighborhoodArgs {
92    /// Center node ID.
93    pub id: String,
94    /// Hops from center (default 1).
95    pub radius: Option<usize>,
96    /// Optional relationship type filter.
97    pub relationship: Option<String>,
98}
99
100// ---------------------------------------------------------------------------
101// GraphTools
102// ---------------------------------------------------------------------------
103
104/// MCP tools for graph queries.
105///
106/// Generates eight tools:
107/// - `graph_related` — find related nodes
108/// - `graph_path` — shortest path between nodes
109/// - `graph_prerequisites` — learning order prerequisites
110/// - `graph_neighborhood` — N-hop neighborhood exploration
111/// - `graph_info` — graph statistics
112/// - `graph_validate` — structure validation
113/// - `graph_centrality` — most central/important nodes
114/// - `graph_bridges` — bridge nodes connecting different areas
115///
116/// # Example
117///
118/// ```rust,ignore
119/// use fabryk_graph::GraphData;
120/// use fabryk_mcp_graph::GraphTools;
121///
122/// let graph = fabryk_graph::load_graph("graph.json")?;
123/// let graph_tools = GraphTools::new(graph);
124/// ```
125pub struct GraphTools {
126    graph: Arc<RwLock<GraphData>>,
127    custom_names: HashMap<String, String>,
128    custom_descriptions: HashMap<String, String>,
129}
130
131impl GraphTools {
132    /// Slot key for the related nodes tool.
133    pub const SLOT_RELATED: &str = "graph_related";
134    /// Slot key for the shortest path tool.
135    pub const SLOT_PATH: &str = "graph_path";
136    /// Slot key for the prerequisites tool.
137    pub const SLOT_PREREQUISITES: &str = "graph_prerequisites";
138    /// Slot key for the neighborhood tool.
139    pub const SLOT_NEIGHBORHOOD: &str = "graph_neighborhood";
140    /// Slot key for the graph info/stats tool.
141    pub const SLOT_INFO: &str = "graph_info";
142    /// Slot key for the validation tool.
143    pub const SLOT_VALIDATE: &str = "graph_validate";
144    /// Slot key for the centrality tool.
145    pub const SLOT_CENTRALITY: &str = "graph_centrality";
146    /// Slot key for the bridges tool.
147    pub const SLOT_BRIDGES: &str = "graph_bridges";
148
149    /// Create new graph tools with owned graph data.
150    pub fn new(graph: GraphData) -> Self {
151        Self {
152            graph: Arc::new(RwLock::new(graph)),
153            custom_names: HashMap::new(),
154            custom_descriptions: HashMap::new(),
155        }
156    }
157
158    /// Create graph tools with a shared graph reference.
159    pub fn with_shared(graph: Arc<RwLock<GraphData>>) -> Self {
160        Self {
161            graph,
162            custom_names: HashMap::new(),
163            custom_descriptions: HashMap::new(),
164        }
165    }
166
167    /// Override tool names by slot key.
168    pub fn with_names(mut self, names: HashMap<String, String>) -> Self {
169        self.custom_names = names;
170        self
171    }
172
173    /// Override tool descriptions by slot key.
174    pub fn with_descriptions(mut self, descriptions: HashMap<String, String>) -> Self {
175        self.custom_descriptions = descriptions;
176        self
177    }
178
179    /// Update the graph data (e.g., after rebuild).
180    pub async fn update_graph(&self, graph: GraphData) {
181        let mut lock = self.graph.write().await;
182        *lock = graph;
183    }
184
185    fn tool_name(&self, slot: &str) -> String {
186        self.custom_names
187            .get(slot)
188            .cloned()
189            .unwrap_or_else(|| slot.to_string())
190    }
191
192    fn tool_description(&self, slot: &str, default: &str) -> String {
193        self.custom_descriptions
194            .get(slot)
195            .cloned()
196            .unwrap_or_else(|| default.to_string())
197    }
198}
199
200impl ToolRegistry for GraphTools {
201    fn tools(&self) -> Vec<Tool> {
202        vec![
203            make_tool(
204                &self.tool_name(Self::SLOT_RELATED),
205                &self.tool_description(Self::SLOT_RELATED, "Find nodes related to a given node"),
206                json!({
207                    "type": "object",
208                    "properties": {
209                        "id": {
210                            "type": "string",
211                            "description": "Node ID"
212                        },
213                        "relationship": {
214                            "type": "string",
215                            "description": "Filter by relationship type (e.g., prerequisite, relates_to)"
216                        },
217                        "limit": {
218                            "type": "integer",
219                            "description": "Maximum results"
220                        }
221                    },
222                    "required": ["id"]
223                }),
224            ),
225            make_tool(
226                &self.tool_name(Self::SLOT_PATH),
227                &self.tool_description(Self::SLOT_PATH, "Find the shortest path between two nodes"),
228                json!({
229                    "type": "object",
230                    "properties": {
231                        "from": {
232                            "type": "string",
233                            "description": "Starting node ID"
234                        },
235                        "to": {
236                            "type": "string",
237                            "description": "Target node ID"
238                        }
239                    },
240                    "required": ["from", "to"]
241                }),
242            ),
243            make_tool(
244                &self.tool_name(Self::SLOT_PREREQUISITES),
245                &self.tool_description(
246                    Self::SLOT_PREREQUISITES,
247                    "Get prerequisites for a node in learning order",
248                ),
249                json!({
250                    "type": "object",
251                    "properties": {
252                        "id": {
253                            "type": "string",
254                            "description": "Node ID"
255                        }
256                    },
257                    "required": ["id"]
258                }),
259            ),
260            make_tool(
261                &self.tool_name(Self::SLOT_NEIGHBORHOOD),
262                &self.tool_description(
263                    Self::SLOT_NEIGHBORHOOD,
264                    "Explore the neighborhood around a node",
265                ),
266                json!({
267                    "type": "object",
268                    "properties": {
269                        "id": {
270                            "type": "string",
271                            "description": "Center node ID"
272                        },
273                        "radius": {
274                            "type": "integer",
275                            "description": "Hops from center (default 1)"
276                        },
277                        "relationship": {
278                            "type": "string",
279                            "description": "Filter by relationship type"
280                        }
281                    },
282                    "required": ["id"]
283                }),
284            ),
285            make_tool(
286                &self.tool_name(Self::SLOT_INFO),
287                &self.tool_description(Self::SLOT_INFO, "Get graph statistics and overview"),
288                json!({
289                    "type": "object",
290                    "properties": {}
291                }),
292            ),
293            make_tool(
294                &self.tool_name(Self::SLOT_VALIDATE),
295                &self.tool_description(
296                    Self::SLOT_VALIDATE,
297                    "Validate graph structure and report issues",
298                ),
299                json!({
300                    "type": "object",
301                    "properties": {}
302                }),
303            ),
304            make_tool(
305                &self.tool_name(Self::SLOT_CENTRALITY),
306                &self.tool_description(Self::SLOT_CENTRALITY, "Get most central/important nodes"),
307                json!({
308                    "type": "object",
309                    "properties": {
310                        "limit": {
311                            "type": "integer",
312                            "description": "Number of results (default 10)"
313                        }
314                    }
315                }),
316            ),
317            make_tool(
318                &self.tool_name(Self::SLOT_BRIDGES),
319                &self.tool_description(
320                    Self::SLOT_BRIDGES,
321                    "Find bridge nodes that connect different areas",
322                ),
323                json!({
324                    "type": "object",
325                    "properties": {
326                        "limit": {
327                            "type": "integer",
328                            "description": "Number of results (default 10)"
329                        }
330                    }
331                }),
332            ),
333        ]
334    }
335
336    fn call(&self, name: &str, args: Value) -> Option<ToolResult> {
337        let graph = Arc::clone(&self.graph);
338
339        if name == self.tool_name(Self::SLOT_RELATED) {
340            return Some(Box::pin(async move {
341                let args: RelatedArgs = serde_json::from_value(args)
342                    .map_err(|e| ErrorData::invalid_params(e.to_string(), None))?;
343                let graph = graph.read().await;
344
345                let rel_filter = args
346                    .relationship
347                    .as_deref()
348                    .map(|r| vec![parse_relationship(r)]);
349
350                let result = neighborhood(&graph, &args.id, 1, rel_filter.as_deref())
351                    .map_err(|e| e.to_mcp_error())?;
352
353                let mut nodes: Vec<NodeSummary> =
354                    result.nodes.iter().map(NodeSummary::from).collect();
355
356                if let Some(limit) = args.limit {
357                    nodes.truncate(limit);
358                }
359
360                let count = nodes.len();
361                let response = json!({
362                    "source": NodeSummary::from(&result.center),
363                    "related": nodes,
364                    "count": count
365                });
366                serialize_response(&response)
367            }));
368        }
369
370        if name == self.tool_name(Self::SLOT_PATH) {
371            return Some(Box::pin(async move {
372                let args: PathArgs = serde_json::from_value(args)
373                    .map_err(|e| ErrorData::invalid_params(e.to_string(), None))?;
374                let graph = graph.read().await;
375
376                let result =
377                    shortest_path(&graph, &args.from, &args.to).map_err(|e| e.to_mcp_error())?;
378
379                if result.found {
380                    let path: Vec<PathStep> = result
381                        .path
382                        .iter()
383                        .enumerate()
384                        .map(|(i, node)| {
385                            let rel = result
386                                .edges
387                                .get(i)
388                                .map(|e| e.relationship.name().to_string());
389                            PathStep {
390                                node: NodeSummary::from(node),
391                                relationship_to_next: rel,
392                            }
393                        })
394                        .collect();
395
396                    let response = json!({
397                        "found": true,
398                        "path": path,
399                        "length": path.len(),
400                        "total_weight": result.total_weight
401                    });
402                    serialize_response(&response)
403                } else {
404                    let response = json!({
405                        "found": false,
406                        "message": format!("No path found from {} to {}", args.from, args.to)
407                    });
408                    serialize_response(&response)
409                }
410            }));
411        }
412
413        if name == self.tool_name(Self::SLOT_PREREQUISITES) {
414            return Some(Box::pin(async move {
415                let args: PrerequisitesArgs = serde_json::from_value(args)
416                    .map_err(|e| ErrorData::invalid_params(e.to_string(), None))?;
417                let graph = graph.read().await;
418
419                let result =
420                    prerequisites_sorted(&graph, &args.id).map_err(|e| e.to_mcp_error())?;
421
422                let prereqs: Vec<NodeSummary> =
423                    result.ordered.iter().map(NodeSummary::from).collect();
424
425                let count = prereqs.len();
426                let response = json!({
427                    "target": NodeSummary::from(&result.target),
428                    "prerequisites": prereqs,
429                    "count": count,
430                    "has_cycles": result.has_cycles
431                });
432                serialize_response(&response)
433            }));
434        }
435
436        if name == self.tool_name(Self::SLOT_NEIGHBORHOOD) {
437            return Some(Box::pin(async move {
438                let args: NeighborhoodArgs = serde_json::from_value(args)
439                    .map_err(|e| ErrorData::invalid_params(e.to_string(), None))?;
440                let graph = graph.read().await;
441
442                let radius = args.radius.unwrap_or(1);
443                let rel_filter = args
444                    .relationship
445                    .as_deref()
446                    .map(|r| vec![parse_relationship(r)]);
447
448                let result = neighborhood(&graph, &args.id, radius, rel_filter.as_deref())
449                    .map_err(|e| e.to_mcp_error())?;
450
451                let nodes: Vec<NeighborInfo> = result
452                    .nodes
453                    .iter()
454                    .map(|n| {
455                        let distance = result.distances.get(&n.id).copied().unwrap_or(0);
456                        NeighborInfo {
457                            node: NodeSummary::from(n),
458                            distance,
459                        }
460                    })
461                    .collect();
462
463                let edges: Vec<EdgeInfo> = result.edges.iter().map(EdgeInfo::from).collect();
464
465                let response = json!({
466                    "center": NodeSummary::from(&result.center),
467                    "radius": radius,
468                    "nodes": nodes,
469                    "edges": edges,
470                    "edge_count": edges.len()
471                });
472                serialize_response(&response)
473            }));
474        }
475
476        if name == self.tool_name(Self::SLOT_INFO) {
477            return Some(Box::pin(async move {
478                let graph = graph.read().await;
479                let stats = compute_stats(&graph);
480                serialize_response(&stats)
481            }));
482        }
483
484        if name == self.tool_name(Self::SLOT_VALIDATE) {
485            return Some(Box::pin(async move {
486                let graph = graph.read().await;
487                let result = validate_graph(&graph);
488                serialize_response(&result)
489            }));
490        }
491
492        if name == self.tool_name(Self::SLOT_CENTRALITY) {
493            return Some(Box::pin(async move {
494                let limit = args
495                    .get("limit")
496                    .and_then(|v| v.as_u64())
497                    .map(|n| n as usize)
498                    .unwrap_or(10);
499
500                let graph = graph.read().await;
501                let scores = calculate_centrality(&graph);
502
503                let top: Vec<_> = scores.into_iter().take(limit).collect();
504                serialize_response(&top)
505            }));
506        }
507
508        if name == self.tool_name(Self::SLOT_BRIDGES) {
509            return Some(Box::pin(async move {
510                let limit = args
511                    .get("limit")
512                    .and_then(|v| v.as_u64())
513                    .map(|n| n as usize)
514                    .unwrap_or(10);
515
516                let graph = graph.read().await;
517                let bridges = find_bridges(&graph, limit);
518
519                let summaries: Vec<NodeSummary> = bridges.iter().map(NodeSummary::from).collect();
520                serialize_response(&summaries)
521            }));
522        }
523
524        None
525    }
526}
527
528// ============================================================================
529// Tests
530// ============================================================================
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use fabryk_graph::{Edge, Node};
536
537    fn make_test_graph() -> GraphData {
538        let mut graph = GraphData::new();
539
540        // Add nodes
541        graph.add_node(Node::new("node-a", "Node A").with_category("alpha"));
542        graph.add_node(Node::new("node-b", "Node B").with_category("beta"));
543        graph.add_node(Node::new("node-c", "Node C").with_category("alpha"));
544
545        // Add edges: A -> B (prerequisite), B -> C (relates_to)
546        let _ = graph.add_edge(Edge::new("node-a", "node-b", Relationship::Prerequisite));
547        let _ = graph.add_edge(Edge::new("node-b", "node-c", Relationship::RelatesTo));
548
549        graph
550    }
551
552    // -- Tool creation tests ------------------------------------------------
553
554    #[test]
555    fn test_graph_tools_creation() {
556        let tools = GraphTools::new(GraphData::new());
557        assert_eq!(tools.tool_count(), 8);
558    }
559
560    #[test]
561    fn test_graph_tools_names() {
562        let tools = GraphTools::new(GraphData::new());
563        let tool_list = tools.tools();
564        let names: Vec<&str> = tool_list.iter().map(|t| t.name.as_ref()).collect();
565        assert!(names.contains(&"graph_related"));
566        assert!(names.contains(&"graph_path"));
567        assert!(names.contains(&"graph_prerequisites"));
568        assert!(names.contains(&"graph_neighborhood"));
569        assert!(names.contains(&"graph_info"));
570        assert!(names.contains(&"graph_validate"));
571        assert!(names.contains(&"graph_centrality"));
572        assert!(names.contains(&"graph_bridges"));
573    }
574
575    #[test]
576    fn test_graph_tools_has_tool() {
577        let tools = GraphTools::new(GraphData::new());
578        assert!(tools.has_tool("graph_related"));
579        assert!(tools.has_tool("graph_info"));
580        assert!(!tools.has_tool("graph_delete"));
581    }
582
583    // -- graph_info tests ---------------------------------------------------
584
585    #[tokio::test]
586    async fn test_graph_info_empty() {
587        let tools = GraphTools::new(GraphData::new());
588        let future = tools.call("graph_info", json!({})).unwrap();
589        let result = future.await.unwrap();
590        assert_eq!(result.is_error, Some(false));
591    }
592
593    #[tokio::test]
594    async fn test_graph_info_with_data() {
595        let tools = GraphTools::new(make_test_graph());
596        let future = tools.call("graph_info", json!({})).unwrap();
597        let result = future.await.unwrap();
598        assert_eq!(result.is_error, Some(false));
599    }
600
601    // -- graph_validate tests -----------------------------------------------
602
603    #[tokio::test]
604    async fn test_graph_validate_empty() {
605        let tools = GraphTools::new(GraphData::new());
606        let future = tools.call("graph_validate", json!({})).unwrap();
607        let result = future.await.unwrap();
608        assert_eq!(result.is_error, Some(false));
609    }
610
611    #[tokio::test]
612    async fn test_graph_validate_with_data() {
613        let tools = GraphTools::new(make_test_graph());
614        let future = tools.call("graph_validate", json!({})).unwrap();
615        let result = future.await.unwrap();
616        assert_eq!(result.is_error, Some(false));
617    }
618
619    // -- graph_related tests ------------------------------------------------
620
621    #[tokio::test]
622    async fn test_graph_related() {
623        let tools = GraphTools::new(make_test_graph());
624        let future = tools
625            .call("graph_related", json!({"id": "node-a"}))
626            .unwrap();
627        let result = future.await.unwrap();
628        assert_eq!(result.is_error, Some(false));
629    }
630
631    #[tokio::test]
632    async fn test_graph_related_not_found() {
633        let tools = GraphTools::new(make_test_graph());
634        let future = tools
635            .call("graph_related", json!({"id": "missing"}))
636            .unwrap();
637        let result = future.await;
638        assert!(result.is_err());
639    }
640
641    #[tokio::test]
642    async fn test_graph_related_with_limit() {
643        let tools = GraphTools::new(make_test_graph());
644        let future = tools
645            .call("graph_related", json!({"id": "node-b", "limit": 1}))
646            .unwrap();
647        let result = future.await.unwrap();
648        assert_eq!(result.is_error, Some(false));
649    }
650
651    // -- graph_path tests ---------------------------------------------------
652
653    #[tokio::test]
654    async fn test_graph_path() {
655        let tools = GraphTools::new(make_test_graph());
656        let future = tools
657            .call("graph_path", json!({"from": "node-a", "to": "node-c"}))
658            .unwrap();
659        let result = future.await.unwrap();
660        assert_eq!(result.is_error, Some(false));
661    }
662
663    #[tokio::test]
664    async fn test_graph_path_not_found() {
665        let tools = GraphTools::new(make_test_graph());
666        let future = tools
667            .call("graph_path", json!({"from": "node-c", "to": "node-a"}))
668            .unwrap();
669        let result = future.await.unwrap();
670        // Should return found: false, not an error
671        assert_eq!(result.is_error, Some(false));
672    }
673
674    // -- graph_prerequisites tests ------------------------------------------
675
676    #[tokio::test]
677    async fn test_graph_prerequisites() {
678        let tools = GraphTools::new(make_test_graph());
679        let future = tools
680            .call("graph_prerequisites", json!({"id": "node-b"}))
681            .unwrap();
682        let result = future.await.unwrap();
683        assert_eq!(result.is_error, Some(false));
684    }
685
686    // -- graph_neighborhood tests -------------------------------------------
687
688    #[tokio::test]
689    async fn test_graph_neighborhood() {
690        let tools = GraphTools::new(make_test_graph());
691        let future = tools
692            .call("graph_neighborhood", json!({"id": "node-b"}))
693            .unwrap();
694        let result = future.await.unwrap();
695        assert_eq!(result.is_error, Some(false));
696    }
697
698    #[tokio::test]
699    async fn test_graph_neighborhood_with_radius() {
700        let tools = GraphTools::new(make_test_graph());
701        let future = tools
702            .call("graph_neighborhood", json!({"id": "node-a", "radius": 2}))
703            .unwrap();
704        let result = future.await.unwrap();
705        assert_eq!(result.is_error, Some(false));
706    }
707
708    // -- graph_centrality tests ---------------------------------------------
709
710    #[tokio::test]
711    async fn test_graph_centrality() {
712        let tools = GraphTools::new(make_test_graph());
713        let future = tools.call("graph_centrality", json!({"limit": 5})).unwrap();
714        let result = future.await.unwrap();
715        assert_eq!(result.is_error, Some(false));
716    }
717
718    // -- graph_bridges tests ------------------------------------------------
719
720    #[tokio::test]
721    async fn test_graph_bridges() {
722        let tools = GraphTools::new(make_test_graph());
723        let future = tools.call("graph_bridges", json!({"limit": 5})).unwrap();
724        let result = future.await.unwrap();
725        assert_eq!(result.is_error, Some(false));
726    }
727
728    // -- Shared state tests -------------------------------------------------
729
730    #[tokio::test]
731    async fn test_graph_update() {
732        let tools = GraphTools::new(GraphData::new());
733
734        // Initial: empty graph
735        let future = tools.call("graph_info", json!({})).unwrap();
736        let _result = future.await.unwrap();
737
738        // Update with populated graph
739        tools.update_graph(make_test_graph()).await;
740
741        let future = tools.call("graph_info", json!({})).unwrap();
742        let result = future.await.unwrap();
743        assert_eq!(result.is_error, Some(false));
744    }
745
746    // -- Unknown tool test --------------------------------------------------
747
748    #[test]
749    fn test_graph_tools_unknown_tool() {
750        let tools = GraphTools::new(GraphData::new());
751        assert!(tools.call("graph_delete", json!({})).is_none());
752    }
753
754    // -- parse_relationship tests -------------------------------------------
755
756    #[test]
757    fn test_parse_relationship_known() {
758        assert!(matches!(
759            parse_relationship("prerequisite"),
760            Relationship::Prerequisite
761        ));
762        assert!(matches!(
763            parse_relationship("leads_to"),
764            Relationship::LeadsTo
765        ));
766        assert!(matches!(
767            parse_relationship("relates_to"),
768            Relationship::RelatesTo
769        ));
770        assert!(matches!(
771            parse_relationship("extends"),
772            Relationship::Extends
773        ));
774    }
775
776    // -- Custom name/description tests -------------------------------------
777
778    #[test]
779    fn test_graph_tools_with_custom_names() {
780        let tools = GraphTools::new(GraphData::new()).with_names(HashMap::from([
781            (
782                "graph_related".to_string(),
783                "get_related_concepts".to_string(),
784            ),
785            ("graph_path".to_string(), "find_concept_path".to_string()),
786            ("graph_info".to_string(), "graph_stats".to_string()),
787        ]));
788        let tool_list = tools.tools();
789        let names: Vec<&str> = tool_list.iter().map(|t| t.name.as_ref()).collect();
790        assert!(names.contains(&"get_related_concepts"));
791        assert!(names.contains(&"find_concept_path"));
792        assert!(names.contains(&"graph_stats"));
793        // Unrenamed tools keep defaults
794        assert!(names.contains(&"graph_prerequisites"));
795    }
796
797    #[tokio::test]
798    async fn test_graph_tools_custom_names_dispatch() {
799        let tools = GraphTools::new(make_test_graph()).with_names(HashMap::from([(
800            "graph_info".to_string(),
801            "graph_stats".to_string(),
802        )]));
803        // Old name should NOT work
804        assert!(tools.call("graph_info", json!({})).is_none());
805        // Custom name should work
806        let future = tools.call("graph_stats", json!({})).unwrap();
807        let result = future.await.unwrap();
808        assert_eq!(result.is_error, Some(false));
809    }
810
811    #[test]
812    fn test_parse_relationship_custom() {
813        match parse_relationship("my_custom") {
814            Relationship::Custom(s) => assert_eq!(s, "my_custom"),
815            _ => panic!("Expected Custom relationship"),
816        }
817    }
818}