pub struct NestedPrefixPenalty {
pub target: PsiSlice,
pub target_tier: PenaltyTier,
pub prefix_sizes: Vec<usize>,
pub shell_weights: Vec<f64>,
pub eps: f64,
pub rho_indices: Vec<usize>,
pub weight_schedule: Option<ScalarWeightSchedule>,
}Expand description
Nested-prefix sparsity penalty used by the Matryoshka SAE (Bussmann/Nabeshima/Karvonen/Nanda, ICML 2025, arXiv:2503.17547).
Given K nested prefix sizes m_1 < m_2 < ... < m_K ≤ F over the latent
dimension F, and per-shell weights λ_k = w_k · exp(ρ_k), the penalty is
P(t; ρ) = Σ_k λ_k · Σ_{i=0}^{m_k - 1} sqrt(t_i² + ε²)summed over all rows of the latent target. Equivalently, coordinate i
contributes with effective weight W_i = Σ_{k: m_k > i} λ_k, so the
earliest atoms (small i) are penalized by every shell (= strongest L¹)
and the latest atoms only by the outermost shell. This is exactly the
mask-weighted sum-of-L¹ over K prefixes used to enforce shell-wise
reconstruction during Matryoshka training.
Closed forms (per row, summed across all rows):
∂P/∂t_i = W_i · t_i / sqrt(t_i² + ε²)
Hess_diag(i) = W_i · ε² / (t_i² + ε²)^{3/2} (PSD)
∂P/∂ρ_k = λ_k · Σ_{i < m_k} sqrt(t_i² + ε²)target lays out n_rows × latent_dim in row-major order (row * F + col).
latent_dim is taken from PsiSlice::latent_dim; if absent we fall back to
the maximum prefix size, which is the standard Matryoshka convention.
Fields§
§target: PsiSlice§target_tier: PenaltyTier§prefix_sizes: Vec<usize>Sorted strictly-increasing prefix sizes m_1 < m_2 < ... < m_K.
shell_weights: Vec<f64>Per-shell base weights w_k. The effective strength is
λ_k = w_k · exp(ρ_k).
eps: f64Smoothing parameter ε > 0 for the smoothed-L¹ surrogate
sqrt(x² + ε²); the Hessian needs ε > 0 for differentiability at 0.
rho_indices: Vec<usize>Local ρ indices for the K per-shell log-strengths.
weight_schedule: Option<ScalarWeightSchedule>Implementations§
Source§impl NestedPrefixPenalty
impl NestedPrefixPenalty
Sourcepub fn new(
target: PsiSlice,
target_tier: PenaltyTier,
prefix_sizes: Vec<usize>,
shell_weights: Vec<f64>,
eps: f64,
) -> Result<Self, String>
pub fn new( target: PsiSlice, target_tier: PenaltyTier, prefix_sizes: Vec<usize>, shell_weights: Vec<f64>, eps: f64, ) -> Result<Self, String>
Build a new nested-prefix penalty.
Errors when:
prefix_sizesis empty.prefix_sizesis not strictly increasing.- any prefix exceeds the latent dimension (when known).
shell_weights.len() != prefix_sizes.len().eps <= 0(the smoothed-L¹ gradient1/sqrt(x²+ε²)and Hessianε²/(x²+ε²)^{3/2}both need ε > 0).
Sourcepub fn with_weight_schedule(self, schedule: ScalarWeightSchedule) -> Self
pub fn with_weight_schedule(self, schedule: ScalarWeightSchedule) -> Self
Attach a global annealing schedule shared by all shell weights. The REML loop still picks per-shell ρ_k on top of this baseline.
Trait Implementations§
Source§impl AnalyticPenalty for NestedPrefixPenalty
impl AnalyticPenalty for NestedPrefixPenalty
Source§fn tier(&self) -> PenaltyTier
fn tier(&self) -> PenaltyTier
Source§fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64
P(target; ρ). The strength factor
exp(ρ) (or whatever parameterization the penalty uses) is folded in.Source§fn grad_target(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64>
fn grad_target( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>
∂P/∂target, same length as target.Source§fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>>
fn hessian_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>
diag(∂²P/∂target²) when the Hessian is
block-diagonal. Returns None for penalties whose Hessian is dense
(Isometry); those implement Self::hvp instead. The default
signals “no closed-form diagonal” by returning None for any
non-empty target — concrete penalties either override with their
own analytic diagonal or rely on the matrix-free hvp path.Source§fn grad_rho(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Array1<f64>
fn grad_rho( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Array1<f64>
Self::rho_count.Source§fn rho_count(&self) -> usize
fn rho_count(&self) -> usize
Source§fn apply_schedule(&mut self, iter: usize)
fn apply_schedule(&mut self, iter: usize)
Source§fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64>
fn hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>
H v = (∂²P/∂target²) v, in closed form. Read moreSource§fn psd_majorizer_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>>
fn psd_majorizer_diag( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, ) -> Option<Array1<f64>>
diag(B(target; ρ)) with
B ⪰ ∂²P/∂target² everywhere and B ⪰ 0. This is a different
operator from Self::hessian_diag: for nonconvex penalties (log
sparsity, JumpReLU) the exact Hessian is indefinite, but the inner
Newton / PIRLS solve and the log-det / preconditioner pipeline require
a PSD curvature block. For convex penalties the majorizer coincides
with the exact Hessian, so the default simply delegates to
Self::hessian_diag; nonconvex penalties override.Source§fn psd_majorizer_hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64>
fn psd_majorizer_hvp( &self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>, ) -> Array1<f64>
B(target; ρ) v
(see Self::psd_majorizer_diag). For convex penalties this is the
exact Hessian-vector product, so the default delegates to
Self::hvp; nonconvex penalties override to return their PSD
surrogate instead of the indefinite true Hessian.Source§impl Clone for NestedPrefixPenalty
impl Clone for NestedPrefixPenalty
Source§fn clone(&self) -> NestedPrefixPenalty
fn clone(&self) -> NestedPrefixPenalty
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreSource§impl Debug for NestedPrefixPenalty
impl Debug for NestedPrefixPenalty
Source§impl PenaltyManifest for NestedPrefixPenalty
impl PenaltyManifest for NestedPrefixPenalty
const KIND_TAG: &'static str = "nested_prefix"
const PYTHON_WRAPPER: &'static str = "NestedPrefixPenalty"
const ROW_BLOCK_DIAGONAL: bool = true
fn dispatch_tier(&self) -> PenaltyTier
Auto Trait Implementations§
impl Freeze for NestedPrefixPenalty
impl RefUnwindSafe for NestedPrefixPenalty
impl Send for NestedPrefixPenalty
impl Sync for NestedPrefixPenalty
impl Unpin for NestedPrefixPenalty
impl UnsafeUnpin for NestedPrefixPenalty
impl UnwindSafe for NestedPrefixPenalty
Blanket Implementations§
impl<T> Allocation for T
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T, U> Imply<T> for U
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pointable for T
impl<T> Pointable for T
impl<T> Read<Exclusive, BecauseExclusive> for Twhere
T: ?Sized,
Source§impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
Source§fn to_subset(&self) -> Option<SS>
fn to_subset(&self) -> Option<SS>
self from the equivalent element of its
superset. Read moreSource§fn is_in_subset(&self) -> bool
fn is_in_subset(&self) -> bool
self is actually part of its subset T (and can be converted to it).Source§fn to_subset_unchecked(&self) -> SS
fn to_subset_unchecked(&self) -> SS
self.to_subset but without any property checks. Always succeeds.Source§fn from_subset(element: &SS) -> SP
fn from_subset(element: &SS) -> SP
self to the equivalent element of its superset.