Skip to main content

SoftmaxAssignmentSparsityPenalty

Struct SoftmaxAssignmentSparsityPenalty 

Source
pub struct SoftmaxAssignmentSparsityPenalty {
    pub k_atoms: usize,
    pub temperature: f64,
    pub weight: f64,
    pub weight_schedule: Option<ScalarWeightSchedule>,
}
Expand description

Entropy sparsity over row-wise softmax assignment logits.

This is the SAE-manifold soft-assignment penalty. The target is a flat row-major (N, K) logit matrix. Assignments are a_i = softmax(logits_i / temperature), and the penalty is

  lambda_sparse * sum_i H(a_i)
  H(a_i) = -sum_k a_ik log a_ik

Minimizing entropy drives each row toward a small active support while the softmax keeps a_ik >= 0 and sum_k a_ik = 1. The exact Hessian is dense in each row and can be indefinite because entropy is concave in assignment space, so callers must use the HVP rather than a diagonal Hessian shortcut.

Fields§

§k_atoms: usize§temperature: f64§weight: f64§weight_schedule: Option<ScalarWeightSchedule>

Implementations§

Source§

impl SoftmaxAssignmentSparsityPenalty

Source

pub fn new(k_atoms: usize, temperature: f64) -> Self

Source

pub fn with_weight_schedule(self, schedule: ScalarWeightSchedule) -> Self

Attach a scalar weight schedule, seeding the current weight from the schedule’s stored iteration counter.

Source

pub fn psd_majorizer_abs_row_sums(&self, row: &[f64], scale: f64) -> Vec<f64>

Absolute row sums of the exact per-row dense entropy Hessian, used as a Gershgorin / diagonal-dominance PSD majorizer.

The exact per-row Hessian wrt logits (symmetric, dense) is

  H_kj = (λ/τ²)·a_k·[ δ_kj·(m − L_k − 1) + a_j·(L_k + L_j + 1 − 2m) ],
  L_k = ln a_k + 1,   m = Σ_j a_j L_j,

whose diagonal coincides with AnalyticPenalty::hessian_diag. Entropy is concave in assignment space, so this block is indefinite (negative on near-uniform rows). Setting D_kk = Σ_j |H_kj| makes D − H symmetric with nonnegative diagonal and diagonally dominant (D_kk − H_kk = |H_kk| − H_kk + Σ_{j≠k}|H_kj| ≥ Σ_{j≠k}|(D−H)_kj|), hence PSD: D ⪰ H and D ⪰ 0 both hold. D is a genuine PSD diagonal operator that dominates the dense Hessian’s quadratic form — unlike the raw indefinite diagonal, which is neither PSD nor a faithful stand-in for the dense operator.

Source

pub fn row_dense_hessian(&self, row_logits: &[f64], scale: f64) -> Array2<f64>

