use ndarray::{Array1, Array2};
use std::sync::Arc;
use super::jeffreys_subspace::{
floored_inverse, floored_inverse_divided_differences, jeffreys_antiderivative,
jeffreys_antiderivative_floor_sensitivity,
};
use super::reml_outer_engine::{PenaltyCoordinate, PenaltySubspaceTrace};
pub struct ThetaDirection {
pub index: Option<usize>,
pub beta_dot: Option<Arc<Array1<f64>>>,
pub h_dot_total: Option<Arc<Array2<f64>>>,
}
pub struct Sensitivity {
pub kernel: Arc<PenaltySubspaceTrace>,
pub logdet: f64,
pub stratum: StratumFingerprint,
}
impl Sensitivity {
pub fn fill_direction<F>(
&self,
index: usize,
op: &crate::solver::sensitivity::FitSensitivity<'_>,
f_beta_theta: &Array1<f64>,
h_dot_frozen: &Array2<f64>,
cubic_drift: F,
) -> Option<ThetaDirection>
where
F: FnOnce(&Array1<f64>) -> Array2<f64>,
{
let p = self.kernel.u_s.nrows();
if f_beta_theta.len() != p
|| op.dim() != p
|| h_dot_frozen.nrows() != p
|| h_dot_frozen.ncols() != p
{
return None;
}
let rhs = f_beta_theta.view().insert_axis(ndarray::Axis(1));
let beta_dot_col = op.mode_response(rhs)?;
let beta_dot = beta_dot_col.column(0).to_owned();
if beta_dot.iter().any(|v| !v.is_finite()) {
return None;
}
let mut h_dot_total = h_dot_frozen.clone();
h_dot_total += &cubic_drift(&beta_dot);
if h_dot_total.iter().any(|v| !v.is_finite()) {
return None;
}
Some(ThetaDirection {
index: Some(index),
beta_dot: Some(Arc::new(beta_dot)),
h_dot_total: Some(Arc::new(h_dot_total)),
})
}
}
pub struct StratumFingerprint {
pub kept_rank: usize,
pub min_relative_eigengap: f64,
}
pub struct BetaChannel {
pub grad_beta: Array1<f64>,
}
pub trait CriterionAtom {
fn name(&self) -> &'static str;
fn value(&self) -> f64;
fn frozen_d1(&self, dir: &ThetaDirection) -> f64;
fn beta_channel(&self) -> Option<BetaChannel>;
fn stratum(&self) -> Option<StratumFingerprint>;
}
pub struct CriterionSum {
pub atoms: Vec<Box<dyn CriterionAtom + Send + Sync>>,
}
impl CriterionSum {
pub fn value(&self) -> f64 {
self.atoms.iter().map(|a| a.value()).sum()
}
pub fn d1(&self, dir: &ThetaDirection) -> f64 {
let frozen: f64 = self.atoms.iter().map(|a| a.frozen_d1(dir)).sum();
let beta_dot = dir
.beta_dot
.as_ref()
.expect("calculus must fill beta_dot before profiled d1");
let mut chained = 0.0;
for atom in &self.atoms {
if let Some(channel) = atom.beta_channel() {
chained += channel.grad_beta.dot(beta_dot.as_ref());
}
}
frozen + chained
}
}
pub struct HessianLogdetAtom {
pub sensitivity: Arc<Sensitivity>,
}
impl CriterionAtom for HessianLogdetAtom {
fn name(&self) -> &'static str {
"hessian_logdet"
}
fn value(&self) -> f64 {
0.5 * self.sensitivity.logdet
}
fn frozen_d1(&self, dir: &ThetaDirection) -> f64 {
let h_dot = dir
.h_dot_total
.as_ref()
.expect("calculus fills h_dot_total before logdet d1");
0.5 * self.sensitivity.kernel.trace_projected_logdet(h_dot)
}
fn beta_channel(&self) -> Option<BetaChannel> {
None
}
fn stratum(&self) -> Option<StratumFingerprint> {
Some(StratumFingerprint {
kept_rank: self.sensitivity.stratum.kept_rank,
min_relative_eigengap: self.sensitivity.stratum.min_relative_eigengap,
})
}
}
pub struct SampledBlockAtom {
pub value: f64,
pub explicit: Array1<f64>,
pub q_bc: Arc<Array2<f64>>,
pub g_d: Array1<f64>,
pub stratum: StratumFingerprint,
}
impl CriterionAtom for SampledBlockAtom {
fn name(&self) -> &'static str {
"sampled_block_marginal"
}
fn value(&self) -> f64 {
self.value
}
fn frozen_d1(&self, dir: &ThetaDirection) -> f64 {
let explicit = match dir.index {
Some(idx) if idx < self.explicit.len() => self.explicit[idx],
_ => 0.0,
};
let h_dot = dir
.h_dot_total
.as_ref()
.expect("calculus fills h_dot_total before sampled-block d1");
let mut trace = 0.0;
for i in 0..h_dot.nrows() {
for j in 0..h_dot.ncols() {
trace += h_dot[[i, j]] * self.q_bc[[j, i]];
}
}
explicit + trace
}
fn beta_channel(&self) -> Option<BetaChannel> {
Some(BetaChannel {
grad_beta: self.g_d.clone(),
})
}
fn stratum(&self) -> Option<StratumFingerprint> {
Some(StratumFingerprint {
kept_rank: self.stratum.kept_rank,
min_relative_eigengap: self.stratum.min_relative_eigengap,
})
}
}
pub struct PenaltyQuadAtom {
pub lambdas: Array1<f64>,
pub block_quadratics: Array1<f64>,
pub penalty_score: Array1<f64>,
pub block_penalty_scores: Vec<Array1<f64>>,
pub stable_value: Option<f64>,
}
impl PenaltyQuadAtom {
pub(crate) fn from_penalty_coords(
lambdas: &[f64],
coords: &[PenaltyCoordinate],
beta: &Array1<f64>,
) -> Result<Self, String> {
if lambdas.len() != coords.len() {
return Err(format!(
"penalty quadratic atom dimension mismatch: lambdas={}, coords={}",
lambdas.len(),
coords.len()
));
}
let mut block_quadratics = Array1::<f64>::zeros(coords.len());
let mut penalty_score = Array1::<f64>::zeros(beta.len());
let mut block_penalty_scores = Vec::with_capacity(coords.len());
for (idx, (coord, &lambda)) in coords.iter().zip(lambdas.iter()).enumerate() {
if !lambda.is_finite() {
return Err(format!(
"penalty quadratic atom received non-finite lambda at coord {idx}: {lambda}"
));
}
let q_k = coord.shifted_quadratic(beta, 1.0);
if !q_k.is_finite() {
return Err(format!(
"penalty quadratic atom produced non-finite shifted quadratic at coord {idx}: {q_k}"
));
}
let score_k = coord.apply_shifted_penalty(beta, lambda);
if score_k.len() != beta.len() {
return Err(format!(
"penalty quadratic atom score length mismatch at coord {idx}: got {}, expected {}",
score_k.len(),
beta.len()
));
}
if score_k.iter().any(|v| !v.is_finite()) {
return Err(format!(
"penalty quadratic atom produced a non-finite beta score at coord {idx}"
));
}
block_quadratics[idx] = q_k;
penalty_score += &score_k;
block_penalty_scores.push(score_k);
}
Ok(Self {
lambdas: Array1::from_vec(lambdas.to_vec()),
block_quadratics,
penalty_score,
block_penalty_scores,
stable_value: None,
})
}
pub(crate) fn stable_value_only(half_stable_penalty_term: f64) -> Self {
Self {
lambdas: Array1::zeros(0),
block_quadratics: Array1::zeros(0),
penalty_score: Array1::zeros(0),
block_penalty_scores: Vec::new(),
stable_value: Some(half_stable_penalty_term),
}
}
pub(crate) fn with_stable_value(mut self, half_stable_penalty_term: f64) -> Self {
self.stable_value = Some(half_stable_penalty_term);
self
}
pub(crate) fn rho_frozen_d1(&self, idx: usize) -> f64 {
if idx < self.lambdas.len() {
0.5 * self.lambdas[idx] * self.block_quadratics[idx]
} else {
0.0
}
}
pub(crate) fn block_penalty_scores(&self) -> &[Array1<f64>] {
&self.block_penalty_scores
}
}
impl CriterionAtom for PenaltyQuadAtom {
fn name(&self) -> &'static str {
"penalty_quadratic"
}
fn value(&self) -> f64 {
self.stable_value.unwrap_or_else(|| {
0.5 * self
.lambdas
.iter()
.zip(self.block_quadratics.iter())
.map(|(&lam, &q)| lam * q)
.sum::<f64>()
})
}
fn frozen_d1(&self, dir: &ThetaDirection) -> f64 {
match dir.index {
Some(k) => self.rho_frozen_d1(k),
_ => 0.0,
}
}
fn beta_channel(&self) -> Option<BetaChannel> {
Some(BetaChannel {
grad_beta: self.penalty_score.clone(),
})
}
fn stratum(&self) -> Option<StratumFingerprint> {
None
}
}
pub struct JeffreysLogdetAtom {
pub eigvals: Array1<f64>,
pub floor: f64,
pub gate_weight: f64,
pub reduced_drift: std::collections::HashMap<usize, Arc<Array2<f64>>>,
pub floor_drift: std::collections::HashMap<usize, f64>,
pub stratum: StratumFingerprint,
}
impl JeffreysLogdetAtom {
fn floored_inv_diag(&self) -> Array1<f64> {
self.eigvals.mapv(|lam| floored_inverse(lam, self.floor))
}
fn floor_sensitivity_sum(&self) -> f64 {
self.eigvals
.iter()
.map(|&lam| jeffreys_antiderivative_floor_sensitivity(lam, self.floor))
.sum()
}
pub fn second_order_curvature(&self, axis_count: usize) -> Result<Array2<f64>, String> {
let m = self.eigvals.len();
let psi = floored_inverse_divided_differences(&self.eigvals, self.floor);
let mut a_rows = Array2::<f64>::zeros((axis_count, m * m));
let mut aw_rows = Array2::<f64>::zeros((axis_count, m * m));
for axis in 0..axis_count {
let reduced = self.reduced_drift.get(&axis).ok_or_else(|| {
format!(
"jeffreys_logdet second-order curvature missing reduced drift for axis {axis}"
)
})?;
if reduced.dim() != (m, m) {
return Err(format!(
"jeffreys_logdet reduced drift shape for axis {axis} is {:?}, expected ({m}, {m})",
reduced.dim()
));
}
let mut col = 0usize;
for i in 0..m {
for j in 0..m {
let a_ij = reduced[[i, j]];
a_rows[[axis, col]] = a_ij;
aw_rows[[axis, col]] = psi[[i, j]] * a_ij;
col += 1;
}
}
}
let mut hphi = crate::linalg::faer_ndarray::fast_abt(&aw_rows, &a_rows);
hphi.mapv_inplace(|v| -0.5 * self.gate_weight * v);
Ok(hphi)
}
}
impl CriterionAtom for JeffreysLogdetAtom {
fn name(&self) -> &'static str {
"jeffreys_logdet"
}
fn value(&self) -> f64 {
self.gate_weight
* 0.5
* self
.eigvals
.iter()
.map(|&lam| jeffreys_antiderivative(lam, self.floor))
.sum::<f64>()
}
fn frozen_d1(&self, dir: &ThetaDirection) -> f64 {
let idx = match dir.index {
Some(idx) => idx,
None => return 0.0,
};
let reduced = match self.reduced_drift.get(&idx) {
Some(r) => r,
None => return 0.0,
};
let d = self.floored_inv_diag();
let m = d.len();
let mut trace = 0.0;
for i in 0..m {
trace += d[i] * reduced[[i, i]];
}
if let Some(floor_dot) = self.floor_drift.get(&idx) {
trace += self.floor_sensitivity_sum() * floor_dot;
}
self.gate_weight * 0.5 * trace
}
fn beta_channel(&self) -> Option<BetaChannel> {
None
}
fn stratum(&self) -> Option<StratumFingerprint> {
Some(StratumFingerprint {
kept_rank: self.stratum.kept_rank,
min_relative_eigengap: self.stratum.min_relative_eigengap,
})
}
}
pub struct ConfiguredRhoPriorAtom {
pub eval: crate::rho_prior_eval::RhoPriorEval,
}
impl ConfiguredRhoPriorAtom {
pub fn cost(&self) -> f64 {
self.eval.cost
}
pub fn gradient(&self) -> &Array1<f64> {
&self.eval.gradient
}
pub fn hessian(&self) -> Option<&Array2<f64>> {
self.eval.hessian.as_ref()
}
}
impl CriterionAtom for ConfiguredRhoPriorAtom {
fn name(&self) -> &'static str {
"configured_rho_prior"
}
fn value(&self) -> f64 {
self.cost()
}
fn frozen_d1(&self, dir: &ThetaDirection) -> f64 {
match dir.index {
Some(idx) if idx < self.eval.gradient.len() => self.eval.gradient[idx],
_ => 0.0,
}
}
fn beta_channel(&self) -> Option<BetaChannel> {
None
}
fn stratum(&self) -> Option<StratumFingerprint> {
None
}
}
pub struct SoftRhoGuardPriorAtom {
pub value: f64,
pub gradient: Array1<f64>,
pub hessian_diag: Option<Array1<f64>>,
}
impl SoftRhoGuardPriorAtom {
pub fn evaluate_anchored(
rho: &Array1<f64>,
weight: f64,
sharpness: f64,
bound: f64,
anchor: f64,
) -> Self {
let len = rho.len();
let mut gradient = Array1::<f64>::zeros(len);
if len == 0 || weight == 0.0 {
return Self {
value: 0.0,
gradient,
hessian_diag: None,
};
}
let a = sharpness / bound;
let grad_prefactor = weight * a;
let hess_prefactor = weight * a * a;
let mut value = 0.0;
let mut hess = Array1::<f64>::zeros(len);
for (i, &ri) in rho.iter().enumerate() {
let scaled = a * (ri - anchor);
let t = scaled.tanh();
value += weight * scaled.cosh().ln();
gradient[i] = grad_prefactor * t;
hess[i] = hess_prefactor * (1.0 - t * t);
}
let hessian_diag = hess.iter().any(|&v| v != 0.0).then_some(hess);
Self {
value,
gradient,
hessian_diag,
}
}
pub fn cost(&self) -> f64 {
self.value
}
pub fn gradient(&self) -> &Array1<f64> {
&self.gradient
}
pub fn hessian(&self) -> Option<Array2<f64>> {
let diag = self.hessian_diag.as_ref()?;
let len = diag.len();
let mut hess = Array2::<f64>::zeros((len, len));
for (i, &d) in diag.iter().enumerate() {
hess[[i, i]] = d;
}
Some(hess)
}
}
impl CriterionAtom for SoftRhoGuardPriorAtom {
fn name(&self) -> &'static str {
"soft_rho_guard_prior"
}
fn value(&self) -> f64 {
self.value
}
fn frozen_d1(&self, dir: &ThetaDirection) -> f64 {
match dir.index {
Some(idx) if idx < self.gradient.len() => self.gradient[idx],
_ => 0.0,
}
}
fn beta_channel(&self) -> Option<BetaChannel> {
None
}
fn stratum(&self) -> Option<StratumFingerprint> {
None
}
}
pub(crate) trait ThetaCorrectionProjection: CriterionAtom {
fn cost(&self) -> f64 {
self.value()
}
fn gradient(&self) -> Option<&Array1<f64>>;
fn hessian(&self) -> Option<&Array2<f64>>;
}
pub struct TierneyKadaneAtom {
terms: super::outer_eval::TkCorrectionTerms,
}
impl TierneyKadaneAtom {
pub(crate) fn from_terms(terms: super::outer_eval::TkCorrectionTerms) -> Self {
Self { terms }
}
pub fn gradient(&self) -> Option<&Array1<f64>> {
self.terms.gradient.as_ref()
}
pub fn hessian(&self) -> Option<&Array2<f64>> {
self.terms.hessian.as_ref()
}
}
impl CriterionAtom for TierneyKadaneAtom {
fn name(&self) -> &'static str {
"tierney_kadane"
}
fn value(&self) -> f64 {
self.terms.value
}
fn frozen_d1(&self, dir: &ThetaDirection) -> f64 {
match (dir.index, self.terms.gradient.as_ref()) {
(Some(idx), Some(gradient)) if idx < gradient.len() => gradient[idx],
_ => 0.0,
}
}
fn beta_channel(&self) -> Option<BetaChannel> {
None
}
fn stratum(&self) -> Option<StratumFingerprint> {
None
}
}
impl ThetaCorrectionProjection for TierneyKadaneAtom {
fn gradient(&self) -> Option<&Array1<f64>> {
TierneyKadaneAtom::gradient(self)
}
fn hessian(&self) -> Option<&Array2<f64>> {
TierneyKadaneAtom::hessian(self)
}
}
pub struct ThetaOnlyCorrectionAtom {
pub label: &'static str,
pub value: f64,
pub gradient: Option<Array1<f64>>,
pub hessian: Option<Array2<f64>>,
}
impl ThetaOnlyCorrectionAtom {
pub(crate) fn from_tk_terms(
label: &'static str,
terms: super::outer_eval::TkCorrectionTerms,
) -> Self {
Self {
label,
value: terms.value,
gradient: terms.gradient,
hessian: terms.hessian,
}
}
pub fn gradient(&self) -> Option<&Array1<f64>> {
self.gradient.as_ref()
}
pub fn hessian(&self) -> Option<&Array2<f64>> {
self.hessian.as_ref()
}
}
impl ThetaCorrectionProjection for ThetaOnlyCorrectionAtom {
fn gradient(&self) -> Option<&Array1<f64>> {
self.gradient()
}
fn hessian(&self) -> Option<&Array2<f64>> {
self.hessian()
}
}
impl CriterionAtom for ThetaOnlyCorrectionAtom {
fn name(&self) -> &'static str {
self.label
}
fn value(&self) -> f64 {
self.value
}
fn frozen_d1(&self, dir: &ThetaDirection) -> f64 {
match (dir.index, self.gradient.as_ref()) {
(Some(idx), Some(gradient)) if idx < gradient.len() => gradient[idx],
_ => 0.0,
}
}
fn beta_channel(&self) -> Option<BetaChannel> {
None
}
fn stratum(&self) -> Option<StratumFingerprint> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
pub(crate) fn hessian_logdet_atom_emits_closed_form_value_and_directional_derivative() {
let kernel = Arc::new(PenaltySubspaceTrace {
u_s: array![[1.0, 0.0], [0.0, 1.0]],
h_proj_inverse: array![[0.5, 0.0], [0.0, 0.25]],
});
let stratum = StratumFingerprint {
kept_rank: 2,
min_relative_eigengap: 0.5,
};
let sensitivity = Arc::new(Sensitivity {
kernel: kernel.clone(),
logdet: 8.0_f64.ln(),
stratum: StratumFingerprint {
kept_rank: stratum.kept_rank,
min_relative_eigengap: stratum.min_relative_eigengap,
},
});
let hess = HessianLogdetAtom {
sensitivity: sensitivity.clone(),
};
assert_eq!(hess.name(), "hessian_logdet");
assert!((hess.value() - 0.5 * 8.0_f64.ln()).abs() < 1e-12);
assert!(
hess.beta_channel().is_none(),
"logdet atom has no β-channel"
);
assert_eq!(hess.stratum().expect("declared stratum").kept_rank, 2);
let h_dot = Arc::new(array![[1.0, 0.3], [0.3, 1.0]]);
let dir = ThetaDirection {
index: Some(0),
beta_dot: Some(Arc::new(array![0.5, 0.5])),
h_dot_total: Some(h_dot.clone()),
};
assert!((hess.frozen_d1(&dir) - 0.375).abs() < 1e-12);
let sampled = SampledBlockAtom {
value: -0.4,
explicit: array![0.2, -0.1],
q_bc: Arc::new(array![[0.5, 0.1], [0.1, 0.3]]),
g_d: array![1.0, -2.0],
stratum,
};
assert!((sampled.value() - (-0.4)).abs() < 1e-12);
assert!((sampled.frozen_d1(&dir) - 1.06).abs() < 1e-12);
assert!(
(sampled
.beta_channel()
.expect("sampled atom declares a β-channel")
.grad_beta
.dot(&array![0.5, 0.5])
- (-0.5))
.abs()
< 1e-12
);
let sum = CriterionSum {
atoms: vec![Box::new(hess), Box::new(sampled)],
};
assert!((sum.value() - (0.5 * 8.0_f64.ln() - 0.4)).abs() < 1e-12);
assert!((sum.d1(&dir) - 0.935).abs() < 1e-12);
}
#[test]
pub(crate) fn penalty_quad_atom_emits_closed_form_value_score_and_directional_derivative() {
let atom = PenaltyQuadAtom {
lambdas: array![3.0, 5.0],
block_quadratics: array![2.0, 4.0],
penalty_score: array![1.5, -0.5],
block_penalty_scores: vec![array![1.0, 0.0], array![0.5, -0.5]],
stable_value: None,
};
assert_eq!(atom.name(), "penalty_quadratic");
assert!((atom.value() - 13.0).abs() < 1e-12);
assert!(
atom.stratum().is_none(),
"the penalty quadratic is C^∞ and declares no stratum boundary"
);
let dir0 = ThetaDirection {
index: Some(0),
beta_dot: Some(Arc::new(array![0.5, 0.5])),
h_dot_total: Some(Arc::new(array![[0.0, 0.0], [0.0, 0.0]])),
};
assert!((atom.frozen_d1(&dir0) - 3.0).abs() < 1e-12);
let dir1 = ThetaDirection {
index: Some(1),
beta_dot: Some(Arc::new(array![0.5, 0.5])),
h_dot_total: Some(Arc::new(array![[0.0, 0.0], [0.0, 0.0]])),
};
assert!((atom.frozen_d1(&dir1) - 10.0).abs() < 1e-12);
let dir_none = ThetaDirection {
index: Some(7),
beta_dot: Some(Arc::new(array![0.5, 0.5])),
h_dot_total: Some(Arc::new(array![[0.0, 0.0], [0.0, 0.0]])),
};
assert!(atom.frozen_d1(&dir_none).abs() < 1e-12);
let channel = atom
.beta_channel()
.expect("penalty quadratic declares a β-channel");
assert!((channel.grad_beta.dot(&array![0.5, 0.5]) - 0.5).abs() < 1e-12);
let sum = CriterionSum {
atoms: vec![Box::new(atom)],
};
assert!((sum.value() - 13.0).abs() < 1e-12);
assert!((sum.d1(&dir0) - 3.5).abs() < 1e-12);
let centered_coord = PenaltyCoordinate::from_dense_root_with_mean(
array![[1.0, 0.0], [0.0, 1.0]],
array![1.0, 1.0],
);
let centered_atom =
PenaltyQuadAtom::from_penalty_coords(&[4.0], &[centered_coord], &array![2.0, 3.0])
.expect("centered penalty atom");
assert!((centered_atom.value() - 10.0).abs() < 1e-12);
assert!((centered_atom.rho_frozen_d1(0) - 10.0).abs() < 1e-12);
let centered_channel = centered_atom
.beta_channel()
.expect("centered penalty atom declares beta channel");
assert_eq!(centered_channel.grad_beta, array![4.0, 8.0]);
assert_eq!(centered_atom.block_penalty_scores()[0], array![4.0, 8.0]);
}
#[test]
pub(crate) fn penalty_quad_atom_stable_value_matches_production_and_gradient_is_rho_derivative()
{
let carrier = PenaltyQuadAtom::stable_value_only(7.25);
assert!((carrier.value() - 7.25).abs() < 1e-12);
assert_eq!(carrier.name(), "penalty_quadratic");
let dir0 = ThetaDirection {
index: Some(0),
beta_dot: Some(Arc::new(array![0.0, 0.0])),
h_dot_total: None,
};
assert_eq!(carrier.frozen_d1(&dir0), 0.0);
assert!(carrier.beta_channel().is_some()); assert_eq!(
carrier.beta_channel().unwrap().grad_beta.len(),
0,
"value-only carrier has no β-channel mass"
);
let beta = array![2.0_f64, 3.0];
let coord0 = PenaltyCoordinate::from_dense_root_with_mean(
array![[1.0, 0.0], [0.0, 0.0]],
array![0.0, 0.0],
);
let coord1 = PenaltyCoordinate::from_dense_root_with_mean(
array![[0.0, 0.0], [0.0, 1.0]],
array![0.0, 0.0],
);
let lambdas = [0.7_f64, 1.3];
let coords = vec![coord0.clone(), coord1.clone()];
let build = |lams: &[f64]| {
PenaltyQuadAtom::from_penalty_coords(lams, &coords, &beta).expect("penalty atom")
};
let stable = 0.5 * (lambdas[0] * 4.0 + lambdas[1] * 9.0);
let atom = build(&lambdas).with_stable_value(stable);
assert!((atom.value() - 7.25).abs() < 1e-12);
assert!((atom.rho_frozen_d1(0) - 0.5 * 0.7 * 4.0).abs() < 1e-12);
assert!((atom.rho_frozen_d1(1) - 0.5 * 1.3 * 9.0).abs() < 1e-12);
let energy_at = |rho: [f64; 2]| -> f64 {
let lams = [rho[0].exp(), rho[1].exp()];
build(&lams).value() };
let rho0 = [lambdas[0].ln(), lambdas[1].ln()];
let h = 1e-6;
for k in 0..2 {
let mut rp = rho0;
let mut rm = rho0;
rp[k] += h;
rm[k] -= h;
let fd = (energy_at(rp) - energy_at(rm)) / (2.0 * h);
assert!(
(fd - atom.rho_frozen_d1(k)).abs() < 1e-6,
"rho_frozen_d1[{k}] {} vs FD-of-value {}",
atom.rho_frozen_d1(k),
fd
);
}
let plain = build(&lambdas);
assert!((plain.value() - 7.25).abs() < 1e-12);
}
#[test]
pub(crate) fn jeffreys_logdet_atom_emits_consistent_value_and_directional_derivative() {
use super::super::jeffreys_subspace::{floored_inverse, jeffreys_antiderivative};
let floor = 1e-3_f64;
let cap = super::super::jeffreys_subspace::jeffreys_cap(floor);
for &lam in &[cap * 4.0, (floor + cap) * 0.5, floor * 0.5, -0.7_f64] {
let h = 1e-7 * lam.abs().max(1e-3);
let fd = (jeffreys_antiderivative(lam + h, floor)
- jeffreys_antiderivative(lam - h, floor))
/ (2.0 * h);
let analytic = floored_inverse(lam, floor);
assert!(
(fd - analytic).abs() <= 1e-4 * analytic.abs().max(1.0),
"g'(λ) desync at λ={lam}: fd={fd} analytic={analytic}"
);
}
let eigvals = array![4.0_f64, 0.5_f64];
let gate = 0.75_f64;
let stratum = StratumFingerprint {
kept_rank: 2,
min_relative_eigengap: (4.0 - 0.5) / 4.0,
};
let mut reduced_drift = std::collections::HashMap::new();
reduced_drift.insert(0_usize, Arc::new(array![[1.0, 0.2], [0.2, 3.0]]));
reduced_drift.insert(1_usize, Arc::new(array![[2.0, 0.5], [0.5, 1.0]]));
let atom = JeffreysLogdetAtom {
eigvals: eigvals.clone(),
floor,
gate_weight: gate,
reduced_drift,
floor_drift: std::collections::HashMap::new(),
stratum,
};
assert_eq!(atom.name(), "jeffreys_logdet");
let expected_value = gate * 0.5 * (4.0_f64.ln() + 0.5_f64.ln());
assert!(
(atom.value() - expected_value).abs() < 1e-12,
"value {} vs {}",
atom.value(),
expected_value
);
assert!(
atom.beta_channel().is_none(),
"Jeffreys logdet rides the shared drift; no β-channel (like HessianLogdetAtom)"
);
assert_eq!(atom.stratum().expect("declared stratum").kept_rank, 2);
let dir0 = ThetaDirection {
index: Some(0),
beta_dot: Some(Arc::new(array![0.0, 0.0])),
h_dot_total: None,
};
assert!(
(atom.frozen_d1(&dir0) - 2.34375).abs() < 1e-12,
"frozen_d1 {} vs 2.34375",
atom.frozen_d1(&dir0)
);
let hphi = atom
.second_order_curvature(2)
.expect("second-order Jeffreys atom curvature");
assert!((hphi[[0, 0]] - 13.5384375).abs() < 1e-12);
assert!((hphi[[0, 1]] - 4.584375).abs() < 1e-12);
assert!((hphi[[1, 0]] - 4.584375).abs() < 1e-12);
assert!((hphi[[1, 1]] - 1.6875).abs() < 1e-12);
let dir_absent = ThetaDirection {
index: Some(9),
beta_dot: None,
h_dot_total: None,
};
assert!(atom.frozen_d1(&dir_absent).abs() < 1e-12);
let sum = CriterionSum {
atoms: vec![Box::new(atom)],
};
assert!((sum.value() - expected_value).abs() < 1e-12);
assert!((sum.d1(&dir0) - 2.34375).abs() < 1e-12);
let mut reduced_drift = std::collections::HashMap::new();
reduced_drift.insert(1_usize, Arc::new(array![[0.0, 0.0], [0.0, 0.0]]));
let mut floor_drift = std::collections::HashMap::new();
floor_drift.insert(1_usize, 2.0e-4);
let floor_atom = JeffreysLogdetAtom {
eigvals: array![0.5 * floor, 0.25 * floor],
floor,
gate_weight: gate,
reduced_drift,
floor_drift,
stratum: StratumFingerprint {
kept_rank: 2,
min_relative_eigengap: 0.25,
},
};
let dir_floor = ThetaDirection {
index: Some(1),
beta_dot: Some(Arc::new(array![0.0, 0.0])),
h_dot_total: None,
};
let expected_floor_d1 = gate * 0.5 * 1250.0 * 2.0e-4;
assert!(
(floor_atom.frozen_d1(&dir_floor) - expected_floor_d1).abs() < 1e-12,
"floor frozen_d1 {} vs {}",
floor_atom.frozen_d1(&dir_floor),
expected_floor_d1
);
}
#[test]
pub(crate) fn configured_rho_prior_atom_projects_one_eval() {
let atom = ConfiguredRhoPriorAtom {
eval: crate::rho_prior_eval::RhoPriorEval {
cost: 1.25,
gradient: array![0.5, -1.5, 2.0],
hessian: Some(array![[3.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 5.0]]),
},
};
assert_eq!(atom.name(), "configured_rho_prior");
assert!((atom.value() - 1.25).abs() < 1e-12);
assert!(
atom.beta_channel().is_none(),
"rho prior is theta-only and declares no β-channel"
);
assert!(
atom.stratum().is_none(),
"rho prior is smooth on its configured-valid branch"
);
let dir1 = ThetaDirection {
index: Some(1),
beta_dot: None,
h_dot_total: None,
};
assert!((atom.frozen_d1(&dir1) - (-1.5)).abs() < 1e-12);
let dir_absent = ThetaDirection {
index: Some(9),
beta_dot: None,
h_dot_total: None,
};
assert!(atom.frozen_d1(&dir_absent).abs() < 1e-12);
assert_eq!(atom.gradient(), &array![0.5, -1.5, 2.0]);
assert_eq!(atom.hessian().expect("configured Hessian")[[2, 2]], 5.0);
}
#[test]
pub(crate) fn theta_only_correction_atom_projects_value_gradient_and_hessian() {
use crate::solver::estimate::reml::outer_eval::TkCorrectionTerms;
let tk_atom = TierneyKadaneAtom::from_terms(TkCorrectionTerms {
value: -0.75,
gradient: Some(array![0.25, -0.5]),
hessian: Some(array![[2.0, 0.1], [0.1, 3.0]]),
});
assert_eq!(tk_atom.name(), "tierney_kadane");
assert!((tk_atom.value() - (-0.75)).abs() < 1e-12);
assert!(tk_atom.beta_channel().is_none());
assert!(tk_atom.stratum().is_none());
let dir1 = ThetaDirection {
index: Some(1),
beta_dot: None,
h_dot_total: None,
};
assert!((tk_atom.frozen_d1(&dir1) - (-0.5)).abs() < 1e-12);
assert_eq!(tk_atom.gradient().expect("gradient"), &array![0.25, -0.5]);
assert_eq!(tk_atom.hessian().expect("hessian")[[1, 1]], 3.0);
let atom = ThetaOnlyCorrectionAtom {
label: "sampled_block_marginal",
value: -0.75,
gradient: Some(array![0.25, -0.5]),
hessian: Some(array![[2.0, 0.1], [0.1, 3.0]]),
};
assert_eq!(atom.name(), "sampled_block_marginal");
assert!((atom.value() - (-0.75)).abs() < 1e-12);
assert!(atom.beta_channel().is_none());
assert!(atom.stratum().is_none());
assert!((atom.frozen_d1(&dir1) - (-0.5)).abs() < 1e-12);
let dir_absent = ThetaDirection {
index: Some(7),
beta_dot: None,
h_dot_total: None,
};
assert!(atom.frozen_d1(&dir_absent).abs() < 1e-12);
assert_eq!(atom.gradient().expect("gradient"), &array![0.25, -0.5]);
assert_eq!(atom.hessian().expect("hessian")[[1, 1]], 3.0);
}
#[test]
pub(crate) fn soft_rho_guard_prior_atom_value_gradient_hessian_are_one_chain() {
let (w, sharp, bound, anchor) = (2.0_f64, 4.0_f64, 8.0_f64, 0.5_f64);
let rho = array![1.5_f64, 0.5_f64, -1.5_f64];
let a = sharp / bound;
let atom = SoftRhoGuardPriorAtom::evaluate_anchored(&rho, w, sharp, bound, anchor);
assert_eq!(atom.name(), "soft_rho_guard_prior");
assert!(
atom.beta_channel().is_none(),
"soft guard prior is θ-only and separable"
);
assert!(atom.stratum().is_none(), "smooth everywhere");
let expected_value: f64 = w * rho
.iter()
.map(|&r| (a * (r - anchor)).cosh().ln())
.sum::<f64>();
assert!(
(atom.value() - expected_value).abs() < 1e-12,
"value {} vs {}",
atom.value(),
expected_value
);
assert!((atom.cost() - expected_value).abs() < 1e-12);
for (i, &r) in rho.iter().enumerate() {
let g = w * a * (a * (r - anchor)).tanh();
assert!(
(atom.gradient()[i] - g).abs() < 1e-12,
"grad[{i}] {} vs {}",
atom.gradient()[i],
g
);
let dir = ThetaDirection {
index: Some(i),
beta_dot: None,
h_dot_total: None,
};
assert!((atom.frozen_d1(&dir) - g).abs() < 1e-12);
}
let hess = atom.hessian().expect("nonzero curvature");
for i in 0..rho.len() {
let t = (a * (rho[i] - anchor)).tanh();
let h = w * a * a * (1.0 - t * t);
assert!((hess[[i, i]] - h).abs() < 1e-12, "hess[{i},{i}]");
for j in 0..rho.len() {
if i != j {
assert_eq!(hess[[i, j]], 0.0, "off-diagonal must be zero");
}
}
}
let step = 1e-6;
for i in 0..rho.len() {
let mut rp = rho.clone();
let mut rm = rho.clone();
rp[i] += step;
rm[i] -= step;
let vp = SoftRhoGuardPriorAtom::evaluate_anchored(&rp, w, sharp, bound, anchor).value();
let vm = SoftRhoGuardPriorAtom::evaluate_anchored(&rm, w, sharp, bound, anchor).value();
let fd_grad = (vp - vm) / (2.0 * step);
assert!(
(fd_grad - atom.gradient()[i]).abs() < 1e-6,
"FD grad[{i}] {} vs analytic {}",
fd_grad,
atom.gradient()[i]
);
let gp = SoftRhoGuardPriorAtom::evaluate_anchored(&rp, w, sharp, bound, anchor)
.gradient()[i];
let gm = SoftRhoGuardPriorAtom::evaluate_anchored(&rm, w, sharp, bound, anchor)
.gradient()[i];
let fd_hess = (gp - gm) / (2.0 * step);
assert!(
(fd_hess - hess[[i, i]]).abs() < 1e-6,
"FD hess[{i}] {} vs analytic {}",
fd_hess,
hess[[i, i]]
);
}
let zero_w = SoftRhoGuardPriorAtom::evaluate_anchored(&rho, 0.0, sharp, bound, 0.0);
assert_eq!(zero_w.value(), 0.0);
assert!(zero_w.hessian().is_none());
let empty = SoftRhoGuardPriorAtom::evaluate_anchored(
&Array1::<f64>::zeros(0),
w,
sharp,
bound,
0.0,
);
assert_eq!(empty.value(), 0.0);
assert!(empty.hessian().is_none());
let dir0 = ThetaDirection {
index: Some(0),
beta_dot: Some(Arc::new(array![0.0, 0.0])),
h_dot_total: None,
};
let g0 = w * a * (a * (rho[0] - anchor)).tanh();
let sum = CriterionSum {
atoms: vec![Box::new(atom)],
};
assert!((sum.value() - expected_value).abs() < 1e-12);
assert!((sum.d1(&dir0) - g0).abs() < 1e-12);
}
#[test]
pub(crate) fn sensitivity_fill_direction_feeds_criterion_sum_end_to_end() {
use crate::solver::sensitivity::FitSensitivity;
let lower = array![[2.0_f64.sqrt(), 0.0], [0.0, 2.0]];
let op = FitSensitivity::from_lower_triangular(&lower);
let kernel = Arc::new(PenaltySubspaceTrace {
u_s: array![[1.0, 0.0], [0.0, 1.0]],
h_proj_inverse: array![[0.5, 0.0], [0.0, 0.25]],
});
let sensitivity = Arc::new(Sensitivity {
kernel: kernel.clone(),
logdet: 8.0_f64.ln(),
stratum: StratumFingerprint {
kept_rank: 2,
min_relative_eigengap: 0.5,
},
});
let f_beta_theta = array![1.0, -2.0];
let h_dot_frozen = array![[1.0, 0.0], [0.0, 1.0]];
let dir = sensitivity
.fill_direction(0, &op, &f_beta_theta, &h_dot_frozen, |beta_dot| {
Array2::from_diag(beta_dot)
})
.expect("finite mode response");
let beta_dot = dir.beta_dot.as_ref().expect("filled β̇");
assert!((beta_dot[0] - (-0.5)).abs() < 1e-12);
assert!((beta_dot[1] - 0.5).abs() < 1e-12);
let h_dot = dir.h_dot_total.as_ref().expect("filled Ḣ_total");
assert!((h_dot[[0, 0]] - 0.5).abs() < 1e-12);
assert!((h_dot[[1, 1]] - 1.5).abs() < 1e-12);
let hess = HessianLogdetAtom {
sensitivity: sensitivity.clone(),
};
let pen = PenaltyQuadAtom {
lambdas: array![3.0, 5.0],
block_quadratics: array![2.0, 4.0],
penalty_score: array![2.0, 1.0],
block_penalty_scores: vec![array![2.0, 0.0], array![0.0, 1.0]],
stable_value: None,
};
assert!((hess.frozen_d1(&dir) - 0.3125).abs() < 1e-12);
let sum = CriterionSum {
atoms: vec![Box::new(hess), Box::new(pen)],
};
assert!(
(sum.d1(&dir) - 2.8125).abs() < 1e-12,
"profiled d1 {} vs 2.8125",
sum.d1(&dir)
);
}
}