use super::*;
use crate::assignment::{AssignmentMode, SaeAssignment};
use approx::assert_abs_diff_eq;
use gam_solve::inference::residual_factor::{ResidualFactorInput, StructuredResidualModel};
use gam_terms::latent::LatentManifold;
use ndarray::{Array1, Array2};
use super::tests::{TestPeriodicEvaluator, periodic_basis, small_two_atom_periodic_term};
use std::sync::Arc;
#[test]
fn streaming_cache_outer_gradient_matches_dense_cache() {
let (term0, target, rho) = small_two_atom_periodic_term();
let mut dense = term0.clone();
let mut streaming = term0;
let (dense_cost, dense_loss, dense_cache) = dense
.reml_criterion_with_cache(target.view(), &rho, None, 2, 0.25, 1.0e-4, 1.0e-4)
.expect("dense cache criterion");
let (stream_cost, stream_loss, stream_cache) = streaming
.reml_criterion_streaming_exact_with_cache(
target.view(),
&rho,
None,
2,
0.25,
1.0e-4,
1.0e-4,
)
.expect("streaming cache criterion");
assert_abs_diff_eq!(stream_cost, dense_cost, epsilon = 1.0e-8);
let smooth = rho.lambda_smooth_vec();
let dense_solver = dense
.outer_gradient_arrow_solver(&dense_cache, &smooth)
.expect("dense outer-gradient solver");
let dense_grad = dense
.analytic_outer_rho_gradient_components(
target.view(),
&rho,
&dense_loss,
&dense_cache,
&dense_solver,
)
.expect("dense outer-gradient components")
.gradient();
let stream_solver = streaming
.outer_gradient_arrow_solver(&stream_cache, &smooth)
.expect("streaming outer-gradient solver");
let stream_grad = streaming
.analytic_outer_rho_gradient_components(
target.view(),
&rho,
&stream_loss,
&stream_cache,
&stream_solver,
)
.expect("streaming outer-gradient components")
.gradient();
assert_eq!(
dense_grad.len(),
stream_grad.len(),
"streaming outer gradient has a different ρ dimension than the dense one"
);
for (i, (d, s)) in dense_grad.iter().zip(stream_grad.iter()).enumerate() {
assert!(
d.is_finite() && s.is_finite(),
"outer-gradient component {i} must be finite (dense={d}, streaming={s})"
);
assert_abs_diff_eq!(d, s, epsilon = 1.0e-8);
}
let g2: f64 = dense_grad.iter().map(|v| v * v).sum();
assert!(
g2 > 0.0 && g2.is_finite(),
"the dense outer gradient must be non-trivial to make the parity check meaningful; ‖g‖²={g2}"
);
assert_abs_diff_eq!(stream_loss.total(), dense_loss.total(), epsilon = 1.0e-8);
}
fn lcg_uniform(s: &mut u64) -> f64 {
*s = s
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((*s >> 11) as f64) / ((1u64 << 53) as f64)
}
fn lcg_normal(s: &mut u64) -> f64 {
let u1 = lcg_uniform(s).max(1e-12);
let u2 = lcg_uniform(s);
(-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
}
fn build_softmax_term(n: usize, p: usize, k: usize) -> SaeManifoldTerm {
let coord_cols: Vec<Array2<f64>> = (0..k)
.map(|i| {
Array2::<f64>::from_shape_fn((n, 1), |(r, _)| {
(0.03 + 0.11 * i as f64 + 0.017 * r as f64).rem_euclid(1.0)
})
})
.collect();
let atoms: Vec<SaeManifoldAtom> = (0..k)
.map(|i| {
let (phi, jet) = periodic_basis(&coord_cols[i]);
let f = (i as f64) + 1.0;
let decoder = Array2::<f64>::from_shape_fn((3, p), |(m, c)| {
0.1 * f * ((m + 1) as f64) - 0.05 * (c as f64) + 0.02 * f
});
SaeManifoldAtom::new(
format!("atom{i}"),
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator))
})
.collect();
let manifolds = vec![LatentManifold::Circle { period: 1.0 }; k];
let logits =
Array2::<f64>::from_shape_fn((n, k), |(r, c)| 0.3 * (c as f64) - 0.1 * (r as f64) + 0.2);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coord_cols,
manifolds,
AssignmentMode::softmax(0.8),
)
.unwrap();
SaeManifoldTerm::new(atoms, assignment).unwrap()
}
fn fit_structured_metric(n: usize, p: usize) -> gam_problem::RowMetric {
let lam = [1.0_f64, -0.7, 0.4, 0.9, -0.5];
let dscale = [0.10_f64, 0.55, 0.95, 0.30, 0.70];
let mut seed = 0x2026_00D5_1234_ABCDu64;
let mut residuals = Array2::<f64>::zeros((n, p));
let mut activity = Array1::<f64>::zeros(n);
for row in 0..n {
let common = lcg_normal(&mut seed);
activity[row] = 0.25 + (row as f64) / (n as f64);
let amp = activity[row].sqrt();
for i in 0..p {
residuals[[row, i]] = amp * lam[i % lam.len()] * common
+ dscale[i % dscale.len()] * lcg_normal(&mut seed);
}
}
let model = StructuredResidualModel::fit(ResidualFactorInput {
residuals: residuals.view(),
activity: activity.view(),
max_factor_rank: 2,
})
.expect("StructuredResidualModel::fit");
model.row_metric(n).expect("row_metric")
}
#[test]
fn wide_border_k32_p128_plan_routes_to_streaming() {
let (n, p, k, d_max) = (500usize, 128usize, 32usize, 1usize);
let total_basis = 2 * k; let border_dim = total_basis * p;
let budget = 2 * 1024 * 1024 * 1024usize; let host_available = 8 * 1024 * 1024 * 1024usize;
let chunk_window = SAE_CPU_L2_CACHE_BYTES * SAE_CHUNK_CACHE_MULTIPLE;
let plan = sae_streaming_plan_from_budget(
n,
total_basis,
k,
d_max,
border_dim,
budget,
chunk_window,
host_available,
);
assert!(
!plan.direct_admitted,
"the dense direct evidence peak ({} bytes) must exceed the 2 GiB budget so the \
criterion routes to streaming",
plan.estimated_direct_peak_bytes
);
assert!(
plan.matrix_free_admitted,
"the matrix-free plan ({} bytes) must be admitted so the fit has a route",
plan.estimated_matrix_free_peak_bytes
);
assert!(
plan.streaming,
"a non-direct-admitted plan must select streaming"
);
plan.admitted_or_error(n, border_dim, k)
.expect("matrix-free-admitted plan must not hard-error at the admission gate");
}
#[test]
fn whitened_streaming_criterion_completes() {
let (n, p, k) = (128usize, 16usize, 8usize);
let mut term = build_softmax_term(n, p, k);
let metric = fit_structured_metric(n, p);
assert!(
metric.whitens_likelihood(),
"the fitted structured-residual metric must whiten the likelihood"
);
term.set_row_metric(metric).unwrap();
let target = Array2::<f64>::from_shape_fn((n, p), |(r, c)| {
0.4 - 0.15 * (r as f64 / n as f64)
+ 0.25 * (c as f64 / p as f64)
+ 0.05 * (((r + c) % 7) as f64)
});
let rho = SaeManifoldRho::new(
-1.0_f64,
0.7_f64.ln(),
vec![Array1::<f64>::from_elem(1, 0.0); k],
);
let (cost, loss) = term
.reml_criterion_streaming_exact(target.view(), &rho, None, 2, 0.25, 1.0e-4, 1.0e-4)
.expect("whitened streaming criterion must complete, not hard-error");
assert!(
cost.is_finite(),
"streaming REML criterion must be finite; got {cost}"
);
assert!(
loss.total().is_finite() && loss.data_fit.is_finite(),
"whitened loss components must be finite (data_fit={}, total={})",
loss.data_fit,
loss.total()
);
}