use crate::context::DEFAULT_TOP_K_RULES;
use crate::context::ann;
use crate::context::embedding::cosine_similarity;
use crate::context::index_db::{self, IndexedRuleChunk, QueryFilter};
use crate::domain::glob_match::{GlobErrorPolicy, glob_match};
use crate::errors::CoreError;
use crate::review_trajectory::{TrajectoryBuilder, TrajectoryStep};
use sqlx::SqlitePool;
use std::collections::{HashMap, HashSet};
use std::time::Duration;
use super::query_embed::embed_query_aligned_to_index;
use super::scoring::{directive_intent_aligned, effective_confidence, infer_rule_kind};
use super::{
ADAPTIVE_INJECT_THRESHOLD, EXPLICIT_RECALL_MIN_RELEVANCE, EXPLICIT_RECALL_RELATIVE_FLOOR,
INTENT_ALIGNMENT_EXEMPT_SCORE, MIN_RELEVANCE_SCORE, RELATIVE_RELEVANCE_FLOOR, RRF_K,
ScoredRuleChunk, concreteness_score, lexical_terms,
};
const MAX_RULE_RETRIEVAL_TOP_K: usize = 50;
const MAX_ANN_CANDIDATES: usize = 150;
pub async fn retrieve_rules(
index_pool: &SqlitePool,
query: &str,
top_k: Option<usize>,
) -> Result<Vec<ScoredRuleChunk>, CoreError> {
retrieve_rules_with_confidence(
index_pool,
query,
RetrievalOptions {
top_k,
..Default::default()
},
)
.await
}
pub(super) fn pattern_allows(file_patterns_json: Option<&str>, target_file: &str) -> bool {
glob_match(file_patterns_json, target_file, GlobErrorPolicy::OverRecall)
}
pub struct RetrievalOptions<'a> {
pub top_k: Option<usize>,
pub confidence_map: Option<&'a HashMap<String, f64>>,
pub eligible_skill_ids: Option<&'a HashSet<String>>,
pub age_days_map: Option<&'a HashMap<String, f32>>,
pub target_file: Option<&'a str>,
pub filter: Option<&'a QueryFilter>,
pub ann_enabled: bool,
pub embedding_timeout: Option<Duration>,
pub cold_start_retry: bool,
pub adaptive_prune: bool,
pub trajectory: Option<&'a mut TrajectoryBuilder>,
}
impl Default for RetrievalOptions<'_> {
fn default() -> Self {
Self {
top_k: None,
confidence_map: None,
eligible_skill_ids: None,
age_days_map: None,
target_file: None,
filter: None,
ann_enabled: true,
embedding_timeout: None,
cold_start_retry: false,
adaptive_prune: false,
trajectory: None,
}
}
}
pub async fn retrieve_rules_with_confidence(
index_pool: &SqlitePool,
query: &str,
options: RetrievalOptions<'_>,
) -> Result<Vec<ScoredRuleChunk>, CoreError> {
let RetrievalOptions {
top_k,
confidence_map,
eligible_skill_ids,
age_days_map,
target_file,
filter,
ann_enabled,
embedding_timeout,
cold_start_retry,
adaptive_prune,
trajectory,
} = options;
let default_filter = QueryFilter::default();
let filter = filter.unwrap_or(&default_filter);
let requested_k = top_k.unwrap_or(DEFAULT_TOP_K_RULES);
if requested_k == 0 {
return Ok(Vec::new());
}
let k = requested_k.min(MAX_RULE_RETRIEVAL_TOP_K);
let retrieval_start = std::time::Instant::now();
let embedded_query =
embed_query_aligned_to_index(index_pool, query, embedding_timeout, cold_start_retry).await;
let query_emb = embedded_query.vector;
let is_semantic = embedded_query.semantic;
let unfiltered_count: u32 = if filter.is_empty() {
0
} else {
sqlx::query_scalar!(r#"SELECT COUNT(*) as "n!: i64" FROM rule_chunks"#)
.fetch_one(index_pool)
.await
.unwrap_or(0)
.try_into()
.unwrap_or(u32::MAX)
};
let chunks = index_db::query_rule_chunks(index_pool, filter).await?;
let after_count: u32 = u32::try_from(chunks.len()).unwrap_or(u32::MAX);
let fts_limit = k.saturating_mul(4).min(200).max(k);
let fts_hits = index_db::fts_search(index_pool, query, filter, fts_limit)
.await
.unwrap_or_default();
let default_confidence = 0.7;
let min_confidence = 0.2;
let matched: Vec<&IndexedRuleChunk> = if let Some(tf) = target_file {
chunks
.iter()
.filter(|c| pattern_allows(c.file_patterns.as_deref(), tf))
.collect()
} else {
chunks.iter().collect()
};
let active: &[&IndexedRuleChunk] = if target_file.is_some() && matched.is_empty() {
&[]
} else {
&matched
};
let id_to_chunk: HashMap<&str, &IndexedRuleChunk> =
active.iter().map(|c| (c.id.as_str(), *c)).collect();
let ann_candidates = k.saturating_mul(3).min(MAX_ANN_CANDIDATES).max(k);
let ann_result = if ann_enabled {
try_ann_rank(
&query_emb,
ann_candidates,
&id_to_chunk,
confidence_map,
eligible_skill_ids,
default_confidence,
min_confidence,
)
.await
} else {
None
};
let (mut emb_ranked, ann_used, ann_index_size, ann_returned): (
Vec<(&IndexedRuleChunk, f64)>,
bool,
u32,
u32,
) = if let Some((ranked, idx_size, returned)) = ann_result {
(ranked, true, idx_size, returned)
} else {
let fallback: Vec<(&IndexedRuleChunk, f64)> = active
.iter()
.filter_map(|c: &&IndexedRuleChunk| {
if !eligible_skill_ids.is_none_or(|ids| ids.contains(&c.skill_id)) {
return None;
}
let confidence = confidence_map
.and_then(|m| m.get(&c.skill_id).copied())
.unwrap_or(default_confidence);
if confidence < min_confidence {
return None;
}
if query_emb.len() != c.embedding.len() {
return None;
}
let sim = cosine_similarity(&query_emb, &c.embedding);
Some((*c, f64::from(sim)))
})
.collect();
(fallback, false, 0, 0)
};
emb_ranked.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.id.cmp(&b.0.id)));
let emb_rank_map: HashMap<&str, usize> = emb_ranked
.iter()
.enumerate()
.map(|(i, (c, _))| (c.id.as_str(), i))
.collect();
let mut fts_rank_map: HashMap<&str, usize> = HashMap::new();
let mut fts_kept = 0u32;
for (i, (id, _)) in fts_hits.iter().enumerate() {
if id_to_chunk.contains_key(id.as_str()) {
fts_rank_map.insert(id.as_str(), i);
fts_kept += 1;
}
}
let overlap: u32 = {
let fts_ids: HashSet<&str> = fts_rank_map.keys().copied().collect();
let emb_ids: HashSet<&str> = emb_rank_map.keys().copied().collect();
u32::try_from(fts_ids.intersection(&emb_ids).count()).unwrap_or(u32::MAX)
};
let (w_emb, w_fts) = if is_semantic { (0.5, 0.5) } else { (0.2, 0.8) };
let mut fused: HashMap<&str, (f64, &IndexedRuleChunk, f64 /*confidence*/)> = HashMap::new();
for (chunk, _sim) in &emb_ranked {
let rank = emb_rank_map.get(chunk.id.as_str()).copied().unwrap_or(0);
let contrib = w_emb / (RRF_K + rank as f64 + 1.0);
let confidence = confidence_map
.and_then(|m| m.get(&chunk.skill_id).copied())
.unwrap_or(default_confidence);
fused
.entry(chunk.id.as_str())
.and_modify(|e| e.0 += contrib)
.or_insert((contrib, *chunk, confidence));
}
for (id, rank) in &fts_rank_map {
if let Some(chunk) = id_to_chunk.get(id) {
if !eligible_skill_ids.is_none_or(|ids| ids.contains(&chunk.skill_id)) {
continue;
}
let contrib = w_fts / (RRF_K + *rank as f64 + 1.0);
let confidence = confidence_map
.and_then(|m| m.get(&chunk.skill_id).copied())
.unwrap_or(default_confidence);
if confidence < min_confidence {
continue;
}
fused
.entry(id)
.and_modify(|e| e.0 += contrib)
.or_insert((contrib, *chunk, confidence));
}
}
if let Some(t) = trajectory {
if !filter.is_empty() {
t.push(TrajectoryStep::RetrievalFilter {
before: unfiltered_count,
after: after_count,
});
}
t.push(TrajectoryStep::AnnRecall {
used: ann_used,
index_size: ann_index_size,
candidates: ann_returned,
});
t.push(TrajectoryStep::HybridFusion {
fts_hits: fts_kept,
emb_hits: u32::try_from(emb_ranked.len()).unwrap_or(u32::MAX),
overlap,
});
}
let mut scored: Vec<ScoredRuleChunk> = fused
.into_values()
.map(|(score, chunk, confidence)| {
let kind = infer_rule_kind(&chunk.content);
let age_days = age_days_map
.and_then(|m| m.get(&chunk.skill_id).copied())
.unwrap_or(0.0);
let eff_conf = f64::from(effective_confidence(confidence as f32, &kind, age_days));
let conf_weight = 0.1f64.mul_add(eff_conf.clamp(0.0, 1.0), 0.9);
let conc = concreteness_score(&chunk.content);
let conc_weight = 0.05f64.mul_add(conc.min(6) as f64, 1.0);
ScoredRuleChunk {
skill_id: chunk.skill_id.clone(),
content: chunk.content.clone(),
score: score * conf_weight * conc_weight,
confidence,
}
})
.collect();
scored.sort_by(|a, b| {
b.score
.total_cmp(&a.score)
.then_with(|| a.skill_id.cmp(&b.skill_id))
});
let adaptive_eligible = adaptive_prune && scored.len() >= 5;
if let Some(top_score) = scored.first().map(|s| s.score) {
if adaptive_eligible && top_score < ADAPTIVE_INJECT_THRESHOLD {
scored.clear();
} else {
prune_below_floors(&mut scored, top_score);
if adaptive_eligible {
let strong_floor = top_score * 0.60;
let strong_count = scored
.iter()
.take_while(|s| s.score >= strong_floor)
.count();
if strong_count > 0 && strong_count < scored.len() {
scored.truncate(strong_count.min(k));
}
}
}
}
scored.truncate(k);
crate::activity_stream::record(
crate::activity_stream::ActivityPayload::RetrievalEmbedding {
hits: u32::try_from(scored.len()).unwrap_or(u32::MAX),
took_ms: u64::try_from(retrieval_start.elapsed().as_millis()).unwrap_or(u64::MAX),
},
);
Ok(scored)
}
fn prune_below_floors(scored: &mut Vec<ScoredRuleChunk>, top_score: f64) {
let relative_floor = top_score * RELATIVE_RELEVANCE_FLOOR;
scored.retain(|s| s.score > MIN_RELEVANCE_SCORE && s.score >= relative_floor);
}
pub fn apply_explicit_recall_threshold(scored: &mut Vec<ScoredRuleChunk>) {
let Some(top_score) = scored.first().map(|s| s.score) else {
return;
};
if top_score < EXPLICIT_RECALL_MIN_RELEVANCE {
scored.clear();
return;
}
let relative_floor = top_score * EXPLICIT_RECALL_RELATIVE_FLOOR;
scored.retain(|s| s.score >= relative_floor);
}
pub fn apply_intent_alignment_gate(scored: &mut Vec<ScoredRuleChunk>, intent: &str) {
if scored.is_empty() {
return;
}
let query_terms = lexical_terms(intent);
if query_terms.is_empty() {
scored.clear();
return;
}
scored.retain(|chunk| {
chunk.score >= INTENT_ALIGNMENT_EXEMPT_SCORE
|| directive_intent_aligned(&chunk.content, &query_terms)
});
}
async fn try_ann_rank<'a>(
query_emb: &[f32],
candidates: usize,
id_to_chunk: &HashMap<&'a str, &'a IndexedRuleChunk>,
confidence_map: Option<&HashMap<String, f64>>,
eligible_skill_ids: Option<&HashSet<String>>,
default_confidence: f64,
min_confidence: f64,
) -> Option<(Vec<(&'a IndexedRuleChunk, f64)>, u32, u32)> {
if query_emb.is_empty() || candidates == 0 {
return None;
}
let project_root = crate::db::current_project_root();
let project_hash = crate::db::project_hash_from_root(&project_root);
let ann_arc = ann::get_ann_for_project(&project_hash, query_emb.len())
.await
.ok()?;
let ann_guard = ann_arc.lock().await;
let index_size = ann_guard.live_size();
if index_size == 0 {
return None;
}
let hits = ann_guard.search(query_emb, candidates);
if hits.is_empty() {
return None;
}
let returned = u32::try_from(hits.len()).unwrap_or(u32::MAX);
let mut ranked: Vec<(&IndexedRuleChunk, f64)> = Vec::with_capacity(hits.len());
for (chunk_id, distance) in hits {
let Some(chunk) = id_to_chunk.get(chunk_id.as_str()) else {
continue;
};
if !eligible_skill_ids.is_none_or(|ids| ids.contains(&chunk.skill_id)) {
continue;
}
let confidence = confidence_map
.and_then(|m| m.get(&chunk.skill_id).copied())
.unwrap_or(default_confidence);
if confidence < min_confidence {
continue;
}
let sim = (1.0 - f64::from(distance)).max(0.0);
ranked.push((*chunk, sim));
}
if ranked.is_empty() {
return None;
}
Some((ranked, index_size, returned))
}
#[cfg(test)]
mod tests {
use super::super::MIN_INTENT_DIRECTIVE_OVERLAP;
use super::*;
fn chunk(id: &str, score: f64) -> ScoredRuleChunk {
ScoredRuleChunk {
skill_id: id.to_owned(),
content: format!("Rule ID: {id}\nRule Name: {id}\n\nbody"),
score,
confidence: 0.7,
}
}
#[test]
fn explicit_recall_threshold_strong_top_hit_survives() {
let mut scored = vec![chunk("strong", 0.30), chunk("supporting", 0.12)];
apply_explicit_recall_threshold(&mut scored);
assert_eq!(scored.len(), 2, "strong matches must not be pruned");
assert_eq!(scored[0].skill_id, "strong");
}
#[test]
fn explicit_recall_threshold_all_weak_returns_empty() {
let mut scored = vec![
chunk("noise-1", 0.004),
chunk("noise-2", 0.003),
chunk("noise-3", 0.002),
chunk("noise-4", 0.0015),
chunk("noise-5", 0.001),
];
apply_explicit_recall_threshold(&mut scored);
assert!(
scored.is_empty(),
"a query whose only matches are weak must return zero results"
);
}
#[test]
fn explicit_recall_threshold_borderline_keeps_only_strong() {
let mut scored = vec![
chunk("leader", 0.40),
chunk("near", 0.10), chunk("tail-1", 0.05), chunk("tail-2", 0.02),
chunk("tail-3", 0.011),
];
apply_explicit_recall_threshold(&mut scored);
let ids: Vec<&str> = scored.iter().map(|s| s.skill_id.as_str()).collect();
assert_eq!(
ids,
vec!["leader", "near"],
"only the leader and rules within the relative band survive"
);
}
#[test]
fn explicit_recall_threshold_top_hit_at_absolute_floor_is_kept() {
let mut scored = vec![chunk("at-floor", EXPLICIT_RECALL_MIN_RELEVANCE)];
apply_explicit_recall_threshold(&mut scored);
assert_eq!(scored.len(), 1, "top hit at the floor must be kept");
}
#[test]
fn explicit_recall_threshold_empty_input_is_noop() {
let mut scored: Vec<ScoredRuleChunk> = Vec::new();
apply_explicit_recall_threshold(&mut scored);
assert!(scored.is_empty());
}
fn directive_chunk(id: &str, title: &str, score: f64) -> ScoredRuleChunk {
ScoredRuleChunk {
skill_id: id.to_owned(),
content: format!(
"Rule ID: {id}\nRule Name: {title}\nType: convention\nTags: \n\n{title}."
),
score,
confidence: 0.7,
}
}
#[test]
fn intent_gate_drops_topically_adjacent_different_subject_rule() {
let mut scored = vec![
directive_chunk(
"panic-message-wording",
"Panic messages should describe the violated invariant",
0.12,
),
directive_chunk(
"test-timing",
"Avoid sleep-based waits in tests; poll for the condition",
0.10,
),
];
apply_intent_alignment_gate(
&mut scored,
"return false instead of panic on invalid input",
);
assert!(
scored.is_empty(),
"topically-adjacent, wrong-subject rules must be dropped, kept: {:?}",
scored.iter().map(|s| &s.skill_id).collect::<Vec<_>>()
);
}
#[test]
fn intent_gate_keeps_directly_on_subject_rule() {
let mut scored = vec![directive_chunk(
"return-false-not-panic",
"Return false rather than panic on invalid input",
0.12,
)];
apply_intent_alignment_gate(
&mut scored,
"return false instead of panic on invalid input",
);
assert_eq!(
scored
.iter()
.map(|s| s.skill_id.as_str())
.collect::<Vec<_>>(),
vec!["return-false-not-panic"],
"a directly-on-subject directive must survive the intent gate"
);
}
#[test]
fn intent_gate_keeps_on_subject_drops_adjacent_in_same_set() {
let mut scored = vec![
directive_chunk(
"return-false-not-panic",
"Return false rather than panic on invalid input",
0.12,
),
directive_chunk(
"panic-message-wording",
"Panic messages should describe the violated invariant",
0.11,
),
directive_chunk(
"test-timing",
"Avoid sleep-based waits in tests; poll for the condition",
0.10,
),
];
apply_intent_alignment_gate(
&mut scored,
"return false instead of panic on invalid input",
);
assert_eq!(
scored
.iter()
.map(|s| s.skill_id.as_str())
.collect::<Vec<_>>(),
vec!["return-false-not-panic"],
"only the intent-aligned rule should survive the mixed set"
);
}
#[test]
fn intent_gate_all_weak_query_returns_zero() {
let mut scored = vec![
directive_chunk("a", "Return false rather than panic on invalid input", 0.12),
directive_chunk("b", "Use structured errors in request handlers", 0.10),
];
apply_intent_alignment_gate(&mut scored, "the and to of");
assert!(
scored.is_empty(),
"an all-weak query must return zero, kept: {:?}",
scored.iter().map(|s| &s.skill_id).collect::<Vec<_>>()
);
}
#[test]
fn intent_gate_exempts_strongly_scored_hits() {
let mut scored = vec![ScoredRuleChunk {
skill_id: "exact-title-strict".to_owned(),
content: "Rule ID: x\nRule Name: Completely unrelated heading\n\nbody".to_owned(),
score: 2.7,
confidence: 0.7,
}];
apply_intent_alignment_gate(
&mut scored,
"return false instead of panic on invalid input",
);
assert_eq!(
scored.len(),
1,
"a strongly-scored hit must be exempt from the alignment gate"
);
}
#[test]
fn intent_gate_ratio_path_keeps_short_sharp_query_match() {
let mut scored = vec![directive_chunk(
"panic-safety",
"Document panic safety for unsafe blocks",
0.12,
)];
apply_intent_alignment_gate(&mut scored, "panic safety");
assert_eq!(
scored.len(),
1,
"a half-coverage match on a short query must survive via the ratio path"
);
}
#[test]
fn intent_gate_empty_input_is_noop() {
let mut scored: Vec<ScoredRuleChunk> = Vec::new();
apply_intent_alignment_gate(&mut scored, "anything");
assert!(scored.is_empty());
}
#[test]
fn intent_gate_drops_two_generic_anchor_overlap_without_distinctive_term() {
let mut scored = vec![directive_chunk(
"runtime-error-logging",
"Log every panic and error with the request input id",
0.12,
)];
apply_intent_alignment_gate(&mut scored, "panic on invalid input handling");
assert!(
scored.is_empty(),
"an all-generic-anchor overlap must not establish a concern match, kept: {:?}",
scored.iter().map(|s| &s.skill_id).collect::<Vec<_>>()
);
}
#[test]
fn intent_gate_drops_off_subject_rule_that_namedrops_one_distinctive_token() {
let mut scored = vec![directive_chunk(
"csv-token-splitting",
"Split each CSV row into fields on the comma token boundary carefully",
0.12,
)];
apply_intent_alignment_gate(
&mut scored,
"validate the auth token before issuing session",
);
assert!(
scored.is_empty(),
"a one-token name-drop in an off-subject rule must be dropped, kept: {:?}",
scored.iter().map(|s| &s.skill_id).collect::<Vec<_>>()
);
}
#[test]
fn intent_gate_keeps_on_subject_rule_with_verbose_body() {
let verbose_body = "When a handler receives malformed input it should return a typed \
error to the caller rather than calling panic!, because a panic unwinds the worker \
thread and takes down unrelated in-flight requests; prefer Result and propagate. \
See the request lifecycle docs and the error-taxonomy appendix for the full list.";
let mut scored = vec![ScoredRuleChunk {
skill_id: "validate-return-error".to_owned(),
content: format!(
"Rule ID: r\nRule Name: Validate input and return a typed error not panic\nType: correction\nTags: \n\n{verbose_body}"
),
score: 0.12,
confidence: 0.7,
}];
apply_intent_alignment_gate(
&mut scored,
"validate input and return error instead of panic",
);
assert_eq!(
scored
.iter()
.map(|s| s.skill_id.as_str())
.collect::<Vec<_>>(),
vec!["validate-return-error"],
"an on-subject rule with a long body must survive (title-scoped coverage)"
);
}
#[test]
fn intent_gate_strictly_subsumes_old_overlap_count_on_anchor_only_match() {
let intent = "panic on invalid input";
let mut anchor_only = vec![directive_chunk(
"anchor-only",
"Buffer every panic and input event into the queue",
0.12,
)];
apply_intent_alignment_gate(&mut anchor_only, intent);
assert!(
anchor_only.is_empty(),
"anchor-only overlap (old gate would keep) must now drop"
);
let mut on_subject = vec![directive_chunk(
"on-subject",
"Reject invalid input instead of letting it panic",
0.12,
)];
apply_intent_alignment_gate(&mut on_subject, intent);
assert_eq!(
on_subject
.iter()
.map(|s| s.skill_id.as_str())
.collect::<Vec<_>>(),
vec!["on-subject"],
"the distinctive-token sibling must be kept"
);
}
#[test]
fn intent_alignment_exempt_score_sits_above_strong_band_below_exact_title() {
let exempt_score = std::hint::black_box(INTENT_ALIGNMENT_EXEMPT_SCORE);
let explicit_floor = std::hint::black_box(EXPLICIT_RECALL_MIN_RELEVANCE);
let exact_title_floor = std::hint::black_box(2.0);
let min_overlap = std::hint::black_box(MIN_INTENT_DIRECTIVE_OVERLAP);
assert!(
exempt_score > explicit_floor,
"exemption ceiling must be above the explicit relevance floor"
);
assert!(
exempt_score < exact_title_floor,
"exemption ceiling must be below the exact-title-strict (2.0 + conf) band"
);
assert!(
min_overlap >= 2,
"a lone topical-anchor overlap must be insufficient"
);
}
#[test]
fn explicit_recall_floors_are_conservative_relative_to_in_retrieval_gates() {
let explicit_relative_floor = std::hint::black_box(EXPLICIT_RECALL_RELATIVE_FLOOR);
let retrieval_relative_floor = std::hint::black_box(RELATIVE_RELEVANCE_FLOOR);
let explicit_min = std::hint::black_box(EXPLICIT_RECALL_MIN_RELEVANCE);
let adaptive_threshold = std::hint::black_box(ADAPTIVE_INJECT_THRESHOLD);
let min_relevance = std::hint::black_box(MIN_RELEVANCE_SCORE);
assert!(
explicit_relative_floor < retrieval_relative_floor,
"explicit relative floor must be looser than the in-retrieval one"
);
assert!(
explicit_min > adaptive_threshold,
"explicit absolute floor must sit above the hook zero-inject threshold"
);
assert!(
explicit_min > min_relevance,
"explicit absolute floor must be stricter than the bare RRF noise floor"
);
}
}