use crate::config::ProjectConfig;
use crate::domain::{
CandidateNote, LifecycleCandidate, MatchedModule, MatchedProject, MatchedScene, MemoryRecord,
MemoryScope, Note, RouteInput, ScoredNote,
};
use crate::engine::scorer;
use std::collections::HashSet;
const HOP_PENALTY: f64 = 0.7;
const CROSS_PROJECT_PENALTY: f64 = 0.6;
#[cfg(feature = "bm25")]
const RRF_K: f64 = 60.0;
#[cfg(all(feature = "embedding", not(feature = "bm25")))]
const RRF_K: f64 = 60.0;
pub fn select_scored_notes(
project_config: Option<&ProjectConfig>,
project: Option<&MatchedProject>,
modules: &[MatchedModule],
scenes: &[MatchedScene],
notes: &[Note],
input: &RouteInput,
limit: usize,
) -> Vec<ScoredNote> {
let mut scored_notes: Vec<ScoredNote> = notes
.iter()
.filter_map(|note| {
let (score, reasons, score_breakdown, confidence) =
scorer::score_note(project_config, project, modules, scenes, note, input);
if score <= 0 {
return None;
}
Some(ScoredNote {
note: note.clone(),
score,
reasons,
score_breakdown,
confidence,
excerpt: note.excerpt_for_input(input, 220),
})
})
.collect();
scored_notes.sort_by(|left, right| {
right
.score
.cmp(&left.score)
.then_with(|| left.note.relative_path.cmp(&right.note.relative_path))
});
let initial: Vec<ScoredNote> = scored_notes.iter().take(limit).cloned().collect();
let selected_paths: HashSet<String> = initial
.iter()
.map(|s| s.note.relative_path.clone())
.collect();
let mut expand_targets: HashSet<String> = HashSet::new();
for scored in &initial {
for link in &scored.note.wikilinks {
expand_targets.insert(link.to_lowercase());
}
if let Some(related) = scored.note.frontmatter.get("related_memory")
&& let Some(arr) = related.as_array()
{
for item in arr {
if let Some(s) = item.as_str() {
let cleaned = s.trim_start_matches("[[").trim_end_matches("]]");
expand_targets.insert(cleaned.to_lowercase());
}
}
}
}
let mut expanded = initial;
if !expand_targets.is_empty() {
for scored in &scored_notes {
if selected_paths.contains(&scored.note.relative_path) {
continue;
}
let title_lc = scored.note.title.to_lowercase();
let path_lc = scored.note.relative_path.to_lowercase();
let is_related = expand_targets.iter().any(|target| {
title_lc.contains(target) || path_lc.contains(target) || target.contains(&title_lc)
});
if is_related {
let penalized_score = ((scored.score as f64) * HOP_PENALTY) as i32;
if penalized_score > 0 {
let mut expanded_note = scored.clone();
expanded_note.score = penalized_score;
expanded_note.reasons.push(format!(
"relation-expanded (1-hop, {:.0}% penalty)",
(1.0 - HOP_PENALTY) * 100.0
));
expanded.push(expanded_note);
}
}
}
}
expanded.sort_by(|left, right| {
right
.score
.cmp(&left.score)
.then_with(|| left.note.relative_path.cmp(&right.note.relative_path))
});
expanded.truncate(limit);
expanded
}
pub fn select_candidates(
project_config: Option<&ProjectConfig>,
project: Option<&MatchedProject>,
modules: &[MatchedModule],
scenes: &[MatchedScene],
notes: &[Note],
input: &RouteInput,
limit: usize,
) -> Vec<CandidateNote> {
select_scored_notes(
project_config,
project,
modules,
scenes,
notes,
input,
limit,
)
.into_iter()
.map(CandidateNote::from)
.collect()
}
pub fn select_lifecycle_candidates(
project: Option<&MatchedProject>,
records: &[(String, MemoryRecord)],
input: &RouteInput,
limit: usize,
excluded_record_ids: &HashSet<String>,
reference_map: Option<&crate::reference_tracker::ReferenceMap>,
) -> Vec<LifecycleCandidate> {
if limit == 0 {
return Vec::new();
}
let mut candidates: Vec<LifecycleCandidate> = records
.iter()
.filter(|(record_id, _)| !excluded_record_ids.contains(record_id))
.filter_map(|(record_id, record)| {
scorer::score_lifecycle_candidate(
project,
record_id,
record,
input,
reference_map,
Some(records),
)
})
.collect();
if project.is_some() {
for candidate in &mut candidates {
if matches!(
candidate.scope,
MemoryScope::User | MemoryScope::Agent | MemoryScope::Team
) {
let penalized = ((candidate.score as f64) * CROSS_PROJECT_PENALTY) as i32;
if penalized != candidate.score {
candidate.score = penalized;
candidate.reasons.push(format!(
"cross-project penalty ({:.0}%)",
(1.0 - CROSS_PROJECT_PENALTY) * 100.0
));
}
}
}
}
candidates.sort_by(|left, right| {
right
.score
.cmp(&left.score)
.then_with(|| left.record_id.cmp(&right.record_id))
});
let initial: Vec<LifecycleCandidate> = candidates.iter().take(limit).cloned().collect();
let selected_ids: HashSet<String> = initial.iter().map(|c| c.record_id.clone()).collect();
let candidate_ids: HashSet<String> = candidates.iter().map(|c| c.record_id.clone()).collect();
let mut expand_targets: HashSet<String> = HashSet::new();
for candidate in &initial {
if let Some((_, record)) = records.iter().find(|(id, _)| id == &candidate.record_id) {
for related_id in &record.related_records {
if !selected_ids.contains(related_id) && !excluded_record_ids.contains(related_id) {
expand_targets.insert(related_id.clone());
}
}
}
}
let mut expanded = initial;
if !expand_targets.is_empty() {
for target_id in &expand_targets {
if let Some(candidate) = candidates.iter().find(|c| &c.record_id == target_id) {
let penalized_score = ((candidate.score as f64) * HOP_PENALTY) as i32;
if penalized_score > 0 {
let mut expanded_candidate = candidate.clone();
expanded_candidate.score = penalized_score;
expanded_candidate.reasons.push(format!(
"relation-expanded (1-hop, {:.0}% penalty)",
(1.0 - HOP_PENALTY) * 100.0
));
expanded.push(expanded_candidate);
}
} else if !candidate_ids.contains(target_id) {
if let Some((_, record)) = records.iter().find(|(id, _)| id == target_id) {
let referrer_score = expanded
.iter()
.filter(|c| {
records
.iter()
.find(|(id, _)| id == &c.record_id)
.map(|(_, r)| r.related_records.contains(target_id))
.unwrap_or(false)
})
.map(|c| c.score)
.max()
.unwrap_or(0);
let penalized_score = ((referrer_score as f64) * HOP_PENALTY) as i32;
if penalized_score > 0 {
let confidence = crate::domain::ConfidenceTier::Medium;
expanded.push(LifecycleCandidate {
record_id: target_id.clone(),
title: record.title.clone(),
summary: record.summary.clone(),
memory_type: record.memory_type.clone(),
scope: record.scope,
state: record.state,
score: penalized_score,
reasons: vec![format!(
"relation-expanded (1-hop, {:.0}% penalty, no direct score)",
(1.0 - HOP_PENALTY) * 100.0
)],
project_id: record.project_id.clone(),
confidence,
contradicts: Vec::new(),
});
}
}
}
}
}
expanded.sort_by(|left, right| {
right
.score
.cmp(&left.score)
.then_with(|| left.record_id.cmp(&right.record_id))
});
expanded.truncate(limit);
expanded
}
pub fn excluded_record_ids_from_scored(scored: &[ScoredNote]) -> HashSet<String> {
scored
.iter()
.filter_map(|s| {
s.note
.frontmatter
.get("record_id")
.and_then(|v| v.as_str())
.map(ToString::to_string)
})
.collect()
}
pub fn superseded_record_ids(records: &[(String, MemoryRecord)]) -> HashSet<String> {
use crate::domain::MemoryLifecycleState;
let mut superseded: HashSet<String> = HashSet::new();
for (_record_id, record) in records {
if !matches!(
record.state,
MemoryLifecycleState::Accepted | MemoryLifecycleState::Canonical
) {
continue;
}
if record.memory_type == "knowledge" {
for source_id in &record.related_records {
superseded.insert(source_id.clone());
}
}
if let Some(ref replaces) = record.supersedes {
superseded.insert(replaces.clone());
}
}
superseded
}
#[cfg(feature = "bm25")]
pub fn select_lifecycle_candidates_with_bm25(
project: Option<&MatchedProject>,
records: &[(String, MemoryRecord)],
input: &RouteInput,
limit: usize,
excluded_record_ids: &HashSet<String>,
reference_map: Option<&crate::reference_tracker::ReferenceMap>,
bm25_index_path: Option<&std::path::Path>,
) -> Vec<LifecycleCandidate> {
let structured_candidates = select_lifecycle_candidates(
project,
records,
input,
limit * 2,
excluded_record_ids,
reference_map,
);
let Some(index_path) = bm25_index_path else {
let mut result = structured_candidates;
result.truncate(limit);
return result;
};
if !index_path.exists() {
let mut result = structured_candidates;
result.truncate(limit);
return result;
}
let bm25_results = match crate::engine::bm25::Bm25Index::open_or_create(index_path) {
Ok(idx) => idx.search(&input.task, limit * 2).unwrap_or_default(),
Err(_) => {
let mut result = structured_candidates;
result.truncate(limit);
return result;
}
};
if bm25_results.is_empty() {
let mut result = structured_candidates;
result.truncate(limit);
return result;
}
let mut rrf_scores: std::collections::HashMap<String, f64> = std::collections::HashMap::new();
for (rank, candidate) in structured_candidates.iter().enumerate() {
let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
*rrf_scores.entry(candidate.record_id.clone()).or_default() += rrf_score;
}
for (rank, (record_id, _score)) in bm25_results.iter().enumerate() {
if excluded_record_ids.contains(record_id) {
continue;
}
let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
*rrf_scores.entry(record_id.clone()).or_default() += rrf_score;
}
let mut fused: Vec<LifecycleCandidate> = structured_candidates
.into_iter()
.map(|mut c| {
let rrf = rrf_scores.get(&c.record_id).copied().unwrap_or(0.0);
c.score = (rrf * 1000.0) as i32;
c.reasons
.push(format!("RRF fused score (bm25+structured): {:.4}", rrf));
c
})
.collect();
let structured_ids: HashSet<String> = fused.iter().map(|c| c.record_id.clone()).collect();
for (record_id, _bm25_score) in &bm25_results {
if structured_ids.contains(record_id) || excluded_record_ids.contains(record_id) {
continue;
}
if let Some((_, record)) = records.iter().find(|(id, _)| id == record_id) {
let rrf = rrf_scores.get(record_id).copied().unwrap_or(0.0);
let score = (rrf * 1000.0) as i32;
if score > 0 {
fused.push(LifecycleCandidate {
record_id: record_id.clone(),
title: record.title.clone(),
summary: record.summary.clone(),
memory_type: record.memory_type.clone(),
scope: record.scope,
state: record.state,
score,
reasons: vec![format!("BM25-only hit, RRF score: {:.4}", rrf)],
project_id: record.project_id.clone(),
confidence: crate::domain::ConfidenceTier::Medium,
contradicts: Vec::new(),
});
}
}
}
fused.sort_by(|left, right| {
right
.score
.cmp(&left.score)
.then_with(|| left.record_id.cmp(&right.record_id))
});
fused.truncate(limit);
fused
}
#[cfg(feature = "embedding")]
pub fn select_lifecycle_candidates_fused(
project: Option<&MatchedProject>,
records: &[(String, MemoryRecord)],
input: &RouteInput,
limit: usize,
excluded_record_ids: &HashSet<String>,
reference_map: Option<&crate::reference_tracker::ReferenceMap>,
#[cfg(feature = "bm25")] bm25_index_path: Option<&std::path::Path>,
embedding_results: &[(String, f32)],
) -> Vec<LifecycleCandidate> {
let structured_candidates = select_lifecycle_candidates(
project,
records,
input,
limit * 2,
excluded_record_ids,
reference_map,
);
#[cfg(feature = "bm25")]
let bm25_results: Vec<(String, f32)> = bm25_index_path
.filter(|p| p.exists())
.and_then(|p| crate::engine::bm25::Bm25Index::open_or_create(p).ok())
.and_then(|idx| idx.search(&input.task, limit * 2).ok())
.unwrap_or_default();
#[cfg(not(feature = "bm25"))]
let bm25_results: Vec<(String, f32)> = Vec::new();
let has_bm25 = !bm25_results.is_empty();
let has_embedding = !embedding_results.is_empty();
if !has_bm25 && !has_embedding {
let mut result = structured_candidates;
result.truncate(limit);
return result;
}
let mut rrf_scores: std::collections::HashMap<String, f64> = std::collections::HashMap::new();
for (rank, candidate) in structured_candidates.iter().enumerate() {
let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
*rrf_scores.entry(candidate.record_id.clone()).or_default() += rrf_score;
}
for (rank, (record_id, _)) in bm25_results.iter().enumerate() {
if excluded_record_ids.contains(record_id) {
continue;
}
let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
*rrf_scores.entry(record_id.clone()).or_default() += rrf_score;
}
for (rank, (record_id, _)) in embedding_results.iter().enumerate() {
if excluded_record_ids.contains(record_id) {
continue;
}
let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
*rrf_scores.entry(record_id.clone()).or_default() += rrf_score;
}
let mut fused: Vec<LifecycleCandidate> = structured_candidates
.into_iter()
.map(|mut c| {
let rrf = rrf_scores.get(&c.record_id).copied().unwrap_or(0.0);
c.score = (rrf * 1000.0) as i32;
let signals: Vec<&str> = [
Some("structured"),
if has_bm25 { Some("bm25") } else { None },
if has_embedding {
Some("embedding")
} else {
None
},
]
.into_iter()
.flatten()
.collect();
c.reasons
.push(format!("RRF fused ({}): {:.4}", signals.join("+"), rrf));
c
})
.collect();
let structured_ids: HashSet<String> = fused.iter().map(|c| c.record_id.clone()).collect();
let extra_ids: HashSet<String> = bm25_results
.iter()
.chain(embedding_results.iter())
.map(|(id, _)| id.clone())
.filter(|id| !structured_ids.contains(id) && !excluded_record_ids.contains(id))
.collect();
for record_id in &extra_ids {
if let Some((_, record)) = records.iter().find(|(id, _)| id == record_id) {
let rrf = rrf_scores.get(record_id).copied().unwrap_or(0.0);
let score = (rrf * 1000.0) as i32;
if score > 0 {
fused.push(LifecycleCandidate {
record_id: record_id.clone(),
title: record.title.clone(),
summary: record.summary.clone(),
memory_type: record.memory_type.clone(),
scope: record.scope,
state: record.state,
score,
reasons: vec![format!("RRF extra hit: {:.4}", rrf)],
project_id: record.project_id.clone(),
confidence: crate::domain::ConfidenceTier::Medium,
contradicts: Vec::new(),
});
}
}
}
fused.sort_by(|left, right| {
right
.score
.cmp(&left.score)
.then_with(|| left.record_id.cmp(&right.record_id))
});
fused.truncate(limit);
fused
}
#[cfg(test)]
mod tests;