Skip to main content

IsometryPenalty

Struct IsometryPenalty 

Source
pub struct IsometryPenalty {
    pub target: PsiSlice,
    pub reference: IsometryReference,
    pub rho_index: usize,
    pub jacobian_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
    pub jacobian_second_cache_slot: RwLock<Option<Arc<Array2<f64>>>>,
    pub duchon_radial_source: Option<Arc<IsometryDuchonRadialSource>>,
    pub third_decoder_derivative_slot: RwLock<Option<Arc<Array3<f64>>>>,
    pub p_out: usize,
    pub weight: WeightField,
    pub scalar_weight: f64,
    pub weight_schedule: Option<ScalarWeightSchedule>,
}
Expand description

Isometry-to-reference penalty (canonical-coordinate gauge term).

Lives on ext-coords: the target slice is a row of the LatentCoordValues flat vector (row-major n_obs × d). Owns one ρ-axis (log μ_iso).

Penalizes ½ μ Σ_n ‖g_n(t) − g^ref(t_n)‖²_F, where the pullback metric at row n is

  g_n = J_n^T W_n J_n,    J_n ∈ ℝ^{p × d}

and W_n is a per-row low-rank PSD behavioral metric stored as W_n = U_n U_n^T with U_n ∈ ℝ^{p × r}. The canonical-coordinate statement is “one unit of motion in t ↦ one unit of behavioral change”, so the W_n weighting is load-bearing.

In the SAE objective this is the extension-coordinate gauge fix: it prevents the latent chart from absorbing arbitrary smooth reparameterizations of the decoder manifold. ARD, sparsity, or rank penalties can then select axes or structure in a chart whose metric scale is pinned.

Contraction order invariant. Every place this struct touches W_n, the contraction is (J^T U_n)(U_n^T J) — never J^T W_n J with W_n materialized as p × p. Concretely we form M_n = U_n^T J_n ∈ ℝ^{r × d} once and then g_n = M_n^T M_n (d × d). Cost per row: O(p · r · d + r · d²), independent of .

When to use. Whenever a LatentCoord block is in play without an auxiliary variable (AuxPrior) to break the diffeomorphism gauge. Fixes the audit finding that ARD is not a standalone gauge fix. With a Euclidean reference, the penalty pulls the decoder toward a local isometry, which is enough to make the inner Hessian on t full-rank and the IFT well-defined.

Math. Let J_n ∈ ℝ^{p × d} be the local decoder Jacobian. Then g_n = J_n^T W_n J_n and the penalty is ½ μ Σ_n ‖J_n^T W_n J_n − g^ref_n‖²_F. Analytic gradient w.r.t. t_n:

  ∂P/∂t_{n,c}
    = μ Σ_{a,b} (g_n − g^ref_n)_{ab}
        [ H_{n,:,a,c}^T W_n J_{n,:,b}
          + J_{n,:,a}^T W_n H_{n,:,b,c} ],
  H_{n,i,a,c} = ∂J_{n,i,a}/∂t_{n,c}.

Gotchas:

  • The value path returns the configured missing-cache default when the first-jet cache is absent; gradient/HVP paths need the first and second decoder jets and return zeros when the analytic jet source is unavailable.
  • The exact Hessian includes a residual-curvature term requiring the third decoder jet. REML/PIRLS curvature should prefer the Gauss-Newton PSD majorizer when a positive curvature block is required.
  • W_n is a metric weight, not a scalar confidence. Changing it changes the canonical units of latent motion.

The per-row Jacobian J_n is exactly the radial-derivative jet design_gradient_wrt_t already computes for LatentCoordValues; the second derivative ∂J/∂t is built by the shared [crate::basis::radial_basis_cartesian_derivative] engine from the radial Hessian identity. A finite-difference oracle for the docstring is to central-difference value(t ± h e_j) against grad_target(t)[j]; the analytic value follows the oracle until finite-difference cancellation dominates. No autograd needed.

μ = exp(ρ_iso) is REML-selectable as one extra ρ axis.

