use super::SaeAtomBasisKind;
use crate::linalg::faer_ndarray::FaerSvd;
use ndarray::{Array1, Array2, Array3, ArrayView2};
pub fn sae_pca_seed_initial_coords(
z: ArrayView2<'_, f64>,
basis_kinds: &[SaeAtomBasisKind],
atom_dim: &[usize],
) -> Result<Array3<f64>, String> {
sae_pca_seed_initial_coords_with_pc_offset(z, basis_kinds, atom_dim, 0)
}
pub fn sae_pca_seed_initial_coords_with_pc_offset(
z: ArrayView2<'_, f64>,
basis_kinds: &[SaeAtomBasisKind],
atom_dim: &[usize],
pc_pair_offset: usize,
) -> Result<Array3<f64>, String> {
let k_atoms = basis_kinds.len();
let (n_obs, _p_out) = z.dim();
let d_max = atom_dim.iter().copied().max().unwrap_or(1).max(1);
let mut out = Array3::<f64>::zeros((k_atoms, n_obs, d_max));
if n_obs == 0 || z.ncols() == 0 {
return Ok(out);
}
for ((row, col), &value) in z.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_pca_seed: Z must be finite; Z[{row}, {col}] = {value}"
));
}
}
let mut col_means = Array1::<f64>::zeros(z.ncols());
for col in 0..z.ncols() {
let mut mean = 0.0_f64;
for (count, row) in (0..n_obs).enumerate() {
let x = z[[row, col]];
mean += (x - mean) / (count as f64 + 1.0);
}
col_means[col] = mean;
}
let mut centered = z.to_owned();
for row in 0..n_obs {
for col in 0..z.ncols() {
centered[[row, col]] -= col_means[col];
}
}
for ((row, col), &value) in centered.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_pca_seed: centered Z is non-finite at [{row}, {col}] \
(data span exceeds f64 range); rescale Z before seeding"
));
}
}
let (u_opt, s_vals, vt_opt) = centered
.svd(true, true)
.map_err(|err| format!("sae_pca_seed: SVD failed: {err:?}"))?;
let u = u_opt.ok_or_else(|| "sae_pca_seed: SVD returned no U".to_string())?;
let vt = vt_opt.ok_or_else(|| "sae_pca_seed: SVD returned no Vt".to_string())?;
let vt_rows = vt.nrows();
let u_cols = u.ncols();
let two_pi = std::f64::consts::TAU;
for atom_idx in 0..k_atoms {
let d = atom_dim[atom_idx];
if d == 0 {
continue;
}
match &basis_kinds[atom_idx] {
SaeAtomBasisKind::Periodic => {
if vt_rows >= 1 {
let pc_pairs = vt_rows / 2;
let (pc1_row, pc2_row) = if pc_pairs >= 1 {
let pair = (atom_idx + pc_pair_offset) % pc_pairs;
(2 * pair, 2 * pair + 1)
} else {
(0, 0)
};
let pc1 = vt.row(pc1_row.min(vt_rows - 1));
let phase_offset = if pc_pairs > 0 && pc_pairs < k_atoms {
atom_idx as f64 / k_atoms as f64
} else {
0.0
};
let s0 = s_vals.get(pc1_row).copied().unwrap_or(0.0).abs();
let s1 = s_vals.get(pc2_row).copied().unwrap_or(0.0).abs();
let has_two_dimensional_phase =
vt_rows >= 2 && pc2_row != pc1_row && s1 > 1.0e-10 * s0.max(1.0);
if has_two_dimensional_phase {
let pc2 = vt.row(pc2_row.min(vt_rows - 1));
for row in 0..n_obs {
let mut a = 0.0_f64;
let mut b = 0.0_f64;
for col in 0..centered.ncols() {
a += centered[[row, col]] * pc1[col];
b += centered[[row, col]] * pc2[col];
}
let phase = b.atan2(a) / two_pi + phase_offset;
out[[atom_idx, row, 0]] = phase - phase.floor();
}
} else {
let mut proj = Array1::<f64>::zeros(n_obs);
for row in 0..n_obs {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc1[col];
}
proj[row] = acc;
}
let (min_v, max_v) = proj
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
let phase = (proj[row] - min_v) / span + phase_offset;
out[[atom_idx, row, 0]] = phase - phase.floor();
}
}
}
}
for axis in 1..d {
if axis >= vt_rows {
break;
}
let pc = vt.row(axis);
let mut proj = Array1::<f64>::zeros(n_obs);
for row in 0..n_obs {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc[col];
}
proj[row] = acc;
}
let (min_v, max_v) = proj
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
out[[atom_idx, row, axis]] = (proj[row] - min_v) / span - 0.5;
}
}
}
}
SaeAtomBasisKind::Sphere => {
let n_pc = vt_rows.min(3);
if n_pc == 0 {
continue;
}
let base = if vt_rows > 0 {
(2 * pc_pair_offset) % vt_rows
} else {
0
};
let pcs: Vec<_> = (0..n_pc).map(|i| vt.row((base + i) % vt_rows)).collect();
for row in 0..n_obs {
let mut amb = [0.0_f64; 3];
for (i, pc) in pcs.iter().enumerate() {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc[col];
}
amb[i] = acc;
}
let norm = (amb[0] * amb[0] + amb[1] * amb[1] + amb[2] * amb[2]).sqrt();
let (x, y, z) = if norm > 0.0 {
(amb[0] / norm, amb[1] / norm, amb[2] / norm)
} else {
(1.0, 0.0, 0.0)
};
let lat = z.clamp(-1.0, 1.0).asin();
let lon = y.atan2(x);
if d >= 1 {
out[[atom_idx, row, 0]] = lat;
}
if d >= 2 {
out[[atom_idx, row, 1]] = lon;
}
}
}
SaeAtomBasisKind::Torus => {
let pc_pairs = vt_rows / 2;
for axis in 0..d {
let pair = if pc_pair_offset != 0 && pc_pairs > 0 {
(axis + pc_pair_offset) % pc_pairs
} else {
axis
};
let pc_a_idx = 2 * pair;
let pc_b_idx = 2 * pair + 1;
if pc_b_idx >= vt_rows {
break;
}
let pc_a = vt.row(pc_a_idx);
let pc_b = vt.row(pc_b_idx);
for row in 0..n_obs {
let mut a = 0.0_f64;
let mut b = 0.0_f64;
for col in 0..centered.ncols() {
a += centered[[row, col]] * pc_a[col];
b += centered[[row, col]] * pc_b[col];
}
let phase = b.atan2(a) / two_pi;
let wrapped = phase - phase.floor();
out[[atom_idx, row, axis]] = wrapped;
}
}
}
_ => {
let avail = u_cols.min(s_vals.len());
let k_cols = d.min(avail);
let base = if avail > 0 {
(2 * pc_pair_offset) % avail
} else {
0
};
let mut tmp = Array2::<f64>::zeros((n_obs, d));
for col in 0..k_cols {
let src = if avail > 0 { (base + col) % avail } else { col };
let s_col = s_vals[src];
for row in 0..n_obs {
tmp[[row, col]] = u[[row, src]] * s_col;
}
}
for col in 0..d {
let mut min_v = f64::INFINITY;
let mut max_v = f64::NEG_INFINITY;
for row in 0..n_obs {
let v = tmp[[row, col]];
if v < min_v {
min_v = v;
}
if v > max_v {
max_v = v;
}
}
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
out[[atom_idx, row, col]] = (tmp[[row, col]] - min_v) / span - 0.5;
}
}
}
}
}
}
Ok(out)
}