use std::time::Duration;
#[derive(Clone, Debug, PartialEq, Default, serde::Serialize, serde::Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
#[non_exhaustive]
pub enum TreePolicy {
#[default]
Uct,
Puct {
prior_weight: f64,
},
ThompsonSampling {
temperature: f64,
},
Gumbel {
sampled_actions: usize,
max_completions_coeff: f64,
},
}
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct RaveConfig {
pub enabled: bool,
pub bias: f64,
}
impl Default for RaveConfig {
fn default() -> Self {
Self {
enabled: true,
bias: 300.0,
}
}
}
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct ProgressiveWideningConfig {
pub minimum_children: usize,
pub coefficient: f64,
pub exponent: f64,
}
impl Default for ProgressiveWideningConfig {
fn default() -> Self {
Self {
minimum_children: 1,
coefficient: 1.5,
exponent: 0.5,
}
}
}
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct SearchConfig {
pub iterations: usize,
pub exploration_constant: f64,
pub max_depth: usize,
pub tree_policy: TreePolicy,
pub heuristic_weight: f64,
pub rave: RaveConfig,
pub progressive_widening: Option<ProgressiveWideningConfig>,
#[serde(skip)]
pub time_budget: Option<Duration>,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
iterations: 10_000,
exploration_constant: std::f64::consts::SQRT_2,
max_depth: 50,
tree_policy: TreePolicy::default(),
heuristic_weight: 0.35,
rave: RaveConfig::default(),
progressive_widening: None,
time_budget: None,
}
}
}
impl SearchConfig {
pub fn builder() -> SearchConfigBuilder {
SearchConfigBuilder(Self::default())
}
#[cfg(feature = "toml")]
pub fn from_toml_str(input: &str) -> Result<Self, toml::de::Error> {
let mut cfg: SearchConfig = toml::from_str(input)?;
let _warnings = cfg.sanitize();
Ok(cfg)
}
#[cfg(feature = "toml")]
pub fn from_toml_file(
path: impl AsRef<std::path::Path>,
) -> Result<Self, SearchConfigLoadError> {
let path = path.as_ref();
let contents = std::fs::read_to_string(path).map_err(SearchConfigLoadError::Io)?;
let mut cfg: SearchConfig =
toml::from_str(&contents).map_err(SearchConfigLoadError::Toml)?;
let _warnings = cfg.sanitize();
Ok(cfg)
}
#[must_use]
pub fn sanitize(&mut self) -> Vec<String> {
let default = SearchConfig::default();
let mut warnings = Vec::new();
if self.iterations == 0 {
warnings.push(format!(
"iterations must be >= 1, resetting to default {}",
default.iterations
));
self.iterations = default.iterations;
}
if !self.exploration_constant.is_finite() || self.exploration_constant < 0.0 {
warnings.push(format!(
"exploration_constant invalid ({}), resetting to default {}",
self.exploration_constant, default.exploration_constant
));
self.exploration_constant = default.exploration_constant;
}
if !self.heuristic_weight.is_finite() {
warnings.push(format!(
"heuristic_weight invalid ({}), resetting to default {}",
self.heuristic_weight, default.heuristic_weight
));
self.heuristic_weight = default.heuristic_weight;
} else if !(0.0..=1.0).contains(&self.heuristic_weight) {
warnings.push(format!(
"heuristic_weight ({}) out of [0,1], clamping",
self.heuristic_weight
));
self.heuristic_weight = self.heuristic_weight.clamp(0.0, 1.0);
}
if !self.rave.bias.is_finite() || self.rave.bias < 0.0 {
warnings.push(format!(
"rave.bias invalid ({}), resetting to default {}",
self.rave.bias, default.rave.bias
));
self.rave.bias = default.rave.bias;
}
if let Some(pw) = &mut self.progressive_widening {
if pw.minimum_children == 0 {
warnings.push(
"progressive_widening.minimum_children must be >= 1, setting to 1".to_string(),
);
pw.minimum_children = 1;
}
if !pw.coefficient.is_finite() || pw.coefficient < 0.0 {
let default_coeff = default
.progressive_widening
.as_ref()
.map_or(1.5, |c| c.coefficient);
warnings.push(format!(
"progressive_widening.coefficient invalid ({}), resetting to {}",
pw.coefficient, default_coeff
));
pw.coefficient = default_coeff;
}
if !pw.exponent.is_finite() || pw.exponent < 0.0 {
let default_exp = default
.progressive_widening
.as_ref()
.map_or(0.5, |c| c.exponent);
warnings.push(format!(
"progressive_widening.exponent invalid ({}), resetting to {}",
pw.exponent, default_exp
));
pw.exponent = default_exp;
}
}
self.sanitize_tree_policy(&mut warnings);
warnings
}
fn sanitize_tree_policy(&mut self, warnings: &mut Vec<String>) {
match &mut self.tree_policy {
TreePolicy::Puct { prior_weight } => {
if !prior_weight.is_finite() || *prior_weight < 0.0 {
warnings.push(format!(
"tree_policy.puct.prior_weight invalid ({prior_weight}), resetting to default 1.0",
));
*prior_weight = 1.0;
}
}
TreePolicy::ThompsonSampling { temperature } => {
if !temperature.is_finite() || *temperature < 0.0 {
warnings.push(format!(
"tree_policy.thompson_sampling.temperature invalid ({temperature}), resetting to default 1.0",
));
*temperature = 1.0;
}
}
TreePolicy::Gumbel {
sampled_actions,
max_completions_coeff,
} => {
if *sampled_actions == 0 {
warnings.push(
"tree_policy.gumbel.sampled_actions set to 0, defaulting to 16".to_string(),
);
*sampled_actions = 16;
}
if !max_completions_coeff.is_finite() || *max_completions_coeff < 0.0 {
warnings.push(format!(
"tree_policy.gumbel.max_completions_coeff invalid ({max_completions_coeff}), resetting to default 50.0",
));
*max_completions_coeff = 50.0;
}
}
TreePolicy::Uct => {}
}
}
}
pub struct SearchConfigBuilder(SearchConfig);
impl SearchConfigBuilder {
pub fn iterations(mut self, iterations: usize) -> Self {
self.0.iterations = iterations;
self
}
pub fn exploration_constant(mut self, exploration_constant: f64) -> Self {
self.0.exploration_constant = exploration_constant;
self
}
pub fn max_depth(mut self, max_depth: usize) -> Self {
self.0.max_depth = max_depth;
self
}
pub fn tree_policy(mut self, tree_policy: TreePolicy) -> Self {
self.0.tree_policy = tree_policy;
self
}
pub fn heuristic_weight(mut self, heuristic_weight: f64) -> Self {
self.0.heuristic_weight = heuristic_weight;
self
}
pub fn rave(mut self, rave: RaveConfig) -> Self {
self.0.rave = rave;
self
}
pub fn progressive_widening(mut self, widening: ProgressiveWideningConfig) -> Self {
self.0.progressive_widening = Some(widening);
self
}
pub fn time_budget(mut self, budget: Duration) -> Self {
self.0.time_budget = Some(budget);
self
}
pub fn build(self) -> SearchConfig {
let mut cfg = self.0;
let _warnings = cfg.sanitize();
cfg
}
}
#[cfg(feature = "toml")]
#[derive(Debug)]
#[non_exhaustive]
pub enum SearchConfigLoadError {
Io(std::io::Error),
Toml(toml::de::Error),
}
#[cfg(feature = "toml")]
impl std::fmt::Display for SearchConfigLoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(error) => {
write!(
f,
"failed to read config: {error}: check file path permissions and ownership"
)
}
Self::Toml(error) => write!(
f,
"failed to parse TOML config: {error}: validate section layout and key/value types"
),
}
}
}
#[cfg(feature = "toml")]
impl std::error::Error for SearchConfigLoadError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_values() {
let c = SearchConfig::default();
assert_eq!(c.iterations, 10_000);
assert_eq!(c.tree_policy, TreePolicy::Uct);
assert!(c.rave.enabled);
assert!(c.time_budget.is_none());
}
#[test]
fn builder_overrides() {
let c = SearchConfig::builder()
.iterations(500)
.exploration_constant(3.0)
.max_depth(10)
.tree_policy(TreePolicy::Puct { prior_weight: 2.0 })
.heuristic_weight(0.5)
.build();
assert_eq!(c.iterations, 500);
assert_eq!(c.max_depth, 10);
assert_eq!(c.tree_policy, TreePolicy::Puct { prior_weight: 2.0 });
}
#[test]
fn builder_time_budget() {
let c = SearchConfig::builder()
.iterations(100)
.time_budget(Duration::from_millis(50))
.build();
assert_eq!(c.time_budget, Some(Duration::from_millis(50)));
}
#[test]
fn sanitize_returns_warnings() {
let mut c = SearchConfig::default();
c.iterations = 0;
c.heuristic_weight = 5.0;
let warnings = c.sanitize();
assert!(warnings.len() >= 2);
assert!(warnings[0].contains("iterations"));
}
#[test]
fn sanitize_valid_config_returns_empty() {
let mut c = SearchConfig::default();
let warnings = c.sanitize();
assert!(warnings.is_empty());
}
#[test]
fn parse_from_toml() {
let config = SearchConfig::from_toml_str(
r#"
iterations = 64
max_depth = 12
[tree_policy]
kind = "thompson_sampling"
temperature = 0.25
"#,
)
.unwrap();
assert_eq!(config.iterations, 64);
assert_eq!(
config.tree_policy,
TreePolicy::ThompsonSampling { temperature: 0.25 }
);
}
#[test]
fn progressive_widening_roundtrip() {
let config = SearchConfig::builder()
.progressive_widening(ProgressiveWideningConfig {
minimum_children: 2,
coefficient: 1.75,
exponent: 0.4,
})
.build();
let serialized = toml::to_string(&config).unwrap();
let parsed: SearchConfig = toml::from_str(&serialized).unwrap();
let widening = parsed.progressive_widening.unwrap();
assert_eq!(widening.minimum_children, 2);
assert!((widening.exponent - 0.4).abs() < f64::EPSILON);
}
#[test]
fn parse_toml_with_all_sections() {
let config = SearchConfig::from_toml_str(
r"
iterations = 64
max_depth = 7
heuristic_weight = 0.42
[rave]
enabled = false
bias = 111.0
[progressive_widening]
minimum_children = 2
coefficient = 2.5
exponent = 0.4
",
)
.unwrap();
assert!(!config.rave.enabled);
assert!((config.rave.bias - 111.0).abs() < f64::EPSILON);
assert_eq!(
config
.progressive_widening
.as_ref()
.unwrap()
.minimum_children,
2
);
}
#[test]
fn parse_from_toml_file_error() {
let err = SearchConfig::from_toml_file("/does/not/exist.toml").unwrap_err();
assert!(matches!(err, SearchConfigLoadError::Io(_)));
}
#[test]
fn tree_policy_default_is_uct() {
let policy: TreePolicy = TreePolicy::default();
assert!(matches!(policy, TreePolicy::Uct));
}
#[test]
fn tree_policy_puct_is_round_trip_toml() {
let config = SearchConfig::builder()
.tree_policy(TreePolicy::Puct { prior_weight: 0.8 })
.build();
let text = toml::to_string(&config).unwrap();
let loaded: SearchConfig = toml::from_str(&text).unwrap();
assert_eq!(loaded.tree_policy, TreePolicy::Puct { prior_weight: 0.8 });
}
#[test]
fn parse_bad_toml_reports_error() {
let bad = "max_depth = 'oops'";
assert!(SearchConfig::from_toml_str(bad).is_err());
}
}