use crate::core::barrier::LogBarrier;
use crate::core::constraint::LinearInequalityConstraints;
use crate::core::executor::run_loop;
use crate::core::inner::WarmStart;
use crate::core::math::{
MatTransposeVec, MatVec, NegInPlace, NormSquared, ScaledAdd, VectorIndex, VectorLen,
};
use crate::core::problem::{CostFunction, Gradient, Problem};
use crate::core::solver::Solver;
use crate::core::state::{BasicState, CountsMirror, GradientState, State};
use crate::core::termination::{
GradientTolerance, MaxIter, TerminationCriterion, TerminationReason,
};
pub struct BarrierMethod<So> {
inner_solver: So,
inner_max_iter: u64,
inner_grad_tol: f64,
mu0: f64,
mu: f64,
reduction: f64,
tol: f64,
gap: f64,
infeasible: bool,
}
impl<So> BarrierMethod<So> {
pub fn new(inner_solver: So) -> Self {
Self {
inner_solver,
inner_max_iter: 50,
inner_grad_tol: 1e-8,
mu0: 1.0,
mu: 1.0,
reduction: 10.0,
tol: 1e-8,
gap: f64::INFINITY,
infeasible: false,
}
}
pub fn mu0(mut self, mu0: f64) -> Self {
assert!(mu0 > 0.0, "mu0 must be > 0");
self.mu0 = mu0;
self
}
pub fn reduction(mut self, reduction: f64) -> Self {
assert!(reduction > 1.0, "reduction must be > 1");
self.reduction = reduction;
self
}
pub fn tol(mut self, tol: f64) -> Self {
assert!(tol > 0.0, "tol must be > 0");
self.tol = tol;
self
}
pub fn inner_max_iter(mut self, inner_max_iter: u64) -> Self {
assert!(inner_max_iter >= 1, "inner_max_iter must be ≥ 1");
self.inner_max_iter = inner_max_iter;
self
}
pub fn inner_grad_tol(mut self, inner_grad_tol: f64) -> Self {
assert!(inner_grad_tol >= 0.0, "inner_grad_tol must be ≥ 0");
self.inner_grad_tol = inner_grad_tol;
self
}
}
impl<P, V, M, So> Solver<P, BasicState<V>> for BarrierMethod<So>
where
P: CostFunction<Param = V, Output = f64>
+ Gradient<Gradient = V>
+ LinearInequalityConstraints<Param = V, Matrix = M>,
M: MatVec<V> + MatTransposeVec<V>,
V: ScaledAdd<f64> + NegInPlace + VectorIndex + VectorLen + NormSquared + Clone,
So: WarmStart<V>
+ for<'a> Solver<LogBarrier<'a, P>, So::State, Error = <P as CostFunction>::Error>,
So::State: GradientState<Param = V> + CountsMirror,
{
type Error = <P as CostFunction>::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicState<V>,
) -> Result<BasicState<V>, Self::Error> {
self.mu = self.mu0;
self.gap = f64::INFINITY;
let mut slack = problem.inner().a().matvec(state.param());
slack.neg_in_place();
slack.scaled_add(1.0, problem.inner().b());
self.infeasible = (0..slack.vec_len()).any(|i| slack.get_scalar(i) <= 0.0);
let (cost, grad) = problem.cost_and_gradient(state.param())?;
state.cost = Some(cost);
state.gradient = Some(grad);
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
mut state: BasicState<V>,
) -> Result<(BasicState<V>, Option<TerminationReason>), Self::Error> {
if self.infeasible {
return Ok((state, Some(TerminationReason::SolverFailed)));
}
let mut barrier_wrapper = Problem::new(LogBarrier::new(problem.inner(), self.mu));
let mut criteria: Vec<Box<dyn TerminationCriterion<So::State>>> = vec![
Box::new(MaxIter(self.inner_max_iter)),
Box::new(GradientTolerance(self.inner_grad_tol)),
];
let inner_state = self.inner_solver.seed(state.param());
let result = run_loop(
&mut barrier_wrapper,
inner_state,
&mut self.inner_solver,
&mut criteria,
self.inner_max_iter,
)?;
let inner_counts = *barrier_wrapper.counts();
problem.counts_mut().add(&inner_counts);
if result.reason.is_failure() {
return Ok((state, Some(TerminationReason::SolverFailed)));
}
state.param = result.state.param().clone();
let (cost, grad) = problem.cost_and_gradient(&state.param)?;
state.cost = Some(cost);
state.gradient = Some(grad);
self.gap = problem.inner().b().vec_len() as f64 * self.mu;
self.mu /= self.reduction;
Ok((state, None))
}
fn terminate(&self, _state: &BasicState<V>) -> Option<TerminationReason> {
if self.gap <= self.tol {
Some(TerminationReason::SolverConverged)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic(expected = "mu0 must be > 0")]
fn rejects_nonpositive_mu0() {
let _ = BarrierMethod::new(()).mu0(0.0);
}
#[test]
#[should_panic(expected = "reduction must be > 1")]
fn rejects_reduction_not_greater_than_one() {
let _ = BarrierMethod::new(()).reduction(1.0);
}
#[test]
#[should_panic(expected = "tol must be > 0")]
fn rejects_nonpositive_tol() {
let _ = BarrierMethod::new(()).tol(0.0);
}
#[test]
#[should_panic(expected = "inner_max_iter must be ≥ 1")]
fn rejects_zero_inner_max_iter() {
let _ = BarrierMethod::new(()).inner_max_iter(0);
}
#[test]
#[should_panic(expected = "inner_grad_tol must be ≥ 0")]
fn rejects_negative_inner_grad_tol() {
let _ = BarrierMethod::new(()).inner_grad_tol(-1.0);
}
}