use std::collections::HashMap;
#[non_exhaustive]
pub struct MctsConfig {
pub max_depth: usize,
pub simulations: usize,
pub exploration_constant: f32,
}
impl Default for MctsConfig {
fn default() -> Self {
Self {
max_depth: 5,
simulations: 50,
exploration_constant: 1.414,
}
}
}
pub struct MctsStrategy {
config: MctsConfig,
}
impl MctsStrategy {
pub const fn new(config: MctsConfig) -> Self {
Self { config }
}
pub fn select_action(
&self,
available_actions: &[String],
value_estimates: &[(String, f32)],
) -> Option<String> {
if available_actions.is_empty() {
return None;
}
let value_map: HashMap<&str, f32> = value_estimates
.iter()
.map(|(a, v)| (a.as_str(), *v))
.collect();
let best = available_actions
.iter()
.map(|a| {
let v = value_map.get(a.as_str()).copied().unwrap_or(0.0);
(a.clone(), v + self.config.exploration_constant)
})
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
best.map(|(a, _)| a)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn select_action_returns_highest_value() {
let config = MctsConfig::default();
let strategy = MctsStrategy::new(config);
let actions = vec![
"tool_a".to_owned(),
"tool_b".to_owned(),
"tool_c".to_owned(),
];
let estimates = vec![
("tool_a".to_owned(), 0.3_f32),
("tool_b".to_owned(), 0.9_f32),
("tool_c".to_owned(), 0.1_f32),
];
let selected = strategy.select_action(&actions, &estimates);
assert_eq!(selected, Some("tool_b".to_owned()));
}
#[test]
fn select_action_empty_returns_none() {
let strategy = MctsStrategy::new(MctsConfig::default());
assert_eq!(strategy.select_action(&[], &[]), None);
}
#[test]
fn select_action_no_estimates_returns_first_by_exploration() {
let strategy = MctsStrategy::new(MctsConfig::default());
let actions = vec!["only_tool".to_owned()];
let selected = strategy.select_action(&actions, &[]);
assert_eq!(selected, Some("only_tool".to_owned()));
}
}