Skip to main content

DecoderIncoherencePenalty

Struct DecoderIncoherencePenalty 

Source
pub struct DecoderIncoherencePenalty {
    pub target: PsiSlice,
    pub block_sizes: Vec<usize>,
    pub p_out: usize,
    pub k_atoms: usize,
    pub pairs: Vec<(usize, usize, f64)>,
    pub weight: f64,
    pub learnable_weight: bool,
    pub rho_index: usize,
    pub weight_schedule: Option<ScalarWeightSchedule>,
}
Expand description

Cross-atom decoder column-space incoherence, restricted to co-activating atom pairs (issue #671).

Lives on the β tier and targets the flat SAE decoder coefficient block. The β layout concatenates the per-atom decoder blocks in atom order: atom k owns M_k · p_out coefficients, stored as β[off_k + a·p_out + o] for basis row a and output feature o. The stored block is B_k ∈ ℝ^{M_k × p_out} with rows B_k[a, :] representing decoder directions in output space.

The penalty is the co-activation-masked cross-column-space overlap

  P = ½ · w · Σ_{j<k} W[j,k] · ‖B_j B_k^T‖²_F,
  W[j,k] = ½ · (coactivation[j,k] + coactivation[k,j]).

coactivation[j,k] is the mean over observations of gate[n,j] · gate[n,k]; pairs that never co-fire (W[j,k] = 0) contribute nothing. In the SAE objective this is the separability lever: atoms that are active on the same examples are discouraged from spanning the same decoder output directions, while unrelated atoms are not pushed apart just because they both exist in the dictionary.

The Hessian used here is the Gauss-Newton (positive-semidefinite) curvature of the Frobenius objective in C, dropping the indefinite second-order term in C. This keeps the β-tier Newton / PIRLS curvature block PSD, matching the other quadratic-on-Gram penalties.

Gotchas:

  • block_sizes are decoder basis-row counts M_k, not output widths; every atom shares the same p_out. Stored decoder blocks are (M_k, p_out), so B_j B_k^T is the cross-Gram of decoder directions in output space and remains well-defined for heterogeneous M_k.
  • The descriptor path builds a placeholder penalty; live SAE wiring replaces the co-activation matrix with the current mean gate products.
  • Offsets are interpreted against the vector passed to this penalty. In the SAE decoder-incoherence path the registered target slice is zero-based; callers using an already sliced target view must keep that convention.

Fields§

§target: PsiSlice§block_sizes: Vec<usize>

Per-atom decoder basis-function counts M_k. The atom blocks are laid out contiguously in β order; Σ_k M_k·p_out == target.len().

§p_out: usize

Output / feature dimension p_out (decoder column count, shared by all atoms).

§k_atoms: usize

Atom count K. The operator only stores the SPARSE list of penalized atom pairs (pairs), not the dense K×K co-activation matrix — at K = 32768 that dense matrix is 8 GiB. Every consumer of this operator already skipped pairs whose symmetrized weight is 0, so storing only the nonzero pairs is exactly equivalent to the dense matrix while being linear in the number of co-active / near-collinear pairs (#1026).

§pairs: Vec<(usize, usize, f64)>

Sparse penalized atom pairs (j, k, w) with j < k and the symmetrized weight w = ½·(W[j,k] + W[k,j]) > 0 (this is exactly the value the old pair_weight(j, k) returned). Pairs with w == 0 are omitted; the dense operator skipped them, so results are byte-identical.

§weight: f64

Base strength. If learnable_weight is true the resolved strength is weight·exp(rho[rho_index]); otherwise it is fixed at weight.

§learnable_weight: bool§rho_index: usize§weight_schedule: Option<ScalarWeightSchedule>

Implementations§

Source§

impl DecoderIncoherencePenalty

Source

pub fn new( target: PsiSlice, block_sizes: Vec<usize>, p_out: usize, coactivation: Array2<f64>, weight: f64, learnable_weight: bool, ) -> Result<Self, String>

Source

pub fn new_sparse( target: PsiSlice, block_sizes: Vec<usize>, p_out: usize, pairs: Vec<(usize, usize, f64)>, weight: f64, learnable_weight: bool, ) -> Result<Self, String>

Sparse-pair constructor used by the SAE live wiring (#1026): build the operator directly from a list of penalized atom pairs (j, k, w) with j < k and the symmetrized per-pair weight w (exactly the value the old dense pair_weight(j, k) returned), avoiding any dense K×K allocation. w == 0 pairs and out-of-range indices are dropped / rejected. This is equivalent to Self::new fed the dense symmetric matrix with the same nonzero entries.

Source

pub fn with_weight_schedule(self, schedule: ScalarWeightSchedule) -> Self

Attach a scalar weight schedule, seeding the current weight from the schedule’s stored iteration counter.

Source

pub fn accumulate_psd_majorizer_dense( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, scale: f64, hbb: &mut Array2<f64>, )

Scatter the Gauss-Newton (PSD majorizer) curvature DIRECTLY into a dense β × β block, accumulating scale · H_GN onto hbb.

This produces exactly the operator AnalyticPenalty::psd_majorizer_hvp applies (the include_residual = false branch of Self::hvp_impl), but assembled block-by-block over the penalized atom pairs instead of reconstructed column-by-column from β unit-probe HVPs. Since H_GN is pair-local — it couples only the (j, k) pairs in self.pairs, each within their (M·p) decoder blocks — reading off the four output loops of hvp_impl at a unit probe gives, per pair (j, k) with w = w_sym · λ · scale and G_x = B_xᵀ B_x (the p × p decoder output Gram of atom x):

  • j-block diagonal H[(j,a,o),(j,a,o')] += w · G_k[o,o']
  • k-block diagonal H[(k,b,o),(k,b,o')] += w · G_j[o,o']
  • off-diagonal H[(j,a,o₁),(k,b,o₂)] += w · B_j[a,o₂] · B_k[b,o₁] and its symmetric transpose into the (k, j) block.

Cost is O(Σ_pairs (M_j·M_k + M_j + M_k)·p²), versus the probe loop’s O(β · Σ_pairs M_j·M_k·p): once β = K·M·p and the collinearity gate admits O(K) co-active pairs, the probe loop spends O(K²) time rebuilding a matrix this assembles in O(K) (#1026).

Trait Implementations§

Source§

impl AnalyticPenalty for DecoderIncoherencePenalty

Source§

fn hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>

Exact Hessian-vector product H v = (∂²P/∂target²) v.

P = ½ w Σ_{j<k} w_{jk} ‖C_{jk}‖²_F is biquadratic (quartic) in the decoder blocks, so the second derivative of the nonlinear-least-squares objective carries two pieces along a direction V (per pair, with W = w·w_{jk}):

  (H v)_j[a,o] = W [ Σ_b dC[a,b]·B_k[b,o]   +   Σ_b C[a,b]·V_k[b,o] ]

the Gauss-Newton term Σ dC·B and the residual term Σ C·V, with dC[a,b] = Σ_o (V_j[a,o]·B_k[b,o] + B_j[a,o]·V_k[b,o]) (and the symmetric _k block). The residual term is what makes the exact Hessian indefinite; the GN-only surrogate lives in Self::psd_majorizer_hvp.

Source§

fn psd_majorizer_hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>

PSD majorizer-vector product B_GN(target; ρ) v for the nonconvex decoder-incoherence penalty.

Dropping the indefinite residual term W·Σ C·V from the exact Self::hvp leaves the Gauss-Newton block W·Jᵀ(J v) with J = ∂vec(C)/∂vec(B). That block is PSD by construction — a sum of W ≥ 0 (weight > 0, coactivation ≥ 0) times rank-structured Gram products JᵀJ — and coincides with the exact Hessian as the cross-Gram C → 0. The inner Newton / PIRLS curvature block must stay positive-definite, so the GN block is the correct operator here, mirroring the other nonconvex penalties (sparsity, JumpReLU, isometry) that override the majorizer rather than hand back the indefinite true Hessian.

Source§

fn tier(&self) -> PenaltyTier

Tier the target lives in (β or ext-coord).
Source§

fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64

Scalar penalty contribution P(target; ρ). The strength factor exp(ρ) (or whatever parameterization the penalty uses) is folded in.
Source§

fn grad_target( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>

Gradient ∂P/∂target, same length as target.
Source§

fn grad_rho( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>

Gradient of the penalty value w.r.t. each owned ρ-axis. Length equals Self::rho_count.
Source§

fn rho_count(&self) -> usize

Number of REML-selectable hyperparameter axes this penalty contributes to the outer ρ vector.
Source§

fn name(&self) -> &str

Human-readable identifier for diagnostics / logging.
Source§

fn apply_schedule(&mut self, iter: usize)

Update any attached scalar weight schedule at the given REML outer iteration. Penalties without schedules keep their stored weight.
Source§

fn hessian_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>

Diagonal of the Hessian diag(∂²P/∂target²) when the Hessian is block-diagonal. Returns None for penalties whose Hessian is dense (Isometry); those implement Self::hvp instead. The default signals “no closed-form diagonal” by returning None for any non-empty target — concrete penalties either override with their own analytic diagonal or rely on the matrix-free hvp path.
Source§

fn psd_majorizer_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>

Diagonal of a PSD majorizer of the Hessian — the positive re-weighted-ℓ₂ / MM surrogate diag(B(target; ρ)) with B ⪰ ∂²P/∂target² everywhere and B ⪰ 0. This is a different operator from Self::hessian_diag: for nonconvex penalties (log sparsity, JumpReLU) the exact Hessian is indefinite, but the inner Newton / PIRLS solve and the log-det / preconditioner pipeline require a PSD curvature block. For convex penalties the majorizer coincides with the exact Hessian, so the default simply delegates to Self::hessian_diag; nonconvex penalties override.
Source§

impl Clone for DecoderIncoherencePenalty

Source§

fn clone(&self) -> DecoderIncoherencePenalty

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for DecoderIncoherencePenalty

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl PenaltyManifest for DecoderIncoherencePenalty

Source§

const KIND_TAG: &'static str = "decoder_incoherence"

Source§

const PYTHON_WRAPPER: &'static str = "DecoderIncoherencePenalty"

Source§

const ROW_BLOCK_DIAGONAL: bool = false

Source§

fn dispatch_tier(&self) -> PenaltyTier

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Allocation for T
where T: RefUnwindSafe + Send + Sync,

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> ByRef<T> for T

Source§

fn by_ref(&self) -> &T

Source§

impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> DistributionExt for T
where T: ?Sized,

Source§

fn rand<T>(&self, rng: &mut (impl Rng + ?Sized)) -> T
where Self: Distribution<T>,

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Imply<T> for U
where T: ?Sized, U: ?Sized,

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> Read<Exclusive, BecauseExclusive> for T
where T: ?Sized,

Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<SS, SP> SupersetOf<SS> for SP
where SS: SubsetOf<SP>,

Source§

fn to_subset(&self) -> Option<SS>

The inverse inclusion map: attempts to construct self from the equivalent element of its superset. Read more
Source§

fn is_in_subset(&self) -> bool

Checks if self is actually part of its subset T (and can be converted to it).
Source§

fn to_subset_unchecked(&self) -> SS

Use with care! Same as self.to_subset but without any property checks. Always succeeds.
Source§

fn from_subset(element: &SS) -> SP

The inclusion map: converts self to the equivalent element of its superset.
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V