use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::linalg::faer_ndarray::FaerEigh;
use crate::solver::evidence::{
HybridAtomCandidate, HybridAtomChoice, HybridSplitSelection, select_hybrid_split,
};
use crate::terms::latent::LatentManifold;
use crate::terms::sae::chart_canonicalization::d1_atom_fitted_turning;
use crate::terms::sae::manifold::SaeManifoldAtom;
fn reduced_laplace_nle(residual_objective: f64, log_det_h: f64) -> f64 {
residual_objective + 0.5 * log_det_h
}
fn curved_design_gram_logdet(
phi: ArrayView2<'_, f64>,
assign: ArrayView1<'_, f64>,
p: usize,
) -> Option<f64> {
let n = phi.nrows();
let m = phi.ncols();
if m == 0 || assign.len() != n || n == 0 {
return None;
}
let mut gram = Array2::<f64>::zeros((m, m));
for i in 0..n {
let w = assign[i] * assign[i];
if !(w.is_finite() && w >= 0.0) {
return None;
}
if w == 0.0 {
continue;
}
let row = phi.row(i);
for a in 0..m {
let wa = w * row[a];
for b in a..m {
gram[[a, b]] += wa * row[b];
}
}
}
for a in 0..m {
for b in 0..a {
gram[[a, b]] = gram[[b, a]];
}
}
if gram.iter().any(|v| !v.is_finite()) {
return None;
}
let (vals, _vecs) = gram.eigh(faer::Side::Lower).ok()?;
let lambda_max = vals.iter().cloned().fold(0.0_f64, f64::max);
if !(lambda_max > 0.0 && lambda_max.is_finite()) {
return None;
}
let floor = lambda_max * 1e-12;
let mut log_det = 0.0_f64;
let mut rank = 0usize;
for &lambda in vals.iter() {
if lambda > floor {
log_det += lambda.ln();
rank += 1;
}
}
if rank == 0 || !log_det.is_finite() {
return None;
}
Some((p as f64) * log_det)
}
#[derive(Clone, Debug)]
pub struct AtomLinearImage {
pub atom_idx: usize,
pub t_bar: f64,
pub b0: Array1<f64>,
pub b1: Array1<f64>,
}
impl AtomLinearImage {
pub fn fill_row(&self, t: f64, out: &mut [f64]) {
let dt = t - self.t_bar;
for (j, slot) in out.iter_mut().enumerate() {
*slot = self.b0[j] + dt * self.b1[j];
}
}
}
#[derive(Clone, Debug)]
pub struct AtomHybridVerdict {
pub atom_name: String,
pub choice: HybridAtomChoice,
pub kept_curved: bool,
pub fitted_turning: Option<f64>,
pub train_loao_delta_ev: Option<f64>,
pub linear_image: Option<AtomLinearImage>,
}
#[derive(Clone, Debug)]
pub struct SaeHybridSplitReport {
pub verdicts: Vec<AtomHybridVerdict>,
pub selection: HybridSplitSelection,
}
const MIN_ROWS_FOR_LINEAR_FIT: usize = 3;
fn build_atom_candidates(
coords: ArrayView1<'_, f64>,
assign: ArrayView1<'_, f64>,
decoded: ArrayView2<'_, f64>,
target_resid: ArrayView2<'_, f64>,
curved_num_params: usize,
curved_phi: Option<ArrayView2<'_, f64>>,
fitted_turning: Option<f64>,
) -> Option<(
HybridAtomCandidate,
HybridAtomCandidate,
(f64, Array1<f64>, Array1<f64>),
)> {
let n = coords.len();
let p = decoded.ncols();
if n < MIN_ROWS_FOR_LINEAR_FIT
|| decoded.nrows() != n
|| assign.len() != n
|| target_resid.nrows() != n
|| target_resid.ncols() != p
|| p == 0
{
return None;
}
let mut w_sum = 0.0_f64;
let mut t_bar = 0.0_f64;
for i in 0..n {
let a = assign[i];
if !(a.is_finite() && a >= 0.0) {
return None;
}
let w = a * a;
w_sum += w;
t_bar += w * coords[i];
}
if !(w_sum > 0.0) {
return None;
}
t_bar /= w_sum;
let mut s_tt = 0.0_f64;
for i in 0..n {
let dt = coords[i] - t_bar;
s_tt += assign[i] * assign[i] * dt * dt;
}
if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
return None;
}
let mut b0 = Array1::<f64>::zeros(p);
let mut b1 = Array1::<f64>::zeros(p);
for j in 0..p {
let mut s_1y = 0.0_f64;
let mut s_ty = 0.0_f64;
for i in 0..n {
let a = assign[i];
let dt = coords[i] - t_bar;
let y = target_resid[[i, j]];
s_1y += a * y;
s_ty += a * dt * y;
}
b0[j] = s_1y / w_sum;
b1[j] = s_ty / s_tt;
}
let mut curved_rss = 0.0_f64;
let mut linear_rss = 0.0_f64;
for i in 0..n {
let a = assign[i];
let dt = coords[i] - t_bar;
for j in 0..p {
let y = target_resid[[i, j]];
let r_curved = y - a * decoded[[i, j]];
curved_rss += r_curved * r_curved;
let r_linear = y - a * (b0[j] + dt * b1[j]);
linear_rss += r_linear * r_linear;
}
}
let curved_residual_objective = 0.5 * curved_rss;
let linear_residual_objective = 0.5 * linear_rss;
let linear_num_params = 2 * p;
if !(w_sum > 0.0 && w_sum.is_finite() && s_tt.is_finite()) {
return None;
}
let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
let curved_log_det_h = curved_phi
.and_then(|phi| {
if phi.nrows() == n {
curved_design_gram_logdet(phi, assign, p)
} else {
None
}
})
.unwrap_or_else(|| (curved_num_params as f64) * w_sum.ln());
let linear_nle = reduced_laplace_nle(linear_residual_objective, linear_log_det_h);
let curved_nle = reduced_laplace_nle(curved_residual_objective, curved_log_det_h);
if !(linear_nle.is_finite() && curved_nle.is_finite()) {
return None;
}
let linear = HybridAtomCandidate::linear(linear_nle, linear_num_params);
let curved = HybridAtomCandidate::curved(1, curved_nle, curved_num_params, fitted_turning);
Some((linear, curved, (t_bar, b0, b1)))
}
pub fn build_hybrid_split_report<'a, C, W, D, R, M, E>(
atoms: &'a [SaeManifoldAtom],
eligible_d1: impl Iterator<Item = usize>,
mut coords_for: C,
mut assign_for: W,
mut decoded_for: D,
mut target_resid_for: R,
mut manifold_for: M,
mut delta_ev_for: E,
) -> Result<Option<SaeHybridSplitReport>, String>
where
C: FnMut(usize) -> Array1<f64>,
W: FnMut(usize) -> Array1<f64>,
D: FnMut(usize) -> Array2<f64>,
R: FnMut(usize) -> Array2<f64>,
M: FnMut(usize) -> LatentManifold,
E: FnMut(usize) -> Option<f64>,
{
let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
let mut names: Vec<String> = Vec::new();
let mut manifolds: Vec<LatentManifold> = Vec::new();
let mut linear_images: Vec<AtomLinearImage> = Vec::new();
let mut turnings: Vec<Option<f64>> = Vec::new();
let mut delta_evs: Vec<Option<f64>> = Vec::new();
for atom_idx in eligible_d1 {
let atom = &atoms[atom_idx];
let coords = coords_for(atom_idx);
let assign = assign_for(atom_idx);
let decoded = decoded_for(atom_idx);
let target_resid = target_resid_for(atom_idx);
let curved_num_params = atom.decoder_coefficients.len();
let fitted_turning = atom.basis_evaluator.as_ref().and_then(|evaluator| {
d1_atom_fitted_turning(
evaluator.as_ref(),
atom.decoder_coefficients.view(),
coords.view(),
)
.ok()
.flatten()
});
let coords_col = coords
.view()
.into_shape_with_order((coords.len(), 1))
.ok()
.map(|v| v.to_owned());
let curved_phi = match (atom.basis_evaluator.as_ref(), coords_col.as_ref()) {
(Some(evaluator), Some(col)) => {
evaluator.evaluate(col.view()).ok().map(|(phi, _jet)| phi)
}
_ => None,
};
let Some((linear, curved, (t_bar, b0, b1))) = build_atom_candidates(
coords.view(),
assign.view(),
decoded.view(),
target_resid.view(),
curved_num_params,
curved_phi.as_ref().map(|phi| phi.view()),
fitted_turning,
) else {
continue;
};
let manifold = manifold_for(atom_idx);
let slot = if manifold.is_euclidean() {
vec![linear]
} else {
vec![linear, curved]
};
slots.push(slot);
names.push(atom.name.clone());
manifolds.push(manifold);
turnings.push(fitted_turning);
delta_evs.push(delta_ev_for(atom_idx));
linear_images.push(AtomLinearImage {
atom_idx,
t_bar,
b0,
b1,
});
}
if slots.is_empty() {
return Ok(None);
}
let selection = select_hybrid_split(&slots)?;
let verdicts: Vec<AtomHybridVerdict> = names
.into_iter()
.zip(selection.atoms.iter().copied())
.zip(linear_images.into_iter())
.zip(turnings.into_iter())
.zip(delta_evs.into_iter())
.map(
|((((atom_name, choice), linear_image), fitted_turning), train_loao_delta_ev)| {
let kept_curved = !choice.param.is_linear();
AtomHybridVerdict {
atom_name,
choice,
kept_curved,
fitted_turning,
train_loao_delta_ev,
linear_image: if kept_curved {
None
} else {
Some(linear_image)
},
}
},
)
.collect();
Ok(Some(SaeHybridSplitReport {
verdicts,
selection,
}))
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn straight_residual_selects_linear() {
let n = 40;
let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
let assign = Array1::<f64>::ones(n);
let mut data = Array2::<f64>::zeros((n, 2));
let mut decoded = Array2::<f64>::zeros((n, 2));
for i in 0..n {
data[[i, 0]] = coords[i];
data[[i, 1]] = 0.6 * coords[i];
decoded[[i, 0]] = coords[i];
decoded[[i, 1]] = 0.6 * coords[i];
}
let (linear, curved, _) = build_atom_candidates(
coords.view(),
assign.view(),
decoded.view(),
data.view(),
10,
None,
Some(0.0),
)
.expect("straight residual yields a candidate pair");
let choice =
crate::solver::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
assert!(
choice.param.is_linear(),
"a straight response residual must keep the linear special case"
);
}
#[test]
fn turning_residual_selects_curved_on_evidence() {
let n = 60;
let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
let assign = Array1::<f64>::ones(n);
let mut data = Array2::<f64>::zeros((n, 2));
let mut decoded = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let theta = 2.0 * PI * coords[i];
data[[i, 0]] = theta.cos();
data[[i, 1]] = theta.sin();
decoded[[i, 0]] = theta.cos();
decoded[[i, 1]] = theta.sin();
}
let (linear, curved, _) = build_atom_candidates(
coords.view(),
assign.view(),
decoded.view(),
data.view(),
5,
None,
Some(2.0 * PI),
)
.expect("turning residual yields a candidate pair");
assert!(
linear.negative_log_evidence > curved.negative_log_evidence,
"the line must misfit the circular residual worse than the curve does \
(linear NLE {} should exceed curved NLE {})",
linear.negative_log_evidence,
curved.negative_log_evidence
);
let choice =
crate::solver::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
assert_eq!(
choice.param,
crate::solver::evidence::HybridAtomParam::Curved { latent_dim: 1 },
"a full-circle response residual must keep the curved parameterization"
);
assert!(
choice.curved_evidence_margin > 0.0,
"curved must win a positive evidence margin over the linear secant"
);
}
#[test]
fn linear_beats_curved_when_curve_misfits_residual() {
let n = 50;
let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
let assign = Array1::<f64>::ones(n);
let mut data = Array2::<f64>::zeros((n, 2));
let mut decoded = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let t = coords[i];
data[[i, 0]] = t;
data[[i, 1]] = 0.5 * t;
decoded[[i, 0]] = t;
decoded[[i, 1]] = t * t; }
let (linear, curved, _) = build_atom_candidates(
coords.view(),
assign.view(),
decoded.view(),
data.view(),
6,
None,
Some(1.0),
)
.expect("candidate pair");
let choice =
crate::solver::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
assert!(
choice.param.is_linear(),
"a curved image that fits the data worse than its own line must yield \
to the linear special case on common-data evidence (#1202)"
);
}
#[test]
fn linear_logdet_includes_weighted_coordinate_spread() {
let n = 40;
let p = 2usize;
let logdet = |coords: &Array1<f64>, assign: &Array1<f64>| -> f64 {
let line = |t: f64| -> [f64; 2] { [t, 0.6 * t] };
let mut decoded = Array2::<f64>::zeros((n, p));
let mut data = Array2::<f64>::zeros((n, p));
for i in 0..n {
let l = line(coords[i]);
decoded[[i, 0]] = l[0];
decoded[[i, 1]] = l[1];
data[[i, 0]] = assign[i] * l[0];
data[[i, 1]] = assign[i] * l[1];
}
let (linear, _curved, _) = build_atom_candidates(
coords.view(),
assign.view(),
decoded.view(),
data.view(),
10,
None,
Some(0.0),
)
.expect("straight residual yields a pair");
2.0 * linear.negative_log_evidence };
let base_coords =
Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
let ones = Array1::<f64>::ones(n);
let wide_coords = base_coords.mapv(|t| 2.0 * t);
let d_spread = logdet(&wide_coords, &ones) - logdet(&base_coords, &ones);
assert!(
(d_spread - (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
"linear logdet must move by p·log(4) when coordinate spread doubles \
(got {d_spread}); the spread term log(s_tt) must be present"
);
let twos = Array1::<f64>::from_elem(n, 2.0);
let d_weight = logdet(&base_coords, &twos) - logdet(&base_coords, &ones);
assert!(
(d_weight - 2.0 * (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
"linear logdet must move by 2p·log(4) when all assignment masses double \
(got {d_weight})"
);
}
#[test]
fn curved_gram_logdet_is_real_weighted_design_determinant() {
let n = 40;
let p = 3usize;
let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
let assign = Array1::<f64>::from_iter((0..n).map(|i| 0.5 + 0.01 * (i as f64)));
let mut w_sum = 0.0;
let mut t_bar = 0.0;
for i in 0..n {
let w = assign[i] * assign[i];
w_sum += w;
t_bar += w * coords[i];
}
t_bar /= w_sum;
let mut s_tt = 0.0;
for i in 0..n {
let dt = coords[i] - t_bar;
s_tt += assign[i] * assign[i] * dt * dt;
}
let mut phi = Array2::<f64>::zeros((n, 2));
for i in 0..n {
phi[[i, 0]] = 1.0;
phi[[i, 1]] = coords[i] - t_bar;
}
let got = curved_design_gram_logdet(phi.view(), assign.view(), p)
.expect("non-degenerate curved design has a determinant");
let want = (p as f64) * (w_sum.ln() + s_tt.ln());
assert!(
(got - want).abs() < 1e-9,
"curved Gram logdet must be the real p·log|ΦᵀWΦ| = {want}, got {got}"
);
let mut phi_dup = Array2::<f64>::zeros((n, 2));
for i in 0..n {
phi_dup[[i, 0]] = 1.0;
phi_dup[[i, 1]] = 1.0;
}
let got_dup = curved_design_gram_logdet(phi_dup.view(), assign.view(), p)
.expect("rank-1 design still has a positive determinant");
let want_dup = (p as f64) * (2.0 * w_sum).ln();
assert!(
(got_dup - want_dup).abs() < 1e-9,
"rank-deficient curved Gram must report only its positive direction \
(p·log(2·w_sum) = {want_dup}), got {got_dup}"
);
}
#[test]
fn degenerate_coordinate_is_refused() {
let n = 5;
let coords = Array1::<f64>::from_elem(n, 0.5); let assign = Array1::<f64>::ones(n);
let decoded = Array2::<f64>::zeros((n, 2));
let data = Array2::<f64>::zeros((n, 2));
assert!(
build_atom_candidates(
coords.view(),
assign.view(),
decoded.view(),
data.view(),
6,
None,
Some(0.0)
)
.is_none(),
"a degenerate coordinate span must be refused"
);
}
}