Skip to main content

agentic_memory_mcp/tools/
memory_context.rs

1//! Tool: memory_context — Get full context (subgraph) around a node.
2
3use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde::Deserialize;
7use serde_json::{json, Value};
8
9use crate::session::SessionManager;
10use crate::types::{McpError, McpResult, ToolCallResult, ToolDefinition};
11
12#[derive(Debug, Deserialize)]
13struct ContextParams {
14    node_id: u64,
15    #[serde(default = "default_depth")]
16    depth: u32,
17}
18
19fn default_depth() -> u32 {
20    2
21}
22
23/// Return the tool definition for memory_context.
24pub fn definition() -> ToolDefinition {
25    ToolDefinition {
26        name: "memory_context".to_string(),
27        description: Some("Get the full context (subgraph) around a node".to_string()),
28        input_schema: json!({
29            "type": "object",
30            "properties": {
31                "node_id": { "type": "integer" },
32                "depth": { "type": "integer", "default": 2, "minimum": 1, "maximum": 5 }
33            },
34            "required": ["node_id"]
35        }),
36    }
37}
38
39/// Execute the memory_context tool.
40pub async fn execute(
41    args: Value,
42    session: &Arc<Mutex<SessionManager>>,
43) -> McpResult<ToolCallResult> {
44    let params: ContextParams =
45        serde_json::from_value(args).map_err(|e| McpError::InvalidParams(e.to_string()))?;
46
47    let session = session.lock().await;
48
49    let subgraph = session
50        .query_engine()
51        .context(session.graph(), params.node_id, params.depth)
52        .map_err(|e| McpError::AgenticMemory(format!("Context query failed: {e}")))?;
53
54    let nodes: Vec<Value> = subgraph
55        .nodes
56        .iter()
57        .map(|event| {
58            json!({
59                "id": event.id,
60                "event_type": event.event_type.name(),
61                "content": event.content,
62                "confidence": event.confidence,
63                "session_id": event.session_id,
64            })
65        })
66        .collect();
67
68    let edges: Vec<Value> = subgraph
69        .edges
70        .iter()
71        .map(|e| {
72            json!({
73                "source_id": e.source_id,
74                "target_id": e.target_id,
75                "edge_type": e.edge_type.name(),
76                "weight": e.weight,
77            })
78        })
79        .collect();
80
81    Ok(ToolCallResult::json(&json!({
82        "center_id": subgraph.center_id,
83        "depth": params.depth,
84        "node_count": nodes.len(),
85        "edge_count": edges.len(),
86        "nodes": nodes,
87        "edges": edges,
88    })))
89}