use crate::{
constraints,
core::{fbs::FBSCache, AlgorithmEngine, Problem},
matrix_operations, FunctionCallResult, SolverError,
};
use num::Float;
pub struct FBSEngine<'a, GradientType, ConstraintType, CostType, T = f64>
where
T: Float,
GradientType: Fn(&[T], &mut [T]) -> FunctionCallResult,
CostType: Fn(&[T], &mut T) -> FunctionCallResult,
ConstraintType: constraints::Constraint<T>,
{
pub(crate) problem: Problem<'a, GradientType, ConstraintType, CostType, T>,
pub(crate) cache: &'a mut FBSCache<T>,
}
impl<'a, GradientType, ConstraintType, CostType, T>
FBSEngine<'a, GradientType, ConstraintType, CostType, T>
where
T: Float,
GradientType: Fn(&[T], &mut [T]) -> FunctionCallResult,
CostType: Fn(&[T], &mut T) -> FunctionCallResult,
ConstraintType: constraints::Constraint<T>,
{
pub fn new(
problem: Problem<'a, GradientType, ConstraintType, CostType, T>,
cache: &'a mut FBSCache<T>,
) -> FBSEngine<'a, GradientType, ConstraintType, CostType, T> {
FBSEngine { problem, cache }
}
fn gradient_step(&mut self, u_current: &mut [T]) -> FunctionCallResult {
(self.problem.gradf)(u_current, &mut self.cache.work_gradient_u)?;
if !crate::matrix_operations::is_finite(&self.cache.work_gradient_u) {
return Err(SolverError::NotFiniteComputation(
"gradient evaluation returned a non-finite value during an FBS step",
));
}
u_current
.iter_mut()
.zip(self.cache.work_gradient_u.iter())
.for_each(|(u, w)| *u = *u - self.cache.gamma * *w);
Ok(())
}
fn projection_step(&mut self, u_current: &mut [T]) -> FunctionCallResult {
self.problem.constraints.project(u_current)
}
}
impl<'a, GradientType, ConstraintType, CostType, T> AlgorithmEngine<T>
for FBSEngine<'a, GradientType, ConstraintType, CostType, T>
where
T: Float,
GradientType: Fn(&[T], &mut [T]) -> FunctionCallResult + 'a,
CostType: Fn(&[T], &mut T) -> FunctionCallResult + 'a,
ConstraintType: constraints::Constraint<T> + 'a,
{
fn step(&mut self, u_current: &mut [T]) -> Result<bool, SolverError> {
self.cache.work_u_previous.copy_from_slice(u_current); self.gradient_step(u_current)?; self.projection_step(u_current)?; if !crate::matrix_operations::is_finite(u_current) {
return Err(SolverError::NotFiniteComputation(
"projected iterate contains a non-finite value during an FBS step",
));
}
self.cache.norm_fpr =
matrix_operations::norm_inf_diff(u_current, &self.cache.work_u_previous);
Ok(self.cache.norm_fpr > self.cache.tolerance)
}
fn init(&mut self, _u_current: &mut [T]) -> FunctionCallResult {
Ok(())
}
}