use crate::{constraints::Constraint, matrix_operations, numeric::cast, FunctionCallResult};
use num::Float;
use std::marker::PhantomData;
use std::{iter::Sum, ops::AddAssign};
pub struct AlmFactory<
MappingF1,
JacobianMappingF1Trans,
MappingF2,
JacobianMappingF2Trans,
Cost,
CostGradient,
SetC,
T = f64,
> where
T: Float + Sum<T> + AddAssign,
Cost: Fn(&[T], &mut T) -> FunctionCallResult, CostGradient: Fn(&[T], &mut [T]) -> FunctionCallResult, MappingF1: Fn(&[T], &mut [T]) -> FunctionCallResult, JacobianMappingF1Trans: Fn(&[T], &[T], &mut [T]) -> FunctionCallResult, MappingF2: Fn(&[T], &mut [T]) -> FunctionCallResult, JacobianMappingF2Trans: Fn(&[T], &[T], &mut [T]) -> FunctionCallResult, SetC: Constraint<T>,
{
f: Cost,
df: CostGradient,
mapping_f1: Option<MappingF1>,
jacobian_mapping_f1_trans: Option<JacobianMappingF1Trans>,
mapping_f2: Option<MappingF2>,
jacobian_mapping_f2_trans: Option<JacobianMappingF2Trans>,
set_c: Option<SetC>,
n2: usize,
marker: PhantomData<T>,
}
impl<
MappingF1,
JacobianMappingF1Trans,
MappingF2,
JacobianMappingF2Trans,
Cost,
CostGradient,
SetC,
T,
>
AlmFactory<
MappingF1,
JacobianMappingF1Trans,
MappingF2,
JacobianMappingF2Trans,
Cost,
CostGradient,
SetC,
T,
>
where
T: Float + Sum<T> + AddAssign,
Cost: Fn(&[T], &mut T) -> FunctionCallResult, CostGradient: Fn(&[T], &mut [T]) -> FunctionCallResult, MappingF1: Fn(&[T], &mut [T]) -> FunctionCallResult, JacobianMappingF1Trans: Fn(&[T], &[T], &mut [T]) -> FunctionCallResult, MappingF2: Fn(&[T], &mut [T]) -> FunctionCallResult, JacobianMappingF2Trans: Fn(&[T], &[T], &mut [T]) -> FunctionCallResult, SetC: Constraint<T>,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
f: Cost,
df: CostGradient,
mapping_f1: Option<MappingF1>,
jacobian_mapping_f1_trans: Option<JacobianMappingF1Trans>,
mapping_f2: Option<MappingF2>,
jacobian_mapping_f2_trans: Option<JacobianMappingF2Trans>,
set_c: Option<SetC>,
n2: usize,
) -> Self {
assert!(
!(mapping_f2.is_none() ^ (n2 == 0)),
"if n2 > 0 then and only then should you provide an F2"
);
assert!(
!(jacobian_mapping_f2_trans.is_none() ^ mapping_f2.is_none()),
"you must have JF2 together with F2"
);
assert!(
!(mapping_f1.is_none() ^ jacobian_mapping_f1_trans.is_none()),
"if n1 > 0 then and only then should you provide an F1"
);
assert!(
!(mapping_f1.is_none() ^ set_c.is_none()),
"F1 must be accompanied by a set C"
);
AlmFactory {
f,
df,
mapping_f1,
jacobian_mapping_f1_trans,
mapping_f2,
jacobian_mapping_f2_trans,
set_c,
n2,
marker: PhantomData,
}
}
pub fn psi(&self, u: &[T], xi: &[T], cost: &mut T) -> FunctionCallResult {
(self.f)(u, cost)?;
let ny = if !xi.is_empty() { xi.len() - 1 } else { 0 };
let mut f1_u_plus_y_over_c = vec![T::zero(); ny];
let mut s = vec![T::zero(); ny];
if let (Some(set_c), Some(mapping_f1)) = (&self.set_c, &self.mapping_f1) {
let penalty_parameter = xi[0];
mapping_f1(u, &mut f1_u_plus_y_over_c)?; let y_lagrange_mult = &xi[1..];
let penalty_scale = if penalty_parameter > T::one() {
penalty_parameter
} else {
T::one()
};
f1_u_plus_y_over_c
.iter_mut()
.zip(y_lagrange_mult.iter())
.for_each(|(ti, yi)| *ti += *yi / penalty_scale);
s.copy_from_slice(&f1_u_plus_y_over_c);
set_c.project(&mut s)?;
let dist_sq: T = matrix_operations::norm2_squared_diff(&f1_u_plus_y_over_c, &s);
let scaling: T = cast::<T>(0.5) * penalty_parameter;
*cost += scaling * dist_sq;
}
if let Some(f2) = &self.mapping_f2 {
let c = xi[0];
let mut z = vec![T::zero(); self.n2];
f2(u, &mut z)?;
let norm_sq: T = matrix_operations::norm2_squared(&z);
let scaling: T = cast::<T>(0.5) * c;
*cost += scaling * norm_sq;
}
Ok(())
}
pub fn d_psi(&self, u: &[T], xi: &[T], grad: &mut [T]) -> FunctionCallResult {
let nu = u.len();
let ny = if !xi.is_empty() { xi.len() - 1 } else { 0 };
(self.df)(u, grad)?;
if let (Some(set_c), Some(mapping_f1), Some(jf1t)) = (
&self.set_c,
&self.mapping_f1,
&self.jacobian_mapping_f1_trans,
) {
let c_penalty_parameter = xi[0];
let mut f1_u_plus_y_over_c = vec![T::zero(); ny];
let mut s_aux_var = vec![T::zero(); ny]; let y_lagrange_mult = &xi[1..];
let mut jac_prod = vec![T::zero(); nu];
mapping_f1(u, &mut f1_u_plus_y_over_c)?; f1_u_plus_y_over_c
.iter_mut()
.zip(y_lagrange_mult.iter())
.for_each(|(ti, yi)| *ti += *yi / c_penalty_parameter);
s_aux_var.copy_from_slice(&f1_u_plus_y_over_c); set_c.project(&mut s_aux_var)?;
f1_u_plus_y_over_c
.iter_mut()
.zip(s_aux_var.iter())
.for_each(|(ti, si)| *ti = *ti - *si);
jf1t(u, &f1_u_plus_y_over_c, &mut jac_prod)?;
grad.iter_mut()
.zip(jac_prod.iter())
.for_each(|(gradi, jac_prodi)| *gradi += c_penalty_parameter * *jac_prodi);
}
if let (Some(f2), Some(jf2)) = (&self.mapping_f2, &self.jacobian_mapping_f2_trans) {
let c = xi[0];
let mut f2u_aux = vec![T::zero(); self.n2];
let mut jf2u_times_f2u_aux = vec![T::zero(); nu];
f2(u, &mut f2u_aux)?; jf2(u, &f2u_aux, &mut jf2u_times_f2u_aux)?;
grad.iter_mut()
.zip(jf2u_times_f2u_aux.iter())
.for_each(|(gradi, jf2u_times_f2u_aux_i)| *gradi += c * *jf2u_times_f2u_aux_i);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::{alm::*, constraints::*, mocks, FunctionCallResult, SolverError};
#[test]
fn t_mocking_alm_factory_psi() {
let set_c = Ball2::new(None, 1.0);
let factory = AlmFactory::new(
mocks::f0,
mocks::d_f0,
Some(mocks::mapping_f1_affine),
Some(mocks::mapping_f1_affine_jacobian_product),
NO_MAPPING,
NO_JACOBIAN_MAPPING,
Some(set_c),
0,
);
let u = [3.0, 5.0, 7.0];
let xi = [2.0, 10.0, 20.0];
let mut cost = 0.0;
assert!(factory.psi(&u, &xi, &mut cost).is_ok());
println!("cost = {}", cost);
unit_test_utils::assert_nearly_equal(
1_064.986_642_583_36,
cost,
1e-14,
1e-10,
"psi is wrong",
);
}
#[test]
fn t_mocking_alm_factory_grad_psi() {
let set_c = Ball2::new(None, 1.0);
let factory = AlmFactory::new(
mocks::f0,
mocks::d_f0,
Some(mocks::mapping_f1_affine),
Some(mocks::mapping_f1_affine_jacobian_product),
NO_MAPPING,
NO_JACOBIAN_MAPPING,
Some(set_c),
0,
);
let u = [3.0, -5.0, 7.0];
let xi = [2.5, 11.0, 20.0];
let mut grad_psi = [0.0; 3];
assert!(factory.d_psi(&u, &xi, &mut grad_psi).is_ok());
unit_test_utils::assert_nearly_equal_array(
&[
71.734_788_756_561_9,
-32.222_828_648_567_5,
46.571_199_153_042_2,
],
&grad_psi,
1e-12,
1e-12,
"d_psi is wrong",
);
}
fn mapping_f2(u: &[f64], res: &mut [f64]) -> FunctionCallResult {
res[0] = u[0];
res[1] = u[1];
res[2] = u[2] - u[0];
res[3] = u[2] - u[0] - u[1];
Ok(())
}
fn jac_mapping_f2_tr(_u: &[f64], d: &[f64], res: &mut [f64]) -> Result<(), crate::SolverError> {
res[0] = d[0] - d[2] - d[3];
res[1] = d[1] - d[3];
res[2] = d[2] + d[3];
Ok(())
}
#[test]
fn t_mocking_alm_factory_psi_with_f2() {
let set_c = Ball2::new(None, 1.0);
let f2 = mapping_f2;
let jac_f2_tr =
|_u: &[f64], _d: &[f64], _res: &mut [f64]| -> Result<(), crate::SolverError> {
Err(SolverError::NotFiniteComputation(
"mock Jacobian-transpose product returned a non-finite result",
))
};
let factory = AlmFactory::new(
mocks::f0,
mocks::d_f0,
Some(mocks::mapping_f1_affine),
Some(mocks::mapping_f1_affine_jacobian_product),
Some(f2),
Some(jac_f2_tr),
Some(set_c),
4,
);
let u = [3.0, 5.0, 7.0];
let xi = [2.0, 10.0, 20.0];
let mut cost = 0.0;
assert!(factory.psi(&u, &xi, &mut cost).is_ok());
println!("cost = {}", cost);
unit_test_utils::assert_nearly_equal(
1.115_986_642_583_36e+03,
cost,
1e-12,
1e-10,
"psi is wrong",
);
}
#[test]
fn t_mocking_alm_factory_grad_psi_with_f2() {
let set_c = Ball2::new(None, 1.0);
let factory = AlmFactory::new(
mocks::f0,
mocks::d_f0,
Some(mocks::mapping_f1_affine),
Some(mocks::mapping_f1_affine_jacobian_product),
Some(mapping_f2),
Some(jac_mapping_f2_tr),
Some(set_c),
4,
);
let u = [3.0, 5.0, 7.0];
let xi = [2.0, 10.0, 20.0];
let mut grad_psi = [0.0; 3];
assert!(factory.d_psi(&u, &xi, &mut grad_psi).is_ok());
println!("grad = {:#?}", &grad_psi);
unit_test_utils::assert_nearly_equal_array(
&[
124.214_512_432_589_49,
180.871_274_908_669_62,
46.962_043_731_516_474,
],
&grad_psi,
1e-12,
1e-12,
"d_psi is wrong",
);
}
#[test]
fn t_mocking_alm_factory_nomappings() {
let factory = AlmFactory::new(
mocks::f0,
mocks::d_f0,
NO_MAPPING,
NO_JACOBIAN_MAPPING,
NO_MAPPING,
NO_JACOBIAN_MAPPING,
NO_SET,
0,
);
let u = [3.0, 5.0, 7.0];
let xi = [];
let mut grad_psi = [0.0; 3];
assert!(factory.d_psi(&u, &xi, &mut grad_psi).is_ok());
println!("grad = {:#?}", &grad_psi);
}
}