Skip to main content

agentic_memory_mcp/tools/
memory_suggest.rs

1//! Tool: memory_suggest — Find similar memories for corrections/suggestions.
2
3use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde::Deserialize;
7use serde_json::{json, Value};
8
9use agentic_memory::TextSearchParams;
10
11use crate::session::SessionManager;
12use crate::types::{McpError, McpResult, ToolCallResult, ToolDefinition};
13
14#[derive(Debug, Deserialize)]
15struct SuggestParams {
16    query: String,
17    #[serde(default = "default_limit")]
18    limit: usize,
19}
20
21fn default_limit() -> usize {
22    5
23}
24
25/// Return the tool definition for memory_suggest.
26pub fn definition() -> ToolDefinition {
27    ToolDefinition {
28        name: "memory_suggest".to_string(),
29        description: Some(
30            "Find similar memories when a claim doesn't match exactly. Useful for \
31             correcting misremembered facts or finding related knowledge."
32                .to_string(),
33        ),
34        input_schema: json!({
35            "type": "object",
36            "required": ["query"],
37            "properties": {
38                "query": {
39                    "type": "string",
40                    "description": "The query to find suggestions for"
41                },
42                "limit": {
43                    "type": "integer",
44                    "default": 5,
45                    "description": "Maximum number of suggestions"
46                }
47            }
48        }),
49    }
50}
51
52/// Execute the memory_suggest tool.
53pub async fn execute(
54    args: Value,
55    session: &Arc<Mutex<SessionManager>>,
56) -> McpResult<ToolCallResult> {
57    let params: SuggestParams =
58        serde_json::from_value(args).map_err(|e| McpError::InvalidParams(e.to_string()))?;
59
60    if params.query.trim().is_empty() {
61        return Ok(ToolCallResult::json(&json!({
62            "query": params.query,
63            "count": 0,
64            "suggestions": []
65        })));
66    }
67
68    let session = session.lock().await;
69    let graph = session.graph();
70
71    // Use text search with low threshold to catch partial matches
72    let results = session
73        .query_engine()
74        .text_search(
75            graph,
76            graph.term_index.as_ref(),
77            graph.doc_lengths.as_ref(),
78            TextSearchParams {
79                query: params.query.clone(),
80                max_results: params.limit * 2,
81                event_types: Vec::new(),
82                session_ids: Vec::new(),
83                min_score: 0.0,
84            },
85        )
86        .map_err(|e| McpError::AgenticMemory(format!("Suggest search failed: {e}")))?;
87
88    let mut suggestions: Vec<Value> = results
89        .iter()
90        .filter_map(|m| {
91            graph.get_node(m.node_id).map(|node| {
92                json!({
93                    "node_id": node.id,
94                    "event_type": node.event_type.name(),
95                    "content": node.content,
96                    "confidence": node.confidence,
97                    "relevance_score": m.score,
98                    "matched_terms": m.matched_terms,
99                    "session_id": node.session_id,
100                })
101            })
102        })
103        .collect();
104
105    // Also add word-overlap suggestions from content scanning
106    if suggestions.len() < params.limit {
107        let query_lower = params.query.to_lowercase();
108        let query_words: Vec<&str> = query_lower.split_whitespace().collect();
109        let existing_ids: Vec<u64> = results.iter().map(|m| m.node_id).collect();
110
111        let mut extra: Vec<(f32, Value)> = Vec::new();
112        for node in graph.nodes() {
113            if existing_ids.contains(&node.id) {
114                continue;
115            }
116            let content_lower = node.content.to_lowercase();
117            let overlap = query_words
118                .iter()
119                .filter(|w| content_lower.contains(**w))
120                .count();
121            if overlap > 0 {
122                let score = overlap as f32 / query_words.len().max(1) as f32;
123                extra.push((
124                    score,
125                    json!({
126                        "node_id": node.id,
127                        "event_type": node.event_type.name(),
128                        "content": node.content,
129                        "confidence": node.confidence,
130                        "relevance_score": score,
131                        "matched_terms": [],
132                        "session_id": node.session_id,
133                    }),
134                ));
135            }
136        }
137
138        extra.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
139        for (_, val) in extra.into_iter().take(params.limit - suggestions.len()) {
140            suggestions.push(val);
141        }
142    }
143
144    suggestions.truncate(params.limit);
145
146    Ok(ToolCallResult::json(&json!({
147        "query": params.query,
148        "count": suggestions.len(),
149        "suggestions": suggestions
150    })))
151}