use grb::prelude::*;
use crate::Sample;
use crate::hypothesis::Classifier;
const QP_TOLERANCE: f64 = 1e-9;
pub(super) struct QPModel {
pub(self) eta: f64,
pub(self) model: Model,
pub(self) gamma: Var,
pub(self) dist: Vec<Var>,
pub(self) constrs: Vec<Constr>,
}
impl QPModel {
pub(super) fn init(eta: f64, size: usize, upper_bound: f64)
-> Self
{
let mut env = Env::new("").unwrap();
env.set(param::OutputFlag, 0).unwrap();
let mut model = Model::with_env("MLPBoost", env).unwrap();
let gamma = add_ctsvar!(model, name: "gamma", bounds: ..)
.unwrap();
let dist = (0..size).map(|i| {
let name = format!("d[{i}]");
add_ctsvar!(model, name: &name, bounds: 0.0..upper_bound)
.unwrap()
}).collect::<Vec<Var>>();
model.add_constr(&"sum_is_1", c!(dist.iter().grb_sum() == 1.0))
.unwrap();
model.update().unwrap();
Self {
eta,
model,
gamma,
dist,
constrs: Vec::new(),
}
}
pub(super) fn update<F>(
&mut self,
sample: &Sample,
dist: &mut [f64],
clf: &F,
)
where F: Classifier
{
let edge = sample.target()
.into_iter()
.enumerate()
.map(|(i, y)| y * clf.confidence(sample, i))
.zip(self.dist.iter().copied())
.map(|(yh, d)| d * yh)
.grb_sum();
let name = format!("{t}-th hypothesis", t = self.constrs.len());
self.constrs.push(
self.model.add_constr(&name, c!(edge <= self.gamma))
.unwrap()
);
self.model.update()
.unwrap();
let mut old_objval = 1e9;
loop {
let regularizer = dist.iter()
.copied()
.zip(self.dist.iter())
.map(|(d, &grb_d)| {
let l_term = d.ln() * grb_d;
let q_term = (0.5_f64 / d) * (grb_d * grb_d);
l_term + q_term
})
.grb_sum();
let objective = self.gamma
+ ((1.0_f64 / self.eta) * regularizer);
self.model.set_objective(objective, Minimize)
.unwrap();
self.model.optimize()
.unwrap();
let status = self.model.status().unwrap();
if status != Status::Optimal && status != Status::SubOptimal {
panic!("Status ({status:?}) is not optimal.");
}
let objval = self.model.get_attr(attr::ObjVal).unwrap();
let mut any_zero = false;
dist.iter_mut()
.zip(&self.dist[..])
.for_each(|(d, grb_d)| {
let g = self.model.get_obj_attr(attr::X, &grb_d)
.unwrap();
any_zero |= g == 0.0;
*d = g;
});
if any_zero || old_objval - objval < QP_TOLERANCE {
break;
}
old_objval = objval;
}
}
pub(super) fn distribution(&self)
-> Vec<f64>
{
self.dist.iter()
.map(|d| self.model.get_obj_attr(attr::X, d).unwrap())
.collect::<Vec<_>>()
}
pub(super) fn weight(&mut self) -> impl Iterator<Item=f64> + '_
{
let objective = self.gamma;
self.model.set_objective(objective, Minimize)
.unwrap();
self.model.update()
.unwrap();
self.model.optimize()
.unwrap();
let status = self.model.status()
.unwrap();
if status != Status::Optimal {
panic!("Cannot solve the primal problem. Status: {status:?}");
}
self.constrs[0..].iter()
.map(|c| self.model.get_obj_attr(attr::Pi, c).unwrap().abs())
}
}