use std::sync::{Arc, Mutex};
use cortex_retrieval::{LexicalDocument, LexicalIndex};
use cortex_store::repo::MemoryRepo;
use cortex_store::Pool;
use serde_json::json;
use crate::tool_handler::{GateId, ToolError, ToolHandler};
const DEFAULT_LIMIT: usize = 10;
const MAX_LIMIT: usize = 50;
#[derive(Debug)]
pub struct CortexSearchTool {
pub pool: Arc<Mutex<Pool>>,
}
impl ToolHandler for CortexSearchTool {
fn name(&self) -> &'static str {
"cortex_search"
}
fn gate_set(&self) -> &'static [GateId] {
&[GateId::FtsRead, GateId::EmbeddingRead]
}
fn call(&self, params: serde_json::Value) -> Result<serde_json::Value, ToolError> {
let query = extract_query(¶ms)?;
let semantic = params
.get("semantic")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let limit = extract_limit(¶ms)?;
let _ = params.get("session_id");
if semantic {
tracing::warn!(
"cortex_search: semantic=true requested but embedding repo is not wired; \
falling back to lexical+FTS5 only"
);
}
let pool = self
.pool
.lock()
.map_err(|e| ToolError::Internal(format!("pool lock poisoned: {e}")))?;
let repo = MemoryRepo::new(&pool);
let memories = repo.list_by_status("active").map_err(|e| {
tracing::error!(error = %e, "cortex_search: failed to read active memories");
ToolError::Internal(format!("failed to read active memories: {e}"))
})?;
let memories: Vec<_> = memories
.into_iter()
.filter(|m| m.status != "pending_mcp_commit")
.collect();
if memories.is_empty() {
return Ok(json!([]));
}
let documents: Vec<LexicalDocument> = memories
.iter()
.map(|m| {
let domains = m
.domains_json
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(str::to_owned))
.collect::<Vec<_>>()
})
.unwrap_or_default();
LexicalDocument::accepted_memory(m.id, m.claim.clone(), domains)
})
.collect();
let index = LexicalIndex::new(documents);
let hits = index
.search(&query)
.map_err(|e| ToolError::InvalidParams(format!("search query error: {e}")))?;
let results: Vec<serde_json::Value> = hits
.into_iter()
.take(limit)
.filter_map(|hit| {
let memory = memories.iter().find(|m| m.id == hit.document.id)?;
let domains: Vec<String> = memory
.domains_json
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(str::to_owned))
.collect()
})
.unwrap_or_default();
Some(json!({
"id": memory.id.to_string(),
"content": memory.claim,
"score": hit.explanation.lexical_match,
"domains": domains,
}))
})
.collect();
Ok(json!(results))
}
}
fn extract_query(params: &serde_json::Value) -> Result<String, ToolError> {
let query = params
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidParams("query parameter is required".to_string()))?;
if query.trim().is_empty() {
return Err(ToolError::InvalidParams(
"query must not be blank".to_string(),
));
}
Ok(query.to_owned())
}
fn extract_limit(params: &serde_json::Value) -> Result<usize, ToolError> {
match params.get("limit") {
None => Ok(DEFAULT_LIMIT),
Some(v) => {
let n = v.as_u64().ok_or_else(|| {
ToolError::InvalidParams("limit must be a non-negative integer".to_string())
})?;
let n = usize::try_from(n).unwrap_or(MAX_LIMIT);
Ok(n.min(MAX_LIMIT))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_query_rejects_missing() {
let err = extract_query(&json!({})).unwrap_err();
assert!(matches!(err, ToolError::InvalidParams(_)));
}
#[test]
fn extract_query_rejects_blank() {
let err = extract_query(&json!({"query": " "})).unwrap_err();
assert!(matches!(err, ToolError::InvalidParams(_)));
}
#[test]
fn extract_query_accepts_non_empty() {
let q = extract_query(&json!({"query": "rust memory"})).unwrap();
assert_eq!(q, "rust memory");
}
#[test]
fn extract_limit_defaults_to_ten() {
assert_eq!(extract_limit(&json!({})).unwrap(), DEFAULT_LIMIT);
}
#[test]
fn extract_limit_caps_at_fifty() {
assert_eq!(extract_limit(&json!({"limit": 999})).unwrap(), MAX_LIMIT);
}
#[test]
fn extract_limit_accepts_valid() {
assert_eq!(extract_limit(&json!({"limit": 20})).unwrap(), 20);
}
#[test]
fn extract_limit_rejects_negative_string() {
let err = extract_limit(&json!({"limit": "bad"})).unwrap_err();
assert!(matches!(err, ToolError::InvalidParams(_)));
}
#[test]
fn gate_set_is_correct() {
use crate::tool_handler::GateId;
let gates: &[GateId] = &[GateId::FtsRead, GateId::EmbeddingRead];
assert_eq!(gates.len(), 2);
}
}