use faer::{Col, Mat};
use crate::bytecode_tape::BytecodeTape;
use crate::BReverse;
pub fn grad_faer(f: impl FnOnce(&[BReverse<f64>]) -> BReverse<f64>, x: &Col<f64>) -> Col<f64> {
let xs: Vec<f64> = (0..x.nrows()).map(|i| x[i]).collect();
let (mut tape, _) = crate::api::record(f, &xs);
let g = tape.gradient(&xs);
Col::from_fn(g.len(), |i| g[i])
}
pub fn grad_faer_val(
f: impl FnOnce(&[BReverse<f64>]) -> BReverse<f64>,
x: &Col<f64>,
) -> (f64, Col<f64>) {
let xs: Vec<f64> = (0..x.nrows()).map(|i| x[i]).collect();
let (mut tape, val) = crate::api::record(f, &xs);
let g = tape.gradient(&xs);
(val, Col::from_fn(g.len(), |i| g[i]))
}
pub fn hessian_faer(
f: impl FnOnce(&[BReverse<f64>]) -> BReverse<f64>,
x: &Col<f64>,
) -> (f64, Col<f64>, Mat<f64>) {
let xs: Vec<f64> = (0..x.nrows()).map(|i| x[i]).collect();
let (tape, _) = crate::api::record(f, &xs);
let (val, grad, hess) = tape.hessian(&xs);
let n = xs.len();
let g = Col::from_fn(n, |i| grad[i]);
let h = Mat::from_fn(n, n, |i, j| hess[i][j]);
(val, g, h)
}
pub fn jacobian_faer(
f: impl FnOnce(&[BReverse<f64>]) -> Vec<BReverse<f64>>,
x: &Col<f64>,
) -> Mat<f64> {
let xs: Vec<f64> = (0..x.nrows()).map(|i| x[i]).collect();
let (mut tape, _) = crate::api::record_multi(f, &xs);
let jac = tape.jacobian(&xs);
let m = jac.len();
let n = if m > 0 { jac[0].len() } else { xs.len() };
Mat::from_fn(m, n, |i, j| jac[i][j])
}
pub fn tape_gradient_faer(tape: &mut BytecodeTape<f64>, x: &Col<f64>) -> Col<f64> {
let xs: Vec<f64> = (0..x.nrows()).map(|i| x[i]).collect();
let g = tape.gradient(&xs);
Col::from_fn(g.len(), |i| g[i])
}
#[must_use]
pub fn tape_hessian_faer(tape: &BytecodeTape<f64>, x: &Col<f64>) -> (f64, Col<f64>, Mat<f64>) {
let xs: Vec<f64> = (0..x.nrows()).map(|i| x[i]).collect();
let (val, grad, hess) = tape.hessian(&xs);
let n = xs.len();
let g = Col::from_fn(n, |i| grad[i]);
let h = Mat::from_fn(n, n, |i, j| hess[i][j]);
(val, g, h)
}
pub fn hvp_faer(
f: impl FnOnce(&[BReverse<f64>]) -> BReverse<f64>,
x: &Col<f64>,
v: &Col<f64>,
) -> (Col<f64>, Col<f64>) {
let xs: Vec<f64> = (0..x.nrows()).map(|i| x[i]).collect();
let vs: Vec<f64> = (0..v.nrows()).map(|i| v[i]).collect();
let (grad, hvp) = crate::api::hvp(f, &xs, &vs);
(
Col::from_fn(grad.len(), |i| grad[i]),
Col::from_fn(hvp.len(), |i| hvp[i]),
)
}
#[must_use]
pub fn tape_hvp_faer(tape: &BytecodeTape<f64>, x: &Col<f64>, v: &Col<f64>) -> (Col<f64>, Col<f64>) {
let xs: Vec<f64> = (0..x.nrows()).map(|i| x[i]).collect();
let vs: Vec<f64> = (0..v.nrows()).map(|i| v[i]).collect();
let (grad, hvp) = tape.hvp(&xs, &vs);
(
Col::from_fn(grad.len(), |i| grad[i]),
Col::from_fn(hvp.len(), |i| hvp[i]),
)
}
pub fn sparse_hessian_faer(
f: impl FnOnce(&[BReverse<f64>]) -> BReverse<f64>,
x: &Col<f64>,
) -> (f64, Col<f64>, crate::sparse::SparsityPattern, Vec<f64>) {
let xs: Vec<f64> = (0..x.nrows()).map(|i| x[i]).collect();
let (val, grad, pattern, values) = crate::api::sparse_hessian(f, &xs);
let g = Col::from_fn(grad.len(), |i| grad[i]);
(val, g, pattern, values)
}
#[must_use]
pub fn tape_sparse_hessian_faer(
tape: &BytecodeTape<f64>,
x: &Col<f64>,
) -> (f64, Col<f64>, crate::sparse::SparsityPattern, Vec<f64>) {
let xs: Vec<f64> = (0..x.nrows()).map(|i| x[i]).collect();
let (val, grad, pattern, values) = tape.sparse_hessian(&xs);
let g = Col::from_fn(grad.len(), |i| grad[i]);
(val, g, pattern, values)
}
#[must_use]
pub fn solve_dense_lu_faer(a: &Mat<f64>, b: &Col<f64>) -> Col<f64> {
use faer::linalg::solvers::Solve;
a.partial_piv_lu().solve(b)
}
#[must_use]
pub fn solve_dense_cholesky_faer(a: &Mat<f64>, b: &Col<f64>) -> Option<Col<f64>> {
use faer::linalg::solvers::Solve;
Some(a.llt(faer::Side::Lower).ok()?.solve(b))
}
#[must_use]
pub fn sparsity_to_faer_symmetric(
pattern: &crate::sparse::SparsityPattern,
values: &[f64],
) -> Option<faer::sparse::SparseColMat<usize, f64>> {
assert_eq!(pattern.nnz(), values.len());
let mut triplets: Vec<faer::sparse::Triplet<usize, usize, f64>> =
Vec::with_capacity(pattern.nnz() * 2);
for ((&row, &col), &v) in pattern.rows.iter().zip(&pattern.cols).zip(values) {
let r = row as usize;
let c = col as usize;
triplets.push(faer::sparse::Triplet {
row: r,
col: c,
val: v,
});
if r != c {
triplets.push(faer::sparse::Triplet {
row: c,
col: r,
val: v,
});
}
}
faer::sparse::SparseColMat::<usize, f64>::try_new_from_triplets(
pattern.dim,
pattern.dim,
&triplets,
)
.ok()
}
#[must_use]
pub fn solve_sparse_cholesky_faer(
pattern: &crate::sparse::SparsityPattern,
values: &[f64],
b: &Col<f64>,
) -> Option<Col<f64>> {
use faer::linalg::solvers::Solve;
let mat = sparsity_to_faer_symmetric(pattern, values)?;
let chol = mat.sp_cholesky(faer::Side::Lower).ok()?;
Some(chol.solve(b))
}
#[must_use]
pub fn solve_sparse_lu_faer(
pattern: &crate::sparse::SparsityPattern,
values: &[f64],
b: &Col<f64>,
) -> Option<Col<f64>> {
use faer::linalg::solvers::Solve;
let mat = sparsity_to_faer_symmetric(pattern, values)?;
let lu = mat.sp_lu().ok()?;
Some(lu.solve(b))
}