use super::tests::small_two_atom_periodic_term;
use approx::assert_abs_diff_eq;
#[test]
fn streaming_cache_is_efs_dropin_for_dense_cache_1026() {
let (term0, target, rho) = small_two_atom_periodic_term();
let mut dense = term0.clone();
let mut streaming = term0;
let (dense_cost, dense_loss, dense_cache) = dense
.reml_criterion_with_cache(target.view(), &rho, None, 2, 0.25, 1.0e-4, 1.0e-4)
.unwrap();
let (stream_cost, stream_loss, stream_cache) = streaming
.reml_criterion_streaming_exact_with_cache(
target.view(),
&rho,
None,
2,
0.25,
1.0e-4,
1.0e-4,
)
.unwrap();
assert_abs_diff_eq!(stream_cost, dense_cost, epsilon = 1.0e-8);
assert_abs_diff_eq!(stream_loss.total(), dense_loss.total(), epsilon = 1.0e-8);
let dense_traces = dense.ard_inverse_traces(&dense_cache).unwrap();
let stream_traces = streaming.ard_inverse_traces(&stream_cache).unwrap();
assert_eq!(
dense_traces.len(),
stream_traces.len(),
"streaming cache yields a different ARD-trace atom count than the dense cache"
);
for (k, (d, s)) in dense_traces.iter().zip(stream_traces.iter()).enumerate() {
assert_eq!(d.len(), s.len(), "ARD-trace latent dim mismatch at atom {k}");
for (dv, sv) in d.iter().zip(s.iter()) {
assert_abs_diff_eq!(dv, sv, epsilon = 1.0e-8);
}
}
let dense_disp = dense
.reconstruction_dispersion(&dense_loss, &dense_cache, &rho)
.unwrap();
let stream_disp = streaming
.reconstruction_dispersion(&stream_loss, &stream_cache, &rho)
.unwrap();
assert_abs_diff_eq!(dense_disp, stream_disp, epsilon = 1.0e-8);
}