Expand description
§candle-mi
Mechanistic interpretability for language models in Rust, built on candle.
candle-mi re-implements model forward passes with built-in hook points
(following the TransformerLens
design), enabling activation capture, attention knockout, steering, logit
lens, and sparse-feature analysis (CLTs and SAEs) — all in pure Rust with
GPU acceleration.
§Supported backends
| Backend | Models | Feature flag |
|---|---|---|
GenericTransformer | LLaMA, Qwen2, Gemma, Gemma 2, Phi-3, StarCoder2, Mistral (+ auto-config for unknown families) | transformer |
GenericRwkv | RWKV-6 (Finch), RWKV-7 (Goose) | rwkv |
See BACKENDS.md
for how to add a new model architecture.
§Feature flags
| Feature | Default | Description |
|---|---|---|
transformer | yes | Generic transformer backend (decoder-only) |
cuda | yes | CUDA GPU acceleration |
rwkv | no | RWKV-6/7 linear RNN backend |
rwkv-tokenizer | no | RWKV world tokenizer (required for RWKV inference) |
clt | no | Cross-Layer Transcoder support |
sae | no | Sparse Autoencoder support |
mmap | no | Memory-mapped weight loading (required for sharded models) |
memory | no | RAM/VRAM memory reporting |
probing | no | Linear probing via linfa (experimental) |
metal | no | Apple Metal GPU acceleration |
§Quick start
Load a model, run a forward pass, and inspect the output:
use candle_mi::{HookSpec, MIModel};
let model = MIModel::from_pretrained("meta-llama/Llama-3.2-1B")?;
let tokenizer = model.tokenizer().unwrap();
let tokens = tokenizer.encode("The capital of France is")?;
let input = candle_core::Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
let cache = model.forward(&input, &HookSpec::new())?;
let logits = cache.output(); // [1, seq, vocab]
let last_logits = logits.get(0)?.get(tokens.len() - 1)?;
let token_id = candle_mi::sample_token(&last_logits, 0.0)?; // greedy
println!("{}", tokenizer.decode(&[token_id])?); // " Paris"§Activation capture
Use HookSpec::capture to snapshot tensors at any
HookPoint during the forward pass:
use candle_mi::{HookPoint, HookSpec, MIModel};
let mut hooks = HookSpec::new();
hooks.capture(HookPoint::AttnPattern(5)) // post-softmax attention
.capture(HookPoint::ResidPost(10)); // residual stream at layer 10
let cache = model.forward(&input, &hooks)?;
let attn = cache.require(&HookPoint::AttnPattern(5))?; // [1, heads, seq, seq]
let resid = cache.require(&HookPoint::ResidPost(10))?; // [1, seq, hidden]§Interventions
Use HookSpec::intervene to modify activations mid-forward-pass.
Five intervention types are available: Intervention::Replace,
Intervention::Add, Intervention::Knockout,
Intervention::Scale, and Intervention::Zero.
use candle_mi::{HookPoint, HookSpec, Intervention, KnockoutSpec, create_knockout_mask};
// Knock out the attention edge: last token cannot attend to position 0
let spec = KnockoutSpec::new().layer(8).edge(seq_len - 1, 0);
let mask = create_knockout_mask(
&spec, model.num_heads(), seq_len, model.device(), candle_core::DType::F32,
)?;
let mut hooks = HookSpec::new();
hooks.intervene(HookPoint::AttnScores(8), Intervention::Knockout(mask));
let ablated = model.forward(&input, &hooks)?;§Logit lens
Project intermediate residual streams to vocabulary space using
MIModel::project_to_vocab:
use candle_mi::{HookPoint, HookSpec, MIModel};
let mut hooks = HookSpec::new();
for layer in 0..model.num_layers() {
hooks.capture(HookPoint::ResidPost(layer));
}
let cache = model.forward(&input, &hooks)?;
for layer in 0..model.num_layers() {
let resid = cache.require(&HookPoint::ResidPost(layer))?;
let last = resid.get(0)?.get(seq_len - 1)?.unsqueeze(0)?;
let logits = model.project_to_vocab(&last)?;
let token_id = candle_mi::sample_token(&logits.flatten_all()?, 0.0)?;
println!("Layer {layer:>2}: {}", tokenizer.decode(&[token_id])?);
}§Fast downloads
candle-mi uses hf-fetch-model
for high-throughput parallel downloads from the HuggingFace Hub:
// Async: parallel chunked download with progress bars
let _path = candle_mi::download_model("meta-llama/Llama-3.2-1B".to_owned()).await?;// Sync: blocking variant (uses local HF cache if already downloaded)
candle_mi::download_model_blocking("meta-llama/Llama-3.2-1B".to_owned())?;
let model = candle_mi::MIModel::from_pretrained("meta-llama/Llama-3.2-1B")?;§Further reading
HOOKS.md— complete hook point reference with shapes, intervention walkthrough, and worked examples.BACKENDS.md— how to add a new model architecture (auto-config, config parser, or customMIBackend).examples/README.md— 15 runnable examples covering inference, logit lens, attention patterns, knockout, steering, activation patching, CLT circuits, SAE encoding, RWKV inference, and more.
Re-exports§
pub use backend::GenerationResult;pub use backend::MIBackend;pub use backend::MIModel;pub use backend::extract_token_prob;pub use backend::sample_token;pub use config::Activation;pub use config::CompatibilityReport;pub use config::MlpLayout;pub use config::NormType;pub use config::QkvLayout;pub use config::SUPPORTED_MODEL_TYPES;pub use config::TransformerConfig;pub use transformer::GenericTransformer;pub use rwkv::GenericRwkv;pub use rwkv::RwkvConfig;pub use rwkv::RwkvLoraDims;pub use rwkv::RwkvVersion;pub use sparse::FeatureId;pub use sparse::SparseActivations;pub use clt::AttributionEdge;pub use clt::AttributionGraph;pub use clt::CltConfig;pub use clt::CltFeatureId;pub use clt::CrossLayerTranscoder;pub use sae::NormalizeActivations;pub use sae::SaeArchitecture;pub use sae::SaeConfig;pub use sae::SaeFeatureId;pub use sae::SparseAutoencoder;pub use sae::TopKStrategy;pub use cache::ActivationCache;pub use cache::AttentionCache;pub use cache::FullActivationCache;pub use cache::KVCache;pub use error::MIError;pub use error::Result;pub use hooks::HookCache;pub use hooks::HookPoint;pub use hooks::HookSpec;pub use hooks::Intervention;pub use interp::intervention::AblationResult;pub use interp::intervention::AttentionEdge;pub use interp::intervention::HeadSpec;pub use interp::intervention::InterventionType;pub use interp::intervention::KnockoutSpec;pub use interp::intervention::LayerSpec;pub use interp::intervention::StateAblationResult;pub use interp::intervention::StateKnockoutSpec;pub use interp::intervention::StateSteeringResult;pub use interp::intervention::StateSteeringSpec;pub use interp::intervention::SteeringResult;pub use interp::intervention::SteeringSpec;pub use interp::intervention::apply_steering;pub use interp::intervention::create_knockout_mask;pub use interp::intervention::kl_divergence;pub use interp::intervention::measure_attention_to_targets;pub use interp::logit_lens::LogitLensAnalysis;pub use interp::logit_lens::LogitLensResult;pub use interp::logit_lens::TokenPrediction;pub use interp::steering::DoseResponseCurve;pub use interp::steering::DoseResponsePoint;pub use interp::steering::SteeringCalibration;pub use tokenizer::MITokenizer;pub use memory::MemoryReport;pub use memory::MemorySnapshot;pub use download::download_model;pub use download::download_model_blocking;
Modules§
- backend
- Core backend trait and model wrapper.
- cache
- Activation, attention, and KV caching for efficient forward passes.
- clt
- Cross-Layer Transcoder (CLT) support.
- config
- Transformer configuration and
HuggingFaceconfig.jsonparsing. - download
- Fast model download via
hf-fetch-model. - error
- Error types for candle-mi.
- hooks
- Hook system for activation capture and intervention.
- interp
- Interpretability tools: intervention, logit lens, steering calibration.
- memory
- Process and GPU memory reporting.
- rwkv
- RWKV gated-linear RNN backend.
- sae
- Sparse Autoencoder (SAE) support.
- sparse
- Shared sparse-feature types used by both CLT and SAE modules.
- tokenizer
- Tokenizer abstraction: dispatch between
HuggingFaceand RWKV backends. - transformer
- Generic transformer implementation.
Structs§
- Encoding
With Offsets - Encoding result with tokens and their character offsets.
- PcaResult
- Result of a PCA decomposition via power iteration.
- Position
Conversion - Result of converting a character position to a token index.
- Recurrent
Feedback Entry - A single feedback injection between recurrent passes.
- Recurrent
Pass Spec - Specification for a recurrent multi-pass forward through a layer block.
- Token
With Offset - Token with its character offset range.
Functions§
- clear_
mask_ caches - Clear all cached masks.
- convert_
positions - Convert multiple character positions to token indices.
- create_
causal_ mask - Create or retrieve a cached causal mask for the given sequence length.
- create_
generation_ mask - Create a causal mask for generation with KV-cache.
- pca_
top_ k - Compute the top
kprincipal components ofmatrixvia power iteration with deflation on the kernel matrix.