Exact per-row dense softmax-entropy Hessian wrt the row’s logits (#1038), scaled by scale = λ/τ². Returns the symmetric K×K block

  H_kj = scale·a_k·[ δ_kj·(m − L_k − 1) + a_j·(L_k + L_j + 1 − 2m) ],
  L_k = ln a_k + 1,   m = Σ_r a_r L_r,

whose diagonal coincides with AnalyticPenalty::hessian_diag and whose quadratic form coincides with AnalyticPenalty::hvp. This is the dense block the Arrow-Schur row factor stores so the criterion’s log|H| and the #1006 θ-adjoint differentiate the SAME operator (not just its diagonal). The entropy block alone is gauge-null (H·𝟙 = 0, softmax shift-invariance); callers must add it to the gauge-breaking data-fit row block before factoring — never factor it in isolation.

Source

pub fn row_dense_hessian_logit_derivative( &self, row_logits: &[f64], scale: f64, w: usize, ) -> Array2<f64>

Derivative of the exact per-row dense entropy Hessian Self::row_dense_hessian with respect to a single row logit z_w, scaled by scale = λ/τ². Returns the symmetric K×K block ∂H_kj/∂z_w, the third-derivative tensor slice the #1006 θ-adjoint contracts against the row’s selected inverse. Built from the SAME (a, L, m) as Self::row_dense_hessian (∂a_r/∂z_w = a_r(δ_rw − a_w)/τ), so value, logdet and adjoint stay on one branch.

Source

pub fn row_psd_majorizer(&self, row_logits: &[f64], scale: f64) -> Array2<f64>

Per-row Gershgorin diagonal majorizer D of the exact softmax-entropy Hessian Self::row_dense_hessian, scaled by scale = λ/τ². Returns the K×K diagonal block diag(D_0, …, D_{K−1}) with D_kk = Σ_j |H_kj| (#1419).

Unlike the Fisher metric Self::row_fisher_metric — which is PSD but does NOT satisfy G ⪰ H_entropy (counterexample a=(0.95,0.05), λ=τ=1: G₁₁=0.0475 < H₁₁=0.0784) — this D is a genuine Loewner majorizer: it is diagonally dominant over H (D_kk − H_kk = |H_kk|−H_kk + Σ_{j≠k}|H_kj| ≥ Σ_{j≠k}|(D−H)_kj|), so D − H ⪰ 0, and every D_kk ≥ 0, so D ⪰ 0. It therefore both keeps the assembled evidence block PD (the property the entropy block needs so the Faddeev–Popov deflation never fires) AND actually majorizes the entropy curvature, which the Fisher surrogate did not. The criterion’s log|H|, its θ-adjoint Self::row_psd_majorizer_logit_derivative, and the assembled Hessian all differentiate this SAME operator D, keeping value and adjoint on one exact branch.

Source

pub fn row_psd_majorizer_logit_derivative( &self, row_logits: &[f64], scale: f64, w: usize, ) -> Array2<f64>

Derivative of the per-row Gershgorin majorizer Self::row_psd_majorizer with respect to a single row logit z_w, scaled by scale = λ/τ². Returns the K×K diagonal block diag(∂D_0/∂z_w, …) with ∂D_kk/∂z_w = Σ_j sign(H_kj)·(∂H_kj/∂z_w) (#1419), where H is the exact entropy Hessian Self::row_dense_hessian and ∂H_kj/∂z_w is Self::row_dense_hessian_logit_derivative. sign(0)=0 (a zero entry contributes no first-order change to its own magnitude). Built from the SAME (a, L, m) derivative convention as the dense Hessian derivative, so the θ-adjoint differentiates the SAME D the assembly added.

Source

pub fn row_fisher_metric(&self, row_logits: &[f64], scale: f64) -> Array2<f64>

Per-row softmax Fisher-information metric G = scale·(diag(a) − a aᵀ) over the row’s logits, with a = softmax(row_logits) and scale = λ/τ² (#1190). Returns the symmetric K×K block

  G_kj = scale·a_k·(δ_kj − a_j).

G is a covariance/Gram matrix, hence exactly PSD and smooth in the logits. It is the Fisher-information metric of the row softmax, NOT a curvature majorizer of the entropy Hessian: G − H_entropy can be indefinite (#1419: K=2, a=(0.95,0.05), λ=τ=1 gives G₁₁=0.0475 < H₁₁=0.0784, so G ⋡ H). The genuine Loewner majorizer the assembled evidence block now uses is Self::row_psd_majorizer (D_kk = Σ_j|H_kj|, which DOES satisfy D ⪰ H and D ⪰ 0); this Fisher metric is retained only as a smooth PSD conditioning reference and its derivative Self::row_fisher_metric_logit_derivative, and must not be presented or used as a curvature majorizer.

Source

pub fn row_fisher_metric_logit_derivative( &self, row_logits: &[f64], scale: f64, w: usize, ) -> Array2<f64>

Derivative of the per-row softmax Fisher metric Self::row_fisher_metric with respect to a single row logit z_w, scaled by scale = λ/τ² (#1190). Returns the symmetric K×K block ∂G_kj/∂z_w, the third-derivative tensor slice the θ-adjoint contracts against the row’s selected inverse so the adjoint differentiates the SAME PSD G = scale·(diag(a) − a aᵀ) the assembly added (value/adjoint on one branch, no deflation needed). Built from the SAME softmax derivative convention as Self::row_dense_hessian_logit_derivative (∂a_r/∂z_w = a_r(δ_rw − a_w)/τ). For G_kj = scale·a_k(δ_kj − a_j), the product rule gives ∂G_kj/∂z_w = scale·[ (∂a_k/∂z_w)(δ_kj − a_j) − a_k(∂a_j/∂z_w) ].

Trait Implementations§

Source§

impl AnalyticPenalty for SoftmaxAssignmentSparsityPenalty

Source§

fn tier(&self) -> PenaltyTier

Tier the target lives in (β or ext-coord).
Source§

fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64

Scalar penalty contribution P(target; ρ). The strength factor exp(ρ) (or whatever parameterization the penalty uses) is folded in.
Source§

fn grad_target( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>

Gradient ∂P/∂target, same length as target.
Source§

fn hessian_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>

Diagonal of the Hessian diag(∂²P/∂target²) when the Hessian is block-diagonal. Returns None for penalties whose Hessian is dense (Isometry); those implement Self::hvp instead. The default signals “no closed-form diagonal” by returning None for any non-empty target — concrete penalties either override with their own analytic diagonal or rely on the matrix-free hvp path.
Source§

fn hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>

Hessian-vector product H v = (∂²P/∂target²) v, in closed form. Read more
Source§

fn psd_majorizer_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>

Diagonal of a PSD majorizer of the Hessian — the positive re-weighted-ℓ₂ / MM surrogate diag(B(target; ρ)) with B ⪰ ∂²P/∂target² everywhere and B ⪰ 0. This is a different operator from Self::hessian_diag: for nonconvex penalties (log sparsity, JumpReLU) the exact Hessian is indefinite, but the inner Newton / PIRLS solve and the log-det / preconditioner pipeline require a PSD curvature block. For convex penalties the majorizer coincides with the exact Hessian, so the default simply delegates to Self::hessian_diag; nonconvex penalties override.
Source§

fn grad_rho( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>

Gradient of the penalty value w.r.t. each owned ρ-axis. Length equals Self::rho_count.
Source§

fn rho_count(&self) -> usize

Number of REML-selectable hyperparameter axes this penalty contributes to the outer ρ vector.
Source§

fn name(&self) -> &str

Human-readable identifier for diagnostics / logging.
Source§

fn apply_schedule(&mut self, iter: usize)

Update any attached scalar weight schedule at the given REML outer iteration. Penalties without schedules keep their stored weight.
Source§

fn psd_majorizer_hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>

Matrix-vector product against the PSD majorizer B(target; ρ) v (see Self::psd_majorizer_diag). For convex penalties this is the exact Hessian-vector product, so the default delegates to Self::hvp; nonconvex penalties override to return their PSD surrogate instead of the indefinite true Hessian.
Source§

impl Clone for SoftmaxAssignmentSparsityPenalty

Source§

fn clone(&self) -> SoftmaxAssignmentSparsityPenalty

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for SoftmaxAssignmentSparsityPenalty

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl PenaltyManifest for SoftmaxAssignmentSparsityPenalty

Source§

const KIND_TAG: &'static str = "softmax_assignment_sparsity"

Source§

const PYTHON_WRAPPER: &'static str = "SoftmaxAssignmentSparsityPenalty"

Source§

const ROW_BLOCK_DIAGONAL: bool = true

Source§

fn dispatch_tier(&self) -> PenaltyTier

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Allocation for T
where T: RefUnwindSafe + Send + Sync,

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> ByRef<T> for T

Source§

fn by_ref(&self) -> &T

Source§

impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> DistributionExt for T
where T: ?Sized,

Source§

fn rand<T>(&self, rng: &mut (impl Rng + ?Sized)) -> T
where Self: Distribution<T>,

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Imply<T> for U
where T: ?Sized, U: ?Sized,

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> Read<Exclusive, BecauseExclusive> for T
where T: ?Sized,

Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<SS, SP> SupersetOf<SS> for SP
where SS: SubsetOf<SP>,

Source§

fn to_subset(&self) -> Option<SS>

The inverse inclusion map: attempts to construct self from the equivalent element of its superset. Read more
Source§

fn is_in_subset(&self) -> bool

Checks if self is actually part of its subset T (and can be converted to it).
Source§

fn to_subset_unchecked(&self) -> SS

Use with care! Same as self.to_subset but without any property checks. Always succeeds.
Source§

fn from_subset(element: &SS) -> SP

The inclusion map: converts self to the equivalent element of its superset.
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V