use std::mem;
use crate::{
Sample,
Booster,
WeakLearner,
Classifier,
WeightedMajority,
common::utils,
common::checker,
common::frank_wolfe::{FrankWolfe, FWType},
research::Research,
};
use std::ops::ControlFlow;
pub struct CERLPBoost<'a, F> {
sample: &'a Sample,
dist: Vec<f64>,
eta: f64,
half_tolerance: f64,
nu: f64,
frank_wolfe: FrankWolfe,
weights: Vec<f64>,
hypotheses: Vec<F>,
max_iter: usize,
terminated: usize,
}
impl<'a, F> CERLPBoost<'a, F> {
pub fn init(sample: &'a Sample) -> Self {
let n_sample = sample.shape().0;
let half_tolerance = 0.005;
let nu = 1.0;
let eta = (n_sample as f64 / nu).ln() / half_tolerance;
let frank_wolfe = FrankWolfe::new(eta, nu, FWType::ShortStep);
Self {
sample,
dist: Vec::new(),
half_tolerance,
eta,
nu: 1.0,
frank_wolfe,
weights: Vec::new(),
hypotheses: Vec::new(),
max_iter: usize::MAX,
terminated: usize::MAX,
}
}
pub fn nu(mut self, nu: f64) -> Self {
let (n_sample, _) = self.sample.shape();
checker::check_nu(nu, n_sample);
self.nu = nu;
self.frank_wolfe.nu(self.nu);
self.regularization_param();
self
}
#[inline(always)]
pub fn tolerance(mut self, tolerance: f64) -> Self {
self.half_tolerance = tolerance / 2.0;
self
}
#[inline(always)]
pub fn fw_type(mut self, fw_type: FWType) -> Self {
self.frank_wolfe.fw_type(fw_type);
self
}
#[inline(always)]
fn regularization_param(&mut self) {
let m = self.dist.len() as f64;
let ln_part = (m / self.nu).ln();
self.eta = ln_part / self.half_tolerance;
self.frank_wolfe.eta(self.eta);
}
pub fn max_loop(&mut self) -> usize {
let m = self.dist.len() as f64;
let ln_m = (m / self.nu).ln();
let max_iter = 8.0 * ln_m / self.half_tolerance.powi(2);
max_iter.ceil() as usize
}
pub fn variant(mut self, fw_type: FWType) -> Self {
self.frank_wolfe.fw_type(fw_type);
self
}
}
impl<F> CERLPBoost<'_, F>
where F: Classifier + PartialEq,
{
fn update_distribution_mut(&mut self) {
self.dist = utils::exp_distribution(
self.eta, self.nu, self.sample,
&self.weights[..], &self.hypotheses[..],
);
}
}
impl<F> Booster<F> for CERLPBoost<'_, F>
where F: Classifier + Clone + PartialEq + std::fmt::Debug,
{
type Output = WeightedMajority<F>;
fn name(&self) -> &str {
"Corrective ERLPBoost"
}
fn info(&self) -> Option<Vec<(&str, String)>> {
let (n_sample, n_feature) = self.sample.shape();
let ratio = self.nu / n_sample as f64;
let nu = utils::format_unit(self.nu);
let fw = self.frank_wolfe.current_type();
let info = Vec::from([
("# of examples", format!("{n_sample}")),
("# of features", format!("{n_feature}")),
("Tolerance", format!("{}", 2f64 * self.half_tolerance)),
("Max iteration", format!("{}", self.max_iter)),
("Capping (outliers)", format!("{nu} ({ratio: >7.3} %)")),
("Frank-Wolfe", format!("{fw}")),
]);
Some(info)
}
fn preprocess<W>(
&mut self,
_weak_learner: &W,
)
where W: WeakLearner<Hypothesis = F>
{
self.sample.is_valid_binary_instance();
let n_sample = self.sample.shape().0;
let uni = 1.0 / n_sample as f64;
self.dist = vec![uni; n_sample];
self.regularization_param();
self.max_iter = self.max_loop();
self.terminated = self.max_iter;
self.weights = Vec::new();
self.hypotheses = Vec::new();
}
fn boost<W>(
&mut self,
weak_learner: &W,
iteration: usize,
) -> ControlFlow<usize>
where W: WeakLearner<Hypothesis = F>,
{
if self.max_iter < iteration {
return ControlFlow::Break(self.max_iter);
}
self.update_distribution_mut();
let h = weak_learner.produce(self.sample, &self.dist);
let new_edge = utils::edge_of_hypothesis(
self.sample, &self.dist[..], &h
);
let old_edge = utils::edge_of_weighted_hypothesis(
self.sample, &self.dist[..],
&self.weights[..], &self.hypotheses[..]
);
let diff = new_edge - old_edge;
if diff <= self.half_tolerance {
self.terminated = iteration;
return ControlFlow::Break(iteration);
}
let pos = self.hypotheses.iter()
.position(|f| *f == h)
.unwrap_or(self.hypotheses.len());
if pos == self.hypotheses.len() {
self.hypotheses.push(h);
self.weights.push(0.0);
}
let weights = mem::take(&mut self.weights);
self.weights = self.frank_wolfe.next_iterate(
iteration, self.sample, &self.dist[..],
&self.hypotheses[..], pos, weights,
);
ControlFlow::Continue(())
}
fn postprocess<W>(
&mut self,
_weak_learner: &W,
) -> Self::Output
where W: WeakLearner<Hypothesis = F>
{
WeightedMajority::from_slices(&self.weights[..], &self.hypotheses[..])
}
}
impl<H> Research for CERLPBoost<'_, H>
where H: Classifier + Clone,
{
type Output = WeightedMajority<H>;
fn current_hypothesis(&self) -> Self::Output {
WeightedMajority::from_slices(&self.weights[..], &self.hypotheses[..])
}
}