use rayon::prelude::*;
use crate::{
Sample,
Booster,
WeakLearner,
Classifier,
WeightedMajority,
research::Research,
};
use std::ops::ControlFlow;
pub struct SmoothBoost<'a, F> {
sample: &'a Sample,
kappa: f64,
theta: f64,
gamma: f64,
n_sample: usize,
current: usize,
terminated: usize,
max_iter: usize,
hypotheses: Vec<F>,
m: Vec<f64>,
n: Vec<f64>,
}
impl<'a, F> SmoothBoost<'a, F> {
pub fn init(sample: &'a Sample) -> Self {
let n_sample = sample.shape().0;
let gamma = 0.5;
Self {
sample,
kappa: 0.5,
theta: gamma / (2.0 + gamma), gamma,
n_sample,
current: 0_usize,
terminated: usize::MAX,
max_iter: usize::MAX,
hypotheses: Vec::new(),
m: Vec::new(),
n: Vec::new(),
}
}
#[inline(always)]
pub fn tolerance(mut self, kappa: f64) -> Self {
self.kappa = kappa;
self
}
#[inline(always)]
pub fn gamma(mut self, gamma: f64) -> Self {
assert!((0.0..0.5).contains(&gamma));
self.gamma = gamma;
self
}
fn theta(&mut self) {
self.theta = self.gamma / (2.0 + self.gamma);
}
fn max_loop(&self) -> usize {
let denom = self.kappa
* self.gamma.powi(2)
* (1.0 - self.gamma).sqrt();
(2.0 / denom).ceil() as usize
}
fn check_preconditions(&self) {
if !(0.0..1.0).contains(&self.kappa) || self.kappa <= 0.0 {
panic!(
"Invalid kappa. \
The parameter `kappa` must be in (0.0, 1.0)"
);
}
if !(self.theta..0.5).contains(&self.gamma) {
panic!(
"Invalid gamma. \
The parameter `gamma` must be in [self.theta, 0.5)"
);
}
}
}
impl<F> Booster<F> for SmoothBoost<'_, F>
where F: Classifier + Clone,
{
type Output = WeightedMajority<F>;
fn name(&self) -> &str {
"SmoothBoost"
}
fn info(&self) -> Option<Vec<(&str, String)>> {
let (n_sample, n_feature) = self.sample.shape();
let info = Vec::from([
("# of examples", format!("{n_sample}")),
("# of features", format!("{n_feature}")),
("Tolerance (Kappa)", format!("{}", self.kappa)),
("Max iteration", format!("{}", self.max_iter)),
("Theta", format!("{}", self.theta)),
("Gamma (WL guarantee)", format!("{}", self.gamma)),
]);
Some(info)
}
fn preprocess<W>(
&mut self,
_weak_learner: &W,
)
where W: WeakLearner<Hypothesis = F>
{
self.sample.is_valid_binary_instance();
self.n_sample = self.sample.shape().0;
self.theta();
self.check_preconditions();
self.current = 0_usize;
self.max_iter = self.max_loop();
self.terminated = self.max_iter;
self.hypotheses = Vec::new();
self.m = vec![1.0; self.n_sample];
self.n = vec![1.0; self.n_sample];
}
fn boost<W>(
&mut self,
weak_learner: &W,
iteration: usize,
) -> ControlFlow<usize>
where W: WeakLearner<Hypothesis = F>
{
if self.max_iter < iteration {
return ControlFlow::Break(self.max_iter);
}
self.current = iteration;
let sum = self.m.iter().sum::<f64>();
if sum < self.n_sample as f64 * self.kappa {
self.terminated = iteration - 1;
return ControlFlow::Break(iteration);
}
let dist = self.m.iter()
.map(|mj| *mj / sum)
.collect::<Vec<_>>();
self.hypotheses.push(
weak_learner.produce(self.sample, &dist[..])
);
let h: &F = self.hypotheses.last().unwrap();
let target = self.sample.target();
let margins = target.iter()
.enumerate()
.map(|(i, y)| y * h.confidence(self.sample, i));
self.n.iter_mut()
.zip(margins)
.for_each(|(nj, yh)| {
*nj = *nj + yh - self.theta;
});
self.m.par_iter_mut()
.zip(&self.n[..])
.for_each(|(mj, nj)| {
if *nj <= 0.0 {
*mj = 1.0;
} else {
*mj = (1.0 - self.gamma).powf(*nj * 0.5);
}
});
ControlFlow::Continue(())
}
fn postprocess<W>(
&mut self,
_weak_learner: &W,
) -> Self::Output
where W: WeakLearner<Hypothesis = F>
{
let weight = 1.0 / self.terminated as f64;
let weights = vec![weight; self.n_sample];
WeightedMajority::from_slices(&weights[..], &self.hypotheses[..])
}
}
impl<H> Research for SmoothBoost<'_, H>
where H: Classifier + Clone,
{
type Output = WeightedMajority<H>;
fn current_hypothesis(&self) -> Self::Output {
let weight = 1.0 / self.terminated as f64;
let weights = vec![weight; self.n_sample];
WeightedMajority::from_slices(&weights[..], &self.hypotheses[..])
}
}