use super::*;
#[must_use = "build error must be handled"]
pub fn term_from_padded_blocks_with_mode(
n_obs: usize,
p_out: usize,
basis_kinds: &[SaeAtomBasisKind],
basis_values: ArrayView3<'_, f64>,
basis_jacobian: ArrayView4<'_, f64>,
basis_sizes: &[usize],
latent_dims: &[usize],
decoder_coefficients: ArrayView3<'_, f64>,
smooth_penalties: ArrayView3<'_, f64>,
logits: ArrayView2<'_, f64>,
coords: &[Array2<f64>],
mode: AssignmentMode,
evaluators: &[Option<Arc<dyn SaeBasisSecondJet>>],
) -> Result<SaeManifoldTerm, String> {
let k_atoms = basis_sizes.len();
if latent_dims.len() != k_atoms || basis_kinds.len() != k_atoms || coords.len() != k_atoms {
return Err("term_from_padded_blocks: K-length metadata mismatch".into());
}
if !evaluators.is_empty() && evaluators.len() != k_atoms {
return Err(format!(
"term_from_padded_blocks: evaluators length {} must equal K={k_atoms} or be empty",
evaluators.len()
));
}
if logits.dim() != (n_obs, k_atoms) {
return Err(format!(
"term_from_padded_blocks: logits must be ({n_obs}, {k_atoms}); got {:?}",
logits.dim()
));
}
let mut atoms = Vec::with_capacity(k_atoms);
for k in 0..k_atoms {
let m = basis_sizes[k];
let d = latent_dims[k];
let phi = basis_values.slice(s![k, 0..n_obs, 0..m]).to_owned();
let jet = basis_jacobian.slice(s![k, 0..n_obs, 0..m, 0..d]).to_owned();
let b = decoder_coefficients.slice(s![k, 0..m, 0..p_out]).to_owned();
let s = smooth_penalties.slice(s![k, 0..m, 0..m]).to_owned();
let atom = SaeManifoldAtom::new(
format!("atom_{k}"),
basis_kinds[k].clone(),
d,
phi,
jet,
b,
s,
)?;
let atom = match evaluators.get(k).and_then(|slot| slot.clone()) {
Some(evaluator) => atom.with_basis_second_jet(evaluator),
None => atom,
};
atoms.push(atom);
}
let manifolds = basis_kinds
.iter()
.zip(latent_dims.iter().copied())
.map(|(kind, d)| kind.latent_manifold(d))
.collect();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits.to_owned(),
coords.to_vec(),
manifolds,
mode,
)?;
SaeManifoldTerm::new(atoms, assignment)
}