1use super::*;
2pub use gam_problem::WeightField;
3
4#[derive(Clone)]
16pub enum IsometryReference {
17 Euclidean,
18 UserSupplied(Arc<Array2<f64>>), }
20
21impl std::fmt::Debug for IsometryReference {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 IsometryReference::Euclidean => f.write_str("Euclidean"),
25 IsometryReference::UserSupplied(a) => f
26 .debug_tuple("UserSupplied")
27 .field(&format_args!("{}×{}", a.nrows(), a.ncols()))
28 .finish(),
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
41pub struct IsometryDuchonRadialSource {
42 pub centers: Arc<Array2<f64>>,
43 pub radial_coefficients: Arc<Array2<f64>>,
44 pub length_scale: Option<f64>,
45 pub nullspace_order: DuchonNullspaceOrder,
46 pub power: usize,
52}
53
54#[derive(Debug)]
130pub struct IsometryPenalty {
131 pub target: PsiSlice,
132 pub reference: IsometryReference,
133 pub rho_index: usize,
136 pub jacobian_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
142 pub jacobian_second_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
149 pub duchon_radial_source: Option<Arc<IsometryDuchonRadialSource>>,
153 pub third_decoder_derivative_slot: RwLock<Option<Arc<ndarray::Array3<f64>>>>,
166 pub p_out: usize,
168 pub weight: WeightField,
174 pub scalar_weight: f64,
175 pub weight_schedule: Option<ScalarWeightSchedule>,
176}
177
178pub(crate) struct IsometryHvpState<'a> {
179 d: usize,
180 n_obs: usize,
181 p: usize,
182 jac2: CowArray<'a, f64, Ix2>,
183 jac3: CowArray<'a, f64, Ix3>,
184 metric: IsometryMetricState,
185 wj_rows: Vec<Array2<f64>>,
186}
187
188#[derive(Debug, Clone)]
189struct IsometryMetricState {
190 g: Array2<f64>,
191 residual: Array2<f64>,
192 metric_grad: Array2<f64>,
193 normalizer: f64,
194 trace_denominator: f64,
195 residual_dot_g: f64,
196}
197
198impl IsometryMetricState {
199 fn residual_direction(&self, delta_g: ArrayView2<'_, f64>, d: usize) -> (Array2<f64>, f64) {
200 let n_obs = self.g.nrows();
201 let dd = d * d;
202 let mut delta_trace_sum = 0.0;
203 for n in 0..n_obs {
204 for a in 0..d {
205 delta_trace_sum += delta_g[[n, a * d + a]];
206 }
207 }
208 let delta_normalizer = delta_trace_sum / self.trace_denominator;
209 let inv_norm = 1.0 / self.normalizer;
210 let inv_norm_sq = inv_norm * inv_norm;
211 let mut delta_residual = Array2::<f64>::zeros((n_obs, dd));
212 for n in 0..n_obs {
213 for k in 0..dd {
214 delta_residual[[n, k]] =
215 delta_g[[n, k]] * inv_norm - self.g[[n, k]] * delta_normalizer * inv_norm_sq;
216 }
217 }
218 (delta_residual, delta_normalizer)
219 }
220
221 fn metric_grad_direction(&self, delta_g: ArrayView2<'_, f64>, d: usize) -> Array2<f64> {
222 let n_obs = self.g.nrows();
223 let dd = d * d;
224 let (delta_residual, delta_normalizer) = self.residual_direction(delta_g, d);
225 let mut delta_residual_dot_g = 0.0;
226 for n in 0..n_obs {
227 for k in 0..dd {
228 delta_residual_dot_g += delta_residual[[n, k]] * self.g[[n, k]];
229 delta_residual_dot_g += self.residual[[n, k]] * delta_g[[n, k]];
230 }
231 }
232 let inv_norm = 1.0 / self.normalizer;
233 let inv_norm_sq = inv_norm * inv_norm;
234 let delta_trace_coeff = delta_residual_dot_g * inv_norm_sq / self.trace_denominator
235 - 2.0 * self.residual_dot_g * delta_normalizer * inv_norm_sq * inv_norm
236 / self.trace_denominator;
237 let mut out = Array2::<f64>::zeros((n_obs, dd));
238 for n in 0..n_obs {
239 for a in 0..d {
240 for b in 0..d {
241 let k = a * d + b;
242 let mut value = delta_residual[[n, k]] * inv_norm
243 - self.residual[[n, k]] * delta_normalizer * inv_norm_sq;
244 if a == b {
245 value -= delta_trace_coeff;
246 }
247 out[[n, k]] = value;
248 }
249 }
250 }
251 out
252 }
253}
254
255fn isometry_dg_entry(
256 jac2: ArrayView2<'_, f64>,
257 wj: ArrayView2<'_, f64>,
258 n: usize,
259 d: usize,
260 p: usize,
261 a: usize,
262 b: usize,
263 c: usize,
264) -> f64 {
265 let mut s = 0.0;
266 for i in 0..p {
267 s += jac2[[n, (i * d + a) * d + c]] * wj[[i, b]];
268 s += wj[[i, a]] * jac2[[n, (i * d + b) * d + c]];
269 }
270 s
271}
272
273fn isometry_row_delta_g(
274 jac2: ArrayView2<'_, f64>,
275 wj: ArrayView2<'_, f64>,
276 v: ArrayView1<'_, f64>,
277 n: usize,
278 d: usize,
279 p: usize,
280) -> Array2<f64> {
281 let mut delta_g = Array2::<f64>::zeros((d, d));
282 for a in 0..d {
283 for b in 0..d {
284 let mut s = 0.0;
285 for c in 0..d {
286 s += isometry_dg_entry(jac2, wj, n, d, p, a, b, c) * v[n * d + c];
287 }
288 delta_g[[a, b]] = s;
289 }
290 }
291 delta_g
292}
293
294impl IsometryPenalty {
295 pub const DEFAULT_VALUE_ON_MISSING_CACHE: f64 = 0.0;
296
297 #[must_use]
298 pub fn new_euclidean(target: PsiSlice, p_out: usize) -> Self {
299 Self {
300 target,
301 reference: IsometryReference::Euclidean,
302 rho_index: 0,
303 jacobian_cache_slot: RwLock::new(None),
304 jacobian_second_cache_slot: RwLock::new(None),
305 duchon_radial_source: None,
306 third_decoder_derivative_slot: RwLock::new(None),
307 p_out,
308 weight: WeightField::Identity,
309 scalar_weight: 1.0,
310 weight_schedule: None,
311 }
312 }
313
314 #[must_use]
320 pub fn jacobian_cache(&self) -> Option<Arc<Array2<f64>>> {
321 self.jacobian_cache_slot
322 .read()
323 .expect("IsometryPenalty::jacobian_cache_slot poisoned")
324 .clone()
325 }
326
327 #[must_use]
330 pub fn jacobian_second_cache(&self) -> Option<Arc<Array2<f64>>> {
331 self.jacobian_second_cache_slot
332 .read()
333 .expect("IsometryPenalty::jacobian_second_cache_slot poisoned")
334 .clone()
335 }
336
337 pub fn refresh_caches(&self, jac: Option<Arc<Array2<f64>>>, jac2: Option<Arc<Array2<f64>>>) {
344 *self
345 .jacobian_cache_slot
346 .write()
347 .expect("IsometryPenalty::jacobian_cache_slot poisoned") = jac;
348 *self
349 .jacobian_second_cache_slot
350 .write()
351 .expect("IsometryPenalty::jacobian_second_cache_slot poisoned") = jac2;
352 }
353
354 pub fn set_jacobian_cache(&self, jac: Option<Arc<Array2<f64>>>) {
357 *self
358 .jacobian_cache_slot
359 .write()
360 .expect("IsometryPenalty::jacobian_cache_slot poisoned") = jac;
361 }
362
363 pub fn set_jacobian_second_cache(&self, jac2: Option<Arc<Array2<f64>>>) {
365 *self
366 .jacobian_second_cache_slot
367 .write()
368 .expect("IsometryPenalty::jacobian_second_cache_slot poisoned") = jac2;
369 }
370
371 #[must_use]
374 pub fn third_decoder_derivative(&self) -> Option<Arc<ndarray::Array3<f64>>> {
375 self.third_decoder_derivative_slot
376 .read()
377 .expect("IsometryPenalty::third_decoder_derivative_slot poisoned")
378 .clone()
379 }
380
381 pub fn set_third_decoder_derivative(&self, jac3: Option<Arc<ndarray::Array3<f64>>>) {
383 *self
384 .third_decoder_derivative_slot
385 .write()
386 .expect("IsometryPenalty::third_decoder_derivative_slot poisoned") = jac3;
387 }
388}
389
390impl Clone for IsometryPenalty {
391 fn clone(&self) -> Self {
392 Self {
393 target: self.target.clone(),
394 reference: self.reference.clone(),
395 rho_index: self.rho_index,
396 jacobian_cache_slot: RwLock::new(self.jacobian_cache()),
397 jacobian_second_cache_slot: RwLock::new(self.jacobian_second_cache()),
398 duchon_radial_source: self.duchon_radial_source.clone(),
399 third_decoder_derivative_slot: RwLock::new(self.third_decoder_derivative()),
400 p_out: self.p_out,
401 weight: self.weight.clone(),
402 scalar_weight: self.scalar_weight,
403 weight_schedule: self.weight_schedule.clone(),
404 }
405 }
406}
407
408impl IsometryPenalty {
409 #[must_use]
415 pub fn with_third_decoder_derivative(self, k: Arc<ndarray::Array3<f64>>) -> Self {
416 self.set_third_decoder_derivative(Some(k));
417 self
418 }
419
420 #[must_use]
421 pub fn with_reference(mut self, reference: IsometryReference) -> Self {
422 self.reference = reference;
423 self
424 }
425
426 #[must_use]
427 pub fn with_jacobian_cache(self, j: Arc<Array2<f64>>) -> Self {
428 self.set_jacobian_cache(Some(j));
429 self
430 }
431
432 #[must_use]
433 pub fn with_jacobian_second_cache(self, h: Arc<Array2<f64>>) -> Self {
434 self.set_jacobian_second_cache(Some(h));
435 self
436 }
437
438 #[must_use]
445 pub fn with_duchon_radial_source(mut self, source: Arc<IsometryDuchonRadialSource>) -> Self {
446 self.duchon_radial_source = Some(source);
447 self
448 }
449
450 #[must_use]
463 pub fn with_row_metric(mut self, metric: &gam_problem::RowMetric) -> Self {
464 if metric.drives_gauge() {
471 self.weight = metric.to_weight_field();
472 }
473 self.p_out = metric.p_out();
474 self
475 }
476
477 impl_with_weight_schedule!(scalar_weight);
478
479 fn missing_cache_default(&self, method: &str, detail: &str) {
480 log::warn!(
481 "IsometryPenalty::{method} missing required derivative state: {detail}; \
482 returning the zero safe default"
483 );
484 }
485
486 fn has_jacobian_cache(&self, method: &str) -> bool {
487 if self.jacobian_cache().is_some() {
488 true
489 } else {
490 self.missing_cache_default(method, "jacobian_cache is None");
491 false
492 }
493 }
494
495 fn has_jacobian_second_source(&self, method: &str) -> bool {
496 if self.jacobian_second_cache().is_some() || self.duchon_radial_source.is_some() {
497 true
498 } else {
499 self.missing_cache_default(
500 method,
501 "both jacobian_second_cache and duchon_radial_source are None",
502 );
503 false
504 }
505 }
506
507 fn has_jacobian_third_source(&self, method: &str) -> bool {
508 if self.third_decoder_derivative().is_some() || self.duchon_radial_source.is_some() {
509 true
510 } else {
511 self.missing_cache_default(
512 method,
513 "both third_decoder_derivative cache and duchon_radial_source are None",
514 );
515 false
516 }
517 }
518
519 fn projected_jacobian_row(&self, n: usize, d: usize) -> Option<Array2<f64>> {
526 let Some(jac) = self.jacobian_cache() else {
527 self.missing_cache_default("projected_jacobian_row", "jacobian_cache is None");
528 return None;
529 };
530 let jac_row = jac.row(n);
531 let jac_slice = jac_row
532 .as_slice()
533 .expect("jacobian cache must be in standard row-major layout");
534 match &self.weight {
535 WeightField::Identity => {
536 let p = self.p_out;
537 let mut m = Array2::<f64>::zeros((p, d));
538 for i in 0..p {
539 for a in 0..d {
540 m[[i, a]] = jac_slice[i * d + a];
541 }
542 }
543 Some(m)
544 }
545 WeightField::Factored { u, rank, p_out } => {
546 let u_row = u.row(n);
547 let u_slice = u_row
548 .as_slice()
549 .expect("weight factor U must be in standard row-major layout");
550 Some(WeightField::project_jac_row_with_u(
551 u_slice, jac_slice, *p_out, *rank, d,
552 ))
553 }
554 }
555 }
556
557 fn weighted_jacobian_row(&self, n: usize, d: usize) -> Option<Array2<f64>> {
559 let Some(jac) = self.jacobian_cache() else {
560 self.missing_cache_default("weighted_jacobian_row", "jacobian_cache is None");
561 return None;
562 };
563 let p = self.p_out;
564 match &self.weight {
565 WeightField::Identity => {
566 let mut out = Array2::<f64>::zeros((p, d));
567 for i in 0..p {
568 for a in 0..d {
569 out[[i, a]] = jac[[n, i * d + a]];
570 }
571 }
572 Some(out)
573 }
574 WeightField::Factored { u, rank, p_out } => {
575 assert_eq!(p, *p_out);
576 let r = *rank;
577 let m_n = self.projected_jacobian_row(n, d)?;
578 let mut out = Array2::<f64>::zeros((p, d));
579 for i in 0..p {
580 for a in 0..d {
581 let mut s = 0.0;
582 for k in 0..r {
583 s += u[[n, i * r + k]] * m_n[[k, a]];
584 }
585 out[[i, a]] = s;
586 }
587 }
588 Some(out)
589 }
590 }
591 }
592
593 fn weighted_dot_decoder_vectors<F, G>(&self, n: usize, p: usize, x: F, y: G) -> f64
594 where
595 F: Fn(usize) -> f64,
596 G: Fn(usize) -> f64,
597 {
598 match &self.weight {
599 WeightField::Identity => {
600 let mut s = 0.0;
601 for i in 0..p {
602 s += x(i) * y(i);
603 }
604 s
605 }
606 WeightField::Factored { u, rank, p_out } => {
607 assert_eq!(p, *p_out);
608 let r = *rank;
609 let mut s = 0.0;
610 for k in 0..r {
611 let mut ux = 0.0;
612 let mut uy = 0.0;
613 for i in 0..p {
614 let uik = u[[n, i * r + k]];
615 ux += uik * x(i);
616 uy += uik * y(i);
617 }
618 s += ux * uy;
619 }
620 s
621 }
622 }
623 }
624
625 fn target_matrix(target: ArrayView1<'_, f64>, n_obs: usize, d: usize) -> Array2<f64> {
626 let mut out = Array2::<f64>::zeros((n_obs, d));
627 for n in 0..n_obs {
628 for a in 0..d {
629 out[[n, a]] = target[n * d + a];
630 }
631 }
632 out
633 }
634
635 fn duchon_radial_jacobian_second(
643 &self,
644 target: ArrayView1<'_, f64>,
645 n_obs: usize,
646 d: usize,
647 source: &IsometryDuchonRadialSource,
648 ) -> Result<Array2<f64>, BasisError> {
649 assert_eq!(source.centers.ncols(), d);
650 assert_eq!(source.radial_coefficients.nrows(), source.centers.nrows());
651 assert_eq!(source.radial_coefficients.ncols(), self.p_out);
652 let t = Self::target_matrix(target, n_obs, d);
653 radial_basis_cartesian_derivative(
654 2,
655 t.view(),
656 source.centers.view(),
657 source.radial_coefficients.view(),
658 source.length_scale,
659 source.nullspace_order,
660 source.power,
661 )
662 }
663
664 fn duchon_radial_jacobian_third(
672 &self,
673 target: ArrayView1<'_, f64>,
674 n_obs: usize,
675 d: usize,
676 source: &IsometryDuchonRadialSource,
677 ) -> Result<ndarray::Array3<f64>, BasisError> {
678 assert_eq!(source.centers.ncols(), d);
679 assert_eq!(source.radial_coefficients.nrows(), source.centers.nrows());
680 assert_eq!(source.radial_coefficients.ncols(), self.p_out);
681 let t = Self::target_matrix(target, n_obs, d);
682 let flat = radial_basis_cartesian_derivative(
683 3,
684 t.view(),
685 source.centers.view(),
686 source.radial_coefficients.view(),
687 source.length_scale,
688 source.nullspace_order,
689 source.power,
690 )?;
691 Ok(flat
692 .into_shape_with_order((n_obs, self.p_out, d * d * d))
693 .expect("radial_basis_cartesian_derivative order-3 output reshapes to (n_obs, p, d³)"))
694 }
695
696 fn jacobian_second<'a>(
697 &'a self,
698 target: ArrayView1<'_, f64>,
699 n_obs: usize,
700 d: usize,
701 ) -> Option<CowArray<'a, f64, Ix2>> {
702 if let Some(jac2) = self.jacobian_second_cache() {
703 return Some(CowArray::from((*jac2).clone()));
710 }
711 let source = self.duchon_radial_source.as_ref()?;
712 match self.duchon_radial_jacobian_second(target, n_obs, d, source) {
713 Ok(jac2) => Some(CowArray::from(jac2)),
714 Err(err) => {
715 self.missing_cache_default(
716 "jacobian_second",
717 &format!("failed to materialize Duchon radial second derivative: {err}"),
718 );
719 None
720 }
721 }
722 }
723
724 fn jacobian_third<'a>(
725 &'a self,
726 target: ArrayView1<'_, f64>,
727 n_obs: usize,
728 d: usize,
729 ) -> Option<CowArray<'a, f64, Ix3>> {
730 if let Some(jac3) = self.third_decoder_derivative() {
731 return Some(CowArray::from(jac3.as_ref().clone()));
732 }
733 let source = self.duchon_radial_source.as_ref()?;
734 match self.duchon_radial_jacobian_third(target, n_obs, d, source) {
735 Ok(jac3) => Some(CowArray::from(jac3)),
736 Err(err) => {
737 self.missing_cache_default(
738 "jacobian_third",
739 &format!("failed to materialize Duchon radial third derivative: {err}"),
740 );
741 None
742 }
743 }
744 }
745
746 pub(crate) fn hvp_state<'a>(
747 &'a self,
748 target: ArrayView1<'_, f64>,
749 ) -> Option<IsometryHvpState<'a>> {
750 let d = self
751 .target
752 .latent_dim
753 .expect("IsometryPenalty requires latent_dim on its PsiSlice");
754 let n_obs = target.len() / d;
755 if !self.has_jacobian_cache("hvp")
756 || !self.has_jacobian_second_source("hvp")
757 || !self.has_jacobian_third_source("hvp")
758 {
759 return None;
760 }
761 let p = self.p_out;
762 let jac2 = self.jacobian_second(target.view(), n_obs, d)?;
763 let jac3 = self.jacobian_third(target.view(), n_obs, d)?;
764 let g = self.pullback_metric(d)?;
765 let metric = self.normalized_metric_state(g, n_obs, d)?;
766 let mut wj_rows = Vec::with_capacity(n_obs);
767 for n in 0..n_obs {
768 wj_rows.push(self.weighted_jacobian_row(n, d)?);
769 }
770 Some(IsometryHvpState {
771 d,
772 n_obs,
773 p,
774 jac2,
775 jac3,
776 metric,
777 wj_rows,
778 })
779 }
780
781 pub(crate) fn hvp_with_precomputed_state(
782 &self,
783 state: &IsometryHvpState<'_>,
784 rho: ArrayView1<'_, f64>,
785 v: ArrayView1<'_, f64>,
786 ) -> Array1<f64> {
787 let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
788 let d = state.d;
789 let n_obs = state.n_obs;
790 let p = state.p;
791 let jac2 = &state.jac2;
792 let jac3 = &state.jac3;
793 let metric = &state.metric;
794 let mut out = Array1::<f64>::zeros(v.len());
795 let mut delta_g = Array2::<f64>::zeros((n_obs, d * d));
796 for n in 0..n_obs {
797 let wj = &state.wj_rows[n];
798 let row_delta = isometry_row_delta_g(jac2.view(), wj.view(), v, n, d, p);
799 for a in 0..d {
800 for b in 0..d {
801 delta_g[[n, a * d + b]] = row_delta[[a, b]];
802 }
803 }
804 }
805 let delta_metric_grad = metric.metric_grad_direction(delta_g.view(), d);
806
807 for n in 0..n_obs {
808 let wj = &state.wj_rows[n];
809 for c in 0..d {
810 let mut acc = 0.0;
811 for a in 0..d {
812 for b in 0..d {
813 let dg = isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, b, c);
814 acc += dg * delta_metric_grad[[n, a * d + b]];
815 }
816 }
817 out[n * d + c] = mu * acc;
818 }
819
820 for c in 0..d {
821 let mut acc_res = 0.0;
822 for a in 0..d {
823 for b in 0..d {
824 let metric_grad = metric.metric_grad[[n, a * d + b]];
825 if metric_grad == 0.0 {
826 continue;
827 }
828 let mut bv = 0.0;
829 for dd in 0..d {
830 let vd = v[n * d + dd];
831 if vd == 0.0 {
832 continue;
833 }
834 let mut k_a_cd_w_j_b = 0.0;
835 for i in 0..p {
836 k_a_cd_w_j_b += jac3[[n, i, ((a * d) + c) * d + dd]] * wj[[i, b]];
837 }
838 let h_a_c_w_h_b_d = self.weighted_dot_decoder_vectors(
839 n,
840 p,
841 |i| jac2[[n, (i * d + a) * d + c]],
842 |i| jac2[[n, (i * d + b) * d + dd]],
843 );
844 let h_a_d_w_h_b_c = self.weighted_dot_decoder_vectors(
845 n,
846 p,
847 |i| jac2[[n, (i * d + a) * d + dd]],
848 |i| jac2[[n, (i * d + b) * d + c]],
849 );
850 let mut j_a_w_k_b_cd = 0.0;
851 for i in 0..p {
852 j_a_w_k_b_cd += wj[[i, a]] * jac3[[n, i, ((b * d) + c) * d + dd]];
853 }
854 bv +=
855 (k_a_cd_w_j_b + h_a_c_w_h_b_d + h_a_d_w_h_b_c + j_a_w_k_b_cd) * vd;
856 }
857 acc_res += metric_grad * bv;
858 }
859 }
860 out[n * d + c] += mu * acc_res;
861 }
862 }
863 out
864 }
865
866 pub fn pullback_metric(&self, latent_dim: usize) -> Option<Array2<f64>> {
874 let Some(jac) = self.jacobian_cache() else {
875 self.missing_cache_default("pullback_metric", "jacobian_cache is None");
876 return None;
877 };
878 let n_obs = jac.nrows();
879 let p = self.p_out;
880 assert_eq!(jac.ncols(), p * latent_dim);
881 let mut g_all = Array2::<f64>::zeros((n_obs, latent_dim * latent_dim));
882 for n in 0..n_obs {
883 let m = self.projected_jacobian_row(n, latent_dim)?;
885 let r = m.nrows();
886 for a in 0..latent_dim {
888 for b in 0..latent_dim {
889 let mut s = 0.0;
890 for k in 0..r {
891 s += m[[k, a]] * m[[k, b]];
892 }
893 g_all[[n, a * latent_dim + b]] = s;
894 }
895 }
896 }
897 Some(g_all)
898 }
899
900 pub fn metric_normalizer(&self, latent_dim: usize) -> Option<f64> {
919 let g = self.pullback_metric(latent_dim)?;
920 let n_obs = g.nrows();
921 let trace_denominator = (n_obs * latent_dim) as f64;
922 let mut trace_sum = 0.0;
923 for n in 0..n_obs {
924 for a in 0..latent_dim {
925 trace_sum += g[[n, a * latent_dim + a]];
926 }
927 }
928 let normalizer = trace_sum / trace_denominator;
929 (normalizer.is_finite() && normalizer > f64::MIN_POSITIVE).then_some(normalizer)
930 }
931
932 fn reference_metric(&self, n_obs: usize, d: usize) -> CowArray<'_, f64, Ix2> {
934 match &self.reference {
935 IsometryReference::Euclidean => {
936 let mut out = Array2::<f64>::zeros((n_obs, d * d));
937 for n in 0..n_obs {
938 for a in 0..d {
939 out[[n, a * d + a]] = 1.0;
940 }
941 }
942 CowArray::from(out)
943 }
944 IsometryReference::UserSupplied(a) => {
945 assert_eq!(a.nrows(), n_obs);
946 assert_eq!(a.ncols(), d * d);
947 CowArray::from(a.view())
948 }
949 }
950 }
951
952 fn normalized_metric_state(
963 &self,
964 g: Array2<f64>,
965 n_obs: usize,
966 d: usize,
967 ) -> Option<IsometryMetricState> {
968 let dd = d * d;
969 let trace_denominator = (n_obs * d) as f64;
970 let mut trace_sum = 0.0;
971 for n in 0..n_obs {
972 for a in 0..d {
973 trace_sum += g[[n, a * d + a]];
974 }
975 }
976 let normalizer = trace_sum / trace_denominator;
977 if !(normalizer.is_finite() && normalizer > f64::MIN_POSITIVE) {
978 self.missing_cache_default(
979 "normalized_metric_state",
980 &format!(
981 "unit-average-speed normalizer is non-positive or non-finite: {normalizer}"
982 ),
983 );
984 return None;
985 }
986 let g_ref = self.reference_metric(n_obs, d);
987 let mut residual = Array2::<f64>::zeros((n_obs, dd));
988 let inv_norm = 1.0 / normalizer;
989 for n in 0..n_obs {
990 for k in 0..dd {
991 residual[[n, k]] = g[[n, k]] * inv_norm - g_ref[[n, k]];
992 }
993 }
994 let mut residual_dot_g = 0.0;
995 for n in 0..n_obs {
996 for k in 0..dd {
997 residual_dot_g += residual[[n, k]] * g[[n, k]];
998 }
999 }
1000 let trace_coeff = residual_dot_g / (normalizer * normalizer * trace_denominator);
1001 let mut metric_grad = Array2::<f64>::zeros((n_obs, dd));
1002 for n in 0..n_obs {
1003 for a in 0..d {
1004 for b in 0..d {
1005 let k = a * d + b;
1006 let mut value = residual[[n, k]] * inv_norm;
1007 if a == b {
1008 value -= trace_coeff;
1009 }
1010 metric_grad[[n, k]] = value;
1011 }
1012 }
1013 }
1014 Some(IsometryMetricState {
1015 g,
1016 residual,
1017 metric_grad,
1018 normalizer,
1019 trace_denominator,
1020 residual_dot_g,
1021 })
1022 }
1023
1024 pub fn grad_jacobian(
1044 &self,
1045 target: ArrayView1<'_, f64>,
1046 rho: ArrayView1<'_, f64>,
1047 ) -> Array2<f64> {
1048 let d = self
1049 .target
1050 .latent_dim
1051 .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1052 let n_obs = target.len() / d;
1053 let p = self.p_out;
1054 let mut grad = Array2::<f64>::zeros((n_obs, p * d));
1055 if !self.has_jacobian_cache("grad_jacobian") {
1056 return grad;
1057 }
1058 let Some(g) = self.pullback_metric(d) else {
1059 return grad;
1060 };
1061 let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1062 return grad;
1063 };
1064 let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1065 for n in 0..n_obs {
1066 let Some(wj) = self.weighted_jacobian_row(n, d) else {
1067 return Array2::<f64>::zeros((n_obs, p * d));
1068 };
1069 for i in 0..p {
1070 for c in 0..d {
1071 let mut acc = 0.0;
1072 for b in 0..d {
1073 acc += metric.metric_grad[[n, c * d + b]] * wj[[i, b]];
1074 }
1075 grad[[n, i * d + c]] = 2.0 * mu * acc;
1076 }
1077 }
1078 }
1079 grad
1080 }
1081}
1082
1083impl AnalyticPenalty for IsometryPenalty {
1084 fn tier(&self) -> PenaltyTier {
1085 PenaltyTier::Psi
1086 }
1087
1088 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
1089 let d = self
1090 .target
1091 .latent_dim
1092 .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1093 let n_obs = target.len() / d;
1094 if !self.has_jacobian_cache("value") {
1095 return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
1096 }
1097 let Some(g) = self.pullback_metric(d) else {
1098 return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
1099 };
1100 let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1101 return Self::DEFAULT_VALUE_ON_MISSING_CACHE;
1102 };
1103 let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1104 let mut acc = 0.0;
1105 for n in 0..n_obs {
1106 for k in 0..(d * d) {
1107 let diff = metric.residual[[n, k]];
1108 acc += diff * diff;
1109 }
1110 }
1111 0.5 * mu * acc
1112 }
1113
1114 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1115 let d = self
1130 .target
1131 .latent_dim
1132 .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1133 let n_obs = target.len() / d;
1134 if !self.has_jacobian_cache("grad_target")
1135 || !self.has_jacobian_second_source("grad_target")
1136 {
1137 return Array1::<f64>::zeros(target.len());
1138 }
1139 let Some(g) = self.pullback_metric(d) else {
1140 return Array1::<f64>::zeros(target.len());
1141 };
1142 let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1143 return Array1::<f64>::zeros(target.len());
1144 };
1145 let p = self.p_out;
1146 let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1147 let mut grad = Array1::<f64>::zeros(target.len());
1148 let Some(jac2) = self.jacobian_second(target, n_obs, d) else {
1149 return grad;
1150 };
1151 assert_eq!(jac2.ncols(), p * d * d);
1152
1153 for n in 0..n_obs {
1154 let Some(wj) = self.weighted_jacobian_row(n, d) else {
1155 return grad;
1156 };
1157 for c in 0..d {
1158 let mut acc = 0.0;
1159 for a in 0..d {
1160 for b in 0..d {
1161 let mut dg = 0.0;
1162 for i in 0..p {
1163 dg += jac2[[n, (i * d + a) * d + c]] * wj[[i, b]];
1164 dg += wj[[i, a]] * jac2[[n, (i * d + b) * d + c]];
1165 }
1166 acc += metric.metric_grad[[n, a * d + b]] * dg;
1167 }
1168 }
1169 grad[n * d + c] = mu * acc;
1170 }
1171 }
1172 grad
1173 }
1174
1175 fn hvp(
1177 &self,
1178 target: ArrayView1<'_, f64>,
1179 rho: ArrayView1<'_, f64>,
1180 v: ArrayView1<'_, f64>,
1181 ) -> Array1<f64> {
1182 let Some(state) = self.hvp_state(target) else {
1195 return Array1::<f64>::zeros(v.len());
1196 };
1197 self.hvp_with_precomputed_state(&state, rho, v)
1198 }
1199
1200 fn psd_majorizer_hvp(
1209 &self,
1210 target: ArrayView1<'_, f64>,
1211 rho: ArrayView1<'_, f64>,
1212 v: ArrayView1<'_, f64>,
1213 ) -> Array1<f64> {
1214 let d = self
1215 .target
1216 .latent_dim
1217 .expect("IsometryPenalty requires latent_dim on its PsiSlice");
1218 let n_obs = target.len() / d;
1219 if !self.has_jacobian_cache("psd_majorizer_hvp")
1220 || !self.has_jacobian_second_source("psd_majorizer_hvp")
1221 {
1222 return Array1::<f64>::zeros(v.len());
1223 }
1224 let Some(jac2) = self.jacobian_second(target, n_obs, d) else {
1225 return Array1::<f64>::zeros(v.len());
1226 };
1227 let Some(g) = self.pullback_metric(d) else {
1228 return Array1::<f64>::zeros(v.len());
1229 };
1230 let Some(metric) = self.normalized_metric_state(g, n_obs, d) else {
1231 return Array1::<f64>::zeros(v.len());
1232 };
1233 let p = self.p_out;
1234 let mu = resolve_learnable_weight(self.scalar_weight, rho[self.rho_index]);
1235 let mut out = Array1::<f64>::zeros(v.len());
1236 let mut wj_rows = Vec::with_capacity(n_obs);
1237 for n in 0..n_obs {
1238 let Some(wj) = self.weighted_jacobian_row(n, d) else {
1239 return Array1::<f64>::zeros(v.len());
1240 };
1241 wj_rows.push(wj);
1242 }
1243 let mut delta_g = Array2::<f64>::zeros((n_obs, d * d));
1244 for n in 0..n_obs {
1245 let row_delta = isometry_row_delta_g(jac2.view(), wj_rows[n].view(), v, n, d, p);
1246 for a in 0..d {
1247 for b in 0..d {
1248 delta_g[[n, a * d + b]] = row_delta[[a, b]];
1249 }
1250 }
1251 }
1252 let (delta_residual, _delta_normalizer) = metric.residual_direction(delta_g.view(), d);
1253 let mut g_dot_delta_residual = 0.0;
1254 for n in 0..n_obs {
1255 for k in 0..(d * d) {
1256 g_dot_delta_residual += metric.g[[n, k]] * delta_residual[[n, k]];
1257 }
1258 }
1259 let inv_norm = 1.0 / metric.normalizer;
1260 let inv_norm_sq = inv_norm * inv_norm;
1261 for n in 0..n_obs {
1262 let wj = &wj_rows[n];
1263 for c in 0..d {
1264 let mut trace_dg = 0.0;
1265 for a in 0..d {
1266 trace_dg += isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, a, c);
1267 }
1268 let delta_normalizer_c = trace_dg / metric.trace_denominator;
1269 let mut acc = -delta_normalizer_c * inv_norm_sq * g_dot_delta_residual;
1270 for a in 0..d {
1271 for b in 0..d {
1272 let dg = isometry_dg_entry(jac2.view(), wj.view(), n, d, p, a, b, c);
1273 acc += dg * inv_norm * delta_residual[[n, a * d + b]];
1274 }
1275 }
1276 out[n * d + c] = mu * acc;
1277 }
1278 }
1279 out
1280 }
1281
1282 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
1283 let mut out = Array1::<f64>::zeros(self.rho_count());
1286 out[self.rho_index] = self.value(target, rho);
1287 out
1288 }
1289
1290 fn rho_count(&self) -> usize {
1291 1
1292 }
1293
1294 fn name(&self) -> &str {
1295 "isometry"
1296 }
1297
1298 impl_scalar_apply_schedule!(scalar_weight);
1299}