use super::*;
pub use gam_problem::WeightField;
#[derive(Clone)]
pub enum IsometryReference {
Euclidean,
UserSupplied(Arc<Array2<f64>>), }
impl std::fmt::Debug for IsometryReference {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IsometryReference::Euclidean => f.write_str("Euclidean"),
IsometryReference::UserSupplied(a) => f
.debug_tuple("UserSupplied")
.field(&format_args!("{}×{}", a.nrows(), a.ncols()))
.finish(),
}
}
}
#[derive(Debug, Clone)]
pub struct IsometryDuchonRadialSource {
pub centers: Arc<Array2<f64>>,
pub radial_coefficients: Arc<Array2<f64>>,
pub length_scale: Option<f64>,
pub nullspace_order: DuchonNullspaceOrder,
pub power: usize,
}
#[derive(Debug)]
pub struct IsometryPenalty {
pub target: PsiSlice,
pub reference: IsometryReference,
pub rho_index: usize,
pub jacobian_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
pub jacobian_second_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
pub duchon_radial_source: Option<Arc<IsometryDuchonRadialSource>>,
pub third_decoder_derivative_slot: RwLock<Option<Arc<ndarray::Array3<f64>>>>,
pub p_out: usize,
pub weight: WeightField,
pub scalar_weight: f64,
pub weight_schedule: Option<ScalarWeightSchedule>,
}
pub(crate) struct IsometryHvpState<'a> {
d: usize,
n_obs: usize,
p: usize,
jac2: CowArray<'a, f64, Ix2>,
jac3: CowArray<'a, f64, Ix3>,
metric: IsometryMetricState,
wj_rows: Vec<Array2<f64>>,
}
#[derive(Debug, Clone)]
struct IsometryMetricState {
g: Array2<f64>,
residual: Array2<f64>,
metric_grad: Array2<f64>,
normalizer: f64,
trace_denominator: f64,
residual_dot_g: f64,
}
impl IsometryMetricState {
fn residual_direction(&self, delta_g: ArrayView2<'_, f64>, d: usize) -> (Array2<f64>, f64) {
let n_obs = self.g.nrows();
let dd = d * d;
let mut delta_trace_sum = 0.0;
for n in 0..n_obs {
for a in 0..d {
delta_trace_sum += delta_g[[n, a * d + a]];
}
}
let delta_normalizer = delta_trace_sum / self.trace_denominator;
let inv_norm = 1.0 / self.normalizer;
let inv_norm_sq = inv_norm * inv_norm;
let mut delta_residual = Array2::<f64>::zeros((n_obs, dd));
for n in 0..n_obs {
for k in 0..dd {
delta_residual[[n, k]] =
delta_g[[n, k]] * inv_norm - self.g[[n, k]] * delta_normalizer * inv_norm_sq;
}
}
(delta_residual, delta_normalizer)
}
fn metric_grad_direction(&self, delta_g: ArrayView2<'_, f64>, d: usize) -> Array2<f64> {
let n_obs = self.g.nrows();
let dd = d * d;
let (delta_residual, delta_normalizer) = self.residual_direction(delta_g, d);
let mut delta_residual_dot_g = 0.0;
for n in 0..n_obs {
for k in 0..dd {
delta_residual_dot_g += delta_residual[[n, k]] * self.g[[n, k]];
delta_residual_dot_g += self.residual[[n, k]] * delta_g[[n, k]];
}
}
let inv_norm = 1.0 / self.normalizer;
let inv_norm_sq = inv_norm * inv_norm;
let delta_trace_coeff = delta_residual_dot_g * inv_norm_sq / self.trace_denominator
- 2.0 * self.residual_dot_g * delta_normalizer * inv_norm_sq * inv_norm
/ self.trace_denominator;
let mut out = Array2::<f64>::zeros((n_obs, dd));
for n in 0..n_obs {
for a in 0..d {
for b in 0..d {
let k = a * d + b;
let mut value = delta_residual[[n, k]] * inv_norm
- self.residual[[n, k]] * delta_normalizer * inv_norm_sq;
if a == b {
value -= delta_trace_coeff;
}
out[[n, k]] = value;
}
}
}
out
}
}
fn isometry_dg_entry(
jac2: ArrayView2<'_, f64>,
wj: ArrayView2<'_, f64>,
n: usize,
d: usize,
p: usize,
a: usize,
b: usize,
c: usize,
) -> f64 {
let mut s = 0.0;
for i in 0..p {
s += jac2[[n, (i * d + a) * d + c]] * wj[[i, b]];
s += wj[[i, a]] * jac2[[n, (i * d + b) * d + c]];
}
s
}
fn isometry_row_delta_g(
jac2: ArrayView2<'_, f64>,
wj: ArrayView2<'_, f64>,
v: ArrayView1<'_, f64>,
n: usize,
d: usize,
p: usize,
) -> Array2<f64> {
let mut delta_g = Array2::<f64>::zeros((d, d));
for a in 0..d {
for b in 0..d {
let mut s = 0.0;
for c in 0..d {
s += isometry_dg_entry(jac2, wj, n, d, p, a, b, c) * v[n * d + c];
}
delta_g[[a, b]] = s;
}
}
delta_g
}
impl IsometryPenalty {
pub const DEFAULT_VALUE_ON_MISSING_CACHE: f64 = 0.0;
#[must_use]
pub fn new_euclidean(target: PsiSlice, p_out: usize) -> Self {
Self {
target,
reference: IsometryReference::Euclidean,
rho_index: 0,
jacobian_cache_slot: RwLock::new(None),
jacobian_second_cache_slot: RwLock::new(None),
duchon_radial_source: None,
third_decoder_derivative_slot: RwLock::new(None),
p_out,
weight: WeightField::Identity,
scalar_weight: 1.0,
weight_schedule: None,
}
}
#[must_use]
pub fn jacobian_cache(&self) -> Option<Arc<Array2<f64>>> {
self.jacobian_cache_slot
.read()
.expect("IsometryPenalty::jacobian_cache_slot poisoned")
.clone()
}
#[must_use]
pub fn jacobian_second_cache(&self) -> Option<Arc<Array2<f64>>> {
self.jacobian_second_cache_slot
.read()
.expect("IsometryPenalty::jacobian_second_cache_slot poisoned")
.clone()
}
pub fn refresh_caches(&self, jac: Option<Arc<Array2<f64>>>, jac2: Option<Arc<Array2<f64>>>) {
*self
.jacobian_cache_slot
.write()
.expect("IsometryPenalty::jacobian_cache_slot poisoned") = jac;
*self
.jacobian_second_cache_slot
.write()
.expect("IsometryPenalty::jacobian_second_cache_slot poisoned") = jac2;
}
pub fn set_jacobian_cache(&self, jac: Option<Arc<Array2<f64>>>) {
*self
.jacobian_cache_slot
.write()
.expect("IsometryPenalty::jacobian_cache_slot poisoned") = jac;
}
pub fn set_jacobian_second_cache(&self, jac2: Option<Arc<Array2<f64>>>) {
*self
.jacobian_second_cache_slot
.write()
.expect("IsometryPenalty::jacobian_second_cache_slot poisoned") = jac2;
}
#[must_use]
pub fn third_decoder_derivative(&self) -> Option<Arc<ndarray::Array3<f64>>> {
self.third_decoder_derivative_slot
.read()
.expect("IsometryPenalty::third_decoder_derivative_slot poisoned")
.clone()
}
pub fn set_third_decoder_derivative(&self, jac3: Option<Arc<ndarray::Array3<f64>>>) {
*self
.third_decoder_derivative_slot
.write()
.expect("IsometryPenalty::third_decoder_derivative_slot poisoned") = jac3;
}
}
impl Clone for IsometryPenalty {
fn clone(&self) -> Self {
Self {
target: self.target.clone(),
reference: self.reference.clone(),
rho_index: self.rho_index,
jacobian_cache_slot: RwLock::new(self.jacobian_cache()),
jacobian_second_cache_slot: RwLock::new(self.jacobian_second_cache()),
duchon_radial_source: self.duchon_radial_source.clone(),
third_decoder_derivative_slot: RwLock::new(self.third_decoder_derivative()),
p_out: self.p_out,
weight: self.weight.clone(),
scalar_weight: self.scalar_weight,
weight_schedule: self.weight_schedule.clone(),
}
}
}
impl IsometryPenalty {
#[must_use]
pub fn with_third_decoder_derivative(self, k: Arc<ndarray::Array3<f64>>) -> Self {
self.set_third_decoder_derivative(Some(k));
self
}
#[must_use]
pub fn with_reference(mut self, reference: IsometryReference) -> Self {
self.reference = reference;
self
}
#[must_use]
pub fn with_jacobian_cache(self, j: Arc<Array2<f64>>) -> Self {
self.set_jacobian_cache(Some(j));
self
}
#[must_use]
pub fn with_jacobian_second_cache(self, h: Arc<Array2<f64>>) -> Self {
self.set_jacobian_second_cache(Some(h));
self
}
#[must_use]
pub fn with_duchon_radial_source(mut self, source: Arc<IsometryDuchonRadialSource>) -> Self {
self.duchon_radial_source = Some(source);
self
}
#[must_use]
pub fn with_row_metric(mut self, metric: &gam_problem::RowMetric) -> Self {
if metric.drives_gauge() {
self.weight = metric.to_weight_field();
}
self.p_out = metric.p_out();
self
}
impl_with_weight_schedule!(scalar_weight);
fn missing_cache_default(&self, method: &str, detail: &str) {
log::warn!(
"IsometryPenalty::{method} missing required derivative state: {detail}; \
returning the zero safe default"
);
}
fn has_jacobian_cache(&self, method: &str) -> bool {
if self.jacobian_cache().is_some() {
true
} else {
self.missing_cache_default(method, "jacobian_cache is None");
false
}
}
fn has_jacobian_second_source(&self, method: &str) -> bool {
if self.jacobian_second_cache().is_some() || self.duchon_radial_source.is_some() {
true
} else {
self.missing_cache_default(
method,
"both jacobian_second_cache and duchon_radial_source are None",
);
false
}
}
fn has_jacobian_third_source(&self, method: &str) -> bool {
if self.third_decoder_derivative().is_some() || self.duchon_radial_source.is_some() {
true
} else {
self.missing_cache_default(
method,
"both third_decoder_derivative cache and duchon_radial_source are None",
);
false
}
}
fn projected_jacobian_row(&self, n: usize, d: usize) -> Option<Array2<f64>> {
let Some(jac) = self.jacobian_cache() else {
self.missing_cache_default("projected_jacobian_row", "jacobian_cache is None");
return None;
};
let jac_row = jac.row(n);
let jac_slice = jac_row
.as_slice()
.expect("jacobian cache must be in standard row-major layout");
match &self.weight {
WeightField::Identity => {
let p = self.p_out;
let mut m = Array2::<f64>::zeros((p, d));
for i in 0..p {
for a in 0..d {
m[[i, a]] = jac_slice[i * d + a];
}
}
Some(m)
}
WeightField::Factored { u, rank, p_out } => {
let u_row = u.row(n);
let u_slice = u_row
.as_slice()
.expect("weight factor U must be in standard row-major layout");
Some(WeightField::project_jac_row_with_u(
u_slice, jac_slice, *p_out, *rank, d,
))
}
}
}
fn weighted_jacobian_row(&self, n: usize, d: usize) -> Option<Array2<f64>> {
let Some(jac) = self.jacobian_cache() else {
self.missing_cache_default("weighted_jacobian_row", "jacobian_cache is None");
return None;
};
let p = self.p_out;
match &self.weight {
WeightField::Identity => {
let mut out = Array2::<f64>::zeros((p, d));
for i in 0..p {
for a in 0..d {
out[[i, a]] = jac[[n, i * d + a]];
}
}
Some(out)
}
WeightField::Factored { u, rank, p_out } => {
assert_eq!(p, *p_out);
let r = *rank;
let m_n = self.projected_jacobian_row(n, d)?;
let mut out = Array2::<f64>::zeros((p, d));
for i in 0..p {
for a in 0..d {
let mut s = 0.0;
for k in 0..r {
s += u[[n, i * r + k]] * m_n[[k, a]];
}
out[[i, a]] = s;
}
}
Some(out)
}
}
}
fn weighted_dot_decoder_vectors<F, G>(&self, n: usize, p: usize, x: F, y: G) -> f64
where
F: Fn(usize) -> f64,
G: Fn(usize) -> f64,
{
match &self.weight {
WeightField::Identity => {
let mut s = 0.0;
for i in 0..p {
s += x(i) * y(i);
}
s
}
WeightField::Factored { u, rank, p_out } => {
assert_eq!(p, *p_out);
let r = *rank;
let mut s = 0.0;
for k in 0..r {
let mut ux = 0.0;
let mut uy = 0.0;
for i in 0..p {
let uik = u[[n, i * r + k]];
ux += uik * x(i);
uy += uik * y(i);
}
s += ux * uy;
}
s
}
}
}
fn target_matrix(target: ArrayView1<'_, f64>, n_obs: usize, d: usize) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((n_obs, d));
for n in 0..n_obs {
for a in 0..d {
out[[n, a]] = target[n * d + a];
}
}
out
}
fn duchon_radial_jacobian_second(
&self,
target: ArrayView1<'_, f64>,
n_obs: usize,
d: usize,
source: &IsometryDuchonRadialSource,
) -> Result<Array2<f64>, BasisError> {
assert_eq!(source.centers.ncols(), d);
assert_eq!(source.radial_coefficients.nrows(), source.centers.nrows());
assert_eq!(source.radial_coefficients.ncols(), self.p_out);
let t = Self::target_matrix(target, n_obs, d);
radial_basis_cartesian_derivative(
2,
t.view(),
source.centers.view(),
source.radial_coefficients.view(),
source.length_scale,
source.nullspace_order,
source.power,
)
}
fn duchon_radial_jacobian_third(
&self,
target: ArrayView1<'_, f64>,
n_obs: usize,
d: usize,
source: &IsometryDuchonRadialSource,
) -> Result<ndarray::Array3<f64>, BasisError> {
assert_eq!(source.centers.ncols(), d);
assert_eq!(source.radial_coefficients.nrows(), source.centers.nrows());
assert_eq!(source.radial_coefficients.ncols(), self.p_out);
let t = Self::target_matrix(target, n_obs, d);
let flat = radial_basis_cartesian_derivative(
3,
t.view(),
source.centers.view(),
source.radial_coefficients.view(),
source.length_scale,
source.nullspace_order,
source.power,
)?;
Ok(flat
.into_shape_with_order((n_obs, self.p_out, d * d * d))
.expect("radial_basis_cartesian_derivative order-3 output reshapes to (n_obs, p, d³)"))
}
fn jacobian_second<'a>(
&'a self,
target: ArrayView1<'_, f64>,
n_obs: usize,
d: usize,
) -> Option<CowArray<'a, f64, Ix2>> {
if let Some(jac2) = self.jacobian_second_cache() {
return Some(CowArray::from((*jac2).clone()));
}
let source = self.duchon_radial_source.as_ref()?;
match self.duchon_radial_jacobian_second(target, n_obs, d, source) {
Ok(jac2) => Some(CowArray::from(jac2)),
Err(err) => {
self.missing_cache_default(
"jacobian_second",
&format!("failed to materialize Duchon radial second derivative: {err}"),
);
None
}
}
}
fn jacobian_third<'a>(
&'a self,
target: ArrayView1<'_, f64>,
n_obs: usize,
d: usize,
) -> Option<CowArray<'a, f64, Ix3>> {
if let Some(jac3) = self.third_decoder_derivative() {
return Some(CowArray::from(jac3.as_ref().clone()));
}
let source = self.duchon_radial_source.as_ref()?;
match self.duchon_radial_jacobian_third(target, n_obs, d, source) {
Ok(jac3) => Some(CowArray::from(jac3)),
Err(err) => {
self.missing_cache_default(
"jacobian_third",
&format!("failed to materialize Duchon radial third derivative: {err}"),
);
None
}
}
}
pub(crate) fn hvp_state<'a>(
&'a self,
target: ArrayView1<'_, f64>,
) -> Option<IsometryHvpState<'a>> {
let d = self
.target
.latent_dim
.expect("IsometryPenalty requires latent_dim on its PsiSlice");
let n_obs = target.len() / d;
if !self.has_jacobian_cache("hvp")
|| !self.has_jacobian_second_source("hvp")
|| !self.has_jacobian_third_source("hvp")
{
return None;
}
let p = self.p_out;
let jac2 = self.jacobian_second(target.view(), n_obs, d)?;
let jac3 = self.jacobian_third(target.view(), n_obs, d)?;
let g = self.pullback_metric(d)?;
let metric = self.normalized_metric_state(g, n_obs, d)?;
let mut wj_rows = Vec::with_capacity(n_obs);
for n in 0..n_obs {
wj_rows.push(self.weighted_jacobian_row(n, d)?);
}
Some(IsometryHvpState {
d,
n_obs,
p,
jac2,
jac3,
metric,
wj_rows,
})
}
pub(crate) fn hvp_with_precomputed_state(
&self,
state: &IsometryHvpState<'_>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
let d = state.d;
let n_obs = state.n_obs;
let p = state.p;
let jac2 = &state.jac2;
let jac3 = &state.jac3;
let metric = &state.metric;
let mut out = Array1::<f64>::zeros(v.len());
let mut delta_g = Array2::<f64>::zeros((n_obs, d * d));
for n in 0..n_obs {
let wj = &state.wj_rows[n];
let row_delta = isometry_row_delta_g(jac2.view(), wj.view(), v, n, d, p);
for a in 0..d {
for b in 0..d {
delta_g[[n, a * d + b]] = row_delta[[a, b]];
}
}
}
let delta_metric_grad = metric.metric_grad_direction(delta_g.view(), d);
for n in 0..n_obs {
let wj = &state.wj_rows[n];
for c in 0..d {
let mut acc = 0.0;
for a in 0..d {
for b in 0..d {
let dg = isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, b, c);
acc += dg * delta_metric_grad[[n, a * d + b]];
}
}
out[n * d + c] = mu * acc;
}
for c in 0..d {
let mut acc_res = 0.0;
for a in 0..d {
for b in 0..d {
let metric_grad = metric.metric_grad[[n, a * d + b]];
if metric_grad == 0.0 {
continue;
}
let mut bv = 0.0;
for dd in 0..d {
let vd = v[n * d + dd];
if vd == 0.0 {
continue;
}
let mut k_a_cd_w_j_b = 0.0;
for i in 0..p {
k_a_cd_w_j_b += jac3[[n, i, ((a * d) + c) * d + dd]] * wj[[i, b]];
}
let h_a_c_w_h_b_d = self.weighted_dot_decoder_vectors(
n,
p,
|i| jac2[[n, (i * d + a) * d + c]],
|i| jac2[[n, (i * d + b) * d + dd]],
);
let h_a_d_w_h_b_c = self.weighted_dot_decoder_vectors(
n,
p,
|i| jac2[[n, (i * d + a) * d + dd]],
|i| jac2[[n, (i * d + b) * d + c]],
);
let mut j_a_w_k_b_cd = 0.0;
for i in 0..p {
j_a_w_k_b_cd += wj[[i, a]] * jac3[[n, i, ((b * d) + c) * d + dd]];
}
bv +=
(k_a_cd_w_j_b + h_a_c_w_h_b_d + h_a_d_w_h_b_c + j_a_w_k_b_cd) * vd;
}
acc_res += metric_grad * bv;
}
}
out[n * d + c] += mu * acc_res;
}
}
out
}
pub fn pullback_metric(&self, latent_dim: usize) -> Option<Array2<f64>> {
let Some(jac) = self.jacobian_cache() else {
self.missing_cache_default("pullback_metric", "jacobian_cache is None");
return None;
};
let n_obs = jac.nrows();
let p = self.p_out;
assert_eq!(jac.ncols(), p * latent_dim);
let mut g_all = Array2::<f64>::zeros((n_obs, latent_dim * latent_dim));
for n in 0..n_obs {
let m = self.projected_jacobian_row(n, latent_dim)?;
let r = m.nrows();
for a in 0..latent_dim {
for b in 0..latent_dim {
let mut s = 0.0;
for k in 0..r {
s += m[[k, a]] * m[[k, b]];
}
g_all[[n, a * latent_dim + b]] = s;
}
}
}
Some(g_all)
}
fn reference_metric(&self, n_obs: usize, d: usize) -> CowArray<'_, f64, Ix2> {
match &self.reference {
IsometryReference::Euclidean => {
let mut out = Array2::<f64>::zeros((n_obs, d * d));
for n in 0..n_obs {
for a in 0..d {
out[[n, a * d + a]] = 1.0;
}
}
CowArray::from(out)
}
IsometryReference::UserSupplied(a) => {
assert_eq!(a.nrows(), n_obs);
assert_eq!(a.ncols(), d * d);
CowArray::from(a.view())
}
}
}
fn normalized_metric_state(
&self,
g: Array2<f64>,
n_obs: usize,
d: usize,
) -> Option<IsometryMetricState> {
let dd = d * d;
let trace_denominator = (n_obs * d) as f64;
let mut trace_sum = 0.0;
for n in 0..n_obs {
for a in 0..d {
trace_sum += g[[n, a * d + a]];
}
}
let normalizer = trace_sum / trace_denominator;
if !(normalizer.is_finite() && normalizer > f64::MIN_POSITIVE) {
self.missing_cache_default(
"normalized_metric_state",
&format!(
"unit-average-speed normalizer is non-positive or non-finite: {normalizer}"
),
);
return None;
}
let g_ref = self.reference_metric(n_obs, d);
let mut residual = Array2::<f64>::zeros((n_obs, dd));
let inv_norm = 1.0 / normalizer;
for n in 0..n_obs {
for k in 0..dd {
residual[[n, k]] = g[[n, k]] * inv_norm - g_ref[[n, k]];
}
}
let mut residual_dot_g = 0.0;
for n in 0..n_obs {
for k in 0..dd {
residual_dot_g += residual[[n, k]] * g[[n, k]];
}
}
let trace_coeff = residual_dot_g / (normalizer * normalizer * trace_denominator);
let mut metric_grad = Array2::<f64>::zeros((n_obs, dd));
for n in 0..n_obs {
for a in 0..d {
for b in 0..d {
let k = a * d + b;
let mut value = residual[[n, k]] * inv_norm;
if a == b {
value -= trace_coeff;
}
metric_grad[[n, k]] = value;
}
}
}
Some(IsometryMetricState {
g,
residual,
metric_grad,
normalizer,
trace_denominator,
residual_dot_g,
})
}
pub fn grad_jacobian(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array2<f64> {
let d = self
.target
.latent_dim
.expect("IsometryPenalty requires latent_dim on its PsiSlice");
let n_obs = target.len() / d;
let p = self.p_out;
let mut grad = Array2::<f64>::zeros((n_obs, p * d));
if !self.has_jacobian_cache("grad_jacobian") {
return grad;
}
let Some(g) = self.pullback_metric(d) else {
return grad;
};
let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
return grad;
};
let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
for n in 0..n_obs {
let Some(wj) = self.weighted_jacobian_row(n, d) else {
return Array2::<f64>::zeros((n_obs, p * d));
};
for i in 0..p {
for c in 0..d {
let mut acc = 0.0;
for b in 0..d {
acc += metric.metric_grad[[n, c * d + b]] * wj[[i, b]];
}
grad[[n, i * d + c]] = 2.0 * mu * acc;
}
}
}
grad
}
}
impl AnalyticPenalty for IsometryPenalty {
fn tier(&self) -> PenaltyTier {
PenaltyTier::Psi
}
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
let d = self
.target
.latent_dim
.expect("IsometryPenalty requires latent_dim on its PsiSlice");
let n_obs = target.len() / d;
if !self.has_jacobian_cache("value") {
return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
}
let Some(g) = self.pullback_metric(d) else {
return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
};
let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
};
let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
let mut acc = 0.0;
for n in 0..n_obs {
for k in 0..(d * d) {
let diff = metric.residual[[n, k]];
acc += diff * diff;
}
}
0.5 * mu * acc
}
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let d = self
.target
.latent_dim
.expect("IsometryPenalty requires latent_dim on its PsiSlice");
let n_obs = target.len() / d;
if !self.has_jacobian_cache("grad_target")
|| !self.has_jacobian_second_source("grad_target")
{
return Array1::<f64>::zeros(target.len());
}
let Some(g) = self.pullback_metric(d) else {
return Array1::<f64>::zeros(target.len());
};
let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
return Array1::<f64>::zeros(target.len());
};
let p = self.p_out;
let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
let mut grad = Array1::<f64>::zeros(target.len());
let Some(jac2) = self.jacobian_second(target, n_obs, d) else {
return grad;
};
assert_eq!(jac2.ncols(), p * d * d);
for n in 0..n_obs {
let Some(wj) = self.weighted_jacobian_row(n, d) else {
return grad;
};
for c in 0..d {
let mut acc = 0.0;
for a in 0..d {
for b in 0..d {
let mut dg = 0.0;
for i in 0..p {
dg += jac2[[n, (i * d + a) * d + c]] * wj[[i, b]];
dg += wj[[i, a]] * jac2[[n, (i * d + b) * d + c]];
}
acc += metric.metric_grad[[n, a * d + b]] * dg;
}
}
grad[n * d + c] = mu * acc;
}
}
grad
}
fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
let Some(state) = self.hvp_state(target) else {
return Array1::<f64>::zeros(v.len());
};
self.hvp_with_precomputed_state(&state, rho, v)
}
fn psd_majorizer_hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
let d = self
.target
.latent_dim
.expect("IsometryPenalty requires latent_dim on its PsiSlice");
let n_obs = target.len() / d;
if !self.has_jacobian_cache("psd_majorizer_hvp")
|| !self.has_jacobian_second_source("psd_majorizer_hvp")
{
return Array1::<f64>::zeros(v.len());
}
let Some(jac2) = self.jacobian_second(target, n_obs, d) else {
return Array1::<f64>::zeros(v.len());
};
let Some(g) = self.pullback_metric(d) else {
return Array1::<f64>::zeros(v.len());
};
let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
return Array1::<f64>::zeros(v.len());
};
let p = self.p_out;
let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
let mut out = Array1::<f64>::zeros(v.len());
let mut wj_rows = Vec::with_capacity(n_obs);
for n in 0..n_obs {
let Some(wj) = self.weighted_jacobian_row(n, d) else {
return Array1::<f64>::zeros(v.len());
};
wj_rows.push(wj);
}
let mut delta_g = Array2::<f64>::zeros((n_obs, d * d));
for n in 0..n_obs {
let row_delta = isometry_row_delta_g(jac2.view(), wj_rows[n].view(), v, n, d, p);
for a in 0..d {
for b in 0..d {
delta_g[[n, a * d + b]] = row_delta[[a, b]];
}
}
}
let (delta_residual, _delta_normalizer) = metric.residual_direction(delta_g.view(), d);
let mut g_dot_delta_residual = 0.0;
for n in 0..n_obs {
for k in 0..(d * d) {
g_dot_delta_residual += metric.g[[n, k]] * delta_residual[[n, k]];
}
}
let inv_norm = 1.0 / metric.normalizer;
let inv_norm_sq = inv_norm * inv_norm;
for n in 0..n_obs {
let wj = &wj_rows[n];
for c in 0..d {
let mut trace_dg = 0.0;
for a in 0..d {
trace_dg += isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, a, c);
}
let delta_normalizer_c = trace_dg / metric.trace_denominator;
let mut acc = -delta_normalizer_c * inv_norm_sq * g_dot_delta_residual;
for a in 0..d {
for b in 0..d {
let dg = isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, b, c);
acc += dg * inv_norm * delta_residual[[n, a * d + b]];
}
}
out[n * d + c] = mu * acc;
}
}
out
}
fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.rho_count());
out[self.rho_index] = self.value(target, rho);
out
}
fn rho_count(&self) -> usize {
1
}
fn name(&self) -> &str {
"isometry"
}
impl_scalar_apply_schedule!(scalar_weight);
}