use std::sync::{Arc, Mutex};
use cortex_retrieval::{EmbedRecord, Embedder, LocalStubEmbedder, STUB_BACKEND_ID};
use cortex_store::repo::{EmbeddingRepo, MemoryRepo};
use cortex_store::Pool;
use serde_json::{json, Value};
use crate::{GateId, ToolError, ToolHandler};
#[derive(Debug)]
pub struct CortexMemoryEmbedTool {
pool: Arc<Mutex<Pool>>,
}
impl CortexMemoryEmbedTool {
#[must_use]
pub fn new(pool: Arc<Mutex<Pool>>) -> Self {
Self { pool }
}
}
impl ToolHandler for CortexMemoryEmbedTool {
fn name(&self) -> &'static str {
"cortex_memory_embed"
}
fn gate_set(&self) -> &'static [GateId] {
&[GateId::SessionWrite]
}
fn call(&self, params: Value) -> Result<Value, ToolError> {
let preview = params
.get("preview")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let model_hint = params
.get("model")
.and_then(|v| v.as_str())
.map(ToOwned::to_owned);
tracing::info!(
preview = %preview,
"cortex_memory_embed via MCP: preview={}", preview
);
if let Some(ref m) = model_hint {
tracing::info!(
model = %m,
"cortex_memory_embed: model hint supplied (stub backend only for now)"
);
}
let pool_guard = self
.pool
.lock()
.map_err(|_| ToolError::Internal("pool lock poisoned".into()))?;
let embedder = LocalStubEmbedder::new();
let backend_id = embedder.backend_id().to_owned();
let repo = MemoryRepo::new(&pool_guard);
let memories = repo
.list_by_status("active")
.map_err(|err| ToolError::Internal(format!("failed to read active memories: {err}")))?;
let total = memories.len();
let embed_repo = EmbeddingRepo::new(&pool_guard);
let now = chrono::Utc::now();
let mut enriched: usize = 0;
let mut skipped: usize = 0;
for memory in &memories {
let existing = embed_repo.read(&memory.id, &backend_id).map_err(|err| {
ToolError::Internal(format!(
"failed to read embedding for memory {}: {err}",
memory.id
))
})?;
if existing.is_some() {
skipped += 1;
continue;
}
if preview {
enriched += 1;
continue;
}
let tags: Vec<String> = memory
.domains_json
.as_array()
.into_iter()
.flatten()
.filter_map(|v| v.as_str().map(ToOwned::to_owned))
.collect();
let vec = embedder.embed(&memory.claim, &tags).map_err(|err| {
ToolError::Internal(format!("embed failed for memory {}: {err}", memory.id))
})?;
let record = EmbedRecord::new(memory.id, STUB_BACKEND_ID, vec, now).map_err(|err| {
ToolError::Internal(format!(
"failed to build embed record for memory {}: {err}",
memory.id
))
})?;
embed_repo.write(&record).map_err(|err| {
ToolError::Internal(format!(
"failed to write embedding for memory {}: {err}",
memory.id
))
})?;
enriched += 1;
}
Ok(json!({
"enriched": enriched,
"skipped": skipped,
"total": total,
"backend": backend_id,
"preview": preview,
}))
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
fn make_pool() -> Arc<Mutex<Pool>> {
let pool = cortex_store::Pool::open_in_memory().expect("in-memory sqlite");
cortex_store::migrate::apply_pending(&pool).expect("in-memory migrations");
Arc::new(Mutex::new(pool))
}
fn make_tool() -> CortexMemoryEmbedTool {
CortexMemoryEmbedTool::new(make_pool())
}
#[test]
fn gate_set_declares_session_write() {
let tool = make_tool();
assert!(
tool.gate_set().contains(&GateId::SessionWrite),
"gate_set must include SessionWrite"
);
}
#[test]
fn tool_name_matches_schema_contract() {
let tool = make_tool();
assert_eq!(tool.name(), "cortex_memory_embed");
}
#[test]
fn empty_store_returns_zero_counts() {
let tool = make_tool();
let result = tool
.call(serde_json::json!({}))
.expect("empty store must succeed");
assert_eq!(result["total"], 0);
assert_eq!(result["enriched"], 0);
assert_eq!(result["skipped"], 0);
assert_eq!(result["preview"], false);
assert!(result["backend"].as_str().is_some());
}
#[test]
fn preview_true_does_not_write_embeddings() {
let pool = make_pool();
let event_id = cortex_core::EventId::new().to_string();
let memory_id = cortex_core::MemoryId::new().to_string();
{
let guard = pool.lock().unwrap();
guard
.execute(
"INSERT INTO events (
id, schema_version, observed_at, recorded_at, source_json,
event_type, trace_id, session_id, domain_tags_json, payload_json,
payload_hash, prev_event_hash, event_hash
) VALUES (
?1, 1, '2026-05-14T00:00:00Z', '2026-05-14T00:00:00Z',
'{\"type\":\"tool\",\"name\":\"test\"}', 'cortex.event.tool_result.v1',
NULL, NULL, '[]', '{\"fixture\":true}',
'pp_test2', NULL, 'eh_test2'
);",
rusqlite::params![event_id],
)
.expect("insert event");
let source_json = serde_json::json!([event_id]).to_string();
guard
.execute(
"INSERT INTO memories (
id, memory_type, status, claim, source_episodes_json,
source_events_json, domains_json, salience_json, confidence,
authority, applies_when_json, does_not_apply_when_json,
created_at, updated_at
) VALUES (
?1, 'semantic', 'active',
'Test memory for embed tool.',
'[]', ?2, '[]', json_object('score', 0.8), 0.8, 'user',
'[]', '[]',
'2026-05-14T00:00:00Z', '2026-05-14T00:00:00Z'
);",
rusqlite::params![memory_id, source_json],
)
.expect("insert memory");
}
let tool = CortexMemoryEmbedTool::new(Arc::clone(&pool));
let preview_result = tool
.call(serde_json::json!({ "preview": true }))
.expect("preview must succeed");
assert_eq!(preview_result["total"], 1);
assert_eq!(preview_result["enriched"], 1);
assert_eq!(preview_result["skipped"], 0);
assert_eq!(preview_result["preview"], true);
let guard = pool.lock().unwrap();
let embed_repo = EmbeddingRepo::new(&guard);
let mid: cortex_core::MemoryId = memory_id.parse().unwrap();
let written = embed_repo.read(&mid, STUB_BACKEND_ID).unwrap();
assert!(written.is_none(), "preview must not write embeddings");
}
#[test]
fn second_run_skips_already_embedded_memories() {
let pool = make_pool();
let event_id = cortex_core::EventId::new().to_string();
let memory_id = cortex_core::MemoryId::new().to_string();
{
let guard = pool.lock().unwrap();
guard
.execute(
"INSERT INTO events (
id, schema_version, observed_at, recorded_at, source_json,
event_type, trace_id, session_id, domain_tags_json, payload_json,
payload_hash, prev_event_hash, event_hash
) VALUES (
?1, 1, '2026-05-14T00:00:00Z', '2026-05-14T00:00:00Z',
'{\"type\":\"tool\",\"name\":\"test\"}', 'cortex.event.tool_result.v1',
NULL, NULL, '[]', '{\"fixture\":true}',
'pp_test3', NULL, 'eh_test3'
);",
rusqlite::params![event_id],
)
.expect("insert event");
let source_json = serde_json::json!([event_id]).to_string();
guard
.execute(
"INSERT INTO memories (
id, memory_type, status, claim, source_episodes_json,
source_events_json, domains_json, salience_json, confidence,
authority, applies_when_json, does_not_apply_when_json,
created_at, updated_at
) VALUES (
?1, 'semantic', 'active',
'Second embed test memory.',
'[]', ?2, '[]', json_object('score', 0.8), 0.8, 'user',
'[]', '[]',
'2026-05-14T00:00:00Z', '2026-05-14T00:00:00Z'
);",
rusqlite::params![memory_id, source_json],
)
.expect("insert memory");
}
let tool = CortexMemoryEmbedTool::new(Arc::clone(&pool));
let first = tool
.call(serde_json::json!({}))
.expect("first run must succeed");
assert_eq!(first["enriched"], 1);
assert_eq!(first["skipped"], 0);
let second = tool
.call(serde_json::json!({}))
.expect("second run must succeed");
assert_eq!(second["enriched"], 0);
assert_eq!(second["skipped"], 1);
}
}