use async_trait::async_trait;
use std::collections::HashMap;
use super::super::RetrievalContext;
use super::super::types::{NavigationDecision, NavigationStep, SearchPath};
use super::scorer::NodeScorer;
use super::{SearchConfig, SearchResult, SearchTree};
use crate::config::StrategyConfig;
use crate::document::{DocumentTree, NodeId};
use crate::retrieval::pilot::Pilot;
#[derive(Debug, Clone, Default)]
struct NodeStats {
visits: usize,
total_score: f32,
}
pub struct MctsSearch {
scorer: NodeScorer,
exploration_weight: f32,
}
impl MctsSearch {
pub fn new() -> Self {
Self::with_config(&StrategyConfig::default())
}
pub fn with_config(config: &StrategyConfig) -> Self {
Self {
scorer: NodeScorer::new(Default::default()),
exploration_weight: config.exploration_weight,
}
}
pub fn with_exploration(mut self, weight: f32) -> Self {
self.exploration_weight = weight;
self
}
fn uct_score(&self, child_stats: &NodeStats, parent_visits: usize, prior_score: f32) -> f32 {
if child_stats.visits == 0 {
return f32::INFINITY;
}
let exploitation = child_stats.total_score / child_stats.visits as f32;
let exploration = self.exploration_weight * (parent_visits as f32).ln().sqrt()
/ child_stats.visits as f32;
0.5 * (exploitation + prior_score) + exploration
}
fn select_child(
&self,
tree: &DocumentTree,
node_id: NodeId,
stats: &HashMap<NodeId, NodeStats>,
) -> Option<(NodeId, f32)> {
let children = tree.children(node_id);
if children.is_empty() {
return None;
}
let parent_stats = stats.get(&node_id).cloned().unwrap_or_default();
let parent_visits = parent_stats.visits.max(1);
let mut best_child = None;
let mut best_score = f32::NEG_INFINITY;
for &child_id in &children {
let prior_score = self.scorer.score(tree, child_id);
let child_stats = stats.get(&child_id).cloned().unwrap_or_default();
let uct = self.uct_score(&child_stats, parent_visits, prior_score);
if uct > best_score {
best_score = uct;
best_child = Some((child_id, prior_score));
}
}
best_child
}
fn simulate(&self, tree: &DocumentTree, node_id: NodeId, max_depth: usize) -> f32 {
let mut current = node_id;
let mut depth = 0;
let mut total_score = self.scorer.score(tree, current);
while depth < max_depth {
let children = tree.children(current);
if children.is_empty() {
break;
}
let scored = self.scorer.score_and_sort(tree, &children);
if let Some((child_id, score)) = scored.first() {
total_score += score;
current = *child_id;
} else {
break;
}
depth += 1;
}
total_score / (depth + 1).max(1) as f32
}
fn backpropagate(&self, stats: &mut HashMap<NodeId, NodeStats>, path: &[NodeId], score: f32) {
for &node_id in path {
let node_stats = stats.entry(node_id).or_default();
node_stats.visits += 1;
node_stats.total_score += score;
}
}
}
impl Default for MctsSearch {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SearchTree for MctsSearch {
async fn search(
&self,
tree: &DocumentTree,
context: &RetrievalContext,
config: &SearchConfig,
_pilot: Option<&dyn Pilot>,
) -> SearchResult {
let mut result = SearchResult::default();
let mut stats: HashMap<NodeId, NodeStats> = HashMap::new();
let root = tree.root();
stats.insert(root, NodeStats::default());
for iteration in 0..config.max_iterations {
result.iterations = iteration + 1;
let mut path = vec![root];
let mut current = root;
while !tree.is_leaf(current) {
if let Some((child_id, score)) = self.select_child(tree, current, &stats) {
path.push(child_id);
current = child_id;
} else {
break;
}
}
result.nodes_visited += path.len();
let leaf = *path.last().unwrap_or(&root);
let sim_score = self.simulate(tree, leaf, 5);
self.backpropagate(&mut stats, &path, sim_score);
if let Some(&last_id) = path.last() {
let node = tree.get(last_id);
result.trace.push(NavigationStep {
node_id: format!("{:?}", last_id),
title: node.map(|n| n.title.clone()).unwrap_or_default(),
score: sim_score,
decision: NavigationDecision::ExploreMore,
depth: node.map(|n| n.depth).unwrap_or(0),
});
}
if iteration > 0 && iteration % 10 == 0 {
let root_children = tree.children(root);
let mut scored_children: Vec<_> = root_children
.iter()
.filter_map(|&child_id| {
stats.get(&child_id).map(|s| {
let avg_score = if s.visits > 0 {
s.total_score / s.visits as f32
} else {
0.0
};
(child_id, avg_score)
})
})
.collect();
scored_children
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for (child_id, score) in scored_children.iter().take(config.top_k) {
if *score >= config.min_score {
result.paths.push(SearchPath::from_node(*child_id, *score));
}
}
}
}
let root_children = tree.children(root);
let mut final_paths: Vec<_> = root_children
.iter()
.filter_map(|&child_id| {
stats.get(&child_id).map(|s| {
let avg_score = if s.visits > 0 {
s.total_score / s.visits as f32
} else {
self.scorer.score(tree, child_id)
};
SearchPath::from_node(child_id, avg_score)
})
})
.collect();
final_paths.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
final_paths.truncate(config.top_k);
result.paths = final_paths
.into_iter()
.filter(|p| p.score >= config.min_score)
.collect();
result
}
fn name(&self) -> &'static str {
"mcts"
}
}