Skip to main content

SparseLinear

Struct SparseLinear 

Source
pub struct SparseLinear {
    pub weight: Parameter,
    pub bias: Option<Parameter>,
    pub threshold: Parameter,
    /* private fields */
}
Expand description

A linear layer with a differentiable magnitude pruning mask.

During the forward pass, a soft mask is computed via sigmoid soft thresholding:

mask = sigmoid((|weight| - threshold) * temperature)
effective_weight = weight * mask
y = x @ effective_weight^T + bias

The sigmoid makes the mask differentiable, so gradients flow through it and the network learns which weights to prune. The threshold parameter is learnable and included in parameters().

§Structured vs Unstructured

  • Structured (structured=true): One threshold per output neuron. Entire output channels can be pruned, yielding hardware-friendly sparsity.
  • Unstructured (structured=false): One threshold per weight element. Finer-grained but less hardware-friendly.

§Example

let layer = SparseLinear::new(784, 256);
let output = layer.forward(&input);
println!("Density: {:.1}%", layer.density() * 100.0);

Fields§

§weight: Parameter

Weight matrix of shape (out_features, in_features).

§bias: Option<Parameter>

Optional bias vector of shape (out_features).

§threshold: Parameter

Learnable magnitude thresholds. Shape depends on structured:

  • Structured: (out_features,)
  • Unstructured: (out_features, in_features)

Implementations§

Source§

impl SparseLinear

Source

pub fn new(in_features: usize, out_features: usize) -> Self

Creates a new SparseLinear layer with structured pruning and bias.

§Arguments
  • in_features - Size of each input sample
  • out_features - Size of each output sample
Source

pub fn unstructured(in_features: usize, out_features: usize) -> Self

Creates a new SparseLinear layer with unstructured (per-weight) pruning.

§Arguments
  • in_features - Size of each input sample
  • out_features - Size of each output sample
Source

pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self

Creates a new SparseLinear layer with configurable bias.

§Arguments
  • in_features - Size of each input sample
  • out_features - Size of each output sample
  • bias - Whether to include a learnable bias
Source

pub fn in_features(&self) -> usize

Returns the input feature dimension.

Source

pub fn out_features(&self) -> usize

Returns the output feature dimension.

Source

pub fn is_structured(&self) -> bool

Returns whether this layer uses structured pruning.

Source

pub fn density(&self) -> f32

Returns the fraction of weights that are active (above threshold).

Uses hard thresholding at |weight| >= threshold.

Source

pub fn sparsity(&self) -> f32

Returns the fraction of weights that are pruned.

Equivalent to 1.0 - density().

Source

pub fn num_active(&self) -> usize

Returns the number of active (non-pruned) weights.

Source

pub fn hard_prune(&mut self)

Permanently applies the pruning mask to the weights.

After calling this, pruned weights are zeroed and the threshold is reset to zero. This is an irreversible optimization for inference — the zeroed weights will not be recovered.

Source

pub fn reset_threshold(&mut self, value: f32)

Resets the threshold to a specific value.

§Arguments
  • value - The new threshold value
Source

pub fn effective_weight(&self) -> Tensor<f32>

Returns the effective weight (weight * hard_mask) for inspection.

This shows what the weight matrix looks like after hard pruning, without actually modifying the layer.

Trait Implementations§

Source§

impl Debug for SparseLinear

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Module for SparseLinear

Source§

fn forward(&self, input: &Variable) -> Variable

Performs the forward pass. Read more
Source§

fn parameters(&self) -> Vec<Parameter>

Returns all parameters of this module. Read more
Source§

fn named_parameters(&self) -> HashMap<String, Parameter>

Returns named parameters of this module.
Source§

fn name(&self) -> &'static str

Returns the module name for debugging.
Source§

fn num_parameters(&self) -> usize

Returns the number of trainable parameters.
Source§

fn train(&mut self)

Sets the module to training mode.
Source§

fn eval(&mut self)

Sets the module to evaluation mode.
Source§

fn set_training(&mut self, _training: bool)

Sets the training mode.
Source§

fn is_training(&self) -> bool

Returns whether the module is in training mode.
Source§

fn zero_grad(&self)

Zeros all gradients of parameters.
Source§

fn to_device(&self, device: Device)

Moves all parameters to the specified device.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V