use std::collections::{BTreeMap, BTreeSet};
use cortex_core::MemoryId;
use cortex_llm::{LlmAdapter, LlmError, LlmMessage, LlmRequest, LlmResponse, LlmRole};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::parse::{parse_principle_candidates, principle_candidate_batch_json_schema};
use crate::schema::PrincipleCandidate;
use crate::ReflectError;
pub const MIN_SUPPORTING_MEMORIES: usize = 3;
pub const MIN_SUPPORTING_DOMAINS: usize = 2;
pub const DEFAULT_PRINCIPLE_EXTRACTION_MODEL: &str = "replay-principles-v1";
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AcceptedMemory {
pub id: MemoryId,
pub claim: String,
pub domains: Vec<String>,
pub applies_when: Vec<String>,
pub does_not_apply_when: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PrincipleExtractionWindow {
pub accepted_memories: Vec<AcceptedMemory>,
}
impl PrincipleExtractionWindow {
#[must_use]
pub fn new(accepted_memories: Vec<AcceptedMemory>) -> Self {
Self { accepted_memories }
}
fn support_index(&self) -> BTreeMap<MemoryId, BTreeSet<String>> {
self.accepted_memories
.iter()
.map(|memory| {
(
memory.id,
memory
.domains
.iter()
.filter_map(|domain| normalized_domain(domain))
.collect(),
)
})
.collect()
}
}
#[derive(Debug, Error)]
pub enum PrincipleExtractionError {
#[error("llm adapter failed: {0}")]
Adapter(#[from] LlmError),
#[error("principle candidate parse failed: {0}")]
Parse(#[from] ReflectError),
#[error("principle extraction window serialization failed: {0}")]
WindowSerialization(#[from] serde_json::Error),
}
pub async fn extract_candidates(
window: PrincipleExtractionWindow,
adapter: &dyn LlmAdapter,
) -> Result<Vec<PrincipleCandidate>, PrincipleExtractionError> {
let support_index = window.support_index();
if !window_meets_threshold(&support_index) {
return Ok(Vec::new());
}
let response = adapter.complete(extraction_request(&window)?).await?;
let output = response_output(&response);
let batch = parse_principle_candidates(&output)?;
Ok(batch
.candidate_principles
.into_iter()
.filter(|candidate| candidate_meets_threshold(candidate, &support_index))
.collect())
}
#[must_use]
pub fn extract_deterministic_candidates(
window: &PrincipleExtractionWindow,
) -> Vec<PrincipleCandidate> {
let support_index = window.support_index();
if !window_meets_threshold(&support_index) {
return Vec::new();
}
let supporting_memory_ids = window
.accepted_memories
.iter()
.map(|memory| memory.id)
.collect::<Vec<_>>();
let domains_observed = support_index
.values()
.flat_map(|domains| domains.iter().cloned())
.collect::<BTreeSet<_>>()
.into_iter()
.collect::<Vec<_>>();
let applies_when = collect_scope(
window
.accepted_memories
.iter()
.flat_map(|memory| memory.applies_when.iter()),
);
let does_not_apply_when = collect_scope(
window
.accepted_memories
.iter()
.flat_map(|memory| memory.does_not_apply_when.iter()),
);
vec![PrincipleCandidate {
statement: format!(
"Preserve evidence-bound guidance across {} before doctrine promotion.",
domains_observed.join(", ")
),
supporting_memory_ids,
contradicting_memory_ids: Vec::new(),
domains_observed,
applies_when: fallback_scope(applies_when, "the same pattern recurs across domains"),
does_not_apply_when: fallback_scope(does_not_apply_when, "support is below threshold"),
alternative_interpretations: vec![
"The pattern may be local to the current active memory window.".to_string(),
],
confidence: 0.7,
overgeneralisation_risk: 0.3,
}]
}
fn extraction_request(window: &PrincipleExtractionWindow) -> Result<LlmRequest, serde_json::Error> {
let window_json = serde_json::to_string_pretty(window)?;
Ok(LlmRequest {
model: DEFAULT_PRINCIPLE_EXTRACTION_MODEL.to_string(),
system: "Return PrincipleCandidateBatch JSON matching the supplied schema. Propose candidates only; do not promote doctrine.".to_string(),
messages: vec![LlmMessage {
role: LlmRole::User,
content: format!(
"Extract cross-domain principle candidates from these accepted memories. Each candidate must cite at least {MIN_SUPPORTING_MEMORIES} supporting memories across at least {MIN_SUPPORTING_DOMAINS} domains.\n\n{window_json}"
),
}],
temperature: 0.0,
max_tokens: 4096,
json_schema: Some(principle_candidate_batch_json_schema()),
timeout_ms: 30_000,
})
}
fn response_output(response: &LlmResponse) -> String {
response
.parsed_json
.as_ref()
.map_or_else(|| response.text.clone(), serde_json::Value::to_string)
}
fn window_meets_threshold(support_index: &BTreeMap<MemoryId, BTreeSet<String>>) -> bool {
if support_index.len() < MIN_SUPPORTING_MEMORIES {
return false;
}
let domains = support_index
.values()
.flat_map(|domains| domains.iter())
.collect::<BTreeSet<_>>();
domains.len() >= MIN_SUPPORTING_DOMAINS
}
fn candidate_meets_threshold(
candidate: &PrincipleCandidate,
support_index: &BTreeMap<MemoryId, BTreeSet<String>>,
) -> bool {
let supporting_ids = candidate
.supporting_memory_ids
.iter()
.copied()
.collect::<BTreeSet<_>>();
if supporting_ids.len() < MIN_SUPPORTING_MEMORIES {
return false;
}
let mut supporting_domains = BTreeSet::new();
for id in &supporting_ids {
let Some(domains) = support_index.get(id) else {
return false;
};
supporting_domains.extend(domains.iter().cloned());
}
let candidate_domains = candidate
.domains_observed
.iter()
.filter_map(|domain| normalized_domain(domain))
.collect::<BTreeSet<_>>();
supporting_domains.len() >= MIN_SUPPORTING_DOMAINS
&& candidate_domains.len() >= MIN_SUPPORTING_DOMAINS
}
fn normalized_domain(domain: &str) -> Option<String> {
let trimmed = domain.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_ascii_lowercase())
}
}
fn collect_scope<'a>(items: impl Iterator<Item = &'a String>) -> Vec<String> {
items
.filter_map(|item| {
let trimmed = item.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
})
.collect::<BTreeSet<_>>()
.into_iter()
.collect()
}
fn fallback_scope(mut scope: Vec<String>, fallback: &str) -> Vec<String> {
if scope.is_empty() {
scope.push(fallback.to_string());
}
scope
}