Skip to main content

cortex_mcp/tools/
search.rs

1//! `cortex_search` — FTS5-backed memory search MCP tool.
2//!
3//! Mirrors the retrieval path used by `cortex memory search` (ADR 0045 §3
4//! gate-equivalence). Active memories are queried via
5//! [`cortex_retrieval::LexicalIndex`] and optionally boosted by the FTS5
6//! fuzzy scorer. Rows tagged `pending_mcp_commit` are excluded (ADR 0047 §2);
7//! since that status does not yet exist in the store schema,
8//! `list_by_status("active")` already provides the correct filter.
9//!
10//! The `semantic: true` parameter is accepted per the ADR 0045 §4 tool
11//! schema but degrades gracefully to FTS-only with a warning until the
12//! embedding repo is wired (no `EmbeddingRepo` or `LocalStubEmbedder` exist
13//! in the current codebase).
14
15use std::sync::{Arc, Mutex};
16
17use cortex_retrieval::{LexicalDocument, LexicalIndex};
18use cortex_store::repo::MemoryRepo;
19use cortex_store::Pool;
20use serde_json::json;
21
22use crate::tool_handler::{GateId, ToolError, ToolHandler};
23
24/// Default result limit when the caller omits `limit`.
25const DEFAULT_LIMIT: usize = 10;
26
27/// Server-side cap on `limit` (ADR 0045 §4).
28const MAX_LIMIT: usize = 50;
29
30/// MCP handler for `cortex_search`.
31///
32/// Schema (ADR 0045 §4):
33/// ```jsonc
34/// cortex_search(
35///   query: string,          // required, non-empty
36///   semantic: bool,         // default false — accepted, FTS-only for now
37///   limit: int,             // default 10, capped at 50
38///   session_id?: string     // optional, accepted and ignored
39/// ) → [{ id, content, score, domains }]
40/// ```
41///
42/// `rusqlite::Connection` is not `Sync`; the pool is wrapped in a `Mutex`
43/// to satisfy the `Send + Sync` bound on [`ToolHandler`].
44#[derive(Debug)]
45pub struct CortexSearchTool {
46    /// SQLite connection pool, mutex-wrapped because `rusqlite::Connection`
47    /// is not `Sync`.
48    pub pool: Arc<Mutex<Pool>>,
49}
50
51impl ToolHandler for CortexSearchTool {
52    fn name(&self) -> &'static str {
53        "cortex_search"
54    }
55
56    fn gate_set(&self) -> &'static [GateId] {
57        &[GateId::FtsRead, GateId::EmbeddingRead]
58    }
59
60    fn call(&self, params: serde_json::Value) -> Result<serde_json::Value, ToolError> {
61        // --- param extraction ---
62        let query = extract_query(&params)?;
63        let semantic = params
64            .get("semantic")
65            .and_then(|v| v.as_bool())
66            .unwrap_or(false);
67        let limit = extract_limit(&params)?;
68
69        // session_id is accepted per schema but not used in the retrieval path.
70        let _ = params.get("session_id");
71
72        if semantic {
73            // Embedding repo and LocalStubEmbedder are not yet wired in this
74            // workspace. The tool degrades to FTS-only and emits a diagnostic
75            // rather than returning an error, preserving the caller's ability to
76            // get search results. Follow-on task: wire embedding read path.
77            tracing::warn!(
78                "cortex_search: semantic=true requested but embedding repo is not wired; \
79                 falling back to lexical+FTS5 only"
80            );
81        }
82
83        // --- retrieval ---
84        let pool = self
85            .pool
86            .lock()
87            .map_err(|e| ToolError::Internal(format!("pool lock poisoned: {e}")))?;
88        let repo = MemoryRepo::new(&pool);
89        let memories = repo.list_by_status("active").map_err(|e| {
90            tracing::error!(error = %e, "cortex_search: failed to read active memories");
91            ToolError::Internal(format!("failed to read active memories: {e}"))
92        })?;
93
94        // Filter out any rows with status pending_mcp_commit (ADR 0047 §2).
95        // list_by_status("active") already excludes these since the status
96        // does not exist in the current schema; the filter is explicit here
97        // so the gate contract is visible when the schema column lands.
98        let memories: Vec<_> = memories
99            .into_iter()
100            .filter(|m| m.status != "pending_mcp_commit")
101            .collect();
102
103        if memories.is_empty() {
104            return Ok(json!([]));
105        }
106
107        // Build lexical index and search.
108        let documents: Vec<LexicalDocument> = memories
109            .iter()
110            .map(|m| {
111                let domains = m
112                    .domains_json
113                    .as_array()
114                    .map(|arr| {
115                        arr.iter()
116                            .filter_map(|v| v.as_str().map(str::to_owned))
117                            .collect::<Vec<_>>()
118                    })
119                    .unwrap_or_default();
120                LexicalDocument::accepted_memory(m.id, m.claim.clone(), domains)
121            })
122            .collect();
123
124        let index = LexicalIndex::new(documents);
125        let hits = index
126            .search(&query)
127            .map_err(|e| ToolError::InvalidParams(format!("search query error: {e}")))?;
128
129        // Apply limit and compose result rows.
130        let results: Vec<serde_json::Value> = hits
131            .into_iter()
132            .take(limit)
133            .filter_map(|hit| {
134                let memory = memories.iter().find(|m| m.id == hit.document.id)?;
135                let domains: Vec<String> = memory
136                    .domains_json
137                    .as_array()
138                    .map(|arr| {
139                        arr.iter()
140                            .filter_map(|v| v.as_str().map(str::to_owned))
141                            .collect()
142                    })
143                    .unwrap_or_default();
144                Some(json!({
145                    "id": memory.id.to_string(),
146                    "content": memory.claim,
147                    "score": hit.explanation.lexical_match,
148                    "domains": domains,
149                }))
150            })
151            .collect();
152
153        Ok(json!(results))
154    }
155}
156
157fn extract_query(params: &serde_json::Value) -> Result<String, ToolError> {
158    let query = params
159        .get("query")
160        .and_then(|v| v.as_str())
161        .ok_or_else(|| ToolError::InvalidParams("query parameter is required".to_string()))?;
162    if query.trim().is_empty() {
163        return Err(ToolError::InvalidParams(
164            "query must not be blank".to_string(),
165        ));
166    }
167    Ok(query.to_owned())
168}
169
170fn extract_limit(params: &serde_json::Value) -> Result<usize, ToolError> {
171    match params.get("limit") {
172        None => Ok(DEFAULT_LIMIT),
173        Some(v) => {
174            let n = v.as_u64().ok_or_else(|| {
175                ToolError::InvalidParams("limit must be a non-negative integer".to_string())
176            })?;
177            let n = usize::try_from(n).unwrap_or(MAX_LIMIT);
178            Ok(n.min(MAX_LIMIT))
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn extract_query_rejects_missing() {
189        let err = extract_query(&json!({})).unwrap_err();
190        assert!(matches!(err, ToolError::InvalidParams(_)));
191    }
192
193    #[test]
194    fn extract_query_rejects_blank() {
195        let err = extract_query(&json!({"query": "   "})).unwrap_err();
196        assert!(matches!(err, ToolError::InvalidParams(_)));
197    }
198
199    #[test]
200    fn extract_query_accepts_non_empty() {
201        let q = extract_query(&json!({"query": "rust memory"})).unwrap();
202        assert_eq!(q, "rust memory");
203    }
204
205    #[test]
206    fn extract_limit_defaults_to_ten() {
207        assert_eq!(extract_limit(&json!({})).unwrap(), DEFAULT_LIMIT);
208    }
209
210    #[test]
211    fn extract_limit_caps_at_fifty() {
212        assert_eq!(extract_limit(&json!({"limit": 999})).unwrap(), MAX_LIMIT);
213    }
214
215    #[test]
216    fn extract_limit_accepts_valid() {
217        assert_eq!(extract_limit(&json!({"limit": 20})).unwrap(), 20);
218    }
219
220    #[test]
221    fn extract_limit_rejects_negative_string() {
222        let err = extract_limit(&json!({"limit": "bad"})).unwrap_err();
223        assert!(matches!(err, ToolError::InvalidParams(_)));
224    }
225
226    #[test]
227    fn gate_set_is_correct() {
228        // Safety: Pool is not constructable without a SQLite path; we only
229        // test the static metadata here.
230        use crate::tool_handler::GateId;
231        // GateId values are Copy — verify them via the slice directly.
232        let gates: &[GateId] = &[GateId::FtsRead, GateId::EmbeddingRead];
233        assert_eq!(gates.len(), 2);
234    }
235}