use std::path::Path;
#[derive(Clone, Debug, PartialEq, Default, serde::Serialize, serde::Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum TreePolicy {
#[default]
Uct,
Puct {
prior_weight: f64,
},
ThompsonSampling {
temperature: f64,
},
}
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct RaveConfig {
pub enabled: bool,
pub bias: f64,
}
impl Default for RaveConfig {
fn default() -> Self {
Self {
enabled: true,
bias: 300.0,
}
}
}
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct ProgressiveWideningConfig {
pub minimum_children: usize,
pub coefficient: f64,
pub exponent: f64,
}
impl Default for ProgressiveWideningConfig {
fn default() -> Self {
Self {
minimum_children: 1,
coefficient: 1.5,
exponent: 0.5,
}
}
}
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct SearchConfig {
pub iterations: usize,
pub exploration_constant: f64,
pub max_depth: usize,
pub tree_policy: TreePolicy,
pub heuristic_weight: f64,
pub rave: RaveConfig,
pub progressive_widening: Option<ProgressiveWideningConfig>,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
iterations: 10_000,
exploration_constant: std::f64::consts::SQRT_2,
max_depth: 50,
tree_policy: TreePolicy::default(),
heuristic_weight: 0.35,
rave: RaveConfig::default(),
progressive_widening: None,
}
}
}
impl SearchConfig {
#[must_use]
pub fn builder() -> SearchConfigBuilder {
SearchConfigBuilder(Self::default())
}
pub fn from_toml_str(input: &str) -> Result<Self, toml::de::Error> {
toml::from_str(input)
}
pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self, SearchConfigLoadError> {
let path = path.as_ref();
let contents = std::fs::read_to_string(path).map_err(SearchConfigLoadError::Io)?;
toml::from_str(&contents).map_err(SearchConfigLoadError::Toml)
}
}
pub struct SearchConfigBuilder(SearchConfig);
impl SearchConfigBuilder {
#[must_use]
pub fn iterations(mut self, iterations: usize) -> Self {
self.0.iterations = iterations;
self
}
#[must_use]
pub fn exploration_constant(mut self, exploration_constant: f64) -> Self {
self.0.exploration_constant = exploration_constant;
self
}
#[must_use]
pub fn max_depth(mut self, max_depth: usize) -> Self {
self.0.max_depth = max_depth;
self
}
#[must_use]
pub fn tree_policy(mut self, tree_policy: TreePolicy) -> Self {
self.0.tree_policy = tree_policy;
self
}
#[must_use]
pub fn heuristic_weight(mut self, heuristic_weight: f64) -> Self {
self.0.heuristic_weight = heuristic_weight;
self
}
#[must_use]
pub fn rave(mut self, rave: RaveConfig) -> Self {
self.0.rave = rave;
self
}
#[must_use]
pub fn progressive_widening(mut self, widening: ProgressiveWideningConfig) -> Self {
self.0.progressive_widening = Some(widening);
self
}
#[must_use]
pub fn build(self) -> SearchConfig {
self.0
}
}
#[derive(Debug)]
pub enum SearchConfigLoadError {
Io(std::io::Error),
Toml(toml::de::Error),
}
impl std::fmt::Display for SearchConfigLoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(error) => {
write!(
f,
"failed to read config: {error}. Fix: check file path permissions and ownership, then retry with a readable configuration file."
)
}
Self::Toml(error) => write!(
f,
"failed to parse TOML config: {error}. Fix: validate the section layout and ensure key/value types match the expected schema."
),
}
}
}
impl std::error::Error for SearchConfigLoadError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_values() {
let c = SearchConfig::default();
assert_eq!(c.iterations, 10_000);
assert_eq!(c.tree_policy, TreePolicy::Uct);
assert!(c.rave.enabled);
}
#[test]
fn builder_overrides() {
let c = SearchConfig::builder()
.iterations(500)
.exploration_constant(3.0)
.max_depth(10)
.tree_policy(TreePolicy::Puct { prior_weight: 2.0 })
.heuristic_weight(0.5)
.build();
assert_eq!(c.iterations, 500);
assert_eq!(c.max_depth, 10);
assert_eq!(c.tree_policy, TreePolicy::Puct { prior_weight: 2.0 });
}
#[test]
fn parse_from_toml() {
let config = SearchConfig::from_toml_str(
r#"
iterations = 64
max_depth = 12
[tree_policy]
kind = "thompson_sampling"
temperature = 0.25
"#,
)
.unwrap();
assert_eq!(config.iterations, 64);
assert_eq!(
config.tree_policy,
TreePolicy::ThompsonSampling { temperature: 0.25 }
);
}
#[test]
fn progressive_widening_roundtrip() {
let config = SearchConfig::builder()
.progressive_widening(ProgressiveWideningConfig {
minimum_children: 2,
coefficient: 1.75,
exponent: 0.4,
})
.build();
let serialized = toml::to_string(&config).unwrap();
let parsed: SearchConfig = toml::from_str(&serialized).unwrap();
let widening = parsed.progressive_widening.unwrap();
assert_eq!(widening.minimum_children, 2);
assert!((widening.exponent - 0.4).abs() < f64::EPSILON);
}
#[test]
fn parse_toml_with_all_sections() {
let config = SearchConfig::from_toml_str(
r"
iterations = 64
max_depth = 7
heuristic_weight = 0.42
[rave]
enabled = false
bias = 111.0
[progressive_widening]
minimum_children = 2
coefficient = 2.5
exponent = 0.4
",
)
.unwrap();
assert!(!config.rave.enabled);
assert!((config.rave.bias - 111.0).abs() < f64::EPSILON);
assert_eq!(
config
.progressive_widening
.as_ref()
.unwrap()
.minimum_children,
2
);
}
#[test]
fn parse_from_toml_file_error() {
let err = SearchConfig::from_toml_file("/does/not/exist.toml").unwrap_err();
assert!(matches!(err, SearchConfigLoadError::Io(_)));
}
#[test]
fn tree_policy_default_is_uct() {
let policy: TreePolicy = TreePolicy::default();
assert!(matches!(policy, TreePolicy::Uct));
}
#[test]
fn tree_policy_puct_is_round_trip_toml() {
let config = SearchConfig::builder()
.tree_policy(TreePolicy::Puct { prior_weight: 0.8 })
.build();
let text = toml::to_string(&config).unwrap();
let loaded: SearchConfig = toml::from_str(&text).unwrap();
assert_eq!(loaded.tree_policy, TreePolicy::Puct { prior_weight: 0.8 });
}
#[test]
fn parse_bad_toml_reports_error() {
let bad = "max_depth = 'oops'";
assert!(SearchConfig::from_toml_str(bad).is_err());
}
}