use crate::error::OptimizeResult;
use crate::result::OptimizeResults;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::random::{rng, Rng, RngExt};
#[derive(Debug, Clone)]
pub struct BanditOptimizer {
pub num_arms: usize,
pub arm_rewards: Array1<f64>,
pub arm_counts: Array1<usize>,
}
impl BanditOptimizer {
pub fn new(num_arms: usize) -> Self {
Self {
num_arms,
arm_rewards: Array1::zeros(num_arms),
arm_counts: Array1::zeros(num_arms),
}
}
pub fn select_arm(&self) -> usize {
let total_counts: usize = self.arm_counts.sum();
if total_counts == 0 {
return scirs2_core::random::rng().random_range(0..self.num_arms);
}
let mut best_arm = 0;
let mut best_ucb = f64::NEG_INFINITY;
for arm in 0..self.num_arms {
if self.arm_counts[arm] == 0 {
return arm; }
let average_reward = self.arm_rewards[arm] / self.arm_counts[arm] as f64;
let confidence_interval =
(2.0 * (total_counts as f64).ln() / self.arm_counts[arm] as f64).sqrt();
let ucb = average_reward + confidence_interval;
if ucb > best_ucb {
best_ucb = ucb;
best_arm = arm;
}
}
best_arm
}
pub fn update_arm(&mut self, arm: usize, reward: f64) {
if arm < self.num_arms {
self.arm_rewards[arm] += reward;
self.arm_counts[arm] += 1;
}
}
}
#[allow(dead_code)]
pub fn bandit_optimize<F>(
objective: F,
initial_params: &ArrayView1<f64>,
num_nit: usize,
) -> OptimizeResult<OptimizeResults<f64>>
where
F: Fn(&ArrayView1<f64>) -> f64,
{
let mut bandit = BanditOptimizer::new(3); let mut params = initial_params.to_owned();
let mut best_obj = objective(initial_params);
for _iter in 0..num_nit {
let arm = bandit.select_arm();
let step_size = match arm {
0 => 0.01, 1 => 0.1, _ => 0.001, };
for i in 0..params.len() {
params[i] += (scirs2_core::random::rng().random::<f64>() - 0.5) * step_size;
}
let new_obj = objective(¶ms.view());
let reward = if new_obj < best_obj { 1.0 } else { 0.0 };
bandit.update_arm(arm, reward);
if new_obj < best_obj {
best_obj = new_obj;
}
}
Ok(OptimizeResults::<f64> {
x: params,
fun: best_obj,
success: true,
nit: num_nit,
message: "Bandit optimization completed".to_string(),
jac: None,
hess: None,
constr: None,
nfev: num_nit * 3, njev: 0,
nhev: 0,
maxcv: 0,
status: 0,
})
}
#[allow(dead_code)]
pub fn placeholder() {}