use crate::custom_family::{ExactNewtonJointGradientEvaluation, ExactNewtonJointHessianWorkspace};
use crate::solver::estimate::reml::unified::{
HyperOperator, ProjectedFactorCache, ProjectedFactorKey,
};
use ndarray::{Array1, Array2, ArrayView2};
use rayon::prelude::*;
use std::sync::Arc;
#[derive(Clone)]
pub enum RowSet {
All,
Subsample {
rows: Arc<Vec<crate::families::marginal_slope_shared::WeightedOuterRow>>,
n_full: usize,
},
}
impl RowSet {
pub fn from_options(
opts: &crate::families::custom_family::BlockwiseFitOptions,
n_total: usize,
) -> Self {
match opts.outer_score_subsample.as_ref() {
None => Self::All,
Some(s) => Self::Subsample {
rows: Arc::clone(&s.rows),
n_full: n_total,
},
}
}
pub fn n_effective(&self) -> f64 {
match self {
Self::All => f64::NAN,
Self::Subsample { rows, .. } => rows.len() as f64,
}
}
pub fn collect_indexed(&self, n_total: usize) -> Vec<(usize, f64)> {
match self {
Self::All => (0..n_total).map(|i| (i, 1.0)).collect(),
Self::Subsample { rows, .. } => rows.iter().map(|r| (r.index, r.weight)).collect(),
}
}
#[inline]
pub fn len(&self, n_total: usize) -> usize {
match self {
Self::All => n_total,
Self::Subsample { rows, .. } => rows.len(),
}
}
#[inline]
pub fn par_for_each<F>(&self, n_total: usize, body: F)
where
F: Fn(usize, f64) + Send + Sync,
{
match self {
Self::All => {
(0..n_total).into_par_iter().for_each(|i| body(i, 1.0));
}
Self::Subsample { rows, .. } => {
rows.par_iter().for_each(|r| body(r.index, r.weight));
}
}
}
#[inline]
pub fn par_try_for_each<F, E>(&self, n_total: usize, body: F) -> Result<(), E>
where
F: Fn(usize, f64) -> Result<(), E> + Send + Sync,
E: Send,
{
match self {
Self::All => (0..n_total).into_par_iter().try_for_each(|i| body(i, 1.0)),
Self::Subsample { rows, .. } => {
rows.par_iter().try_for_each(|r| body(r.index, r.weight))
}
}
}
#[inline]
pub fn par_reduce_fold<T, I, F, R>(&self, n_total: usize, init: I, fold: F, reduce: R) -> T
where
T: Send,
I: Fn() -> T + Send + Sync,
F: Fn(T, usize, f64) -> T + Send + Sync,
R: Fn(T, T) -> T + Send + Sync,
{
match self {
Self::All => (0..n_total)
.into_par_iter()
.fold(&init, |acc, i| fold(acc, i, 1.0))
.reduce(&init, &reduce),
Self::Subsample { rows, .. } => rows
.par_iter()
.fold(&init, |acc, r| fold(acc, r.index, r.weight))
.reduce(&init, &reduce),
}
}
#[inline]
pub fn par_try_reduce_fold<T, E, I, F, R>(
&self,
n_total: usize,
init: I,
fold: F,
reduce: R,
) -> Result<T, E>
where
T: Send,
E: Send,
I: Fn() -> T + Send + Sync,
F: Fn(T, usize, f64) -> Result<T, E> + Send + Sync,
R: Fn(T, T) -> Result<T, E> + Send + Sync,
{
match self {
Self::All => (0..n_total)
.into_par_iter()
.try_fold(&init, |acc, i| fold(acc, i, 1.0))
.try_reduce(&init, &reduce),
Self::Subsample { rows, .. } => rows
.par_iter()
.try_fold(&init, |acc, r| fold(acc, r.index, r.weight))
.try_reduce(&init, &reduce),
}
}
}
pub trait RowKernel<const K: usize>: Send + Sync {
fn n_rows(&self) -> usize;
fn n_coefficients(&self) -> usize;
fn row_kernel(&self, row: usize) -> Result<(f64, [f64; K], [[f64; K]; K]), String>;
fn jacobian_action(&self, row: usize, d_beta: &[f64]) -> [f64; K];
fn jacobian_transpose_action(&self, row: usize, v: &[f64; K], out: &mut [f64]);
fn add_pullback_hessian(&self, row: usize, h: &[[f64; K]; K], target: &mut Array2<f64>);
fn add_diagonal_quadratic(&self, row: usize, h: &[[f64; K]; K], diag: &mut [f64]);
fn row_third_contracted(&self, row: usize, dir: &[f64; K]) -> Result<[[f64; K]; K], String>;
fn row_fourth_contracted(
&self,
row: usize,
dir_u: &[f64; K],
dir_v: &[f64; K],
) -> Result<[[f64; K]; K], String>;
fn warm_up_directional_caches(&self) -> Result<(), String> {
Ok(())
}
}
pub struct RowKernelCache<const K: usize> {
pub n: usize,
pub p: usize,
pub nll: Vec<f64>,
pub gradients: Vec<[f64; K]>,
pub hessians: Vec<[[f64; K]; K]>,
}
pub fn build_row_kernel_cache<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
rows: &RowSet,
) -> Result<RowKernelCache<K>, String> {
let n = kern.n_rows();
let p = kern.n_coefficients();
let mut nll = vec![0.0_f64; n];
let mut gradients = vec![[0.0_f64; K]; n];
let mut hessians = vec![[[0.0_f64; K]; K]; n];
match rows {
RowSet::All => {
let evaluated: Vec<(f64, [f64; K], [[f64; K]; K])> = (0..n)
.into_par_iter()
.map(|row| kern.row_kernel(row))
.collect::<Result<Vec<_>, String>>()?;
for (i, (l, g, h)) in evaluated.into_iter().enumerate() {
nll[i] = l;
gradients[i] = g;
hessians[i] = h;
}
}
RowSet::Subsample { rows: list, .. } => {
let pairs: Vec<(usize, (f64, [f64; K], [[f64; K]; K]))> = list
.par_iter()
.map(|r| kern.row_kernel(r.index).map(|out| (r.index, out)))
.collect::<Result<Vec<_>, String>>()?;
for (idx, (l, g, h)) in pairs {
nll[idx] = l;
gradients[idx] = g;
hessians[idx] = h;
}
}
}
Ok(RowKernelCache {
n,
p,
nll,
gradients,
hessians,
})
}
pub fn row_kernel_hessian_matvec<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
rows: &RowSet,
direction: &[f64],
) -> Array1<f64> {
let p = cache.p;
let out = rows.par_reduce_fold(
cache.n,
|| vec![0.0_f64; p],
|mut acc, row, w| {
let dir_k = kern.jacobian_action(row, direction);
let h = &cache.hessians[row];
let mut action = [0.0_f64; K];
for a in 0..K {
let mut s = 0.0;
for b in 0..K {
s += h[a][b] * dir_k[b];
}
action[a] = w * s;
}
kern.jacobian_transpose_action(row, &action, &mut acc);
acc
},
|mut a, b| {
for i in 0..a.len() {
a[i] += b[i];
}
a
},
);
Array1::from_vec(out)
}
pub fn row_kernel_hessian_diagonal<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
rows: &RowSet,
) -> Array1<f64> {
let p = cache.p;
let out = rows.par_reduce_fold(
cache.n,
|| vec![0.0_f64; p],
|mut diag, row, w| {
if w == 1.0 {
kern.add_diagonal_quadratic(row, &cache.hessians[row], &mut diag);
} else {
let h = &cache.hessians[row];
let mut scaled = [[0.0_f64; K]; K];
for a in 0..K {
for b in 0..K {
scaled[a][b] = w * h[a][b];
}
}
kern.add_diagonal_quadratic(row, &scaled, &mut diag);
}
diag
},
|mut a, b| {
for i in 0..a.len() {
a[i] += b[i];
}
a
},
);
Array1::from_vec(out)
}
pub fn row_kernel_gradient<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
rows: &RowSet,
) -> Array1<f64> {
let p = cache.p;
let out = rows.par_reduce_fold(
cache.n,
|| vec![0.0_f64; p],
|mut acc, row, w| {
if w == 1.0 {
kern.jacobian_transpose_action(row, &cache.gradients[row], &mut acc);
} else {
let g = &cache.gradients[row];
let mut scaled = [0.0_f64; K];
for a in 0..K {
scaled[a] = w * g[a];
}
kern.jacobian_transpose_action(row, &scaled, &mut acc);
}
acc
},
|mut a, b| {
for i in 0..a.len() {
a[i] += b[i];
}
a
},
);
Array1::from_vec(out)
}
pub fn row_kernel_log_likelihood<const K: usize>(cache: &RowKernelCache<K>, rows: &RowSet) -> f64 {
let total = rows.par_reduce_fold(
cache.n,
|| 0.0_f64,
|acc, row, w| acc + w * cache.nll[row],
|a, b| a + b,
);
-total
}
pub fn row_kernel_hessian_dense<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
rows: &RowSet,
) -> Array2<f64> {
let p = cache.p;
rows.par_reduce_fold(
cache.n,
|| Array2::<f64>::zeros((p, p)),
|mut acc, row, w| {
if w == 1.0 {
kern.add_pullback_hessian(row, &cache.hessians[row], &mut acc);
} else {
let h = &cache.hessians[row];
let mut scaled = [[0.0_f64; K]; K];
for a in 0..K {
for b in 0..K {
scaled[a][b] = w * h[a][b];
}
}
kern.add_pullback_hessian(row, &scaled, &mut acc);
}
acc
},
|a, b| a + b,
)
}
pub fn row_kernel_directional_derivative<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
rows: &RowSet,
d_beta: &[f64],
) -> Result<Array2<f64>, String> {
let n = kern.n_rows();
let p = kern.n_coefficients();
rows.par_try_reduce_fold(
n,
|| Array2::<f64>::zeros((p, p)),
|mut acc, row, w| -> Result<_, String> {
let dir_k = kern.jacobian_action(row, d_beta);
let third = kern.row_third_contracted(row, &dir_k)?;
if w == 1.0 {
kern.add_pullback_hessian(row, &third, &mut acc);
} else {
let mut scaled = [[0.0_f64; K]; K];
for a in 0..K {
for b in 0..K {
scaled[a][b] = w * third[a][b];
}
}
kern.add_pullback_hessian(row, &scaled, &mut acc);
}
Ok(acc)
},
|a, b| Ok(a + b),
)
}
pub fn row_kernel_second_directional_derivative<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
rows: &RowSet,
d_beta_u: &[f64],
d_beta_v: &[f64],
) -> Result<Array2<f64>, String> {
let n = kern.n_rows();
let p = kern.n_coefficients();
rows.par_try_reduce_fold(
n,
|| Array2::<f64>::zeros((p, p)),
|mut acc, row, w| -> Result<_, String> {
let dir_u = kern.jacobian_action(row, d_beta_u);
let dir_v = kern.jacobian_action(row, d_beta_v);
let fourth = kern.row_fourth_contracted(row, &dir_u, &dir_v)?;
if w == 1.0 {
kern.add_pullback_hessian(row, &fourth, &mut acc);
} else {
let mut scaled = [[0.0_f64; K]; K];
for a in 0..K {
for b in 0..K {
scaled[a][b] = w * fourth[a][b];
}
}
kern.add_pullback_hessian(row, &scaled, &mut acc);
}
Ok(acc)
},
|a, b| Ok(a + b),
)
}
struct RowKernelDirectionalDerivativeOperator<const K: usize, T: RowKernel<K>> {
kern: Arc<T>,
direction: Vec<f64>,
p: usize,
rows: RowSet,
}
impl<const K: usize, T: RowKernel<K>> HyperOperator
for RowKernelDirectionalDerivativeOperator<K, T>
{
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let direction = v
.as_slice()
.expect("row-kernel directional derivative operator requires contiguous input");
let out = self.rows.par_reduce_fold(
self.kern.n_rows(),
|| vec![0.0_f64; self.p],
|mut acc, row, w| {
let dir_k = self.kern.jacobian_action(row, &self.direction);
let vec_k = self.kern.jacobian_action(row, direction);
let third = self
.kern
.row_third_contracted(row, &dir_k)
.expect("row-kernel third contraction should succeed for validated directions");
let mut action = [0.0_f64; K];
for a in 0..K {
let mut sum = 0.0;
for b in 0..K {
sum += third[a][b] * vec_k[b];
}
action[a] = w * sum;
}
self.kern.jacobian_transpose_action(row, &action, &mut acc);
acc
},
|mut left, right| {
for idx in 0..left.len() {
left[idx] += right[idx];
}
left
},
);
Array1::from_vec(out)
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
debug_assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
let jf = self.compute_jf(factor);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
debug_assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
let jf = self.cached_jf(factor, cache);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn to_dense(&self) -> Array2<f64> {
row_kernel_directional_derivative(&*self.kern, &self.rows, &self.direction)
.expect("row-kernel directional derivative dense materialization should succeed")
}
fn is_implicit(&self) -> bool {
true
}
}
impl<const K: usize, T: RowKernel<K>> RowKernelDirectionalDerivativeOperator<K, T> {
fn compute_jf(&self, factor: &Array2<f64>) -> Array2<f64> {
let n_rows = self.kern.n_rows();
let rank = factor.ncols();
let stride = K * rank;
let mut jf = Array2::<f64>::zeros((n_rows, stride));
if n_rows == 0 || rank == 0 {
return jf;
}
let f_t: Array2<f64> = factor.t().as_standard_layout().into_owned();
jf.as_slice_mut()
.expect("row-major JF matrix must be contiguous")
.par_chunks_mut(stride)
.enumerate()
.for_each(|(row, jf_row)| {
for k_col in 0..rank {
let f_slice = f_t
.row(k_col)
.to_slice()
.expect("standard-layout row must be contiguous");
let vec_k = self.kern.jacobian_action(row, f_slice);
for k in 0..K {
jf_row[k * rank + k_col] = vec_k[k];
}
}
});
jf
}
fn cached_jf(&self, factor: &Array2<f64>, cache: &ProjectedFactorCache) -> Arc<Array2<f64>> {
let design_id = Arc::as_ptr(&self.kern) as *const () as usize;
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
cache.get_or_insert_with(key, || self.compute_jf(factor))
}
fn trace_projected_factor_with_jf(&self, factor: &Array2<f64>, jf: ArrayView2<'_, f64>) -> f64 {
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
debug_assert_eq!(jf.dim(), (n_rows, K * rank));
let direction = self.direction.as_slice();
self.rows.par_reduce_fold(
n_rows,
|| 0.0_f64,
|acc, row, w| {
let dir_k = self.kern.jacobian_action(row, direction);
let third = self
.kern
.row_third_contracted(row, &dir_k)
.expect("row-kernel third contraction should succeed for validated directions");
let jf_row = jf.row(row);
let jf_slice = jf_row
.to_slice()
.expect("J·F is built standard-layout (row-major)");
let mut gram = [[0.0_f64; K]; K];
for a in 0..K {
let row_a = &jf_slice[a * rank..(a + 1) * rank];
for b in a..K {
let row_b = &jf_slice[b * rank..(b + 1) * rank];
let mut s = 0.0_f64;
for k_col in 0..rank {
s += row_a[k_col] * row_b[k_col];
}
gram[a][b] = s;
gram[b][a] = s;
}
}
let mut row_total = 0.0_f64;
for a in 0..K {
for b in 0..K {
row_total += third[a][b] * gram[a][b];
}
}
acc + w * row_total
},
|a, b| a + b,
)
}
}
struct RowKernelSecondDirectionalDerivativeOperator<const K: usize, T: RowKernel<K>> {
kern: Arc<T>,
direction_u: Vec<f64>,
direction_v: Vec<f64>,
p: usize,
rows: RowSet,
}
impl<const K: usize, T: RowKernel<K>> HyperOperator
for RowKernelSecondDirectionalDerivativeOperator<K, T>
{
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let direction = v
.as_slice()
.expect("row-kernel second directional derivative operator requires contiguous input");
let out = self.rows.par_reduce_fold(
self.kern.n_rows(),
|| vec![0.0_f64; self.p],
|mut acc, row, w| {
let dir_u = self.kern.jacobian_action(row, &self.direction_u);
let dir_v = self.kern.jacobian_action(row, &self.direction_v);
let vec_k = self.kern.jacobian_action(row, direction);
let fourth = self.kern.row_fourth_contracted(row, &dir_u, &dir_v).expect(
"row-kernel fourth contraction should succeed for validated directions",
);
let mut action = [0.0_f64; K];
for a in 0..K {
let mut sum = 0.0;
for b in 0..K {
sum += fourth[a][b] * vec_k[b];
}
action[a] = w * sum;
}
self.kern.jacobian_transpose_action(row, &action, &mut acc);
acc
},
|mut left, right| {
for idx in 0..left.len() {
left[idx] += right[idx];
}
left
},
);
Array1::from_vec(out)
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
debug_assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
let jf = self.compute_jf(factor);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
debug_assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
let jf = self.cached_jf(factor, cache);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn to_dense(&self) -> Array2<f64> {
row_kernel_second_directional_derivative(
&*self.kern,
&self.rows,
&self.direction_u,
&self.direction_v,
)
.expect("row-kernel second directional derivative dense materialization should succeed")
}
fn is_implicit(&self) -> bool {
true
}
}
impl<const K: usize, T: RowKernel<K>> RowKernelSecondDirectionalDerivativeOperator<K, T> {
fn compute_jf(&self, factor: &Array2<f64>) -> Array2<f64> {
let n_rows = self.kern.n_rows();
let rank = factor.ncols();
let stride = K * rank;
let mut jf = Array2::<f64>::zeros((n_rows, stride));
if n_rows == 0 || rank == 0 {
return jf;
}
let f_t: Array2<f64> = factor.t().as_standard_layout().into_owned();
jf.as_slice_mut()
.expect("row-major JF matrix must be contiguous")
.par_chunks_mut(stride)
.enumerate()
.for_each(|(row, jf_row)| {
for k_col in 0..rank {
let f_slice = f_t
.row(k_col)
.to_slice()
.expect("standard-layout row must be contiguous");
let vec_k = self.kern.jacobian_action(row, f_slice);
for k in 0..K {
jf_row[k * rank + k_col] = vec_k[k];
}
}
});
jf
}
fn cached_jf(&self, factor: &Array2<f64>, cache: &ProjectedFactorCache) -> Arc<Array2<f64>> {
let design_id = Arc::as_ptr(&self.kern) as *const () as usize;
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
cache.get_or_insert_with(key, || self.compute_jf(factor))
}
fn trace_projected_factor_with_jf(&self, factor: &Array2<f64>, jf: ArrayView2<'_, f64>) -> f64 {
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
debug_assert_eq!(jf.dim(), (n_rows, K * rank));
let direction_u = self.direction_u.as_slice();
let direction_v = self.direction_v.as_slice();
self.rows.par_reduce_fold(
n_rows,
|| 0.0_f64,
|acc, row, w| {
let dir_u = self.kern.jacobian_action(row, direction_u);
let dir_v = self.kern.jacobian_action(row, direction_v);
let fourth = self.kern.row_fourth_contracted(row, &dir_u, &dir_v).expect(
"row-kernel fourth contraction should succeed for validated directions",
);
let jf_row = jf.row(row);
let jf_slice = jf_row
.to_slice()
.expect("J·F is built standard-layout (row-major)");
let mut gram = [[0.0_f64; K]; K];
for a in 0..K {
let row_a = &jf_slice[a * rank..(a + 1) * rank];
for b in a..K {
let row_b = &jf_slice[b * rank..(b + 1) * rank];
let mut s = 0.0_f64;
for k_col in 0..rank {
s += row_a[k_col] * row_b[k_col];
}
gram[a][b] = s;
gram[b][a] = s;
}
}
let mut row_total = 0.0_f64;
for a in 0..K {
for b in 0..K {
row_total += fourth[a][b] * gram[a][b];
}
}
acc + w * row_total
},
|a, b| a + b,
)
}
}
pub struct RowKernelHessianWorkspace<const K: usize, T: RowKernel<K>> {
kern: Arc<T>,
cache: RowKernelCache<K>,
rows: RowSet,
}
impl<const K: usize, T: RowKernel<K>> RowKernelHessianWorkspace<K, T> {
pub fn new(kern: T) -> Result<Self, String> {
Self::with_rows(kern, RowSet::All)
}
pub fn with_rows(kern: T, rows: RowSet) -> Result<Self, String> {
let kern = Arc::new(kern);
let cache = build_row_kernel_cache(&*kern, &rows)?;
Ok(Self { kern, cache, rows })
}
}
impl<const K: usize, T: RowKernel<K> + 'static> ExactNewtonJointHessianWorkspace
for RowKernelHessianWorkspace<K, T>
{
fn warm_up_outer_caches(&self) -> Result<(), String> {
self.kern.warm_up_directional_caches()
}
fn joint_log_likelihood_evaluation(&self) -> Result<Option<f64>, String> {
Ok(Some(row_kernel_log_likelihood(&self.cache, &self.rows)))
}
fn joint_gradient_evaluation(
&self,
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
Ok(Some(ExactNewtonJointGradientEvaluation {
log_likelihood: row_kernel_log_likelihood(&self.cache, &self.rows),
gradient: -row_kernel_gradient(&*self.kern, &self.cache, &self.rows),
}))
}
fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
Ok(Some(row_kernel_hessian_dense(
&*self.kern,
&self.cache,
&self.rows,
)))
}
fn hessian_matvec(&self, v: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
let sl = v.as_slice().ok_or("hessian_matvec: non-contiguous input")?;
Ok(Some(row_kernel_hessian_matvec(
&*self.kern,
&self.cache,
&self.rows,
sl,
)))
}
fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
Ok(Some(row_kernel_hessian_diagonal(
&*self.kern,
&self.cache,
&self.rows,
)))
}
fn directional_derivative(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let sl = d_beta_flat
.as_slice()
.ok_or("directional_derivative: non-contiguous input")?;
row_kernel_directional_derivative(&*self.kern, &self.rows, sl).map(Some)
}
fn directional_derivative_operator(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let direction = d_beta_flat
.as_slice()
.ok_or("directional_derivative_operator: non-contiguous input")?
.to_vec();
Ok(Some(Arc::new(RowKernelDirectionalDerivativeOperator {
kern: Arc::clone(&self.kern),
direction,
p: self.cache.p,
rows: self.rows.clone(),
})))
}
fn second_directional_derivative(
&self,
d_beta_u: &Array1<f64>,
d_beta_v: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let su = d_beta_u
.as_slice()
.ok_or("second_directional_derivative: non-contiguous u")?;
let sv = d_beta_v
.as_slice()
.ok_or("second_directional_derivative: non-contiguous v")?;
row_kernel_second_directional_derivative(&*self.kern, &self.rows, su, sv).map(Some)
}
fn second_directional_derivative_operator(
&self,
d_beta_u: &Array1<f64>,
d_beta_v: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let direction_u = d_beta_u
.as_slice()
.ok_or("second_directional_derivative_operator: non-contiguous u")?
.to_vec();
let direction_v = d_beta_v
.as_slice()
.ok_or("second_directional_derivative_operator: non-contiguous v")?
.to_vec();
Ok(Some(Arc::new(
RowKernelSecondDirectionalDerivativeOperator {
kern: Arc::clone(&self.kern),
direction_u,
direction_v,
p: self.cache.p,
rows: self.rows.clone(),
},
)))
}
}
#[cfg(test)]
mod gram_inner_contraction_tests {
use super::*;
use crate::solver::estimate::reml::unified::ProjectedFactorCache;
use ndarray::Array2;
struct SyntheticKernel {
n: usize,
p: usize,
designs: [Array2<f64>; 4],
}
impl SyntheticKernel {
fn new(n: usize, p: usize, seed: u64) -> Self {
let mut s = seed;
let mut next = || -> f64 {
s = s
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((s >> 33) as f64 / (u32::MAX as f64)) - 0.5
};
let mut mk = || -> Array2<f64> { Array2::from_shape_fn((n, p), |_| next()) };
let d0 = mk();
let d1 = mk();
let d2 = mk();
let d3 = mk();
Self {
n,
p,
designs: [d0, d1, d2, d3],
}
}
}
impl RowKernel<4> for SyntheticKernel {
fn n_rows(&self) -> usize {
self.n
}
fn n_coefficients(&self) -> usize {
self.p
}
fn row_kernel(&self, _row: usize) -> Result<(f64, [f64; 4], [[f64; 4]; 4]), String> {
Ok((0.0, [0.0; 4], [[0.0; 4]; 4]))
}
fn jacobian_action(&self, row: usize, d_beta: &[f64]) -> [f64; 4] {
let mut out = [0.0_f64; 4];
for k in 0..4 {
let design_row = self.designs[k].row(row);
let mut s = 0.0_f64;
for j in 0..self.p {
s += design_row[j] * d_beta[j];
}
out[k] = s;
}
out
}
fn jacobian_transpose_action(&self, row: usize, v: &[f64; 4], out: &mut [f64]) {
for k in 0..4 {
let design_row = self.designs[k].row(row);
for j in 0..self.p {
out[j] += design_row[j] * v[k];
}
}
}
fn add_pullback_hessian(&self, _row: usize, _h: &[[f64; 4]; 4], _target: &mut Array2<f64>) {
unreachable!("not used in this regression test")
}
fn add_diagonal_quadratic(&self, _row: usize, _h: &[[f64; 4]; 4], _diag: &mut [f64]) {
unreachable!("not used in this regression test")
}
fn row_third_contracted(
&self,
row: usize,
dir: &[f64; 4],
) -> Result<[[f64; 4]; 4], String> {
let mut t = [[0.0_f64; 4]; 4];
let row_f = (row as f64) * 0.013;
for a in 0..4 {
for b in a..4 {
let v = (row_f + a as f64 * 0.7 + b as f64 * 1.3).sin()
+ dir[a] * 0.25
+ dir[b] * 0.5
+ dir[(a + b) % 4] * 0.125;
t[a][b] = v;
t[b][a] = v;
}
}
Ok(t)
}
fn row_fourth_contracted(
&self,
row: usize,
dir_u: &[f64; 4],
dir_v: &[f64; 4],
) -> Result<[[f64; 4]; 4], String> {
let mut t = [[0.0_f64; 4]; 4];
let row_f = (row as f64) * 0.011 + 0.31;
for a in 0..4 {
for b in a..4 {
let v = (row_f + a as f64 * 0.9 + b as f64 * 1.7).cos()
+ dir_u[a] * 0.13
+ dir_v[b] * 0.27
+ dir_u[(a + b) % 4] * dir_v[(a + 1) % 4] * 0.05;
t[a][b] = v;
t[b][a] = v;
}
}
Ok(t)
}
}
fn reference_trace_first<const K: usize>(
kern: &impl RowKernel<K>,
direction: &[f64],
factor: &Array2<f64>,
) -> f64 {
let n_rows = kern.n_rows();
let rank = factor.ncols();
let p = factor.nrows();
let mut total = 0.0_f64;
for row in 0..n_rows {
let mut dir_k_buf = [0.0_f64; 16];
let dir_k_arr = kern.jacobian_action(row, direction);
for k in 0..K {
dir_k_buf[k] = dir_k_arr[k];
}
let third = kern.row_third_contracted(row, &dir_k_arr).expect("third");
let _ = dir_k_buf;
for k_col in 0..rank {
let mut col = vec![0.0_f64; p];
for j in 0..p {
col[j] = factor[[j, k_col]];
}
let vec_k = kern.jacobian_action(row, &col);
let mut quad = 0.0_f64;
for a in 0..K {
let mut t_dot = 0.0_f64;
for b in 0..K {
t_dot += third[a][b] * vec_k[b];
}
quad += vec_k[a] * t_dot;
}
total += quad;
}
}
total
}
fn reference_trace_second<const K: usize>(
kern: &impl RowKernel<K>,
direction_u: &[f64],
direction_v: &[f64],
factor: &Array2<f64>,
) -> f64 {
let n_rows = kern.n_rows();
let rank = factor.ncols();
let p = factor.nrows();
let mut total = 0.0_f64;
for row in 0..n_rows {
let dir_u = kern.jacobian_action(row, direction_u);
let dir_v = kern.jacobian_action(row, direction_v);
let fourth = kern
.row_fourth_contracted(row, &dir_u, &dir_v)
.expect("fourth");
for k_col in 0..rank {
let mut col = vec![0.0_f64; p];
for j in 0..p {
col[j] = factor[[j, k_col]];
}
let vec_k = kern.jacobian_action(row, &col);
let mut quad = 0.0_f64;
for a in 0..K {
let mut t_dot = 0.0_f64;
for b in 0..K {
t_dot += fourth[a][b] * vec_k[b];
}
quad += vec_k[a] * t_dot;
}
total += quad;
}
}
total
}
#[test]
fn gram_inner_contraction_matches_reference() {
let n = 32;
let p = 11;
let rank = 7;
let kern = Arc::new(SyntheticKernel::new(n, p, 0xC0FFEE));
let mut s = 0xDEADBEEF_u64;
let mut next = || -> f64 {
s = s
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((s >> 33) as f64 / (u32::MAX as f64)) - 0.5
};
let direction: Vec<f64> = (0..p).map(|_| next()).collect();
let direction_u: Vec<f64> = (0..p).map(|_| next()).collect();
let direction_v: Vec<f64> = (0..p).map(|_| next()).collect();
let factor = Array2::from_shape_fn((p, rank), |_| next());
let op1 = RowKernelDirectionalDerivativeOperator {
kern: Arc::clone(&kern),
direction: direction.clone(),
p,
rows: RowSet::All,
};
let cache = ProjectedFactorCache::default();
let got1_uncached = HyperOperator::trace_projected_factor(&op1, &factor);
let got1_cached = op1.trace_projected_factor_cached(&factor, &cache);
let ref1 = reference_trace_first::<4>(&*kern, &direction, &factor);
let rel1_uncached = (got1_uncached - ref1).abs() / ref1.abs().max(1e-12);
let rel1_cached = (got1_cached - ref1).abs() / ref1.abs().max(1e-12);
assert!(
rel1_uncached < 1e-10,
"first-derivative Gram path drifted: rel={rel1_uncached:.3e} got={got1_uncached} ref={ref1}",
);
assert!(
rel1_cached < 1e-10,
"first-derivative cached Gram path drifted: rel={rel1_cached:.3e} got={got1_cached} ref={ref1}",
);
let op2 = RowKernelSecondDirectionalDerivativeOperator {
kern: Arc::clone(&kern),
direction_u: direction_u.clone(),
direction_v: direction_v.clone(),
p,
rows: RowSet::All,
};
let cache2 = ProjectedFactorCache::default();
let got2_uncached = HyperOperator::trace_projected_factor(&op2, &factor);
let got2_cached = op2.trace_projected_factor_cached(&factor, &cache2);
let ref2 = reference_trace_second::<4>(&*kern, &direction_u, &direction_v, &factor);
let rel2_uncached = (got2_uncached - ref2).abs() / ref2.abs().max(1e-12);
let rel2_cached = (got2_cached - ref2).abs() / ref2.abs().max(1e-12);
assert!(
rel2_uncached < 1e-10,
"second-derivative Gram path drifted: rel={rel2_uncached:.3e} got={got2_uncached} ref={ref2}",
);
assert!(
rel2_cached < 1e-10,
"second-derivative cached Gram path drifted: rel={rel2_cached:.3e} got={got2_cached} ref={ref2}",
);
}
}