pub struct DecoderIncoherencePenalty {
pub target: PsiSlice,
pub block_sizes: Vec<usize>,
pub p_out: usize,
pub k_atoms: usize,
pub pairs: Vec<(usize, usize, f64)>,
pub weight: f64,
pub learnable_weight: bool,
pub rho_index: usize,
pub weight_schedule: Option<ScalarWeightSchedule>,
}Expand description
Cross-atom decoder column-space incoherence, restricted to co-activating atom pairs (issue #671).
Lives on the β tier and targets the flat SAE decoder coefficient block. The
β layout concatenates the per-atom decoder blocks in atom order: atom k
owns M_k · p_out coefficients, stored as
β[off_k + a·p_out + o] for basis row a and output feature o.
The stored block is B_k ∈ ℝ^{M_k × p_out} with rows B_k[a, :]
representing decoder directions in output space.
The penalty is the co-activation-masked cross-column-space overlap
P = ½ · w · Σ_{j<k} W[j,k] · ‖B_j B_k^T‖²_F,
W[j,k] = ½ · (coactivation[j,k] + coactivation[k,j]).coactivation[j,k] is the mean over observations of
gate[n,j] · gate[n,k]; pairs that never co-fire (W[j,k] = 0) contribute
nothing. In the SAE objective this is the separability lever: atoms that
are active on the same examples are discouraged from spanning the same
decoder output directions, while unrelated atoms are not pushed apart just
because they both exist in the dictionary.
The Hessian used here is the Gauss-Newton (positive-semidefinite) curvature
of the Frobenius objective in C, dropping the indefinite second-order term
in C. This keeps the β-tier Newton / PIRLS curvature block PSD, matching
the other quadratic-on-Gram penalties.
Gotchas:
block_sizesare decoder basis-row countsM_k, not output widths; every atom shares the samep_out. Stored decoder blocks are(M_k, p_out), soB_j B_k^Tis the cross-Gram of decoder directions in output space and remains well-defined for heterogeneousM_k.- The descriptor path builds a placeholder penalty; live SAE wiring replaces the co-activation matrix with the current mean gate products.
- Offsets are interpreted against the vector passed to this penalty. In the SAE decoder-incoherence path the registered target slice is zero-based; callers using an already sliced target view must keep that convention.
Fields§
§target: PsiSlice§block_sizes: Vec<usize>Per-atom decoder basis-function counts M_k. The atom blocks are laid
out contiguously in β order; Σ_k M_k·p_out == target.len().
p_out: usizeOutput / feature dimension p_out (decoder column count, shared by all
atoms).
k_atoms: usizeAtom count K. The operator only stores the SPARSE list of penalized
atom pairs (pairs), not the dense K×K co-activation matrix — at
K = 32768 that dense matrix is 8 GiB. Every consumer of this operator
already skipped pairs whose symmetrized weight is 0, so storing only
the nonzero pairs is exactly equivalent to the dense matrix while being
linear in the number of co-active / near-collinear pairs (#1026).
pairs: Vec<(usize, usize, f64)>Sparse penalized atom pairs (j, k, w) with j < k and the symmetrized
weight w = ½·(W[j,k] + W[k,j]) > 0 (this is exactly the value the old
pair_weight(j, k) returned). Pairs with w == 0 are omitted; the dense
operator skipped them, so results are byte-identical.
weight: f64Base strength. If learnable_weight is true the resolved strength is
weight·exp(rho[rho_index]); otherwise it is fixed at weight.
learnable_weight: bool§rho_index: usize§weight_schedule: Option<ScalarWeightSchedule>Implementations§
Source§impl DecoderIncoherencePenalty
impl DecoderIncoherencePenalty
pub fn new( target: PsiSlice, block_sizes: Vec<usize>, p_out: usize, coactivation: Array2<f64>, weight: f64, learnable_weight: bool, ) -> Result<Self, String>
Sourcepub fn new_sparse(
target: PsiSlice,
block_sizes: Vec<usize>,
p_out: usize,
pairs: Vec<(usize, usize, f64)>,
weight: f64,
learnable_weight: bool,
) -> Result<Self, String>
pub fn new_sparse( target: PsiSlice, block_sizes: Vec<usize>, p_out: usize, pairs: Vec<(usize, usize, f64)>, weight: f64, learnable_weight: bool, ) -> Result<Self, String>
Sparse-pair constructor used by the SAE live wiring (#1026): build the
operator directly from a list of penalized atom pairs (j, k, w) with
j < k and the symmetrized per-pair weight w (exactly the value the old
dense pair_weight(j, k) returned), avoiding any dense K×K allocation.
w == 0 pairs and out-of-range indices are dropped / rejected. This is
equivalent to Self::new fed the dense symmetric matrix with the same
nonzero entries.
Sourcepub fn with_weight_schedule(self, schedule: ScalarWeightSchedule) -> Self
pub fn with_weight_schedule(self, schedule: ScalarWeightSchedule) -> Self
Attach a scalar weight schedule, seeding the current weight from the schedule’s stored iteration counter.
Sourcepub fn accumulate_psd_majorizer_dense(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
scale: f64,
hbb: &mut Array2<f64>,
)
pub fn accumulate_psd_majorizer_dense( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, scale: f64, hbb: &mut Array2<f64>, )
Scatter the Gauss-Newton (PSD majorizer) curvature DIRECTLY into a dense
β × β block, accumulating scale · H_GN onto hbb.
This produces exactly the operator AnalyticPenalty::psd_majorizer_hvp
applies (the include_residual = false branch of Self::hvp_impl), but
assembled block-by-block over the penalized atom pairs instead of
reconstructed column-by-column from β unit-probe HVPs. Since H_GN is
pair-local — it couples only the (j, k) pairs in self.pairs, each within
their (M·p) decoder blocks — reading off the four output loops of
hvp_impl at a unit probe gives, per pair (j, k) with
w = w_sym · λ · scale and G_x = B_xᵀ B_x (the p × p decoder output
Gram of atom x):
- j-block diagonal
H[(j,a,o),(j,a,o')] += w · G_k[o,o'] - k-block diagonal
H[(k,b,o),(k,b,o')] += w · G_j[o,o'] - off-diagonal
H[(j,a,o₁),(k,b,o₂)] += w · B_j[a,o₂] · B_k[b,o₁]and its symmetric transpose into the(k, j)block.
Cost is O(Σ_pairs (M_j·M_k + M_j + M_k)·p²), versus the probe loop’s
O(β · Σ_pairs M_j·M_k·p): once β = K·M·p and the collinearity gate
admits O(K) co-active pairs, the probe loop spends O(K²) time
rebuilding a matrix this assembles in O(K) (#1026).
Trait Implementations§
Source§impl AnalyticPenalty for DecoderIncoherencePenalty
impl AnalyticPenalty for DecoderIncoherencePenalty
Source§fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64>
fn hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>
Exact Hessian-vector product H v = (∂²P/∂target²) v.
P = ½ w Σ_{j<k} w_{jk} ‖C_{jk}‖²_F is biquadratic (quartic) in the
decoder blocks, so the second derivative of the nonlinear-least-squares
objective carries two pieces along a direction V (per pair, with
W = w·w_{jk}):
(H v)_j[a,o] = W [ Σ_b dC[a,b]·B_k[b,o] + Σ_b C[a,b]·V_k[b,o] ]the Gauss-Newton term Σ dC·B and the residual term Σ C·V, with
dC[a,b] = Σ_o (V_j[a,o]·B_k[b,o] + B_j[a,o]·V_k[b,o]) (and the symmetric
_k block). The residual term is what makes the exact Hessian indefinite;
the GN-only surrogate lives in Self::psd_majorizer_hvp.
Source§fn psd_majorizer_hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64>
fn psd_majorizer_hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>
PSD majorizer-vector product B_GN(target; ρ) v for the nonconvex
decoder-incoherence penalty.
Dropping the indefinite residual term W·Σ C·V from the exact
Self::hvp leaves the Gauss-Newton block W·Jᵀ(J v) with
J = ∂vec(C)/∂vec(B). That block is PSD by construction — a sum of
W ≥ 0 (weight > 0, coactivation ≥ 0) times rank-structured Gram
products JᵀJ — and coincides with the exact Hessian as the cross-Gram
C → 0. The inner Newton / PIRLS curvature block must stay
positive-definite, so the GN block is the correct operator here, mirroring
the other nonconvex penalties (sparsity, JumpReLU, isometry) that override
the majorizer rather than hand back the indefinite true Hessian.
Source§fn tier(&self) -> PenaltyTier
fn tier(&self) -> PenaltyTier
Source§fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64
P(target; ρ). The strength factor
exp(ρ) (or whatever parameterization the penalty uses) is folded in.Source§fn grad_target(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64>
fn grad_target( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>
∂P/∂target, same length as target.Source§fn grad_rho(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64>
fn grad_rho( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>
Self::rho_count.Source§fn rho_count(&self) -> usize
fn rho_count(&self) -> usize
Source§fn apply_schedule(&mut self, iter: usize)
fn apply_schedule(&mut self, iter: usize)
Source§fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>>
fn hessian_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>
diag(∂²P/∂target²) when the Hessian is
block-diagonal. Returns None for penalties whose Hessian is dense
(Isometry); those implement Self::hvp instead. The default
signals “no closed-form diagonal” by returning None for any
non-empty target — concrete penalties either override with their
own analytic diagonal or rely on the matrix-free hvp path.Source§fn psd_majorizer_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>>
fn psd_majorizer_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>
diag(B(target; ρ)) with
B ⪰ ∂²P/∂target² everywhere and B ⪰ 0. This is a different
operator from Self::hessian_diag: for nonconvex penalties (log
sparsity, JumpReLU) the exact Hessian is indefinite, but the inner
Newton / PIRLS solve and the log-det / preconditioner pipeline require
a PSD curvature block. For convex penalties the majorizer coincides
with the exact Hessian, so the default simply delegates to
Self::hessian_diag; nonconvex penalties override.Source§impl Clone for DecoderIncoherencePenalty
impl Clone for DecoderIncoherencePenalty
Source§fn clone(&self) -> DecoderIncoherencePenalty
fn clone(&self) -> DecoderIncoherencePenalty
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreSource§impl Debug for DecoderIncoherencePenalty
impl Debug for DecoderIncoherencePenalty
Source§impl PenaltyManifest for DecoderIncoherencePenalty
impl PenaltyManifest for DecoderIncoherencePenalty
const KIND_TAG: &'static str = "decoder_incoherence"
const PYTHON_WRAPPER: &'static str = "DecoderIncoherencePenalty"
const ROW_BLOCK_DIAGONAL: bool = false
fn dispatch_tier(&self) -> PenaltyTier
Auto Trait Implementations§
impl Freeze for DecoderIncoherencePenalty
impl RefUnwindSafe for DecoderIncoherencePenalty
impl Send for DecoderIncoherencePenalty
impl Sync for DecoderIncoherencePenalty
impl Unpin for DecoderIncoherencePenalty
impl UnsafeUnpin for DecoderIncoherencePenalty
impl UnwindSafe for DecoderIncoherencePenalty
Blanket Implementations§
impl<T> Allocation for T
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T, U> Imply<T> for U
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pointable for T
impl<T> Pointable for T
impl<T> Read<Exclusive, BecauseExclusive> for Twhere
T: ?Sized,
Source§impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
Source§fn to_subset(&self) -> Option<SS>
fn to_subset(&self) -> Option<SS>
self from the equivalent element of its
superset. Read moreSource§fn is_in_subset(&self) -> bool
fn is_in_subset(&self) -> bool
self is actually part of its subset T (and can be converted to it).Source§fn to_subset_unchecked(&self) -> SS
fn to_subset_unchecked(&self) -> SS
self.to_subset but without any property checks. Always succeeds.Source§fn from_subset(element: &SS) -> SP
fn from_subset(element: &SS) -> SP
self to the equivalent element of its superset.