use super::*;
pub fn refresh_isometry_caches_from_atom(
penalty: &IsometryPenalty,
atom: &SaeManifoldAtom,
coords: ArrayView2<'_, f64>,
) -> Result<bool, String> {
let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
format!(
"refresh_isometry_caches_from_atom: atom {} has no basis evaluator",
atom.name
)
})?;
let (_phi, jet) = evaluator.evaluate(coords)?;
let n_obs = coords.nrows();
let d = atom.latent_dim;
let m = atom.basis_size();
let p = atom.decoder_coefficients.ncols();
if penalty.p_out != p {
return Err(format!(
"refresh_isometry_caches_from_atom: penalty.p_out={} but atom.decoder.cols={p}",
penalty.p_out
));
}
if jet.dim() != (n_obs, m, d) {
return Err(format!(
"refresh_isometry_caches_from_atom: evaluator first jet has shape {:?}, expected ({n_obs}, {m}, {d})",
jet.dim()
));
}
let b = &atom.decoder_coefficients;
let mut jac = Array2::<f64>::zeros((n_obs, p * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += jet[[n, mm, a]] * b[[mm, i]];
}
jac[[n, i * d + a]] = acc;
}
}
}
let jac2_opt = if let Some(second_eval) = atom.basis_second_jet.as_ref() {
let hess = second_eval.second_jet(coords)?;
if hess.dim() != (n_obs, m, d, d) {
return Err(format!(
"refresh_isometry_caches_from_atom: evaluator second jet has shape {:?}, expected ({n_obs}, {m}, {d}, {d})",
hess.dim()
));
}
let mut jac2 = Array2::<f64>::zeros((n_obs, p * d * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
for c in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += hess[[n, mm, a, c]] * b[[mm, i]];
}
jac2[[n, (i * d + a) * d + c]] = acc;
}
}
}
}
Some(Arc::new(jac2))
} else {
None
};
let jac3_opt = if penalty.duchon_radial_source.is_none() {
match evaluator.third_jet_dyn(coords) {
Some(third) => {
let t3 = third?;
if t3.dim() != (n_obs, m, d, d, d) {
return Err(format!(
"refresh_isometry_caches_from_atom: evaluator third jet has shape {:?}, expected ({n_obs}, {m}, {d}, {d}, {d})",
t3.dim()
));
}
let mut jac3 = Array3::<f64>::zeros((n_obs, p, d * d * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
for c in 0..d {
for e in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += t3[[n, mm, a, c, e]] * b[[mm, i]];
}
jac3[[n, i, ((a * d) + c) * d + e]] = acc;
}
}
}
}
}
Some(Arc::new(jac3))
}
None => None,
}
} else {
None
};
let installed = jac2_opt.is_some();
penalty.refresh_caches(Some(Arc::new(jac)), jac2_opt);
penalty.set_third_decoder_derivative(jac3_opt);
Ok(installed)
}
pub fn refresh_isometry_caches_from_term(
registry: &AnalyticPenaltyRegistry,
term: &SaeManifoldTerm,
coords_per_atom: &[Array2<f64>],
) -> Result<usize, String> {
if coords_per_atom.len() != term.atoms.len() {
return Err(format!(
"refresh_isometry_caches_from_term: coords_per_atom length {} != number of atoms {}",
coords_per_atom.len(),
term.atoms.len()
));
}
let mut refreshed_with_second = 0usize;
let mut consumed_per_signature: std::collections::HashMap<(usize, usize), usize> =
std::collections::HashMap::new();
for entry in registry.penalties.iter() {
let AnalyticPenaltyKind::Isometry(p) = entry else {
continue;
};
let Some(p_latent_dim) = p.target.latent_dim else {
continue;
};
let signature = (p_latent_dim, p.p_out);
let already_consumed = consumed_per_signature.entry(signature).or_insert(0);
let mut seen = 0usize;
let mut paired: Option<usize> = None;
for (atom_idx, atom) in term.atoms.iter().enumerate() {
let matches = atom.latent_dim == p_latent_dim
&& atom.decoder_coefficients.ncols() == p.p_out
&& atom.basis_evaluator.is_some();
if !matches {
continue;
}
if seen == *already_consumed {
paired = Some(atom_idx);
break;
}
seen += 1;
}
let Some(atom_idx) = paired else {
continue;
};
*already_consumed += 1;
let atom = &term.atoms[atom_idx];
let coords = coords_per_atom[atom_idx].view();
if refresh_isometry_caches_from_atom(p, atom, coords)? {
refreshed_with_second += 1;
}
}
Ok(refreshed_with_second)
}