jacobian_cache_slot and jacobian_second_cache_slot are interior-mutable (RwLock<Option<Arc<…>>>) so the SAE outer loop can refresh them in place each step without needing &mut self on the registry-held penalty (see refresh_caches and [crate::terms::sae::manifold::refresh_isometry_caches_from_atom]). Readers go through the Self::jacobian_cache / Self::jacobian_second_cache accessors, which take the read lock briefly and clone the inner Arc (refcount bump — no payload copy). Writers go through Self::refresh_caches.

Fields§

§target: PsiSlice§reference: IsometryReference§rho_index: usize

Index of this penalty’s strength log μ_iso inside the local rho view this penalty receives. Always 0 for now (single owned axis).

§jacobian_cache_slot: RwLock<Option<Arc<Array2<f64>>>>

Cached Jacobian J ∈ ℝ^{n_obs × p × d}, flattened row-major (n_obs, p*d). The owning driver refreshes this each IFT outer step before invoking value / grad_target; in operator-only call sites (Hessian-vector products) the cache must be live. Access through Self::jacobian_cache / Self::set_jacobian_cache.

§jacobian_second_cache_slot: RwLock<Option<Arc<Array2<f64>>>>

Optional cached per-row Jacobian second derivative H_n ∈ ℝ^{p × d × d}, flattened row-major as (n_obs, p*d*d). H_n[i, a, c] = ∂J_n[i, a] / ∂t_{n, c}. Either this cache or duchon_radial_source must be present for exact isometry gradient/HVP calls. Access through Self::jacobian_second_cache / Self::set_jacobian_second_cache.

§duchon_radial_source: Option<Arc<IsometryDuchonRadialSource>>

Optional radial-Duchon source used to build jacobian_second_cache analytically from φ'(r) and the public φ''(r) jet helper. This is the exact chain-rule path for callers that do not pre-cache ∂J/∂t.

§third_decoder_derivative_slot: RwLock<Option<Arc<Array3<f64>>>>

Optional cached per-row Jacobian third derivative K_n ∈ ℝ^{p × d × d × d}, stored as an Array3 with shape (n_obs, p, d * d * d) where the third axis packs (a, c, d) in row-major order ((a * d) + c) * d + dd. hvp uses the full residual-curvature Hessian (proposal §4(b)): B_{ab,cd} = K_{a,cd}^T W J_b + H_{a,c}^T W H_{b,d} + H_{a,d}^T W H_{b,c} + J_a^T W K_{b,cd}. Either this cache or duchon_radial_source must be present for analytic hvp calls. Interior-mutable (mirrors jacobian_second_cache_slot) so the SAE outer loop can refresh K in place each step. Access through Self::third_decoder_derivative / Self::set_third_decoder_derivative.

§p_out: usize

Output dimensionality p (column count of each per-row Jacobian).

§weight: WeightField

Per-row behavioral metric in low-rank factored form. Defaults to Identity (the unweighted J^T J pullback). When Factored, all g_n contractions are done via M_n = U_n^T J_n (r × d), keeping memory and FLOPs scaling at O(p · r · d) per row instead of O(p²) per row.

§scalar_weight: f64§weight_schedule: Option<ScalarWeightSchedule>

Implementations§

Source§

impl IsometryPenalty

Source

pub const DEFAULT_VALUE_ON_MISSING_CACHE: f64 = 0.0

Source

pub fn new_euclidean(target: PsiSlice, p_out: usize) -> Self

Source

pub fn jacobian_cache(&self) -> Option<Arc<Array2<f64>>>

Read-side accessor: takes the read lock briefly and clones the inner Arc (refcount bump only; no payload copy). Returns None when the cache has not been refreshed yet. Internally panics on poisoned lock — the lock only wraps an Option<Arc<…>>, so the write side cannot leave it in an invariant-violating state.

Source

pub fn jacobian_second_cache(&self) -> Option<Arc<Array2<f64>>>

Read-side accessor for the per-row Jacobian second derivative. Mirrors Self::jacobian_cache.

Source

pub fn refresh_caches( &self, jac: Option<Arc<Array2<f64>>>, jac2: Option<Arc<Array2<f64>>>, )

Per-step refresh entry point. Takes &self (no &mut) so the SAE outer loop can install fresh caches on an Arc<IsometryPenalty> held in the analytic-penalty registry without disturbing the surrounding dispatcher. Pass None for either argument to clear that cache (the dispatcher will then either fall back to the Duchon radial source if available, or return the zero safe default).

