cortex_mcp/tools/
reflect.rs1use std::collections::HashSet;
16use std::sync::{Arc, Mutex};
17
18use cortex_store::repo::MemoryRepo;
19use cortex_store::Pool;
20use rusqlite::{params, OptionalExtension};
21use serde_json::{json, Value};
22
23use crate::{GateId, ToolError, ToolHandler};
24
25pub struct CortexReflectTool {
41 pool: Arc<Mutex<Pool>>,
42}
43
44impl std::fmt::Debug for CortexReflectTool {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("CortexReflectTool").finish_non_exhaustive()
47 }
48}
49
50impl CortexReflectTool {
51 #[must_use]
53 pub fn new(pool: Arc<Mutex<Pool>>) -> Self {
54 Self { pool }
55 }
56}
57
58impl ToolHandler for CortexReflectTool {
59 fn name(&self) -> &'static str {
60 "cortex_reflect"
61 }
62
63 fn gate_set(&self) -> &'static [GateId] {
64 &[GateId::FtsRead]
65 }
66
67 fn call(&self, params: Value) -> Result<Value, ToolError> {
68 tracing::info!("cortex_reflect called via MCP");
69
70 let explicit_trace_id = params["trace_id"]
71 .as_str()
72 .filter(|s| !s.trim().is_empty())
73 .map(ToOwned::to_owned);
74
75 let pool = self
76 .pool
77 .lock()
78 .map_err(|err| ToolError::Internal(format!("pool lock poisoned: {err}")))?;
79
80 let trace_id = match explicit_trace_id {
83 Some(id) => id,
84 None => {
85 let result = pool
86 .query_row(
87 "SELECT id FROM traces ORDER BY opened_at DESC, id DESC LIMIT 1",
88 [],
89 |row| row.get::<_, String>(0),
90 )
91 .optional()
92 .map_err(|err| {
93 ToolError::Internal(format!("failed to query most recent trace: {err}"))
94 })?;
95
96 match result {
97 Some(id) => id,
98 None => {
99 return Ok(json!({ "candidates": [], "count": 0 }));
100 }
101 }
102 }
103 };
104
105 let event_ids: HashSet<String> = {
107 let mut stmt = pool
108 .prepare("SELECT event_id FROM trace_events WHERE trace_id = ?1")
109 .map_err(|err| {
110 ToolError::Internal(format!("failed to prepare trace_events query: {err}"))
111 })?;
112
113 let rows: Vec<Result<String, rusqlite::Error>> = stmt
115 .query_map(params![trace_id], |row| row.get::<_, String>(0))
116 .map_err(|err| {
117 ToolError::Internal(format!(
118 "failed to query trace_events for trace {trace_id}: {err}"
119 ))
120 })?
121 .collect();
122
123 rows.into_iter()
124 .collect::<Result<HashSet<_>, _>>()
125 .map_err(|err| {
126 ToolError::Internal(format!("error reading trace_events row: {err}"))
127 })?
128 };
129
130 if event_ids.is_empty() {
131 return Ok(json!({ "candidates": [], "count": 0 }));
132 }
133
134 let all_candidates = MemoryRepo::new(&pool)
137 .list_candidates()
138 .map_err(|err| ToolError::Internal(format!("failed to list candidates: {err}")))?;
139
140 let mut candidates: Vec<Value> = Vec::new();
141 for record in all_candidates {
142 let source_events: Vec<String> = record
144 .source_events_json
145 .as_array()
146 .map(|arr| {
147 arr.iter()
148 .filter_map(|v| v.as_str().map(ToOwned::to_owned))
149 .collect()
150 })
151 .unwrap_or_default();
152
153 let belongs_to_trace = source_events.iter().any(|id| event_ids.contains(id));
154 if !belongs_to_trace {
155 continue;
156 }
157
158 candidates.push(json!({
160 "claim": record.claim,
161 "memory_type": record.memory_type,
162 "confidence": record.confidence,
163 "domains": record.domains_json,
164 }));
165 }
166
167 let count = candidates.len();
168 Ok(json!({
169 "candidates": candidates,
170 "count": count,
171 }))
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use std::sync::{Arc, Mutex};
179
180 fn make_tool() -> CortexReflectTool {
181 let pool = rusqlite::Connection::open_in_memory().expect("in-memory sqlite");
182 cortex_store::migrate::apply_pending(&pool).expect("migrations");
183 CortexReflectTool::new(Arc::new(Mutex::new(pool)))
184 }
185
186 #[test]
187 fn name_and_gate() {
188 let tool = make_tool();
189 assert_eq!(tool.name(), "cortex_reflect");
190 assert!(!tool.gate_set().is_empty());
191 assert_eq!(tool.gate_set(), &[GateId::FtsRead]);
192 }
193
194 #[test]
195 fn empty_store_returns_empty_candidates() {
196 let tool = make_tool();
197 let result = tool.call(serde_json::Value::Null).unwrap();
198 assert_eq!(result["candidates"], json!([]));
199 assert_eq!(result["count"], 0);
200 }
201
202 #[test]
203 fn unknown_trace_id_returns_empty_candidates() {
204 let tool = make_tool();
205 let result = tool
206 .call(json!({ "trace_id": "01J0000000000000000000000X" }))
207 .unwrap();
208 assert_eq!(result["candidates"], json!([]));
209 assert_eq!(result["count"], 0);
210 }
211}