cortex_reflect/
principles.rs1use std::collections::{BTreeMap, BTreeSet};
10
11use cortex_core::MemoryId;
12use cortex_llm::{LlmAdapter, LlmError, LlmMessage, LlmRequest, LlmResponse, LlmRole};
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15
16use crate::parse::{parse_principle_candidates, principle_candidate_batch_json_schema};
17use crate::schema::PrincipleCandidate;
18use crate::ReflectError;
19
20pub const MIN_SUPPORTING_MEMORIES: usize = 3;
22
23pub const MIN_SUPPORTING_DOMAINS: usize = 2;
25
26pub const DEFAULT_PRINCIPLE_EXTRACTION_MODEL: &str = "replay-principles-v1";
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub struct AcceptedMemory {
32 pub id: MemoryId,
34 pub claim: String,
36 pub domains: Vec<String>,
38 pub applies_when: Vec<String>,
40 pub does_not_apply_when: Vec<String>,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
46pub struct PrincipleExtractionWindow {
47 pub accepted_memories: Vec<AcceptedMemory>,
49}
50
51impl PrincipleExtractionWindow {
52 #[must_use]
54 pub fn new(accepted_memories: Vec<AcceptedMemory>) -> Self {
55 Self { accepted_memories }
56 }
57
58 fn support_index(&self) -> BTreeMap<MemoryId, BTreeSet<String>> {
59 self.accepted_memories
60 .iter()
61 .map(|memory| {
62 (
63 memory.id,
64 memory
65 .domains
66 .iter()
67 .filter_map(|domain| normalized_domain(domain))
68 .collect(),
69 )
70 })
71 .collect()
72 }
73}
74
75#[derive(Debug, Error)]
77pub enum PrincipleExtractionError {
78 #[error("llm adapter failed: {0}")]
80 Adapter(#[from] LlmError),
81 #[error("principle candidate parse failed: {0}")]
83 Parse(#[from] ReflectError),
84 #[error("principle extraction window serialization failed: {0}")]
86 WindowSerialization(#[from] serde_json::Error),
87}
88
89pub async fn extract_candidates(
95 window: PrincipleExtractionWindow,
96 adapter: &dyn LlmAdapter,
97) -> Result<Vec<PrincipleCandidate>, PrincipleExtractionError> {
98 let support_index = window.support_index();
99 if !window_meets_threshold(&support_index) {
100 return Ok(Vec::new());
101 }
102
103 let response = adapter.complete(extraction_request(&window)?).await?;
104 let output = response_output(&response);
105 let batch = parse_principle_candidates(&output)?;
106
107 Ok(batch
108 .candidate_principles
109 .into_iter()
110 .filter(|candidate| candidate_meets_threshold(candidate, &support_index))
111 .collect())
112}
113
114#[must_use]
120pub fn extract_deterministic_candidates(
121 window: &PrincipleExtractionWindow,
122) -> Vec<PrincipleCandidate> {
123 let support_index = window.support_index();
124 if !window_meets_threshold(&support_index) {
125 return Vec::new();
126 }
127
128 let supporting_memory_ids = window
129 .accepted_memories
130 .iter()
131 .map(|memory| memory.id)
132 .collect::<Vec<_>>();
133 let domains_observed = support_index
134 .values()
135 .flat_map(|domains| domains.iter().cloned())
136 .collect::<BTreeSet<_>>()
137 .into_iter()
138 .collect::<Vec<_>>();
139 let applies_when = collect_scope(
140 window
141 .accepted_memories
142 .iter()
143 .flat_map(|memory| memory.applies_when.iter()),
144 );
145 let does_not_apply_when = collect_scope(
146 window
147 .accepted_memories
148 .iter()
149 .flat_map(|memory| memory.does_not_apply_when.iter()),
150 );
151
152 vec![PrincipleCandidate {
153 statement: format!(
154 "Preserve evidence-bound guidance across {} before doctrine promotion.",
155 domains_observed.join(", ")
156 ),
157 supporting_memory_ids,
158 contradicting_memory_ids: Vec::new(),
159 domains_observed,
160 applies_when: fallback_scope(applies_when, "the same pattern recurs across domains"),
161 does_not_apply_when: fallback_scope(does_not_apply_when, "support is below threshold"),
162 alternative_interpretations: vec![
163 "The pattern may be local to the current active memory window.".to_string(),
164 ],
165 confidence: 0.7,
166 overgeneralisation_risk: 0.3,
167 }]
168}
169
170fn extraction_request(window: &PrincipleExtractionWindow) -> Result<LlmRequest, serde_json::Error> {
171 let window_json = serde_json::to_string_pretty(window)?;
172 Ok(LlmRequest {
173 model: DEFAULT_PRINCIPLE_EXTRACTION_MODEL.to_string(),
174 system: "Return PrincipleCandidateBatch JSON matching the supplied schema. Propose candidates only; do not promote doctrine.".to_string(),
175 messages: vec![LlmMessage {
176 role: LlmRole::User,
177 content: format!(
178 "Extract cross-domain principle candidates from these accepted memories. Each candidate must cite at least {MIN_SUPPORTING_MEMORIES} supporting memories across at least {MIN_SUPPORTING_DOMAINS} domains.\n\n{window_json}"
179 ),
180 }],
181 temperature: 0.0,
182 max_tokens: 4096,
183 json_schema: Some(principle_candidate_batch_json_schema()),
184 timeout_ms: 30_000,
185 })
186}
187
188fn response_output(response: &LlmResponse) -> String {
189 response
190 .parsed_json
191 .as_ref()
192 .map_or_else(|| response.text.clone(), serde_json::Value::to_string)
193}
194
195fn window_meets_threshold(support_index: &BTreeMap<MemoryId, BTreeSet<String>>) -> bool {
196 if support_index.len() < MIN_SUPPORTING_MEMORIES {
197 return false;
198 }
199
200 let domains = support_index
201 .values()
202 .flat_map(|domains| domains.iter())
203 .collect::<BTreeSet<_>>();
204 domains.len() >= MIN_SUPPORTING_DOMAINS
205}
206
207fn candidate_meets_threshold(
208 candidate: &PrincipleCandidate,
209 support_index: &BTreeMap<MemoryId, BTreeSet<String>>,
210) -> bool {
211 let supporting_ids = candidate
212 .supporting_memory_ids
213 .iter()
214 .copied()
215 .collect::<BTreeSet<_>>();
216 if supporting_ids.len() < MIN_SUPPORTING_MEMORIES {
217 return false;
218 }
219
220 let mut supporting_domains = BTreeSet::new();
221 for id in &supporting_ids {
222 let Some(domains) = support_index.get(id) else {
223 return false;
224 };
225 supporting_domains.extend(domains.iter().cloned());
226 }
227
228 let candidate_domains = candidate
229 .domains_observed
230 .iter()
231 .filter_map(|domain| normalized_domain(domain))
232 .collect::<BTreeSet<_>>();
233
234 supporting_domains.len() >= MIN_SUPPORTING_DOMAINS
235 && candidate_domains.len() >= MIN_SUPPORTING_DOMAINS
236}
237
238fn normalized_domain(domain: &str) -> Option<String> {
239 let trimmed = domain.trim();
240 if trimmed.is_empty() {
241 None
242 } else {
243 Some(trimmed.to_ascii_lowercase())
244 }
245}
246
247fn collect_scope<'a>(items: impl Iterator<Item = &'a String>) -> Vec<String> {
248 items
249 .filter_map(|item| {
250 let trimmed = item.trim();
251 if trimmed.is_empty() {
252 None
253 } else {
254 Some(trimmed.to_string())
255 }
256 })
257 .collect::<BTreeSet<_>>()
258 .into_iter()
259 .collect()
260}
261
262fn fallback_scope(mut scope: Vec<String>, fallback: &str) -> Vec<String> {
263 if scope.is_empty() {
264 scope.push(fallback.to_string());
265 }
266 scope
267}