use crate::{
constraints,
core::{fbs::FBSCache, AlgorithmEngine, Problem},
matrix_operations, SolverError,
};
pub struct FBSEngine<'a, GradientType, ConstraintType, CostType>
where
GradientType: Fn(&[f64], &mut [f64]) -> Result<(), SolverError>,
CostType: Fn(&[f64], &mut f64) -> Result<(), SolverError>,
ConstraintType: constraints::Constraint,
{
pub(crate) problem: Problem<'a, GradientType, ConstraintType, CostType>,
pub(crate) cache: &'a mut FBSCache,
}
impl<'a, GradientType, ConstraintType, CostType>
FBSEngine<'a, GradientType, ConstraintType, CostType>
where
GradientType: Fn(&[f64], &mut [f64]) -> Result<(), SolverError>,
CostType: Fn(&[f64], &mut f64) -> Result<(), SolverError>,
ConstraintType: constraints::Constraint,
{
pub fn new(
problem: Problem<'a, GradientType, ConstraintType, CostType>,
cache: &'a mut FBSCache,
) -> FBSEngine<'a, GradientType, ConstraintType, CostType> {
FBSEngine { problem, cache }
}
fn gradient_step(&mut self, u_current: &mut [f64]) {
assert_eq!(
Ok(()),
(self.problem.gradf)(u_current, &mut self.cache.work_gradient_u),
"The computation of the gradient of the cost failed miserably"
);
u_current
.iter_mut()
.zip(self.cache.work_gradient_u.iter())
.for_each(|(u, w)| *u -= self.cache.gamma * *w);
}
fn projection_step(&mut self, u_current: &mut [f64]) {
self.problem.constraints.project(u_current);
}
}
impl<'a, GradientType, ConstraintType, CostType> AlgorithmEngine
for FBSEngine<'a, GradientType, ConstraintType, CostType>
where
GradientType: Fn(&[f64], &mut [f64]) -> Result<(), SolverError> + 'a,
CostType: Fn(&[f64], &mut f64) -> Result<(), SolverError> + 'a,
ConstraintType: constraints::Constraint + 'a,
{
fn step(&mut self, u_current: &mut [f64]) -> Result<bool, SolverError> {
self.cache.work_u_previous.copy_from_slice(u_current); self.gradient_step(u_current); self.projection_step(u_current); self.cache.norm_fpr =
matrix_operations::norm_inf_diff(u_current, &self.cache.work_u_previous);
Ok(self.cache.norm_fpr > self.cache.tolerance)
}
fn init(&mut self, _u_current: &mut [f64]) -> Result<(), SolverError> {
Ok(())
}
}