agentic_memory_mcp/tools/
memory_traverse.rs1use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde::Deserialize;
7use serde_json::{json, Value};
8
9use agentic_memory::{EdgeType, TraversalDirection, TraversalParams};
10
11use crate::session::SessionManager;
12use crate::types::{McpError, McpResult, ToolCallResult, ToolDefinition};
13
14#[derive(Debug, Deserialize)]
15struct TraverseParams {
16 start_id: u64,
17 #[serde(default)]
18 edge_types: Vec<String>,
19 #[serde(default = "default_direction")]
20 direction: String,
21 #[serde(default = "default_max_depth")]
22 max_depth: u32,
23 #[serde(default = "default_max_results")]
24 max_results: usize,
25 min_confidence: Option<f32>,
26}
27
28fn default_direction() -> String {
29 "forward".to_string()
30}
31
32fn default_max_depth() -> u32 {
33 5
34}
35
36fn default_max_results() -> usize {
37 20
38}
39
40pub fn definition() -> ToolDefinition {
42 ToolDefinition {
43 name: "memory_traverse".to_string(),
44 description: Some(
45 "Walk the graph from a starting node, following edges of specified types".to_string(),
46 ),
47 input_schema: json!({
48 "type": "object",
49 "properties": {
50 "start_id": { "type": "integer", "description": "Starting node ID" },
51 "edge_types": { "type": "array", "items": { "type": "string" } },
52 "direction": { "type": "string", "enum": ["forward", "backward", "both"], "default": "forward" },
53 "max_depth": { "type": "integer", "default": 5 },
54 "max_results": { "type": "integer", "default": 20 },
55 "min_confidence": { "type": "number" }
56 },
57 "required": ["start_id"]
58 }),
59 }
60}
61
62pub async fn execute(
64 args: Value,
65 session: &Arc<Mutex<SessionManager>>,
66) -> McpResult<ToolCallResult> {
67 let params: TraverseParams =
68 serde_json::from_value(args).map_err(|e| McpError::InvalidParams(e.to_string()))?;
69
70 let edge_types: Vec<EdgeType> = if params.edge_types.is_empty() {
71 vec![
72 EdgeType::CausedBy,
73 EdgeType::Supports,
74 EdgeType::Contradicts,
75 EdgeType::Supersedes,
76 EdgeType::RelatedTo,
77 EdgeType::PartOf,
78 EdgeType::TemporalNext,
79 ]
80 } else {
81 params
82 .edge_types
83 .iter()
84 .filter_map(|name| EdgeType::from_name(name))
85 .collect()
86 };
87
88 let direction = match params.direction.as_str() {
89 "backward" => TraversalDirection::Backward,
90 "both" => TraversalDirection::Both,
91 _ => TraversalDirection::Forward,
92 };
93
94 let traversal = TraversalParams {
95 start_id: params.start_id,
96 edge_types,
97 direction,
98 max_depth: params.max_depth,
99 max_results: params.max_results,
100 min_confidence: params.min_confidence.unwrap_or(0.0),
101 };
102
103 let session = session.lock().await;
104 let result = session
105 .query_engine()
106 .traverse(session.graph(), traversal)
107 .map_err(|e| McpError::AgenticMemory(format!("Traversal failed: {e}")))?;
108
109 let visited: Vec<Value> = result
110 .visited
111 .iter()
112 .filter_map(|id| {
113 session.graph().get_node(*id).map(|node| {
114 json!({
115 "id": node.id,
116 "event_type": node.event_type.name(),
117 "content": node.content,
118 "confidence": node.confidence,
119 "depth": result.depths.get(id).copied().unwrap_or(0),
120 })
121 })
122 })
123 .collect();
124
125 let edges: Vec<Value> = result
126 .edges_traversed
127 .iter()
128 .map(|e| {
129 json!({
130 "source_id": e.source_id,
131 "target_id": e.target_id,
132 "edge_type": e.edge_type.name(),
133 "weight": e.weight,
134 })
135 })
136 .collect();
137
138 Ok(ToolCallResult::json(&json!({
139 "start_id": params.start_id,
140 "visited_count": visited.len(),
141 "visited": visited,
142 "edges_traversed": edges,
143 })))
144}