use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use cortex_store::repo::MemoryRepo;
use cortex_store::Pool;
use rusqlite::{params, OptionalExtension};
use serde_json::{json, Value};
use crate::{GateId, ToolError, ToolHandler};
pub struct CortexReflectTool {
pool: Arc<Mutex<Pool>>,
}
impl std::fmt::Debug for CortexReflectTool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CortexReflectTool").finish_non_exhaustive()
}
}
impl CortexReflectTool {
#[must_use]
pub fn new(pool: Arc<Mutex<Pool>>) -> Self {
Self { pool }
}
}
impl ToolHandler for CortexReflectTool {
fn name(&self) -> &'static str {
"cortex_reflect"
}
fn gate_set(&self) -> &'static [GateId] {
&[GateId::FtsRead]
}
fn call(&self, params: Value) -> Result<Value, ToolError> {
tracing::info!("cortex_reflect called via MCP");
let explicit_trace_id = params["trace_id"]
.as_str()
.filter(|s| !s.trim().is_empty())
.map(ToOwned::to_owned);
let pool = self
.pool
.lock()
.map_err(|err| ToolError::Internal(format!("pool lock poisoned: {err}")))?;
let trace_id = match explicit_trace_id {
Some(id) => id,
None => {
let result = pool
.query_row(
"SELECT id FROM traces ORDER BY opened_at DESC, id DESC LIMIT 1",
[],
|row| row.get::<_, String>(0),
)
.optional()
.map_err(|err| {
ToolError::Internal(format!("failed to query most recent trace: {err}"))
})?;
match result {
Some(id) => id,
None => {
return Ok(json!({ "candidates": [], "count": 0 }));
}
}
}
};
let event_ids: HashSet<String> = {
let mut stmt = pool
.prepare("SELECT event_id FROM trace_events WHERE trace_id = ?1")
.map_err(|err| {
ToolError::Internal(format!("failed to prepare trace_events query: {err}"))
})?;
let rows: Vec<Result<String, rusqlite::Error>> = stmt
.query_map(params![trace_id], |row| row.get::<_, String>(0))
.map_err(|err| {
ToolError::Internal(format!(
"failed to query trace_events for trace {trace_id}: {err}"
))
})?
.collect();
rows.into_iter()
.collect::<Result<HashSet<_>, _>>()
.map_err(|err| {
ToolError::Internal(format!("error reading trace_events row: {err}"))
})?
};
if event_ids.is_empty() {
return Ok(json!({ "candidates": [], "count": 0 }));
}
let all_candidates = MemoryRepo::new(&pool)
.list_candidates()
.map_err(|err| ToolError::Internal(format!("failed to list candidates: {err}")))?;
let mut candidates: Vec<Value> = Vec::new();
for record in all_candidates {
let source_events: Vec<String> = record
.source_events_json
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(ToOwned::to_owned))
.collect()
})
.unwrap_or_default();
let belongs_to_trace = source_events.iter().any(|id| event_ids.contains(id));
if !belongs_to_trace {
continue;
}
candidates.push(json!({
"claim": record.claim,
"memory_type": record.memory_type,
"confidence": record.confidence,
"domains": record.domains_json,
}));
}
let count = candidates.len();
Ok(json!({
"candidates": candidates,
"count": count,
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
fn make_tool() -> CortexReflectTool {
let pool = rusqlite::Connection::open_in_memory().expect("in-memory sqlite");
cortex_store::migrate::apply_pending(&pool).expect("migrations");
CortexReflectTool::new(Arc::new(Mutex::new(pool)))
}
#[test]
fn name_and_gate() {
let tool = make_tool();
assert_eq!(tool.name(), "cortex_reflect");
assert!(!tool.gate_set().is_empty());
assert_eq!(tool.gate_set(), &[GateId::FtsRead]);
}
#[test]
fn empty_store_returns_empty_candidates() {
let tool = make_tool();
let result = tool.call(serde_json::Value::Null).unwrap();
assert_eq!(result["candidates"], json!([]));
assert_eq!(result["count"], 0);
}
#[test]
fn unknown_trace_id_returns_empty_candidates() {
let tool = make_tool();
let result = tool
.call(json!({ "trace_id": "01J0000000000000000000000X" }))
.unwrap();
assert_eq!(result["candidates"], json!([]));
assert_eq!(result["count"], 0);
}
}