1use super::*;
2pub use gam_problem::WeightField;
3
4#[derive(Clone)]
16pub enum IsometryReference {
17 Euclidean,
18 UserSupplied(Arc<Array2<f64>>), }
20
21impl std::fmt::Debug for IsometryReference {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 IsometryReference::Euclidean => f.write_str("Euclidean"),
25 IsometryReference::UserSupplied(a) => f
26 .debug_tuple("UserSupplied")
27 .field(&format_args!("{}×{}", a.nrows(), a.ncols()))
28 .finish(),
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
41pub struct IsometryDuchonRadialSource {
42 pub centers: Arc<Array2<f64>>,
43 pub radial_coefficients: Arc<Array2<f64>>,
44 pub length_scale: Option<f64>,
45 pub nullspace_order: DuchonNullspaceOrder,
46 pub power: usize,
52}
53
54#[derive(Debug)]
130pub struct IsometryPenalty {
131 pub target: PsiSlice,
132 pub reference: IsometryReference,
133 pub rho_index: usize,
136 pub jacobian_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
142 pub jacobian_second_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
149 pub duchon_radial_source: Option<Arc<IsometryDuchonRadialSource>>,
153 pub third_decoder_derivative_slot: RwLock<Option<Arc<ndarray::Array3<f64>>>>,
166 pub p_out: usize,
168 pub weight: WeightField,
174 pub scalar_weight: f64,
175 pub weight_schedule: Option<ScalarWeightSchedule>,
176}
177
178pub(crate) struct IsometryHvpState<'a> {
179 d: usize,
180 n_obs: usize,
181 p: usize,
182 jac2: CowArray<'a, f64, Ix2>,
183 jac3: CowArray<'a, f64, Ix3>,
184 metric: IsometryMetricState,
185 wj_rows: Vec<Array2<f64>>,
186}
187
188#[derive(Debug, Clone)]
189struct IsometryMetricState {
190 g: Array2<f64>,
191 residual: Array2<f64>,
192 metric_grad: Array2<f64>,
193 normalizer: f64,
194 trace_denominator: f64,
195 residual_dot_g: f64,
196}
197
198impl IsometryMetricState {
199 fn residual_direction(&self, delta_g: ArrayView2<'_, f64>, d: usize) -> (Array2<f64>, f64) {
200 let n_obs = self.g.nrows();
201 let dd = d * d;
202 let mut delta_trace_sum = 0.0;
203 for n in 0..n_obs {
204 for a in 0..d {
205 delta_trace_sum += delta_g[[n, a * d + a]];
206 }
207 }
208 let delta_normalizer = delta_trace_sum / self.trace_denominator;
209 let inv_norm = 1.0 / self.normalizer;
210 let inv_norm_sq = inv_norm * inv_norm;
211 let mut delta_residual = Array2::<f64>::zeros((n_obs, dd));
212 for n in 0..n_obs {
213 for k in 0..dd {
214 delta_residual[[n, k]] =
215 delta_g[[n, k]] * inv_norm - self.g[[n, k]] * delta_normalizer * inv_norm_sq;
216 }
217 }
218 (delta_residual, delta_normalizer)
219 }
220
221 fn metric_grad_direction(&self, delta_g: ArrayView2<'_, f64>, d: usize) -> Array2<f64> {
222 let n_obs = self.g.nrows();
223 let dd = d * d;
224 let (delta_residual, delta_normalizer) = self.residual_direction(delta_g, d);
225 let mut delta_residual_dot_g = 0.0;
226 for n in 0..n_obs {
227 for k in 0..dd {
228 delta_residual_dot_g += delta_residual[[n, k]] * self.g[[n, k]];
229 delta_residual_dot_g += self.residual[[n, k]] * delta_g[[n, k]];
230 }
231 }
232 let inv_norm = 1.0 / self.normalizer;
233 let inv_norm_sq = inv_norm * inv_norm;
234 let delta_trace_coeff = delta_residual_dot_g * inv_norm_sq / self.trace_denominator
235 - 2.0 * self.residual_dot_g * delta_normalizer * inv_norm_sq * inv_norm
236 / self.trace_denominator;
237 let mut out = Array2::<f64>::zeros((n_obs, dd));
238 for n in 0..n_obs {
239 for a in 0..d {
240 for b in 0..d {
241 let k = a * d + b;
242 let mut value = delta_residual[[n, k]] * inv_norm
243 - self.residual[[n, k]] * delta_normalizer * inv_norm_sq;
244 if a == b {
245 value -= delta_trace_coeff;
246 }
247 out[[n, k]] = value;
248 }
249 }
250 }
251 out
252 }
253}
254
255fn isometry_dg_entry(
256 jac2: ArrayView2<'_, f64>,
257 wj: ArrayView2<'_, f64>,
258 n: usize,
259 d: usize,
260 p: usize,
261 a: usize,
262 b: usize,
263 c: usize,
264) -> f64 {
265 let mut s = 0.0;
266 for i in 0..p {
267 s += jac2[[n, (i * d + a) * d + c]] * wj[[i, b]];
268 s += wj[[i, a]] * jac2[[n, (i * d + b) * d + c]];
269 }
270 s
271}
272
273fn isometry_row_delta_g(
274 jac2: ArrayView2<'_, f64>,
275 wj: ArrayView2<'_, f64>,
276 v: ArrayView1<'_, f64>,
277 n: usize,
278 d: usize,
279 p: usize,
280) -> Array2<f64> {
281 let mut delta_g = Array2::<f64>::zeros((d, d));
282 for a in 0..d {
283 for b in 0..d {
284 let mut s = 0.0;
285 for c in 0..d {
286 s += isometry_dg_entry(jac2, wj, n, d, p, a, b, c) * v[n * d + c];
287 }
288 delta_g[[a, b]] = s;
289 }
290 }
291 delta_g
292}
293
294impl IsometryPenalty {
295 pub const DEFAULT_VALUE_ON_MISSING_CACHE: f64 = 0.0;
296
297 #[must_use]
298 pub fn new_euclidean(target: PsiSlice, p_out: usize) -> Self {
299 Self {
300 target,
301 reference: IsometryReference::Euclidean,
302 rho_index: 0,
303 jacobian_cache_slot: RwLock::new(None),
304 jacobian_second_cache_slot: RwLock::new(None),
305 duchon_radial_source: None,
306 third_decoder_derivative_slot: RwLock::new(None),
307 p_out,
308 weight: WeightField::Identity,
309 scalar_weight: 1.0,
310 weight_schedule: None,
311 }
312 }
313
314 #[must_use]
320 pub fn jacobian_cache(&self) -> Option<Arc<Array2<f64>>> {
321 self.jacobian_cache_slot
322 .read()
323 .expect("IsometryPenalty::jacobian_cache_slot poisoned")
324 .clone()
325 }
326
327 #[must_use]
330 pub fn jacobian_second_cache(&self) -> Option<Arc<Array2<f64>>> {
331 self.jacobian_second_cache_slot
332 .read()
333 .expect("IsometryPenalty::jacobian_second_cache_slot poisoned")
334 .clone()
335 }
336
337 pub fn refresh_caches(&self, jac: Option<Arc<Array2<f64>>>, jac2: Option<Arc<Array2<f64>>>) {
344 *self
345 .jacobian_cache_slot
346 .write()
347 .expect("IsometryPenalty::jacobian_cache_slot poisoned") = jac;
348 *self
349 .jacobian_second_cache_slot
350 .write()
351 .expect("IsometryPenalty::jacobian_second_cache_slot poisoned") = jac2;
352 }
353
354 pub fn set_jacobian_cache(&self, jac: Option<Arc<Array2<f64>>>) {
357 *self
358 .jacobian_cache_slot
359 .write()
360 .expect("IsometryPenalty::jacobian_cache_slot poisoned") = jac;
361 }
362
363 pub fn set_jacobian_second_cache(&self, jac2: Option<Arc<Array2<f64>>>) {
365 *self
366 .jacobian_second_cache_slot
367 .write()
368 .expect("IsometryPenalty::jacobian_second_cache_slot poisoned") = jac2;
369 }
370
371 #[must_use]
374 pub fn third_decoder_derivative(&self) -> Option<Arc<ndarray::Array3<f64>>> {
375 self.third_decoder_derivative_slot
376 .read()
377 .expect("IsometryPenalty::third_decoder_derivative_slot poisoned")
378 .clone()
379 }
380
381 pub fn set_third_decoder_derivative(&self, jac3: Option<Arc<ndarray::Array3<f64>>>) {
383 *self
384 .third_decoder_derivative_slot
385 .write()
386 .expect("IsometryPenalty::third_decoder_derivative_slot poisoned") = jac3;
387 }
388}
389
390impl Clone for IsometryPenalty {
391 fn clone(&self) -> Self {
392 Self {
393 target: self.target.clone(),
394 reference: self.reference.clone(),
395 rho_index: self.rho_index,
396 jacobian_cache_slot: RwLock::new(self.jacobian_cache()),
397 jacobian_second_cache_slot: RwLock::new(self.jacobian_second_cache()),
398 duchon_radial_source: self.duchon_radial_source.clone(),
399 third_decoder_derivative_slot: RwLock::new(self.third_decoder_derivative()),
400 p_out: self.p_out,
401 weight: self.weight.clone(),
402 scalar_weight: self.scalar_weight,
403 weight_schedule: self.weight_schedule.clone(),
404 }
405 }
406}
407
408impl IsometryPenalty {
409 #[must_use]
415 pub fn with_third_decoder_derivative(self, k: Arc<ndarray::Array3<f64>>) -> Self {
416 self.set_third_decoder_derivative(Some(k));
417 self
418 }
419
420 #[must_use]
421 pub fn with_reference(mut self, reference: IsometryReference) -> Self {
422 self.reference = reference;
423 self
424 }
425
426 #[must_use]
427 pub fn with_jacobian_cache(self, j: Arc<Array2<f64>>) -> Self {
428 self.set_jacobian_cache(Some(j));
429 self
430 }
431
432 #[must_use]
433 pub fn with_jacobian_second_cache(self, h: Arc<Array2<f64>>) -> Self {
434 self.set_jacobian_second_cache(Some(h));
435 self
436 }
437
438 #[must_use]
445 pub fn with_duchon_radial_source(mut self, source: Arc<IsometryDuchonRadialSource>) -> Self {
446 self.duchon_radial_source = Some(source);
447 self
448 }
449
450 #[must_use]
463 pub fn with_row_metric(mut self, metric: &gam_problem::RowMetric) -> Self {
464 if metric.drives_gauge() {
471 self.weight = metric.to_weight_field();
472 }
473 self.p_out = metric.p_out();
474 self
475 }
476
477 impl_with_weight_schedule!(scalar_weight);
478
479 fn missing_cache_default(&self, method: &str, detail: &str) {
480 log::warn!(
481 "IsometryPenalty::{method} missing required derivative state: {detail}; \
482 returning the zero safe default"
483 );
484 }
485
486 fn has_jacobian_cache(&self, method: &str) -> bool {
487 if self.jacobian_cache().is_some() {
488 true
489 } else {
490 self.missing_cache_default(method, "jacobian_cache is None");
491 false
492 }
493 }
494
495 fn has_jacobian_second_source(&self, method: &str) -> bool {
496 if self.jacobian_second_cache().is_some() || self.duchon_radial_source.is_some() {
497 true
498 } else {
499 self.missing_cache_default(
500 method,
501 "both jacobian_second_cache and duchon_radial_source are None",
502 );
503 false
504 }
505 }
506
507 fn has_jacobian_third_source(&self, method: &str) -> bool {
508 if self.third_decoder_derivative().is_some() || self.duchon_radial_source.is_some() {
509 true
510 } else {
511 self.missing_cache_default(
512 method,
513 "both third_decoder_derivative cache and duchon_radial_source are None",
514 );
515 false
516 }
517 }
518
519 fn projected_jacobian_row(&self, n: usize, d: usize) -> Option<Array2<f64>> {
526 let Some(jac) = self.jacobian_cache() else {
527 self.missing_cache_default("projected_jacobian_row", "jacobian_cache is None");
528 return None;
529 };
530 let jac_row = jac.row(n);
531 let jac_slice = jac_row
532 .as_slice()
533 .expect("jacobian cache must be in standard row-major layout");
534 match &self.weight {
535 WeightField::Identity => {
536 let p = self.p_out;
537 let mut m = Array2::<f64>::zeros((p, d));
538 for i in 0..p {
539 for a in 0..d {
540 m[[i, a]] = jac_slice[i * d + a];
541 }
542 }
543 Some(m)
544 }
545 WeightField::Factored { u, rank, p_out } => {
546 let u_row = u.row(n);
547 let u_slice = u_row
548 .as_slice()
549 .expect("weight factor U must be in standard row-major layout");
550 Some(WeightField::project_jac_row_with_u(
551 u_slice, jac_slice, *p_out, *rank, d,
552 ))
553 }
554 }
555 }
556
557 fn weighted_jacobian_row(&self, n: usize, d: usize) -> Option<Array2<f64>> {
559 let Some(jac) = self.jacobian_cache() else {
560 self.missing_cache_default("weighted_jacobian_row", "jacobian_cache is None");
561 return None;
562 };
563 let p = self.p_out;
564 match &self.weight {
565 WeightField::Identity => {
566 let mut out = Array2::<f64>::zeros((p, d));
567 for i in 0..p {
568 for a in 0..d {
569 out[[i, a]] = jac[[n, i * d + a]];
570 }
571 }
572 Some(out)
573 }
574 WeightField::Factored { u, rank, p_out } => {
575 assert_eq!(p, *p_out);
576 let r = *rank;
577 let m_n = self.projected_jacobian_row(n, d)?;
578 let mut out = Array2::<f64>::zeros((p, d));
579 for i in 0..p {
580 for a in 0..d {
581 let mut s = 0.0;
582 for k in 0..r {
583 s += u[[n, i * r + k]] * m_n[[k, a]];
584 }
585 out[[i, a]] = s;
586 }
587 }
588 Some(out)
589 }
590 }
591 }
592
593 fn weighted_dot_decoder_vectors<F, G>(&self, n: usize, p: usize, x: F, y: G) -> f64
594 where
595 F: Fn(usize) -> f64,
596 G: Fn(usize) -> f64,
597 {
598 match &self.weight {
599 WeightField::Identity => {
600 let mut s = 0.0;
601 for i in 0..p {
602 s += x(i) * y(i);
603 }
604 s
605 }
606 WeightField::Factored { u, rank, p_out } => {
607 assert_eq!(p, *p_out);
608 let r = *rank;
609 let mut s = 0.0;
610 for k in 0..r {
611 let mut ux = 0.0;
612 let mut uy = 0.0;
613 for i in 0..p {
614 let uik = u[[n, i * r + k]];
615 ux += uik * x(i);
616 uy += uik * y(i);
617 }
618 s += ux * uy;
619 }
620 s
621 }
622 }
623 }
624
625 fn target_matrix(target: ArrayView1<'_, f64>, n_obs: usize, d: usize) -> Array2<f64> {
626 let mut out = Array2::<f64>::zeros((n_obs, d));
627 for n in 0..n_obs {
628 for a in 0..d {
629 out[[n, a]] = target[n * d + a];
630 }
631 }
632 out
633 }
634
635 fn duchon_radial_jacobian_second(
643 &self,
644 target: ArrayView1<'_, f64>,
645 n_obs: usize,
646 d: usize,
647 source: &IsometryDuchonRadialSource,
648 ) -> Result<Array2<f64>, BasisError> {
649 assert_eq!(source.centers.ncols(), d);
650 assert_eq!(source.radial_coefficients.nrows(), source.centers.nrows());
651 assert_eq!(source.radial_coefficients.ncols(), self.p_out);
652 let t = Self::target_matrix(target, n_obs, d);
653 radial_basis_cartesian_derivative(
654 2,
655 t.view(),
656 source.centers.view(),
657 source.radial_coefficients.view(),
658 source.length_scale,
659 source.nullspace_order,
660 source.power,
661 )
662 }
663
664 fn duchon_radial_jacobian_third(
672 &self,
673 target: ArrayView1<'_, f64>,
674 n_obs: usize,
675 d: usize,
676 source: &IsometryDuchonRadialSource,
677 ) -> Result<ndarray::Array3<f64>, BasisError> {
678 assert_eq!(source.centers.ncols(), d);
679 assert_eq!(source.radial_coefficients.nrows(), source.centers.nrows());
680 assert_eq!(source.radial_coefficients.ncols(), self.p_out);
681 let t = Self::target_matrix(target, n_obs, d);
682 let flat = radial_basis_cartesian_derivative(
683 3,
684 t.view(),
685 source.centers.view(),
686 source.radial_coefficients.view(),
687 source.length_scale,
688 source.nullspace_order,
689 source.power,
690 )?;
691 Ok(flat
692 .into_shape_with_order((n_obs, self.p_out, d * d * d))
693 .expect("radial_basis_cartesian_derivative order-3 output reshapes to (n_obs, p, d³)"))
694 }
695
696 fn jacobian_second<'a>(
697 &'a self,
698 target: ArrayView1<'_, f64>,
699 n_obs: usize,
700 d: usize,
701 ) -> Option<CowArray<'a, f64, Ix2>> {
702 if let Some(jac2) = self.jacobian_second_cache() {
703 return Some(CowArray::from((*jac2).clone()));
710 }
711 let source = self.duchon_radial_source.as_ref()?;
712 match self.duchon_radial_jacobian_second(target, n_obs, d, source) {
713 Ok(jac2) => Some(CowArray::from(jac2)),
714 Err(err) => {
715 self.missing_cache_default(
716 "jacobian_second",
717 &format!("failed to materialize Duchon radial second derivative: {err}"),
718 );
719 None
720 }
721 }
722 }
723
724 fn jacobian_third<'a>(
725 &'a self,
726 target: ArrayView1<'_, f64>,
727 n_obs: usize,
728 d: usize,
729 ) -> Option<CowArray<'a, f64, Ix3>> {
730 if let Some(jac3) = self.third_decoder_derivative() {
731 return Some(CowArray::from(jac3.as_ref().clone()));
732 }
733 let source = self.duchon_radial_source.as_ref()?;
734 match self.duchon_radial_jacobian_third(target, n_obs, d, source) {
735 Ok(jac3) => Some(CowArray::from(jac3)),
736 Err(err) => {
737 self.missing_cache_default(
738 "jacobian_third",
739 &format!("failed to materialize Duchon radial third derivative: {err}"),
740 );
741 None
742 }
743 }
744 }
745
746 pub(crate) fn hvp_state<'a>(
747 &'a self,
748 target: ArrayView1<'_, f64>,
749 ) -> Option<IsometryHvpState<'a>> {
750 let d = self
751 .target
752 .latent_dim
753 .expect("IsometryPenalty requires latent_dim on its PsiSlice");
754 let n_obs = target.len() / d;
755 if !self.has_jacobian_cache("hvp")
756 || !self.has_jacobian_second_source("hvp")
757 || !self.has_jacobian_third_source("hvp")
758 {
759 return None;
760 }
761 let p = self.p_out;
762 let jac2 = self.jacobian_second(target.view(), n_obs, d)?;
763 let jac3 = self.jacobian_third(target.view(), n_obs, d)?;
764 let g = self.pullback_metric(d)?;
765 let metric = self.normalized_metric_state(g, n_obs, d)?;
766 let mut wj_rows = Vec::with_capacity(n_obs);
767 for n in 0..n_obs {
768 wj_rows.push(self.weighted_jacobian_row(n, d)?);
769 }
770 Some(IsometryHvpState {
771 d,
772 n_obs,
773 p,
774 jac2,
775 jac3,
776 metric,
777 wj_rows,
778 })
779 }
780
781 pub(crate) fn hvp_with_precomputed_state(
782 &self,
783 state: &IsometryHvpState<'_>,
784 rho: ArrayView1<'_, f64>,
785 v: ArrayView1<'_, f64>,
786 ) -> Array1<f64> {
787 let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
788 let d = state.d;
789 let n_obs = state.n_obs;
790 let p = state.p;
791 let jac2 = &state.jac2;
792 let jac3 = &state.jac3;
793 let metric = &state.metric;
794 let mut out = Array1::<f64>::zeros(v.len());
795 let mut delta_g = Array2::<f64>::zeros((n_obs, d * d));
796 for n in 0..n_obs {
797 let wj = &state.wj_rows[n];
798 let row_delta = isometry_row_delta_g(jac2.view(), wj.view(), v, n, d, p);
799 for a in 0..d {
800 for b in 0..d {
801 delta_g[[n, a * d + b]] = row_delta[[a, b]];
802 }
803 }
804 }
805 let delta_metric_grad = metric.metric_grad_direction(delta_g.view(), d);
806
807 for n in 0..n_obs {
808 let wj = &state.wj_rows[n];
809 for c in 0..d {
810 let mut acc = 0.0;
811 for a in 0..d {
812 for b in 0..d {
813 let dg = isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, b, c);
814 acc += dg * delta_metric_grad[[n, a * d + b]];
815 }
816 }
817 out[n * d + c] = mu * acc;
818 }
819
820 for c in 0..d {
821 let mut acc_res = 0.0;
822 for a in 0..d {
823 for b in 0..d {
824 let metric_grad = metric.metric_grad[[n, a * d + b]];
825 if metric_grad == 0.0 {
826 continue;
827 }
828 let mut bv = 0.0;
829 for dd in 0..d {
830 let vd = v[n * d + dd];
831 if vd == 0.0 {
832 continue;
833 }
834 let mut k_a_cd_w_j_b = 0.0;
835 for i in 0..p {
836 k_a_cd_w_j_b += jac3[[n, i, ((a * d) + c) * d + dd]] * wj[[i, b]];
837 }
838 let h_a_c_w_h_b_d = self.weighted_dot_decoder_vectors(
839 n,
840 p,
841 |i| jac2[[n, (i * d + a) * d + c]],
842 |i| jac2[[n, (i * d + b) * d + dd]],
843 );
844 let h_a_d_w_h_b_c = self.weighted_dot_decoder_vectors(
845 n,
846 p,
847 |i| jac2[[n, (i * d + a) * d + dd]],
848 |i| jac2[[n, (i * d + b) * d + c]],
849 );
850 let mut j_a_w_k_b_cd = 0.0;
851 for i in 0..p {
852 j_a_w_k_b_cd += wj[[i, a]] * jac3[[n, i, ((b * d) + c) * d + dd]];
853 }
854 bv +=
855 (k_a_cd_w_j_b + h_a_c_w_h_b_d + h_a_d_w_h_b_c + j_a_w_k_b_cd) * vd;
856 }
857 acc_res += metric_grad * bv;
858 }
859 }
860 out[n * d + c] += mu * acc_res;
861 }
862 }
863 out
864 }
865
866 pub fn pullback_metric(&self, latent_dim: usize) -> Option<Array2<f64>> {
874 let Some(jac) = self.jacobian_cache() else {
875 self.missing_cache_default("pullback_metric", "jacobian_cache is None");
876 return None;
877 };
878 let n_obs = jac.nrows();
879 let p = self.p_out;
880 assert_eq!(jac.ncols(), p * latent_dim);
881 let mut g_all = Array2::<f64>::zeros((n_obs, latent_dim * latent_dim));
882 for n in 0..n_obs {
883 let m = self.projected_jacobian_row(n, latent_dim)?;
885 let r = m.nrows();
886 for a in 0..latent_dim {
888 for b in 0..latent_dim {
889 let mut s = 0.0;
890 for k in 0..r {
891 s += m[[k, a]] * m[[k, b]];
892 }
893 g_all[[n, a * latent_dim + b]] = s;
894 }
895 }
896 }
897 Some(g_all)
898 }
899
900 fn reference_metric(&self, n_obs: usize, d: usize) -> CowArray<'_, f64, Ix2> {
902 match &self.reference {
903 IsometryReference::Euclidean => {
904 let mut out = Array2::<f64>::zeros((n_obs, d * d));
905 for n in 0..n_obs {
906 for a in 0..d {
907 out[[n, a * d + a]] = 1.0;
908 }
909 }
910 CowArray::from(out)
911 }
912 IsometryReference::UserSupplied(a) => {
913 assert_eq!(a.nrows(), n_obs);
914 assert_eq!(a.ncols(), d * d);
915 CowArray::from(a.view())
916 }
917 }
918 }
919
920 fn normalized_metric_state(
931 &self,
932 g: Array2<f64>,
933 n_obs: usize,
934 d: usize,
935 ) -> Option<IsometryMetricState> {
936 let dd = d * d;
937 let trace_denominator = (n_obs * d) as f64;
938 let mut trace_sum = 0.0;
939 for n in 0..n_obs {
940 for a in 0..d {
941 trace_sum += g[[n, a * d + a]];
942 }
943 }
944 let normalizer = trace_sum / trace_denominator;
945 if !(normalizer.is_finite() && normalizer > f64::MIN_POSITIVE) {
946 self.missing_cache_default(
947 "normalized_metric_state",
948 &format!(
949 "unit-average-speed normalizer is non-positive or non-finite: {normalizer}"
950 ),
951 );
952 return None;
953 }
954 let g_ref = self.reference_metric(n_obs, d);
955 let mut residual = Array2::<f64>::zeros((n_obs, dd));
956 let inv_norm = 1.0 / normalizer;
957 for n in 0..n_obs {
958 for k in 0..dd {
959 residual[[n, k]] = g[[n, k]] * inv_norm - g_ref[[n, k]];
960 }
961 }
962 let mut residual_dot_g = 0.0;
963 for n in 0..n_obs {
964 for k in 0..dd {
965 residual_dot_g += residual[[n, k]] * g[[n, k]];
966 }
967 }
968 let trace_coeff = residual_dot_g / (normalizer * normalizer * trace_denominator);
969 let mut metric_grad = Array2::<f64>::zeros((n_obs, dd));
970 for n in 0..n_obs {
971 for a in 0..d {
972 for b in 0..d {
973 let k = a * d + b;
974 let mut value = residual[[n, k]] * inv_norm;
975 if a == b {
976 value -= trace_coeff;
977 }
978 metric_grad[[n, k]] = value;
979 }
980 }
981 }
982 Some(IsometryMetricState {
983 g,
984 residual,
985 metric_grad,
986 normalizer,
987 trace_denominator,
988 residual_dot_g,
989 })
990 }
991
992 pub fn grad_jacobian(
1012 &self,
1013 target: ArrayView1<'_, f64>,
1014 rho: ArrayView1<'_, f64>,
1015 ) -> Array2<f64> {
1016 let d = self
1017 .target
1018 .latent_dim
1019 .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1020 let n_obs = target.len() / d;
1021 let p = self.p_out;
1022 let mut grad = Array2::<f64>::zeros((n_obs, p * d));
1023 if !self.has_jacobian_cache("grad_jacobian") {
1024 return grad;
1025 }
1026 let Some(g) = self.pullback_metric(d) else {
1027 return grad;
1028 };
1029 let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1030 return grad;
1031 };
1032 let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1033 for n in 0..n_obs {
1034 let Some(wj) = self.weighted_jacobian_row(n, d) else {
1035 return Array2::<f64>::zeros((n_obs, p * d));
1036 };
1037 for i in 0..p {
1038 for c in 0..d {
1039 let mut acc = 0.0;
1040 for b in 0..d {
1041 acc += metric.metric_grad[[n, c * d + b]] * wj[[i, b]];
1042 }
1043 grad[[n, i * d + c]] = 2.0 * mu * acc;
1044 }
1045 }
1046 }
1047 grad
1048 }
1049}
1050
1051impl AnalyticPenalty for IsometryPenalty {
1052 fn tier(&self) -> PenaltyTier {
1053 PenaltyTier::Psi
1054 }
1055
1056 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
1057 let d = self
1058 .target
1059 .latent_dim
1060 .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1061 let n_obs = target.len() / d;
1062 if !self.has_jacobian_cache("value") {
1063 return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
1064 }
1065 let Some(g) = self.pullback_metric(d) else {
1066 return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
1067 };
1068 let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1069 return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
1070 };
1071 let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1072 let mut acc = 0.0;
1073 for n in 0..n_obs {
1074 for k in 0..(d * d) {
1075 let diff = metric.residual[[n, k]];
1076 acc += diff * diff;
1077 }
1078 }
1079 0.5 * mu * acc
1080 }
1081
1082 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1083 let d = self
1098 .target
1099 .latent_dim
1100 .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1101 let n_obs = target.len() / d;
1102 if !self.has_jacobian_cache("grad_target")
1103 || !self.has_jacobian_second_source("grad_target")
1104 {
1105 return Array1::<f64>::zeros(target.len());
1106 }
1107 let Some(g) = self.pullback_metric(d) else {
1108 return Array1::<f64>::zeros(target.len());
1109 };
1110 let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1111 return Array1::<f64>::zeros(target.len());
1112 };
1113 let p = self.p_out;
1114 let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1115 let mut grad = Array1::<f64>::zeros(target.len());
1116 let Some(jac2) = self.jacobian_second(target, n_obs, d) else {
1117 return grad;
1118 };
1119 assert_eq!(jac2.ncols(), p * d * d);
1120
1121 for n in 0..n_obs {
1122 let Some(wj) = self.weighted_jacobian_row(n, d) else {
1123 return grad;
1124 };
1125 for c in 0..d {
1126 let mut acc = 0.0;
1127 for a in 0..d {
1128 for b in 0..d {
1129 let mut dg = 0.0;
1130 for i in 0..p {
1131 dg += jac2[[n, (i * d + a) * d + c]] * wj[[i, b]];
1132 dg += wj[[i, a]] * jac2[[n, (i * d + b) * d + c]];
1133 }
1134 acc += metric.metric_grad[[n, a * d + b]] * dg;
1135 }
1136 }
1137 grad[n * d + c] = mu * acc;
1138 }
1139 }
1140 grad
1141 }
1142
1143 fn hvp(
1145 &self,
1146 target: ArrayView1<'_, f64>,
1147 rho: ArrayView1<'_, f64>,
1148 v: ArrayView1<'_, f64>,
1149 ) -> Array1<f64> {
1150 let Some(state) = self.hvp_state(target) else {
1163 return Array1::<f64>::zeros(v.len());
1164 };
1165 self.hvp_with_precomputed_state(&state, rho, v)
1166 }
1167
1168 fn psd_majorizer_hvp(
1177 &self,
1178 target: ArrayView1<'_, f64>,
1179 rho: ArrayView1<'_, f64>,
1180 v: ArrayView1<'_, f64>,
1181 ) -> Array1<f64> {
1182 let d = self
1183 .target
1184 .latent_dim
1185 .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1186 let n_obs = target.len() / d;
1187 if !self.has_jacobian_cache("psd_majorizer_hvp")
1188 || !self.has_jacobian_second_source("psd_majorizer_hvp")
1189 {
1190 return Array1::<f64>::zeros(v.len());
1191 }
1192 let Some(jac2) = self.jacobian_second(target, n_obs, d) else {
1193 return Array1::<f64>::zeros(v.len());
1194 };
1195 let Some(g) = self.pullback_metric(d) else {
1196 return Array1::<f64>::zeros(v.len());
1197 };
1198 let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1199 return Array1::<f64>::zeros(v.len());
1200 };
1201 let p = self.p_out;
1202 let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1203 let mut out = Array1::<f64>::zeros(v.len());
1204 let mut wj_rows = Vec::with_capacity(n_obs);
1205 for n in 0..n_obs {
1206 let Some(wj) = self.weighted_jacobian_row(n, d) else {
1207 return Array1::<f64>::zeros(v.len());
1208 };
1209 wj_rows.push(wj);
1210 }
1211 let mut delta_g = Array2::<f64>::zeros((n_obs, d * d));
1212 for n in 0..n_obs {
1213 let row_delta = isometry_row_delta_g(jac2.view(), wj_rows[n].view(), v, n, d, p);
1214 for a in 0..d {
1215 for b in 0..d {
1216 delta_g[[n, a * d + b]] = row_delta[[a, b]];
1217 }
1218 }
1219 }
1220 let (delta_residual, _delta_normalizer) = metric.residual_direction(delta_g.view(), d);
1221 let mut g_dot_delta_residual = 0.0;
1222 for n in 0..n_obs {
1223 for k in 0..(d * d) {
1224 g_dot_delta_residual += metric.g[[n, k]] * delta_residual[[n, k]];
1225 }
1226 }
1227 let inv_norm = 1.0 / metric.normalizer;
1228 let inv_norm_sq = inv_norm * inv_norm;
1229 for n in 0..n_obs {
1230 let wj = &wj_rows[n];
1231 for c in 0..d {
1232 let mut trace_dg = 0.0;
1233 for a in 0..d {
1234 trace_dg += isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, a, c);
1235 }
1236 let delta_normalizer_c = trace_dg / metric.trace_denominator;
1237 let mut acc = -delta_normalizer_c * inv_norm_sq * g_dot_delta_residual;
1238 for a in 0..d {
1239 for b in 0..d {
1240 let dg = isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, b, c);
1241 acc += dg * inv_norm * delta_residual[[n, a * d + b]];
1242 }
1243 }
1244 out[n * d + c] = mu * acc;
1245 }
1246 }
1247 out
1248 }
1249
1250 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1251 let mut out = Array1::<f64>::zeros(self.rho_count());
1254 out[self.rho_index] = self.value(target, rho);
1255 out
1256 }
1257
1258 fn rho_count(&self) -> usize {
1259 1
1260 }
1261
1262 fn name(&self) -> &str {
1263 "isometry"
1264 }
1265
1266 impl_scalar_apply_schedule!(scalar_weight);
1267}