pub struct IsometryPenalty {
pub target: PsiSlice,
pub reference: IsometryReference,
pub rho_index: usize,
pub jacobian_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
pub jacobian_second_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
pub duchon_radial_source: Option<Arc<IsometryDuchonRadialSource>>,
pub third_decoder_derivative_slot: RwLock<Option<Arc<Array3<f64>>>>,
pub p_out: usize,
pub weight: WeightField,
pub scalar_weight: f64,
pub weight_schedule: Option<ScalarWeightSchedule>,
}Expand description
Isometry-to-reference penalty (canonical-coordinate gauge term).
Lives on ext-coords: the target slice is a row of the LatentCoordValues flat
vector (row-major n_obs × d). Owns one ρ-axis (log μ_iso).
Penalizes ½ μ Σ_n ‖g_n(t) − g^ref(t_n)‖²_F, where the pullback metric
at row n is
g_n = J_n^T W_n J_n, J_n ∈ ℝ^{p × d}and W_n is a per-row low-rank PSD behavioral metric stored as
W_n = U_n U_n^T with U_n ∈ ℝ^{p × r}. The canonical-coordinate
statement is “one unit of motion in t ↦ one unit of behavioral change”,
so the W_n weighting is load-bearing.
In the SAE objective this is the extension-coordinate gauge fix: it prevents the latent chart from absorbing arbitrary smooth reparameterizations of the decoder manifold. ARD, sparsity, or rank penalties can then select axes or structure in a chart whose metric scale is pinned.
Contraction order invariant. Every place this struct touches W_n,
the contraction is (J^T U_n)(U_n^T J) — never J^T W_n J with W_n
materialized as p × p. Concretely we form M_n = U_n^T J_n ∈ ℝ^{r × d}
once and then g_n = M_n^T M_n (d × d). Cost per row:
O(p · r · d + r · d²), independent of p².
When to use. Whenever a LatentCoord block is in play without an
auxiliary variable (AuxPrior) to break the diffeomorphism gauge. Fixes
the audit finding that ARD is not a standalone gauge fix. With a Euclidean
reference, the penalty pulls the decoder toward a local isometry, which is
enough to make the inner Hessian on t full-rank and the IFT well-defined.
Math. Let J_n ∈ ℝ^{p × d} be the local decoder Jacobian. Then
g_n = J_n^T W_n J_n and the penalty is
½ μ Σ_n ‖J_n^T W_n J_n − g^ref_n‖²_F. Analytic gradient w.r.t. t_n:
∂P/∂t_{n,c}
= μ Σ_{a,b} (g_n − g^ref_n)_{ab}
[ H_{n,:,a,c}^T W_n J_{n,:,b}
+ J_{n,:,a}^T W_n H_{n,:,b,c} ],
H_{n,i,a,c} = ∂J_{n,i,a}/∂t_{n,c}.Gotchas:
- The value path returns the configured missing-cache default when the first-jet cache is absent; gradient/HVP paths need the first and second decoder jets and return zeros when the analytic jet source is unavailable.
- The exact Hessian includes a residual-curvature term requiring the third decoder jet. REML/PIRLS curvature should prefer the Gauss-Newton PSD majorizer when a positive curvature block is required.
W_nis a metric weight, not a scalar confidence. Changing it changes the canonical units of latent motion.
The per-row Jacobian J_n is exactly the radial-derivative jet
design_gradient_wrt_t already computes for LatentCoordValues; the
second derivative ∂J/∂t is built by the shared
[crate::basis::radial_basis_cartesian_derivative] engine from the
radial Hessian identity. A finite-difference oracle for the docstring is
to central-difference value(t ± h e_j) against grad_target(t)[j];
the analytic value follows the oracle until finite-difference
cancellation dominates. No autograd needed.
μ = exp(ρ_iso) is REML-selectable as one extra ρ axis.
jacobian_cache_slot and jacobian_second_cache_slot are interior-mutable
(RwLock<Option<Arc<…>>>) so the SAE outer loop can refresh them in place
each step without needing &mut self on the registry-held penalty (see
refresh_caches and [crate::terms::sae::manifold::refresh_isometry_caches_from_atom]).
Readers go through the Self::jacobian_cache / Self::jacobian_second_cache
accessors, which take the read lock briefly and clone the inner Arc
(refcount bump — no payload copy). Writers go through Self::refresh_caches.
Fields§
§target: PsiSlice§reference: IsometryReference§rho_index: usizeIndex of this penalty’s strength log μ_iso inside the local rho
view this penalty receives. Always 0 for now (single owned axis).
jacobian_cache_slot: RwLock<Option<Arc<Array2<f64>>>>Cached Jacobian J ∈ ℝ^{n_obs × p × d}, flattened row-major
(n_obs, p*d). The owning driver refreshes this each IFT outer step
before invoking value / grad_target; in operator-only call sites
(Hessian-vector products) the cache must be live. Access through
Self::jacobian_cache / Self::set_jacobian_cache.
jacobian_second_cache_slot: RwLock<Option<Arc<Array2<f64>>>>Optional cached per-row Jacobian second derivative
H_n ∈ ℝ^{p × d × d}, flattened row-major as (n_obs, p*d*d).
H_n[i, a, c] = ∂J_n[i, a] / ∂t_{n, c}. Either this cache or
duchon_radial_source must be present for exact isometry
gradient/HVP calls. Access through Self::jacobian_second_cache /
Self::set_jacobian_second_cache.
duchon_radial_source: Option<Arc<IsometryDuchonRadialSource>>Optional radial-Duchon source used to build jacobian_second_cache
analytically from φ'(r) and the public φ''(r) jet helper. This is
the exact chain-rule path for callers that do not pre-cache ∂J/∂t.
third_decoder_derivative_slot: RwLock<Option<Arc<Array3<f64>>>>Optional cached per-row Jacobian third derivative
K_n ∈ ℝ^{p × d × d × d}, stored as an Array3 with shape
(n_obs, p, d * d * d) where the third axis packs (a, c, d) in
row-major order ((a * d) + c) * d + dd. hvp uses the full
residual-curvature Hessian (proposal §4(b)):
B_{ab,cd} = K_{a,cd}^T W J_b + H_{a,c}^T W H_{b,d}
+ H_{a,d}^T W H_{b,c} + J_a^T W K_{b,cd}.
Either this cache or duchon_radial_source must be present for
analytic hvp calls. Interior-mutable (mirrors
jacobian_second_cache_slot) so the SAE outer loop can refresh K in
place each step. Access through Self::third_decoder_derivative /
Self::set_third_decoder_derivative.
p_out: usizeOutput dimensionality p (column count of each per-row Jacobian).
weight: WeightFieldPer-row behavioral metric in low-rank factored form. Defaults to
Identity (the unweighted J^T J pullback). When Factored, all
g_n contractions are done via M_n = U_n^T J_n (r × d), keeping
memory and FLOPs scaling at O(p · r · d) per row instead of
O(p²) per row.
scalar_weight: f64§weight_schedule: Option<ScalarWeightSchedule>Implementations§
Source§impl IsometryPenalty
impl IsometryPenalty
pub const DEFAULT_VALUE_ON_MISSING_CACHE: f64 = 0.0
pub fn new_euclidean(target: PsiSlice, p_out: usize) -> Self
Sourcepub fn jacobian_cache(&self) -> Option<Arc<Array2<f64>>>
pub fn jacobian_cache(&self) -> Option<Arc<Array2<f64>>>
Read-side accessor: takes the read lock briefly and clones the inner
Arc (refcount bump only; no payload copy). Returns None when the
cache has not been refreshed yet. Internally panics on poisoned lock
— the lock only wraps an Option<Arc<…>>, so the write side cannot
leave it in an invariant-violating state.
Sourcepub fn jacobian_second_cache(&self) -> Option<Arc<Array2<f64>>>
pub fn jacobian_second_cache(&self) -> Option<Arc<Array2<f64>>>
Read-side accessor for the per-row Jacobian second derivative.
Mirrors Self::jacobian_cache.
Sourcepub fn refresh_caches(
&self,
jac: Option<Arc<Array2<f64>>>,
jac2: Option<Arc<Array2<f64>>>,
)
pub fn refresh_caches( &self, jac: Option<Arc<Array2<f64>>>, jac2: Option<Arc<Array2<f64>>>, )
Per-step refresh entry point. Takes &self (no &mut) so the SAE
outer loop can install fresh caches on an Arc<IsometryPenalty> held
in the analytic-penalty registry without disturbing the surrounding
dispatcher. Pass None for either argument to clear that cache (the
dispatcher will then either fall back to the Duchon radial source if
available, or return the zero safe default).
Sourcepub fn set_jacobian_cache(&self, jac: Option<Arc<Array2<f64>>>)
pub fn set_jacobian_cache(&self, jac: Option<Arc<Array2<f64>>>)
In-place writer for just the Jacobian cache (used by callers that
already own the radial Duchon source and only want to refresh J).
Sourcepub fn set_jacobian_second_cache(&self, jac2: Option<Arc<Array2<f64>>>)
pub fn set_jacobian_second_cache(&self, jac2: Option<Arc<Array2<f64>>>)
In-place writer for just the Jacobian second-derivative cache.
Sourcepub fn third_decoder_derivative(&self) -> Option<Arc<Array3<f64>>>
pub fn third_decoder_derivative(&self) -> Option<Arc<Array3<f64>>>
Read-side accessor for the per-row Jacobian third derivative K.
Mirrors Self::jacobian_second_cache.
Source§impl IsometryPenalty
impl IsometryPenalty
Sourcepub fn with_third_decoder_derivative(self, k: Arc<Array3<f64>>) -> Self
pub fn with_third_decoder_derivative(self, k: Arc<Array3<f64>>) -> Self
Attach a cached third decoder derivative
K_n[i, a, c, d] = ∂²J_n[i, a] / ∂t_{n, c} ∂t_{n, d}, flattened
row-major as (n_obs, p * d * d * d). The Hessian-vector product
uses the full residual-curvature term in addition to the metric
Gauss-Newton piece.
pub fn with_reference(self, reference: IsometryReference) -> Self
pub fn with_jacobian_cache(self, j: Arc<Array2<f64>>) -> Self
pub fn with_jacobian_second_cache(self, h: Arc<Array2<f64>>) -> Self
Sourcepub fn with_duchon_radial_source(
self,
source: Arc<IsometryDuchonRadialSource>,
) -> Self
pub fn with_duchon_radial_source( self, source: Arc<IsometryDuchonRadialSource>, ) -> Self
Attach radial Duchon decoder metadata so the exact ∂J/∂t tensor can
be rebuilt from the current target coordinates. A doc-test oracle for
this path is: build J(t) from duchon_radial_first_derivative_nd,
evaluate grad_target(t), then central-difference value(t ± h e_j);
the analytic component should agree to finite-difference tolerance as
h is refined before cancellation dominates.
Sourcepub fn with_row_metric(self, metric: &RowMetric) -> Self
pub fn with_row_metric(self, metric: &RowMetric) -> Self
Attach the gauge metric from the single
RowMetric that also drives
the reconstruction likelihood. This is the only way an IsometryPenalty
acquires a non-identity behavioral metric: the independent
WeightField setter has been removed so a gauge-metric ≠
likelihood-metric state is structurally unrepresentable. The
contraction-order invariant (M_n = U_n^T J_n, never materializing the
p × p W_n) is preserved by the WeightField::Factored layout the
metric emits.
p_out is taken from the metric so the gauge’s output dimension is
pinned to the metric’s.
Sourcepub fn with_weight_schedule(self, schedule: ScalarWeightSchedule) -> Self
pub fn with_weight_schedule(self, schedule: ScalarWeightSchedule) -> Self
Attach a scalar weight schedule, seeding the current weight from the schedule’s stored iteration counter.
Sourcepub fn pullback_metric(&self, latent_dim: usize) -> Option<Array2<f64>>
pub fn pullback_metric(&self, latent_dim: usize) -> Option<Array2<f64>>
Per-row pullback metric g_n = J_n^T W_n J_n = M_n^T M_n with
M_n = U_n^T J_n ∈ ℝ^{r_n × d}. Returns (n_obs, d, d) flattened
row-major as (n_obs, d*d).
Cost per row: O(p · r · d) for the M_n build (single pass over
U_n and J_n) plus O(r · d²) for M_n^T M_n. The p × p weight
W_n is never materialized.
Sourcepub fn grad_jacobian(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array2<f64>
pub fn grad_jacobian( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array2<f64>
Exact closed-form gradient of the isometry penalty with respect to the
cached decoder Jacobian J ∈ ℝ^{n_obs × p × d} (the autograd input that
torch’s _IsometryPenaltyFn differentiates). Returns the flattened
(n_obs, p*d) layout that matches the Jacobian cache.
Derivation (W-aware, reference-aware, weight-aware):
P = ½ μ Σ_n ‖R_n‖²_F, R_n = g_n / gbar − g^ref_n, gbar = (1 / (N d)) Σ_n tr(g_n) A_n = ∂(P/μ)/∂g_n ∂g_{ab}/∂J_{i,c} = δ_{ca}(W J){i,b} + δ{cb}(W J){i,a} (W symmetric) ∂P/∂J{i,c} = μ Σ_{a,b} A_{ab} ∂g_{ab}/∂J_{i,c} = 2 μ Σ_b A_{cb} (W J){i,b} = 2 μ ((W J) A){i,c}
where A includes the exact derivative of the shared gbar normalizer.
Trait Implementations§
Source§impl AnalyticPenalty for IsometryPenalty
impl AnalyticPenalty for IsometryPenalty
Source§fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64>
fn hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>
Fully analytic - wired through radial_basis_cartesian_derivative.
Source§fn psd_majorizer_hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64>
fn psd_majorizer_hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>
PSD majorizer-vector product B_GN(target; ρ) v for the nonconvex
isometry penalty.
The Gauss-Newton block differentiates the normalized residual
R = g/gbar - G_ref itself and returns μ DRᵀ DR v. This is PSD by
construction and includes the shared-normalizer derivative exactly;
using only ∂g would reintroduce scale coupling and would not be the
Gauss-Newton operator of the objective being minimized.
Source§fn tier(&self) -> PenaltyTier
fn tier(&self) -> PenaltyTier
Source§fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64
P(target; ρ). The strength factor
exp(ρ) (or whatever parameterization the penalty uses) is folded in.Source§fn grad_target(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64>
fn grad_target( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>
∂P/∂target, same length as target.Source§fn grad_rho(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64>
fn grad_rho( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>
Self::rho_count.Source§fn rho_count(&self) -> usize
fn rho_count(&self) -> usize
Source§fn apply_schedule(&mut self, iter: usize)
fn apply_schedule(&mut self, iter: usize)
Source§fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>>
fn hessian_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>
diag(∂²P/∂target²) when the Hessian is
block-diagonal. Returns None for penalties whose Hessian is dense
(Isometry); those implement Self::hvp instead. The default
signals “no closed-form diagonal” by returning None for any
non-empty target — concrete penalties either override with their
own analytic diagonal or rely on the matrix-free hvp path.Source§fn psd_majorizer_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>>
fn psd_majorizer_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>
diag(B(target; ρ)) with
B ⪰ ∂²P/∂target² everywhere and B ⪰ 0. This is a different
operator from Self::hessian_diag: for nonconvex penalties (log
sparsity, JumpReLU) the exact Hessian is indefinite, but the inner
Newton / PIRLS solve and the log-det / preconditioner pipeline require
a PSD curvature block. For convex penalties the majorizer coincides
with the exact Hessian, so the default simply delegates to
Self::hessian_diag; nonconvex penalties override.Source§impl Clone for IsometryPenalty
impl Clone for IsometryPenalty
Source§impl Debug for IsometryPenalty
impl Debug for IsometryPenalty
Source§impl PenaltyManifest for IsometryPenalty
impl PenaltyManifest for IsometryPenalty
const KIND_TAG: &'static str = "isometry"
const PYTHON_WRAPPER: &'static str = "IsometryPenalty"
const ROW_BLOCK_DIAGONAL: bool = false
fn dispatch_tier(&self) -> PenaltyTier
Auto Trait Implementations§
impl !Freeze for IsometryPenalty
impl RefUnwindSafe for IsometryPenalty
impl Send for IsometryPenalty
impl Sync for IsometryPenalty
impl Unpin for IsometryPenalty
impl UnsafeUnpin for IsometryPenalty
impl UnwindSafe for IsometryPenalty
Blanket Implementations§
impl<T> Allocation for T
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T, U> Imply<T> for U
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pointable for T
impl<T> Pointable for T
impl<T> Read<Exclusive, BecauseExclusive> for Twhere
T: ?Sized,
Source§impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
Source§fn to_subset(&self) -> Option<SS>
fn to_subset(&self) -> Option<SS>
self from the equivalent element of its
superset. Read moreSource§fn is_in_subset(&self) -> bool
fn is_in_subset(&self) -> bool
self is actually part of its subset T (and can be converted to it).Source§fn to_subset_unchecked(&self) -> SS
fn to_subset_unchecked(&self) -> SS
self.to_subset but without any property checks. Always succeeds.Source§fn from_subset(element: &SS) -> SP
fn from_subset(element: &SS) -> SP
self to the equivalent element of its superset.