use super::SaeAtomBasisKind;
use faer::Side;
use gam_linalg::faer_ndarray::{FaerEigh, FaerSvd};
use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2};
const SURPLUS_DIR_FLOOR: f64 = 1.0e-6;
const GOLDEN_RATIO_CONJUGATE: f64 = 0.618_033_988_749_894_9;
const TOPOLOGY_SEED_LAPLACIAN_BUDGET_BYTES: usize = 256 << 20;
fn topology_seed_max_points() -> usize {
let m = ((TOPOLOGY_SEED_LAPLACIAN_BUDGET_BYTES / 16) as f64).sqrt() as usize;
m.max(4)
}
fn topology_seed_knn(n_points: usize, d_atom: usize) -> usize {
let tangent_floor = 2 * d_atom + 1;
let connectivity_floor = (n_points.max(2) as f64).log2().ceil() as usize;
tangent_floor.max(connectivity_floor).max(2)
}
fn is_curved_kind(kind: &SaeAtomBasisKind) -> bool {
matches!(
kind,
SaeAtomBasisKind::Periodic | SaeAtomBasisKind::Torus | SaeAtomBasisKind::Sphere
)
}
fn topology_seed_subsample(n_obs: usize) -> Vec<usize> {
let cap = topology_seed_max_points();
if n_obs <= cap {
return (0..n_obs).collect();
}
let mut rows = Vec::with_capacity(cap);
for i in 0..cap {
rows.push(i * n_obs / cap);
}
rows
}
fn squared_distance_rows(z: ArrayView2<'_, f64>, a: usize, b: usize) -> f64 {
let mut acc = 0.0;
for c in 0..z.ncols() {
let d = z[[a, c]] - z[[b, c]];
acc += d * d;
}
acc
}
fn topology_curved_seed_initial_coords(
z: ArrayView2<'_, f64>,
basis_kinds: &[SaeAtomBasisKind],
atom_dim: &[usize],
pc_pair_offset: usize,
) -> Result<Option<Array3<f64>>, String> {
if !basis_kinds.iter().any(is_curved_kind) || z.nrows() < 4 || z.ncols() == 0 {
return Ok(None);
}
for ((row, col), &value) in z.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_pca_seed: Z must be finite; Z[{row}, {col}] = {value}"
));
}
}
let rows = topology_seed_subsample(z.nrows());
let m = rows.len();
if m < 4 {
return Ok(None);
}
let d_max = atom_dim.iter().copied().max().unwrap_or(1).max(1);
let k = topology_seed_knn(m, d_max).min(m - 1);
let mut w = Array2::<f64>::zeros((m, m));
for (ia, &ra) in rows.iter().enumerate() {
let mut dists = Vec::with_capacity(m - 1);
for (ib, &rb) in rows.iter().enumerate() {
if ia != ib {
dists.push((squared_distance_rows(z, ra, rb), ib));
}
}
dists.sort_by(|a, b| a.0.total_cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
let scale = dists[k.saturating_sub(1)].0.max(1.0e-24);
for &(dist2, ib) in dists.iter().take(k) {
let wij = (-dist2 / scale).exp().max(1.0e-12);
if wij > w[[ia, ib]] {
w[[ia, ib]] = wij;
}
if wij > w[[ib, ia]] {
w[[ib, ia]] = wij;
}
}
}
let mut lap = Array2::<f64>::zeros((m, m));
for i in 0..m {
let deg: f64 = w.row(i).sum();
if deg <= 0.0 || !deg.is_finite() {
return Ok(None);
}
lap[[i, i]] = 1.0;
let inv_sqrt = 1.0 / deg.sqrt();
for j in 0..m {
if i != j && w[[i, j]] != 0.0 {
let deg_j: f64 = w.row(j).sum();
lap[[i, j]] = -w[[i, j]] * inv_sqrt / deg_j.sqrt();
}
}
}
let (evals, evecs) = lap
.eigh(Side::Lower)
.map_err(|err| format!("topology_seed: graph Laplacian eigensolve failed: {err:?}"))?;
if evals.len() < 3 {
return Ok(None);
}
let mut out = Array3::<f64>::zeros((basis_kinds.len(), z.nrows(), d_max));
let interp_k = (d_max + 1).max(2).min(m);
let interp = |sample_values: &Array1<f64>, row: usize| -> f64 {
if let Some(pos) = rows.iter().position(|&r| r == row) {
return sample_values[pos];
}
let mut best: Vec<(f64, usize)> = vec![(f64::INFINITY, 0usize); interp_k];
for (i, &r) in rows.iter().enumerate() {
let d = squared_distance_rows(z, row, r);
if d < best[interp_k - 1].0 {
best[interp_k - 1] = (d, i);
best.sort_by(|a, b| a.0.total_cmp(&b.0));
}
}
let mut num = 0.0;
let mut den = 0.0;
for (d, i) in best {
let ww = 1.0 / d.max(1.0e-24);
num += ww * sample_values[i];
den += ww;
}
num / den
};
let harmonic: Vec<ArrayView1<'_, f64>> = (1..evecs.ncols()).map(|c| evecs.column(c)).collect();
let n_harm = harmonic.len();
let starts = topology_seed_harmonic_starts(basis_kinds, atom_dim);
for atom_idx in 0..basis_kinds.len() {
let d = atom_dim[atom_idx];
let kind = &basis_kinds[atom_idx];
let need = topology_seed_chart_need(kind, d);
if need == 0 || n_harm == 0 {
continue;
}
let start = starts[atom_idx];
let canonical = pc_pair_offset == 0 && start + need <= n_harm;
let fns: Vec<Array1<f64>> = if canonical {
(0..need).map(|i| harmonic[start + i].to_owned()).collect()
} else {
generic_ortho_combos(&harmonic, atom_idx, pc_pair_offset, basis_kinds.len(), need)
};
if fns.is_empty() {
continue;
}
match kind {
SaeAtomBasisKind::Periodic => {
if fns.len() < 2 {
continue;
}
for row in 0..z.nrows() {
let phase =
interp(&fns[1], row).atan2(interp(&fns[0], row)) / std::f64::consts::TAU;
out[[atom_idx, row, 0]] = phase - phase.floor();
}
for axis in 1..d {
let Some(values) = fns.get(axis + 1) else {
break;
};
let (lo, hi) = values
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = hi - lo;
if span > 0.0 && span.is_finite() {
for row in 0..z.nrows() {
out[[atom_idx, row, axis]] = (interp(values, row) - lo) / span - 0.5;
}
}
}
}
SaeAtomBasisKind::Torus => {
for axis in 0..d {
let (Some(a), Some(b)) = (fns.get(2 * axis), fns.get(2 * axis + 1)) else {
break;
};
for row in 0..z.nrows() {
let phase = interp(b, row).atan2(interp(a, row)) / std::f64::consts::TAU;
out[[atom_idx, row, axis]] = phase - phase.floor();
}
}
}
SaeAtomBasisKind::Sphere => {
if fns.len() < 3 {
continue;
}
for row in 0..z.nrows() {
let x = interp(&fns[0], row);
let y = interp(&fns[1], row);
let zz = interp(&fns[2], row);
let norm = (x * x + y * y + zz * zz).sqrt().max(1.0e-24);
if d >= 1 {
out[[atom_idx, row, 0]] = (zz / norm).clamp(-1.0, 1.0).asin();
}
if d >= 2 {
out[[atom_idx, row, 1]] = y.atan2(x);
}
}
}
_ => {
for axis in 0..d {
let Some(values) = fns.get(axis) else {
break;
};
let (lo, hi) = values
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = hi - lo;
if span > 0.0 && span.is_finite() {
for row in 0..z.nrows() {
out[[atom_idx, row, axis]] = (interp(values, row) - lo) / span - 0.5;
}
}
}
}
}
}
Ok(Some(out))
}
fn topology_seed_chart_need(kind: &SaeAtomBasisKind, d: usize) -> usize {
match kind {
SaeAtomBasisKind::Periodic => 2 + d.max(1).saturating_sub(1),
SaeAtomBasisKind::Torus => 2 * d.max(1),
SaeAtomBasisKind::Sphere => 3,
_ => d.max(1),
}
}
fn topology_seed_harmonic_starts(
basis_kinds: &[SaeAtomBasisKind],
atom_dim: &[usize],
) -> Vec<usize> {
let mut next = 0usize;
let mut starts = Vec::with_capacity(basis_kinds.len());
for (kind, &d) in basis_kinds.iter().zip(atom_dim.iter()) {
starts.push(next);
next = next.saturating_add(topology_seed_chart_need(kind, d));
}
starts
}
fn splitmix_unit(mut z: u64) -> f64 {
z = z.wrapping_add(0x9E3779B97F4A7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^= z >> 31;
(z as f64 / u64::MAX as f64) * 2.0 - 1.0
}
fn slot_salt(slot: usize) -> u64 {
match slot {
0 => 0,
1 => 0xD1B54A32D192ED03,
_ => {
let mut z = (slot as u64).wrapping_mul(0x9E3779B97F4A7C15) ^ 0xA0761D6478BD642F;
z = (z ^ (z >> 32)).wrapping_mul(0xE7037ED1A0B428DB);
z ^ (z >> 29)
}
}
}
fn generic_ortho_combos(
sources: &[ArrayView1<'_, f64>],
atom_idx: usize,
retry: usize,
k_atoms: usize,
m: usize,
) -> Vec<Array1<f64>> {
if sources.is_empty() || m == 0 {
return Vec::new();
}
let len = sources[0].len();
let base = ((atom_idx as u64) + (retry as u64) * (k_atoms as u64)) << 20;
let mut out: Vec<Array1<f64>> = Vec::with_capacity(m);
for slot in 0..m {
let salt = slot_salt(slot);
let mut v = Array1::<f64>::zeros(len);
for (pc, src) in sources.iter().enumerate() {
let w = splitmix_unit((base ^ (pc as u64)) ^ salt);
v.scaled_add(w, src);
}
for u in &out {
let proj = v.dot(u);
v.scaled_add(-proj, u);
}
let nv = v.dot(&v).sqrt();
if nv > SURPLUS_DIR_FLOOR {
v.mapv_inplace(|x| x / nv);
out.push(v);
}
}
out
}
fn surplus_phase_plane(
vt: ArrayView2<'_, f64>,
atom_idx: usize,
pc_pair_offset: usize,
k_atoms: usize,
) -> (Array1<f64>, Array1<f64>, bool) {
let ncols = vt.ncols();
let sources: Vec<ArrayView1<'_, f64>> = (0..vt.nrows()).map(|r| vt.row(r)).collect();
let dirs = generic_ortho_combos(&sources, atom_idx, pc_pair_offset, k_atoms, 2);
match dirs.len() {
n if n >= 2 => {
let mut it = dirs.into_iter();
(it.next().unwrap(), it.next().unwrap(), true)
}
1 => (
dirs.into_iter().next().unwrap(),
Array1::zeros(ncols),
false,
),
_ => (Array1::zeros(ncols), Array1::zeros(ncols), false),
}
}
pub fn sae_pca_seed_initial_coords(
z: ArrayView2<'_, f64>,
basis_kinds: &[SaeAtomBasisKind],
atom_dim: &[usize],
) -> Result<Array3<f64>, String> {
sae_pca_seed_initial_coords_with_pc_offset(z, basis_kinds, atom_dim, 0)
}
pub fn sae_pca_seed_initial_coords_with_pc_offset(
z: ArrayView2<'_, f64>,
basis_kinds: &[SaeAtomBasisKind],
atom_dim: &[usize],
pc_pair_offset: usize,
) -> Result<Array3<f64>, String> {
if let Some(seed) =
topology_curved_seed_initial_coords(z, basis_kinds, atom_dim, pc_pair_offset)?
{
return Ok(seed);
}
let k_atoms = basis_kinds.len();
let (n_obs, _p_out) = z.dim();
let d_max = atom_dim.iter().copied().max().unwrap_or(1).max(1);
let mut out = Array3::<f64>::zeros((k_atoms, n_obs, d_max));
if n_obs == 0 || z.ncols() == 0 {
return Ok(out);
}
for ((row, col), &value) in z.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_pca_seed: Z must be finite; Z[{row}, {col}] = {value}"
));
}
}
let mut col_means = Array1::<f64>::zeros(z.ncols());
for col in 0..z.ncols() {
let mut mean = 0.0_f64;
for (count, row) in (0..n_obs).enumerate() {
let x = z[[row, col]];
mean += (x - mean) / (count as f64 + 1.0);
}
col_means[col] = mean;
}
let mut centered = z.to_owned();
for row in 0..n_obs {
for col in 0..z.ncols() {
centered[[row, col]] -= col_means[col];
}
}
for ((row, col), &value) in centered.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_pca_seed: centered Z is non-finite at [{row}, {col}] \
(data span exceeds f64 range); rescale Z before seeding"
));
}
}
let (u_opt, s_vals, vt_opt) = centered
.svd(true, true)
.map_err(|err| format!("sae_pca_seed: SVD failed: {err:?}"))?;
let u = u_opt.ok_or_else(|| "sae_pca_seed: SVD returned no U".to_string())?;
let vt = vt_opt.ok_or_else(|| "sae_pca_seed: SVD returned no Vt".to_string())?;
let vt_rows = vt.nrows();
let u_cols = u.ncols();
let two_pi = std::f64::consts::TAU;
for atom_idx in 0..k_atoms {
let d = atom_dim[atom_idx];
if d == 0 {
continue;
}
match &basis_kinds[atom_idx] {
SaeAtomBasisKind::Periodic => {
if vt_rows >= 1 {
let pc_pairs = vt_rows / 2;
let (pc1_row, pc2_row) = if pc_pairs >= 1 {
let pair = (atom_idx + pc_pair_offset) % pc_pairs;
(2 * pair, 2 * pair + 1)
} else {
(0, 0)
};
let surplus = pc_pairs >= 1 && k_atoms > pc_pairs && atom_idx >= pc_pairs;
let phase_offset = if pc_pairs > 0 && pc_pairs < k_atoms {
atom_idx as f64 / k_atoms as f64
+ pc_pair_offset as f64 * GOLDEN_RATIO_CONJUGATE
} else {
0.0
};
let s0 = s_vals.get(pc1_row).copied().unwrap_or(0.0).abs();
let s1 = s_vals.get(pc2_row).copied().unwrap_or(0.0).abs();
let has_two_dimensional_phase =
vt_rows >= 2 && pc2_row != pc1_row && s1 > 1.0e-10 * s0.max(1.0);
let mut two_dimensional_phase = has_two_dimensional_phase;
let dir1: Array1<f64>;
let dir2: Array1<f64>;
if surplus {
let (a, b, two_d) =
surplus_phase_plane(vt.view(), atom_idx, pc_pair_offset, k_atoms);
dir1 = a;
dir2 = b;
two_dimensional_phase = two_d;
} else {
dir1 = vt.row(pc1_row.min(vt_rows - 1)).to_owned();
dir2 = vt.row(pc2_row.min(vt_rows - 1)).to_owned();
}
let pc1 = dir1.view();
if two_dimensional_phase {
let pc2 = dir2.view();
for row in 0..n_obs {
let mut a = 0.0_f64;
let mut b = 0.0_f64;
for col in 0..centered.ncols() {
a += centered[[row, col]] * pc1[col];
b += centered[[row, col]] * pc2[col];
}
let phase = b.atan2(a) / two_pi + phase_offset;
out[[atom_idx, row, 0]] = phase - phase.floor();
}
} else {
let mut proj = Array1::<f64>::zeros(n_obs);
for row in 0..n_obs {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc1[col];
}
proj[row] = acc;
}
let (min_v, max_v) = proj
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
let phase = (proj[row] - min_v) / span + phase_offset;
out[[atom_idx, row, 0]] = phase - phase.floor();
}
}
}
}
for axis in 1..d {
if axis >= vt_rows {
break;
}
let pc = vt.row(axis);
let mut proj = Array1::<f64>::zeros(n_obs);
for row in 0..n_obs {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc[col];
}
proj[row] = acc;
}
let (min_v, max_v) = proj
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
out[[atom_idx, row, axis]] = (proj[row] - min_v) / span - 0.5;
}
}
}
}
SaeAtomBasisKind::Sphere => {
let n_pc = vt_rows.min(3);
if n_pc == 0 {
continue;
}
let atom_start = atom_idx.saturating_mul(3);
let canonical = pc_pair_offset == 0 && atom_start + 3 <= vt_rows;
let frame: Vec<Array1<f64>> = if canonical {
(0..n_pc)
.map(|i| vt.row(atom_start + i).to_owned())
.collect()
} else {
let sources: Vec<ArrayView1<'_, f64>> =
(0..vt_rows).map(|r| vt.row(r)).collect();
let dirs = generic_ortho_combos(&sources, atom_idx, pc_pair_offset, k_atoms, 3);
if dirs.is_empty() {
continue;
}
dirs
};
for row in 0..n_obs {
let mut amb = [0.0_f64; 3];
for (i, pc) in frame.iter().enumerate().take(3) {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc[col];
}
amb[i] = acc;
}
let norm = (amb[0] * amb[0] + amb[1] * amb[1] + amb[2] * amb[2]).sqrt();
let (x, y, z) = if norm > 0.0 {
(amb[0] / norm, amb[1] / norm, amb[2] / norm)
} else {
(1.0, 0.0, 0.0)
};
let lat = z.clamp(-1.0, 1.0).asin();
let lon = y.atan2(x);
if d >= 1 {
out[[atom_idx, row, 0]] = lat;
}
if d >= 2 {
out[[atom_idx, row, 1]] = lon;
}
}
}
SaeAtomBasisKind::Torus => {
let pc_pairs = vt_rows / 2;
let atom_start = atom_idx.saturating_mul(d);
let canonical = pc_pair_offset == 0 && pc_pairs > 0 && atom_start + d <= pc_pairs;
let generic_dirs: Vec<Array1<f64>> = if canonical {
Vec::new()
} else {
let sources: Vec<ArrayView1<'_, f64>> =
(0..vt_rows).map(|r| vt.row(r)).collect();
generic_ortho_combos(&sources, atom_idx, pc_pair_offset, k_atoms, 2 * d)
};
for axis in 0..d {
let (pc_a, pc_b): (Array1<f64>, Array1<f64>) = if canonical {
let pair = atom_start + axis;
let pc_b_idx = 2 * pair + 1;
if pc_b_idx >= vt_rows {
break;
}
(vt.row(2 * pair).to_owned(), vt.row(pc_b_idx).to_owned())
} else {
match (generic_dirs.get(2 * axis), generic_dirs.get(2 * axis + 1)) {
(Some(a), Some(b)) => (a.clone(), b.clone()),
_ => break,
}
};
for row in 0..n_obs {
let mut a = 0.0_f64;
let mut b = 0.0_f64;
for col in 0..centered.ncols() {
a += centered[[row, col]] * pc_a[col];
b += centered[[row, col]] * pc_b[col];
}
let phase = b.atan2(a) / two_pi;
let wrapped = phase - phase.floor();
out[[atom_idx, row, axis]] = wrapped;
}
}
}
_ => {
let avail = u_cols.min(s_vals.len());
let atom_start = atom_idx.saturating_mul(d);
let canonical = pc_pair_offset == 0 && avail > 0 && atom_start + d <= avail;
if canonical {
let k_cols = d.min(avail);
let mut tmp = Array2::<f64>::zeros((n_obs, d));
for col in 0..k_cols {
let src = atom_start + col;
let s_col = s_vals[src];
for row in 0..n_obs {
tmp[[row, col]] = u[[row, src]] * s_col;
}
}
for col in 0..d {
let mut min_v = f64::INFINITY;
let mut max_v = f64::NEG_INFINITY;
for row in 0..n_obs {
let v = tmp[[row, col]];
if v < min_v {
min_v = v;
}
if v > max_v {
max_v = v;
}
}
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
out[[atom_idx, row, col]] = (tmp[[row, col]] - min_v) / span - 0.5;
}
}
}
} else {
let sources: Vec<ArrayView1<'_, f64>> =
(0..vt_rows).map(|r| vt.row(r)).collect();
let dirs = generic_ortho_combos(&sources, atom_idx, pc_pair_offset, k_atoms, d);
for (col, dir) in dirs.iter().enumerate().take(d) {
let mut proj = Array1::<f64>::zeros(n_obs);
for row in 0..n_obs {
let mut acc = 0.0_f64;
for c in 0..centered.ncols() {
acc += centered[[row, c]] * dir[c];
}
proj[row] = acc;
}
let (min_v, max_v) = proj
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
out[[atom_idx, row, col]] = (proj[row] - min_v) / span - 0.5;
}
}
}
}
}
}
}
Ok(out)
}
pub fn sae_data_row_anchored_euclidean_coords(
residual: ArrayView2<'_, f64>,
atom_dim: &[usize],
anchor_rows: &[usize],
) -> Result<Array3<f64>, String> {
let k_atoms = atom_dim.len();
let (n_obs, p_out) = residual.dim();
let d_max = atom_dim.iter().copied().max().unwrap_or(1).max(1);
let mut out = Array3::<f64>::zeros((k_atoms, n_obs, d_max));
if n_obs == 0 || p_out == 0 {
return Ok(out);
}
if anchor_rows.len() != k_atoms {
return Err(format!(
"sae_data_row_anchored_euclidean_coords: anchor_rows len {} != atoms {k_atoms}",
anchor_rows.len()
));
}
for ((row, col), &value) in residual.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_data_row_anchored_euclidean_coords: residual must be finite; \
residual[{row}, {col}] = {value}"
));
}
}
let mut sim = vec![0.0_f64; n_obs];
for slot in 0..k_atoms {
let d = atom_dim[slot];
if d == 0 {
continue;
}
let base = anchor_rows[slot] % n_obs;
for axis in 0..d {
let anchor = (base + axis) % n_obs;
let mut min_v = f64::INFINITY;
let mut max_v = f64::NEG_INFINITY;
for i in 0..n_obs {
let mut dot = 0.0_f64;
for col in 0..p_out {
dot += residual[[i, col]] * residual[[anchor, col]];
}
sim[i] = dot;
if dot < min_v {
min_v = dot;
}
if dot > max_v {
max_v = dot;
}
}
let span = max_v - min_v;
if span > 0.0 {
for i in 0..n_obs {
out[[slot, i, axis]] = (sim[i] - min_v) / span - 0.5;
}
}
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_vt(rows: usize, cols: usize) -> Array2<f64> {
let mut m = Array2::<f64>::zeros((rows, cols));
for r in 0..rows {
for c in 0..cols {
m[[r, c]] = ((r * 7 + c * 3 + 1) as f64).sin() + 0.1 * (r as f64 - c as f64);
}
}
m
}
#[test]
fn topology_harmonic_windows_are_cumulative_for_mixed_kinds() {
let kinds = vec![
SaeAtomBasisKind::Torus,
SaeAtomBasisKind::Periodic,
SaeAtomBasisKind::Sphere,
SaeAtomBasisKind::Linear,
];
let dims = vec![2usize, 1, 2, 3];
assert_eq!(
topology_seed_harmonic_starts(&kinds, &dims),
vec![0, 4, 6, 9],
"mixed atom kinds must allocate canonical harmonic windows cumulatively; \
atom_idx * per-atom-need overlaps when need varies"
);
}
#[test]
fn surplus_plane_differs_across_retries() {
let vt = make_vt(6, 5);
let k_atoms = 8;
let atom_idx = 5;
let (d1_0, d2_0, ok0) = surplus_phase_plane(vt.view(), atom_idx, 0, k_atoms);
let (d1_1, _d2_1, ok1) = surplus_phase_plane(vt.view(), atom_idx, 1, k_atoms);
assert!(ok0 && ok1, "well-conditioned vt should give a 2-D plane");
let diff = d1_0
.iter()
.zip(d1_1.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
diff > 1e-6,
"distinct retry offsets must yield distinct dir1 (max diff {diff:.3e})"
);
let dot: f64 = d1_0.iter().zip(d2_0.iter()).map(|(a, b)| a * b).sum();
assert!(
dot.abs() < 1e-9,
"dir2 must be orthogonal to dir1 (dot {dot:.3e})"
);
let n2: f64 = d2_0.dot(&d2_0).sqrt();
assert!(
(n2 - 1.0).abs() < 1e-9,
"dir2 must be unit-normalized (norm {n2})"
);
}
#[test]
fn surplus_plane_offset_zero_matches_original_key() {
let vt = make_vt(5, 4);
let k_atoms = 7;
let atom_idx = 6;
let mix = |mut z: u64| -> f64 {
z = z.wrapping_add(0x9E3779B97F4A7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^= z >> 31;
(z as f64 / u64::MAX as f64) * 2.0 - 1.0
};
let ncols = vt.ncols();
let mut a = Array1::<f64>::zeros(ncols);
for pc in 0..vt.nrows() {
let row_pc = vt.row(pc);
let key = ((atom_idx as u64) << 20) ^ (pc as u64);
let wa = mix(key);
for c in 0..ncols {
a[c] += wa * row_pc[c];
}
}
let na = a.dot(&a).sqrt().max(1.0e-12);
a.mapv_inplace(|v| v / na);
let (d1, _d2, _ok) = surplus_phase_plane(vt.view(), atom_idx, 0, k_atoms);
for (r, o) in d1.iter().zip(a.iter()) {
assert!(
(r - o).abs() < 1e-15,
"offset-0 dir1 must match the original key bit-for-bit ({r} vs {o})"
);
}
}
#[test]
fn surplus_plane_collinear_falls_back_to_1d() {
let vt = make_vt(1, 4);
let (d1, _d2, ok) = surplus_phase_plane(vt.view(), 3, 0, 5);
assert!(!ok, "rank-1 span must report a 1-D (non-2-plane) result");
let n1: f64 = d1.dot(&d1).sqrt();
assert!(
(n1 - 1.0).abs() < 1e-9 && n1.is_finite(),
"dir1 must stay a finite unit vector (norm {n1})"
);
}
#[test]
fn surplus_periodic_seed_reproduces_and_diversifies() {
let n_obs = 8;
let p = 4;
let mut zvals = Vec::with_capacity(n_obs * p);
for r in 0..n_obs {
for c in 0..p {
zvals.push(
((r as f64) * 0.9 + 1.0).sin() * ((c + 1) as f64)
+ 0.3 * ((r * c) as f64).cos(),
);
}
}
let z = Array2::from_shape_vec((n_obs, p), zvals).unwrap();
let kinds = vec![SaeAtomBasisKind::Periodic; 5];
let dims = vec![1usize; 5];
let s0 = sae_pca_seed_initial_coords_with_pc_offset(z.view(), &kinds, &dims, 0).unwrap();
let s1 = sae_pca_seed_initial_coords_with_pc_offset(z.view(), &kinds, &dims, 1).unwrap();
let plain = sae_pca_seed_initial_coords(z.view(), &kinds, &dims).unwrap();
assert_eq!(
s0, plain,
"offset-0 must equal the no-offset seed bit-for-bit"
);
for v in s0.iter().chain(s1.iter()) {
assert!(
v.is_finite() && *v >= 0.0 && *v < 1.0,
"periodic phase must be finite in [0, 1): {v}"
);
}
let surplus_atom = 3;
let mut moved = 0.0_f64;
for row in 0..n_obs {
moved = moved.max((s0[[surplus_atom, row, 0]] - s1[[surplus_atom, row, 0]]).abs());
}
assert!(
moved > 1e-6,
"surplus atom must land on a distinct basin across retries (max move {moved:.3e})"
);
}
#[test]
fn data_row_anchored_seed_diversifies_beyond_pc_pool() {
let n = 64usize;
let p = 4usize; let mut residual = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
residual[[i, j]] = ((i * 7 + j) as f64).sin() + 0.25 * ((i + 3 * j) as f64).cos();
}
}
let dims = vec![1usize];
let mut seeds = std::collections::HashSet::new();
for anchor in 0..p {
let s = sae_data_row_anchored_euclidean_coords(residual.view(), &dims, &[anchor])
.expect("data-row seed");
for v in s.iter() {
assert!(
v.is_finite() && *v >= -0.5 - 1e-12 && *v <= 0.5 + 1e-12,
"seed coord must stay in [-0.5, 0.5]: {v}"
);
}
let key: Vec<i64> = (0..n)
.map(|i| (s[[0, i, 0]] * 1e6).round() as i64)
.collect();
seeds.insert(key);
}
assert!(
seeds.len() >= p,
"data-row anchors must give ≥ p distinct seeds; the PCA reseed caps at \
pc_pairs = {} — got {} distinct",
(n.min(p)) / 2,
seeds.len()
);
}
fn structured_z(n: usize, p: usize) -> Array2<f64> {
let mut zvals = Vec::with_capacity(n * p);
for r in 0..n {
for c in 0..p {
zvals.push(
((r as f64) * 0.37 + (c as f64) * 1.1).sin() * ((c + 1) as f64)
+ 0.3 * (((r * 3 + c * 5) as f64) * 0.21).cos()
+ 0.05 * (r as f64 - c as f64),
);
}
}
Array2::from_shape_vec((n, p), zvals).unwrap()
}
fn distinct_fibers(seed: &Array3<f64>) -> usize {
let (k, n, dm) = seed.dim();
let mut set = std::collections::HashSet::new();
for atom in 0..k {
let mut key: Vec<i64> = Vec::with_capacity(n * dm);
for row in 0..n {
for ax in 0..dm {
key.push((seed[[atom, row, ax]] * 1.0e6).round() as i64);
}
}
set.insert(key);
}
set.len()
}
#[test]
fn overcomplete_topology_seeds_pairwise_distinct_all_curved() {
let n = 64usize;
let p = 6usize;
let z = structured_z(n, p);
for (kind, d) in [
(SaeAtomBasisKind::Periodic, 1usize),
(SaeAtomBasisKind::Torus, 2usize),
(SaeAtomBasisKind::Sphere, 2usize),
] {
for mult in [4usize, 40usize] {
let k = mult * p;
let kinds = vec![kind.clone(); k];
let dims = vec![d; k];
let seed = sae_pca_seed_initial_coords(z.view(), &kinds, &dims).unwrap();
for v in seed.iter() {
assert!(v.is_finite(), "{kind:?} K={k}: non-finite seed coord {v}");
}
let distinct = distinct_fibers(&seed);
assert_eq!(
distinct, k,
"{kind:?} K={k} (={mult}·p): every atom's chart must be a distinct \
design — got {distinct}/{k} distinct (duplicate designs ⇒ exact \
Hessian null ⇒ co-collapse)"
);
}
}
}
#[test]
fn overcomplete_flat_linear_seeds_pairwise_distinct() {
let n = 64usize;
let p = 6usize; let z = structured_z(n, p);
for mult in [4usize, 40usize] {
let k = mult * p;
let kinds = vec![SaeAtomBasisKind::Linear; k];
let dims = vec![1usize; k];
let seed = sae_pca_seed_initial_coords(z.view(), &kinds, &dims).unwrap();
let distinct = distinct_fibers(&seed);
assert_eq!(
distinct, k,
"flat K={k} (={mult}·p): surplus atoms must not wrap onto duplicate \
principal-score designs — got {distinct}/{k} distinct"
);
}
}
#[test]
fn overcomplete_linear_curved_fallback_distinct() {
let n = 3usize; let p = 6usize;
let z = structured_z(n, p);
for (kind, d) in [
(SaeAtomBasisKind::Periodic, 1usize),
(SaeAtomBasisKind::Torus, 2usize),
(SaeAtomBasisKind::Sphere, 2usize),
] {
let k = 4 * p;
let kinds = vec![kind.clone(); k];
let dims = vec![d; k];
let seed = sae_pca_seed_initial_coords(z.view(), &kinds, &dims).unwrap();
for v in seed.iter() {
assert!(v.is_finite(), "{kind:?} linear K={k}: non-finite {v}");
}
let distinct = distinct_fibers(&seed);
assert_eq!(
distinct, k,
"{kind:?} linear-fallback K={k}: surplus atoms must be pairwise \
distinct — got {distinct}/{k}"
);
}
}
}