1use std::sync::{Arc, Mutex};
17
18use cortex_retrieval::{EmbedRecord, Embedder, LocalStubEmbedder, STUB_BACKEND_ID};
19use cortex_store::repo::{EmbeddingRepo, MemoryRepo};
20use cortex_store::Pool;
21use serde_json::{json, Value};
22
23use crate::{GateId, ToolError, ToolHandler};
24
25#[derive(Debug)]
39pub struct CortexMemoryEmbedTool {
40 pool: Arc<Mutex<Pool>>,
41}
42
43impl CortexMemoryEmbedTool {
44 #[must_use]
46 pub fn new(pool: Arc<Mutex<Pool>>) -> Self {
47 Self { pool }
48 }
49}
50
51impl ToolHandler for CortexMemoryEmbedTool {
52 fn name(&self) -> &'static str {
53 "cortex_memory_embed"
54 }
55
56 fn gate_set(&self) -> &'static [GateId] {
57 &[GateId::SessionWrite]
58 }
59
60 fn call(&self, params: Value) -> Result<Value, ToolError> {
61 let preview = params
62 .get("preview")
63 .and_then(|v| v.as_bool())
64 .unwrap_or(false);
65
66 let model_hint = params
70 .get("model")
71 .and_then(|v| v.as_str())
72 .map(ToOwned::to_owned);
73
74 tracing::info!(
75 preview = %preview,
76 "cortex_memory_embed via MCP: preview={}", preview
77 );
78
79 if let Some(ref m) = model_hint {
80 tracing::info!(
81 model = %m,
82 "cortex_memory_embed: model hint supplied (stub backend only for now)"
83 );
84 }
85
86 let pool_guard = self
87 .pool
88 .lock()
89 .map_err(|_| ToolError::Internal("pool lock poisoned".into()))?;
90
91 let embedder = LocalStubEmbedder::new();
94 let backend_id = embedder.backend_id().to_owned();
95
96 let repo = MemoryRepo::new(&pool_guard);
98 let memories = repo
99 .list_by_status("active")
100 .map_err(|err| ToolError::Internal(format!("failed to read active memories: {err}")))?;
101
102 let total = memories.len();
103 let embed_repo = EmbeddingRepo::new(&pool_guard);
104 let now = chrono::Utc::now();
105
106 let mut enriched: usize = 0;
107 let mut skipped: usize = 0;
108
109 for memory in &memories {
110 let existing = embed_repo.read(&memory.id, &backend_id).map_err(|err| {
112 ToolError::Internal(format!(
113 "failed to read embedding for memory {}: {err}",
114 memory.id
115 ))
116 })?;
117
118 if existing.is_some() {
119 skipped += 1;
120 continue;
121 }
122
123 if preview {
124 enriched += 1;
126 continue;
127 }
128
129 let tags: Vec<String> = memory
131 .domains_json
132 .as_array()
133 .into_iter()
134 .flatten()
135 .filter_map(|v| v.as_str().map(ToOwned::to_owned))
136 .collect();
137
138 let vec = embedder.embed(&memory.claim, &tags).map_err(|err| {
139 ToolError::Internal(format!("embed failed for memory {}: {err}", memory.id))
140 })?;
141
142 let record = EmbedRecord::new(memory.id, STUB_BACKEND_ID, vec, now).map_err(|err| {
143 ToolError::Internal(format!(
144 "failed to build embed record for memory {}: {err}",
145 memory.id
146 ))
147 })?;
148
149 embed_repo.write(&record).map_err(|err| {
150 ToolError::Internal(format!(
151 "failed to write embedding for memory {}: {err}",
152 memory.id
153 ))
154 })?;
155
156 enriched += 1;
157 }
158
159 Ok(json!({
160 "enriched": enriched,
161 "skipped": skipped,
162 "total": total,
163 "backend": backend_id,
164 "preview": preview,
165 }))
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use std::sync::Mutex;
172
173 use super::*;
174
175 fn make_pool() -> Arc<Mutex<Pool>> {
176 let pool = cortex_store::Pool::open_in_memory().expect("in-memory sqlite");
177 cortex_store::migrate::apply_pending(&pool).expect("in-memory migrations");
178 Arc::new(Mutex::new(pool))
179 }
180
181 fn make_tool() -> CortexMemoryEmbedTool {
182 CortexMemoryEmbedTool::new(make_pool())
183 }
184
185 #[test]
186 fn gate_set_declares_session_write() {
187 let tool = make_tool();
188 assert!(
189 tool.gate_set().contains(&GateId::SessionWrite),
190 "gate_set must include SessionWrite"
191 );
192 }
193
194 #[test]
195 fn tool_name_matches_schema_contract() {
196 let tool = make_tool();
197 assert_eq!(tool.name(), "cortex_memory_embed");
198 }
199
200 #[test]
201 fn empty_store_returns_zero_counts() {
202 let tool = make_tool();
203 let result = tool
204 .call(serde_json::json!({}))
205 .expect("empty store must succeed");
206
207 assert_eq!(result["total"], 0);
208 assert_eq!(result["enriched"], 0);
209 assert_eq!(result["skipped"], 0);
210 assert_eq!(result["preview"], false);
211 assert!(result["backend"].as_str().is_some());
212 }
213
214 #[test]
215 fn preview_true_does_not_write_embeddings() {
216 let pool = make_pool();
217 let event_id = cortex_core::EventId::new().to_string();
219 let memory_id = cortex_core::MemoryId::new().to_string();
220 {
221 let guard = pool.lock().unwrap();
222 guard
223 .execute(
224 "INSERT INTO events (
225 id, schema_version, observed_at, recorded_at, source_json,
226 event_type, trace_id, session_id, domain_tags_json, payload_json,
227 payload_hash, prev_event_hash, event_hash
228 ) VALUES (
229 ?1, 1, '2026-05-14T00:00:00Z', '2026-05-14T00:00:00Z',
230 '{\"type\":\"tool\",\"name\":\"test\"}', 'cortex.event.tool_result.v1',
231 NULL, NULL, '[]', '{\"fixture\":true}',
232 'pp_test2', NULL, 'eh_test2'
233 );",
234 rusqlite::params![event_id],
235 )
236 .expect("insert event");
237 let source_json = serde_json::json!([event_id]).to_string();
238 guard
239 .execute(
240 "INSERT INTO memories (
241 id, memory_type, status, claim, source_episodes_json,
242 source_events_json, domains_json, salience_json, confidence,
243 authority, applies_when_json, does_not_apply_when_json,
244 created_at, updated_at
245 ) VALUES (
246 ?1, 'semantic', 'active',
247 'Test memory for embed tool.',
248 '[]', ?2, '[]', json_object('score', 0.8), 0.8, 'user',
249 '[]', '[]',
250 '2026-05-14T00:00:00Z', '2026-05-14T00:00:00Z'
251 );",
252 rusqlite::params![memory_id, source_json],
253 )
254 .expect("insert memory");
255 }
256
257 let tool = CortexMemoryEmbedTool::new(Arc::clone(&pool));
258
259 let preview_result = tool
261 .call(serde_json::json!({ "preview": true }))
262 .expect("preview must succeed");
263 assert_eq!(preview_result["total"], 1);
264 assert_eq!(preview_result["enriched"], 1);
265 assert_eq!(preview_result["skipped"], 0);
266 assert_eq!(preview_result["preview"], true);
267
268 let guard = pool.lock().unwrap();
270 let embed_repo = EmbeddingRepo::new(&guard);
271 let mid: cortex_core::MemoryId = memory_id.parse().unwrap();
272 let written = embed_repo.read(&mid, STUB_BACKEND_ID).unwrap();
273 assert!(written.is_none(), "preview must not write embeddings");
274 }
275
276 #[test]
277 fn second_run_skips_already_embedded_memories() {
278 let pool = make_pool();
279 let event_id = cortex_core::EventId::new().to_string();
280 let memory_id = cortex_core::MemoryId::new().to_string();
281 {
282 let guard = pool.lock().unwrap();
283 guard
284 .execute(
285 "INSERT INTO events (
286 id, schema_version, observed_at, recorded_at, source_json,
287 event_type, trace_id, session_id, domain_tags_json, payload_json,
288 payload_hash, prev_event_hash, event_hash
289 ) VALUES (
290 ?1, 1, '2026-05-14T00:00:00Z', '2026-05-14T00:00:00Z',
291 '{\"type\":\"tool\",\"name\":\"test\"}', 'cortex.event.tool_result.v1',
292 NULL, NULL, '[]', '{\"fixture\":true}',
293 'pp_test3', NULL, 'eh_test3'
294 );",
295 rusqlite::params![event_id],
296 )
297 .expect("insert event");
298 let source_json = serde_json::json!([event_id]).to_string();
299 guard
300 .execute(
301 "INSERT INTO memories (
302 id, memory_type, status, claim, source_episodes_json,
303 source_events_json, domains_json, salience_json, confidence,
304 authority, applies_when_json, does_not_apply_when_json,
305 created_at, updated_at
306 ) VALUES (
307 ?1, 'semantic', 'active',
308 'Second embed test memory.',
309 '[]', ?2, '[]', json_object('score', 0.8), 0.8, 'user',
310 '[]', '[]',
311 '2026-05-14T00:00:00Z', '2026-05-14T00:00:00Z'
312 );",
313 rusqlite::params![memory_id, source_json],
314 )
315 .expect("insert memory");
316 }
317
318 let tool = CortexMemoryEmbedTool::new(Arc::clone(&pool));
319
320 let first = tool
322 .call(serde_json::json!({}))
323 .expect("first run must succeed");
324 assert_eq!(first["enriched"], 1);
325 assert_eq!(first["skipped"], 0);
326
327 let second = tool
329 .call(serde_json::json!({}))
330 .expect("second run must succeed");
331 assert_eq!(second["enriched"], 0);
332 assert_eq!(second["skipped"], 1);
333 }
334}