use super::*;
use ndarray::array;
use std::sync::Arc;
#[test]
pub(crate) fn arrow_schur_assembly_is_faer_parallelism_invariant_1557() {
use ndarray::Array2;
let n = 128usize;
let p = 32usize;
let k = 8usize;
let m = 5usize; let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap());
let mut atoms = Vec::with_capacity(k);
let mut coord_blocks = Vec::with_capacity(k);
for atom_idx in 0..k {
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| {
((row as f64 * 0.013 + atom_idx as f64 * 0.071) % 1.0).fract()
});
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let decoder = Array2::<f64>::from_shape_fn((m, p), |(i, j)| {
0.1 * ((i as f64 + 1.0) * 0.3 - (j as f64) * 0.017 + atom_idx as f64 * 0.05).sin()
});
let atom = SaeManifoldAtom::new(
format!("periodic_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(evaluator.clone());
atoms.push(atom);
coord_blocks.push(coords);
}
let logits = Array2::<f64>::from_shape_fn((n, k), |(row, col)| {
0.5 * ((row as f64) * 0.021 + (col as f64) * 0.37).sin()
});
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coord_blocks,
vec![LatentManifold::Circle { period: 1.0 }; k],
AssignmentMode::softmax(0.8),
)
.unwrap();
let mut term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let target = Array2::<f64>::from_shape_fn((n, p), |(row, col)| {
0.05 * ((row as f64) * 0.011 - (col as f64) * 0.023).cos()
});
let rho = SaeManifoldRho::new(
(-0.3_f64).exp().ln(),
0.7_f64.ln(),
vec![array![0.9_f64.ln()]; k],
);
let entry_par = faer::get_global_parallelism();
let assemble = |term: &mut SaeManifoldTerm, par: faer::Par| {
faer::set_global_parallelism(par);
term.assemble_arrow_schur(target.view(), &rho, None)
.expect("arrow-Schur assembly must succeed")
};
let seq = assemble(&mut term, faer::Par::Seq);
let par = assemble(&mut term, faer::Par::rayon(4));
faer::set_global_parallelism(entry_par);
assert_eq!(seq.gb.len(), par.gb.len(), "gb length mismatch");
for (i, (&s, &q)) in seq.gb.iter().zip(par.gb.iter()).enumerate() {
assert_eq!(
s.to_bits(),
q.to_bits(),
"gb[{i}] not bit-identical across faer parallelism (Seq={s}, rayon={q})"
);
}
assert_eq!(seq.rows.len(), par.rows.len(), "row count mismatch");
assert_eq!(seq.rows.len(), n, "expected n assembled rows");
for (row, (rs, rq)) in seq.rows.iter().zip(par.rows.iter()).enumerate() {
assert_eq!(rs.gt.len(), rq.gt.len(), "row {row} gt len mismatch");
for (a, (&s, &q)) in rs.gt.iter().zip(rq.gt.iter()).enumerate() {
assert_eq!(
s.to_bits(),
q.to_bits(),
"row {row} gt[{a}] not bit-identical (Seq={s}, rayon={q})"
);
}
assert_eq!(rs.htt.dim(), rq.htt.dim(), "row {row} htt dim mismatch");
for ((i, j), &s) in rs.htt.indexed_iter() {
let q = rq.htt[[i, j]];
assert_eq!(
s.to_bits(),
q.to_bits(),
"row {row} htt[{i},{j}] not bit-identical (Seq={s}, rayon={q})"
);
}
assert_eq!(
rs.htbeta.dim(),
rq.htbeta.dim(),
"row {row} htbeta dim mismatch"
);
for ((i, j), &s) in rs.htbeta.indexed_iter() {
let q = rq.htbeta[[i, j]];
assert_eq!(
s.to_bits(),
q.to_bits(),
"row {row} htbeta[{i},{j}] not bit-identical (Seq={s}, rayon={q})"
);
}
}
}