Skip to main content

SparseAutoencoder

Struct SparseAutoencoder 

Source
pub struct SparseAutoencoder { /* private fields */ }
Expand description

A Sparse Autoencoder for mechanistic interpretability.

Loads SAE weights from SAELens-format safetensors + cfg.json, encodes model activations into sparse feature vectors, decodes back to activation space, and produces steering vectors for injection.

Each SAE targets a single hook point in the model (e.g., resid_post at layer 5). Multiple SAEs can be loaded independently for different hook points.

§Example

use candle_mi::sae::SparseAutoencoder;
use candle_core::Device;

let sae = SparseAutoencoder::from_pretrained(
    "jbloom/Gemma-2-2B-Residual-Stream-SAEs",
    "gemma-2-2b-res-jb/blocks.20.hook_resid_post",
    &Device::Cpu,
)?;
println!("SAE: d_in={}, d_sae={}", sae.d_in(), sae.d_sae());

Implementations§

Source§

impl SparseAutoencoder

Source

pub fn from_local(dir: &Path, device: &Device) -> Result<Self>

Load an SAE from a local directory containing safetensors + cfg.json.

Expects either sae_weights.safetensors or model.safetensors plus a cfg.json file.

§Errors

Returns MIError::Config if cfg.json is missing or malformed. Returns MIError::Config if weight shapes don’t match cfg.json dimensions. Returns MIError::Model on tensor loading failure. Returns MIError::Io if files cannot be read.

Source

pub fn from_npz( npz_path: &Path, hook_layer: usize, device: &Device, ) -> Result<Self>

Load an SAE from a Gemma Scope NPZ file (params.npz).

The NPZ file must contain W_enc, W_dec, b_enc, b_dec arrays, and optionally threshold (for JumpReLU). Config is inferred from tensor shapes since NPZ files have no cfg.json.

§Arguments
  • npz_path — Path to the params.npz file
  • hook_layer — Which model layer this SAE hooks into
  • device — Target device (CPU or CUDA)
§Errors

Returns MIError::Config if required tensors are missing or shapes are inconsistent. Returns MIError::Io if the file cannot be read.

Source

pub fn from_pretrained_npz( repo_id: &str, npz_path: &str, hook_layer: usize, device: &Device, ) -> Result<Self>

Load an SAE from a HuggingFace repository containing an NPZ file.

Downloads the NPZ file via hf-fetch-model, then delegates to from_npz.

§Arguments
  • repo_idHuggingFace repository ID (e.g., "google/gemma-scope-2b-pt-res")
  • npz_path — Path within the repo to the NPZ file (e.g., "layer_0/width_16k/average_l0_105/params.npz")
  • hook_layer — Which model layer this SAE hooks into
  • device — Target device (CPU or CUDA)
§Errors

Returns MIError::Download if the file cannot be fetched. Returns MIError::Config if the NPZ format is invalid.

Source

pub fn from_pretrained( repo_id: &str, sae_id: &str, device: &Device, ) -> Result<Self>

Load an SAE from a HuggingFace repository.

Downloads safetensors + cfg.json via hf-fetch-model, then delegates to from_local.

§Arguments
  • repo_idHuggingFace repository ID (e.g., "jbloom/Gemma-2-2B-Residual-Stream-SAEs")
  • sae_id — Subdirectory within the repo (e.g., "gemma-2-2b-res-jb/blocks.20.hook_resid_post")
  • device — Target device (CPU or CUDA)
§Errors

Returns MIError::Download if files cannot be fetched. Returns MIError::Config if the SAE format is invalid.

Source

pub const fn config(&self) -> &SaeConfig

Access the SAE configuration.

Source

pub const fn hook_point(&self) -> &HookPoint

The hook point this SAE targets.

Source

pub const fn d_sae(&self) -> usize

Dictionary size (number of features).

Source

pub const fn d_in(&self) -> usize

Input dimension.

Source

pub fn encode(&self, x: &Tensor) -> Result<Tensor>

Encode activations into SAE feature space (dense output).

Applies the full encoder: pre_acts = x @ W_enc + b_enc, then the architecture-specific activation function (ReLU, JumpReLU, or TopK).

Uses TopKStrategy::Auto for TopK SAEs.

§Shapes
  • x: [..., d_in] — activations with any leading dimensions
  • returns: [..., d_sae] — encoded features (mostly sparse)
§Errors

Returns MIError::Config if the last dimension of x != d_in. Returns MIError::Model on tensor operation failure.

Source

pub fn encode_with_strategy( &self, x: &Tensor, strategy: &TopKStrategy, ) -> Result<Tensor>

Encode activations with an explicit TopKStrategy.

Same as encode() but allows overriding the TopK computation strategy.

§Shapes
  • x: [..., d_in] — activations with any leading dimensions
  • returns: [..., d_sae] — encoded features (mostly sparse)
§Errors

Returns MIError::Config if the last dimension of x != d_in. Returns MIError::Model on tensor operation failure.

Source

pub fn encode_sparse( &self, x: &Tensor, ) -> Result<SparseActivations<SaeFeatureId>>

Encode a single activation vector into sparse SAE features.

Returns only non-zero features, sorted by magnitude descending.

§Shapes
§Errors

Returns MIError::Config if x has wrong dimension. Returns MIError::Model on tensor operation failure.

Source

pub fn decode(&self, features: &Tensor) -> Result<Tensor>

Decode SAE features back to activation space.

§Shapes
  • features: [..., d_sae] — encoded feature activations
  • returns: [..., d_in] — reconstructed activations
§Errors

Returns MIError::Model on tensor operation failure.

Source

pub fn reconstruct(&self, x: &Tensor) -> Result<Tensor>

Reconstruct activations through the SAE (encode then decode).

§Shapes
  • x: [..., d_in] — original activations
  • returns: [..., d_in] — reconstructed activations
§Errors

Returns MIError::Config if the last dimension of x != d_in. Returns MIError::Model on tensor operation failure.

Source

pub fn reconstruction_error(&self, x: &Tensor) -> Result<f64>

Compute reconstruction MSE loss.

§Shapes
  • x: [..., d_in] — original activations
  • returns: scalar f64 mean squared error
§Errors

Returns MIError::Config if the last dimension of x != d_in. Returns MIError::Model on tensor operation failure.

Source

pub fn decoder_vector(&self, feature_idx: usize) -> Result<Tensor>

Extract a single feature’s decoder vector (steering direction).

§Shapes
  • returns: [d_in] — decoder vector on the SAE’s device
§Errors

Returns MIError::Config if feature_idx >= d_sae. Returns MIError::Model on tensor operation failure.

Source

pub fn prepare_hook_injection( &self, features: &[(usize, f32)], position: usize, seq_len: usize, device: &Device, ) -> Result<HookSpec>

Build a HookSpec that injects SAE decoder vectors into the model.

Creates an Intervention::Add at this SAE’s hook point with the accumulated (scaled) decoder vectors placed at the given position.

§Shapes
  • Internally constructs [1, seq_len, d_in] with the vector at position.
§Arguments
  • features — List of (feature_index, strength) pairs
  • position — Token position in the sequence to inject at
  • seq_len — Total sequence length
  • device — Device to construct injection tensors on
§Errors

Returns MIError::Config if any feature_index >= d_sae. Returns MIError::Model on tensor construction failure.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> PolicyExt for T
where T: ?Sized,

Source§

fn and<P, B, E>(self, other: P) -> And<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow only if self and other return Action::Follow. Read more
Source§

fn or<P, B, E>(self, other: P) -> Or<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow if either self or other returns Action::Follow. Read more
Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

impl<T> ErasedDestructor for T
where T: 'static,