1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
//! Production PCA-based per-atom chart seeding for the SAE manifold fit.
//! Moved out of the (test-only) `tests` module so the production FFI seed
//! path can call it in release builds.
use super::SaeAtomBasisKind;
use gam_linalg::faer_ndarray::FaerSvd;
use ndarray::{Array1, Array2, Array3, ArrayView2};
/// PCA-based seed for SAE atom latent coordinates. Centers `z`, takes its SVD,
/// and projects onto leading principal components to initialize each atom's
/// chart according to its [`SaeAtomBasisKind`]: periodic atoms read a `[0, 1)`
/// phase off the top-2 PCs (remaining axes min-max normalized to
/// `[-0.5, 0.5]`), sphere atoms read `(lat, lon)` off the unit-normalized top-3
/// PCs, torus axes read a `[0, 1)` phase off disjoint PC pairs, and
/// Euclidean/other atoms take score-scaled, min-max-normalized PC projections.
/// Returns a padded
/// `(K_atoms, n_obs, d_max)` coordinate array.
pub fn sae_pca_seed_initial_coords(
z: ArrayView2<'_, f64>,
basis_kinds: &[SaeAtomBasisKind],
atom_dim: &[usize],
) -> Result<Array3<f64>, String> {
sae_pca_seed_initial_coords_with_pc_offset(z, basis_kinds, atom_dim, 0)
}
/// PCA seed with a deterministic principal-component-pair ROTATION offset.
///
/// Identical to [`sae_pca_seed_initial_coords`] (which is this with
/// `pc_pair_offset = 0`) except every atom reads its chart off a PC subspace
/// shifted by `pc_pair_offset` pairs. This is the lever the #976 simultaneous
/// co-collapse multi-start uses to make successive reseeds explore GENUINELY
/// DIFFERENT basins: the residual a co-collapsed dictionary leaves is ≈ the
/// target on every retry, so a fixed-offset reseed re-reads the SAME leading
/// PCs and the joint LSQ relaxes back into the SAME degenerate basin — the
/// budget-N multi-start would then be N identical attempts. Shifting the PC
/// pairs by the retry index lands the atoms on a disjoint principal subspace
/// each attempt (top pairs on retry 0, next pairs on retry 1, …), so the basins
/// are distinct by construction. The offset is a pure deterministic function of
/// the retry count (no RNG), so the seed stays bit-reproducible run-to-run and
/// across thread/device counts.
pub fn sae_pca_seed_initial_coords_with_pc_offset(
z: ArrayView2<'_, f64>,
basis_kinds: &[SaeAtomBasisKind],
atom_dim: &[usize],
pc_pair_offset: usize,
) -> Result<Array3<f64>, String> {
let k_atoms = basis_kinds.len();
let (n_obs, _p_out) = z.dim();
let d_max = atom_dim.iter().copied().max().unwrap_or(1).max(1);
let mut out = Array3::<f64>::zeros((k_atoms, n_obs, d_max));
if n_obs == 0 || z.ncols() == 0 {
return Ok(out);
}
// Reject non-finite input up front so a clean error surfaces here rather
// than a silent non-finite seed (or an opaque SVD failure) downstream.
for ((row, col), &value) in z.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_pca_seed: Z must be finite; Z[{row}, {col}] = {value}"
));
}
}
// Accumulate the column mean with Welford's running update
// `mean += (x − mean) / count` instead of a plain running sum. The plain
// sum overflows to `±inf` for huge finite columns (e.g. two rows of
// `1e308` sum to `2e308 = inf`), which poisons the centered matrix and the
// SVD. Welford's update keeps the accumulator bounded by the column's data
// range, so the mean is finite whenever the inputs are.
let mut col_means = Array1::<f64>::zeros(z.ncols());
for col in 0..z.ncols() {
let mut mean = 0.0_f64;
for (count, row) in (0..n_obs).enumerate() {
let x = z[[row, col]];
mean += (x - mean) / (count as f64 + 1.0);
}
col_means[col] = mean;
}
let mut centered = z.to_owned();
for row in 0..n_obs {
for col in 0..z.ncols() {
centered[[row, col]] -= col_means[col];
}
}
// Centering can still overflow if the data span itself is non-finite
// (e.g. `+1e308` and `−1e308` in one column give a finite mean but an
// `inf` deviation). Surface that as a clean error rather than feeding a
// non-finite matrix to the SVD.
for ((row, col), &value) in centered.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_pca_seed: centered Z is non-finite at [{row}, {col}] \
(data span exceeds f64 range); rescale Z before seeding"
));
}
}
let (u_opt, s_vals, vt_opt) = centered
.svd(true, true)
.map_err(|err| format!("sae_pca_seed: SVD failed: {err:?}"))?;
let u = u_opt.ok_or_else(|| "sae_pca_seed: SVD returned no U".to_string())?;
let vt = vt_opt.ok_or_else(|| "sae_pca_seed: SVD returned no Vt".to_string())?;
let vt_rows = vt.nrows();
let u_cols = u.ncols();
let two_pi = std::f64::consts::TAU;
for atom_idx in 0..k_atoms {
let d = atom_dim[atom_idx];
if d == 0 {
continue;
}
match &basis_kinds[atom_idx] {
SaeAtomBasisKind::Periodic => {
if vt_rows >= 1 {
// Diversify the per-atom circle seed (issue #671). The
// previous scheme shared PC0 as the first phase axis for
// *every* atom, so all periodic atoms read off nearly the
// same phase coordinate, producing near-duplicate basis
// designs and a severely ill-conditioned joint decoder LSQ
// seed. Give each atom a disjoint pair of principal
// components `(PC_{2k}, PC_{2k+1})` when the spectrum is
// wide enough, wrapping around only when atoms outnumber the
// available PC pairs. This keeps distinct atoms' seed
// coordinates decorrelated so the decoder seed stays
// well-conditioned and the cross-atom Gram starts small.
let pc_pairs = vt_rows / 2;
let (pc1_row, pc2_row) = if pc_pairs >= 1 {
// Rotate the per-atom PC pair by the multi-start offset so
// a co-collapse reseed retry reads a DISJOINT principal
// subspace (the #976 distinct-basin lever).
let pair = (atom_idx + pc_pair_offset) % pc_pairs;
(2 * pair, 2 * pair + 1)
} else {
(0, 0)
};
let pc1 = vt.row(pc1_row.min(vt_rows - 1));
let phase_offset = if pc_pairs > 0 && pc_pairs < k_atoms {
atom_idx as f64 / k_atoms as f64
} else {
0.0
};
let s0 = s_vals.get(pc1_row).copied().unwrap_or(0.0).abs();
let s1 = s_vals.get(pc2_row).copied().unwrap_or(0.0).abs();
let has_two_dimensional_phase =
vt_rows >= 2 && pc2_row != pc1_row && s1 > 1.0e-10 * s0.max(1.0);
if has_two_dimensional_phase {
let pc2 = vt.row(pc2_row.min(vt_rows - 1));
for row in 0..n_obs {
let mut a = 0.0_f64;
let mut b = 0.0_f64;
for col in 0..centered.ncols() {
a += centered[[row, col]] * pc1[col];
b += centered[[row, col]] * pc2[col];
}
let phase = b.atan2(a) / two_pi + phase_offset;
out[[atom_idx, row, 0]] = phase - phase.floor();
}
} else {
let mut proj = Array1::<f64>::zeros(n_obs);
for row in 0..n_obs {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc1[col];
}
proj[row] = acc;
}
let (min_v, max_v) = proj
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
let phase = (proj[row] - min_v) / span + phase_offset;
out[[atom_idx, row, 0]] = phase - phase.floor();
}
}
}
}
for axis in 1..d {
if axis >= vt_rows {
break;
}
let pc = vt.row(axis);
let mut proj = Array1::<f64>::zeros(n_obs);
for row in 0..n_obs {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc[col];
}
proj[row] = acc;
}
let (min_v, max_v) = proj
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
out[[atom_idx, row, axis]] = (proj[row] - min_v) / span - 0.5;
}
}
}
}
SaeAtomBasisKind::Sphere => {
// Seed the sphere chart from the top-3 PCs: drop the centred
// response onto (pc0, pc1, pc2), unit-normalise, and read off
// (lat, lon). This places every row on the chart with
// `lat ∈ (-π/2, π/2)` and `lon ∈ (-π, π]`.
let n_pc = vt_rows.min(3);
if n_pc == 0 {
continue;
}
// Rotate the sphere's leading-PC window by the multi-start offset
// (in PC-pair units, mod the available PCs) so a reseed retry
// reads a distinct 3-PC subspace (the #976 distinct-basin lever).
let base = if vt_rows > 0 {
(2 * pc_pair_offset) % vt_rows
} else {
0
};
let pcs: Vec<_> = (0..n_pc).map(|i| vt.row((base + i) % vt_rows)).collect();
for row in 0..n_obs {
let mut amb = [0.0_f64; 3];
for (i, pc) in pcs.iter().enumerate() {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc[col];
}
amb[i] = acc;
}
let norm = (amb[0] * amb[0] + amb[1] * amb[1] + amb[2] * amb[2]).sqrt();
let (x, y, z) = if norm > 0.0 {
(amb[0] / norm, amb[1] / norm, amb[2] / norm)
} else {
(1.0, 0.0, 0.0)
};
let lat = z.clamp(-1.0, 1.0).asin();
let lon = y.atan2(x);
if d >= 1 {
out[[atom_idx, row, 0]] = lat;
}
if d >= 2 {
out[[atom_idx, row, 1]] = lon;
}
}
}
SaeAtomBasisKind::Torus => {
// Seed each torus axis from a disjoint pair of PCs: axis `a`
// uses (pc_{2a}, pc_{2a+1}) projected onto the centred
// response and read off as `atan2`, normalised to `[0, 1)`.
let pc_pairs = vt_rows / 2;
for axis in 0..d {
// Rotate each torus axis's PC pair by the multi-start offset
// (same #976 distinct-basin lever as the periodic arm). With
// `pc_pair_offset == 0` this is the identity (`pair == axis`)
// and the original `pc_b_idx >= vt_rows` break is preserved
// bit-for-bit; a nonzero offset wraps within the available
// pairs so a retry reads a disjoint pair.
let pair = if pc_pair_offset != 0 && pc_pairs > 0 {
(axis + pc_pair_offset) % pc_pairs
} else {
axis
};
let pc_a_idx = 2 * pair;
let pc_b_idx = 2 * pair + 1;
if pc_b_idx >= vt_rows {
break;
}
let pc_a = vt.row(pc_a_idx);
let pc_b = vt.row(pc_b_idx);
for row in 0..n_obs {
let mut a = 0.0_f64;
let mut b = 0.0_f64;
for col in 0..centered.ncols() {
a += centered[[row, col]] * pc_a[col];
b += centered[[row, col]] * pc_b[col];
}
// atan2 ∈ (-π, π]; map to phase ∈ [0, 1).
let phase = b.atan2(a) / two_pi;
let wrapped = phase - phase.floor();
out[[atom_idx, row, axis]] = wrapped;
}
}
}
_ => {
let avail = u_cols.min(s_vals.len());
let k_cols = d.min(avail);
// Rotate the score-column window by the multi-start offset (in
// PC-pair units, mod the available components) so a reseed retry
// reads distinct principal scores (the #976 distinct-basin lever).
let base = if avail > 0 {
(2 * pc_pair_offset) % avail
} else {
0
};
let mut tmp = Array2::<f64>::zeros((n_obs, d));
for col in 0..k_cols {
let src = if avail > 0 { (base + col) % avail } else { col };
let s_col = s_vals[src];
for row in 0..n_obs {
tmp[[row, col]] = u[[row, src]] * s_col;
}
}
for col in 0..d {
let mut min_v = f64::INFINITY;
let mut max_v = f64::NEG_INFINITY;
for row in 0..n_obs {
let v = tmp[[row, col]];
if v < min_v {
min_v = v;
}
if v > max_v {
max_v = v;
}
}
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
out[[atom_idx, row, col]] = (tmp[[row, col]] - min_v) / span - 0.5;
}
}
}
}
}
}
Ok(out)
}