Source

pub fn set_jacobian_cache(&self, jac: Option<Arc<Array2<f64>>>)

In-place writer for just the Jacobian cache (used by callers that already own the radial Duchon source and only want to refresh J).

Source

pub fn set_jacobian_second_cache(&self, jac2: Option<Arc<Array2<f64>>>)

In-place writer for just the Jacobian second-derivative cache.

Source

pub fn third_decoder_derivative(&self) -> Option<Arc<Array3<f64>>>

Read-side accessor for the per-row Jacobian third derivative K. Mirrors Self::jacobian_second_cache.

Source

pub fn set_third_decoder_derivative(&self, jac3: Option<Arc<Array3<f64>>>)

In-place writer for just the Jacobian third-derivative cache K.

Source§

impl IsometryPenalty

Source

pub fn with_third_decoder_derivative(self, k: Arc<Array3<f64>>) -> Self

Attach a cached third decoder derivative K_n[i, a, c, d] = ∂²J_n[i, a] / ∂t_{n, c} ∂t_{n, d}, flattened row-major as (n_obs, p * d * d * d). The Hessian-vector product uses the full residual-curvature term in addition to the metric Gauss-Newton piece.

Source

pub fn with_reference(self, reference: IsometryReference) -> Self

Source

pub fn with_jacobian_cache(self, j: Arc<Array2<f64>>) -> Self

Source

pub fn with_jacobian_second_cache(self, h: Arc<Array2<f64>>) -> Self

Source

pub fn with_duchon_radial_source( self, source: Arc<IsometryDuchonRadialSource>, ) -> Self

Attach radial Duchon decoder metadata so the exact ∂J/∂t tensor can be rebuilt from the current target coordinates. A doc-test oracle for this path is: build J(t) from duchon_radial_first_derivative_nd, evaluate grad_target(t), then central-difference value(t ± h e_j); the analytic component should agree to finite-difference tolerance as h is refined before cancellation dominates.

Source

pub fn with_row_metric(self, metric: &RowMetric) -> Self

Attach the gauge metric from the single RowMetric that also drives the reconstruction likelihood. This is the only way an IsometryPenalty acquires a non-identity behavioral metric: the independent WeightField setter has been removed so a gauge-metric ≠ likelihood-metric state is structurally unrepresentable. The contraction-order invariant (M_n = U_n^T J_n, never materializing the p × p W_n) is preserved by the WeightField::Factored layout the metric emits.

p_out is taken from the metric so the gauge’s output dimension is pinned to the metric’s.

Source

pub fn with_weight_schedule(self, schedule: ScalarWeightSchedule) -> Self

Attach a scalar weight schedule, seeding the current weight from the schedule’s stored iteration counter.

Source

pub fn pullback_metric(&self, latent_dim: usize) -> Option<Array2<f64>>

Per-row pullback metric g_n = J_n^T W_n J_n = M_n^T M_n with M_n = U_n^T J_n ∈ ℝ^{r_n × d}. Returns (n_obs, d, d) flattened row-major as (n_obs, d*d).

Cost per row: O(p · r · d) for the M_n build (single pass over U_n and J_n) plus O(r · d²) for M_n^T M_n. The p × p weight W_n is never materialized.

Source

