agentic_memory_mcp/tools/
memory_causal.rs1use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde::Deserialize;
7use serde_json::{json, Value};
8
9use agentic_memory::{CausalParams, EdgeType};
10
11use crate::session::SessionManager;
12use crate::types::{McpError, McpResult, ToolCallResult, ToolDefinition};
13
14#[derive(Debug, Deserialize)]
15struct CausalInputParams {
16 node_id: u64,
17 #[serde(default = "default_max_depth")]
18 max_depth: u32,
19}
20
21fn default_max_depth() -> u32 {
22 5
23}
24
25pub fn definition() -> ToolDefinition {
27 ToolDefinition {
28 name: "memory_causal".to_string(),
29 description: Some(
30 "Impact analysis — find everything that depends on a given node".to_string(),
31 ),
32 input_schema: json!({
33 "type": "object",
34 "properties": {
35 "node_id": { "type": "integer" },
36 "max_depth": { "type": "integer", "default": 5 }
37 },
38 "required": ["node_id"]
39 }),
40 }
41}
42
43pub async fn execute(
45 args: Value,
46 session: &Arc<Mutex<SessionManager>>,
47) -> McpResult<ToolCallResult> {
48 let params: CausalInputParams =
49 serde_json::from_value(args).map_err(|e| McpError::InvalidParams(e.to_string()))?;
50
51 let causal_params = CausalParams {
52 node_id: params.node_id,
53 max_depth: params.max_depth,
54 dependency_types: vec![EdgeType::CausedBy, EdgeType::Supports],
55 };
56
57 let session = session.lock().await;
58
59 let result = session
60 .query_engine()
61 .causal(session.graph(), causal_params)
62 .map_err(|e| McpError::AgenticMemory(format!("Causal analysis failed: {e}")))?;
63
64 let dependents: Vec<Value> = result
65 .dependents
66 .iter()
67 .filter_map(|id| {
68 session.graph().get_node(*id).map(|node| {
69 json!({
70 "id": node.id,
71 "event_type": node.event_type.name(),
72 "content": node.content,
73 "confidence": node.confidence,
74 })
75 })
76 })
77 .collect();
78
79 Ok(ToolCallResult::json(&json!({
80 "root_id": result.root_id,
81 "dependent_count": result.dependents.len(),
82 "affected_decisions": result.affected_decisions,
83 "affected_inferences": result.affected_inferences,
84 "dependents": dependents,
85 })))
86}