cortex_mcp/tools/
principles_extract.rs1use std::sync::{Arc, Mutex};
15
16use cortex_reflect::{extract_deterministic_candidates, AcceptedMemory, PrincipleExtractionWindow};
17use cortex_store::repo::MemoryRepo;
18use cortex_store::Pool;
19use serde_json::{json, Value};
20
21use crate::tool_handler::{GateId, ToolError, ToolHandler};
22
23const DEFAULT_LIMIT: usize = 10;
25
26const MAX_LIMIT: usize = 50;
28
29#[derive(Debug)]
42pub struct CortexPrinciplesExtractTool {
43 pool: Arc<Mutex<Pool>>,
44}
45
46impl CortexPrinciplesExtractTool {
47 #[must_use]
49 pub fn new(pool: Arc<Mutex<Pool>>) -> Self {
50 Self { pool }
51 }
52}
53
54impl ToolHandler for CortexPrinciplesExtractTool {
55 fn name(&self) -> &'static str {
56 "cortex_principles_extract"
57 }
58
59 fn gate_set(&self) -> &'static [GateId] {
60 &[GateId::FtsRead]
61 }
62
63 fn call(&self, params: Value) -> Result<Value, ToolError> {
64 tracing::info!("cortex_principles_extract called via MCP");
65
66 let limit = extract_limit(¶ms)?;
67 let domain_filter = extract_domains(¶ms)?;
68
69 let pool = self
70 .pool
71 .lock()
72 .map_err(|err| ToolError::Internal(format!("pool lock poisoned: {err}")))?;
73
74 let repo = MemoryRepo::new(&pool);
75 let memories = repo.list_by_status("active").map_err(|err| {
76 tracing::error!(error = %err, "cortex_principles_extract: failed to read active memories");
77 ToolError::Internal(format!("failed to read active memories: {err}"))
78 })?;
79
80 let accepted: Vec<AcceptedMemory> = memories
81 .into_iter()
82 .map(|m| AcceptedMemory {
83 id: m.id,
84 claim: m.claim,
85 domains: json_string_array(&m.domains_json),
86 applies_when: json_string_array(&m.applies_when_json),
87 does_not_apply_when: json_string_array(&m.does_not_apply_when_json),
88 })
89 .collect();
90
91 let window = PrincipleExtractionWindow::new(accepted);
92 let mut candidates = extract_deterministic_candidates(&window);
93
94 if !domain_filter.is_empty() {
97 let normalized: Vec<String> = domain_filter
98 .iter()
99 .map(|d| d.trim().to_ascii_lowercase())
100 .collect();
101 candidates.retain(|c| {
102 c.domains_observed
103 .iter()
104 .any(|d| normalized.contains(&d.trim().to_ascii_lowercase()))
105 });
106 }
107
108 let result: Vec<Value> = candidates
109 .into_iter()
110 .take(limit)
111 .enumerate()
112 .map(|(i, c)| {
113 json!({
114 "id": format!("candidate-{i}"),
115 "claim": c.statement,
116 "confidence": c.confidence,
117 "domains": c.domains_observed,
118 })
119 })
120 .collect();
121
122 let count = result.len();
123 Ok(json!({
124 "candidates": result,
125 "count": count,
126 }))
127 }
128}
129
130fn extract_limit(params: &Value) -> Result<usize, ToolError> {
131 match params.get("limit") {
132 None | Some(Value::Null) => Ok(DEFAULT_LIMIT),
133 Some(v) => {
134 let n = v.as_u64().ok_or_else(|| {
135 ToolError::InvalidParams("limit must be a non-negative integer".to_string())
136 })?;
137 let n = usize::try_from(n).unwrap_or(MAX_LIMIT);
138 Ok(n.min(MAX_LIMIT))
139 }
140 }
141}
142
143fn extract_domains(params: &Value) -> Result<Vec<String>, ToolError> {
144 match params.get("domains") {
145 None | Some(Value::Null) => Ok(Vec::new()),
146 Some(Value::Array(arr)) => {
147 let mut tags = Vec::with_capacity(arr.len());
148 for (i, v) in arr.iter().enumerate() {
149 match v.as_str() {
150 Some(s) => tags.push(s.to_owned()),
151 None => {
152 return Err(ToolError::InvalidParams(format!(
153 "domains[{i}] must be a string"
154 )));
155 }
156 }
157 }
158 Ok(tags)
159 }
160 Some(other) => Err(ToolError::InvalidParams(format!(
161 "domains must be an array of strings, got {other}"
162 ))),
163 }
164}
165
166fn json_string_array(value: &Value) -> Vec<String> {
167 value
168 .as_array()
169 .into_iter()
170 .flatten()
171 .filter_map(|v| v.as_str().map(ToOwned::to_owned))
172 .collect()
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 fn make_tool() -> CortexPrinciplesExtractTool {
180 let pool = rusqlite::Connection::open_in_memory().expect("in-memory sqlite");
181 cortex_store::migrate::apply_pending(&pool).expect("migrations");
182 CortexPrinciplesExtractTool::new(Arc::new(Mutex::new(pool)))
183 }
184
185 #[test]
186 fn name_and_gate() {
187 let tool = make_tool();
188 assert_eq!(tool.name(), "cortex_principles_extract");
189 assert_eq!(tool.gate_set(), &[GateId::FtsRead]);
190 }
191
192 #[test]
193 fn empty_store_returns_empty_candidates() {
194 let tool = make_tool();
195 let result = tool.call(Value::Null).unwrap();
196 assert_eq!(result["candidates"], json!([]));
197 assert_eq!(result["count"], 0);
198 }
199
200 #[test]
201 fn limit_defaults_to_ten() {
202 assert_eq!(extract_limit(&json!({})).unwrap(), DEFAULT_LIMIT);
203 }
204
205 #[test]
206 fn limit_capped_at_fifty() {
207 assert_eq!(extract_limit(&json!({"limit": 999})).unwrap(), MAX_LIMIT);
208 }
209
210 #[test]
211 fn limit_rejects_non_integer() {
212 let err = extract_limit(&json!({"limit": "bad"})).unwrap_err();
213 assert!(matches!(err, ToolError::InvalidParams(_)));
214 }
215
216 #[test]
217 fn domains_accepts_empty_array() {
218 let tags = extract_domains(&json!({"domains": []})).unwrap();
219 assert!(tags.is_empty());
220 }
221
222 #[test]
223 fn domains_rejects_non_string_element() {
224 let err = extract_domains(&json!({"domains": [42]})).unwrap_err();
225 assert!(matches!(err, ToolError::InvalidParams(_)));
226 }
227
228 #[test]
229 fn domains_rejects_non_array() {
230 let err = extract_domains(&json!({"domains": "bad"})).unwrap_err();
231 assert!(matches!(err, ToolError::InvalidParams(_)));
232 }
233}