use async_trait::async_trait;
use std::collections::HashSet;
use tracing::{debug, trace};
use super::super::RetrievalContext;
use super::super::types::{NavigationDecision, NavigationStep, SearchPath};
use super::scorer::{NodeScorer, ScoringContext};
use super::{SearchConfig, SearchResult, SearchTree};
use crate::document::{DocumentTree, NodeId};
use crate::retrieval::pilot::{Pilot, SearchState};
pub struct BeamSearch {
beam_width: usize,
}
impl BeamSearch {
pub fn new() -> Self {
Self { beam_width: 3 }
}
pub fn with_width(width: usize) -> Self {
Self {
beam_width: width.max(1),
}
}
fn create_scorer(&self, query: &str) -> NodeScorer {
NodeScorer::new(ScoringContext::new(query))
}
fn score_candidates_with_query(
&self,
tree: &DocumentTree,
candidates: &[NodeId],
query: &str,
) -> Vec<(NodeId, f32)> {
let scorer = self.create_scorer(query);
scorer.score_and_sort(tree, candidates)
}
fn merge_with_pilot_decision(
&self,
tree: &DocumentTree,
candidates: &[NodeId],
pilot_decision: &crate::retrieval::pilot::PilotDecision,
query: &str,
) -> Vec<(NodeId, f32)> {
let scorer = self.create_scorer(query);
let alpha = 0.4;
let beta = 0.6 * pilot_decision.confidence;
let mut pilot_scores: std::collections::HashMap<NodeId, f32> =
std::collections::HashMap::new();
for ranked in &pilot_decision.ranked_candidates {
pilot_scores.insert(ranked.node_id, ranked.score);
}
let mut merged: Vec<(NodeId, f32)> = candidates
.iter()
.map(|&node_id| {
let algo_score = scorer.score(tree, node_id);
let pilot_score = pilot_scores.get(&node_id).copied().unwrap_or(0.0);
let final_score = if beta > 0.0 {
(alpha * algo_score + beta * pilot_score) / (alpha + beta)
} else {
algo_score
};
(node_id, final_score)
})
.collect();
merged.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
merged
}
async fn search_impl(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
start_node: NodeId,
) -> SearchResult {
let mut result = SearchResult::default();
let beam_width = config.beam_width.min(self.beam_width);
let mut visited: HashSet<NodeId> = HashSet::new();
visited.insert(start_node);
debug!(
"BeamSearch: query='{}', start_node={:?}, beam_width={}, min_score={:.2}",
context.query, start_node, beam_width, config.min_score
);
let mut pilot_interventions = 0;
let start_children = tree.children(start_node);
debug!("Start node has {} children", start_children.len());
let initial_candidates = if let Some(p) = pilot {
debug!(
"BeamSearch: Pilot is available, name={}, guide_at_start={}",
p.name(),
p.config().guide_at_start
);
if p.config().guide_at_start {
if let Some(guidance) = p.guide_start(tree, &context.query, start_node).await {
debug!(
"Pilot provided start guidance with confidence {}",
guidance.confidence
);
pilot_interventions += 1;
if guidance.has_candidates() {
self.merge_with_pilot_decision(
tree,
&start_children,
&guidance,
&context.query,
)
} else {
self.score_candidates_with_query(tree, &start_children, &context.query)
}
} else {
self.score_candidates_with_query(tree, &start_children, &context.query)
}
} else {
self.score_candidates_with_query(tree, &start_children, &context.query)
}
} else {
self.score_candidates_with_query(tree, &start_children, &context.query)
};
let mut current_beam: Vec<SearchPath> = initial_candidates
.into_iter()
.map(|(node_id, score)| SearchPath::from_node(node_id, score))
.collect();
debug!("Initial {} candidates after scoring", current_beam.len());
current_beam.truncate(beam_width);
for iteration in 0..config.max_iterations {
result.iterations = iteration + 1;
if current_beam.is_empty() {
break;
}
let mut next_beam = Vec::new();
for path in ¤t_beam {
if let Some(leaf_id) = path.leaf {
visited.insert(leaf_id);
if tree.is_leaf(leaf_id) {
if path.score >= config.min_score {
result.paths.push(path.clone());
}
result.nodes_visited += 1;
continue;
}
let children = tree.children(leaf_id);
let scored_children = if let Some(p) = pilot {
let state = SearchState::new(
tree,
&context.query,
&path.nodes,
&children,
&visited,
);
if p.should_intervene(&state) {
trace!(
"Pilot intervening at fork with {} candidates",
children.len()
);
match p.decide(&state).await {
decision => {
pilot_interventions += 1;
debug!(
"Pilot decision: confidence={}, direction={:?}",
decision.confidence,
std::mem::discriminant(&decision.direction)
);
self.merge_with_pilot_decision(
tree,
&children,
&decision,
&context.query,
)
}
}
} else {
self.score_candidates_with_query(tree, &children, &context.query)
}
} else {
self.score_candidates_with_query(tree, &children, &context.query)
};
for (child_id, child_score) in scored_children.into_iter().take(beam_width) {
let new_path = path.extend(child_id, child_score);
let child_node = tree.get(child_id);
result.trace.push(NavigationStep {
node_id: format!("{:?}", child_id),
title: child_node.map(|n| n.title.clone()).unwrap_or_default(),
score: child_score,
decision: NavigationDecision::GoToChild(
children.iter().position(|&c| c == child_id).unwrap_or(0),
),
depth: child_node.map(|n| n.depth).unwrap_or(0),
});
next_beam.push(new_path);
result.nodes_visited += 1;
}
}
}
next_beam.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
next_beam.truncate(beam_width);
current_beam = next_beam;
if result.paths.len() >= config.top_k {
break;
}
}
for path in current_beam {
if path.score >= config.min_score && result.paths.len() < config.top_k {
result.paths.push(path);
}
}
if result.paths.is_empty() && config.min_score > 0.0 {
debug!("No results above min_score, adding best candidates as fallback");
let all_candidates =
self.score_candidates_with_query(tree, &tree.children(start_node), &context.query);
for (node_id, score) in all_candidates.into_iter().take(config.top_k) {
result.paths.push(SearchPath::from_node(node_id, score));
}
}
result.paths.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
result.paths.truncate(config.top_k);
result.pilot_interventions = pilot_interventions;
result
}
}
impl Default for BeamSearch {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SearchTree for BeamSearch {
async fn search(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
) -> SearchResult {
self.search_impl(tree, context, config, pilot, tree.root())
.await
}
async fn search_from(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
start_node: NodeId,
) -> SearchResult {
self.search_impl(tree, context, config, pilot, start_node)
.await
}
fn name(&self) -> &'static str {
"beam"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_beam_search_creation() {
let search = BeamSearch::new();
assert_eq!(search.beam_width, 3);
let search_wide = BeamSearch::with_width(5);
assert_eq!(search_wide.beam_width, 5);
}
#[test]
fn test_beam_search_minimum_width() {
let search = BeamSearch::with_width(0);
assert_eq!(search.beam_width, 1);
}
}