use crate::custom_family::{ExactNewtonJointGradientEvaluation, ExactNewtonJointHessianWorkspace};
use crate::solver::estimate::reml::unified::{
HyperOperator, ProjectedFactorCache, ProjectedFactorKey,
};
use ndarray::{Array1, Array2, ArrayView2};
use rayon::prelude::*;
use std::sync::Arc;
pub trait RowKernel<const K: usize>: Send + Sync {
fn n_rows(&self) -> usize;
fn n_coefficients(&self) -> usize;
fn row_kernel(&self, row: usize) -> Result<(f64, [f64; K], [[f64; K]; K]), String>;
fn jacobian_action(&self, row: usize, d_beta: &[f64]) -> [f64; K];
fn jacobian_transpose_action(&self, row: usize, v: &[f64; K], out: &mut [f64]);
fn add_pullback_hessian(&self, row: usize, h: &[[f64; K]; K], target: &mut Array2<f64>);
fn add_diagonal_quadratic(&self, row: usize, h: &[[f64; K]; K], diag: &mut [f64]);
fn row_third_contracted(&self, row: usize, dir: &[f64; K]) -> Result<[[f64; K]; K], String>;
fn row_fourth_contracted(
&self,
row: usize,
dir_u: &[f64; K],
dir_v: &[f64; K],
) -> Result<[[f64; K]; K], String>;
fn warm_up_directional_caches(&self) -> Result<(), String> {
Ok(())
}
}
pub struct RowKernelCache<const K: usize> {
pub n: usize,
pub p: usize,
pub nll: Vec<f64>,
pub gradients: Vec<[f64; K]>,
pub hessians: Vec<[[f64; K]; K]>,
}
pub fn build_row_kernel_cache<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
) -> Result<RowKernelCache<K>, String> {
let n = kern.n_rows();
let p = kern.n_coefficients();
let rows: Vec<(f64, [f64; K], [[f64; K]; K])> = (0..n)
.into_par_iter()
.map(|row| kern.row_kernel(row))
.collect::<Result<Vec<_>, String>>()?;
let mut nll = Vec::with_capacity(n);
let mut gradients = Vec::with_capacity(n);
let mut hessians = Vec::with_capacity(n);
for (l, g, h) in rows {
nll.push(l);
gradients.push(g);
hessians.push(h);
}
Ok(RowKernelCache {
n,
p,
nll,
gradients,
hessians,
})
}
pub fn row_kernel_hessian_matvec<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
direction: &[f64],
) -> Array1<f64> {
let p = cache.p;
let out = (0..cache.n)
.into_par_iter()
.fold(
|| vec![0.0_f64; p],
|mut acc, row| {
let dir_k = kern.jacobian_action(row, direction);
let h = &cache.hessians[row];
let mut action = [0.0_f64; K];
for a in 0..K {
let mut s = 0.0;
for b in 0..K {
s += h[a][b] * dir_k[b];
}
action[a] = s;
}
kern.jacobian_transpose_action(row, &action, &mut acc);
acc
},
)
.reduce(
|| vec![0.0; p],
|mut a, b| {
for i in 0..a.len() {
a[i] += b[i];
}
a
},
);
Array1::from_vec(out)
}
pub fn row_kernel_hessian_diagonal<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
) -> Array1<f64> {
let p = cache.p;
let out = (0..cache.n)
.into_par_iter()
.fold(
|| vec![0.0_f64; p],
|mut diag, row| {
kern.add_diagonal_quadratic(row, &cache.hessians[row], &mut diag);
diag
},
)
.reduce(
|| vec![0.0; p],
|mut a, b| {
for i in 0..a.len() {
a[i] += b[i];
}
a
},
);
Array1::from_vec(out)
}
pub fn row_kernel_gradient<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
) -> Array1<f64> {
let p = cache.p;
let out = (0..cache.n)
.into_par_iter()
.fold(
|| vec![0.0_f64; p],
|mut acc, row| {
kern.jacobian_transpose_action(row, &cache.gradients[row], &mut acc);
acc
},
)
.reduce(
|| vec![0.0; p],
|mut a, b| {
for i in 0..a.len() {
a[i] += b[i];
}
a
},
);
Array1::from_vec(out)
}
pub fn row_kernel_log_likelihood<const K: usize>(cache: &RowKernelCache<K>) -> f64 {
-cache.nll.iter().sum::<f64>()
}
pub fn row_kernel_hessian_dense<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
) -> Array2<f64> {
let p = cache.p;
(0..cache.n)
.into_par_iter()
.fold(
|| Array2::<f64>::zeros((p, p)),
|mut acc, row| {
kern.add_pullback_hessian(row, &cache.hessians[row], &mut acc);
acc
},
)
.reduce(|| Array2::zeros((p, p)), |a, b| a + b)
}
pub fn row_kernel_directional_derivative<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
d_beta: &[f64],
) -> Result<Array2<f64>, String> {
let n = kern.n_rows();
let p = kern.n_coefficients();
(0..n)
.into_par_iter()
.try_fold(
|| Array2::<f64>::zeros((p, p)),
|mut acc, row| -> Result<_, String> {
let dir_k = kern.jacobian_action(row, d_beta);
let third = kern.row_third_contracted(row, &dir_k)?;
kern.add_pullback_hessian(row, &third, &mut acc);
Ok(acc)
},
)
.try_reduce(|| Array2::zeros((p, p)), |a, b| Ok(a + b))
}
pub fn row_kernel_second_directional_derivative<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
d_beta_u: &[f64],
d_beta_v: &[f64],
) -> Result<Array2<f64>, String> {
let n = kern.n_rows();
let p = kern.n_coefficients();
(0..n)
.into_par_iter()
.try_fold(
|| Array2::<f64>::zeros((p, p)),
|mut acc, row| -> Result<_, String> {
let dir_u = kern.jacobian_action(row, d_beta_u);
let dir_v = kern.jacobian_action(row, d_beta_v);
let fourth = kern.row_fourth_contracted(row, &dir_u, &dir_v)?;
kern.add_pullback_hessian(row, &fourth, &mut acc);
Ok(acc)
},
)
.try_reduce(|| Array2::zeros((p, p)), |a, b| Ok(a + b))
}
struct RowKernelDirectionalDerivativeOperator<const K: usize, T: RowKernel<K>> {
kern: Arc<T>,
direction: Vec<f64>,
p: usize,
}
impl<const K: usize, T: RowKernel<K>> HyperOperator
for RowKernelDirectionalDerivativeOperator<K, T>
{
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let direction = v
.as_slice()
.expect("row-kernel directional derivative operator requires contiguous input");
let out = (0..self.kern.n_rows())
.into_par_iter()
.fold(
|| vec![0.0_f64; self.p],
|mut acc, row| {
let dir_k = self.kern.jacobian_action(row, &self.direction);
let vec_k = self.kern.jacobian_action(row, direction);
let third = self.kern.row_third_contracted(row, &dir_k).expect(
"row-kernel third contraction should succeed for validated directions",
);
let mut action = [0.0_f64; K];
for a in 0..K {
let mut sum = 0.0;
for b in 0..K {
sum += third[a][b] * vec_k[b];
}
action[a] = sum;
}
self.kern.jacobian_transpose_action(row, &action, &mut acc);
acc
},
)
.reduce(
|| vec![0.0_f64; self.p],
|mut left, right| {
for idx in 0..left.len() {
left[idx] += right[idx];
}
left
},
);
Array1::from_vec(out)
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
debug_assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
let jf = self.compute_jf(factor);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
debug_assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
let jf = self.cached_jf(factor, cache);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn to_dense(&self) -> Array2<f64> {
row_kernel_directional_derivative(&*self.kern, &self.direction)
.expect("row-kernel directional derivative dense materialization should succeed")
}
fn is_implicit(&self) -> bool {
true
}
}
impl<const K: usize, T: RowKernel<K>> RowKernelDirectionalDerivativeOperator<K, T> {
fn compute_jf(&self, factor: &Array2<f64>) -> Array2<f64> {
let n_rows = self.kern.n_rows();
let rank = factor.ncols();
let stride = K * rank;
let mut jf = Array2::<f64>::zeros((n_rows, stride));
if n_rows == 0 || rank == 0 {
return jf;
}
let f_t: Array2<f64> = factor.t().as_standard_layout().into_owned();
jf.as_slice_mut()
.expect("row-major JF matrix must be contiguous")
.par_chunks_mut(stride)
.enumerate()
.for_each(|(row, jf_row)| {
for k_col in 0..rank {
let f_slice = f_t
.row(k_col)
.to_slice()
.expect("standard-layout row must be contiguous");
let vec_k = self.kern.jacobian_action(row, f_slice);
for k in 0..K {
jf_row[k * rank + k_col] = vec_k[k];
}
}
});
jf
}
fn cached_jf(&self, factor: &Array2<f64>, cache: &ProjectedFactorCache) -> Arc<Array2<f64>> {
let design_id = Arc::as_ptr(&self.kern) as *const () as usize;
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
cache.get_or_insert_with(key, || self.compute_jf(factor))
}
fn trace_projected_factor_with_jf(&self, factor: &Array2<f64>, jf: ArrayView2<'_, f64>) -> f64 {
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
debug_assert_eq!(jf.dim(), (n_rows, K * rank));
let direction = self.direction.as_slice();
(0..n_rows)
.into_par_iter()
.map(|row| -> f64 {
let dir_k = self.kern.jacobian_action(row, direction);
let third = self
.kern
.row_third_contracted(row, &dir_k)
.expect("row-kernel third contraction should succeed for validated directions");
let jf_row = jf.row(row);
let jf_slice = jf_row
.to_slice()
.expect("J·F is built standard-layout (row-major)");
let mut row_total = 0.0_f64;
for k_col in 0..rank {
let mut vec_k = [0.0_f64; K];
for k in 0..K {
vec_k[k] = jf_slice[k * rank + k_col];
}
let mut quad = 0.0_f64;
for a in 0..K {
let mut t_dot = 0.0_f64;
for b in 0..K {
t_dot += third[a][b] * vec_k[b];
}
quad += vec_k[a] * t_dot;
}
row_total += quad;
}
row_total
})
.sum()
}
}
struct RowKernelSecondDirectionalDerivativeOperator<const K: usize, T: RowKernel<K>> {
kern: Arc<T>,
direction_u: Vec<f64>,
direction_v: Vec<f64>,
p: usize,
}
impl<const K: usize, T: RowKernel<K>> HyperOperator
for RowKernelSecondDirectionalDerivativeOperator<K, T>
{
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let direction = v
.as_slice()
.expect("row-kernel second directional derivative operator requires contiguous input");
let out = (0..self.kern.n_rows())
.into_par_iter()
.fold(
|| vec![0.0_f64; self.p],
|mut acc, row| {
let dir_u = self.kern.jacobian_action(row, &self.direction_u);
let dir_v = self.kern.jacobian_action(row, &self.direction_v);
let vec_k = self.kern.jacobian_action(row, direction);
let fourth = self.kern.row_fourth_contracted(row, &dir_u, &dir_v).expect(
"row-kernel fourth contraction should succeed for validated directions",
);
let mut action = [0.0_f64; K];
for a in 0..K {
let mut sum = 0.0;
for b in 0..K {
sum += fourth[a][b] * vec_k[b];
}
action[a] = sum;
}
self.kern.jacobian_transpose_action(row, &action, &mut acc);
acc
},
)
.reduce(
|| vec![0.0_f64; self.p],
|mut left, right| {
for idx in 0..left.len() {
left[idx] += right[idx];
}
left
},
);
Array1::from_vec(out)
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
debug_assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
let jf = self.compute_jf(factor);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
debug_assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
let jf = self.cached_jf(factor, cache);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn to_dense(&self) -> Array2<f64> {
row_kernel_second_directional_derivative(&*self.kern, &self.direction_u, &self.direction_v)
.expect("row-kernel second directional derivative dense materialization should succeed")
}
fn is_implicit(&self) -> bool {
true
}
}
impl<const K: usize, T: RowKernel<K>> RowKernelSecondDirectionalDerivativeOperator<K, T> {
fn compute_jf(&self, factor: &Array2<f64>) -> Array2<f64> {
let n_rows = self.kern.n_rows();
let rank = factor.ncols();
let stride = K * rank;
let mut jf = Array2::<f64>::zeros((n_rows, stride));
if n_rows == 0 || rank == 0 {
return jf;
}
let f_t: Array2<f64> = factor.t().as_standard_layout().into_owned();
jf.as_slice_mut()
.expect("row-major JF matrix must be contiguous")
.par_chunks_mut(stride)
.enumerate()
.for_each(|(row, jf_row)| {
for k_col in 0..rank {
let f_slice = f_t
.row(k_col)
.to_slice()
.expect("standard-layout row must be contiguous");
let vec_k = self.kern.jacobian_action(row, f_slice);
for k in 0..K {
jf_row[k * rank + k_col] = vec_k[k];
}
}
});
jf
}
fn cached_jf(&self, factor: &Array2<f64>, cache: &ProjectedFactorCache) -> Arc<Array2<f64>> {
let design_id = Arc::as_ptr(&self.kern) as *const () as usize;
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
cache.get_or_insert_with(key, || self.compute_jf(factor))
}
fn trace_projected_factor_with_jf(&self, factor: &Array2<f64>, jf: ArrayView2<'_, f64>) -> f64 {
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
debug_assert_eq!(jf.dim(), (n_rows, K * rank));
let direction_u = self.direction_u.as_slice();
let direction_v = self.direction_v.as_slice();
(0..n_rows)
.into_par_iter()
.map(|row| -> f64 {
let dir_u = self.kern.jacobian_action(row, direction_u);
let dir_v = self.kern.jacobian_action(row, direction_v);
let fourth = self.kern.row_fourth_contracted(row, &dir_u, &dir_v).expect(
"row-kernel fourth contraction should succeed for validated directions",
);
let jf_row = jf.row(row);
let jf_slice = jf_row
.to_slice()
.expect("J·F is built standard-layout (row-major)");
let mut row_total = 0.0_f64;
for k_col in 0..rank {
let mut vec_k = [0.0_f64; K];
for k in 0..K {
vec_k[k] = jf_slice[k * rank + k_col];
}
let mut quad = 0.0_f64;
for a in 0..K {
let mut t_dot = 0.0_f64;
for b in 0..K {
t_dot += fourth[a][b] * vec_k[b];
}
quad += vec_k[a] * t_dot;
}
row_total += quad;
}
row_total
})
.sum()
}
}
pub struct RowKernelHessianWorkspace<const K: usize, T: RowKernel<K>> {
kern: Arc<T>,
cache: RowKernelCache<K>,
}
impl<const K: usize, T: RowKernel<K>> RowKernelHessianWorkspace<K, T> {
pub fn new(kern: T) -> Result<Self, String> {
let kern = Arc::new(kern);
let cache = build_row_kernel_cache(&*kern)?;
Ok(Self { kern, cache })
}
}
impl<const K: usize, T: RowKernel<K> + 'static> ExactNewtonJointHessianWorkspace
for RowKernelHessianWorkspace<K, T>
{
fn warm_up_outer_caches(&self) -> Result<(), String> {
self.kern.warm_up_directional_caches()
}
fn joint_log_likelihood_evaluation(&self) -> Result<Option<f64>, String> {
Ok(Some(row_kernel_log_likelihood(&self.cache)))
}
fn joint_gradient_evaluation(
&self,
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
Ok(Some(ExactNewtonJointGradientEvaluation {
log_likelihood: row_kernel_log_likelihood(&self.cache),
gradient: -row_kernel_gradient(&*self.kern, &self.cache),
}))
}
fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
Ok(Some(row_kernel_hessian_dense(&*self.kern, &self.cache)))
}
fn hessian_matvec(&self, v: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
let sl = v.as_slice().ok_or("hessian_matvec: non-contiguous input")?;
Ok(Some(row_kernel_hessian_matvec(
&*self.kern,
&self.cache,
sl,
)))
}
fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
Ok(Some(row_kernel_hessian_diagonal(&*self.kern, &self.cache)))
}
fn directional_derivative(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let sl = d_beta_flat
.as_slice()
.ok_or("directional_derivative: non-contiguous input")?;
row_kernel_directional_derivative(&*self.kern, sl).map(Some)
}
fn directional_derivative_operator(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let direction = d_beta_flat
.as_slice()
.ok_or("directional_derivative_operator: non-contiguous input")?
.to_vec();
Ok(Some(Arc::new(RowKernelDirectionalDerivativeOperator {
kern: Arc::clone(&self.kern),
direction,
p: self.cache.p,
})))
}
fn second_directional_derivative(
&self,
d_beta_u: &Array1<f64>,
d_beta_v: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let su = d_beta_u
.as_slice()
.ok_or("second_directional_derivative: non-contiguous u")?;
let sv = d_beta_v
.as_slice()
.ok_or("second_directional_derivative: non-contiguous v")?;
row_kernel_second_directional_derivative(&*self.kern, su, sv).map(Some)
}
fn second_directional_derivative_operator(
&self,
d_beta_u: &Array1<f64>,
d_beta_v: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let direction_u = d_beta_u
.as_slice()
.ok_or("second_directional_derivative_operator: non-contiguous u")?
.to_vec();
let direction_v = d_beta_v
.as_slice()
.ok_or("second_directional_derivative_operator: non-contiguous v")?
.to_vec();
Ok(Some(Arc::new(
RowKernelSecondDirectionalDerivativeOperator {
kern: Arc::clone(&self.kern),
direction_u,
direction_v,
p: self.cache.p,
},
)))
}
}