cortex_mcp/tools/
search.rs1use std::sync::{Arc, Mutex};
16
17use cortex_retrieval::{LexicalDocument, LexicalIndex};
18use cortex_store::repo::MemoryRepo;
19use cortex_store::Pool;
20use serde_json::json;
21
22use crate::tool_handler::{GateId, ToolError, ToolHandler};
23
24const DEFAULT_LIMIT: usize = 10;
26
27const MAX_LIMIT: usize = 50;
29
30#[derive(Debug)]
45pub struct CortexSearchTool {
46 pub pool: Arc<Mutex<Pool>>,
49}
50
51impl ToolHandler for CortexSearchTool {
52 fn name(&self) -> &'static str {
53 "cortex_search"
54 }
55
56 fn gate_set(&self) -> &'static [GateId] {
57 &[GateId::FtsRead, GateId::EmbeddingRead]
58 }
59
60 fn call(&self, params: serde_json::Value) -> Result<serde_json::Value, ToolError> {
61 let query = extract_query(¶ms)?;
63 let semantic = params
64 .get("semantic")
65 .and_then(|v| v.as_bool())
66 .unwrap_or(false);
67 let limit = extract_limit(¶ms)?;
68
69 let _ = params.get("session_id");
71
72 if semantic {
73 tracing::warn!(
78 "cortex_search: semantic=true requested but embedding repo is not wired; \
79 falling back to lexical+FTS5 only"
80 );
81 }
82
83 let pool = self
85 .pool
86 .lock()
87 .map_err(|e| ToolError::Internal(format!("pool lock poisoned: {e}")))?;
88 let repo = MemoryRepo::new(&pool);
89 let memories = repo.list_by_status("active").map_err(|e| {
90 tracing::error!(error = %e, "cortex_search: failed to read active memories");
91 ToolError::Internal(format!("failed to read active memories: {e}"))
92 })?;
93
94 let memories: Vec<_> = memories
99 .into_iter()
100 .filter(|m| m.status != "pending_mcp_commit")
101 .collect();
102
103 if memories.is_empty() {
104 return Ok(json!([]));
105 }
106
107 let documents: Vec<LexicalDocument> = memories
109 .iter()
110 .map(|m| {
111 let domains = m
112 .domains_json
113 .as_array()
114 .map(|arr| {
115 arr.iter()
116 .filter_map(|v| v.as_str().map(str::to_owned))
117 .collect::<Vec<_>>()
118 })
119 .unwrap_or_default();
120 LexicalDocument::accepted_memory(m.id, m.claim.clone(), domains)
121 })
122 .collect();
123
124 let index = LexicalIndex::new(documents);
125 let hits = index
126 .search(&query)
127 .map_err(|e| ToolError::InvalidParams(format!("search query error: {e}")))?;
128
129 let results: Vec<serde_json::Value> = hits
131 .into_iter()
132 .take(limit)
133 .filter_map(|hit| {
134 let memory = memories.iter().find(|m| m.id == hit.document.id)?;
135 let domains: Vec<String> = memory
136 .domains_json
137 .as_array()
138 .map(|arr| {
139 arr.iter()
140 .filter_map(|v| v.as_str().map(str::to_owned))
141 .collect()
142 })
143 .unwrap_or_default();
144 Some(json!({
145 "id": memory.id.to_string(),
146 "content": memory.claim,
147 "score": hit.explanation.lexical_match,
148 "domains": domains,
149 }))
150 })
151 .collect();
152
153 Ok(json!(results))
154 }
155}
156
157fn extract_query(params: &serde_json::Value) -> Result<String, ToolError> {
158 let query = params
159 .get("query")
160 .and_then(|v| v.as_str())
161 .ok_or_else(|| ToolError::InvalidParams("query parameter is required".to_string()))?;
162 if query.trim().is_empty() {
163 return Err(ToolError::InvalidParams(
164 "query must not be blank".to_string(),
165 ));
166 }
167 Ok(query.to_owned())
168}
169
170fn extract_limit(params: &serde_json::Value) -> Result<usize, ToolError> {
171 match params.get("limit") {
172 None => Ok(DEFAULT_LIMIT),
173 Some(v) => {
174 let n = v.as_u64().ok_or_else(|| {
175 ToolError::InvalidParams("limit must be a non-negative integer".to_string())
176 })?;
177 let n = usize::try_from(n).unwrap_or(MAX_LIMIT);
178 Ok(n.min(MAX_LIMIT))
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn extract_query_rejects_missing() {
189 let err = extract_query(&json!({})).unwrap_err();
190 assert!(matches!(err, ToolError::InvalidParams(_)));
191 }
192
193 #[test]
194 fn extract_query_rejects_blank() {
195 let err = extract_query(&json!({"query": " "})).unwrap_err();
196 assert!(matches!(err, ToolError::InvalidParams(_)));
197 }
198
199 #[test]
200 fn extract_query_accepts_non_empty() {
201 let q = extract_query(&json!({"query": "rust memory"})).unwrap();
202 assert_eq!(q, "rust memory");
203 }
204
205 #[test]
206 fn extract_limit_defaults_to_ten() {
207 assert_eq!(extract_limit(&json!({})).unwrap(), DEFAULT_LIMIT);
208 }
209
210 #[test]
211 fn extract_limit_caps_at_fifty() {
212 assert_eq!(extract_limit(&json!({"limit": 999})).unwrap(), MAX_LIMIT);
213 }
214
215 #[test]
216 fn extract_limit_accepts_valid() {
217 assert_eq!(extract_limit(&json!({"limit": 20})).unwrap(), 20);
218 }
219
220 #[test]
221 fn extract_limit_rejects_negative_string() {
222 let err = extract_limit(&json!({"limit": "bad"})).unwrap_err();
223 assert!(matches!(err, ToolError::InvalidParams(_)));
224 }
225
226 #[test]
227 fn gate_set_is_correct() {
228 use crate::tool_handler::GateId;
231 let gates: &[GateId] = &[GateId::FtsRead, GateId::EmbeddingRead];
233 assert_eq!(gates.len(), 2);
234 }
235}