use std::sync::{Arc, Mutex};
use cortex_reflect::{extract_deterministic_candidates, AcceptedMemory, PrincipleExtractionWindow};
use cortex_store::repo::MemoryRepo;
use cortex_store::Pool;
use serde_json::{json, Value};
use crate::tool_handler::{GateId, ToolError, ToolHandler};
const DEFAULT_LIMIT: usize = 10;
const MAX_LIMIT: usize = 50;
#[derive(Debug)]
pub struct CortexPrinciplesExtractTool {
pool: Arc<Mutex<Pool>>,
}
impl CortexPrinciplesExtractTool {
#[must_use]
pub fn new(pool: Arc<Mutex<Pool>>) -> Self {
Self { pool }
}
}
impl ToolHandler for CortexPrinciplesExtractTool {
fn name(&self) -> &'static str {
"cortex_principles_extract"
}
fn gate_set(&self) -> &'static [GateId] {
&[GateId::FtsRead]
}
fn call(&self, params: Value) -> Result<Value, ToolError> {
tracing::info!("cortex_principles_extract called via MCP");
let limit = extract_limit(¶ms)?;
let domain_filter = extract_domains(¶ms)?;
let pool = self
.pool
.lock()
.map_err(|err| ToolError::Internal(format!("pool lock poisoned: {err}")))?;
let repo = MemoryRepo::new(&pool);
let memories = repo.list_by_status("active").map_err(|err| {
tracing::error!(error = %err, "cortex_principles_extract: failed to read active memories");
ToolError::Internal(format!("failed to read active memories: {err}"))
})?;
let accepted: Vec<AcceptedMemory> = memories
.into_iter()
.map(|m| AcceptedMemory {
id: m.id,
claim: m.claim,
domains: json_string_array(&m.domains_json),
applies_when: json_string_array(&m.applies_when_json),
does_not_apply_when: json_string_array(&m.does_not_apply_when_json),
})
.collect();
let window = PrincipleExtractionWindow::new(accepted);
let mut candidates = extract_deterministic_candidates(&window);
if !domain_filter.is_empty() {
let normalized: Vec<String> = domain_filter
.iter()
.map(|d| d.trim().to_ascii_lowercase())
.collect();
candidates.retain(|c| {
c.domains_observed
.iter()
.any(|d| normalized.contains(&d.trim().to_ascii_lowercase()))
});
}
let result: Vec<Value> = candidates
.into_iter()
.take(limit)
.enumerate()
.map(|(i, c)| {
json!({
"id": format!("candidate-{i}"),
"claim": c.statement,
"confidence": c.confidence,
"domains": c.domains_observed,
})
})
.collect();
let count = result.len();
Ok(json!({
"candidates": result,
"count": count,
}))
}
}
fn extract_limit(params: &Value) -> Result<usize, ToolError> {
match params.get("limit") {
None | Some(Value::Null) => Ok(DEFAULT_LIMIT),
Some(v) => {
let n = v.as_u64().ok_or_else(|| {
ToolError::InvalidParams("limit must be a non-negative integer".to_string())
})?;
let n = usize::try_from(n).unwrap_or(MAX_LIMIT);
Ok(n.min(MAX_LIMIT))
}
}
}
fn extract_domains(params: &Value) -> Result<Vec<String>, ToolError> {
match params.get("domains") {
None | Some(Value::Null) => Ok(Vec::new()),
Some(Value::Array(arr)) => {
let mut tags = Vec::with_capacity(arr.len());
for (i, v) in arr.iter().enumerate() {
match v.as_str() {
Some(s) => tags.push(s.to_owned()),
None => {
return Err(ToolError::InvalidParams(format!(
"domains[{i}] must be a string"
)));
}
}
}
Ok(tags)
}
Some(other) => Err(ToolError::InvalidParams(format!(
"domains must be an array of strings, got {other}"
))),
}
}
fn json_string_array(value: &Value) -> Vec<String> {
value
.as_array()
.into_iter()
.flatten()
.filter_map(|v| v.as_str().map(ToOwned::to_owned))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tool() -> CortexPrinciplesExtractTool {
let pool = rusqlite::Connection::open_in_memory().expect("in-memory sqlite");
cortex_store::migrate::apply_pending(&pool).expect("migrations");
CortexPrinciplesExtractTool::new(Arc::new(Mutex::new(pool)))
}
#[test]
fn name_and_gate() {
let tool = make_tool();
assert_eq!(tool.name(), "cortex_principles_extract");
assert_eq!(tool.gate_set(), &[GateId::FtsRead]);
}
#[test]
fn empty_store_returns_empty_candidates() {
let tool = make_tool();
let result = tool.call(Value::Null).unwrap();
assert_eq!(result["candidates"], json!([]));
assert_eq!(result["count"], 0);
}
#[test]
fn limit_defaults_to_ten() {
assert_eq!(extract_limit(&json!({})).unwrap(), DEFAULT_LIMIT);
}
#[test]
fn limit_capped_at_fifty() {
assert_eq!(extract_limit(&json!({"limit": 999})).unwrap(), MAX_LIMIT);
}
#[test]
fn limit_rejects_non_integer() {
let err = extract_limit(&json!({"limit": "bad"})).unwrap_err();
assert!(matches!(err, ToolError::InvalidParams(_)));
}
#[test]
fn domains_accepts_empty_array() {
let tags = extract_domains(&json!({"domains": []})).unwrap();
assert!(tags.is_empty());
}
#[test]
fn domains_rejects_non_string_element() {
let err = extract_domains(&json!({"domains": [42]})).unwrap_err();
assert!(matches!(err, ToolError::InvalidParams(_)));
}
#[test]
fn domains_rejects_non_array() {
let err = extract_domains(&json!({"domains": "bad"})).unwrap_err();
assert!(matches!(err, ToolError::InvalidParams(_)));
}
}