pub fn grad_jacobian( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array2<f64>

Exact closed-form gradient of the isometry penalty with respect to the cached decoder Jacobian J ∈ ℝ^{n_obs × p × d} (the autograd input that torch’s _IsometryPenaltyFn differentiates). Returns the flattened (n_obs, p*d) layout that matches the Jacobian cache.

Derivation (W-aware, reference-aware, weight-aware):

P = ½ μ Σ_n ‖R_n‖²_F, R_n = g_n / gbar − g^ref_n, gbar = (1 / (N d)) Σ_n tr(g_n) A_n = ∂(P/μ)/∂g_n ∂g_{ab}/∂J_{i,c} = δ_{ca}(W J){i,b} + δ{cb}(W J){i,a} (W symmetric) ∂P/∂J{i,c} = μ Σ_{a,b} A_{ab} ∂g_{ab}/∂J_{i,c} = 2 μ Σ_b A_{cb} (W J){i,b} = 2 μ ((W J) A){i,c}

where A includes the exact derivative of the shared gbar normalizer.

Trait Implementations§

Source§

impl AnalyticPenalty for IsometryPenalty

Source§

fn hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>

Fully analytic - wired through radial_basis_cartesian_derivative.

Source§

fn psd_majorizer_hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>

PSD majorizer-vector product B_GN(target; ρ) v for the nonconvex isometry penalty.

The Gauss-Newton block differentiates the normalized residual R = g/gbar - G_ref itself and returns μ DRᵀ DR v. This is PSD by construction and includes the shared-normalizer derivative exactly; using only ∂g would reintroduce scale coupling and would not be the Gauss-Newton operator of the objective being minimized.

Source§

fn tier(&self) -> PenaltyTier

Tier the target lives in (β or ext-coord).
Source§

fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64

Scalar penalty contribution P(target; ρ). The strength factor exp(ρ) (or whatever parameterization the penalty uses) is folded in.
Source§

fn grad_target( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>

Gradient ∂P/∂target, same length as target.
Source§

fn grad_rho( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>

Gradient of the penalty value w.r.t. each owned ρ-axis. Length equals Self::rho_count.
Source§

fn rho_count(&self) -> usize

Number of REML-selectable hyperparameter axes this penalty contributes to the outer ρ vector.
Source§

fn name(&self) -> &str

Human-readable identifier for diagnostics / logging.
Source§

fn apply_schedule(&mut self, iter: usize)

Update any attached scalar weight schedule at the given REML outer iteration. Penalties without schedules keep their stored weight.
Source§

fn hessian_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>

Diagonal of the Hessian diag(∂²P/∂target²) when the Hessian is block-diagonal. Returns None for penalties whose Hessian is dense (Isometry); those implement Self::hvp instead. The default signals “no closed-form diagonal” by returning None for any non-empty target — concrete penalties either override with their own analytic diagonal or rely on the matrix-free hvp path.
Source§

fn psd_majorizer_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>

Diagonal of a PSD majorizer of the Hessian — the positive re-weighted-ℓ₂ / MM surrogate diag(B(target; ρ)) with B ⪰ ∂²P/∂target² everywhere and B ⪰ 0. This is a different operator from Self::hessian_diag: for nonconvex penalties (log sparsity, JumpReLU) the exact Hessian is indefinite, but the inner Newton / PIRLS solve and the log-det / preconditioner pipeline require a PSD curvature block. For convex penalties the majorizer coincides with the exact Hessian, so the default simply delegates to Self::hessian_diag; nonconvex penalties override.
Source§

impl Clone for IsometryPenalty

Source§

fn clone(&self) -> Self

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for IsometryPenalty

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl PenaltyManifest for IsometryPenalty

Source§

const KIND_TAG: &'static str = "isometry"

Source§

const PYTHON_WRAPPER: &'static str = "IsometryPenalty"

Source§

const ROW_BLOCK_DIAGONAL: bool = false

Source§

fn dispatch_tier(&self) -> PenaltyTier

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Allocation for T
where T: RefUnwindSafe + Send + Sync,

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> ByRef<T> for T

Source§

fn by_ref(&self) -> &T

Source§

impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> DistributionExt for T
where T: ?Sized,

Source§

fn rand<T>(&self, rng: &mut (impl Rng + ?Sized)) -> T
where Self: Distribution<T>,

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Imply<T> for U
where T: ?Sized, U: ?Sized,

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> Read<Exclusive, BecauseExclusive> for T
where T: ?Sized,

Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<SS, SP> SupersetOf<SS> for SP
where SS: SubsetOf<SP>,

Source§

fn to_subset(&self) -> Option<SS>

The inverse inclusion map: attempts to construct self from the equivalent element of its superset. Read more
Source§

fn is_in_subset(&self) -> bool

Checks if self is actually part of its subset T (and can be converted to it).
Source§

fn to_subset_unchecked(&self) -> SS

Use with care! Same as self.to_subset but without any property checks. Always succeeds.
Source§

fn from_subset(element: &SS) -> SP

The inclusion map: converts self to the equivalent element of its superset.
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V