agentic_memory_mcp/tools/
memory_suggest.rs1use std::sync::Arc;
4use tokio::sync::Mutex;
5
6use serde::Deserialize;
7use serde_json::{json, Value};
8
9use agentic_memory::TextSearchParams;
10
11use crate::session::SessionManager;
12use crate::types::{McpError, McpResult, ToolCallResult, ToolDefinition};
13
14#[derive(Debug, Deserialize)]
15struct SuggestParams {
16 query: String,
17 #[serde(default = "default_limit")]
18 limit: usize,
19}
20
21fn default_limit() -> usize {
22 5
23}
24
25pub fn definition() -> ToolDefinition {
27 ToolDefinition {
28 name: "memory_suggest".to_string(),
29 description: Some(
30 "Find similar memories when a claim doesn't match exactly. Useful for \
31 correcting misremembered facts or finding related knowledge."
32 .to_string(),
33 ),
34 input_schema: json!({
35 "type": "object",
36 "required": ["query"],
37 "properties": {
38 "query": {
39 "type": "string",
40 "description": "The query to find suggestions for"
41 },
42 "limit": {
43 "type": "integer",
44 "default": 5,
45 "description": "Maximum number of suggestions"
46 }
47 }
48 }),
49 }
50}
51
52pub async fn execute(
54 args: Value,
55 session: &Arc<Mutex<SessionManager>>,
56) -> McpResult<ToolCallResult> {
57 let params: SuggestParams =
58 serde_json::from_value(args).map_err(|e| McpError::InvalidParams(e.to_string()))?;
59
60 if params.query.trim().is_empty() {
61 return Ok(ToolCallResult::json(&json!({
62 "query": params.query,
63 "count": 0,
64 "suggestions": []
65 })));
66 }
67
68 let session = session.lock().await;
69 let graph = session.graph();
70
71 let results = session
73 .query_engine()
74 .text_search(
75 graph,
76 graph.term_index.as_ref(),
77 graph.doc_lengths.as_ref(),
78 TextSearchParams {
79 query: params.query.clone(),
80 max_results: params.limit * 2,
81 event_types: Vec::new(),
82 session_ids: Vec::new(),
83 min_score: 0.0,
84 },
85 )
86 .map_err(|e| McpError::AgenticMemory(format!("Suggest search failed: {e}")))?;
87
88 let mut suggestions: Vec<Value> = results
89 .iter()
90 .filter_map(|m| {
91 graph.get_node(m.node_id).map(|node| {
92 json!({
93 "node_id": node.id,
94 "event_type": node.event_type.name(),
95 "content": node.content,
96 "confidence": node.confidence,
97 "relevance_score": m.score,
98 "matched_terms": m.matched_terms,
99 "session_id": node.session_id,
100 })
101 })
102 })
103 .collect();
104
105 if suggestions.len() < params.limit {
107 let query_lower = params.query.to_lowercase();
108 let query_words: Vec<&str> = query_lower.split_whitespace().collect();
109 let existing_ids: Vec<u64> = results.iter().map(|m| m.node_id).collect();
110
111 let mut extra: Vec<(f32, Value)> = Vec::new();
112 for node in graph.nodes() {
113 if existing_ids.contains(&node.id) {
114 continue;
115 }
116 let content_lower = node.content.to_lowercase();
117 let overlap = query_words
118 .iter()
119 .filter(|w| content_lower.contains(**w))
120 .count();
121 if overlap > 0 {
122 let score = overlap as f32 / query_words.len().max(1) as f32;
123 extra.push((
124 score,
125 json!({
126 "node_id": node.id,
127 "event_type": node.event_type.name(),
128 "content": node.content,
129 "confidence": node.confidence,
130 "relevance_score": score,
131 "matched_terms": [],
132 "session_id": node.session_id,
133 }),
134 ));
135 }
136 }
137
138 extra.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
139 for (_, val) in extra.into_iter().take(params.limit - suggestions.len()) {
140 suggestions.push(val);
141 }
142 }
143
144 suggestions.truncate(params.limit);
145
146 Ok(ToolCallResult::json(&json!({
147 "query": params.query,
148 "count": suggestions.len(),
149 "suggestions": suggestions
150 })))
151}