use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::Mutex;
use super::scorer::{NodeScorer, ScoringContext};
use crate::document::{DocumentTree, NodeId};
use crate::retrieval::pilot::{Pilot, PilotDecision, SearchState};
type CacheKey = (u64, NodeId);
#[derive(Debug, Clone, Default)]
pub struct PilotDecisionCache {
inner: Arc<Mutex<HashMap<CacheKey, PilotDecision>>>,
}
impl PilotDecisionCache {
pub fn new() -> Self {
Self::default()
}
fn cache_key(query: &str, parent: NodeId) -> CacheKey {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
query.hash(&mut hasher);
(hasher.finish(), parent)
}
pub async fn get(&self, query: &str, parent: NodeId) -> Option<PilotDecision> {
let key = Self::cache_key(query, parent);
let cache = self.inner.lock().await;
cache.get(&key).cloned()
}
pub async fn put(&self, query: &str, parent: NodeId, decision: &PilotDecision) {
let key = Self::cache_key(query, parent);
let mut cache = self.inner.lock().await;
cache.entry(key).or_insert_with(|| decision.clone());
}
pub async fn clear(&self) {
self.inner.lock().await.clear();
}
}
pub async fn score_candidates(
tree: &DocumentTree,
candidates: &[NodeId],
query: &str,
pilot: Option<&dyn Pilot>,
path: &[NodeId],
visited: &HashSet<NodeId>,
pilot_weight: f32,
cache: Option<&PilotDecisionCache>,
step_reasons: Option<&[Option<String>]>,
) -> Vec<(NodeId, f32)> {
let scored = score_candidates_detailed(
tree,
candidates,
query,
pilot,
path,
visited,
pilot_weight,
cache,
step_reasons,
)
.await;
scored.into_iter().map(|s| (s.node_id, s.score)).collect()
}
#[derive(Debug, Clone)]
pub struct ScoredCandidate {
pub node_id: NodeId,
pub score: f32,
pub reason: Option<String>,
}
pub async fn score_candidates_detailed(
tree: &DocumentTree,
candidates: &[NodeId],
query: &str,
pilot: Option<&dyn Pilot>,
path: &[NodeId],
visited: &HashSet<NodeId>,
pilot_weight: f32,
cache: Option<&PilotDecisionCache>,
step_reasons: Option<&[Option<String>]>,
) -> Vec<ScoredCandidate> {
if candidates.is_empty() {
return Vec::new();
}
let Some(p) = pilot else {
return score_with_scorer_detailed(tree, candidates, query);
};
if !p.is_active() {
return score_with_scorer_detailed(tree, candidates, query);
}
let parent = path.last().copied().unwrap_or(tree.root());
let prefilter_cfg = &p.config().prefilter;
let pilot_candidates: Vec<NodeId> = if prefilter_cfg.should_prefilter(candidates.len()) {
let scorer = NodeScorer::new(ScoringContext::new(query));
let mut sorted = scorer.score_and_sort(tree, candidates);
let pilot_max = prefilter_cfg.max_to_pilot.min(candidates.len());
sorted.truncate(pilot_max);
let ids: Vec<NodeId> = sorted.into_iter().map(|(id, _)| id).collect();
tracing::debug!(
"Pre-filtered: {} candidates -> {} to Pilot (threshold={})",
candidates.len(),
ids.len(),
prefilter_cfg.threshold,
);
ids
} else {
candidates.to_vec()
};
let prune_cfg = &p.config().prune;
let pilot_candidates = if prune_cfg.should_prune(pilot_candidates.len()) {
let mut prune_state = SearchState::new(tree, query, path, &pilot_candidates, visited);
prune_state.step_reasons = step_reasons;
if let Some(relevant_ids) = p.binary_prune(&prune_state).await {
let relevant_set: HashSet<NodeId> = relevant_ids.iter().copied().collect();
let mut pruned: Vec<NodeId> = pilot_candidates
.iter()
.filter(|id| relevant_set.contains(id))
.copied()
.collect();
if pruned.len() < prune_cfg.min_keep {
for id in &pilot_candidates {
if pruned.len() >= prune_cfg.min_keep {
break;
}
if !relevant_set.contains(id) {
pruned.push(*id);
}
}
}
tracing::debug!(
"Binary prune: {} candidates -> {} relevant (min_keep={})",
pilot_candidates.len(),
pruned.len(),
prune_cfg.min_keep,
);
pruned
} else {
pilot_candidates
}
} else {
pilot_candidates
};
let decision = if let Some(c) = cache {
if let Some(cached) = c.get(query, parent).await {
tracing::trace!("Pilot cache hit for parent={:?}", parent);
cached
} else {
let mut state = SearchState::new(tree, query, path, &pilot_candidates, visited);
state.step_reasons = step_reasons;
let d = p.decide(&state).await;
c.put(query, parent, &d).await;
d
}
} else {
let mut state = SearchState::new(tree, query, path, &pilot_candidates, visited);
state.step_reasons = step_reasons;
p.decide(&state).await
};
let mut pilot_data: HashMap<NodeId, (f32, Option<String>)> = HashMap::new();
for ranked in &decision.ranked_candidates {
pilot_data.insert(ranked.node_id, (ranked.score, ranked.reason.clone()));
}
let scorer_weight = 1.0 - pilot_weight;
let confidence = decision.confidence;
let effective_pilot = pilot_weight * confidence;
let scorer = NodeScorer::new(ScoringContext::new(query));
let mut scored: Vec<ScoredCandidate> = candidates
.iter()
.map(|&node_id| {
let algo_score = scorer.score(tree, node_id);
let (p_score, reason) = pilot_data
.get(&node_id)
.map(|(s, r)| (*s, r.clone()))
.unwrap_or((0.0, None));
let final_score = if effective_pilot > 0.0 && pilot_data.contains_key(&node_id) {
(effective_pilot * p_score + scorer_weight * algo_score)
/ (effective_pilot + scorer_weight)
} else {
algo_score
};
ScoredCandidate {
node_id,
score: final_score,
reason,
}
})
.collect();
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored
}
fn score_with_scorer(
tree: &DocumentTree,
candidates: &[NodeId],
query: &str,
) -> Vec<(NodeId, f32)> {
let scorer = NodeScorer::new(ScoringContext::new(query));
scorer.score_and_sort(tree, candidates)
}
fn score_with_scorer_detailed(
tree: &DocumentTree,
candidates: &[NodeId],
query: &str,
) -> Vec<ScoredCandidate> {
let scorer = NodeScorer::new(ScoringContext::new(query));
scorer
.score_and_sort(tree, candidates)
.into_iter()
.map(|(node_id, score)| ScoredCandidate {
node_id,
score,
reason: None,
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::document::TreeNode;
use indextree::Arena;
fn make_node_id(arena: &mut Arena<TreeNode>) -> NodeId {
NodeId(arena.new_node(TreeNode::default()))
}
#[test]
fn test_cache_key_deterministic() {
let mut arena = Arena::new();
let nid = make_node_id(&mut arena);
let key1 = PilotDecisionCache::cache_key("hello", nid);
let key2 = PilotDecisionCache::cache_key("hello", nid);
assert_eq!(key1, key2);
let key3 = PilotDecisionCache::cache_key("world", nid);
assert_ne!(key1, key3);
}
}