pub struct SparseAutoencoder { /* private fields */ }Expand description
A Sparse Autoencoder for mechanistic interpretability.
Loads SAE weights from SAELens-format safetensors + cfg.json,
encodes model activations into sparse feature vectors, decodes
back to activation space, and produces steering vectors for injection.
Each SAE targets a single hook point in the model (e.g., resid_post
at layer 5). Multiple SAEs can be loaded independently for different
hook points.
§Example
use candle_mi::sae::SparseAutoencoder;
use candle_core::Device;
let sae = SparseAutoencoder::from_pretrained(
"jbloom/Gemma-2-2B-Residual-Stream-SAEs",
"gemma-2-2b-res-jb/blocks.20.hook_resid_post",
&Device::Cpu,
)?;
println!("SAE: d_in={}, d_sae={}", sae.d_in(), sae.d_sae());Implementations§
Source§impl SparseAutoencoder
impl SparseAutoencoder
Sourcepub fn from_local(dir: &Path, device: &Device) -> Result<Self>
pub fn from_local(dir: &Path, device: &Device) -> Result<Self>
Load an SAE from a local directory containing safetensors + cfg.json.
Expects either sae_weights.safetensors or model.safetensors
plus a cfg.json file.
§Errors
Returns MIError::Config if cfg.json is missing or malformed.
Returns MIError::Config if weight shapes don’t match cfg.json dimensions.
Returns MIError::Model on tensor loading failure.
Returns MIError::Io if files cannot be read.
Sourcepub fn from_npz(
npz_path: &Path,
hook_layer: usize,
device: &Device,
) -> Result<Self>
pub fn from_npz( npz_path: &Path, hook_layer: usize, device: &Device, ) -> Result<Self>
Load an SAE from a Gemma Scope NPZ file (params.npz).
The NPZ file must contain W_enc, W_dec, b_enc, b_dec arrays,
and optionally threshold (for JumpReLU). Config is inferred from
tensor shapes since NPZ files have no cfg.json.
§Arguments
npz_path— Path to theparams.npzfilehook_layer— Which model layer this SAE hooks intodevice— Target device (CPU or CUDA)
§Errors
Returns MIError::Config if required tensors are missing or shapes
are inconsistent.
Returns MIError::Io if the file cannot be read.
Sourcepub fn from_pretrained_npz(
repo_id: &str,
npz_path: &str,
hook_layer: usize,
device: &Device,
) -> Result<Self>
pub fn from_pretrained_npz( repo_id: &str, npz_path: &str, hook_layer: usize, device: &Device, ) -> Result<Self>
Load an SAE from a HuggingFace repository containing an NPZ file.
Downloads the NPZ file via hf-fetch-model, then delegates to
from_npz.
§Arguments
repo_id—HuggingFacerepository ID (e.g.,"google/gemma-scope-2b-pt-res")npz_path— Path within the repo to the NPZ file (e.g.,"layer_0/width_16k/average_l0_105/params.npz")hook_layer— Which model layer this SAE hooks intodevice— Target device (CPU or CUDA)
§Errors
Returns MIError::Download if the file cannot be fetched.
Returns MIError::Config if the NPZ format is invalid.
Sourcepub fn from_pretrained(
repo_id: &str,
sae_id: &str,
device: &Device,
) -> Result<Self>
pub fn from_pretrained( repo_id: &str, sae_id: &str, device: &Device, ) -> Result<Self>
Load an SAE from a HuggingFace repository.
Downloads safetensors + cfg.json via hf-fetch-model, then delegates
to from_local.
§Arguments
repo_id—HuggingFacerepository ID (e.g.,"jbloom/Gemma-2-2B-Residual-Stream-SAEs")sae_id— Subdirectory within the repo (e.g.,"gemma-2-2b-res-jb/blocks.20.hook_resid_post")device— Target device (CPU or CUDA)
§Errors
Returns MIError::Download if files cannot be fetched.
Returns MIError::Config if the SAE format is invalid.
Sourcepub const fn hook_point(&self) -> &HookPoint
pub const fn hook_point(&self) -> &HookPoint
The hook point this SAE targets.
Sourcepub fn encode(&self, x: &Tensor) -> Result<Tensor>
pub fn encode(&self, x: &Tensor) -> Result<Tensor>
Encode activations into SAE feature space (dense output).
Applies the full encoder: pre_acts = x @ W_enc + b_enc, then the
architecture-specific activation function (ReLU, JumpReLU, or TopK).
Uses TopKStrategy::Auto for TopK SAEs.
§Shapes
x:[..., d_in]— activations with any leading dimensions- returns:
[..., d_sae]— encoded features (mostly sparse)
§Errors
Returns MIError::Config if the last dimension of x != d_in.
Returns MIError::Model on tensor operation failure.
Sourcepub fn encode_with_strategy(
&self,
x: &Tensor,
strategy: &TopKStrategy,
) -> Result<Tensor>
pub fn encode_with_strategy( &self, x: &Tensor, strategy: &TopKStrategy, ) -> Result<Tensor>
Encode activations with an explicit TopKStrategy.
Same as encode() but allows overriding the TopK
computation strategy.
§Shapes
x:[..., d_in]— activations with any leading dimensions- returns:
[..., d_sae]— encoded features (mostly sparse)
§Errors
Returns MIError::Config if the last dimension of x != d_in.
Returns MIError::Model on tensor operation failure.
Sourcepub fn encode_sparse(
&self,
x: &Tensor,
) -> Result<SparseActivations<SaeFeatureId>>
pub fn encode_sparse( &self, x: &Tensor, ) -> Result<SparseActivations<SaeFeatureId>>
Encode a single activation vector into sparse SAE features.
Returns only non-zero features, sorted by magnitude descending.
§Shapes
x:[d_in]— single activation vector- returns:
SparseActivations<SaeFeatureId>with(SaeFeatureId, f32)pairs
§Errors
Returns MIError::Config if x has wrong dimension.
Returns MIError::Model on tensor operation failure.
Sourcepub fn decode(&self, features: &Tensor) -> Result<Tensor>
pub fn decode(&self, features: &Tensor) -> Result<Tensor>
Decode SAE features back to activation space.
§Shapes
features:[..., d_sae]— encoded feature activations- returns:
[..., d_in]— reconstructed activations
§Errors
Returns MIError::Model on tensor operation failure.
Sourcepub fn reconstruct(&self, x: &Tensor) -> Result<Tensor>
pub fn reconstruct(&self, x: &Tensor) -> Result<Tensor>
Reconstruct activations through the SAE (encode then decode).
§Shapes
x:[..., d_in]— original activations- returns:
[..., d_in]— reconstructed activations
§Errors
Returns MIError::Config if the last dimension of x != d_in.
Returns MIError::Model on tensor operation failure.
Sourcepub fn reconstruction_error(&self, x: &Tensor) -> Result<f64>
pub fn reconstruction_error(&self, x: &Tensor) -> Result<f64>
Compute reconstruction MSE loss.
§Shapes
x:[..., d_in]— original activations- returns: scalar
f64mean squared error
§Errors
Returns MIError::Config if the last dimension of x != d_in.
Returns MIError::Model on tensor operation failure.
Sourcepub fn decoder_vector(&self, feature_idx: usize) -> Result<Tensor>
pub fn decoder_vector(&self, feature_idx: usize) -> Result<Tensor>
Extract a single feature’s decoder vector (steering direction).
§Shapes
- returns:
[d_in]— decoder vector on the SAE’s device
§Errors
Returns MIError::Config if feature_idx >= d_sae.
Returns MIError::Model on tensor operation failure.
Sourcepub fn prepare_hook_injection(
&self,
features: &[(usize, f32)],
position: usize,
seq_len: usize,
device: &Device,
) -> Result<HookSpec>
pub fn prepare_hook_injection( &self, features: &[(usize, f32)], position: usize, seq_len: usize, device: &Device, ) -> Result<HookSpec>
Build a HookSpec that injects SAE decoder vectors into the model.
Creates an Intervention::Add at this SAE’s hook point with the
accumulated (scaled) decoder vectors placed at the given position.
§Shapes
- Internally constructs
[1, seq_len, d_in]with the vector atposition.
§Arguments
features— List of(feature_index, strength)pairsposition— Token position in the sequence to inject atseq_len— Total sequence lengthdevice— Device to construct injection tensors on
§Errors
Returns MIError::Config if any feature_index >= d_sae.
Returns MIError::Model on tensor construction failure.
Auto Trait Implementations§
impl Freeze for SparseAutoencoder
impl !RefUnwindSafe for SparseAutoencoder
impl Send for SparseAutoencoder
impl Sync for SparseAutoencoder
impl Unpin for SparseAutoencoder
impl UnsafeUnpin for SparseAutoencoder
impl !UnwindSafe for SparseAutoencoder
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more