Skip to main content

agentic_memory_mcp/tools/
memory_traverse.rs

1//! Tool: memory_traverse — Walk the graph from a starting node.
2
3use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde::Deserialize;
7use serde_json::{json, Value};
8
9use agentic_memory::{EdgeType, TraversalDirection, TraversalParams};
10
11use crate::session::SessionManager;
12use crate::types::{McpError, McpResult, ToolCallResult, ToolDefinition};
13
14#[derive(Debug, Deserialize)]
15struct TraverseParams {
16    start_id: u64,
17    #[serde(default)]
18    edge_types: Vec<String>,
19    #[serde(default = "default_direction")]
20    direction: String,
21    #[serde(default = "default_max_depth")]
22    max_depth: u32,
23    #[serde(default = "default_max_results")]
24    max_results: usize,
25    min_confidence: Option<f32>,
26}
27
28fn default_direction() -> String {
29    "forward".to_string()
30}
31
32fn default_max_depth() -> u32 {
33    5
34}
35
36fn default_max_results() -> usize {
37    20
38}
39
40/// Return the tool definition for memory_traverse.
41pub fn definition() -> ToolDefinition {
42    ToolDefinition {
43        name: "memory_traverse".to_string(),
44        description: Some(
45            "Walk the graph from a starting node, following edges of specified types".to_string(),
46        ),
47        input_schema: json!({
48            "type": "object",
49            "properties": {
50                "start_id": { "type": "integer", "description": "Starting node ID" },
51                "edge_types": { "type": "array", "items": { "type": "string" } },
52                "direction": { "type": "string", "enum": ["forward", "backward", "both"], "default": "forward" },
53                "max_depth": { "type": "integer", "default": 5 },
54                "max_results": { "type": "integer", "default": 20 },
55                "min_confidence": { "type": "number" }
56            },
57            "required": ["start_id"]
58        }),
59    }
60}
61
62/// Execute the memory_traverse tool.
63pub async fn execute(
64    args: Value,
65    session: &Arc<Mutex<SessionManager>>,
66) -> McpResult<ToolCallResult> {
67    let params: TraverseParams =
68        serde_json::from_value(args).map_err(|e| McpError::InvalidParams(e.to_string()))?;
69
70    let edge_types: Vec<EdgeType> = if params.edge_types.is_empty() {
71        vec![
72            EdgeType::CausedBy,
73            EdgeType::Supports,
74            EdgeType::Contradicts,
75            EdgeType::Supersedes,
76            EdgeType::RelatedTo,
77            EdgeType::PartOf,
78            EdgeType::TemporalNext,
79        ]
80    } else {
81        params
82            .edge_types
83            .iter()
84            .filter_map(|name| EdgeType::from_name(name))
85            .collect()
86    };
87
88    let direction = match params.direction.as_str() {
89        "backward" => TraversalDirection::Backward,
90        "both" => TraversalDirection::Both,
91        _ => TraversalDirection::Forward,
92    };
93
94    let traversal = TraversalParams {
95        start_id: params.start_id,
96        edge_types,
97        direction,
98        max_depth: params.max_depth,
99        max_results: params.max_results,
100        min_confidence: params.min_confidence.unwrap_or(0.0),
101    };
102
103    let session = session.lock().await;
104    let result = session
105        .query_engine()
106        .traverse(session.graph(), traversal)
107        .map_err(|e| McpError::AgenticMemory(format!("Traversal failed: {e}")))?;
108
109    let visited: Vec<Value> = result
110        .visited
111        .iter()
112        .filter_map(|id| {
113            session.graph().get_node(*id).map(|node| {
114                json!({
115                    "id": node.id,
116                    "event_type": node.event_type.name(),
117                    "content": node.content,
118                    "confidence": node.confidence,
119                    "depth": result.depths.get(id).copied().unwrap_or(0),
120                })
121            })
122        })
123        .collect();
124
125    let edges: Vec<Value> = result
126        .edges_traversed
127        .iter()
128        .map(|e| {
129            json!({
130                "source_id": e.source_id,
131                "target_id": e.target_id,
132                "edge_type": e.edge_type.name(),
133                "weight": e.weight,
134            })
135        })
136        .collect();
137
138    Ok(ToolCallResult::json(&json!({
139        "start_id": params.start_id,
140        "visited_count": visited.len(),
141        "visited": visited,
142        "edges_traversed": edges,
143    })))
144}