use async_trait::async_trait;
use std::collections::{HashMap, HashSet};
use tracing::debug;
use super::super::RetrievalContext;
use super::super::types::{NavigationDecision, NavigationStep, SearchPath};
use super::{SearchConfig, SearchResult, SearchTree};
use crate::document::{DocumentTree, NodeId};
use crate::retrieval::pilot::Pilot;
use crate::retrieval::pilot::{NodeScorer, PilotDecisionCache, ScoringContext, score_candidates};
#[derive(Debug, Clone, Default)]
struct NodeStats {
visits: usize,
total_score: f32,
}
pub struct MctsSearch {
exploration_weight: f32,
}
impl MctsSearch {
pub fn new() -> Self {
Self {
exploration_weight: 1.414, }
}
pub fn with_exploration(mut self, weight: f32) -> Self {
self.exploration_weight = weight;
self
}
fn uct_score(&self, child_stats: &NodeStats, parent_visits: usize, prior_score: f32) -> f32 {
if child_stats.visits == 0 {
return f32::INFINITY;
}
let exploitation = child_stats.total_score / child_stats.visits as f32;
let exploration = self.exploration_weight * (parent_visits as f32).ln().sqrt()
/ child_stats.visits as f32;
0.5 * (exploitation + prior_score) + exploration
}
async fn select_child(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
node_id: NodeId,
stats: &HashMap<NodeId, NodeStats>,
pilot: Option<&dyn Pilot>,
cache: &PilotDecisionCache,
visited: &HashSet<NodeId>,
) -> Option<(NodeId, f32)> {
let children = tree.children_with_refs(node_id);
if children.is_empty() {
return None;
}
let parent_stats = stats.get(&node_id).cloned().unwrap_or_default();
let parent_visits = parent_stats.visits.max(1);
let priors = score_candidates(
tree,
&children,
&context.query,
pilot,
&[node_id], visited,
0.5, Some(cache),
None, )
.await;
let prior_map: HashMap<NodeId, f32> = priors.into_iter().collect();
let mut best_child = None;
let mut best_score = f32::NEG_INFINITY;
for &child_id in &children {
let prior = prior_map.get(&child_id).copied().unwrap_or_else(|| {
let scorer = NodeScorer::new(ScoringContext::new(&context.query));
scorer.score(tree, child_id)
});
let child_stats = stats.get(&child_id).cloned().unwrap_or_default();
let uct = self.uct_score(&child_stats, parent_visits, prior);
if uct > best_score {
best_score = uct;
best_child = Some((child_id, prior));
}
}
best_child
}
async fn simulate(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
node_id: NodeId,
max_depth: usize,
pilot: Option<&dyn Pilot>,
cache: &PilotDecisionCache,
visited: &HashSet<NodeId>,
) -> f32 {
let mut current = node_id;
let mut depth = 0;
let mut path = vec![node_id];
let mut total_score = 0.0f32;
let mut count = 0;
let scorer = NodeScorer::new(ScoringContext::new(&context.query));
total_score += scorer.score(tree, current);
count += 1;
while depth < max_depth {
let children = tree.children_with_refs(current);
if children.is_empty() {
break;
}
let scored = score_candidates(
tree,
&children,
&context.query,
pilot,
&path,
visited,
0.5, Some(cache),
None, )
.await;
if let Some(&(child_id, score)) = scored.first() {
total_score += score;
path.push(child_id);
current = child_id;
} else {
break;
}
depth += 1;
count += 1;
}
total_score / count.max(1) as f32
}
fn backpropagate(&self, stats: &mut HashMap<NodeId, NodeStats>, path: &[NodeId], score: f32) {
for &node_id in path {
let node_stats = stats.entry(node_id).or_default();
node_stats.visits += 1;
node_stats.total_score += score;
}
}
async fn search_impl(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
start_node: NodeId,
) -> SearchResult {
let mut result = SearchResult::default();
let mut stats: HashMap<NodeId, NodeStats> = HashMap::new();
let cache = PilotDecisionCache::new();
let visited: HashSet<NodeId> = HashSet::new();
stats.insert(start_node, NodeStats::default());
debug!(
"MctsSearch: query='{}', start_node={:?}, max_iterations={}, exploration={:.2}",
context.query, start_node, config.max_iterations, self.exploration_weight
);
let mut pilot_interventions = 0;
for iteration in 0..config.max_iterations {
result.iterations = iteration + 1;
let mut path = vec![start_node];
let mut current = start_node;
while !tree.is_leaf(current) {
if let Some((child_id, _score)) = self
.select_child(tree, context, current, &stats, pilot, &cache, &visited)
.await
{
path.push(child_id);
current = child_id;
if pilot.is_some() {
pilot_interventions += 1;
}
} else {
break;
}
}
result.nodes_visited += path.len();
let leaf = *path.last().unwrap_or(&start_node);
let sim_score = self
.simulate(tree, context, leaf, 5, pilot, &cache, &visited)
.await;
if pilot.is_some() {
pilot_interventions += 1;
}
self.backpropagate(&mut stats, &path, sim_score);
if let Some(&last_id) = path.last() {
let node = tree.get(last_id);
result.trace.push(NavigationStep {
node_id: format!("{:?}", last_id),
title: node.map(|n| n.title.clone()).unwrap_or_default(),
score: sim_score,
decision: NavigationDecision::ExploreMore,
depth: node.map(|n| n.depth).unwrap_or(0),
});
}
if iteration > 0 && iteration % 10 == 0 {
self.extract_paths(
tree,
start_node,
&stats,
config.min_score,
config.top_k,
&mut result,
);
}
}
self.extract_paths(
tree,
start_node,
&stats,
config.min_score,
config.top_k,
&mut result,
);
result.pilot_interventions = pilot_interventions;
result
}
fn extract_paths(
&self,
tree: &DocumentTree,
root: NodeId,
stats: &HashMap<NodeId, NodeStats>,
min_score: f32,
top_k: usize,
result: &mut SearchResult,
) {
let root_children = tree.children_with_refs(root);
let mut scored_children: Vec<_> = root_children
.iter()
.filter_map(|&child_id| {
stats.get(&child_id).map(|s| {
let avg_score = if s.visits > 0 {
s.total_score / s.visits as f32
} else {
0.0
};
(child_id, avg_score)
})
})
.collect();
scored_children.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
result.paths = scored_children
.into_iter()
.filter(|(_, score)| *score >= min_score)
.take(top_k)
.map(|(node_id, score)| SearchPath::from_node(node_id, score))
.collect();
}
}
impl Default for MctsSearch {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SearchTree for MctsSearch {
async fn search(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
) -> SearchResult {
self.search_impl(tree, context, config, pilot, tree.root())
.await
}
async fn search_from(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
pilot: Option<&dyn Pilot>,
start_node: NodeId,
) -> SearchResult {
self.search_impl(tree, context, config, pilot, start_node)
.await
}
fn name(&self) -> &'static str {
"mcts"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mcts_creation() {
let search = MctsSearch::new();
assert!((search.exploration_weight - 1.414).abs() < 0.01);
}
#[test]
fn test_mcts_custom_exploration() {
let search = MctsSearch::new().with_exploration(2.0);
assert!((search.exploration_weight - 2.0).abs() < 0.01);
}
#[test]
fn test_uct_unvisited() {
let search = MctsSearch::new();
let stats = NodeStats::default();
let score = search.uct_score(&stats, 10, 0.5);
assert!(score.is_infinite());
}
#[test]
fn test_uct_visited() {
let search = MctsSearch::new();
let stats = NodeStats {
visits: 5,
total_score: 3.0,
};
let score = search.uct_score(&stats, 20, 0.8);
assert!(score.is_finite());
assert!(score > 0.0);
}
}