mctrust 0.4.0

Universal search & planning toolkit — MCTS, bandit search, pluggable evaluators, tree reuse, DAG transpositions, root parallelism. Define an Environment, search handles the rest.
Documentation
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// Scalarizer for named signals supplied at observation time.
///
/// The scalarizer maps signal names to weights and computes a single reward score:
/// `sum(weights[name] * signal_value)` for all provided signals.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Scalarizer {
    /// Weights by signal name.
    pub signal_weights: HashMap<String, f64>,

    /// Weight to apply when a provided signal is not explicitly configured.
    pub default_weight: f64,
}

impl Scalarizer {
    /// Scalarizes a set of named signals into a single reward.
    ///
    /// # Parameters
    ///
    /// - `signals`: Signal name/value pairs.
    ///
    /// # Returns
    ///
    /// Returns the weighted sum of all provided signals using this scalarizer.
    pub fn scalarize(&self, signals: &[(&str, f64)]) -> f64 {
        signals.iter().fold(0.0, |acc, (name, value)| {
            if !value.is_finite() {
                return acc;
            }

            let weight = self
                .signal_weights
                .get(*name)
                .copied()
                .unwrap_or(self.default_weight);
            if weight.is_finite() {
                acc + weight * value
            } else {
                acc
            }
        })
    }
}

impl Default for Scalarizer {
    fn default() -> Self {
        Self {
            signal_weights: HashMap::new(),
            default_weight: 0.0,
        }
    }
}

/// Configuration for [`crate::BanditSearch`].
///
/// # Examples
///
/// ```
/// use mctrust::BanditConfig;
///
/// let config = BanditConfig::builder()
///     .exploration_constant(2.0)
///     .rave_bias(300.0)
///     .max_pulls(10_000)
///     .build();
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BanditConfig {
    /// UCT exploration constant. Default: `sqrt(2)`.
    pub exploration_constant: f64,

    /// RAVE bias weight. Controls cross-arm reward propagation.
    ///
    /// When an arm in group A receives a high reward, all unvisited arms
    /// in other groups get a proportional RAVE boost. The bias decays
    /// as visit counts increase.
    ///
    /// - `0.0` = RAVE disabled (pure UCT).
    /// - `500.0` = strong RAVE influence early, decays with visits.
    ///
    /// Default: `500.0`.
    pub rave_bias: f64,

    /// Maximum number of arms to pull before stopping. `0` = unlimited.
    pub max_pulls: u64,

    /// Scalarizer used when observing vector signals.
    #[serde(default)]
    pub scalarizer: Scalarizer,
}

impl Default for BanditConfig {
    fn default() -> Self {
        Self {
            exploration_constant: std::f64::consts::SQRT_2,
            rave_bias: 500.0,
            max_pulls: 0,
            scalarizer: Scalarizer::default(),
        }
    }
}

impl BanditConfig {
    /// Creates a builder initialized with default bandit-search settings.
    ///
    /// # Parameters
    ///
    /// This function takes no additional parameters.
    ///
    /// # Returns
    ///
    /// Returns a `BanditConfigBuilder` seeded from [`BanditConfig::default`].
    ///
    /// # Panics
    ///
    /// This function does not panic.
    pub fn builder() -> BanditConfigBuilder {
        BanditConfigBuilder(Self::default())
    }

    /// Validates and corrects invalid configuration fields.
    ///
    /// Returns a list of human-readable warnings describing what was fixed.
    /// An empty list means the configuration was already valid.
    #[must_use]
    pub fn sanitize(&mut self) -> Vec<String> {
        let default = BanditConfig::default();
        let mut warnings = Vec::new();

        if !self.exploration_constant.is_finite() || self.exploration_constant < 0.0 {
            warnings.push(format!(
                "exploration_constant invalid ({}), resetting to default {}",
                self.exploration_constant, default.exploration_constant
            ));
            self.exploration_constant = default.exploration_constant;
        }

        if !self.rave_bias.is_finite() || self.rave_bias < 0.0 {
            warnings.push(format!(
                "rave_bias invalid ({}), resetting to default {}",
                self.rave_bias, default.rave_bias
            ));
            self.rave_bias = default.rave_bias;
        }

        if !self.scalarizer.default_weight.is_finite() {
            warnings.push(format!(
                "scalarizer.default_weight invalid ({}), resetting to default {}",
                self.scalarizer.default_weight, default.scalarizer.default_weight
            ));
            self.scalarizer.default_weight = default.scalarizer.default_weight;
        }

        let invalid_signal_weights: Vec<(String, f64)> = self
            .scalarizer
            .signal_weights
            .iter()
            .filter_map(|(name, weight)| {
                if weight.is_finite() {
                    None
                } else {
                    Some((name.clone(), *weight))
                }
            })
            .collect();
        for (name, weight) in invalid_signal_weights {
            warnings.push(format!(
                "scalarizer.signal_weights[{name}] invalid ({weight}), removing"
            ));
            self.scalarizer.signal_weights.remove(&name);
        }

        warnings
    }
}

/// Fluent builder for [`BanditConfig`].
pub struct BanditConfigBuilder(BanditConfig);

impl BanditConfigBuilder {
    /// Sets the UCT exploration constant.
    ///
    /// # Parameters
    ///
    /// - `c`: Exploration multiplier to store in the builder.
    ///
    /// # Returns
    ///
    /// Returns the updated builder.
    ///
    /// # Panics
    ///
    /// This function does not panic.
    pub fn exploration_constant(mut self, c: f64) -> Self {
        self.0.exploration_constant = c;
        self
    }

    /// Sets the RAVE bias weight.
    ///
    /// # Parameters
    ///
    /// - `bias`: RAVE bias value to store in the builder.
    ///
    /// # Returns
    ///
    /// Returns the updated builder.
    ///
    /// # Panics
    ///
    /// This function does not panic.
    pub fn rave_bias(mut self, bias: f64) -> Self {
        self.0.rave_bias = bias;
        self
    }

    /// Sets the maximum number of arm pulls.
    ///
    /// # Parameters
    ///
    /// - `n`: Pull budget to enforce, or `0` for unlimited.
    ///
    /// # Returns
    ///
    /// Returns the updated builder.
    ///
    /// # Panics
    ///
    /// This function does not panic.
    pub fn max_pulls(mut self, n: u64) -> Self {
        self.0.max_pulls = n;
        self
    }

    /// Replaces the configured signal scalarizer.
    ///
    /// # Parameters
    ///
    /// - `scalarizer`: Scalarizer used by [`BanditSearch::observe_with_signals`].
    ///
    /// # Returns
    ///
    /// Returns the updated builder.
    pub fn scalarizer(mut self, scalarizer: Scalarizer) -> Self {
        self.0.scalarizer = scalarizer;
        self
    }

    /// Finalizes the builder and returns the accumulated [`BanditConfig`].
    ///
    /// # Parameters
    ///
    /// This function takes no additional parameters.
    ///
    /// # Returns
    ///
    /// Returns the built [`BanditConfig`].
    ///
    /// # Panics
    ///
    /// This function does not panic.
    pub fn build(self) -> BanditConfig {
        let mut cfg = self.0;
        let _warnings = cfg.sanitize();
        cfg
    }
}