use crate::{
Sample,
Classifier,
common::{utils, checker},
};
use std::fmt;
const SUB_TOLERANCE: f64 = 1e-9;
#[derive(Clone, Copy)]
pub enum FWType {
Classic,
ShortStep,
LineSearch,
BlendedPairwise,
}
impl fmt::Display for FWType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let fw = match self {
Self::Classic => "Classic FW",
Self::ShortStep => "Short-step FW",
Self::LineSearch => "Line-search FW",
Self::BlendedPairwise => "Blended Pairwise (BP) FW",
};
write!(f, "{fw}")
}
}
pub(crate) struct FrankWolfe {
eta: f64, nu: f64,
fw_type: FWType,
}
impl FrankWolfe {
pub(crate) fn new(eta: f64, nu: f64, fw_type: FWType) -> Self {
Self { eta, nu, fw_type, }
}
pub(crate) fn eta(&mut self, eta: f64) {
self.eta = eta;
}
pub(crate) fn nu(&mut self, nu: f64) {
self.nu = nu;
}
pub(crate) fn current_type(&self) -> FWType {
self.fw_type
}
pub(crate) fn fw_type(&mut self, fw_type: FWType) {
self.fw_type = fw_type;
}
pub(crate) fn next_iterate<H>(
&self,
iteration: usize,
sample: &Sample,
dist: &[f64],
hypotheses: &[H],
position_of_new_one: usize,
weights: Vec<f64>,
) -> Vec<f64>
where H: Classifier,
{
let n_hypotheses = hypotheses.len();
assert!((0..n_hypotheses).contains(&position_of_new_one));
match self.fw_type {
FWType::Classic
=> self.classic(iteration, position_of_new_one, weights),
FWType::ShortStep
=> self.short_step(
sample, dist, hypotheses, position_of_new_one, weights,
),
FWType::LineSearch
=> self.line_search(
sample, hypotheses, position_of_new_one, weights,
),
FWType::BlendedPairwise
=> self.blended_pairwise(
sample, dist, hypotheses, position_of_new_one, weights,
),
}
}
fn classic(
&self,
iteration: usize,
position_of_new_one: usize,
weights: Vec<f64>,
) -> Vec<f64>
{
let step_size = 2.0_f64 / ((iteration + 1) as f64);
interior_point(step_size, position_of_new_one, weights)
}
fn short_step<H>(
&self,
sample: &Sample,
dist: &[f64],
hypotheses: &[H],
position_of_new_one: usize,
weights: Vec<f64>,
) -> Vec<f64>
where H: Classifier,
{
if hypotheses.len() == 1 { return vec![1.0]; }
let h = &hypotheses[position_of_new_one];
let old_margins = utils::margins_of_weighted_hypothesis(
sample, &weights[..], hypotheses,
);
let new_margins = utils::margins_of_hypothesis(sample, h);
let mut numer: f64 = 0.0;
let mut denom: f64 = f64::MIN;
new_margins.into_iter()
.zip(old_margins)
.zip(dist)
.for_each(|((ae, aw), &d)| {
let diff = ae - aw;
numer += d * diff;
denom = denom.max(diff.abs());
});
let step = numer / (self.eta * denom.powi(2));
let step_size = step.clamp(0f64, 1f64);
interior_point(step_size, position_of_new_one, weights)
}
fn line_search<H>(
&self,
sample: &Sample,
hypotheses: &[H],
position_of_new_one: usize,
mut weights: Vec<f64>,
) -> Vec<f64>
where H: Classifier,
{
let base = weights.clone();
let dir = weights.iter()
.copied()
.enumerate()
.map(|(j, w)|
if j == position_of_new_one { 1.0 - w } else { -w }
)
.collect::<Vec<_>>();
let dir_margins = utils::margins_of_weighted_hypothesis(
sample, &dir[..], hypotheses,
);
let base_margins = utils::margins_of_weighted_hypothesis(
sample, &base[..], hypotheses,
);
let dist = utils::exp_distribution(
self.eta, self.nu, sample, &dir[..], hypotheses,
);
let dot = utils::inner_product(&dist[..], &dir_margins[..]);
if dot <= 0.0 {
let n_weights = weights.len();
weights = vec![0.0; n_weights];
weights[position_of_new_one] = 1.0;
return weights;
}
let mut ub = 1.0;
let mut lb = 0.0;
while ub - lb > SUB_TOLERANCE {
let step_size = (lb + ub) / 2.0;
let margins = base_margins.iter()
.zip(&dir_margins[..])
.map(|(&b, &d)| b + step_size * d);
let dist = utils::exp_distribution_from_margins(
self.eta, self.nu, margins,
);
let dot = utils::inner_product(&dist[..], &dir_margins[..]);
if dot < 0.0 {
lb = step_size;
} else if dot > 0.0 {
ub = step_size;
} else {
break;
}
}
let step_size = (lb + ub) / 2.0;
interior_point(step_size, position_of_new_one, base)
}
fn blended_pairwise<H>(
&self,
sample: &Sample,
dist: &[f64],
hypotheses: &[H],
position_of_new_one: usize,
mut weights: Vec<f64>,
) -> Vec<f64>
where H: Classifier,
{
let mut worst_edge = 2.0;
let mut local_best_edge = -2.0;
let mut global_best_edge = -2.0;
let mut position_of_worst_one = hypotheses.len();
let mut position_of_local_best_one = hypotheses.len();
let mut position_of_global_best_one = position_of_new_one;
weights.iter()
.zip(hypotheses)
.enumerate()
.filter_map(|(j, (w, h))| {
if *w <= 0.0 {
None
} else {
let edge = utils::edge_of_hypothesis(sample, dist, h);
Some((j, edge))
}
})
.for_each(|(j, edge)| {
if j != position_of_new_one && edge > local_best_edge {
local_best_edge = edge;
position_of_local_best_one = j;
}
if edge > global_best_edge {
global_best_edge = edge;
position_of_global_best_one = j;
}
if edge < worst_edge {
worst_edge = edge;
position_of_worst_one = j;
}
});
if position_of_global_best_one + 1 != hypotheses.len() {
local_best_edge = global_best_edge;
position_of_local_best_one = position_of_global_best_one;
}
let current_edge = utils::edge_of_weighted_hypothesis(
sample, dist, &weights[..], hypotheses
);
let lhs = local_best_edge - worst_edge;
let rhs = global_best_edge - current_edge;
if lhs >= rhs {
let max_stepsize = weights[position_of_worst_one];
let local_best_margins = utils::margins_of_hypothesis(
sample, &hypotheses[position_of_local_best_one]
);
let worst_margins = utils::margins_of_hypothesis(
sample, &hypotheses[position_of_worst_one]
);
let dir_margins = local_best_margins.into_iter()
.zip(worst_margins)
.map(|(a, b)| a - b)
.collect::<Vec<_>>();
let base_margins = utils::margins_of_weighted_hypothesis(
sample, &weights[..], hypotheses,
);
let margins = dir_margins.iter()
.zip(base_margins.iter())
.map(|(dir, cur)| cur + max_stepsize * dir)
.collect::<Vec<_>>();
let dist = utils::exp_distribution(
self.eta, self.nu, sample, &margins[..], hypotheses,
);
if utils::inner_product(&dist[..], &dir_margins[..]) <= 0.0 {
weights[position_of_new_one] = max_stepsize;
weights[position_of_worst_one] = 0.0;
return weights;
}
let mut ub = max_stepsize;
let mut lb = 0.0;
while ub - lb > SUB_TOLERANCE {
let step_size = (lb + ub) / 2.0;
let margins = base_margins.iter()
.zip(&dir_margins[..])
.map(|(&b, &d)| b + step_size * d);
let dist = utils::exp_distribution_from_margins(
self.eta, self.nu, margins,
);
let dot = utils::inner_product(&dist[..], &dir_margins[..]);
if dot < 0.0 {
lb = step_size;
} else if dot > 0.0 {
ub = step_size;
} else {
break;
}
}
let step_size = (lb + ub) / 2.0;
weights[position_of_local_best_one] += step_size;
weights[position_of_worst_one] -= step_size;
weights
} else {
self.line_search(
sample, hypotheses, position_of_new_one, weights,
)
}
}
}
pub(crate) fn interior_point(
step_size: f64,
new_basis: usize,
base: Vec<f64>,
) -> Vec<f64>
{
checker::check_stepsize(step_size);
base.into_iter()
.enumerate()
.map(|(j, b)| {
let dir = if j == new_basis { 1.0 - b } else { -b };
b + step_size * dir
})
.collect()
}