use std::sync::Arc;
use ndarray::{Array1, Array2, Array3, Axis, s};
use crate::linalg::faer_ndarray::{
FaerEigh, default_rrqr_rank_alpha, fast_ab, fast_ata, fast_atb, fast_xt_diag_y,
rrqr_with_permutation,
};
use faer::Side;
pub trait RowJacobianOperator: Send + Sync {
fn k(&self) -> usize;
fn ncols(&self) -> usize;
fn nrows(&self) -> usize;
fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]);
fn evaluate_full(&self) -> Array3<f64>;
}
pub trait RowHessian: Send + Sync {
fn k(&self) -> usize;
fn nrows(&self) -> usize;
fn fill_row(&self, row: usize, out: &mut [f64]);
fn evaluate_full(&self) -> Array3<f64>;
}
pub struct IdentityRowHessian {
n: usize,
k: usize,
}
impl IdentityRowHessian {
pub fn new(n: usize, k: usize) -> Self {
Self { n, k }
}
}
impl RowHessian for IdentityRowHessian {
fn k(&self) -> usize {
self.k
}
fn nrows(&self) -> usize {
self.n
}
fn fill_row(&self, row: usize, out: &mut [f64]) {
assert!(
row < self.n,
"IdentityRowHessian::fill_row row {row} out of range {n}",
n = self.n
);
assert_eq!(out.len(), self.k * self.k);
for i in 0..self.k {
for j in 0..self.k {
out[i * self.k + j] = if i == j { 1.0 } else { 0.0 };
}
}
}
fn evaluate_full(&self) -> Array3<f64> {
let mut out = Array3::<f64>::zeros((self.n, self.k, self.k));
for i in 0..self.n {
for c in 0..self.k {
out[[i, c, c]] = 1.0;
}
}
out
}
}
pub trait AnchorRowEvaluator: Send + Sync {
fn anchor_rows(&self, predict_arg: &Array1<f64>) -> Result<Array2<f64>, String>;
fn ncols(&self) -> usize;
}
pub struct CompiledBlock {
pub t_lw: Array2<f64>,
pub anchor_correction: Option<Array2<f64>>,
pub r_lw: Option<Array2<f64>>,
pub anchor_evaluator: Option<Arc<dyn AnchorRowEvaluator>>,
}
pub struct CompiledBlocks {
pub blocks: Vec<CompiledBlock>,
pub joint_rank: usize,
pub dropped: Vec<(usize, usize)>,
}
#[derive(Debug)]
pub enum CompilerError {
DimensionMismatch(String),
FullyAliased { block_idx: usize, reason: String },
LinalgFailure(String),
}
impl std::fmt::Display for CompilerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CompilerError::DimensionMismatch(msg) => write!(f, "dimension mismatch: {msg}"),
CompilerError::FullyAliased { block_idx, reason } => {
write!(f, "block {block_idx} fully aliased: {reason}")
}
CompilerError::LinalgFailure(msg) => write!(f, "linalg failure: {msg}"),
}
}
}
impl std::error::Error for CompilerError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlockOrder {
Time,
Marginal,
Logslope,
ScoreWarp,
LinkDev,
}
pub fn compile(
operators: &[Arc<dyn RowJacobianOperator>],
row_hess: &dyn RowHessian,
ordering: &[BlockOrder],
) -> Result<CompiledBlocks, CompilerError> {
let n = row_hess.nrows();
let k = row_hess.k();
let id_struct = IdentityRowHessian::new(n, k);
compile_with_dual_metric(operators, row_hess, &id_struct, ordering)
}
pub fn compile_with_dual_metric(
operators: &[Arc<dyn RowJacobianOperator>],
row_hess: &dyn RowHessian,
row_structural: &dyn RowHessian,
ordering: &[BlockOrder],
) -> Result<CompiledBlocks, CompilerError> {
if operators.len() != ordering.len() {
return Err(CompilerError::DimensionMismatch(format!(
"operators ({}) and ordering ({}) length mismatch",
operators.len(),
ordering.len()
)));
}
if operators.is_empty() {
return Ok(CompiledBlocks {
blocks: Vec::new(),
joint_rank: 0,
dropped: Vec::new(),
});
}
let k = row_hess.k();
let n = row_hess.nrows();
if row_structural.k() != k {
return Err(CompilerError::DimensionMismatch(format!(
"structural row metric has K={} but curvature row Hessian has K={k}",
row_structural.k()
)));
}
if row_structural.nrows() != n {
return Err(CompilerError::DimensionMismatch(format!(
"structural row metric has nrows={} but curvature row Hessian has nrows={n}",
row_structural.nrows()
)));
}
for (idx, op) in operators.iter().enumerate() {
if op.k() != k {
return Err(CompilerError::DimensionMismatch(format!(
"operator {idx} has K={} but row Hessian has K={k}",
op.k()
)));
}
if op.nrows() != n {
return Err(CompilerError::DimensionMismatch(format!(
"operator {idx} has nrows={} but row Hessian has nrows={n}",
op.nrows()
)));
}
}
let h_full = row_hess.evaluate_full();
let s_full = row_structural.evaluate_full();
let j_full: Vec<Array3<f64>> = operators.iter().map(|op| op.evaluate_full()).collect();
let scaled_h: Vec<Array2<f64>> = j_full
.iter()
.map(|jb| scale_block_by_sqrt_h(jb, &h_full))
.collect();
let scaled_s: Vec<Array2<f64>> = j_full
.iter()
.map(|jb| scale_block_by_sqrt_h(jb, &s_full))
.collect();
let mut compiled: Vec<CompiledBlock> = Vec::with_capacity(operators.len());
let mut anchor_h: Array2<f64> = Array2::zeros((n * k, 0));
let mut anchor_s: Array2<f64> = Array2::zeros((n * k, 0));
for idx in 0..operators.len() {
let w_h = &scaled_h[idx];
let w_s = &scaled_s[idx];
let p_b = w_h.ncols();
let (residual_s, _) = residualise_in_metric(&anchor_s, w_s)?;
let g_s = fast_atb(&residual_s, &residual_s);
let g_s_trace: f64 = (0..p_b).map(|i| g_s[[i, i]].max(0.0)).sum();
let d = keep_positive_eigenspace(&g_s, n, k, g_s_trace)?;
if d.ncols() == 0 {
return Err(CompilerError::FullyAliased {
block_idx: idx,
reason: format!(
"structural residual Gram has no positive eigenspace (block of width {p_b} fully aliased by cumulative structural anchor)"
),
});
}
let w_h_d = fast_ab(w_h, &d);
let (residual_h, m_h_inner_opt) = residualise_in_metric(&anchor_h, &w_h_d)?;
let g_h = fast_atb(&residual_h, &residual_h);
let p_d = d.ncols();
let g_h_trace: f64 = (0..p_d).map(|i| g_h[[i, i]].max(0.0)).sum();
let t_inner = keep_positive_eigenspace(&g_h, n, k, g_h_trace)?;
if t_inner.ncols() == 0 {
return Err(CompilerError::FullyAliased {
block_idx: idx,
reason: format!(
"curvature residual Gram has no positive eigenspace within structurally-kept basis (block of width {p_b}, structural-kept {p_d})"
),
});
}
let v = fast_ab(&d, &t_inner);
let residual_h_t = fast_ab(&residual_h, &t_inner);
anchor_h = concat_cols(&anchor_h, &residual_h_t);
let residual_s_v = fast_ab(&residual_s, &v);
anchor_s = concat_cols(&anchor_s, &residual_s_v);
let m_compiled = m_h_inner_opt.as_ref().map(|m| fast_ab(m, &t_inner));
compiled.push(CompiledBlock {
t_lw: v,
anchor_correction: m_compiled.clone(),
r_lw: m_compiled,
anchor_evaluator: None,
});
}
let dropped = audit_and_drop_trailing_pivots(&anchor_h, &mut compiled)?;
let joint_rank: usize = compiled.iter().map(|b| b.t_lw.ncols()).sum();
Ok(CompiledBlocks {
blocks: compiled,
joint_rank,
dropped,
})
}
fn scale_block_by_sqrt_h(jb: &Array3<f64>, h_full: &Array3<f64>) -> Array2<f64> {
let n = jb.shape()[0];
let p = jb.shape()[1];
let k = jb.shape()[2];
assert_eq!(h_full.shape(), &[n, k, k]);
let mut out = Array2::<f64>::zeros((n * k, p));
let mut sqrt_h = Array2::<f64>::zeros((k, k));
let mut scratch_jrow = Array2::<f64>::zeros((p, k));
for i in 0..n {
let h_i = h_full.index_axis(Axis(0), i).to_owned();
sqrt_h.fill(0.0);
symmetric_sqrt_into(&h_i, &mut sqrt_h);
for a in 0..p {
for c in 0..k {
scratch_jrow[[a, c]] = jb[[i, a, c]];
}
}
for c in 0..k {
for a in 0..p {
let mut acc = 0.0;
for cp in 0..k {
acc += sqrt_h[[c, cp]] * scratch_jrow[[a, cp]];
}
out[[i * k + c, a]] = acc;
}
}
}
out
}
fn symmetric_sqrt_into(m: &Array2<f64>, out: &mut Array2<f64>) {
let k = m.nrows();
assert_eq!(m.ncols(), k);
assert_eq!(out.shape(), &[k, k]);
if k == 1 {
out[[0, 0]] = m[[0, 0]].max(0.0).sqrt();
return;
}
let (evals, evecs) = match m.eigh(Side::Lower) {
Ok(pair) => pair,
Err(_) => {
out.fill(0.0);
for i in 0..k {
out[[i, i]] = m[[i, i]].max(0.0).sqrt();
}
return;
}
};
let mut scaled = evecs.clone();
for j in 0..k {
let s = evals[j].max(0.0).sqrt();
for i in 0..k {
scaled[[i, j]] *= s;
}
}
out.assign(&fast_atb(&evecs.t().to_owned(), &scaled.t().to_owned()));
out.fill(0.0);
for i in 0..k {
for j in 0..k {
let mut acc = 0.0;
for l in 0..k {
acc += evecs[[i, l]] * evals[l].max(0.0).sqrt() * evecs[[j, l]];
}
out[[i, j]] = acc;
}
}
}
fn residualise_in_metric(
a_scaled: &Array2<f64>,
b_scaled: &Array2<f64>,
) -> Result<(Array2<f64>, Option<Array2<f64>>), CompilerError> {
let d = a_scaled.ncols();
if d == 0 {
return Ok((b_scaled.clone(), None));
}
let g_aa = fast_atb(a_scaled, a_scaled);
let g_ab = fast_atb(a_scaled, b_scaled);
let m = solve_psd_system(&g_aa, &g_ab)?;
let a_m = fast_ab(a_scaled, &m);
let residual = b_scaled - &a_m;
Ok((residual, Some(m)))
}
fn solve_psd_system(g: &Array2<f64>, r: &Array2<f64>) -> Result<Array2<f64>, CompilerError> {
let n = g.nrows();
if n == 0 {
return Ok(Array2::zeros((0, r.ncols())));
}
let (evals, evecs) = g
.eigh(Side::Lower)
.map_err(|err| CompilerError::LinalgFailure(format!("Gram eigh failed: {err:?}")))?;
let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
let tol = lambda_max * 64.0 * (n.max(1) as f64) * f64::EPSILON;
let u_t_r = fast_atb(&evecs, r);
let mut scaled = u_t_r.clone();
for i in 0..n {
let lam = evals[i];
let inv = if lam > tol { 1.0 / lam } else { 0.0 };
for j in 0..scaled.ncols() {
scaled[[i, j]] *= inv;
}
}
let m = fast_ab(&evecs, &scaled);
Ok(m)
}
fn keep_positive_eigenspace(
g_tilde: &Array2<f64>,
n: usize,
k: usize,
g_bb_trace: f64,
) -> Result<Array2<f64>, CompilerError> {
let p = g_tilde.nrows();
if p == 0 {
return Ok(Array2::zeros((0, 0)));
}
let (evals, evecs) = g_tilde.eigh(Side::Lower).map_err(|err| {
CompilerError::LinalgFailure(format!("residual Gram eigh failed: {err:?}"))
})?;
let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max).max(0.0);
let scale = lambda_max.max(g_bb_trace);
let nk = (n.saturating_mul(k)).max(p).max(1) as f64;
let tau = scale * 64.0 * nk * f64::EPSILON;
let mut kept: Vec<usize> = (0..p).filter(|&i| evals[i] > tau).collect();
kept.sort_by(|&a, &b| {
evals[b]
.partial_cmp(&evals[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut v = Array2::<f64>::zeros((p, kept.len()));
for (out_col, &src_col) in kept.iter().enumerate() {
for row in 0..p {
v[[row, out_col]] = evecs[[row, src_col]];
}
}
Ok(v)
}
fn concat_cols(left: &Array2<f64>, right: &Array2<f64>) -> Array2<f64> {
let nrows = left.nrows().max(right.nrows());
let lc = left.ncols();
let rc = right.ncols();
let mut out = Array2::<f64>::zeros((nrows, lc + rc));
if lc > 0 {
out.slice_mut(s![.., ..lc]).assign(left);
}
if rc > 0 {
out.slice_mut(s![.., lc..]).assign(right);
}
out
}
fn audit_and_drop_trailing_pivots(
w_joint: &Array2<f64>,
compiled: &mut [CompiledBlock],
) -> Result<Vec<(usize, usize)>, CompilerError> {
let p_total: usize = compiled.iter().map(|b| b.t_lw.ncols()).sum();
if p_total == 0 || w_joint.nrows() == 0 {
return Ok(Vec::new());
}
let rrqr = rrqr_with_permutation(w_joint, default_rrqr_rank_alpha())
.map_err(|err| CompilerError::LinalgFailure(format!("audit RRQR failed: {err:?}")))?;
let rank = rrqr.rank;
if rank >= p_total {
return Ok(Vec::new());
}
let drop_count = p_total - rank;
let latest_idx = compiled.len() - 1;
let latest = &mut compiled[latest_idx];
let kept_local = latest.t_lw.ncols().saturating_sub(drop_count);
let dropped_locals: Vec<(usize, usize)> = (kept_local..latest.t_lw.ncols())
.map(|c| (latest_idx, c))
.collect();
latest.t_lw = latest.t_lw.slice(s![.., ..kept_local]).to_owned();
if let Some(m) = latest.anchor_correction.as_ref() {
latest.anchor_correction = Some(m.slice(s![.., ..kept_local]).to_owned());
}
if let Some(r) = latest.r_lw.as_ref() {
latest.r_lw = Some(r.slice(s![.., ..kept_local]).to_owned());
}
Ok(dropped_locals)
}
pub struct PrimaryChannelBlocks {
pub blocks: Vec<Vec<Option<Array2<f64>>>>,
}
pub fn build_raw_grams_from_channel_blocks(
channel_blocks: &PrimaryChannelBlocks,
row_hess: &dyn RowHessian,
raw_block_ranges: &[std::ops::Range<usize>],
) -> Result<Array2<f64>, CompilerError> {
let num_blocks = channel_blocks.blocks.len();
if num_blocks != raw_block_ranges.len() {
return Err(CompilerError::DimensionMismatch(format!(
"channel_blocks ({num_blocks}) and raw_block_ranges ({}) length mismatch",
raw_block_ranges.len()
)));
}
if num_blocks == 0 {
return Ok(Array2::<f64>::zeros((0, 0)));
}
let k = row_hess.k();
let n = row_hess.nrows();
let p_total: usize = raw_block_ranges.iter().map(|r| r.end - r.start).sum();
let expected_total = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
if expected_total != p_total {
return Err(CompilerError::DimensionMismatch(format!(
"raw_block_ranges must be contiguous from 0; got p_total={p_total} but last end={expected_total}"
)));
}
for (b, slots) in channel_blocks.blocks.iter().enumerate() {
if slots.len() != k {
return Err(CompilerError::DimensionMismatch(format!(
"block {b}: expected {k} channel slots, got {}",
slots.len()
)));
}
let p_b = raw_block_ranges[b].end - raw_block_ranges[b].start;
for (c, mat) in slots.iter().enumerate() {
if let Some(x) = mat.as_ref() {
if x.nrows() != n {
return Err(CompilerError::DimensionMismatch(format!(
"block {b} channel {c}: nrows={} but row Hessian nrows={n}",
x.nrows()
)));
}
if x.ncols() != p_b {
return Err(CompilerError::DimensionMismatch(format!(
"block {b} channel {c}: ncols={} but block width={p_b}",
x.ncols()
)));
}
}
}
}
let h_full = row_hess.evaluate_full();
if h_full.shape() != &[n, k, k] {
return Err(CompilerError::DimensionMismatch(format!(
"row Hessian evaluate_full shape {:?} != [n={n}, k={k}, k={k}]",
h_full.shape()
)));
}
let mut h_pairs: Vec<Array1<f64>> = Vec::with_capacity(k * k);
for c in 0..k {
for d in 0..k {
let mut v = Array1::<f64>::zeros(n);
for i in 0..n {
v[i] = h_full[[i, c, d]];
}
h_pairs.push(v);
}
}
let mut gram = Array2::<f64>::zeros((p_total, p_total));
for a in 0..num_blocks {
let range_a = raw_block_ranges[a].clone();
for b in a..num_blocks {
let range_b = raw_block_ranges[b].clone();
let mut block_acc =
Array2::<f64>::zeros((range_a.end - range_a.start, range_b.end - range_b.start));
for c in 0..k {
let Some(x_a_c) = channel_blocks.blocks[a][c].as_ref() else {
continue;
};
for d in 0..k {
let Some(x_b_d) = channel_blocks.blocks[b][d].as_ref() else {
continue;
};
let h_cd = &h_pairs[c * k + d];
let contrib = fast_xt_diag_y(x_a_c, h_cd, x_b_d);
block_acc += &contrib;
}
}
gram.slice_mut(s![range_a.start..range_a.end, range_b.start..range_b.end])
.assign(&block_acc);
}
}
for i in 0..p_total {
for j in 0..i {
let v = gram[[j, i]];
gram[[i, j]] = v;
}
}
Ok(gram)
}
pub fn build_raw_grams_structural(
channel_blocks: &PrimaryChannelBlocks,
raw_block_ranges: &[std::ops::Range<usize>],
) -> Array2<f64> {
let num_blocks = channel_blocks.blocks.len();
assert_eq!(
num_blocks,
raw_block_ranges.len(),
"channel_blocks ({num_blocks}) and raw_block_ranges ({}) length mismatch",
raw_block_ranges.len()
);
if num_blocks == 0 {
return Array2::<f64>::zeros((0, 0));
}
let p_total = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
let mut gram = Array2::<f64>::zeros((p_total, p_total));
for a in 0..num_blocks {
let range_a = raw_block_ranges[a].clone();
for b in a..num_blocks {
let range_b = raw_block_ranges[b].clone();
let p_a = range_a.end - range_a.start;
let p_b = range_b.end - range_b.start;
let k_a = channel_blocks.blocks[a].len();
let k_b = channel_blocks.blocks[b].len();
assert_eq!(
k_a, k_b,
"structural Gram: block {a} has {k_a} channels but block {b} has {k_b}",
);
let mut block_acc = Array2::<f64>::zeros((p_a, p_b));
for c in 0..k_a {
let (Some(x_a_c), Some(x_b_c)) = (
channel_blocks.blocks[a][c].as_ref(),
channel_blocks.blocks[b][c].as_ref(),
) else {
continue;
};
let contrib = if a == b {
fast_ata(x_a_c)
} else {
fast_atb(x_a_c, x_b_c)
};
block_acc += &contrib;
}
gram.slice_mut(s![range_a.start..range_a.end, range_b.start..range_b.end])
.assign(&block_acc);
}
}
for i in 0..p_total {
for j in 0..i {
let v = gram[[j, i]];
gram[[i, j]] = v;
}
}
gram
}
pub fn build_primary_grams_gpu_or_cpu(
channel_blocks: &PrimaryChannelBlocks,
row_hess: &dyn RowHessian,
raw_block_ranges: &[std::ops::Range<usize>],
) -> Result<(Array2<f64>, Array2<f64>), CompilerError> {
let k = row_hess.k();
if k == crate::gpu::identifiability_compile::CHANNELS {
let gpu_blocks: Vec<Vec<Option<Array2<f64>>>> = channel_blocks
.blocks
.iter()
.map(|slots| slots.iter().cloned().collect())
.collect();
if let Some(h_packed) = pack_row_hessian_symmetric(row_hess) {
if let Some(bundle) = crate::gpu::identifiability_compile::try_primary_state_gram_cuda(
&gpu_blocks,
&h_packed,
raw_block_ranges,
) {
log::info!("[identifiability_compile] gram path = gpu");
return Ok((bundle.gram_h, bundle.gram_struct));
}
}
}
log::info!("[identifiability_compile] gram path = cpu");
let gram_h = build_raw_grams_from_channel_blocks(channel_blocks, row_hess, raw_block_ranges)?;
let gram_struct = build_raw_grams_structural(channel_blocks, raw_block_ranges);
Ok((gram_h, gram_struct))
}
fn pack_row_hessian_symmetric(row_hess: &dyn RowHessian) -> Option<Array2<f64>> {
use crate::gpu::identifiability_compile::{CHANNELS, PACKED_LEN, packed_index};
if row_hess.k() != CHANNELS {
return None;
}
let n = row_hess.nrows();
let h_full = row_hess.evaluate_full();
if h_full.shape() != [n, CHANNELS, CHANNELS] {
return None;
}
let mut packed = Array2::<f64>::zeros((n, PACKED_LEN));
for i in 0..n {
for c in 0..CHANNELS {
for d in c..CHANNELS {
packed[[i, packed_index(c, d)]] = h_full[[i, c, d]];
}
}
}
Some(packed)
}
#[derive(Debug)]
pub struct CompiledMap {
pub raw_from_compiled: Array2<f64>,
pub compiled_block_ranges: Vec<std::ops::Range<usize>>,
pub raw_block_ranges: Vec<std::ops::Range<usize>>,
}
pub fn compile_from_raw_grams(
gram_h: &Array2<f64>,
gram_struct: &Array2<f64>,
raw_block_ranges: &[std::ops::Range<usize>],
ordering: &[BlockOrder],
) -> Result<CompiledMap, CompilerError> {
if raw_block_ranges.len() != ordering.len() {
return Err(CompilerError::DimensionMismatch(format!(
"raw_block_ranges ({}) and ordering ({}) length mismatch",
raw_block_ranges.len(),
ordering.len()
)));
}
let p_raw = raw_block_ranges.last().map(|r| r.end).unwrap_or(0);
if gram_h.shape() != [p_raw, p_raw] {
return Err(CompilerError::DimensionMismatch(format!(
"gram_h shape {:?} != [p_raw={p_raw}, p_raw={p_raw}]",
gram_h.shape()
)));
}
if gram_struct.shape() != [p_raw, p_raw] {
return Err(CompilerError::DimensionMismatch(format!(
"gram_struct shape {:?} != [p_raw={p_raw}, p_raw={p_raw}]",
gram_struct.shape()
)));
}
if raw_block_ranges.is_empty() {
return Ok(CompiledMap {
raw_from_compiled: Array2::<f64>::zeros((0, 0)),
compiled_block_ranges: Vec::new(),
raw_block_ranges: Vec::new(),
});
}
let mut expected_start = 0usize;
for (b, r) in raw_block_ranges.iter().enumerate() {
if r.start != expected_start {
return Err(CompilerError::DimensionMismatch(format!(
"raw_block_ranges must be contiguous from 0; block {b} starts at {} expected {expected_start}",
r.start
)));
}
expected_start = r.end;
}
let mut t_cum: Array2<f64> = Array2::<f64>::zeros((p_raw, 0));
let mut compiled_block_ranges: Vec<std::ops::Range<usize>> =
Vec::with_capacity(raw_block_ranges.len());
for (idx, range_b) in raw_block_ranges.iter().enumerate() {
let p_b = range_b.end - range_b.start;
let ks_t = fast_ab(gram_struct, &t_cum);
let g_s_aa = fast_atb(&t_cum, &ks_t);
let ks_pb = gram_struct
.slice(s![.., range_b.start..range_b.end])
.to_owned();
let g_s_ab = fast_atb(&t_cum, &ks_pb);
let g_s_bb = gram_struct
.slice(s![range_b.start..range_b.end, range_b.start..range_b.end])
.to_owned();
let r_s = solve_psd_system(&g_s_aa, &g_s_ab)?;
let g_s_res_raw = &g_s_bb - &fast_atb(&g_s_ab, &r_s);
let g_s_res = symmetrise(&g_s_res_raw);
let g_s_bb_trace: f64 = (0..p_b).map(|i| g_s_bb[[i, i]].max(0.0)).sum();
let q_plus = keep_positive_eigenspace(&g_s_res, p_raw, 1, g_s_bb_trace)?;
if q_plus.ncols() == 0 {
return Err(CompilerError::FullyAliased {
block_idx: idx,
reason: format!(
"structural residual Gram has no positive eigenspace (block of width {p_b} fully aliased by cumulative anchor in K^S)"
),
});
}
let mut diff = Array2::<f64>::zeros((p_raw, p_b));
if t_cum.ncols() > 0 {
let t_rs = fast_ab(&t_cum, &r_s);
for i in 0..p_raw {
for j in 0..p_b {
diff[[i, j]] = -t_rs[[i, j]];
}
}
}
for j in 0..p_b {
diff[[range_b.start + j, j]] += 1.0;
}
let d_mat = fast_ab(&diff, &q_plus);
let kh_t = fast_ab(gram_h, &t_cum);
let g_h_aa = fast_atb(&t_cum, &kh_t);
let kh_d = fast_ab(gram_h, &d_mat);
let g_h_ad = fast_atb(&t_cum, &kh_d);
let r_h = solve_psd_system(&g_h_aa, &g_h_ad)?;
let d_t_kh_d = fast_atb(&d_mat, &kh_d);
let g_h_res_raw = &d_t_kh_d - &fast_atb(&g_h_ad, &r_h);
let g_h_res = symmetrise(&g_h_res_raw);
let k_kept = q_plus.ncols();
let g_h_dd_trace: f64 = (0..k_kept).map(|i| d_t_kh_d[[i, i]].max(0.0)).sum();
let u_mat = keep_positive_eigenspace(&g_h_res, p_raw, 1, g_h_dd_trace)?;
if u_mat.ncols() == 0 {
return Err(CompilerError::FullyAliased {
block_idx: idx,
reason: format!(
"curvature residual Gram has no positive eigenspace within structurally-kept basis (block of width {p_b}, structural-kept {k_kept})"
),
});
}
let mut e_mat = d_mat.clone();
if t_cum.ncols() > 0 {
let t_rh = fast_ab(&t_cum, &r_h);
e_mat = &e_mat - &t_rh;
}
let t_b = fast_ab(&e_mat, &u_mat);
let start = t_cum.ncols();
let end = start + t_b.ncols();
compiled_block_ranges.push(start..end);
t_cum = concat_cols(&t_cum, &t_b);
}
for v in t_cum.iter() {
if !v.is_finite() {
return Err(CompilerError::LinalgFailure(
"compile_from_raw_grams produced non-finite entry in raw_from_compiled".to_string(),
));
}
}
Ok(CompiledMap {
raw_from_compiled: t_cum,
compiled_block_ranges,
raw_block_ranges: raw_block_ranges.to_vec(),
})
}
fn symmetrise(m: &Array2<f64>) -> Array2<f64> {
let (r, c) = m.dim();
assert_eq!(r, c, "symmetrise expects square matrix");
let mut out = Array2::<f64>::zeros((r, c));
for i in 0..r {
for j in 0..c {
out[[i, j]] = 0.5 * (m[[i, j]] + m[[j, i]]);
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2};
struct DenseScalarOperator {
design: Array2<f64>,
}
impl DenseScalarOperator {
fn new(design: Array2<f64>) -> Self {
Self { design }
}
}
impl RowJacobianOperator for DenseScalarOperator {
fn k(&self) -> usize {
1
}
fn ncols(&self) -> usize {
self.design.ncols()
}
fn nrows(&self) -> usize {
self.design.nrows()
}
fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
assert_eq!(out.len(), 1);
let mut acc = 0.0;
for (j, &b) in delta_beta.iter().enumerate() {
acc += self.design[[row, j]] * b;
}
out[0] = acc;
}
fn evaluate_full(&self) -> Array3<f64> {
let n = self.design.nrows();
let p = self.design.ncols();
let mut out = Array3::<f64>::zeros((n, p, 1));
for i in 0..n {
for j in 0..p {
out[[i, j, 0]] = self.design[[i, j]];
}
}
out
}
}
struct DiagonalScalarRowHessian {
w: Array1<f64>,
}
impl DiagonalScalarRowHessian {
fn new(w: Array1<f64>) -> Self {
Self { w }
}
}
impl RowHessian for DiagonalScalarRowHessian {
fn k(&self) -> usize {
1
}
fn nrows(&self) -> usize {
self.w.len()
}
fn fill_row(&self, row: usize, out: &mut [f64]) {
assert_eq!(out.len(), 1);
out[0] = self.w[row];
}
fn evaluate_full(&self) -> Array3<f64> {
let n = self.w.len();
let mut out = Array3::<f64>::zeros((n, 1, 1));
for i in 0..n {
out[[i, 0, 0]] = self.w[i];
}
out
}
}
fn op(design: Array2<f64>) -> Arc<dyn RowJacobianOperator> {
Arc::new(DenseScalarOperator::new(design))
}
#[test]
fn compile_two_block_orthogonalises_under_metric() {
let n = 50;
let a = Array2::from_shape_fn((n, 3), |(i, j)| ((i + 1) as f64).sin().powi((j + 1) as i32));
let b = Array2::from_shape_fn((n, 2), |(i, j)| {
0.5 * a[[i, 0]] + ((i as f64) * 0.13 + j as f64).cos()
});
let hess = IdentityRowHessian::new(n, 1);
let ops = vec![op(a.clone()), op(b.clone())];
let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
.expect("compile should succeed");
let v_b = &compiled.blocks[1].t_lw;
let m_b = compiled.blocks[1]
.anchor_correction
.as_ref()
.expect("second block must carry an anchor correction");
let b_v = b.dot(v_b);
let a_m = a.dot(m_b);
let b_compiled = &b_v - &a_m;
let cross = a.t().dot(&b_compiled);
let max_err = cross.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(
max_err < 1e-10,
"orthogonality residual too large: {max_err:e}"
);
}
#[test]
fn compile_three_block_chain() {
let n = 80;
let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.1 + j as f64).sin());
let b = Array2::from_shape_fn((n, 2), |(i, j)| {
0.3 * a[[i, 0]] + (j as f64) * (i as f64).cos()
});
let c = Array2::from_shape_fn((n, 2), |(i, j)| {
0.2 * a[[i, 1]] + 0.4 * b[[i, 0]] + ((i + j) as f64).tan().min(5.0).max(-5.0)
});
let hess = IdentityRowHessian::new(n, 1);
let ops = vec![op(a), op(b), op(c)];
let compiled = compile(
&ops,
&hess,
&[
BlockOrder::Marginal,
BlockOrder::Logslope,
BlockOrder::LinkDev,
],
)
.expect("compile should succeed");
let total: usize = compiled.blocks.iter().map(|b| b.t_lw.ncols()).sum();
assert_eq!(
compiled.joint_rank, total,
"audit must report full rank on synthetic full-rank design"
);
}
#[test]
fn compile_weighted_metric_nontrivial() {
let n = 32;
let a: Array2<f64> = Array2::from_shape_fn((n, 1), |(i, _)| (i as f64 + 1.0).sqrt());
let b: Array2<f64> =
Array2::from_shape_fn((n, 1), |(i, _)| 0.7 * a[[i, 0]] + (i as f64 * 0.05).cos());
let w = Array1::from_shape_fn(n, |i| 0.5 + (i as f64 * 0.2).sin().abs());
let hess = DiagonalScalarRowHessian::new(w.clone());
let ops = vec![op(a.clone()), op(b.clone())];
let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
.expect("compile should succeed");
let m = compiled.blocks[1]
.anchor_correction
.as_ref()
.expect("anchor correction present");
let analytic_num: f64 = (0..n).map(|i| w[i] * a[[i, 0]] * b[[i, 0]]).sum();
let analytic_den: f64 = (0..n).map(|i| w[i] * a[[i, 0]] * a[[i, 0]]).sum();
let analytic = analytic_num / analytic_den;
assert!(m.dim() == (1, 1));
assert!(
(m[[0, 0]] - analytic).abs() < 1e-10,
"weighted projection mismatch: got {got}, analytic {analytic}",
got = m[[0, 0]]
);
}
#[test]
fn compile_drops_trailing_pivots_from_latest_block() {
let n = 40;
let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 + 1.0).ln() * (j as f64 + 1.0));
let c = Array2::from_shape_fn((n, 2), |(i, j)| {
if j == 0 {
a[[i, 0]]
} else {
(i as f64 * 0.1).cos()
}
});
let hess = IdentityRowHessian::new(n, 1);
let ops = vec![op(a), op(c)];
let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
.expect("compile should succeed");
let v1_cols = compiled.blocks[1].t_lw.ncols();
assert!(
v1_cols < 2 || !compiled.dropped.is_empty(),
"expected rank loss attributed to block 1, got v1_cols={v1_cols}, dropped={dropped:?}",
dropped = compiled.dropped
);
for (block_idx, _) in &compiled.dropped {
assert_eq!(
*block_idx, 1,
"audit drops must come from the latest block only"
);
}
}
#[test]
fn audit_truncation_keeps_t_lw_and_anchor_correction_in_lockstep() {
let n = 40;
let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 + 1.0).ln() * (j as f64 + 1.0));
let c = Array2::from_shape_fn((n, 2), |(i, j)| {
if j == 0 {
a[[i, 0]]
} else {
(i as f64 * 0.1).cos()
}
});
let hess = IdentityRowHessian::new(n, 1);
let ops = vec![op(a), op(c)];
let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
.expect("compile should succeed");
for (idx, block) in compiled.blocks.iter().enumerate() {
let k_kept = block.t_lw.ncols();
if let Some(m) = block.anchor_correction.as_ref() {
assert_eq!(
m.ncols(),
k_kept,
"block {idx}: anchor_correction.ncols()={ac} must equal t_lw.ncols()={k_kept} \
after audit truncation",
ac = m.ncols(),
);
}
if let Some(r) = block.r_lw.as_ref() {
assert_eq!(
r.ncols(),
k_kept,
"block {idx}: r_lw.ncols()={r_cols} must equal t_lw.ncols()={k_kept} \
after audit truncation",
r_cols = r.ncols(),
);
}
}
}
struct MockAnchorEvaluator {
rows: Array2<f64>,
}
impl AnchorRowEvaluator for MockAnchorEvaluator {
fn anchor_rows(&self, predict_arg: &Array1<f64>) -> Result<Array2<f64>, String> {
assert_eq!(
predict_arg.len(),
self.rows.nrows(),
"MockAnchorEvaluator: predict_arg length {} must match stored rows {}",
predict_arg.len(),
self.rows.nrows(),
);
Ok(self.rows.clone())
}
fn ncols(&self) -> usize {
self.rows.ncols()
}
}
#[test]
fn compile_flex_anchor_is_first_class() {
let n = 60;
let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.07 + j as f64).sin());
let b = Array2::from_shape_fn((n, 2), |(i, j)| {
0.4 * a[[i, 0]] + (j as f64) * (i as f64 + 1.0).ln()
});
let hess = IdentityRowHessian::new(n, 1);
let ops_param = vec![op(a.clone()), op(b.clone())];
let compiled_param = compile(
&ops_param,
&hess,
&[BlockOrder::Marginal, BlockOrder::Logslope],
)
.expect("compile should succeed");
let _flex_eval: Arc<dyn AnchorRowEvaluator> =
Arc::new(MockAnchorEvaluator { rows: a.clone() });
let ops_flex = vec![op(a.clone()), op(b.clone())];
let compiled_flex = compile(
&ops_flex,
&hess,
&[BlockOrder::ScoreWarp, BlockOrder::LinkDev],
)
.expect("compile should succeed");
let m_param = compiled_param.blocks[1].anchor_correction.as_ref().unwrap();
let m_flex = compiled_flex.blocks[1].anchor_correction.as_ref().unwrap();
assert_eq!(m_param.dim(), m_flex.dim());
let max_diff = (m_param - m_flex)
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(
max_diff < 1e-12,
"flex vs parametric anchor correction mismatch: {max_diff:e}"
);
}
#[test]
fn bernoulli_row_hessian_matches_irls_weight() {
let w = Array1::from(vec![0.1, 0.5, 0.9, 0.25, 0.75]);
let hess = DiagonalScalarRowHessian::new(w.clone());
let full = hess.evaluate_full();
assert_eq!(full.shape(), &[5, 1, 1]);
for i in 0..5 {
assert_eq!(full[[i, 0, 0]], w[i]);
let mut buf = [0.0_f64; 1];
hess.fill_row(i, &mut buf);
assert_eq!(buf[0], w[i]);
}
}
#[test]
fn compiler_predict_path_roundtrip() {
let n = 24;
let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.21).cos() + j as f64);
let b = Array2::from_shape_fn((n, 2), |(i, j)| {
0.3 * a[[i, 0]] + (i as f64 + j as f64).sqrt()
});
let hess = IdentityRowHessian::new(n, 1);
let ops = vec![op(a.clone()), op(b.clone())];
let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
.expect("compile should succeed");
let v_b = &compiled.blocks[1].t_lw;
let m_b = compiled.blocks[1].anchor_correction.as_ref().unwrap();
let predict_design = b.dot(v_b) - a.dot(m_b);
assert_eq!(predict_design.nrows(), n);
assert_eq!(predict_design.ncols(), v_b.ncols());
for &val in predict_design.iter() {
assert!(val.is_finite(), "predict design produced non-finite entry");
}
}
#[test]
fn compile_exposes_r_lw_equal_to_m_dot_v() {
let n = 40;
let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.17 + j as f64).sin());
let b = Array2::from_shape_fn((n, 2), |(i, j)| {
0.6 * a[[i, 0]] + ((i as f64) * 0.11 + j as f64).cos()
});
let hess = IdentityRowHessian::new(n, 1);
let ops = vec![op(a.clone()), op(b.clone())];
let compiled = compile(&ops, &hess, &[BlockOrder::Marginal, BlockOrder::Logslope])
.expect("compile should succeed");
assert!(compiled.blocks[0].r_lw.is_none());
assert!(compiled.blocks[0].anchor_correction.is_none());
let v_a = &compiled.blocks[0].t_lw;
let v_b = &compiled.blocks[1].t_lw;
let m_compiled = compiled.blocks[1]
.anchor_correction
.as_ref()
.expect("second block must carry an anchor correction");
let r_lw = compiled.blocks[1]
.r_lw
.as_ref()
.expect("second block must expose r_lw");
let p_a_kept = v_a.ncols();
let p_b_kept = v_b.ncols();
assert_eq!(
m_compiled.dim(),
(p_a_kept, p_b_kept),
"anchor_correction must be at compiled width"
);
assert_eq!(r_lw.dim(), (p_a_kept, p_b_kept));
let diff = r_lw - m_compiled;
let max_diff = diff.iter().fold(0.0_f64, |acc, &x| acc.max(x.abs()));
assert!(
max_diff == 0.0,
"r_lw and anchor_correction must be identical"
);
let b_compiled = b.dot(v_b) - a.dot(m_compiled);
let cross = a.t().dot(&b_compiled);
let max_cross = cross.iter().fold(0.0_f64, |acc, &x| acc.max(x.abs()));
assert!(
max_cross < 1e-10,
"compiled B-design must be H-orthogonal to A: max cross = {max_cross:e}"
);
}
struct DenseRowHessian {
h: Array3<f64>,
}
impl RowHessian for DenseRowHessian {
fn k(&self) -> usize {
self.h.shape()[1]
}
fn nrows(&self) -> usize {
self.h.shape()[0]
}
fn fill_row(&self, row: usize, out: &mut [f64]) {
let k = self.k();
assert_eq!(out.len(), k * k);
for c in 0..k {
for d in 0..k {
out[c * k + d] = self.h[[row, c, d]];
}
}
}
fn evaluate_full(&self) -> Array3<f64> {
self.h.clone()
}
}
fn reference_gram_from_w(j_full: &Array3<f64>, h_full: &Array3<f64>) -> Array2<f64> {
let w = scale_block_by_sqrt_h(j_full, h_full);
fast_ata(&w)
}
#[test]
fn closed_form_gram_matches_reference_two_block_k4() {
let n = 17;
let k = 4;
let p_a = 3;
let p_b = 2;
let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
(0..4)
.map(|c| {
let m = Array2::from_shape_fn((n, p), |(i, j)| {
((i as f64 + 1.0) * (j as f64 + 1.0) * (c as f64 + 1.0) + seed).sin()
});
Some(m)
})
.collect()
};
let block_a = make_block(0.3, n, p_a);
let block_b = make_block(1.1, n, p_b);
let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
let mut acc = 0.0;
for r in 0..k {
let mc = ((i + 1) as f64 * (c + 1) as f64 * (r + 1) as f64 * 0.13).cos();
let md = ((i + 1) as f64 * (d + 1) as f64 * (r + 1) as f64 * 0.13).cos();
acc += mc * md;
}
acc + if c == d { 0.5 } else { 0.0 }
});
let row_hess = DenseRowHessian { h: h.clone() };
let channel_blocks = PrimaryChannelBlocks {
blocks: vec![block_a.clone(), block_b.clone()],
};
let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
let gram = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
.expect("closed-form Gram should succeed");
let p_total = p_a + p_b;
let mut j_full = Array3::<f64>::zeros((n, p_total, k));
for c in 0..k {
if let Some(xa) = block_a[c].as_ref() {
for i in 0..n {
for j in 0..p_a {
j_full[[i, j, c]] = xa[[i, j]];
}
}
}
if let Some(xb) = block_b[c].as_ref() {
for i in 0..n {
for j in 0..p_b {
j_full[[i, p_a + j, c]] = xb[[i, j]];
}
}
}
}
let ref_gram = reference_gram_from_w(&j_full, &h);
let diff = &gram - &ref_gram;
let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
let scale = ref_gram.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(
max_err < 1e-9 * scale.max(1.0),
"closed-form Gram mismatches reference: max_err={max_err:e}, scale={scale:e}"
);
for i in 0..p_total {
for j in 0..p_total {
assert!(
(gram[[i, j]] - gram[[j, i]]).abs() < 1e-12,
"closed-form Gram not symmetric at ({i},{j})"
);
}
}
}
#[test]
fn closed_form_gram_channel_sparsity() {
let n = 13;
let k = 4;
let p_a = 2;
let p_b = 2;
let xa = Array2::from_shape_fn((n, p_a), |(i, j)| ((i + 1) as f64 * 0.21 + j as f64).cos());
let xb = Array2::from_shape_fn((n, p_b), |(i, j)| {
((i + 1) as f64 * 0.17 + j as f64).sin() + 0.5
});
let block_a: Vec<Option<Array2<f64>>> = vec![Some(xa.clone()), None, None, None];
let block_b: Vec<Option<Array2<f64>>> = vec![None, None, None, Some(xb.clone())];
let h_03_vec = Array1::from_shape_fn(n, |i| 0.7 + 0.3 * ((i as f64) * 0.4).sin());
let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
if (c, d) == (0, 3) || (c, d) == (3, 0) {
h_03_vec[i]
} else if c == d {
2.0
} else {
0.0
}
});
let row_hess = DenseRowHessian { h: h.clone() };
let channel_blocks = PrimaryChannelBlocks {
blocks: vec![block_a.clone(), block_b.clone()],
};
let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
let gram = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
.expect("closed-form Gram should succeed");
let cross = gram.slice(s![0..p_a, p_a..(p_a + p_b)]).to_owned();
let expected = fast_xt_diag_y(&xa, &h_03_vec, &xb);
let diff = &cross - &expected;
let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(
max_err < 1e-12,
"cross-block Gram must equal Xaᵀ·diag(h_03)·Xb: max_err={max_err:e}"
);
let h_zero = Array3::from_shape_fn((n, k, k), |(_, c, d)| if c == d { 2.0 } else { 0.0 });
let row_hess_zero = DenseRowHessian { h: h_zero };
let gram_zero =
build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess_zero, &raw_ranges)
.expect("closed-form Gram should succeed");
let cross_zero = gram_zero.slice(s![0..p_a, p_a..(p_a + p_b)]);
let max_zero = cross_zero.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(
max_zero < 1e-12,
"cross-block Gram must vanish when coupling channel pair is zero: got {max_zero:e}"
);
}
#[test]
fn structural_gram_matches_within_channel_sum() {
let n = 11;
let p_a = 2;
let p_b = 3;
let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
(0..4)
.map(|c| {
if c == 1 {
return None;
}
Some(Array2::from_shape_fn((n, p), |(i, j)| {
((i as f64 + 1.0) * (j as f64 + 1.0) + seed * (c as f64 + 1.0)).sin()
}))
})
.collect()
};
let block_a = make_block(0.1, n, p_a);
let block_b = make_block(0.7, n, p_b);
let channel_blocks = PrimaryChannelBlocks {
blocks: vec![block_a.clone(), block_b.clone()],
};
let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
let gram = build_raw_grams_structural(&channel_blocks, &raw_ranges);
let mut expected_cross = Array2::<f64>::zeros((p_a, p_b));
for c in 0..4 {
if let (Some(xa), Some(xb)) = (block_a[c].as_ref(), block_b[c].as_ref()) {
expected_cross += &fast_atb(xa, xb);
}
}
let cross = gram.slice(s![0..p_a, p_a..(p_a + p_b)]).to_owned();
let diff = &cross - &expected_cross;
let max_err = diff.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(
max_err < 1e-12,
"structural cross-block must equal Σ_c Xaᵀ·Xb: max_err={max_err:e}"
);
for i in 0..(p_a + p_b) {
for j in 0..(p_a + p_b) {
assert!(
(gram[[i, j]] - gram[[j, i]]).abs() < 1e-12,
"structural Gram not symmetric at ({i},{j})"
);
}
}
}
fn diag_hess(w: Array1<f64>) -> DiagonalScalarRowHessian {
DiagonalScalarRowHessian::new(w)
}
#[test]
fn dual_metric_with_equal_metrics_matches_single_metric() {
let n = 36;
let a = Array2::from_shape_fn((n, 2), |(i, j)| (i as f64 * 0.13 + j as f64).sin());
let b = Array2::from_shape_fn((n, 2), |(i, j)| {
0.4 * a[[i, 0]] + (i as f64 * 0.07 + j as f64).cos()
});
let w = Array1::from_shape_fn(n, |i| 0.5 + (i as f64 * 0.17).sin().abs());
let curvature = diag_hess(w.clone());
let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
let ops_single = vec![op(a.clone()), op(b.clone())];
let single = compile(&ops_single, &curvature, &ordering)
.expect("single-metric compile should succeed");
let structural_same = diag_hess(w.clone());
let ops_dual = vec![op(a.clone()), op(b.clone())];
let dual = compile_with_dual_metric(&ops_dual, &curvature, &structural_same, &ordering)
.expect("dual-metric compile should succeed");
assert_eq!(single.blocks.len(), dual.blocks.len());
for (idx, (sb, db)) in single.blocks.iter().zip(dual.blocks.iter()).enumerate() {
assert_eq!(sb.t_lw.dim(), db.t_lw.dim(), "block {idx}: V dims differ");
let max_v = (&sb.t_lw - &db.t_lw)
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(max_v < 1e-10, "block {idx}: V mismatch {max_v:e}");
match (sb.anchor_correction.as_ref(), db.anchor_correction.as_ref()) {
(None, None) => {}
(Some(s), Some(d)) => {
assert_eq!(s.dim(), d.dim());
let max_m = (s - d).iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(max_m < 1e-10, "block {idx}: M mismatch {max_m:e}");
}
_ => panic!("block {idx}: one side has anchor correction, the other does not"),
}
}
assert_eq!(single.joint_rank, dual.joint_rank);
}
#[test]
fn dual_metric_resists_pilot_curvature_alias() {
let n = 12;
let a = Array2::from_shape_fn((n, 1), |(i, _)| (i as f64) + 1.0);
let b = Array2::from_shape_fn((n, 1), |(i, _)| {
if i < 6 {
2.0 * a[[i, 0]]
} else {
((i as f64) * 0.3).cos() + 0.5
}
});
let mut w_vec = vec![0.0_f64; n];
for w in &mut w_vec[..6] {
*w = 1.0;
}
let w = Array1::from(w_vec);
let curvature = diag_hess(w.clone());
let id_struct = IdentityRowHessian::new(n, 1);
let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
let ops_dual = vec![op(a.clone()), op(b.clone())];
let dual = compile_with_dual_metric(&ops_dual, &curvature, &id_struct, &ordering);
let ops_h_only = vec![op(a.clone()), op(b.clone())];
let h_only = compile_with_dual_metric(&ops_h_only, &curvature, &curvature, &ordering);
match h_only {
Err(CompilerError::FullyAliased { block_idx, .. }) => {
assert_eq!(block_idx, 1, "H-only path must alias block 1");
}
Ok(out) => {
let v1_cols = out.blocks[1].t_lw.ncols();
assert!(
v1_cols == 0 || !out.dropped.is_empty(),
"H-only path should reject B's curvature-aliased column; v1_cols={v1_cols}, dropped={dropped:?}",
dropped = out.dropped,
);
}
Err(other) => panic!("unexpected H-only error: {other:?}"),
}
let dual =
dual.expect("dual-metric must succeed: identity-structural sees B as independent");
assert_eq!(dual.blocks.len(), 2);
assert_eq!(dual.blocks[0].t_lw.ncols(), 1, "A must keep its column");
let v1_post_audit = dual.blocks[1].t_lw.ncols();
let dropped_count = dual.dropped.len();
assert_eq!(
v1_post_audit + dropped_count,
1,
"structural pass kept B's column; audit may demote it but the pre-audit width was 1"
);
}
#[test]
fn dual_metric_identity_structural_preserves_full_rank() {
let n = 24;
let a = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 + j as f64).sqrt());
let b = Array2::from_shape_fn((n, 2), |(i, j)| {
((i + 1) as f64).ln() + (i as f64 * 0.1 + j as f64).cos()
});
let w = Array1::from_shape_fn(n, |i| 0.4 + (i as f64 * 0.05).sin().powi(2));
let curvature = diag_hess(w.clone());
let id_struct = IdentityRowHessian::new(n, 1);
let ordering = [BlockOrder::Marginal, BlockOrder::Logslope];
let ops = vec![op(a.clone()), op(b.clone())];
let out =
compile_with_dual_metric(&ops, &curvature, &id_struct, &ordering).expect("compile");
assert_eq!(out.blocks[0].t_lw.ncols(), 2);
assert_eq!(out.blocks[1].t_lw.ncols(), 2);
assert_eq!(out.dropped.len(), 0);
assert_eq!(out.joint_rank, 4);
}
#[test]
fn build_primary_grams_gpu_or_cpu_two_block_k4_matches_cpu() {
let n = 11;
let k = 4;
let p_a = 2;
let p_b = 3;
let make_block = |seed: f64, n: usize, p: usize| -> Vec<Option<Array2<f64>>> {
(0..4)
.map(|c| {
let m = Array2::from_shape_fn((n, p), |(i, j)| {
((i as f64 + 1.0) * (j as f64 + 1.0) * (c as f64 + 1.0) + seed).sin()
});
Some(m)
})
.collect()
};
let block_a = make_block(0.7, n, p_a);
let block_b = make_block(-0.4, n, p_b);
let h = Array3::from_shape_fn((n, k, k), |(i, c, d)| {
let mut acc = 0.0;
for r in 0..k {
let mc = ((i + 1) as f64 * (c + 1) as f64 * (r + 1) as f64 * 0.11).cos();
let md = ((i + 1) as f64 * (d + 1) as f64 * (r + 1) as f64 * 0.11).cos();
acc += mc * md;
}
acc + if c == d { 0.25 } else { 0.0 }
});
let row_hess = DenseRowHessian { h: h.clone() };
let channel_blocks = PrimaryChannelBlocks {
blocks: vec![block_a, block_b],
};
let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
let (gram_h, gram_struct) =
build_primary_grams_gpu_or_cpu(&channel_blocks, &row_hess, &raw_ranges)
.expect("dispatch helper should succeed");
let cpu_h = build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
.expect("CPU curvature Gram should succeed");
let cpu_s = build_raw_grams_structural(&channel_blocks, &raw_ranges);
let tol = 1e-9_f64;
for idx in cpu_h.indexed_iter().map(|(i, _)| i) {
let diff = (gram_h[idx] - cpu_h[idx]).abs();
let scale = cpu_h[idx].abs().max(1.0);
assert!(
diff <= tol * scale,
"gram_h mismatch at {idx:?}: helper={} cpu={}",
gram_h[idx],
cpu_h[idx]
);
}
for idx in cpu_s.indexed_iter().map(|(i, _)| i) {
let diff = (gram_struct[idx] - cpu_s[idx]).abs();
let scale = cpu_s[idx].abs().max(1.0);
assert!(
diff <= tol * scale,
"gram_struct mismatch at {idx:?}: helper={} cpu={}",
gram_struct[idx],
cpu_s[idx]
);
}
}
fn scalar_grams_two_block(
a: &Array2<f64>,
b: &Array2<f64>,
w: &Array1<f64>,
) -> (Array2<f64>, Array2<f64>, Vec<std::ops::Range<usize>>) {
let n = a.nrows();
let p_a = a.ncols();
let p_b = b.ncols();
let raw_ranges = vec![0..p_a, p_a..(p_a + p_b)];
let channel_blocks = PrimaryChannelBlocks {
blocks: vec![vec![Some(a.clone())], vec![Some(b.clone())]],
};
let row_hess = DiagonalScalarRowHessian::new(w.clone());
let _ = n; let gram_h =
build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges).unwrap();
let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
(gram_h, gram_struct, raw_ranges)
}
#[test]
fn compile_from_raw_grams_full_structural_alias() {
let n = 10;
let a = Array2::from_shape_fn((n, 2), |(i, j)| ((i + 1) as f64 * (j + 1) as f64).sin());
let l = Array2::from_shape_vec((2, 2), vec![1.0, 0.5, -0.25, 1.0]).unwrap();
let b = a.dot(&l);
let w = Array1::ones(n);
let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
let res = compile_from_raw_grams(
&gram_h,
&gram_struct,
&raw_ranges,
&[BlockOrder::Marginal, BlockOrder::Logslope],
);
match res {
Err(CompilerError::FullyAliased { block_idx, .. }) => {
assert_eq!(block_idx, 1, "alias should fire on block 1, not 0");
}
other => panic!("expected FullyAliased on block 1, got {other:?}"),
}
}
#[test]
fn compile_from_raw_grams_partial_alias_matches_w_reference() {
let n = 25;
let a = Array2::from_shape_fn((n, 2), |(i, j)| {
((i + 1) as f64 * (j + 1) as f64 * 0.3).sin()
});
let mut b = Array2::<f64>::zeros((n, 2));
for i in 0..n {
b[[i, 0]] = a[[i, 0]];
b[[i, 1]] = ((i + 1) as f64 * 0.7).cos();
}
let w = Array1::from_shape_fn(n, |i| 1.0 + 0.1 * (i as f64));
let (gram_h, gram_struct, raw_ranges) = scalar_grams_two_block(&a, &b, &w);
let compiled = compile_from_raw_grams(
&gram_h,
&gram_struct,
&raw_ranges,
&[BlockOrder::Marginal, BlockOrder::Logslope],
)
.expect("closed-form compile must succeed");
let p_a = a.ncols();
let p_b = b.ncols();
assert_eq!(compiled.raw_from_compiled.shape()[0], p_a + p_b);
assert_eq!(
compiled.raw_from_compiled.shape()[1],
p_a + 1,
"partial alias should leave compiled width = p_a + 1 (one column dropped from B)"
);
assert_eq!(compiled.compiled_block_ranges[0], 0..p_a);
assert_eq!(
compiled.compiled_block_ranges[1].end - compiled.compiled_block_ranges[1].start,
1
);
let mut x_raw = Array2::<f64>::zeros((n, p_a + p_b));
for i in 0..n {
for j in 0..p_a {
x_raw[[i, j]] = a[[i, j]];
}
for j in 0..p_b {
x_raw[[i, p_a + j]] = b[[i, j]];
}
}
let x_compiled = fast_ab(&x_raw, &compiled.raw_from_compiled);
let g_compiled = fast_ata(&x_compiled);
let (evals, _) = g_compiled.eigh(Side::Lower).unwrap();
let lam_max = evals.iter().cloned().fold(0.0_f64, f64::max);
let tol = lam_max * 64.0 * (g_compiled.nrows() as f64) * f64::EPSILON;
let rank_compiled = evals.iter().filter(|&&l| l > tol).count();
assert_eq!(
rank_compiled,
p_a + 1,
"compiled design column rank must equal p_a + 1 after dropping the alias"
);
let ops_dual: Vec<Arc<dyn RowJacobianOperator>> = vec![op(a.clone()), op(b.clone())];
let curvature = DiagonalScalarRowHessian::new(w.clone());
let id_struct = IdentityRowHessian::new(n, 1);
let dual = compile_with_dual_metric(
&ops_dual,
&curvature,
&id_struct,
&[BlockOrder::Marginal, BlockOrder::Logslope],
)
.expect("dual metric compile should succeed");
let dual_total: usize = dual.blocks.iter().map(|b| b.t_lw.ncols()).sum();
assert_eq!(dual_total, p_a + 1, "W-reference total width should match");
}
#[test]
fn compile_from_raw_grams_three_block_ordering_matters() {
let n = 30;
let a = Array2::from_shape_fn((n, 2), |(i, j)| {
((i + 1) as f64 * (j + 2) as f64 * 0.2).sin()
});
let mut b = Array2::<f64>::zeros((n, 2));
for i in 0..n {
b[[i, 0]] = ((i + 1) as f64 * 0.4).cos();
b[[i, 1]] = a[[i, 0]];
}
let mut c = Array2::<f64>::zeros((n, 2));
for i in 0..n {
c[[i, 0]] = ((i + 1) as f64 * 0.55).sin();
c[[i, 1]] = a[[i, 1]];
}
let w = Array1::ones(n);
let build = |b0: &Array2<f64>, b1: &Array2<f64>, b2: &Array2<f64>| {
let raw_ranges = vec![
0..b0.ncols(),
b0.ncols()..(b0.ncols() + b1.ncols()),
(b0.ncols() + b1.ncols())..(b0.ncols() + b1.ncols() + b2.ncols()),
];
let channel_blocks = PrimaryChannelBlocks {
blocks: vec![
vec![Some(b0.clone())],
vec![Some(b1.clone())],
vec![Some(b2.clone())],
],
};
let row_hess = DiagonalScalarRowHessian::new(w.clone());
let gram_h =
build_raw_grams_from_channel_blocks(&channel_blocks, &row_hess, &raw_ranges)
.unwrap();
let gram_struct = build_raw_grams_structural(&channel_blocks, &raw_ranges);
(gram_h, gram_struct, raw_ranges)
};
let (gh, gs, rr) = build(&a, &b, &c);
let order_abc = compile_from_raw_grams(
&gh,
&gs,
&rr,
&[
BlockOrder::Marginal,
BlockOrder::Logslope,
BlockOrder::LinkDev,
],
)
.expect("ABC compile");
assert_eq!(order_abc.compiled_block_ranges[0].len(), 2);
assert_eq!(order_abc.compiled_block_ranges[1].len(), 1);
assert_eq!(order_abc.compiled_block_ranges[2].len(), 1);
let (gh2, gs2, rr2) = build(&b, &a, &c);
let order_bac = compile_from_raw_grams(
&gh2,
&gs2,
&rr2,
&[
BlockOrder::Marginal,
BlockOrder::Logslope,
BlockOrder::LinkDev,
],
)
.expect("BAC compile");
assert_eq!(order_bac.compiled_block_ranges[0].len(), 2);
assert_eq!(order_bac.compiled_block_ranges[1].len(), 1);
let total_abc: usize = order_abc
.compiled_block_ranges
.iter()
.map(|r| r.len())
.sum();
let total_bac: usize = order_bac
.compiled_block_ranges
.iter()
.map(|r| r.len())
.sum();
assert_eq!(total_abc, total_bac);
assert_eq!(total_abc, 4);
}
}