use crate::{make::var, types::advec, Ad};
use faer::{
sparse::{CreationError, SparseColMat},
Col,
};
use itertools::Itertools;
pub struct ComputedObjective<const N: usize> {
pub value: f64,
pub grad: Col<f64>,
pub hess_trips: Vec<(usize, usize, f64)>,
}
pub trait Objective<const N: usize> {
type EvalArgs;
fn eval(&self, variables: &advec<N, N>, args: &Self::EvalArgs) -> Ad<N>;
fn evaluate_for_indices(
&self,
global_inds: [usize; N],
x: &Col<f64>,
args: &Self::EvalArgs,
) -> Ad<N> {
let vals = global_inds.map(|i| x[i]);
let vals_slice = vals.as_slice();
let vars = var::vector_from_slice(vals_slice);
self.eval(&vars, args)
}
fn compute(
&self,
x: &Col<f64>,
operand_indices: &[[usize; N]],
args: &Self::EvalArgs,
) -> ComputedObjective<N> {
let mut value = 0.0;
let mut grad = Col::zeros(x.nrows());
let mut hess_trips = Vec::new();
for &global_inds in operand_indices {
let obj = self.evaluate_for_indices(global_inds, x, args);
let ind = global_inds.into_iter().enumerate();
value += obj.value;
ind.clone()
.for_each(|(ilocal, iglobal)| grad[iglobal] += obj.grad[ilocal]);
ind.clone().cartesian_product(ind).for_each(
|((ixlocal, ixglobal), (iylocal, iyglobal))| {
hess_trips.push((ixglobal, iyglobal, obj.hess[(ixlocal, iylocal)]));
},
);
}
ComputedObjective {
value,
grad,
hess_trips,
}
}
fn value(&self, x: &Col<f64>, operand_indices: &[[usize; N]], args: &Self::EvalArgs) -> f64 {
let mut res = 0.0;
operand_indices.iter().for_each(|&ind| {
let obj = self.evaluate_for_indices(ind, x, args);
res += obj.value;
});
res
}
fn grad(
&self,
x: &Col<f64>,
operand_indices: &[[usize; N]],
args: &Self::EvalArgs,
) -> Col<f64> {
let mut res = Col::zeros(x.nrows());
operand_indices.iter().for_each(|&ind| {
let obj = self.evaluate_for_indices(ind, x, args);
ind.into_iter()
.enumerate()
.for_each(|(ilocal, iglobal)| res[iglobal] += obj.grad[ilocal]);
});
res
}
fn hess_trips(
&self,
x: &Col<f64>,
operand_indices: &[[usize; N]],
args: &Self::EvalArgs,
) -> Vec<(usize, usize, f64)> {
let mut trips = Vec::new();
operand_indices.iter().for_each(|&ind| {
let obj = self.evaluate_for_indices(ind, x, args);
let ind = ind.into_iter().enumerate();
ind.clone().cartesian_product(ind).for_each(
|((ixlocal, ixglobal), (iylocal, iyglobal))| {
trips.push((ixglobal, iyglobal, obj.hess[(ixlocal, iylocal)]));
},
);
});
trips
}
fn hess(
&self,
x: &Col<f64>,
operand_indices: &[[usize; N]],
args: &Self::EvalArgs,
) -> Result<SparseColMat<usize, f64>, CreationError> {
let n = x.nrows();
SparseColMat::try_new_from_triplets(n, n, &self.hess_trips(x, operand_indices, args))
}
}