use super::SaeAtomBasisKind;
use gam_linalg::faer_ndarray::FaerSvd;
use ndarray::{Array1, Array2, Array3, ArrayView2};
const SURPLUS_DIR_FLOOR: f64 = 1.0e-6;
const GOLDEN_RATIO_CONJUGATE: f64 = 0.618_033_988_749_894_9;
fn surplus_phase_plane(
vt: ArrayView2<'_, f64>,
atom_idx: usize,
pc_pair_offset: usize,
k_atoms: usize,
) -> (Array1<f64>, Array1<f64>, bool) {
let vt_rows = vt.nrows();
let ncols = vt.ncols();
let mix = |mut z: u64| -> f64 {
z = z.wrapping_add(0x9E3779B97F4A7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^= z >> 31;
(z as f64 / u64::MAX as f64) * 2.0 - 1.0
};
let mut a = Array1::<f64>::zeros(ncols);
let mut b = Array1::<f64>::zeros(ncols);
for pc in 0..vt_rows {
let row_pc = vt.row(pc);
let key = (((atom_idx as u64) + (pc_pair_offset as u64) * (k_atoms as u64)) << 20)
^ (pc as u64);
let wa = mix(key);
let wb = mix(key ^ 0xD1B54A32D192ED03);
for c in 0..ncols {
a[c] += wa * row_pc[c];
b[c] += wb * row_pc[c];
}
}
let na = a.dot(&a).sqrt().max(1.0e-12);
a.mapv_inplace(|v| v / na);
let nb = b.dot(&b).sqrt().max(1.0e-12);
b.mapv_inplace(|v| v / nb);
let proj = b.dot(&a);
b.scaled_add(-proj, &a);
let nb_res = b.dot(&b).sqrt();
if nb_res > SURPLUS_DIR_FLOOR {
b.mapv_inplace(|v| v / nb_res);
(a, b, true)
} else {
(a, b, false)
}
}
pub fn sae_pca_seed_initial_coords(
z: ArrayView2<'_, f64>,
basis_kinds: &[SaeAtomBasisKind],
atom_dim: &[usize],
) -> Result<Array3<f64>, String> {
sae_pca_seed_initial_coords_with_pc_offset(z, basis_kinds, atom_dim, 0)
}
pub fn sae_pca_seed_initial_coords_with_pc_offset(
z: ArrayView2<'_, f64>,
basis_kinds: &[SaeAtomBasisKind],
atom_dim: &[usize],
pc_pair_offset: usize,
) -> Result<Array3<f64>, String> {
let k_atoms = basis_kinds.len();
let (n_obs, _p_out) = z.dim();
let d_max = atom_dim.iter().copied().max().unwrap_or(1).max(1);
let mut out = Array3::<f64>::zeros((k_atoms, n_obs, d_max));
if n_obs == 0 || z.ncols() == 0 {
return Ok(out);
}
for ((row, col), &value) in z.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_pca_seed: Z must be finite; Z[{row}, {col}] = {value}"
));
}
}
let mut col_means = Array1::<f64>::zeros(z.ncols());
for col in 0..z.ncols() {
let mut mean = 0.0_f64;
for (count, row) in (0..n_obs).enumerate() {
let x = z[[row, col]];
mean += (x - mean) / (count as f64 + 1.0);
}
col_means[col] = mean;
}
let mut centered = z.to_owned();
for row in 0..n_obs {
for col in 0..z.ncols() {
centered[[row, col]] -= col_means[col];
}
}
for ((row, col), &value) in centered.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_pca_seed: centered Z is non-finite at [{row}, {col}] \
(data span exceeds f64 range); rescale Z before seeding"
));
}
}
let (u_opt, s_vals, vt_opt) = centered
.svd(true, true)
.map_err(|err| format!("sae_pca_seed: SVD failed: {err:?}"))?;
let u = u_opt.ok_or_else(|| "sae_pca_seed: SVD returned no U".to_string())?;
let vt = vt_opt.ok_or_else(|| "sae_pca_seed: SVD returned no Vt".to_string())?;
let vt_rows = vt.nrows();
let u_cols = u.ncols();
let two_pi = std::f64::consts::TAU;
for atom_idx in 0..k_atoms {
let d = atom_dim[atom_idx];
if d == 0 {
continue;
}
match &basis_kinds[atom_idx] {
SaeAtomBasisKind::Periodic => {
if vt_rows >= 1 {
let pc_pairs = vt_rows / 2;
let (pc1_row, pc2_row) = if pc_pairs >= 1 {
let pair = (atom_idx + pc_pair_offset) % pc_pairs;
(2 * pair, 2 * pair + 1)
} else {
(0, 0)
};
let surplus = pc_pairs >= 1 && k_atoms > pc_pairs && atom_idx >= pc_pairs;
let phase_offset = if pc_pairs > 0 && pc_pairs < k_atoms {
atom_idx as f64 / k_atoms as f64
+ pc_pair_offset as f64 * GOLDEN_RATIO_CONJUGATE
} else {
0.0
};
let s0 = s_vals.get(pc1_row).copied().unwrap_or(0.0).abs();
let s1 = s_vals.get(pc2_row).copied().unwrap_or(0.0).abs();
let has_two_dimensional_phase =
vt_rows >= 2 && pc2_row != pc1_row && s1 > 1.0e-10 * s0.max(1.0);
let mut two_dimensional_phase = has_two_dimensional_phase;
let dir1: Array1<f64>;
let dir2: Array1<f64>;
if surplus {
let (a, b, two_d) =
surplus_phase_plane(vt.view(), atom_idx, pc_pair_offset, k_atoms);
dir1 = a;
dir2 = b;
two_dimensional_phase = two_d;
} else {
dir1 = vt.row(pc1_row.min(vt_rows - 1)).to_owned();
dir2 = vt.row(pc2_row.min(vt_rows - 1)).to_owned();
}
let pc1 = dir1.view();
if two_dimensional_phase {
let pc2 = dir2.view();
for row in 0..n_obs {
let mut a = 0.0_f64;
let mut b = 0.0_f64;
for col in 0..centered.ncols() {
a += centered[[row, col]] * pc1[col];
b += centered[[row, col]] * pc2[col];
}
let phase = b.atan2(a) / two_pi + phase_offset;
out[[atom_idx, row, 0]] = phase - phase.floor();
}
} else {
let mut proj = Array1::<f64>::zeros(n_obs);
for row in 0..n_obs {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc1[col];
}
proj[row] = acc;
}
let (min_v, max_v) = proj
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
let phase = (proj[row] - min_v) / span + phase_offset;
out[[atom_idx, row, 0]] = phase - phase.floor();
}
}
}
}
for axis in 1..d {
if axis >= vt_rows {
break;
}
let pc = vt.row(axis);
let mut proj = Array1::<f64>::zeros(n_obs);
for row in 0..n_obs {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc[col];
}
proj[row] = acc;
}
let (min_v, max_v) = proj
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
out[[atom_idx, row, axis]] = (proj[row] - min_v) / span - 0.5;
}
}
}
}
SaeAtomBasisKind::Sphere => {
let n_pc = vt_rows.min(3);
if n_pc == 0 {
continue;
}
let base = if vt_rows > 0 {
(2 * pc_pair_offset) % vt_rows
} else {
0
};
let pcs: Vec<_> = (0..n_pc).map(|i| vt.row((base + i) % vt_rows)).collect();
for row in 0..n_obs {
let mut amb = [0.0_f64; 3];
for (i, pc) in pcs.iter().enumerate() {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc[col];
}
amb[i] = acc;
}
let norm = (amb[0] * amb[0] + amb[1] * amb[1] + amb[2] * amb[2]).sqrt();
let (x, y, z) = if norm > 0.0 {
(amb[0] / norm, amb[1] / norm, amb[2] / norm)
} else {
(1.0, 0.0, 0.0)
};
let lat = z.clamp(-1.0, 1.0).asin();
let lon = y.atan2(x);
if d >= 1 {
out[[atom_idx, row, 0]] = lat;
}
if d >= 2 {
out[[atom_idx, row, 1]] = lon;
}
}
}
SaeAtomBasisKind::Torus => {
let pc_pairs = vt_rows / 2;
for axis in 0..d {
let pair = if pc_pair_offset != 0 && pc_pairs > 0 {
(axis + pc_pair_offset) % pc_pairs
} else {
axis
};
let pc_a_idx = 2 * pair;
let pc_b_idx = 2 * pair + 1;
if pc_b_idx >= vt_rows {
break;
}
let pc_a = vt.row(pc_a_idx);
let pc_b = vt.row(pc_b_idx);
for row in 0..n_obs {
let mut a = 0.0_f64;
let mut b = 0.0_f64;
for col in 0..centered.ncols() {
a += centered[[row, col]] * pc_a[col];
b += centered[[row, col]] * pc_b[col];
}
let phase = b.atan2(a) / two_pi;
let wrapped = phase - phase.floor();
out[[atom_idx, row, axis]] = wrapped;
}
}
}
_ => {
let avail = u_cols.min(s_vals.len());
let k_cols = d.min(avail);
let base = if avail > 0 {
(2 * pc_pair_offset) % avail
} else {
0
};
let atom_pc_offset = atom_idx.saturating_mul(d);
let mut tmp = Array2::<f64>::zeros((n_obs, d));
for col in 0..k_cols {
let src = if avail > 0 {
(base + atom_pc_offset + col) % avail
} else {
col
};
let s_col = s_vals[src];
for row in 0..n_obs {
tmp[[row, col]] = u[[row, src]] * s_col;
}
}
for col in 0..d {
let mut min_v = f64::INFINITY;
let mut max_v = f64::NEG_INFINITY;
for row in 0..n_obs {
let v = tmp[[row, col]];
if v < min_v {
min_v = v;
}
if v > max_v {
max_v = v;
}
}
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
out[[atom_idx, row, col]] = (tmp[[row, col]] - min_v) / span - 0.5;
}
}
}
}
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_vt(rows: usize, cols: usize) -> Array2<f64> {
let mut m = Array2::<f64>::zeros((rows, cols));
for r in 0..rows {
for c in 0..cols {
m[[r, c]] = ((r * 7 + c * 3 + 1) as f64).sin() + 0.1 * (r as f64 - c as f64);
}
}
m
}
#[test]
fn surplus_plane_differs_across_retries() {
let vt = make_vt(6, 5);
let k_atoms = 8;
let atom_idx = 5;
let (d1_0, d2_0, ok0) = surplus_phase_plane(vt.view(), atom_idx, 0, k_atoms);
let (d1_1, _d2_1, ok1) = surplus_phase_plane(vt.view(), atom_idx, 1, k_atoms);
assert!(ok0 && ok1, "well-conditioned vt should give a 2-D plane");
let diff = d1_0
.iter()
.zip(d1_1.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
diff > 1e-6,
"distinct retry offsets must yield distinct dir1 (max diff {diff:.3e})"
);
let dot: f64 = d1_0.iter().zip(d2_0.iter()).map(|(a, b)| a * b).sum();
assert!(dot.abs() < 1e-9, "dir2 must be orthogonal to dir1 (dot {dot:.3e})");
let n2: f64 = d2_0.dot(&d2_0).sqrt();
assert!((n2 - 1.0).abs() < 1e-9, "dir2 must be unit-normalized (norm {n2})");
}
#[test]
fn surplus_plane_offset_zero_matches_original_key() {
let vt = make_vt(5, 4);
let k_atoms = 7;
let atom_idx = 6;
let mix = |mut z: u64| -> f64 {
z = z.wrapping_add(0x9E3779B97F4A7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^= z >> 31;
(z as f64 / u64::MAX as f64) * 2.0 - 1.0
};
let ncols = vt.ncols();
let mut a = Array1::<f64>::zeros(ncols);
for pc in 0..vt.nrows() {
let row_pc = vt.row(pc);
let key = ((atom_idx as u64) << 20) ^ (pc as u64);
let wa = mix(key);
for c in 0..ncols {
a[c] += wa * row_pc[c];
}
}
let na = a.dot(&a).sqrt().max(1.0e-12);
a.mapv_inplace(|v| v / na);
let (d1, _d2, _ok) = surplus_phase_plane(vt.view(), atom_idx, 0, k_atoms);
for (r, o) in d1.iter().zip(a.iter()) {
assert!(
(r - o).abs() < 1e-15,
"offset-0 dir1 must match the original key bit-for-bit ({r} vs {o})"
);
}
}
#[test]
fn surplus_plane_collinear_falls_back_to_1d() {
let vt = make_vt(1, 4);
let (d1, _d2, ok) = surplus_phase_plane(vt.view(), 3, 0, 5);
assert!(!ok, "rank-1 span must report a 1-D (non-2-plane) result");
let n1: f64 = d1.dot(&d1).sqrt();
assert!(
(n1 - 1.0).abs() < 1e-9 && n1.is_finite(),
"dir1 must stay a finite unit vector (norm {n1})"
);
}
#[test]
fn surplus_periodic_seed_reproduces_and_diversifies() {
let n_obs = 8;
let p = 4;
let mut zvals = Vec::with_capacity(n_obs * p);
for r in 0..n_obs {
for c in 0..p {
zvals.push(
((r as f64) * 0.9 + 1.0).sin() * ((c + 1) as f64)
+ 0.3 * ((r * c) as f64).cos(),
);
}
}
let z = Array2::from_shape_vec((n_obs, p), zvals).unwrap();
let kinds = vec![SaeAtomBasisKind::Periodic; 5];
let dims = vec![1usize; 5];
let s0 = sae_pca_seed_initial_coords_with_pc_offset(z.view(), &kinds, &dims, 0).unwrap();
let s1 = sae_pca_seed_initial_coords_with_pc_offset(z.view(), &kinds, &dims, 1).unwrap();
let plain = sae_pca_seed_initial_coords(z.view(), &kinds, &dims).unwrap();
assert_eq!(s0, plain, "offset-0 must equal the no-offset seed bit-for-bit");
for v in s0.iter().chain(s1.iter()) {
assert!(
v.is_finite() && *v >= 0.0 && *v < 1.0,
"periodic phase must be finite in [0, 1): {v}"
);
}
let surplus_atom = 3;
let mut moved = 0.0_f64;
for row in 0..n_obs {
moved = moved.max((s0[[surplus_atom, row, 0]] - s1[[surplus_atom, row, 0]]).abs());
}
assert!(
moved > 1e-6,
"surplus atom must land on a distinct basin across retries (max move {moved:.3e})"
);
}
}