use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::solver::evidence::{
HybridAtomCandidate, HybridAtomChoice, HybridSplitSelection, select_hybrid_split,
};
use crate::terms::latent_coord::LatentManifold;
use crate::terms::sae_chart_canonicalization::d1_atom_fitted_turning;
use crate::terms::sae_manifold::SaeManifoldAtom;
fn reduced_laplace_nle(residual_objective: f64, log_det_h: f64) -> f64 {
residual_objective + 0.5 * log_det_h
}
#[derive(Clone, Debug)]
pub struct AtomLinearImage {
pub atom_idx: usize,
pub t_bar: f64,
pub b0: Array1<f64>,
pub b1: Array1<f64>,
}
impl AtomLinearImage {
pub fn fill_row(&self, t: f64, out: &mut [f64]) {
let dt = t - self.t_bar;
for (j, slot) in out.iter_mut().enumerate() {
*slot = self.b0[j] + dt * self.b1[j];
}
}
}
#[derive(Clone, Debug)]
pub struct AtomHybridVerdict {
pub atom_name: String,
pub choice: HybridAtomChoice,
pub kept_curved: bool,
pub linear_image: Option<AtomLinearImage>,
}
#[derive(Clone, Debug)]
pub struct SaeHybridSplitReport {
pub verdicts: Vec<AtomHybridVerdict>,
pub selection: HybridSplitSelection,
}
const MIN_ROWS_FOR_LINEAR_FIT: usize = 3;
fn build_atom_candidates(
coords: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
decoded: ArrayView2<'_, f64>,
curved_num_params: usize,
fitted_turning: Option<f64>,
) -> Option<(
HybridAtomCandidate,
HybridAtomCandidate,
(f64, Array1<f64>, Array1<f64>),
)> {
let n = coords.len();
let p = decoded.ncols();
if n < MIN_ROWS_FOR_LINEAR_FIT || decoded.nrows() != n || weights.len() != n || p == 0 {
return None;
}
let mut w_sum = 0.0_f64;
let mut t_bar = 0.0_f64;
for i in 0..n {
let w = weights[i];
if !(w.is_finite() && w >= 0.0) {
return None;
}
w_sum += w;
t_bar += w * coords[i];
}
if !(w_sum > 0.0) {
return None;
}
t_bar /= w_sum;
let mut s_tt = 0.0_f64;
for i in 0..n {
let dt = coords[i] - t_bar;
s_tt += weights[i] * dt * dt;
}
if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
return None;
}
let mut b0 = Array1::<f64>::zeros(p);
let mut b1 = Array1::<f64>::zeros(p);
for j in 0..p {
let mut s_tg = 0.0_f64;
let mut g_bar = 0.0_f64;
for i in 0..n {
let w = weights[i];
g_bar += w * decoded[[i, j]];
}
g_bar /= w_sum;
for i in 0..n {
let dt = coords[i] - t_bar;
s_tg += weights[i] * dt * (decoded[[i, j]] - g_bar);
}
let slope = s_tg / s_tt;
b1[j] = slope;
b0[j] = g_bar - slope * t_bar;
}
let mut linear_rss = 0.0_f64;
for i in 0..n {
let dt = coords[i] - t_bar;
for j in 0..p {
let pred = b0[j] + dt * b1[j] + t_bar * b1[j]; let r = decoded[[i, j]] - pred;
linear_rss += weights[i] * r * r;
}
}
let curved_residual_objective = 0.0_f64;
let linear_residual_objective = 0.5 * linear_rss;
let linear_num_params = 2 * p;
if !(w_sum > 0.0 && w_sum.is_finite()) {
return None;
}
let linear_log_det_h = (linear_num_params as f64) * w_sum.ln();
let curved_log_det_h = (curved_num_params as f64) * w_sum.ln();
let linear_nle = reduced_laplace_nle(linear_residual_objective, linear_log_det_h);
let curved_nle = reduced_laplace_nle(curved_residual_objective, curved_log_det_h);
if !(linear_nle.is_finite() && curved_nle.is_finite()) {
return None;
}
let linear = HybridAtomCandidate::linear(linear_nle, linear_num_params);
let curved = HybridAtomCandidate::curved(1, curved_nle, curved_num_params, fitted_turning);
Some((linear, curved, (t_bar, b0, b1)))
}
pub fn build_hybrid_split_report<'a, C, W, D, M>(
atoms: &'a [SaeManifoldAtom],
eligible_d1: impl Iterator<Item = usize>,
mut coords_for: C,
mut weights_for: W,
mut decoded_for: D,
mut manifold_for: M,
) -> Result<Option<SaeHybridSplitReport>, String>
where
C: FnMut(usize) -> Array1<f64>,
W: FnMut(usize) -> Array1<f64>,
D: FnMut(usize) -> Array2<f64>,
M: FnMut(usize) -> LatentManifold,
{
let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
let mut names: Vec<String> = Vec::new();
let mut manifolds: Vec<LatentManifold> = Vec::new();
let mut linear_images: Vec<AtomLinearImage> = Vec::new();
for atom_idx in eligible_d1 {
let atom = &atoms[atom_idx];
let coords = coords_for(atom_idx);
let weights = weights_for(atom_idx);
let decoded = decoded_for(atom_idx);
let curved_num_params = atom.decoder_coefficients.len();
let fitted_turning = atom.basis_evaluator.as_ref().and_then(|evaluator| {
d1_atom_fitted_turning(
evaluator.as_ref(),
atom.decoder_coefficients.view(),
coords.view(),
)
.ok()
.flatten()
});
let Some((linear, curved, (t_bar, b0, b1))) = build_atom_candidates(
coords.view(),
weights.view(),
decoded.view(),
curved_num_params,
fitted_turning,
) else {
continue;
};
let manifold = manifold_for(atom_idx);
let slot = if manifold.is_euclidean() {
vec![linear]
} else {
vec![linear, curved]
};
slots.push(slot);
names.push(atom.name.clone());
manifolds.push(manifold);
linear_images.push(AtomLinearImage {
atom_idx,
t_bar,
b0,
b1,
});
}
if slots.is_empty() {
return Ok(None);
}
let selection = select_hybrid_split(&slots)?;
let verdicts: Vec<AtomHybridVerdict> = names
.into_iter()
.zip(selection.atoms.iter().copied())
.zip(linear_images.into_iter())
.map(|((atom_name, choice), linear_image)| {
let kept_curved = !choice.param.is_linear();
AtomHybridVerdict {
atom_name,
choice,
kept_curved,
linear_image: if kept_curved {
None
} else {
Some(linear_image)
},
}
})
.collect();
Ok(Some(SaeHybridSplitReport {
verdicts,
selection,
}))
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn straight_image_selects_linear_via_dominance_floor() {
let n = 40;
let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
let weights = Array1::<f64>::ones(n);
let mut decoded = Array2::<f64>::zeros((n, 2));
for i in 0..n {
decoded[[i, 0]] = coords[i];
decoded[[i, 1]] = 0.6 * coords[i];
}
let (linear, curved, _) = build_atom_candidates(
coords.view(),
weights.view(),
decoded.view(),
10,
Some(0.0),
)
.expect("straight image yields a candidate pair");
let choice =
crate::solver::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
assert!(
choice.param.is_linear(),
"a straight image must keep the linear special case (Θ = 0 dominance floor)"
);
}
#[test]
fn turning_image_selects_curved_on_evidence() {
let n = 60;
let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
let weights = Array1::<f64>::ones(n);
let mut decoded = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let theta = PI * coords[i];
decoded[[i, 0]] = theta.cos();
decoded[[i, 1]] = theta.sin();
}
let (linear, curved, _) =
build_atom_candidates(coords.view(), weights.view(), decoded.view(), 5, Some(PI))
.expect("turning image yields a candidate pair");
assert!(
linear.negative_log_evidence.is_finite(),
"linear candidate must carry a real deviance"
);
let choice =
crate::solver::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
assert_eq!(
choice.param,
crate::solver::evidence::HybridAtomParam::Curved { latent_dim: 1 },
"a half-circle image must keep the curved parameterization (5 params, large linear RSS)"
);
assert!(
choice.curved_evidence_margin > 0.0,
"curved must win a positive evidence margin over the linear secant"
);
}
#[test]
fn adjudication_is_scale_invariant() {
let n = 40;
let mut decoded = Array2::<f64>::zeros((n, 2));
let weights = Array1::<f64>::ones(n);
for scale_exp in [-3i32, -1, 0, 1, 3] {
let c = 10.0_f64.powi(scale_exp);
let coords =
Array1::from_iter((0..n).map(|i| c * (-1.0 + 2.0 * (i as f64) / ((n - 1) as f64))));
for i in 0..n {
decoded[[i, 0]] = coords[i] / c; decoded[[i, 1]] = 0.6 * coords[i] / c;
}
let (linear, curved, _) =
build_atom_candidates(coords.view(), weights.view(), decoded.view(), 10, Some(0.0))
.expect("straight image always yields a pair");
let choice = crate::solver::evidence::select_hybrid_atom(&[linear, curved])
.expect("non-empty slot");
assert!(
choice.param.is_linear(),
"straight image must select linear at any t-scale (scale={c})"
);
}
let n = 60;
let weights = Array1::<f64>::ones(n);
for scale_exp in [-2i32, -1, 0, 1, 2] {
let c = 10.0_f64.powi(scale_exp);
let coords = Array1::from_iter((0..n).map(|i| c * (i as f64) / ((n - 1) as f64)));
let mut decoded = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let theta = PI * (i as f64) / ((n - 1) as f64); decoded[[i, 0]] = theta.cos();
decoded[[i, 1]] = theta.sin();
}
let (linear, curved, _) = build_atom_candidates(
coords.view(),
weights.view(),
decoded.view(),
5, Some(PI),
)
.expect("curved image always yields a pair");
let choice = crate::solver::evidence::select_hybrid_atom(&[linear, curved])
.expect("non-empty slot");
assert_eq!(
choice.param,
crate::solver::evidence::HybridAtomParam::Curved { latent_dim: 1 },
"curved image must select curved at any t-scale (scale={c})"
);
}
}
#[test]
fn degenerate_coordinate_is_refused() {
let n = 5;
let coords = Array1::<f64>::from_elem(n, 0.5); let weights = Array1::<f64>::ones(n);
let decoded = Array2::<f64>::zeros((n, 2));
assert!(
build_atom_candidates(coords.view(), weights.view(), decoded.view(), 6, Some(0.0))
.is_none(),
"a degenerate coordinate span must be refused"
);
}
}