1use fabryk_mcp_core::error::McpErrorExt;
7use fabryk_mcp_core::model::{CallToolResult, Content, ErrorData, Tool};
8use fabryk_mcp_core::registry::{ToolRegistry, ToolResult};
9
10use fabryk_graph::{
11 EdgeInfo, GraphData, NeighborInfo, NodeSummary, PathStep, Relationship, calculate_centrality,
12 compute_stats, find_bridges, neighborhood, prerequisites_sorted, shortest_path, validate_graph,
13};
14use serde::Deserialize;
15use serde_json::{Value, json};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19
20fn json_schema(value: Value) -> Arc<serde_json::Map<String, Value>> {
25 match value {
26 Value::Object(map) => Arc::new(map),
27 _ => Arc::new(serde_json::Map::new()),
28 }
29}
30
31fn make_tool(name: &str, description: &str, schema: Value) -> Tool {
32 Tool::new(
33 name.to_string(),
34 description.to_string(),
35 json_schema(schema),
36 )
37}
38
39fn serialize_response<T: serde::Serialize>(value: &T) -> Result<CallToolResult, ErrorData> {
40 let json = serde_json::to_string_pretty(value)
41 .map_err(|e| ErrorData::internal_error(e.to_string(), None))?;
42 Ok(CallToolResult::success(vec![Content::text(json)]))
43}
44
45fn parse_relationship(s: &str) -> Relationship {
46 match s.to_lowercase().as_str() {
47 "prerequisite" => Relationship::Prerequisite,
48 "leads_to" | "leadsto" => Relationship::LeadsTo,
49 "relates_to" | "relatesto" => Relationship::RelatesTo,
50 "extends" => Relationship::Extends,
51 "introduces" => Relationship::Introduces,
52 "covers" => Relationship::Covers,
53 "variant_of" | "variantof" => Relationship::VariantOf,
54 other => Relationship::Custom(other.to_string()),
55 }
56}
57
58#[derive(Debug, Deserialize)]
64pub struct RelatedArgs {
65 pub id: String,
67 pub relationship: Option<String>,
69 pub limit: Option<usize>,
71}
72
73#[derive(Debug, Deserialize)]
75pub struct PathArgs {
76 pub from: String,
78 pub to: String,
80}
81
82#[derive(Debug, Deserialize)]
84pub struct PrerequisitesArgs {
85 pub id: String,
87}
88
89#[derive(Debug, Deserialize)]
91pub struct NeighborhoodArgs {
92 pub id: String,
94 pub radius: Option<usize>,
96 pub relationship: Option<String>,
98}
99
100pub struct GraphTools {
126 graph: Arc<RwLock<GraphData>>,
127 custom_names: HashMap<String, String>,
128 custom_descriptions: HashMap<String, String>,
129}
130
131impl GraphTools {
132 pub const SLOT_RELATED: &str = "graph_related";
134 pub const SLOT_PATH: &str = "graph_path";
136 pub const SLOT_PREREQUISITES: &str = "graph_prerequisites";
138 pub const SLOT_NEIGHBORHOOD: &str = "graph_neighborhood";
140 pub const SLOT_INFO: &str = "graph_info";
142 pub const SLOT_VALIDATE: &str = "graph_validate";
144 pub const SLOT_CENTRALITY: &str = "graph_centrality";
146 pub const SLOT_BRIDGES: &str = "graph_bridges";
148
149 pub fn new(graph: GraphData) -> Self {
151 Self {
152 graph: Arc::new(RwLock::new(graph)),
153 custom_names: HashMap::new(),
154 custom_descriptions: HashMap::new(),
155 }
156 }
157
158 pub fn with_shared(graph: Arc<RwLock<GraphData>>) -> Self {
160 Self {
161 graph,
162 custom_names: HashMap::new(),
163 custom_descriptions: HashMap::new(),
164 }
165 }
166
167 pub fn with_names(mut self, names: HashMap<String, String>) -> Self {
169 self.custom_names = names;
170 self
171 }
172
173 pub fn with_descriptions(mut self, descriptions: HashMap<String, String>) -> Self {
175 self.custom_descriptions = descriptions;
176 self
177 }
178
179 pub async fn update_graph(&self, graph: GraphData) {
181 let mut lock = self.graph.write().await;
182 *lock = graph;
183 }
184
185 fn tool_name(&self, slot: &str) -> String {
186 self.custom_names
187 .get(slot)
188 .cloned()
189 .unwrap_or_else(|| slot.to_string())
190 }
191
192 fn tool_description(&self, slot: &str, default: &str) -> String {
193 self.custom_descriptions
194 .get(slot)
195 .cloned()
196 .unwrap_or_else(|| default.to_string())
197 }
198}
199
200impl ToolRegistry for GraphTools {
201 fn tools(&self) -> Vec<Tool> {
202 vec![
203 make_tool(
204 &self.tool_name(Self::SLOT_RELATED),
205 &self.tool_description(Self::SLOT_RELATED, "Find nodes related to a given node"),
206 json!({
207 "type": "object",
208 "properties": {
209 "id": {
210 "type": "string",
211 "description": "Node ID"
212 },
213 "relationship": {
214 "type": "string",
215 "description": "Filter by relationship type (e.g., prerequisite, relates_to)"
216 },
217 "limit": {
218 "type": "integer",
219 "description": "Maximum results"
220 }
221 },
222 "required": ["id"]
223 }),
224 ),
225 make_tool(
226 &self.tool_name(Self::SLOT_PATH),
227 &self.tool_description(Self::SLOT_PATH, "Find the shortest path between two nodes"),
228 json!({
229 "type": "object",
230 "properties": {
231 "from": {
232 "type": "string",
233 "description": "Starting node ID"
234 },
235 "to": {
236 "type": "string",
237 "description": "Target node ID"
238 }
239 },
240 "required": ["from", "to"]
241 }),
242 ),
243 make_tool(
244 &self.tool_name(Self::SLOT_PREREQUISITES),
245 &self.tool_description(
246 Self::SLOT_PREREQUISITES,
247 "Get prerequisites for a node in learning order",
248 ),
249 json!({
250 "type": "object",
251 "properties": {
252 "id": {
253 "type": "string",
254 "description": "Node ID"
255 }
256 },
257 "required": ["id"]
258 }),
259 ),
260 make_tool(
261 &self.tool_name(Self::SLOT_NEIGHBORHOOD),
262 &self.tool_description(
263 Self::SLOT_NEIGHBORHOOD,
264 "Explore the neighborhood around a node",
265 ),
266 json!({
267 "type": "object",
268 "properties": {
269 "id": {
270 "type": "string",
271 "description": "Center node ID"
272 },
273 "radius": {
274 "type": "integer",
275 "description": "Hops from center (default 1)"
276 },
277 "relationship": {
278 "type": "string",
279 "description": "Filter by relationship type"
280 }
281 },
282 "required": ["id"]
283 }),
284 ),
285 make_tool(
286 &self.tool_name(Self::SLOT_INFO),
287 &self.tool_description(Self::SLOT_INFO, "Get graph statistics and overview"),
288 json!({
289 "type": "object",
290 "properties": {}
291 }),
292 ),
293 make_tool(
294 &self.tool_name(Self::SLOT_VALIDATE),
295 &self.tool_description(
296 Self::SLOT_VALIDATE,
297 "Validate graph structure and report issues",
298 ),
299 json!({
300 "type": "object",
301 "properties": {}
302 }),
303 ),
304 make_tool(
305 &self.tool_name(Self::SLOT_CENTRALITY),
306 &self.tool_description(Self::SLOT_CENTRALITY, "Get most central/important nodes"),
307 json!({
308 "type": "object",
309 "properties": {
310 "limit": {
311 "type": "integer",
312 "description": "Number of results (default 10)"
313 }
314 }
315 }),
316 ),
317 make_tool(
318 &self.tool_name(Self::SLOT_BRIDGES),
319 &self.tool_description(
320 Self::SLOT_BRIDGES,
321 "Find bridge nodes that connect different areas",
322 ),
323 json!({
324 "type": "object",
325 "properties": {
326 "limit": {
327 "type": "integer",
328 "description": "Number of results (default 10)"
329 }
330 }
331 }),
332 ),
333 ]
334 }
335
336 fn call(&self, name: &str, args: Value) -> Option<ToolResult> {
337 let graph = Arc::clone(&self.graph);
338
339 if name == self.tool_name(Self::SLOT_RELATED) {
340 return Some(Box::pin(async move {
341 let args: RelatedArgs = serde_json::from_value(args)
342 .map_err(|e| ErrorData::invalid_params(e.to_string(), None))?;
343 let graph = graph.read().await;
344
345 let rel_filter = args
346 .relationship
347 .as_deref()
348 .map(|r| vec![parse_relationship(r)]);
349
350 let result = neighborhood(&graph, &args.id, 1, rel_filter.as_deref())
351 .map_err(|e| e.to_mcp_error())?;
352
353 let mut nodes: Vec<NodeSummary> =
354 result.nodes.iter().map(NodeSummary::from).collect();
355
356 if let Some(limit) = args.limit {
357 nodes.truncate(limit);
358 }
359
360 let count = nodes.len();
361 let response = json!({
362 "source": NodeSummary::from(&result.center),
363 "related": nodes,
364 "count": count
365 });
366 serialize_response(&response)
367 }));
368 }
369
370 if name == self.tool_name(Self::SLOT_PATH) {
371 return Some(Box::pin(async move {
372 let args: PathArgs = serde_json::from_value(args)
373 .map_err(|e| ErrorData::invalid_params(e.to_string(), None))?;
374 let graph = graph.read().await;
375
376 let result =
377 shortest_path(&graph, &args.from, &args.to).map_err(|e| e.to_mcp_error())?;
378
379 if result.found {
380 let path: Vec<PathStep> = result
381 .path
382 .iter()
383 .enumerate()
384 .map(|(i, node)| {
385 let rel = result
386 .edges
387 .get(i)
388 .map(|e| e.relationship.name().to_string());
389 PathStep {
390 node: NodeSummary::from(node),
391 relationship_to_next: rel,
392 }
393 })
394 .collect();
395
396 let response = json!({
397 "found": true,
398 "path": path,
399 "length": path.len(),
400 "total_weight": result.total_weight
401 });
402 serialize_response(&response)
403 } else {
404 let response = json!({
405 "found": false,
406 "message": format!("No path found from {} to {}", args.from, args.to)
407 });
408 serialize_response(&response)
409 }
410 }));
411 }
412
413 if name == self.tool_name(Self::SLOT_PREREQUISITES) {
414 return Some(Box::pin(async move {
415 let args: PrerequisitesArgs = serde_json::from_value(args)
416 .map_err(|e| ErrorData::invalid_params(e.to_string(), None))?;
417 let graph = graph.read().await;
418
419 let result =
420 prerequisites_sorted(&graph, &args.id).map_err(|e| e.to_mcp_error())?;
421
422 let prereqs: Vec<NodeSummary> =
423 result.ordered.iter().map(NodeSummary::from).collect();
424
425 let count = prereqs.len();
426 let response = json!({
427 "target": NodeSummary::from(&result.target),
428 "prerequisites": prereqs,
429 "count": count,
430 "has_cycles": result.has_cycles
431 });
432 serialize_response(&response)
433 }));
434 }
435
436 if name == self.tool_name(Self::SLOT_NEIGHBORHOOD) {
437 return Some(Box::pin(async move {
438 let args: NeighborhoodArgs = serde_json::from_value(args)
439 .map_err(|e| ErrorData::invalid_params(e.to_string(), None))?;
440 let graph = graph.read().await;
441
442 let radius = args.radius.unwrap_or(1);
443 let rel_filter = args
444 .relationship
445 .as_deref()
446 .map(|r| vec![parse_relationship(r)]);
447
448 let result = neighborhood(&graph, &args.id, radius, rel_filter.as_deref())
449 .map_err(|e| e.to_mcp_error())?;
450
451 let nodes: Vec<NeighborInfo> = result
452 .nodes
453 .iter()
454 .map(|n| {
455 let distance = result.distances.get(&n.id).copied().unwrap_or(0);
456 NeighborInfo {
457 node: NodeSummary::from(n),
458 distance,
459 }
460 })
461 .collect();
462
463 let edges: Vec<EdgeInfo> = result.edges.iter().map(EdgeInfo::from).collect();
464
465 let response = json!({
466 "center": NodeSummary::from(&result.center),
467 "radius": radius,
468 "nodes": nodes,
469 "edges": edges,
470 "edge_count": edges.len()
471 });
472 serialize_response(&response)
473 }));
474 }
475
476 if name == self.tool_name(Self::SLOT_INFO) {
477 return Some(Box::pin(async move {
478 let graph = graph.read().await;
479 let stats = compute_stats(&graph);
480 serialize_response(&stats)
481 }));
482 }
483
484 if name == self.tool_name(Self::SLOT_VALIDATE) {
485 return Some(Box::pin(async move {
486 let graph = graph.read().await;
487 let result = validate_graph(&graph);
488 serialize_response(&result)
489 }));
490 }
491
492 if name == self.tool_name(Self::SLOT_CENTRALITY) {
493 return Some(Box::pin(async move {
494 let limit = args
495 .get("limit")
496 .and_then(|v| v.as_u64())
497 .map(|n| n as usize)
498 .unwrap_or(10);
499
500 let graph = graph.read().await;
501 let scores = calculate_centrality(&graph);
502
503 let top: Vec<_> = scores.into_iter().take(limit).collect();
504 serialize_response(&top)
505 }));
506 }
507
508 if name == self.tool_name(Self::SLOT_BRIDGES) {
509 return Some(Box::pin(async move {
510 let limit = args
511 .get("limit")
512 .and_then(|v| v.as_u64())
513 .map(|n| n as usize)
514 .unwrap_or(10);
515
516 let graph = graph.read().await;
517 let bridges = find_bridges(&graph, limit);
518
519 let summaries: Vec<NodeSummary> = bridges.iter().map(NodeSummary::from).collect();
520 serialize_response(&summaries)
521 }));
522 }
523
524 None
525 }
526}
527
528#[cfg(test)]
533mod tests {
534 use super::*;
535 use fabryk_graph::{Edge, Node};
536
537 fn make_test_graph() -> GraphData {
538 let mut graph = GraphData::new();
539
540 graph.add_node(Node::new("node-a", "Node A").with_category("alpha"));
542 graph.add_node(Node::new("node-b", "Node B").with_category("beta"));
543 graph.add_node(Node::new("node-c", "Node C").with_category("alpha"));
544
545 let _ = graph.add_edge(Edge::new("node-a", "node-b", Relationship::Prerequisite));
547 let _ = graph.add_edge(Edge::new("node-b", "node-c", Relationship::RelatesTo));
548
549 graph
550 }
551
552 #[test]
555 fn test_graph_tools_creation() {
556 let tools = GraphTools::new(GraphData::new());
557 assert_eq!(tools.tool_count(), 8);
558 }
559
560 #[test]
561 fn test_graph_tools_names() {
562 let tools = GraphTools::new(GraphData::new());
563 let tool_list = tools.tools();
564 let names: Vec<&str> = tool_list.iter().map(|t| t.name.as_ref()).collect();
565 assert!(names.contains(&"graph_related"));
566 assert!(names.contains(&"graph_path"));
567 assert!(names.contains(&"graph_prerequisites"));
568 assert!(names.contains(&"graph_neighborhood"));
569 assert!(names.contains(&"graph_info"));
570 assert!(names.contains(&"graph_validate"));
571 assert!(names.contains(&"graph_centrality"));
572 assert!(names.contains(&"graph_bridges"));
573 }
574
575 #[test]
576 fn test_graph_tools_has_tool() {
577 let tools = GraphTools::new(GraphData::new());
578 assert!(tools.has_tool("graph_related"));
579 assert!(tools.has_tool("graph_info"));
580 assert!(!tools.has_tool("graph_delete"));
581 }
582
583 #[tokio::test]
586 async fn test_graph_info_empty() {
587 let tools = GraphTools::new(GraphData::new());
588 let future = tools.call("graph_info", json!({})).unwrap();
589 let result = future.await.unwrap();
590 assert_eq!(result.is_error, Some(false));
591 }
592
593 #[tokio::test]
594 async fn test_graph_info_with_data() {
595 let tools = GraphTools::new(make_test_graph());
596 let future = tools.call("graph_info", json!({})).unwrap();
597 let result = future.await.unwrap();
598 assert_eq!(result.is_error, Some(false));
599 }
600
601 #[tokio::test]
604 async fn test_graph_validate_empty() {
605 let tools = GraphTools::new(GraphData::new());
606 let future = tools.call("graph_validate", json!({})).unwrap();
607 let result = future.await.unwrap();
608 assert_eq!(result.is_error, Some(false));
609 }
610
611 #[tokio::test]
612 async fn test_graph_validate_with_data() {
613 let tools = GraphTools::new(make_test_graph());
614 let future = tools.call("graph_validate", json!({})).unwrap();
615 let result = future.await.unwrap();
616 assert_eq!(result.is_error, Some(false));
617 }
618
619 #[tokio::test]
622 async fn test_graph_related() {
623 let tools = GraphTools::new(make_test_graph());
624 let future = tools
625 .call("graph_related", json!({"id": "node-a"}))
626 .unwrap();
627 let result = future.await.unwrap();
628 assert_eq!(result.is_error, Some(false));
629 }
630
631 #[tokio::test]
632 async fn test_graph_related_not_found() {
633 let tools = GraphTools::new(make_test_graph());
634 let future = tools
635 .call("graph_related", json!({"id": "missing"}))
636 .unwrap();
637 let result = future.await;
638 assert!(result.is_err());
639 }
640
641 #[tokio::test]
642 async fn test_graph_related_with_limit() {
643 let tools = GraphTools::new(make_test_graph());
644 let future = tools
645 .call("graph_related", json!({"id": "node-b", "limit": 1}))
646 .unwrap();
647 let result = future.await.unwrap();
648 assert_eq!(result.is_error, Some(false));
649 }
650
651 #[tokio::test]
654 async fn test_graph_path() {
655 let tools = GraphTools::new(make_test_graph());
656 let future = tools
657 .call("graph_path", json!({"from": "node-a", "to": "node-c"}))
658 .unwrap();
659 let result = future.await.unwrap();
660 assert_eq!(result.is_error, Some(false));
661 }
662
663 #[tokio::test]
664 async fn test_graph_path_not_found() {
665 let tools = GraphTools::new(make_test_graph());
666 let future = tools
667 .call("graph_path", json!({"from": "node-c", "to": "node-a"}))
668 .unwrap();
669 let result = future.await.unwrap();
670 assert_eq!(result.is_error, Some(false));
672 }
673
674 #[tokio::test]
677 async fn test_graph_prerequisites() {
678 let tools = GraphTools::new(make_test_graph());
679 let future = tools
680 .call("graph_prerequisites", json!({"id": "node-b"}))
681 .unwrap();
682 let result = future.await.unwrap();
683 assert_eq!(result.is_error, Some(false));
684 }
685
686 #[tokio::test]
689 async fn test_graph_neighborhood() {
690 let tools = GraphTools::new(make_test_graph());
691 let future = tools
692 .call("graph_neighborhood", json!({"id": "node-b"}))
693 .unwrap();
694 let result = future.await.unwrap();
695 assert_eq!(result.is_error, Some(false));
696 }
697
698 #[tokio::test]
699 async fn test_graph_neighborhood_with_radius() {
700 let tools = GraphTools::new(make_test_graph());
701 let future = tools
702 .call("graph_neighborhood", json!({"id": "node-a", "radius": 2}))
703 .unwrap();
704 let result = future.await.unwrap();
705 assert_eq!(result.is_error, Some(false));
706 }
707
708 #[tokio::test]
711 async fn test_graph_centrality() {
712 let tools = GraphTools::new(make_test_graph());
713 let future = tools.call("graph_centrality", json!({"limit": 5})).unwrap();
714 let result = future.await.unwrap();
715 assert_eq!(result.is_error, Some(false));
716 }
717
718 #[tokio::test]
721 async fn test_graph_bridges() {
722 let tools = GraphTools::new(make_test_graph());
723 let future = tools.call("graph_bridges", json!({"limit": 5})).unwrap();
724 let result = future.await.unwrap();
725 assert_eq!(result.is_error, Some(false));
726 }
727
728 #[tokio::test]
731 async fn test_graph_update() {
732 let tools = GraphTools::new(GraphData::new());
733
734 let future = tools.call("graph_info", json!({})).unwrap();
736 let _result = future.await.unwrap();
737
738 tools.update_graph(make_test_graph()).await;
740
741 let future = tools.call("graph_info", json!({})).unwrap();
742 let result = future.await.unwrap();
743 assert_eq!(result.is_error, Some(false));
744 }
745
746 #[test]
749 fn test_graph_tools_unknown_tool() {
750 let tools = GraphTools::new(GraphData::new());
751 assert!(tools.call("graph_delete", json!({})).is_none());
752 }
753
754 #[test]
757 fn test_parse_relationship_known() {
758 assert!(matches!(
759 parse_relationship("prerequisite"),
760 Relationship::Prerequisite
761 ));
762 assert!(matches!(
763 parse_relationship("leads_to"),
764 Relationship::LeadsTo
765 ));
766 assert!(matches!(
767 parse_relationship("relates_to"),
768 Relationship::RelatesTo
769 ));
770 assert!(matches!(
771 parse_relationship("extends"),
772 Relationship::Extends
773 ));
774 }
775
776 #[test]
779 fn test_graph_tools_with_custom_names() {
780 let tools = GraphTools::new(GraphData::new()).with_names(HashMap::from([
781 (
782 "graph_related".to_string(),
783 "get_related_concepts".to_string(),
784 ),
785 ("graph_path".to_string(), "find_concept_path".to_string()),
786 ("graph_info".to_string(), "graph_stats".to_string()),
787 ]));
788 let tool_list = tools.tools();
789 let names: Vec<&str> = tool_list.iter().map(|t| t.name.as_ref()).collect();
790 assert!(names.contains(&"get_related_concepts"));
791 assert!(names.contains(&"find_concept_path"));
792 assert!(names.contains(&"graph_stats"));
793 assert!(names.contains(&"graph_prerequisites"));
795 }
796
797 #[tokio::test]
798 async fn test_graph_tools_custom_names_dispatch() {
799 let tools = GraphTools::new(make_test_graph()).with_names(HashMap::from([(
800 "graph_info".to_string(),
801 "graph_stats".to_string(),
802 )]));
803 assert!(tools.call("graph_info", json!({})).is_none());
805 let future = tools.call("graph_stats", json!({})).unwrap();
807 let result = future.await.unwrap();
808 assert_eq!(result.is_error, Some(false));
809 }
810
811 #[test]
812 fn test_parse_relationship_custom() {
813 match parse_relationship("my_custom") {
814 Relationship::Custom(s) => assert_eq!(s, "my_custom"),
815 _ => panic!("Expected Custom relationship"),
816 }
817 }
818}