use crate::custom_family::{
ExactNewtonJointGradientEvaluation, ExactNewtonJointHessianWorkspace,
JointHessianSourcePreference, MaterializationIntent, use_joint_matrix_free_path,
};
use crate::faer_ndarray::fast_ab;
use crate::matrix::DesignMatrix;
use crate::util::loop_progress::LoopProgress;
use gam_problem::{HyperOperator, ProjectedFactorCache, ProjectedFactorKey};
use ndarray::{Array1, Array2, ArrayView2, s};
use rayon::prelude::*;
use std::sync::Arc;
const ROW_KERNEL_CACHE_PROGRESS_MIN_ROWS: usize = 100_000;
pub use crate::outer_subsample::RowSet;
use crate::outer_subsample::{ARROW_ROW_CHUNK, arrow_row_chunk_count};
const JF_MATERIALIZATION_BUDGET_BYTES: usize = 1024 * 1024 * 1024;
const JF_TILE_BUDGET_BYTES: usize = 256 * 1024 * 1024;
fn jf_projection_exceeds_budget<const K: usize>(n_rows: usize, rank: usize) -> bool {
jf_projection_bytes::<K>(n_rows, rank) > JF_MATERIALIZATION_BUDGET_BYTES
}
#[inline]
fn jf_projection_bytes<const K: usize>(n_rows: usize, rank: usize) -> usize {
n_rows
.saturating_mul(K)
.saturating_mul(rank)
.saturating_mul(std::mem::size_of::<f64>())
}
fn jf_tile_rows<const K: usize>(rank: usize) -> usize {
let per_row = (K.saturating_mul(rank)).max(1) * std::mem::size_of::<f64>();
let max_rows = (JF_TILE_BUDGET_BYTES / per_row).max(1);
(max_rows / ARROW_ROW_CHUNK).max(1) * ARROW_ROW_CHUNK
}
fn cache_build_chunk_rows(n_rows: usize) -> usize {
const OVERSUBSCRIBE: usize = 4;
if n_rows == 0 {
return ARROW_ROW_CHUNK;
}
let workers = rayon::current_num_threads().max(1);
let target_blocks = (workers * OVERSUBSCRIBE).max(1);
let by_target = n_rows.div_ceil(target_blocks).max(1);
by_target.div_ceil(ARROW_ROW_CHUNK).max(1) * ARROW_ROW_CHUNK
}
#[inline]
fn cache_build_block_count(n_rows: usize, chunk_rows: usize) -> usize {
if n_rows == 0 {
0
} else {
(n_rows - 1) / chunk_rows + 1
}
}
impl RowSet {
pub fn from_options(
opts: &crate::families::custom_family::BlockwiseFitOptions,
n_total: usize,
) -> Self {
match opts.outer_score_subsample.as_ref() {
None => Self::All,
Some(s) => Self::Subsample {
rows: Arc::clone(&s.rows),
n_full: n_total,
},
}
}
}
#[inline]
fn deterministic_chunked_sum<F>(n_items: usize, map_chunk: F) -> f64
where
F: Fn(usize) -> f64 + Send + Sync,
{
let partials: Vec<f64> = (0..arrow_row_chunk_count(n_items))
.into_par_iter()
.map(map_chunk)
.collect();
let mut total = 0.0_f64;
for partial in partials {
total += partial;
}
total
}
pub trait RowKernel<const K: usize>: Send + Sync {
fn n_rows(&self) -> usize;
fn n_coefficients(&self) -> usize;
fn row_kernel(&self, row: usize) -> Result<(f64, [f64; K], [[f64; K]; K]), String>;
fn jacobian_action(&self, row: usize, d_beta: &[f64]) -> [f64; K];
fn jacobian_transpose_action(&self, row: usize, v: &[f64; K], out: &mut [f64]);
fn add_pullback_hessian(&self, row: usize, h: &[[f64; K]; K], target: &mut Array2<f64>);
fn add_diagonal_quadratic(&self, row: usize, h: &[[f64; K]; K], diag: &mut [f64]);
fn row_third_contracted(&self, row: usize, dir: &[f64; K]) -> Result<[[f64; K]; K], String>;
fn row_fourth_contracted(
&self,
row: usize,
dir_u: &[f64; K],
dir_v: &[f64; K],
) -> Result<[[f64; K]; K], String>;
fn warm_up_directional_caches(&self) -> Result<(), String> {
Ok(())
}
fn batched_value_grad_hess_all(
&self,
) -> Option<Result<(Vec<f64>, Vec<[f64; K]>, Vec<[[f64; K]; K]>), String>> {
None
}
fn jacobian_action_matrix(&self, factor: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
Some(row_kernel_jacobian_action_matrix_generic(self, factor))
}
fn jacobian_action_matrix_rows(
&self,
factor: ArrayView2<'_, f64>,
start: usize,
end: usize,
) -> Array2<f64> {
row_kernel_jacobian_action_matrix_generic_rows(self, factor, start, end)
}
fn directional_derivative_dense_override(
&self,
rows: &RowSet,
d_beta: &[f64],
) -> Option<Result<Array2<f64>, String>> {
Some(row_kernel_directional_derivative_generic(
self, rows, d_beta,
))
}
fn directional_derivative_all_axes_dense_override(
&self,
rows: &RowSet,
p: usize,
) -> Option<Result<Vec<Array2<f64>>, String>> {
if p != self.n_coefficients() {
let all = matches!(rows, RowSet::All);
return Some(Err(format!(
"directional_derivative_all_axes_dense_override: axis count {} \
disagrees with n_coefficients() {} (rows::All = {all})",
p,
self.n_coefficients(),
)));
}
None
}
fn hessian_dense_override(
&self,
rows: &RowSet,
row_hessians: &[[[f64; K]; K]],
) -> Option<Array2<f64>> {
Some(row_kernel_hessian_dense_generic(self, rows, row_hessians))
}
fn second_directional_derivative_all_axes_dense_override(
&self,
rows: &RowSet,
d_beta_u: &[f64],
) -> Option<Result<Vec<Array2<f64>>, String>> {
if d_beta_u.len() != self.n_coefficients() {
let all = matches!(rows, RowSet::All);
return Some(Err(format!(
"second_directional_derivative_all_axes_dense_override: fixed direction has \
{} entries, expected {} (rows::All = {all})",
d_beta_u.len(),
self.n_coefficients(),
)));
}
None
}
}
fn row_kernel_jacobian_action_matrix_generic<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
factor: ArrayView2<'_, f64>,
) -> Array2<f64> {
assert_eq!(
factor.nrows(),
kern.n_coefficients(),
"row-kernel JF factor row count must match coefficient dimension"
);
let n_rows = kern.n_rows();
let rank = factor.ncols();
let stride = K * rank;
let mut jf = Array2::<f64>::zeros((n_rows, stride));
if n_rows == 0 || rank == 0 {
return jf;
}
let f_t: Array2<f64> = factor.t().as_standard_layout().into_owned();
jf.as_slice_mut()
.expect("row-major JF matrix must be contiguous")
.par_chunks_mut(stride)
.enumerate()
.for_each(|(row, jf_row)| {
for k_col in 0..rank {
let f_slice = f_t
.row(k_col)
.to_slice()
.expect("standard-layout row must be contiguous");
let vec_k = kern.jacobian_action(row, f_slice);
for k in 0..K {
jf_row[k * rank + k_col] = vec_k[k];
}
}
});
jf
}
pub(crate) fn row_kernel_jacobian_action_matrix_generic_rows<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
factor: ArrayView2<'_, f64>,
start: usize,
end: usize,
) -> Array2<f64> {
assert_eq!(
factor.nrows(),
kern.n_coefficients(),
"row-kernel JF factor row count must match coefficient dimension"
);
let rank = factor.ncols();
let stride = K * rank;
let b = end.saturating_sub(start);
let mut jf = Array2::<f64>::zeros((b, stride));
if b == 0 || rank == 0 {
return jf;
}
let f_t: Array2<f64> = factor.t().as_standard_layout().into_owned();
jf.as_slice_mut()
.expect("row-major JF matrix must be contiguous")
.par_chunks_mut(stride)
.enumerate()
.for_each(|(local, jf_row)| {
let row = start + local;
for k_col in 0..rank {
let f_slice = f_t
.row(k_col)
.to_slice()
.expect("standard-layout row must be contiguous");
let vec_k = kern.jacobian_action(row, f_slice);
for k in 0..K {
jf_row[k * rank + k_col] = vec_k[k];
}
}
});
jf
}
pub(crate) fn row_kernel_design_jf(
design: &DesignMatrix,
factor_block: ArrayView2<'_, f64>,
n_rows: usize,
) -> Array2<f64> {
let rank = factor_block.ncols();
if rank == 0 {
return Array2::<f64>::zeros((n_rows, 0));
}
let factor = factor_block.as_standard_layout().into_owned();
match design.as_dense_ref() {
Some(dense) => fast_ab(dense, &factor),
None => row_kernel_design_jf_column_dot(design, &factor, n_rows),
}
}
pub(crate) fn row_kernel_design_jf_rows(
design: &DesignMatrix,
factor_block: ArrayView2<'_, f64>,
start: usize,
end: usize,
) -> Array2<f64> {
let b = end.saturating_sub(start);
let rank = factor_block.ncols();
if rank == 0 {
return Array2::<f64>::zeros((b, 0));
}
let factor = factor_block.as_standard_layout().into_owned();
match design.as_dense_ref() {
Some(dense) => {
let block = dense.slice(s![start..end, ..]);
fast_ab(&block, &factor)
}
None => {
let mut out = Array2::<f64>::zeros((b, rank));
for (i, row) in (start..end).enumerate() {
for c in 0..rank {
out[[i, c]] = design.dot_row_view(row, factor.column(c));
}
}
out
}
}
}
pub(crate) fn row_kernel_pack_jf_axes<const K: usize>(
n_rows: usize,
rank: usize,
axes: impl IntoIterator<Item = (usize, Array2<f64>)>,
) -> Array2<f64> {
let mut jf = Array2::<f64>::zeros((n_rows, K * rank));
if rank == 0 {
return jf;
}
for (axis, block) in axes {
assert!(
axis < K,
"row-kernel JF axis index {axis} out of range for K={K}"
);
assert_eq!(
block.dim(),
(n_rows, rank),
"row-kernel JF axis {axis} block shape must be ({n_rows}, {rank})"
);
jf.slice_mut(s![.., axis * rank..(axis + 1) * rank])
.assign(&block);
}
jf
}
pub(crate) fn row_kernel_design_jf_column_dot(
design: &DesignMatrix,
factor_block: &Array2<f64>,
n_rows: usize,
) -> Array2<f64> {
let rank = factor_block.ncols();
let mut out = Array2::<f64>::zeros((n_rows, rank));
for c in 0..rank {
let result = design.dot(&factor_block.column(c).to_owned());
out.column_mut(c).assign(&result);
}
out
}
pub(crate) fn validate_row_kernel_cache_lengths(
context: &str,
expected_len: usize,
caches: &[(&str, usize)],
) -> Result<(), String> {
let mismatches = caches
.iter()
.filter_map(|(name, actual)| {
(*actual != expected_len).then_some(format!("{name}={actual}"))
})
.collect::<Vec<_>>();
if mismatches.is_empty() {
Ok(())
} else {
Err(format!(
"{context} row-kernel cache length mismatch: {} expected={expected_len}",
mismatches.join(" ")
))
}
}
pub struct RowKernelCache<const K: usize> {
pub n: usize,
pub p: usize,
pub nll: Vec<f64>,
pub gradients: Vec<[f64; K]>,
pub hessians: Vec<[[f64; K]; K]>,
}
pub fn build_row_kernel_cache<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
rows: &RowSet,
) -> Result<RowKernelCache<K>, String> {
let n = kern.n_rows();
let p = kern.n_coefficients();
let mut nll = vec![0.0_f64; n];
let mut gradients = vec![[0.0_f64; K]; n];
let mut hessians = vec![[[0.0_f64; K]; K]; n];
let work_count = match rows {
RowSet::All => n,
RowSet::Subsample { rows: list, .. } => list.len(),
};
let progress_ticker =
(work_count >= ROW_KERNEL_CACHE_PROGRESS_MIN_ROWS).then(LoopProgress::default_interval);
match rows {
RowSet::All => {
if let Some(Ok((bv, bg, bh))) = kern.batched_value_grad_hess_all() {
if bv.len() == n && bg.len() == n && bh.len() == n {
return Ok(RowKernelCache {
n,
p,
nll: bv,
gradients: bg,
hessians: bh,
});
}
}
let block_rows = cache_build_chunk_rows(n);
let evaluated_chunks: Vec<Vec<(f64, [f64; K], [[f64; K]; K])>> =
(0..cache_build_block_count(n, block_rows))
.into_par_iter()
.map(|block_idx| {
let start = block_idx * block_rows;
let end = (start + block_rows).min(n);
let mut chunk = Vec::with_capacity(end - start);
for row in start..end {
let out = kern.row_kernel(row)?;
if let Some(ticker) = progress_ticker.as_ref() {
ticker.tick(1, |progress, elapsed| {
log::info!(
"[STAGE] row-kernel cache (all) progress={}/{} ({:.1}%) elapsed={:.1}s threads={}",
progress.min(n),
n,
100.0 * progress.min(n) as f64 / n.max(1) as f64,
elapsed,
rayon::current_num_threads(),
);
});
}
chunk.push(out);
}
Ok(chunk)
})
.collect::<Result<Vec<_>, String>>()?;
for (block_idx, chunk) in evaluated_chunks.into_iter().enumerate() {
let start = block_idx * block_rows;
for (local, (l, g, h)) in chunk.into_iter().enumerate() {
let i = start + local;
nll[i] = l;
gradients[i] = g;
hessians[i] = h;
}
}
}
RowSet::Subsample { rows: list, .. } => {
let total = list.len();
let block_rows = cache_build_chunk_rows(total);
let pair_chunks: Vec<Vec<(usize, (f64, [f64; K], [[f64; K]; K]))>> = list
.par_chunks(block_rows)
.map(|row_chunk| {
let mut chunk = Vec::with_capacity(row_chunk.len());
for r in row_chunk {
let out = kern.row_kernel(r.index).map(|out| (r.index, out))?;
if let Some(ticker) = progress_ticker.as_ref() {
ticker.tick(1, |progress, elapsed| {
log::info!(
"[STAGE] row-kernel cache (subsample) progress={}/{} ({:.1}%) elapsed={:.1}s threads={}",
progress.min(total),
total,
100.0 * progress.min(total) as f64 / total.max(1) as f64,
elapsed,
rayon::current_num_threads(),
);
});
}
chunk.push(out);
}
Ok(chunk)
})
.collect::<Result<Vec<_>, String>>()?;
for chunk in pair_chunks {
for (idx, (l, g, h)) in chunk {
nll[idx] = l;
gradients[idx] = g;
hessians[idx] = h;
}
}
}
}
Ok(RowKernelCache {
n,
p,
nll,
gradients,
hessians,
})
}
pub fn row_kernel_hessian_matvec<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
rows: &RowSet,
direction: &[f64],
) -> Array1<f64> {
let p = cache.p;
let out = rows.par_reduce_fold(
cache.n,
|| vec![0.0_f64; p],
|mut acc, row, w| {
let dir_k = kern.jacobian_action(row, direction);
let h = &cache.hessians[row];
let mut action = [0.0_f64; K];
for a in 0..K {
let mut s = 0.0;
for b in 0..K {
s += h[a][b] * dir_k[b];
}
action[a] = w * s;
}
kern.jacobian_transpose_action(row, &action, &mut acc);
acc
},
|mut a, b| {
for i in 0..a.len() {
a[i] += b[i];
}
a
},
);
Array1::from_vec(out)
}
pub fn row_kernel_hessian_diagonal<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
rows: &RowSet,
) -> Array1<f64> {
let p = cache.p;
let out = rows.par_reduce_fold(
cache.n,
|| vec![0.0_f64; p],
|mut diag, row, w| {
if w == 1.0 {
kern.add_diagonal_quadratic(row, &cache.hessians[row], &mut diag);
} else {
let h = &cache.hessians[row];
let mut scaled = [[0.0_f64; K]; K];
for a in 0..K {
for b in 0..K {
scaled[a][b] = w * h[a][b];
}
}
kern.add_diagonal_quadratic(row, &scaled, &mut diag);
}
diag
},
|mut a, b| {
for i in 0..a.len() {
a[i] += b[i];
}
a
},
);
Array1::from_vec(out)
}
pub fn row_kernel_gradient<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
rows: &RowSet,
) -> Array1<f64> {
let p = cache.p;
let out = rows.par_reduce_fold(
cache.n,
|| vec![0.0_f64; p],
|mut acc, row, w| {
if w == 1.0 {
kern.jacobian_transpose_action(row, &cache.gradients[row], &mut acc);
} else {
let g = &cache.gradients[row];
let mut scaled = [0.0_f64; K];
for a in 0..K {
scaled[a] = w * g[a];
}
kern.jacobian_transpose_action(row, &scaled, &mut acc);
}
acc
},
|mut a, b| {
for i in 0..a.len() {
a[i] += b[i];
}
a
},
);
Array1::from_vec(out)
}
pub fn row_kernel_log_likelihood<const K: usize>(cache: &RowKernelCache<K>, rows: &RowSet) -> f64 {
let total = rows.par_reduce_fold(
cache.n,
|| 0.0_f64,
|acc, row, w| acc + w * cache.nll[row],
|a, b| a + b,
);
-total
}
pub fn row_kernel_hessian_dense<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
cache: &RowKernelCache<K>,
rows: &RowSet,
) -> Array2<f64> {
if let Some(dense) = kern.hessian_dense_override(rows, &cache.hessians) {
return dense;
}
row_kernel_hessian_dense_generic(kern, rows, &cache.hessians)
}
pub fn row_kernel_hessian_dense_generic<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
rows: &RowSet,
row_hessians: &[[[f64; K]; K]],
) -> Array2<f64> {
let p = kern.n_coefficients();
let n = row_hessians.len();
rows.par_reduce_fold(
n,
|| Array2::<f64>::zeros((p, p)),
|mut acc, row, w| {
if w == 1.0 {
kern.add_pullback_hessian(row, &row_hessians[row], &mut acc);
} else {
let h = &row_hessians[row];
let mut scaled = [[0.0_f64; K]; K];
for a in 0..K {
for b in 0..K {
scaled[a][b] = w * h[a][b];
}
}
kern.add_pullback_hessian(row, &scaled, &mut acc);
}
acc
},
|a, b| a + b,
)
}
pub fn row_kernel_directional_derivative<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
rows: &RowSet,
d_beta: &[f64],
) -> Result<Array2<f64>, String> {
if let Some(result) = kern.directional_derivative_dense_override(rows, d_beta) {
return result;
}
row_kernel_directional_derivative_generic(kern, rows, d_beta)
}
pub fn row_kernel_directional_derivative_generic<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
rows: &RowSet,
d_beta: &[f64],
) -> Result<Array2<f64>, String> {
let n = kern.n_rows();
let p = kern.n_coefficients();
kern.warm_up_directional_caches()?;
rows.par_try_reduce_fold(
n,
|| Array2::<f64>::zeros((p, p)),
|mut acc, row, w| -> Result<_, String> {
let dir_k = kern.jacobian_action(row, d_beta);
let third = kern.row_third_contracted(row, &dir_k)?;
if w == 1.0 {
kern.add_pullback_hessian(row, &third, &mut acc);
} else {
let mut scaled = [[0.0_f64; K]; K];
for a in 0..K {
for b in 0..K {
scaled[a][b] = w * third[a][b];
}
}
kern.add_pullback_hessian(row, &scaled, &mut acc);
}
Ok(acc)
},
|a, b| Ok(a + b),
)
}
pub fn row_kernel_directional_derivative_all_axes<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized + Sync),
rows: &RowSet,
) -> Result<Vec<Array2<f64>>, String> {
let p = kern.n_coefficients();
if let Some(result) = kern.directional_derivative_all_axes_dense_override(rows, p) {
return result;
}
(0..p)
.into_par_iter()
.map(|a| {
let mut axis = vec![0.0_f64; p];
axis[a] = 1.0;
gam_problem::with_nested_parallel(|| {
row_kernel_directional_derivative(kern, rows, &axis)
})
})
.collect::<Result<Vec<_>, _>>()
}
pub fn row_kernel_second_directional_derivative<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized),
rows: &RowSet,
d_beta_u: &[f64],
d_beta_v: &[f64],
) -> Result<Array2<f64>, String> {
let n = kern.n_rows();
let p = kern.n_coefficients();
kern.warm_up_directional_caches()?;
rows.par_try_reduce_fold(
n,
|| Array2::<f64>::zeros((p, p)),
|mut acc, row, w| -> Result<_, String> {
let dir_u = kern.jacobian_action(row, d_beta_u);
let dir_v = kern.jacobian_action(row, d_beta_v);
let fourth = kern.row_fourth_contracted(row, &dir_u, &dir_v)?;
if w == 1.0 {
kern.add_pullback_hessian(row, &fourth, &mut acc);
} else {
let mut scaled = [[0.0_f64; K]; K];
for a in 0..K {
for b in 0..K {
scaled[a][b] = w * fourth[a][b];
}
}
kern.add_pullback_hessian(row, &scaled, &mut acc);
}
Ok(acc)
},
|a, b| Ok(a + b),
)
}
pub fn row_kernel_second_directional_derivative_all_axes<const K: usize>(
kern: &(impl RowKernel<K> + ?Sized + Sync),
rows: &RowSet,
d_beta_u: &[f64],
) -> Result<Vec<Array2<f64>>, String> {
if let Some(result) = kern.second_directional_derivative_all_axes_dense_override(rows, d_beta_u)
{
return result;
}
let p = kern.n_coefficients();
(0..p)
.into_par_iter()
.map(|a| {
let mut axis = vec![0.0_f64; p];
axis[a] = 1.0;
gam_problem::with_nested_parallel(|| {
row_kernel_second_directional_derivative(kern, rows, d_beta_u, &axis)
})
})
.collect::<Result<Vec<_>, _>>()
}
struct RowKernelDirectionalDerivativeOperator<const K: usize, T: RowKernel<K>> {
kern: Arc<T>,
direction: Vec<f64>,
p: usize,
rows: RowSet,
}
impl<const K: usize, T: RowKernel<K>> HyperOperator
for RowKernelDirectionalDerivativeOperator<K, T>
{
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let direction = v
.as_slice()
.expect("row-kernel directional derivative operator requires contiguous input");
let out = self.rows.par_reduce_fold(
self.kern.n_rows(),
|| vec![0.0_f64; self.p],
|mut acc, row, w| {
let dir_k = self.kern.jacobian_action(row, &self.direction);
let vec_k = self.kern.jacobian_action(row, direction);
let third = self
.kern
.row_third_contracted(row, &dir_k)
.expect("row-kernel third contraction should succeed for validated directions");
let mut action = [0.0_f64; K];
for a in 0..K {
let mut sum = 0.0;
for b in 0..K {
sum += third[a][b] * vec_k[b];
}
action[a] = w * sum;
}
self.kern.jacobian_transpose_action(row, &action, &mut acc);
acc
},
|mut left, right| {
for idx in 0..left.len() {
left[idx] += right[idx];
}
left
},
);
Array1::from_vec(out)
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
if jf_projection_exceeds_budget::<K>(n_rows, rank) {
return self.trace_projected_factor_tiled(factor);
}
let jf = self.compute_jf(factor);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
if jf_projection_exceeds_budget::<K>(n_rows, rank) {
return self.trace_projected_factor_tiled(factor);
}
let jf = self.cached_jf(factor, cache);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return Array2::<f64>::zeros((rank, rank));
}
const BLAS3_PROJECTED_MATRIX_FLOP_THRESHOLD: usize = 2_500_000;
if n_rows.saturating_mul(rank).saturating_mul(rank) < BLAS3_PROJECTED_MATRIX_FLOP_THRESHOLD
{
let op_factor = self.mul_mat(factor);
return factor.t().dot(&op_factor);
}
let jf = self.compute_jf(factor);
assert_eq!(jf.dim(), (n_rows, K * rank));
let direction = self.direction.as_slice();
let t_flat = self.rows.par_reduce_fold(
n_rows,
|| vec![0.0_f64; n_rows * K * K],
|mut acc, row, w| {
let dir_k = self.kern.jacobian_action(row, direction);
let third = self
.kern
.row_third_contracted(row, &dir_k)
.expect("row-kernel third contraction should succeed for validated directions");
let base = row * (K * K);
for a in 0..K {
for b in 0..K {
acc[base + a * K + b] = w * third[a][b];
}
}
acc
},
|mut left, right| {
assert_eq!(left.len(), right.len());
for (l, r) in left.iter_mut().zip(right.iter()) {
*l += *r;
}
left
},
);
let mut out = Array2::<f64>::zeros((rank, rank));
let mut jf_axis_blocks: Vec<Array2<f64>> = Vec::with_capacity(K);
for a in 0..K {
jf_axis_blocks.push(
jf.slice(s![.., a * rank..(a + 1) * rank])
.as_standard_layout()
.into_owned(),
);
}
let mut w_col = Array1::<f64>::zeros(n_rows);
let mut jf_a_weighted: Array2<f64> = Array2::<f64>::zeros((n_rows, rank));
for a in 0..K {
for b in a..K {
for r in 0..n_rows {
w_col[r] = t_flat[r * (K * K) + a * K + b];
}
jf_a_weighted.assign(&jf_axis_blocks[a]);
for r in 0..n_rows {
let wr = w_col[r];
if wr == 0.0 {
for c in 0..rank {
jf_a_weighted[[r, c]] = 0.0;
}
} else {
for c in 0..rank {
jf_a_weighted[[r, c]] *= wr;
}
}
}
let contrib = jf_a_weighted.t().dot(&jf_axis_blocks[b]);
if a == b {
out += &contrib;
} else {
out += &contrib;
out += &contrib.t();
}
}
}
let out_t = out.t().to_owned();
out += &out_t;
out.mapv_inplace(|v| 0.5 * v);
out
}
fn to_dense(&self) -> Array2<f64> {
row_kernel_directional_derivative(&*self.kern, &self.rows, &self.direction)
.expect("row-kernel directional derivative dense materialization should succeed")
}
fn is_implicit(&self) -> bool {
true
}
}
impl<const K: usize, T: RowKernel<K>> RowKernelDirectionalDerivativeOperator<K, T> {
fn compute_jf(&self, factor: &Array2<f64>) -> Array2<f64> {
let n_rows = self.kern.n_rows();
let rank = factor.ncols();
let stride = K * rank;
if n_rows == 0 || rank == 0 {
return Array2::<f64>::zeros((n_rows, stride));
}
let jf = self
.kern
.jacobian_action_matrix(factor.view())
.unwrap_or_else(|| {
row_kernel_jacobian_action_matrix_generic(&*self.kern, factor.view())
});
assert_eq!(jf.dim(), (n_rows, stride));
jf
}
fn cached_jf(&self, factor: &Array2<f64>, cache: &ProjectedFactorCache) -> Arc<Array2<f64>> {
let design_id = Arc::as_ptr(&self.kern) as *const () as usize;
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
cache.get_or_insert_with(key, || self.compute_jf(factor))
}
fn trace_projected_factor_with_jf(&self, factor: &Array2<f64>, jf: ArrayView2<'_, f64>) -> f64 {
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
assert_eq!(jf.dim(), (n_rows, K * rank));
let direction = self.direction.as_slice();
deterministic_chunked_sum(n_rows, |chunk_idx| -> f64 {
let start = chunk_idx * ARROW_ROW_CHUNK;
let end = (start + ARROW_ROW_CHUNK).min(n_rows);
let mut chunk_total = 0.0_f64;
for row in start..end {
let dir_k = self.kern.jacobian_action(row, direction);
let third = self
.kern
.row_third_contracted(row, &dir_k)
.expect("row-kernel third contraction should succeed for validated directions");
let jf_row = jf.row(row);
let jf_slice = jf_row
.to_slice()
.expect("J·F is built standard-layout (row-major)");
let mut row_total = 0.0_f64;
for k_col in 0..rank {
let mut vec_k = [0.0_f64; K];
for k in 0..K {
vec_k[k] = jf_slice[k * rank + k_col];
}
let mut quad = 0.0_f64;
for a in 0..K {
let mut t_dot = 0.0_f64;
for b in 0..K {
t_dot += third[a][b] * vec_k[b];
}
quad += vec_k[a] * t_dot;
}
row_total += quad;
}
chunk_total += row_total;
}
chunk_total
})
}
fn trace_projected_factor_tiled(&self, factor: &Array2<f64>) -> f64 {
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
let direction = self.direction.as_slice();
let tile = jf_tile_rows::<K>(rank);
let mut total = 0.0_f64;
let mut tile_start = 0;
while tile_start < n_rows {
let tile_end = (tile_start + tile).min(n_rows);
let jf = self
.kern
.jacobian_action_matrix_rows(factor.view(), tile_start, tile_end);
let b = tile_end - tile_start;
total += deterministic_chunked_sum(b, |chunk_idx| -> f64 {
let start = chunk_idx * ARROW_ROW_CHUNK;
let end = (start + ARROW_ROW_CHUNK).min(b);
let mut chunk_total = 0.0_f64;
for local in start..end {
let row = tile_start + local;
let dir_k = self.kern.jacobian_action(row, direction);
let third = self.kern.row_third_contracted(row, &dir_k).expect(
"row-kernel third contraction should succeed for validated directions",
);
let jf_slice = jf
.row(local)
.to_slice()
.expect("J·F tile is built standard-layout (row-major)");
let mut row_total = 0.0_f64;
for k_col in 0..rank {
let mut vec_k = [0.0_f64; K];
for k in 0..K {
vec_k[k] = jf_slice[k * rank + k_col];
}
let mut quad = 0.0_f64;
for a in 0..K {
let mut t_dot = 0.0_f64;
for b2 in 0..K {
t_dot += third[a][b2] * vec_k[b2];
}
quad += vec_k[a] * t_dot;
}
row_total += quad;
}
chunk_total += row_total;
}
chunk_total
});
tile_start = tile_end;
}
total
}
}
struct RowKernelSecondDirectionalDerivativeOperator<const K: usize, T: RowKernel<K>> {
kern: Arc<T>,
direction_u: Vec<f64>,
direction_v: Vec<f64>,
p: usize,
rows: RowSet,
}
impl<const K: usize, T: RowKernel<K>> HyperOperator
for RowKernelSecondDirectionalDerivativeOperator<K, T>
{
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let direction = v
.as_slice()
.expect("row-kernel second directional derivative operator requires contiguous input");
let out = self.rows.par_reduce_fold(
self.kern.n_rows(),
|| vec![0.0_f64; self.p],
|mut acc, row, w| {
let dir_u = self.kern.jacobian_action(row, &self.direction_u);
let dir_v = self.kern.jacobian_action(row, &self.direction_v);
let vec_k = self.kern.jacobian_action(row, direction);
let fourth = self.kern.row_fourth_contracted(row, &dir_u, &dir_v).expect(
"row-kernel fourth contraction should succeed for validated directions",
);
let mut action = [0.0_f64; K];
for a in 0..K {
let mut sum = 0.0;
for b in 0..K {
sum += fourth[a][b] * vec_k[b];
}
action[a] = w * sum;
}
self.kern.jacobian_transpose_action(row, &action, &mut acc);
acc
},
|mut left, right| {
for idx in 0..left.len() {
left[idx] += right[idx];
}
left
},
);
Array1::from_vec(out)
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
if jf_projection_exceeds_budget::<K>(n_rows, rank) {
return self.trace_projected_factor_tiled(factor);
}
let jf = self.compute_jf(factor);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
assert_eq!(factor.nrows(), self.p);
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
if rank == 0 || n_rows == 0 {
return 0.0;
}
if jf_projection_exceeds_budget::<K>(n_rows, rank) {
return self.trace_projected_factor_tiled(factor);
}
let jf = self.cached_jf(factor, cache);
self.trace_projected_factor_with_jf(factor, jf.view())
}
fn to_dense(&self) -> Array2<f64> {
row_kernel_second_directional_derivative(
&*self.kern,
&self.rows,
&self.direction_u,
&self.direction_v,
)
.expect("row-kernel second directional derivative dense materialization should succeed")
}
fn is_implicit(&self) -> bool {
true
}
}
impl<const K: usize, T: RowKernel<K>> RowKernelSecondDirectionalDerivativeOperator<K, T> {
fn compute_jf(&self, factor: &Array2<f64>) -> Array2<f64> {
let n_rows = self.kern.n_rows();
let rank = factor.ncols();
let stride = K * rank;
if n_rows == 0 || rank == 0 {
return Array2::<f64>::zeros((n_rows, stride));
}
let jf = self
.kern
.jacobian_action_matrix(factor.view())
.unwrap_or_else(|| {
row_kernel_jacobian_action_matrix_generic(&*self.kern, factor.view())
});
assert_eq!(jf.dim(), (n_rows, stride));
jf
}
fn cached_jf(&self, factor: &Array2<f64>, cache: &ProjectedFactorCache) -> Arc<Array2<f64>> {
let design_id = Arc::as_ptr(&self.kern) as *const () as usize;
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
cache.get_or_insert_with(key, || self.compute_jf(factor))
}
fn trace_projected_factor_with_jf(&self, factor: &Array2<f64>, jf: ArrayView2<'_, f64>) -> f64 {
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
assert_eq!(jf.dim(), (n_rows, K * rank));
let direction_u = self.direction_u.as_slice();
let direction_v = self.direction_v.as_slice();
deterministic_chunked_sum(n_rows, |chunk_idx| -> f64 {
let start = chunk_idx * ARROW_ROW_CHUNK;
let end = (start + ARROW_ROW_CHUNK).min(n_rows);
let mut chunk_total = 0.0_f64;
for row in start..end {
let dir_u = self.kern.jacobian_action(row, direction_u);
let dir_v = self.kern.jacobian_action(row, direction_v);
let fourth = self.kern.row_fourth_contracted(row, &dir_u, &dir_v).expect(
"row-kernel fourth contraction should succeed for validated directions",
);
let jf_row = jf.row(row);
let jf_slice = jf_row
.to_slice()
.expect("J·F is built standard-layout (row-major)");
let mut row_total = 0.0_f64;
for k_col in 0..rank {
let mut vec_k = [0.0_f64; K];
for k in 0..K {
vec_k[k] = jf_slice[k * rank + k_col];
}
let mut quad = 0.0_f64;
for a in 0..K {
let mut t_dot = 0.0_f64;
for b in 0..K {
t_dot += fourth[a][b] * vec_k[b];
}
quad += vec_k[a] * t_dot;
}
row_total += quad;
}
chunk_total += row_total;
}
chunk_total
})
}
fn trace_projected_factor_tiled(&self, factor: &Array2<f64>) -> f64 {
let rank = factor.ncols();
let n_rows = self.kern.n_rows();
let direction_u = self.direction_u.as_slice();
let direction_v = self.direction_v.as_slice();
let tile = jf_tile_rows::<K>(rank);
let mut total = 0.0_f64;
let mut tile_start = 0;
while tile_start < n_rows {
let tile_end = (tile_start + tile).min(n_rows);
let jf = self
.kern
.jacobian_action_matrix_rows(factor.view(), tile_start, tile_end);
let b = tile_end - tile_start;
total += deterministic_chunked_sum(b, |chunk_idx| -> f64 {
let start = chunk_idx * ARROW_ROW_CHUNK;
let end = (start + ARROW_ROW_CHUNK).min(b);
let mut chunk_total = 0.0_f64;
for local in start..end {
let row = tile_start + local;
let dir_u = self.kern.jacobian_action(row, direction_u);
let dir_v = self.kern.jacobian_action(row, direction_v);
let fourth = self.kern.row_fourth_contracted(row, &dir_u, &dir_v).expect(
"row-kernel fourth contraction should succeed for validated directions",
);
let jf_slice = jf
.row(local)
.to_slice()
.expect("J·F tile is built standard-layout (row-major)");
let mut row_total = 0.0_f64;
for k_col in 0..rank {
let mut vec_k = [0.0_f64; K];
for k in 0..K {
vec_k[k] = jf_slice[k * rank + k_col];
}
let mut quad = 0.0_f64;
for a in 0..K {
let mut t_dot = 0.0_f64;
for b2 in 0..K {
t_dot += fourth[a][b2] * vec_k[b2];
}
quad += vec_k[a] * t_dot;
}
row_total += quad;
}
chunk_total += row_total;
}
chunk_total
});
tile_start = tile_end;
}
total
}
}
pub struct RowKernelHessianWorkspace<const K: usize, T: RowKernel<K>> {
kern: Arc<T>,
cache: RowKernelCache<K>,
rows: RowSet,
}
impl<const K: usize, T: RowKernel<K>> RowKernelHessianWorkspace<K, T> {
pub fn new(kern: T) -> Result<Self, String> {
Self::with_rows(kern, RowSet::All)
}
pub fn with_rows(kern: T, rows: RowSet) -> Result<Self, String> {
let kern = Arc::new(kern);
let cache = build_row_kernel_cache(&*kern, &rows)?;
Ok(Self { kern, cache, rows })
}
}
impl<const K: usize, T: RowKernel<K> + 'static> ExactNewtonJointHessianWorkspace
for RowKernelHessianWorkspace<K, T>
{
fn warm_up_outer_caches(&self) -> Result<(), String> {
self.kern.warm_up_directional_caches()
}
fn joint_log_likelihood_evaluation(&self) -> Result<Option<f64>, String> {
Ok(Some(row_kernel_log_likelihood(&self.cache, &self.rows)))
}
fn joint_gradient_evaluation(
&self,
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
Ok(Some(ExactNewtonJointGradientEvaluation {
log_likelihood: row_kernel_log_likelihood(&self.cache, &self.rows),
gradient: -row_kernel_gradient(&*self.kern, &self.cache, &self.rows),
}))
}
fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
Ok(Some(row_kernel_hessian_dense(
&*self.kern,
&self.cache,
&self.rows,
)))
}
fn hessian_source_preference_for_intent(
&self,
intent: MaterializationIntent,
) -> JointHessianSourcePreference {
match intent {
MaterializationIntent::InnerSolve
if use_joint_matrix_free_path(self.cache.p, self.cache.n) =>
{
JointHessianSourcePreference::Operator
}
MaterializationIntent::InnerSolve => JointHessianSourcePreference::Dense,
MaterializationIntent::LogdetFactorization
| MaterializationIntent::OuterEvaluation
| MaterializationIntent::OuterGradient => JointHessianSourcePreference::Dense,
}
}
fn hessian_matvec_available(&self) -> bool {
true
}
fn hessian_matvec(&self, v: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
let sl = v.as_slice().ok_or("hessian_matvec: non-contiguous input")?;
Ok(Some(row_kernel_hessian_matvec(
&*self.kern,
&self.cache,
&self.rows,
sl,
)))
}
fn hessian_matvec_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<bool, String> {
let result = self
.hessian_matvec(v)?
.ok_or_else(|| "row-kernel hessian_matvec unexpectedly unavailable".to_string())?;
if result.len() != out.len() {
return Err(format!(
"row-kernel hessian_matvec_into: result length {} != out length {}",
result.len(),
out.len()
));
}
out.assign(&result);
Ok(true)
}
fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
Ok(Some(row_kernel_hessian_diagonal(
&*self.kern,
&self.cache,
&self.rows,
)))
}
fn directional_derivative(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let sl = d_beta_flat
.as_slice()
.ok_or("directional_derivative: non-contiguous input")?;
row_kernel_directional_derivative(&*self.kern, &self.rows, sl).map(Some)
}
fn directional_derivative_operator(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let direction = d_beta_flat
.as_slice()
.ok_or("directional_derivative_operator: non-contiguous input")?
.to_vec();
Ok(Some(Arc::new(RowKernelDirectionalDerivativeOperator {
kern: Arc::clone(&self.kern),
direction,
p: self.cache.p,
rows: self.rows.clone(),
})))
}
fn second_directional_derivative(
&self,
d_beta_u: &Array1<f64>,
d_beta_v: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let su = d_beta_u
.as_slice()
.ok_or("second_directional_derivative: non-contiguous u")?;
let sv = d_beta_v
.as_slice()
.ok_or("second_directional_derivative: non-contiguous v")?;
row_kernel_second_directional_derivative(&*self.kern, &self.rows, su, sv).map(Some)
}
fn second_directional_derivative_operator(
&self,
d_beta_u: &Array1<f64>,
d_beta_v: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
let direction_u = d_beta_u
.as_slice()
.ok_or("second_directional_derivative_operator: non-contiguous u")?
.to_vec();
let direction_v = d_beta_v
.as_slice()
.ok_or("second_directional_derivative_operator: non-contiguous v")?
.to_vec();
Ok(Some(Arc::new(
RowKernelSecondDirectionalDerivativeOperator {
kern: Arc::clone(&self.kern),
direction_u,
direction_v,
p: self.cache.p,
rows: self.rows.clone(),
},
)))
}
}
#[cfg(test)]
mod gram_inner_contraction_tests {
use super::*;
use crate::custom_family::{
JointHessianSource, exact_newton_joint_hessian_source_from_workspace,
};
use gam_problem::ProjectedFactorCache;
use ndarray::Array2;
#[test]
fn pack_jf_axes_places_blocks_in_primary_axis_order() {
let axis0 = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let axis2 = Array2::from_shape_vec((2, 2), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
let packed = row_kernel_pack_jf_axes::<3>(2, 2, [(2, axis2), (0, axis0)]);
assert_eq!(packed.dim(), (2, 6));
assert_eq!(
packed,
Array2::from_shape_vec(
(2, 6),
vec![1.0, 2.0, 0.0, 0.0, 5.0, 6.0, 3.0, 4.0, 0.0, 0.0, 7.0, 8.0,],
)
.unwrap()
);
}
#[test]
fn validate_row_kernel_cache_lengths_reports_all_mismatches() {
validate_row_kernel_cache_lengths("ctx", 3, &[("third", 3), ("fourth", 3)])
.expect("matching lengths pass");
let err = validate_row_kernel_cache_lengths("ctx", 3, &[("third", 2), ("fourth", 4)])
.expect_err("mismatches fail");
assert_eq!(
err,
"ctx row-kernel cache length mismatch: third=2 fourth=4 expected=3"
);
}
struct SyntheticKernel {
n: usize,
p: usize,
designs: [Array2<f64>; 4],
}
impl SyntheticKernel {
fn new(n: usize, p: usize, seed: u64) -> Self {
let mut s = seed;
let mut next = || -> f64 {
s = s
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((s >> 33) as f64 / (u32::MAX as f64)) - 0.5
};
let mut mk = || -> Array2<f64> { Array2::from_shape_fn((n, p), |_| next()) };
let d0 = mk();
let d1 = mk();
let d2 = mk();
let d3 = mk();
Self {
n,
p,
designs: [d0, d1, d2, d3],
}
}
}
impl RowKernel<4> for SyntheticKernel {
fn n_rows(&self) -> usize {
self.n
}
fn n_coefficients(&self) -> usize {
self.p
}
fn row_kernel(&self, row: usize) -> Result<(f64, [f64; 4], [[f64; 4]; 4]), String> {
if row >= self.n {
return Err(format!("synthetic row {row} outside n={}", self.n));
}
let mut grad = [0.0_f64; 4];
let mut hess = [[0.0_f64; 4]; 4];
for k in 0..4 {
grad[k] = self.designs[k].row(row).sum();
hess[k][k] = 1.0 + (row as f64 + k as f64).abs() * 1.0e-6;
}
Ok((0.5 * grad.iter().map(|v| v * v).sum::<f64>(), grad, hess))
}
fn jacobian_action(&self, row: usize, d_beta: &[f64]) -> [f64; 4] {
let mut out = [0.0_f64; 4];
for k in 0..4 {
let design_row = self.designs[k].row(row);
let mut s = 0.0_f64;
for j in 0..self.p {
s += design_row[j] * d_beta[j];
}
out[k] = s;
}
out
}
fn jacobian_transpose_action(&self, row: usize, v: &[f64; 4], out: &mut [f64]) {
for k in 0..4 {
let design_row = self.designs[k].row(row);
for j in 0..self.p {
out[j] += design_row[j] * v[k];
}
}
}
fn add_pullback_hessian(&self, row: usize, h: &[[f64; 4]; 4], target: &mut Array2<f64>) {
for a in 0..4 {
let row_a = self.designs[a].row(row);
for b in 0..4 {
let scale = h[a][b];
if scale == 0.0 {
continue;
}
let row_b = self.designs[b].row(row);
for i in 0..self.p {
for j in 0..self.p {
target[[i, j]] += scale * row_a[i] * row_b[j];
}
}
}
}
}
fn add_diagonal_quadratic(&self, row: usize, h: &[[f64; 4]; 4], diag: &mut [f64]) {
for j in 0..self.p {
let mut acc = 0.0;
for a in 0..4 {
let x_a = self.designs[a][[row, j]];
for b in 0..4 {
acc += h[a][b] * x_a * self.designs[b][[row, j]];
}
}
diag[j] += acc;
}
}
fn row_third_contracted(
&self,
row: usize,
dir: &[f64; 4],
) -> Result<[[f64; 4]; 4], String> {
let mut t = [[0.0_f64; 4]; 4];
let row_f = (row as f64) * 0.013;
for a in 0..4 {
for b in a..4 {
let v = (row_f + a as f64 * 0.7 + b as f64 * 1.3).sin()
+ dir[a] * 0.25
+ dir[b] * 0.5
+ dir[(a + b) % 4] * 0.125;
t[a][b] = v;
t[b][a] = v;
}
}
Ok(t)
}
fn row_fourth_contracted(
&self,
row: usize,
dir_u: &[f64; 4],
dir_v: &[f64; 4],
) -> Result<[[f64; 4]; 4], String> {
let mut t = [[0.0_f64; 4]; 4];
let row_f = (row as f64) * 0.011 + 0.31;
for a in 0..4 {
for b in a..4 {
let v = (row_f + a as f64 * 0.9 + b as f64 * 1.7).cos()
+ dir_u[a] * 0.13
+ dir_v[b] * 0.27
+ dir_u[(a + b) % 4] * dir_v[(a + 1) % 4] * 0.05;
t[a][b] = v;
t[b][a] = v;
}
}
Ok(t)
}
}
#[test]
fn row_kernel_workspace_routes_inner_solve_to_operator() {
let p = crate::custom_family::JOINT_MATRIX_FREE_MIN_DIM;
let kernel = SyntheticKernel::new(8, p, 0x979);
let workspace: Arc<dyn ExactNewtonJointHessianWorkspace> =
Arc::new(RowKernelHessianWorkspace::new(kernel).expect("workspace"));
let source = exact_newton_joint_hessian_source_from_workspace(
&workspace,
p,
MaterializationIntent::InnerSolve,
"row-kernel inner source",
)
.expect("source construction succeeds")
.expect("source is present");
let JointHessianSource::Operator {
apply,
apply_into,
diagonal,
..
} = source
else {
panic!("row-kernel inner solve must use operator source");
};
assert_eq!(diagonal.len(), p);
let v = Array1::from_shape_fn(p, |i| (i as f64 % 7.0 - 3.0) * 0.125);
let hv = apply(&v).expect("operator apply succeeds");
let mut hv_into = Array1::<f64>::zeros(p);
apply_into(&v, &mut hv_into).expect("operator apply_into succeeds");
assert_eq!(hv, hv_into);
}
fn reference_trace_first<const K: usize>(
kern: &impl RowKernel<K>,
direction: &[f64],
factor: &Array2<f64>,
) -> f64 {
let n_rows = kern.n_rows();
let rank = factor.ncols();
let p = factor.nrows();
let mut total = 0.0_f64;
for row in 0..n_rows {
let dir_k_arr = kern.jacobian_action(row, direction);
let third = kern.row_third_contracted(row, &dir_k_arr).expect("third");
for k_col in 0..rank {
let mut col = vec![0.0_f64; p];
for j in 0..p {
col[j] = factor[[j, k_col]];
}
let vec_k = kern.jacobian_action(row, &col);
let mut quad = 0.0_f64;
for a in 0..K {
let mut t_dot = 0.0_f64;
for b in 0..K {
t_dot += third[a][b] * vec_k[b];
}
quad += vec_k[a] * t_dot;
}
total += quad;
}
}
total
}
fn reference_trace_second<const K: usize>(
kern: &impl RowKernel<K>,
direction_u: &[f64],
direction_v: &[f64],
factor: &Array2<f64>,
) -> f64 {
let n_rows = kern.n_rows();
let rank = factor.ncols();
let p = factor.nrows();
let mut total = 0.0_f64;
for row in 0..n_rows {
let dir_u = kern.jacobian_action(row, direction_u);
let dir_v = kern.jacobian_action(row, direction_v);
let fourth = kern
.row_fourth_contracted(row, &dir_u, &dir_v)
.expect("fourth");
for k_col in 0..rank {
let mut col = vec![0.0_f64; p];
for j in 0..p {
col[j] = factor[[j, k_col]];
}
let vec_k = kern.jacobian_action(row, &col);
let mut quad = 0.0_f64;
for a in 0..K {
let mut t_dot = 0.0_f64;
for b in 0..K {
t_dot += fourth[a][b] * vec_k[b];
}
quad += vec_k[a] * t_dot;
}
total += quad;
}
}
total
}
#[test]
fn gram_inner_contraction_matches_reference() {
let n = 32;
let p = 11;
let rank = 7;
let kern = Arc::new(SyntheticKernel::new(n, p, 0xC0FFEE));
let mut s = 0xDEADBEEF_u64;
let mut next = || -> f64 {
s = s
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((s >> 33) as f64 / (u32::MAX as f64)) - 0.5
};
let direction: Vec<f64> = (0..p).map(|_| next()).collect();
let direction_u: Vec<f64> = (0..p).map(|_| next()).collect();
let direction_v: Vec<f64> = (0..p).map(|_| next()).collect();
let factor = Array2::from_shape_fn((p, rank), |_| next());
let op1 = RowKernelDirectionalDerivativeOperator {
kern: Arc::clone(&kern),
direction: direction.clone(),
p,
rows: RowSet::All,
};
let cache = ProjectedFactorCache::default();
let got1_uncached = HyperOperator::trace_projected_factor(&op1, &factor);
let got1_cached = op1.trace_projected_factor_cached(&factor, &cache);
let ref1 = reference_trace_first::<4>(&*kern, &direction, &factor);
let rel1_uncached = (got1_uncached - ref1).abs() / ref1.abs().max(1e-12);
let rel1_cached = (got1_cached - ref1).abs() / ref1.abs().max(1e-12);
assert!(
rel1_uncached < 1e-10,
"first-derivative Gram path drifted: rel={rel1_uncached:.3e} got={got1_uncached} ref={ref1}",
);
assert!(
rel1_cached < 1e-10,
"first-derivative cached Gram path drifted: rel={rel1_cached:.3e} got={got1_cached} ref={ref1}",
);
let op2 = RowKernelSecondDirectionalDerivativeOperator {
kern: Arc::clone(&kern),
direction_u: direction_u.clone(),
direction_v: direction_v.clone(),
p,
rows: RowSet::All,
};
let cache2 = ProjectedFactorCache::default();
let got2_uncached = HyperOperator::trace_projected_factor(&op2, &factor);
let got2_cached = op2.trace_projected_factor_cached(&factor, &cache2);
let ref2 = reference_trace_second::<4>(&*kern, &direction_u, &direction_v, &factor);
let rel2_uncached = (got2_uncached - ref2).abs() / ref2.abs().max(1e-12);
let rel2_cached = (got2_cached - ref2).abs() / ref2.abs().max(1e-12);
assert!(
rel2_uncached < 1e-10,
"second-derivative Gram path drifted: rel={rel2_uncached:.3e} got={got2_uncached} ref={ref2}",
);
assert!(
rel2_cached < 1e-10,
"second-derivative cached Gram path drifted: rel={rel2_cached:.3e} got={got2_cached} ref={ref2}",
);
}
struct BuildCountingKernel {
inner: SyntheticKernel,
builds: std::sync::Arc<std::sync::atomic::AtomicUsize>,
}
impl BuildCountingKernel {
fn new(n: usize, p: usize, seed: u64) -> Self {
let builds = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
builds.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Self {
inner: SyntheticKernel::new(n, p, seed),
builds,
}
}
}
impl RowKernel<4> for BuildCountingKernel {
fn n_rows(&self) -> usize {
self.inner.n_rows()
}
fn n_coefficients(&self) -> usize {
self.inner.n_coefficients()
}
fn row_kernel(&self, row: usize) -> Result<(f64, [f64; 4], [[f64; 4]; 4]), String> {
self.inner.row_kernel(row)
}
fn jacobian_action(&self, row: usize, d_beta: &[f64]) -> [f64; 4] {
self.inner.jacobian_action(row, d_beta)
}
fn jacobian_transpose_action(&self, row: usize, v: &[f64; 4], out: &mut [f64]) {
self.inner.jacobian_transpose_action(row, v, out)
}
fn add_pullback_hessian(&self, row: usize, h: &[[f64; 4]; 4], target: &mut Array2<f64>) {
self.inner.add_pullback_hessian(row, h, target)
}
fn add_diagonal_quadratic(&self, row: usize, h: &[[f64; 4]; 4], diag: &mut [f64]) {
self.inner.add_diagonal_quadratic(row, h, diag)
}
fn row_third_contracted(
&self,
row: usize,
dir: &[f64; 4],
) -> Result<[[f64; 4]; 4], String> {
self.inner.row_third_contracted(row, dir)
}
fn row_fourth_contracted(
&self,
row: usize,
dir_u: &[f64; 4],
dir_v: &[f64; 4],
) -> Result<[[f64; 4]; 4], String> {
self.inner.row_fourth_contracted(row, dir_u, dir_v)
}
}
#[test]
fn all_axes_directional_derivative_is_build_once_and_matches_per_axis_loop_979() {
let n = 24usize;
let p = 7usize;
let kern = BuildCountingKernel::new(n, p, 0x979);
let builds = std::sync::Arc::clone(&kern.builds);
let batched =
row_kernel_directional_derivative_all_axes(&kern, &RowSet::All).expect("all-axes ok");
let n_builds = builds.load(std::sync::atomic::Ordering::Relaxed);
assert_eq!(
n_builds, 1,
"all-axes Jeffreys sweep rebuilt the row kernel {n_builds} times for p={p}; \
the #979 fix must build it ONCE and sweep every axis off that single kernel \
(a revert to the per-axis `SurvivalMarginalSlopeRowKernel::new` loop reads p)",
);
assert_eq!(batched.len(), p, "one Hdot matrix per coefficient axis");
for (a, hdot_a) in batched.iter().enumerate() {
let mut e_a = vec![0.0_f64; p];
e_a[a] = 1.0;
let per_axis = row_kernel_directional_derivative(&kern, &RowSet::All, &e_a)
.expect("per-axis directional derivative ok");
assert_eq!(hdot_a.dim(), per_axis.dim(), "axis {a} shape mismatch");
let mut max_abs = 0.0_f64;
for (g, r) in hdot_a.iter().zip(per_axis.iter()) {
max_abs = max_abs.max((g - r).abs());
}
assert!(
max_abs <= 1e-9,
"batched all-axes Hdot[e_{a}] diverged from the per-axis sweep by \
{max_abs:.3e} (> 1e-9); the #979 build-once route must be numerically \
identical to the per-axis loop it replaced",
);
}
}
}