use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Scalarizer {
pub signal_weights: HashMap<String, f64>,
pub default_weight: f64,
}
impl Scalarizer {
pub fn scalarize(&self, signals: &[(&str, f64)]) -> f64 {
signals.iter().fold(0.0, |acc, (name, value)| {
if !value.is_finite() {
return acc;
}
let weight = self
.signal_weights
.get(*name)
.copied()
.unwrap_or(self.default_weight);
if weight.is_finite() {
acc + weight * value
} else {
acc
}
})
}
}
impl Default for Scalarizer {
fn default() -> Self {
Self {
signal_weights: HashMap::new(),
default_weight: 0.0,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BanditConfig {
pub exploration_constant: f64,
pub rave_bias: f64,
pub max_pulls: u64,
#[serde(default)]
pub scalarizer: Scalarizer,
}
impl Default for BanditConfig {
fn default() -> Self {
Self {
exploration_constant: std::f64::consts::SQRT_2,
rave_bias: 500.0,
max_pulls: 0,
scalarizer: Scalarizer::default(),
}
}
}
impl BanditConfig {
pub fn builder() -> BanditConfigBuilder {
BanditConfigBuilder(Self::default())
}
#[must_use]
pub fn sanitize(&mut self) -> Vec<String> {
let default = BanditConfig::default();
let mut warnings = Vec::new();
if !self.exploration_constant.is_finite() || self.exploration_constant < 0.0 {
warnings.push(format!(
"exploration_constant invalid ({}), resetting to default {}",
self.exploration_constant, default.exploration_constant
));
self.exploration_constant = default.exploration_constant;
}
if !self.rave_bias.is_finite() || self.rave_bias < 0.0 {
warnings.push(format!(
"rave_bias invalid ({}), resetting to default {}",
self.rave_bias, default.rave_bias
));
self.rave_bias = default.rave_bias;
}
if !self.scalarizer.default_weight.is_finite() {
warnings.push(format!(
"scalarizer.default_weight invalid ({}), resetting to default {}",
self.scalarizer.default_weight, default.scalarizer.default_weight
));
self.scalarizer.default_weight = default.scalarizer.default_weight;
}
let invalid_signal_weights: Vec<(String, f64)> = self
.scalarizer
.signal_weights
.iter()
.filter_map(|(name, weight)| {
if weight.is_finite() {
None
} else {
Some((name.clone(), *weight))
}
})
.collect();
for (name, weight) in invalid_signal_weights {
warnings.push(format!(
"scalarizer.signal_weights[{name}] invalid ({weight}), removing"
));
self.scalarizer.signal_weights.remove(&name);
}
warnings
}
}
pub struct BanditConfigBuilder(BanditConfig);
impl BanditConfigBuilder {
pub fn exploration_constant(mut self, c: f64) -> Self {
self.0.exploration_constant = c;
self
}
pub fn rave_bias(mut self, bias: f64) -> Self {
self.0.rave_bias = bias;
self
}
pub fn max_pulls(mut self, n: u64) -> Self {
self.0.max_pulls = n;
self
}
pub fn scalarizer(mut self, scalarizer: Scalarizer) -> Self {
self.0.scalarizer = scalarizer;
self
}
pub fn build(self) -> BanditConfig {
let mut cfg = self.0;
let _warnings = cfg.sanitize();
cfg
}
}