Skip to main content

cortex_mcp/tools/
reflect.rs

1//! `cortex_reflect` MCP tool handler.
2//!
3//! Read-only surface for SessionReflection candidates. Given an optional
4//! `trace_id`, returns the candidate memory rows whose `source_events_json`
5//! references events belonging to that trace. When no `trace_id` is supplied,
6//! uses the most recent trace from the store (by `opened_at DESC`).
7//!
8//! Nothing is written on this path — the handler queries `memories` and
9//! `trace_events` but never inserts or updates any row. The write path
10//! (candidate persistence) lives in `cortex_reflect::orchestrate::reflect`.
11//!
12//! Gate: [`GateId::FtsRead`].
13//! Tier: supervised — logs at every entry.
14
15use std::collections::HashSet;
16use std::sync::{Arc, Mutex};
17
18use cortex_store::repo::MemoryRepo;
19use cortex_store::Pool;
20use rusqlite::{params, OptionalExtension};
21use serde_json::{json, Value};
22
23use crate::{GateId, ToolError, ToolHandler};
24
25/// MCP tool: `cortex_reflect`.
26///
27/// Schema:
28/// ```text
29/// cortex_reflect(
30///   trace_id?: string,
31/// ) -> {
32///   candidates: [{ claim, memory_type, confidence, domains }],
33///   count:      int,
34/// }
35/// ```
36///
37/// Returns the candidate memories recorded for the given trace (or the most
38/// recent trace if `trace_id` is absent) as a read-only preview. Never writes
39/// to the store.
40pub struct CortexReflectTool {
41    pool: Arc<Mutex<Pool>>,
42}
43
44impl std::fmt::Debug for CortexReflectTool {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("CortexReflectTool").finish_non_exhaustive()
47    }
48}
49
50impl CortexReflectTool {
51    /// Construct the tool over a shared store connection.
52    #[must_use]
53    pub fn new(pool: Arc<Mutex<Pool>>) -> Self {
54        Self { pool }
55    }
56}
57
58impl ToolHandler for CortexReflectTool {
59    fn name(&self) -> &'static str {
60        "cortex_reflect"
61    }
62
63    fn gate_set(&self) -> &'static [GateId] {
64        &[GateId::FtsRead]
65    }
66
67    fn call(&self, params: Value) -> Result<Value, ToolError> {
68        tracing::info!("cortex_reflect called via MCP");
69
70        let explicit_trace_id = params["trace_id"]
71            .as_str()
72            .filter(|s| !s.trim().is_empty())
73            .map(ToOwned::to_owned);
74
75        let pool = self
76            .pool
77            .lock()
78            .map_err(|err| ToolError::Internal(format!("pool lock poisoned: {err}")))?;
79
80        // Resolve the trace id — either caller-supplied or the most recently
81        // opened trace in the store.
82        let trace_id = match explicit_trace_id {
83            Some(id) => id,
84            None => {
85                let result = pool
86                    .query_row(
87                        "SELECT id FROM traces ORDER BY opened_at DESC, id DESC LIMIT 1",
88                        [],
89                        |row| row.get::<_, String>(0),
90                    )
91                    .optional()
92                    .map_err(|err| {
93                        ToolError::Internal(format!("failed to query most recent trace: {err}"))
94                    })?;
95
96                match result {
97                    Some(id) => id,
98                    None => {
99                        return Ok(json!({ "candidates": [], "count": 0 }));
100                    }
101                }
102            }
103        };
104
105        // Collect event IDs belonging to this trace.
106        let event_ids: HashSet<String> = {
107            let mut stmt = pool
108                .prepare("SELECT event_id FROM trace_events WHERE trace_id = ?1")
109                .map_err(|err| {
110                    ToolError::Internal(format!("failed to prepare trace_events query: {err}"))
111                })?;
112
113            // Collect rows eagerly so that `stmt` can be dropped at block end.
114            let rows: Vec<Result<String, rusqlite::Error>> = stmt
115                .query_map(params![trace_id], |row| row.get::<_, String>(0))
116                .map_err(|err| {
117                    ToolError::Internal(format!(
118                        "failed to query trace_events for trace {trace_id}: {err}"
119                    ))
120                })?
121                .collect();
122
123            rows.into_iter()
124                .collect::<Result<HashSet<_>, _>>()
125                .map_err(|err| {
126                    ToolError::Internal(format!("error reading trace_events row: {err}"))
127                })?
128        };
129
130        if event_ids.is_empty() {
131            return Ok(json!({ "candidates": [], "count": 0 }));
132        }
133
134        // List all candidate memories and filter to those whose
135        // source_events_json intersects with the trace's event IDs.
136        let all_candidates = MemoryRepo::new(&pool)
137            .list_candidates()
138            .map_err(|err| ToolError::Internal(format!("failed to list candidates: {err}")))?;
139
140        let mut candidates: Vec<Value> = Vec::new();
141        for record in all_candidates {
142            // source_events_json is a JSON array of event-id strings.
143            let source_events: Vec<String> = record
144                .source_events_json
145                .as_array()
146                .map(|arr| {
147                    arr.iter()
148                        .filter_map(|v| v.as_str().map(ToOwned::to_owned))
149                        .collect()
150                })
151                .unwrap_or_default();
152
153            let belongs_to_trace = source_events.iter().any(|id| event_ids.contains(id));
154            if !belongs_to_trace {
155                continue;
156            }
157
158            // domains_json is already a JSON array of strings.
159            candidates.push(json!({
160                "claim":       record.claim,
161                "memory_type": record.memory_type,
162                "confidence":  record.confidence,
163                "domains":     record.domains_json,
164            }));
165        }
166
167        let count = candidates.len();
168        Ok(json!({
169            "candidates": candidates,
170            "count": count,
171        }))
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use std::sync::{Arc, Mutex};
179
180    fn make_tool() -> CortexReflectTool {
181        let pool = rusqlite::Connection::open_in_memory().expect("in-memory sqlite");
182        cortex_store::migrate::apply_pending(&pool).expect("migrations");
183        CortexReflectTool::new(Arc::new(Mutex::new(pool)))
184    }
185
186    #[test]
187    fn name_and_gate() {
188        let tool = make_tool();
189        assert_eq!(tool.name(), "cortex_reflect");
190        assert!(!tool.gate_set().is_empty());
191        assert_eq!(tool.gate_set(), &[GateId::FtsRead]);
192    }
193
194    #[test]
195    fn empty_store_returns_empty_candidates() {
196        let tool = make_tool();
197        let result = tool.call(serde_json::Value::Null).unwrap();
198        assert_eq!(result["candidates"], json!([]));
199        assert_eq!(result["count"], 0);
200    }
201
202    #[test]
203    fn unknown_trace_id_returns_empty_candidates() {
204        let tool = make_tool();
205        let result = tool
206            .call(json!({ "trace_id": "01J0000000000000000000000X" }))
207            .unwrap();
208        assert_eq!(result["candidates"], json!([]));
209        assert_eq!(result["count"], 0);
210    }
211}