use async_trait::async_trait;
use tracing::{debug, trace};
use super::super::RetrievalContext;
use super::super::types::{NavigationDecision, NavigationStep, SearchPath};
use super::scorer::{NodeScorer, ScoringContext};
use super::{SearchConfig, SearchResult, SearchTree};
use crate::document::{DocumentTree, NodeId};
use crate::retrieval::pilot::{Pilot, SearchState};
pub struct GreedySearch;
impl GreedySearch {
pub fn new() -> Self {
Self
}
fn create_scorer(&self, query: &str) -> NodeScorer {
NodeScorer::new(ScoringContext::new(query))
}
fn score_candidates_with_query(
&self,
tree: &DocumentTree,
candidates: &[NodeId],
query: &str,
) -> Vec<(NodeId, f32)> {
let scorer = self.create_scorer(query);
scorer.score_and_sort(tree, candidates)
}
fn merge_with_pilot_decision(
&self,
tree: &DocumentTree,
candidates: &[NodeId],
pilot_decision: &crate::retrieval::pilot::PilotDecision,
query: &str,
) -> Vec<(NodeId, f32)> {
let scorer = self.create_scorer(query);
let alpha = 0.4;
let beta = 0.6 * pilot_decision.confidence;
let mut pilot_scores: std::collections::HashMap<NodeId, f32> =
std::collections::HashMap::new();
for ranked in &pilot_decision.ranked_candidates {
pilot_scores.insert(ranked.node_id, ranked.score);
}
let mut merged: Vec<(NodeId, f32)> = candidates
.iter()
.map(|&node_id| {
let algo_score = scorer.score(tree, node_id);
let pilot_score = pilot_scores.get(&node_id).copied().unwrap_or(0.0);
let final_score = if beta > 0.0 {
(alpha * algo_score + beta * pilot_score) / (alpha + beta)
} else {
algo_score
};
(node_id, final_score)
})
.collect();
merged.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
merged
}
async fn search_impl(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
start_node: NodeId,
) -> SearchResult {
let mut result = SearchResult::default();
let mut current_path = SearchPath::new();
let mut current_node = start_node;
let mut visited: std::collections::HashSet<NodeId> = std::collections::HashSet::new();
debug!(
"GreedySearch: query='{}', start_node={:?}, max_iterations={}, min_score={:.2}",
context.query, start_node, config.max_iterations, config.min_score
);
let mut pilot_interventions = 0;
for iteration in 0..config.max_iterations {
result.iterations = iteration + 1;
let children = tree.children(current_node);
if children.is_empty() {
current_path.leaf = Some(current_node);
if !config.leaf_only || tree.is_leaf(current_node) {
result.paths.push(current_path.clone());
}
break;
}
let scored_children = if let Some(p) = pilot {
let state = SearchState::new(
tree,
&context.query,
¤t_path.nodes,
&children,
&visited,
);
if p.should_intervene(&state) {
trace!(
"Pilot intervening at greedy decision point with {} candidates",
children.len()
);
match p.decide(&state).await {
decision => {
pilot_interventions += 1;
debug!(
"Pilot decision: confidence={}, direction={:?}",
decision.confidence,
std::mem::discriminant(&decision.direction)
);
self.merge_with_pilot_decision(
tree,
&children,
&decision,
&context.query,
)
}
}
} else {
self.score_candidates_with_query(tree, &children, &context.query)
}
} else {
self.score_candidates_with_query(tree, &children, &context.query)
};
let mut best_child = None;
let mut best_score = 0.0;
for (child_id, score) in scored_children {
if score >= config.min_score {
best_child = Some(child_id);
best_score = score;
break;
}
}
if let Some(child_id) = best_child {
visited.insert(child_id);
let child_node = tree.get(child_id);
result.trace.push(NavigationStep {
node_id: format!("{:?}", child_id),
title: child_node.map(|n| n.title.clone()).unwrap_or_default(),
score: best_score,
decision: NavigationDecision::GoToChild(
children.iter().position(|&c| c == child_id).unwrap_or(0),
),
depth: child_node.map(|n| n.depth).unwrap_or(0),
});
current_path = current_path.extend(child_id, best_score);
current_node = child_id;
result.nodes_visited += 1;
if result.paths.len() >= config.top_k {
break;
}
} else {
current_path.leaf = Some(current_node);
if current_path.score > 0.0 {
result.paths.push(current_path);
}
break;
}
}
result.pilot_interventions = pilot_interventions;
result
}
}
impl Default for GreedySearch {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SearchTree for GreedySearch {
async fn search(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
) -> SearchResult {
self.search_impl(tree, context, config, pilot, tree.root())
.await
}
async fn search_from(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
start_node: NodeId,
) -> SearchResult {
self.search_impl(tree, context, config, pilot, start_node)
.await
}
fn name(&self) -> &'static str {
"greedy"
}
}