1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
//! Padded-FFI term builder, split out of `construction.rs` to keep that tracked
//! file under the #780 10k-line gate. `term_from_padded_blocks_with_mode` is the
//! single entry point Python callers use to assemble a [`SaeManifoldTerm`] from
//! `(K, N, M_max[, D_max])`-padded arrays; it is re-exported from `mod.rs` via
//! `pub use construction_padded_blocks::*;` so every caller keeps reaching it
//! bare through `use super::*`.
use super::*;
/// Helper for padded FFI callers. Arrays use `(K, N, M_max)` and
/// `(K, N, M_max, D_max)` storage, with `basis_sizes` and `latent_dims`
/// selecting each atom's active prefix.
///
/// `evaluators`, when non-empty, must have length `K`. Each entry attaches an
/// optional [`SaeBasisSecondJet`] to the matching atom so the Rust Newton
/// loop can refresh `Phi`/`dPhi/dt` between iterations without rebuilding the
/// term from Python. The evaluator is installed through
/// [`SaeManifoldAtom::with_basis_second_jet`], so its closed-form Hessian slot
/// is populated too — this is what lets the #1117 rank-revealing reduction
/// (`reduce_atoms_to_data_supported_rank`) reparametrize a rank-deficient
/// fixed-width decoder (e.g. the periodic circle's 5-column basis whose data
/// Gram comes out rank 3/5 on a near-degenerate checkpoint) onto its
/// data-supported subspace instead of stalling on the flat REML valley. An
/// empty slice leaves every atom in snapshot-only mode.
#[must_use = "build error must be handled"]
pub fn term_from_padded_blocks_with_mode(
n_obs: usize,
p_out: usize,
basis_kinds: &[SaeAtomBasisKind],
basis_values: ArrayView3<'_, f64>,
basis_jacobian: ArrayView4<'_, f64>,
basis_sizes: &[usize],
latent_dims: &[usize],
decoder_coefficients: ArrayView3<'_, f64>,
smooth_penalties: ArrayView3<'_, f64>,
logits: ArrayView2<'_, f64>,
coords: &[Array2<f64>],
mode: AssignmentMode,
evaluators: &[Option<Arc<dyn SaeBasisSecondJet>>],
) -> Result<SaeManifoldTerm, String> {
let k_atoms = basis_sizes.len();
if latent_dims.len() != k_atoms || basis_kinds.len() != k_atoms || coords.len() != k_atoms {
return Err("term_from_padded_blocks: K-length metadata mismatch".into());
}
if !evaluators.is_empty() && evaluators.len() != k_atoms {
return Err(format!(
"term_from_padded_blocks: evaluators length {} must equal K={k_atoms} or be empty",
evaluators.len()
));
}
if logits.dim() != (n_obs, k_atoms) {
return Err(format!(
"term_from_padded_blocks: logits must be ({n_obs}, {k_atoms}); got {:?}",
logits.dim()
));
}
let mut atoms = Vec::with_capacity(k_atoms);
for k in 0..k_atoms {
let m = basis_sizes[k];
let d = latent_dims[k];
let phi = basis_values.slice(s![k, 0..n_obs, 0..m]).to_owned();
let jet = basis_jacobian.slice(s![k, 0..n_obs, 0..m, 0..d]).to_owned();
let b = decoder_coefficients.slice(s![k, 0..m, 0..p_out]).to_owned();
let s = smooth_penalties.slice(s![k, 0..m, 0..m]).to_owned();
let atom = SaeManifoldAtom::new(
format!("atom_{k}"),
basis_kinds[k].clone(),
d,
phi,
jet,
b,
s,
)?;
let atom = match evaluators.get(k).and_then(|slot| slot.clone()) {
// Install through the second-jet slot so the analytic Hessian is
// available: the #1117 rank-revealing reduction needs it to compose
// the reduced jets when it reparametrizes a rank-deficient atom onto
// its data-supported subspace. All production SAE evaluators
// (periodic/sphere/torus/cylinder/Duchon/Euclidean-patch) implement
// `SaeBasisSecondJet`, so this is the standard install path.
Some(evaluator) => atom.with_basis_second_jet(evaluator),
None => atom,
};
atoms.push(atom);
}
let manifolds = basis_kinds
.iter()
.zip(latent_dims.iter().copied())
.map(|(kind, d)| kind.latent_manifold(d))
.collect();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits.to_owned(),
coords.to_vec(),
manifolds,
mode,
)?;
SaeManifoldTerm::new(atoms, assignment)
}