pub struct SoftmaxAssignmentSparsityPenalty {
pub k_atoms: usize,
pub temperature: f64,
pub weight: f64,
pub weight_schedule: Option<ScalarWeightSchedule>,
}Expand description
Entropy sparsity over row-wise softmax assignment logits.
This is the SAE-manifold soft-assignment penalty. The target is a flat
row-major (N, K) logit matrix. Assignments are
a_i = softmax(logits_i / temperature), and the penalty is
lambda_sparse * sum_i H(a_i)
H(a_i) = -sum_k a_ik log a_ikMinimizing entropy drives each row toward a small active support while the
softmax keeps a_ik >= 0 and sum_k a_ik = 1. The exact Hessian is dense
in each row and can be indefinite because entropy is concave in assignment
space, so callers must use the HVP rather than a diagonal Hessian shortcut.
Fields§
§k_atoms: usize§temperature: f64§weight: f64§weight_schedule: Option<ScalarWeightSchedule>Implementations§
Source§impl SoftmaxAssignmentSparsityPenalty
impl SoftmaxAssignmentSparsityPenalty
pub fn new(k_atoms: usize, temperature: f64) -> Self
Sourcepub fn with_weight_schedule(self, schedule: ScalarWeightSchedule) -> Self
pub fn with_weight_schedule(self, schedule: ScalarWeightSchedule) -> Self
Attach a scalar weight schedule, seeding the current weight from the schedule’s stored iteration counter.
Sourcepub fn psd_majorizer_abs_row_sums(&self, row: &[f64], scale: f64) -> Vec<f64>
pub fn psd_majorizer_abs_row_sums(&self, row: &[f64], scale: f64) -> Vec<f64>
Absolute row sums of the exact per-row dense entropy Hessian, used as a Gershgorin / diagonal-dominance PSD majorizer.
The exact per-row Hessian wrt logits (symmetric, dense) is
H_kj = (λ/τ²)·a_k·[ δ_kj·(m − L_k − 1) + a_j·(L_k + L_j + 1 − 2m) ],
L_k = ln a_k + 1, m = Σ_j a_j L_j,whose diagonal coincides with AnalyticPenalty::hessian_diag. Entropy
is concave in assignment space, so this block is indefinite (negative on
near-uniform rows). Setting D_kk = Σ_j |H_kj| makes D − H symmetric
with nonnegative diagonal and diagonally dominant
(D_kk − H_kk = |H_kk| − H_kk + Σ_{j≠k}|H_kj| ≥ Σ_{j≠k}|(D−H)_kj|),
hence PSD: D ⪰ H and D ⪰ 0 both hold. D is a genuine PSD diagonal
operator that dominates the dense Hessian’s quadratic form — unlike the
raw indefinite diagonal, which is neither PSD nor a faithful stand-in for
the dense operator.
Sourcepub fn row_dense_hessian(&self, row_logits: &[f64], scale: f64) -> Array2<f64>
pub fn row_dense_hessian(&self, row_logits: &[f64], scale: f64) -> Array2<f64>
Exact per-row dense softmax-entropy Hessian wrt the row’s logits (#1038),
scaled by scale = λ/τ². Returns the symmetric K×K block
H_kj = scale·a_k·[ δ_kj·(m − L_k − 1) + a_j·(L_k + L_j + 1 − 2m) ],
L_k = ln a_k + 1, m = Σ_r a_r L_r,whose diagonal coincides with AnalyticPenalty::hessian_diag and whose
quadratic form coincides with AnalyticPenalty::hvp. This is the dense
block the Arrow-Schur row factor stores so the criterion’s log|H| and
the #1006 θ-adjoint differentiate the SAME operator (not just its
diagonal). The entropy block alone is gauge-null (H·𝟙 = 0, softmax
shift-invariance); callers must add it to the gauge-breaking data-fit
row block before factoring — never factor it in isolation.
Sourcepub fn row_dense_hessian_logit_derivative(
&self,
row_logits: &[f64],
scale: f64,
w: usize,
) -> Array2<f64>
pub fn row_dense_hessian_logit_derivative( &self, row_logits: &[f64], scale: f64, w: usize, ) -> Array2<f64>
Derivative of the exact per-row dense entropy Hessian
Self::row_dense_hessian with respect to a single row logit z_w,
scaled by scale = λ/τ². Returns the symmetric K×K block
∂H_kj/∂z_w, the third-derivative tensor slice the #1006 θ-adjoint
contracts against the row’s selected inverse. Built from the SAME
(a, L, m) as Self::row_dense_hessian (∂a_r/∂z_w = a_r(δ_rw − a_w)/τ),
so value, logdet and adjoint stay on one branch.
Sourcepub fn row_psd_majorizer(&self, row_logits: &[f64], scale: f64) -> Array2<f64>
pub fn row_psd_majorizer(&self, row_logits: &[f64], scale: f64) -> Array2<f64>
Per-row Gershgorin diagonal majorizer D of the exact softmax-entropy
Hessian Self::row_dense_hessian, scaled by scale = λ/τ². Returns the
K×K diagonal block diag(D_0, …, D_{K−1}) with
D_kk = Σ_j |H_kj| (#1419).
Unlike the Fisher metric Self::row_fisher_metric — which is PSD but
does NOT satisfy G ⪰ H_entropy (counterexample a=(0.95,0.05),
λ=τ=1: G₁₁=0.0475 < H₁₁=0.0784) — this D is a genuine Loewner
majorizer: it is diagonally dominant over H (D_kk − H_kk = |H_kk|−H_kk + Σ_{j≠k}|H_kj| ≥ Σ_{j≠k}|(D−H)_kj|), so D − H ⪰ 0, and
every D_kk ≥ 0, so D ⪰ 0. It therefore both keeps the assembled
evidence block PD (the property the entropy block needs so the
Faddeev–Popov deflation never fires) AND actually majorizes the entropy
curvature, which the Fisher surrogate did not. The criterion’s log|H|,
its θ-adjoint Self::row_psd_majorizer_logit_derivative, and the
assembled Hessian all differentiate this SAME operator D, keeping value
and adjoint on one exact branch.
Sourcepub fn row_psd_majorizer_logit_derivative(
&self,
row_logits: &[f64],
scale: f64,
w: usize,
) -> Array2<f64>
pub fn row_psd_majorizer_logit_derivative( &self, row_logits: &[f64], scale: f64, w: usize, ) -> Array2<f64>
Derivative of the per-row Gershgorin majorizer Self::row_psd_majorizer
with respect to a single row logit z_w, scaled by scale = λ/τ².
Returns the K×K diagonal block diag(∂D_0/∂z_w, …) with
∂D_kk/∂z_w = Σ_j sign(H_kj)·(∂H_kj/∂z_w) (#1419), where H is the exact
entropy Hessian Self::row_dense_hessian and ∂H_kj/∂z_w is
Self::row_dense_hessian_logit_derivative. sign(0)=0 (a zero entry
contributes no first-order change to its own magnitude). Built from the
SAME (a, L, m) derivative convention as the dense Hessian derivative, so
the θ-adjoint differentiates the SAME D the assembly added.
Sourcepub fn row_fisher_metric(&self, row_logits: &[f64], scale: f64) -> Array2<f64>
pub fn row_fisher_metric(&self, row_logits: &[f64], scale: f64) -> Array2<f64>
Per-row softmax Fisher-information metric G = scale·(diag(a) − a aᵀ)
over the row’s logits, with a = softmax(row_logits) and
scale = λ/τ² (#1190). Returns the symmetric K×K block
G_kj = scale·a_k·(δ_kj − a_j).G is a covariance/Gram matrix, hence exactly PSD and smooth in the
logits. It is the Fisher-information metric of the row softmax, NOT a
curvature majorizer of the entropy Hessian: G − H_entropy can be
indefinite (#1419: K=2, a=(0.95,0.05), λ=τ=1 gives G₁₁=0.0475 < H₁₁=0.0784, so G ⋡ H). The genuine Loewner majorizer the assembled
evidence block now uses is Self::row_psd_majorizer
(D_kk = Σ_j|H_kj|, which DOES satisfy D ⪰ H and D ⪰ 0); this
Fisher metric is retained only as a smooth PSD conditioning reference and
its derivative Self::row_fisher_metric_logit_derivative, and must not
be presented or used as a curvature majorizer.
Sourcepub fn row_fisher_metric_logit_derivative(
&self,
row_logits: &[f64],
scale: f64,
w: usize,
) -> Array2<f64>
pub fn row_fisher_metric_logit_derivative( &self, row_logits: &[f64], scale: f64, w: usize, ) -> Array2<f64>
Derivative of the per-row softmax Fisher metric
Self::row_fisher_metric with respect to a single row logit z_w,
scaled by scale = λ/τ² (#1190). Returns the symmetric K×K block
∂G_kj/∂z_w, the third-derivative tensor slice the θ-adjoint contracts
against the row’s selected inverse so the adjoint differentiates the SAME
PSD G = scale·(diag(a) − a aᵀ) the assembly added (value/adjoint on one
branch, no deflation needed). Built from the SAME softmax derivative
convention as Self::row_dense_hessian_logit_derivative
(∂a_r/∂z_w = a_r(δ_rw − a_w)/τ). For G_kj = scale·a_k(δ_kj − a_j),
the product rule gives
∂G_kj/∂z_w = scale·[ (∂a_k/∂z_w)(δ_kj − a_j) − a_k(∂a_j/∂z_w) ].
Trait Implementations§
Source§impl AnalyticPenalty for SoftmaxAssignmentSparsityPenalty
impl AnalyticPenalty for SoftmaxAssignmentSparsityPenalty
Source§fn tier(&self) -> PenaltyTier
fn tier(&self) -> PenaltyTier
Source§fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64
P(target; ρ). The strength factor
exp(ρ) (or whatever parameterization the penalty uses) is folded in.Source§fn grad_target(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64>
fn grad_target( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>
∂P/∂target, same length as target.Source§fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>>
fn hessian_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>
diag(∂²P/∂target²) when the Hessian is
block-diagonal. Returns None for penalties whose Hessian is dense
(Isometry); those implement Self::hvp instead. The default
signals “no closed-form diagonal” by returning None for any
non-empty target — concrete penalties either override with their
own analytic diagonal or rely on the matrix-free hvp path.Source§fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64>
fn hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>
H v = (∂²P/∂target²) v, in closed form. Read moreSource§fn psd_majorizer_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>>
fn psd_majorizer_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>
diag(B(target; ρ)) with
B ⪰ ∂²P/∂target² everywhere and B ⪰ 0. This is a different
operator from Self::hessian_diag: for nonconvex penalties (log
sparsity, JumpReLU) the exact Hessian is indefinite, but the inner
Newton / PIRLS solve and the log-det / preconditioner pipeline require
a PSD curvature block. For convex penalties the majorizer coincides
with the exact Hessian, so the default simply delegates to
Self::hessian_diag; nonconvex penalties override.Source§fn grad_rho(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64>
fn grad_rho( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>
Self::rho_count.Source§fn rho_count(&self) -> usize
fn rho_count(&self) -> usize
Source§fn apply_schedule(&mut self, iter: usize)
fn apply_schedule(&mut self, iter: usize)
Source§fn psd_majorizer_hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64>
fn psd_majorizer_hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>
B(target; ρ) v
(see Self::psd_majorizer_diag). For convex penalties this is the
exact Hessian-vector product, so the default delegates to
Self::hvp; nonconvex penalties override to return their PSD
surrogate instead of the indefinite true Hessian.Source§impl Clone for SoftmaxAssignmentSparsityPenalty
impl Clone for SoftmaxAssignmentSparsityPenalty
Source§fn clone(&self) -> SoftmaxAssignmentSparsityPenalty
fn clone(&self) -> SoftmaxAssignmentSparsityPenalty
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreSource§impl PenaltyManifest for SoftmaxAssignmentSparsityPenalty
impl PenaltyManifest for SoftmaxAssignmentSparsityPenalty
const KIND_TAG: &'static str = "softmax_assignment_sparsity"
const PYTHON_WRAPPER: &'static str = "SoftmaxAssignmentSparsityPenalty"
const ROW_BLOCK_DIAGONAL: bool = true
fn dispatch_tier(&self) -> PenaltyTier
Auto Trait Implementations§
impl Freeze for SoftmaxAssignmentSparsityPenalty
impl RefUnwindSafe for SoftmaxAssignmentSparsityPenalty
impl Send for SoftmaxAssignmentSparsityPenalty
impl Sync for SoftmaxAssignmentSparsityPenalty
impl Unpin for SoftmaxAssignmentSparsityPenalty
impl UnsafeUnpin for SoftmaxAssignmentSparsityPenalty
impl UnwindSafe for SoftmaxAssignmentSparsityPenalty
Blanket Implementations§
impl<T> Allocation for T
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T, U> Imply<T> for U
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pointable for T
impl<T> Pointable for T
impl<T> Read<Exclusive, BecauseExclusive> for Twhere
T: ?Sized,
Source§impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
Source§fn to_subset(&self) -> Option<SS>
fn to_subset(&self) -> Option<SS>
self from the equivalent element of its
superset. Read moreSource§fn is_in_subset(&self) -> bool
fn is_in_subset(&self) -> bool
self is actually part of its subset T (and can be converted to it).Source§fn to_subset_unchecked(&self) -> SS
fn to_subset_unchecked(&self) -> SS
self.to_subset but without any property checks. Always succeeds.Source§fn from_subset(element: &SS) -> SP
fn from_subset(element: &SS) -> SP
self to the equivalent element of its superset.