Skip to main content

agentic_memory_mcp/tools/
memory_causal.rs

1//! Tool: memory_causal — Impact analysis: what depends on this node?
2
3use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde::Deserialize;
7use serde_json::{json, Value};
8
9use agentic_memory::{CausalParams, EdgeType};
10
11use crate::session::SessionManager;
12use crate::types::{McpError, McpResult, ToolCallResult, ToolDefinition};
13
14#[derive(Debug, Deserialize)]
15struct CausalInputParams {
16    node_id: u64,
17    #[serde(default = "default_max_depth")]
18    max_depth: u32,
19}
20
21fn default_max_depth() -> u32 {
22    5
23}
24
25/// Return the tool definition for memory_causal.
26pub fn definition() -> ToolDefinition {
27    ToolDefinition {
28        name: "memory_causal".to_string(),
29        description: Some(
30            "Impact analysis — find everything that depends on a given node".to_string(),
31        ),
32        input_schema: json!({
33            "type": "object",
34            "properties": {
35                "node_id": { "type": "integer" },
36                "max_depth": { "type": "integer", "default": 5 }
37            },
38            "required": ["node_id"]
39        }),
40    }
41}
42
43/// Execute the memory_causal tool.
44pub async fn execute(
45    args: Value,
46    session: &Arc<Mutex<SessionManager>>,
47) -> McpResult<ToolCallResult> {
48    let params: CausalInputParams =
49        serde_json::from_value(args).map_err(|e| McpError::InvalidParams(e.to_string()))?;
50
51    let causal_params = CausalParams {
52        node_id: params.node_id,
53        max_depth: params.max_depth,
54        dependency_types: vec![EdgeType::CausedBy, EdgeType::Supports],
55    };
56
57    let session = session.lock().await;
58
59    let result = session
60        .query_engine()
61        .causal(session.graph(), causal_params)
62        .map_err(|e| McpError::AgenticMemory(format!("Causal analysis failed: {e}")))?;
63
64    let dependents: Vec<Value> = result
65        .dependents
66        .iter()
67        .filter_map(|id| {
68            session.graph().get_node(*id).map(|node| {
69                json!({
70                    "id": node.id,
71                    "event_type": node.event_type.name(),
72                    "content": node.content,
73                    "confidence": node.confidence,
74                })
75            })
76        })
77        .collect();
78
79    Ok(ToolCallResult::json(&json!({
80        "root_id": result.root_id,
81        "dependent_count": result.dependents.len(),
82        "affected_decisions": result.affected_decisions,
83        "affected_inferences": result.affected_inferences,
84        "dependents": dependents,
85    })))
86}