gam-sae 0.3.130

Sparse-autoencoder latent-manifold terms for the gam penalized-likelihood engine
Documentation
//! `sae_row_jet_program_matches_production_row_jets_on_converged_cache` and
//! `ibp_map_outer_objective_advertises_analytic_gradient`, split verbatim out
//! of `tests.rs` to keep that tracked file under the #780 10k-line gate.
//! Declared as a sibling `#[cfg(test)] mod` in `mod.rs`; the shared
//! `gamma_fd_tiny_fixture` is sourced from the sibling `tests` module.

use super::*;
use super::tests::gamma_fd_tiny_fixture;

/// #932 follow-up (the issue-comment cache-seam ask): the SAE row
/// jet-program oracle driven directly from a CONVERGED production
/// `ArrowFactorCache`, not a mirrored test layout.
///
/// For every row of the converged tiny fixture, the production
/// `row_jets_for_logdet` channels — the exact `first`/`second` tensors the
/// #1006 `logdet_theta_adjoint` contracts — are rebuilt as a
/// [`SaeReconstructionRowProgram`] from the SAME production inputs (the
/// term's basis value/jacobian tensors, `atom_second_jets`, decoder
/// blocks, gate logits/assignments, and the cache's own
/// `row_vars_for_cache_row` primary layout) and compared column by
/// column. The hand path sums sparse cross terms per (logit, coord)
/// variable pair; the tower derives them by Leibniz from one expression —
/// independent arithmetic, so agreement is a correctness proof of the
/// production packing on a real converged state. The `weighted` arm
/// exercises the #977 `set_row_loss_weights` √w seam, which scales every
/// production channel by `sqrt(w_row)`.
#[test]
pub(crate) fn sae_row_jet_program_matches_production_row_jets_on_converged_cache() {
    use crate::row_jet_program::{
        AtomRowBasisJet, RowGate, SaeReconstructionRowProgram,
    };

    // Tiny-fixture row arity: softmax gauges the last logit as the fixed
    // reference (assignment_coord_dim = k_atoms − 1 = 1 free logit), plus
    // 2 atoms × 1 latent coord.
    const K: usize = 3;
    for weighted in [false, true] {
        let (mut term, target, rho) = gamma_fd_tiny_fixture();
        if weighted {
            let weights: Vec<f64> = (0..term.n_obs())
                .map(|row| 0.5 + 0.17 * row as f64)
                .collect();
            term.set_row_loss_weights(weights)
                .expect("set row loss weights");
        }
        let (_value, _loss, cache) = term
            .reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
            .expect("converged cache");
        let second_jets = term.atom_second_jets().expect("second jets");
        let border = term
            .border_channels_for_cache(&cache)
            .expect("border channels");
        let AssignmentMode::Softmax { temperature, .. } = term.assignment.mode else {
            panic!("gamma fixture is softmax-gated");
        };
        let inv_tau = 1.0 / temperature;
        let p = term.output_dim();
        let k_atoms = term.k_atoms();

        for row in 0..term.n_obs() {
            let vars = term.row_vars_for_cache_row(row, &cache).expect("row vars");
            assert_eq!(
                vars.len(),
                K,
                "tiny fixture rows carry 1 free softmax logit + 2 coords"
            );
            let assignments = term
                .assignment
                .try_assignments_row(row)
                .expect("assignments row");
            let jets = term
                .row_jets_for_logdet(
                    &rho,
                    row,
                    vars.clone(),
                    assignments.view(),
                    &second_jets,
                    &border,
                )
                .expect("production row jets");

            // Primary layout exactly as the cache rows it: slot positions
            // come from the production `row_vars_for_cache_row`, not a
            // re-derived convention.
            let mut logit_slot = vec![None; k_atoms];
            let mut coord_slot: Vec<Vec<usize>> = term
                .atoms
                .iter()
                .map(|atom| vec![usize::MAX; atom.latent_dim])
                .collect();
            for (pos, var) in vars.iter().enumerate() {
                match *var {
                    SaeLocalRowVar::Logit { atom } => logit_slot[atom] = Some(pos),
                    SaeLocalRowVar::Coord { atom, axis } => coord_slot[atom][axis] = pos,
                }
            }

            // Per-atom basis jets straight from the production tensors the
            // hand path consumes: basis_values / basis_jacobian /
            // atom_second_jets / decoder_coefficients.
            let atoms: Vec<AtomRowBasisJet> = term
                .atoms
                .iter()
                .enumerate()
                .map(|(k, atom)| {
                    let m = atom.basis_size();
                    let d = atom.latent_dim;
                    AtomRowBasisJet {
                        phi: (0..m).map(|b| atom.basis_values[[row, b]]).collect(),
                        d_phi: (0..m)
                            .map(|b| {
                                (0..d)
                                    .map(|axis| atom.basis_jacobian[[row, b, axis]])
                                    .collect()
                            })
                            .collect(),
                        d2_phi: (0..m)
                            .map(|b| {
                                (0..d)
                                    .map(|aa| {
                                        (0..d).map(|bb| second_jets[k][[row, b, aa, bb]]).collect()
                                    })
                                    .collect()
                            })
                            .collect(),
                        decoder: (0..m)
                            .map(|b| (0..p).map(|c| atom.decoder_coefficients[[b, c]]).collect())
                            .collect(),
                        latent_dim: d,
                    }
                })
                .collect();

            let prog = SaeReconstructionRowProgram {
                atoms,
                gate_value: assignments.to_vec(),
                logits: term.assignment.logits.row(row).to_vec(),
                gate_scale: vec![1.0; k_atoms],
                gate_shift: vec![0.0; k_atoms],
                gate: RowGate::Softmax { inv_tau },
                logit_slot,
                coord_slot,
                n_primaries: K,
            };
            // The production channels carry the √w row-loss weight (#977
            // single seam); the program is the unweighted reconstruction.
            let sqrt_row_w = term
                .row_loss_weights
                .as_deref()
                .map_or(1.0, |w| w[row].sqrt());
            if weighted {
                assert!(
                    (sqrt_row_w - 1.0).abs() > 1e-6,
                    "weighted arm must exercise a non-unit √w (row {row}, √w={sqrt_row_w})"
                );
            }

            for out_col in 0..p {
                let tower = prog.reconstruction_column::<K>(out_col);
                let g_floor = (0..K)
                    .map(|a| jets.first[a][out_col].abs())
                    .fold(1e-12_f64, f64::max);
                let h_floor = (0..K)
                    .flat_map(|a| (0..K).map(move |b| (a, b)))
                    .map(|(a, b)| jets.second[a][b][out_col].abs())
                    .fold(1e-12_f64, f64::max);
                for a in 0..K {
                    let want = sqrt_row_w * tower.g[a];
                    assert!(
                        (jets.first[a][out_col] - want).abs() <= 1e-9 * g_floor,
                        "weighted={weighted} row {row} col {out_col} first[{a}]: \
                             production {} vs tower {}",
                        jets.first[a][out_col],
                        want
                    );
                    for b in 0..K {
                        let want2 = sqrt_row_w * tower.h[a][b];
                        assert!(
                            (jets.second[a][b][out_col] - want2).abs() <= 1e-9 * h_floor,
                            "weighted={weighted} row {row} col {out_col} \
                                 second[{a}][{b}]: production {} vs tower {}",
                            jets.second[a][b][out_col],
                            want2
                        );
                    }
                }
            }

            // β BORDER CHANNELS (#932): the hand path packs `beta`
            // (value ∂ẑ_c/∂β = ζ_k·Φ_b·output_c) and `beta_deriv` /
            // `beta_l_deriv` (the mixed ∂²ẑ_c/∂β∂p_a = ∂(ζ_k·Φ_b)/∂p_a·output_c)
            // term by term in `row_jets_for_logdet`, with NO tower oracle
            // previously. The arrow β coefficient multiplies the channel's
            // (frame / identity) `output` vector — NOT the current decoder
            // matrix — so the local-variable dependence is exactly
            // s = ζ_k(ℓ)·Φ_b(t_k) = `beta_border_tower` (built from the SAME
            // gate_tower / basis_tower primitives as the reconstruction column);
            // production multiplies that scalar by `channel.output[c]·√w`. Pin
            // every β channel (value + both mixed-derivative arrays) to it at
            // ~1e-9.
            for (beta_pos, channel) in border.iter().enumerate() {
                // The β border channel's LOCAL-variable dependence is
                // s = ζ_k(ℓ)·Φ_b(t_k); the production packing multiplies that
                // scalar by the channel's (frame / identity) `output[c]` — NOT
                // the decoder matrix — and by √w.
                let s = prog.beta_border_tower::<K>(channel.atom, channel.basis_col);
                for out_col in 0..p {
                    let out_c = channel.output[out_col];
                    let want_v = sqrt_row_w * s.v * out_c;
                    let v_floor = want_v.abs().max(1e-12);
                    assert!(
                        (jets.beta[beta_pos][out_col] - want_v).abs() <= 1e-9 * v_floor,
                        "weighted={weighted} row {row} col {out_col} \
                         beta[{beta_pos}] (atom {} basis {}): production {} vs tower {}",
                        channel.atom,
                        channel.basis_col,
                        jets.beta[beta_pos][out_col],
                        want_v
                    );
                    for a in 0..K {
                        let want_d = sqrt_row_w * s.g[a] * out_c;
                        let d_floor = want_d.abs().max(1e-12);
                        // `beta_deriv` and `beta_l_deriv` are the SAME mixed
                        // ∂²ẑ_c/∂β∂p_a derivative the linear-in-β reconstruction
                        // produces (the hand path fills both identically); both
                        // must equal the tower's first-derivative channel × out_c.
                        assert!(
                            (jets.beta_deriv[a][beta_pos][out_col] - want_d).abs()
                                <= 1e-9 * d_floor,
                            "weighted={weighted} row {row} col {out_col} \
                             beta_deriv[{a}][{beta_pos}]: production {} vs tower {}",
                            jets.beta_deriv[a][beta_pos][out_col],
                            want_d
                        );
                        assert!(
                            (jets.beta_l_deriv[a][beta_pos][out_col] - want_d).abs()
                                <= 1e-9 * d_floor,
                            "weighted={weighted} row {row} col {out_col} \
                             beta_l_deriv[{a}][{beta_pos}]: production {} vs tower {}",
                            jets.beta_l_deriv[a][beta_pos][out_col],
                            want_d
                        );
                    }
                }
            }
        }
    }
}

#[test]
pub(crate) fn ibp_map_outer_objective_advertises_analytic_gradient() {
    // The IBP-MAP empirical-π third channel (including the cross-row M_k
    // coupling) is now assembled exactly in `logdet_theta_adjoint` (#1006),
    // so the outer objective advertises an analytic gradient like every
    // other assignment mode.
    let (mut term, target, rho) = gamma_fd_tiny_fixture();
    term.assignment.mode = AssignmentMode::ibp_map(0.9, 1.0, false);

    let obj = SaeManifoldOuterObjective::new(term, target, None, rho, 5, 0.4, 1.0e-6, 1.0e-6);
    assert_eq!(obj.capability().gradient, Derivative::Analytic);
}