Skip to main content

agentic_memory_mcp/tools/
memory_ground.rs

1//! Tool: memory_ground — Verify a claim has memory backing (anti-hallucination).
2
3use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde::Deserialize;
7use serde_json::{json, Value};
8
9use agentic_memory::TextSearchParams;
10
11use crate::session::SessionManager;
12use crate::types::{McpError, McpResult, ToolCallResult, ToolDefinition};
13
14#[derive(Debug, Deserialize)]
15struct GroundParams {
16    claim: String,
17    #[serde(default = "default_threshold")]
18    threshold: f32,
19}
20
21fn default_threshold() -> f32 {
22    0.3
23}
24
25/// Return the tool definition for memory_ground.
26pub fn definition() -> ToolDefinition {
27    ToolDefinition {
28        name: "memory_ground".to_string(),
29        description: Some(
30            "Verify a claim has memory backing. Returns verified/partial/ungrounded status \
31             to prevent hallucination about what was previously remembered."
32                .to_string(),
33        ),
34        input_schema: json!({
35            "type": "object",
36            "required": ["claim"],
37            "properties": {
38                "claim": {
39                    "type": "string",
40                    "description": "The claim to verify against stored memories"
41                },
42                "threshold": {
43                    "type": "number",
44                    "default": 0.3,
45                    "description": "Minimum BM25 score to consider a match (0.0-10.0)"
46                }
47            }
48        }),
49    }
50}
51
52/// Execute the memory_ground tool.
53pub async fn execute(
54    args: Value,
55    session: &Arc<Mutex<SessionManager>>,
56) -> McpResult<ToolCallResult> {
57    let params: GroundParams =
58        serde_json::from_value(args).map_err(|e| McpError::InvalidParams(e.to_string()))?;
59
60    if params.claim.trim().is_empty() {
61        return Ok(ToolCallResult::json(&json!({
62            "status": "ungrounded",
63            "claim": params.claim,
64            "reason": "Empty claim",
65            "suggestions": []
66        })));
67    }
68
69    let session = session.lock().await;
70    let graph = session.graph();
71
72    // Use BM25 text search to find matching memories
73    let results = session
74        .query_engine()
75        .text_search(
76            graph,
77            graph.term_index.as_ref(),
78            graph.doc_lengths.as_ref(),
79            TextSearchParams {
80                query: params.claim.clone(),
81                max_results: 10,
82                event_types: Vec::new(),
83                session_ids: Vec::new(),
84                min_score: 0.0,
85            },
86        )
87        .map_err(|e| McpError::AgenticMemory(format!("Grounding search failed: {e}")))?;
88
89    let threshold = params.threshold;
90
91    // Classify results
92    let strong: Vec<&agentic_memory::TextMatch> =
93        results.iter().filter(|m| m.score >= threshold).collect();
94
95    if strong.is_empty() {
96        // Try fuzzy suggestions
97        let suggestions = suggest_similar_content(graph, &params.claim);
98        return Ok(ToolCallResult::json(&json!({
99            "status": "ungrounded",
100            "claim": params.claim,
101            "reason": "No memory nodes match this claim",
102            "suggestions": suggestions
103        })));
104    }
105
106    // Build evidence from strong matches
107    let evidence: Vec<Value> = strong
108        .iter()
109        .filter_map(|m| {
110            graph.get_node(m.node_id).map(|node| {
111                json!({
112                    "node_id": node.id,
113                    "event_type": node.event_type.name(),
114                    "content": node.content,
115                    "confidence": node.confidence,
116                    "session_id": node.session_id,
117                    "created_at": node.created_at,
118                    "score": m.score,
119                    "matched_terms": m.matched_terms,
120                })
121            })
122        })
123        .collect();
124
125    let avg_score: f32 = strong.iter().map(|m| m.score).sum::<f32>() / strong.len() as f32;
126    let confidence = (avg_score / (avg_score + 1.0)).min(1.0);
127
128    Ok(ToolCallResult::json(&json!({
129        "status": "verified",
130        "claim": params.claim,
131        "confidence": confidence,
132        "evidence_count": evidence.len(),
133        "evidence": evidence
134    })))
135}
136
137/// Find memory content that is similar to the query (for suggestions).
138fn suggest_similar_content(graph: &agentic_memory::MemoryGraph, query: &str) -> Vec<String> {
139    let query_lower = query.to_lowercase();
140    let query_words: Vec<&str> = query_lower.split_whitespace().collect();
141    let mut suggestions: Vec<(f32, String)> = Vec::new();
142
143    for node in graph.nodes() {
144        let content_lower = node.content.to_lowercase();
145        // Count overlapping words
146        let overlap = query_words
147            .iter()
148            .filter(|w| content_lower.contains(**w))
149            .count();
150        if overlap > 0 {
151            let score = overlap as f32 / query_words.len().max(1) as f32;
152            let preview = if node.content.len() > 80 {
153                format!("{}...", &node.content[..80])
154            } else {
155                node.content.clone()
156            };
157            suggestions.push((score, preview));
158        }
159    }
160
161    suggestions.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
162    suggestions.truncate(5);
163    suggestions.into_iter().map(|(_, s)| s).collect()
164}