use crate::{
Booster,
WeakLearner,
Classifier,
NaiveAggregation,
Sample,
research::Research,
};
use std::ops::ControlFlow;
use std::collections::HashSet;
pub struct GraphSepBoost<'a, F> {
sample: &'a Sample,
edges: Vec<HashSet<usize>>,
hypotheses: Vec<F>,
n_edges: usize,
}
impl<'a, F> GraphSepBoost<'a, F> {
#[inline]
pub fn init(sample: &'a Sample) -> Self {
Self {
sample,
hypotheses: Vec::new(),
edges: Vec::new(),
n_edges: usize::MAX,
}
}
}
impl<'a, F> GraphSepBoost<'a, F>
where F: Classifier
{
#[inline]
fn update_params(&mut self, h: &F) {
let predictions = h.predict_all(self.sample);
let (n_sample, _) = self.sample.shape();
for i in 0..n_sample {
for j in i+1..n_sample {
if predictions[i] != predictions[j] {
self.edges[i].remove(&j);
self.edges[j].remove(&i);
}
}
}
}
}
impl<F> Booster<F> for GraphSepBoost<'_, F>
where F: Classifier + Clone,
{
type Output = NaiveAggregation<F>;
fn name(&self) -> &str {
"Graph Separation Boosting"
}
fn info(&self) -> Option<Vec<(&str, String)>> {
let (n_sample, n_feature) = self.sample.shape();
let info = Vec::from([
("# of examples", format!("{n_sample}")),
("# of features", format!("{n_feature}")),
]);
Some(info)
}
fn preprocess<W>(
&mut self,
_weak_learner: &W,
)
where W: WeakLearner<Hypothesis = F>
{
self.sample.is_valid_binary_instance();
let n_sample = self.sample.shape().0;
let target = self.sample.target();
self.edges = vec![HashSet::new(); n_sample];
for i in 0..n_sample {
for j in i+1..n_sample {
if target[i] != target[j] {
self.edges[i].insert(j);
self.edges[j].insert(i);
}
}
}
self.n_edges = self.edges
.iter()
.map(|edges| edges.len())
.sum();
self.hypotheses = Vec::new();
}
fn boost<W>(
&mut self,
weak_learner: &W,
iteration: usize,
) -> ControlFlow<usize>
where W: WeakLearner<Hypothesis = F>,
{
if self.n_edges == 0 {
return ControlFlow::Break(iteration);
}
let dist = self.edges.iter()
.map(|edge| edge.len() as f64 / self.n_edges as f64)
.collect::<Vec<_>>();
let h = weak_learner.produce(self.sample, &dist);
self.update_params(&h);
self.hypotheses.push(h);
let n_edges = self.edges
.iter()
.map(|edges| edges.len())
.sum::<usize>();
if self.n_edges == n_edges {
eprintln!("[WARN] number of edges does not decrease.");
return ControlFlow::Break(iteration+1);
}
self.n_edges = n_edges;
ControlFlow::Continue(())
}
fn postprocess<W>(
&mut self,
_weak_learner: &W,
) -> Self::Output
where W: WeakLearner<Hypothesis = F>
{
let hypotheses = std::mem::take(&mut self.hypotheses);
NaiveAggregation::new(hypotheses, &self.sample)
}
}
impl<H> Research for GraphSepBoost<'_, H>
where H: Classifier + Clone,
{
type Output = NaiveAggregation<H>;
fn current_hypothesis(&self) -> Self::Output {
NaiveAggregation::from_slice(&self.hypotheses, &self.sample)
}
}