Skip to main content

Crate candle_mi

Crate candle_mi 

Source
Expand description

§candle-mi

Mechanistic interpretability for language models in Rust, built on candle.

candle-mi re-implements model forward passes with built-in hook points (following the TransformerLens design), enabling activation capture, attention knockout, steering, logit lens, and sparse-feature analysis (CLTs and SAEs) — all in pure Rust with GPU acceleration.

§Supported backends

BackendModelsFeature flag
GenericTransformerLLaMA, Qwen2, Gemma, Gemma 2, Phi-3, StarCoder2, Mistral (+ auto-config for unknown families)transformer
GenericRwkvRWKV-6 (Finch), RWKV-7 (Goose)rwkv

See BACKENDS.md for how to add a new model architecture.

§Feature flags

FeatureDefaultDescription
transformeryesGeneric transformer backend (decoder-only)
cudayesCUDA GPU acceleration
rwkvnoRWKV-6/7 linear RNN backend
rwkv-tokenizernoRWKV world tokenizer (required for RWKV inference)
cltnoCross-Layer Transcoder support
saenoSparse Autoencoder support
mmapnoMemory-mapped weight loading (required for sharded models)
memorynoRAM/VRAM memory reporting
probingnoLinear probing via linfa (experimental)
metalnoApple Metal GPU acceleration

§Quick start

Load a model, run a forward pass, and inspect the output:

use candle_mi::{HookSpec, MIModel};

let model = MIModel::from_pretrained("meta-llama/Llama-3.2-1B")?;
let tokenizer = model.tokenizer().unwrap();

let tokens = tokenizer.encode("The capital of France is")?;
let input = candle_core::Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;

let cache = model.forward(&input, &HookSpec::new())?;
let logits = cache.output();  // [1, seq, vocab]

let last_logits = logits.get(0)?.get(tokens.len() - 1)?;
let token_id = candle_mi::sample_token(&last_logits, 0.0)?;  // greedy
println!("{}", tokenizer.decode(&[token_id])?);  // " Paris"

§Activation capture

Use HookSpec::capture to snapshot tensors at any HookPoint during the forward pass:

use candle_mi::{HookPoint, HookSpec, MIModel};

let mut hooks = HookSpec::new();
hooks.capture(HookPoint::AttnPattern(5))       // post-softmax attention
     .capture(HookPoint::ResidPost(10));        // residual stream at layer 10

let cache = model.forward(&input, &hooks)?;

let attn = cache.require(&HookPoint::AttnPattern(5))?;   // [1, heads, seq, seq]
let resid = cache.require(&HookPoint::ResidPost(10))?;    // [1, seq, hidden]

§Interventions

Use HookSpec::intervene to modify activations mid-forward-pass. Five intervention types are available: Intervention::Replace, Intervention::Add, Intervention::Knockout, Intervention::Scale, and Intervention::Zero.

use candle_mi::{HookPoint, HookSpec, Intervention, KnockoutSpec, create_knockout_mask};

// Knock out the attention edge: last token cannot attend to position 0
let spec = KnockoutSpec::new().layer(8).edge(seq_len - 1, 0);
let mask = create_knockout_mask(
    &spec, model.num_heads(), seq_len, model.device(), candle_core::DType::F32,
)?;

let mut hooks = HookSpec::new();
hooks.intervene(HookPoint::AttnScores(8), Intervention::Knockout(mask));

let ablated = model.forward(&input, &hooks)?;

§Logit lens

Project intermediate residual streams to vocabulary space using MIModel::project_to_vocab:

use candle_mi::{HookPoint, HookSpec, MIModel};

let mut hooks = HookSpec::new();
for layer in 0..model.num_layers() {
    hooks.capture(HookPoint::ResidPost(layer));
}
let cache = model.forward(&input, &hooks)?;

for layer in 0..model.num_layers() {
    let resid = cache.require(&HookPoint::ResidPost(layer))?;
    let last = resid.get(0)?.get(seq_len - 1)?.unsqueeze(0)?;
    let logits = model.project_to_vocab(&last)?;
    let token_id = candle_mi::sample_token(&logits.flatten_all()?, 0.0)?;
    println!("Layer {layer:>2}: {}", tokenizer.decode(&[token_id])?);
}

§Fast downloads

candle-mi uses hf-fetch-model for high-throughput parallel downloads from the HuggingFace Hub:

// Async: parallel chunked download with progress bars
let _path = candle_mi::download_model("meta-llama/Llama-3.2-1B".to_owned()).await?;
// Sync: blocking variant (uses local HF cache if already downloaded)
candle_mi::download_model_blocking("meta-llama/Llama-3.2-1B".to_owned())?;
let model = candle_mi::MIModel::from_pretrained("meta-llama/Llama-3.2-1B")?;

§Further reading

  • HOOKS.md — complete hook point reference with shapes, intervention walkthrough, and worked examples.
  • BACKENDS.md — how to add a new model architecture (auto-config, config parser, or custom MIBackend).
  • examples/README.md — 15 runnable examples covering inference, logit lens, attention patterns, knockout, steering, activation patching, CLT circuits, SAE encoding, RWKV inference, and more.

Re-exports§

pub use backend::GenerationResult;
pub use backend::MIBackend;
pub use backend::MIModel;
pub use backend::extract_token_prob;
pub use backend::sample_token;
pub use config::Activation;
pub use config::CompatibilityReport;
pub use config::MlpLayout;
pub use config::NormType;
pub use config::QkvLayout;
pub use config::SUPPORTED_MODEL_TYPES;
pub use config::TransformerConfig;
pub use transformer::GenericTransformer;
pub use rwkv::GenericRwkv;
pub use rwkv::RwkvConfig;
pub use rwkv::RwkvLoraDims;
pub use rwkv::RwkvVersion;
pub use sparse::FeatureId;
pub use sparse::SparseActivations;
pub use clt::AttributionEdge;
pub use clt::AttributionGraph;
pub use clt::CltConfig;
pub use clt::CltFeatureId;
pub use clt::CrossLayerTranscoder;
pub use sae::NormalizeActivations;
pub use sae::SaeArchitecture;
pub use sae::SaeConfig;
pub use sae::SaeFeatureId;
pub use sae::SparseAutoencoder;
pub use sae::TopKStrategy;
pub use cache::ActivationCache;
pub use cache::AttentionCache;
pub use cache::FullActivationCache;
pub use cache::KVCache;
pub use error::MIError;
pub use error::Result;
pub use hooks::HookCache;
pub use hooks::HookPoint;
pub use hooks::HookSpec;
pub use hooks::Intervention;
pub use interp::intervention::AblationResult;
pub use interp::intervention::AttentionEdge;
pub use interp::intervention::HeadSpec;
pub use interp::intervention::InterventionType;
pub use interp::intervention::KnockoutSpec;
pub use interp::intervention::LayerSpec;
pub use interp::intervention::StateAblationResult;
pub use interp::intervention::StateKnockoutSpec;
pub use interp::intervention::StateSteeringResult;
pub use interp::intervention::StateSteeringSpec;
pub use interp::intervention::SteeringResult;
pub use interp::intervention::SteeringSpec;
pub use interp::intervention::apply_steering;
pub use interp::intervention::create_knockout_mask;
pub use interp::intervention::kl_divergence;
pub use interp::intervention::measure_attention_to_targets;
pub use interp::logit_lens::LogitLensAnalysis;
pub use interp::logit_lens::LogitLensResult;
pub use interp::logit_lens::TokenPrediction;
pub use interp::steering::DoseResponseCurve;
pub use interp::steering::DoseResponsePoint;
pub use interp::steering::SteeringCalibration;
pub use tokenizer::MITokenizer;
pub use memory::MemoryReport;
pub use memory::MemorySnapshot;
pub use download::download_model;
pub use download::download_model_blocking;

Modules§

backend
Core backend trait and model wrapper.
cache
Activation, attention, and KV caching for efficient forward passes.
clt
Cross-Layer Transcoder (CLT) support.
config
Transformer configuration and HuggingFace config.json parsing.
download
Fast model download via hf-fetch-model.
error
Error types for candle-mi.
hooks
Hook system for activation capture and intervention.
interp
Interpretability tools: intervention, logit lens, steering calibration.
memory
Process and GPU memory reporting.
rwkv
RWKV gated-linear RNN backend.
sae
Sparse Autoencoder (SAE) support.
sparse
Shared sparse-feature types used by both CLT and SAE modules.
tokenizer
Tokenizer abstraction: dispatch between HuggingFace and RWKV backends.
transformer
Generic transformer implementation.

Structs§

EncodingWithOffsets
Encoding result with tokens and their character offsets.
PcaResult
Result of a PCA decomposition via power iteration.
PositionConversion
Result of converting a character position to a token index.
RecurrentFeedbackEntry
A single feedback injection between recurrent passes.
RecurrentPassSpec
Specification for a recurrent multi-pass forward through a layer block.
TokenWithOffset
Token with its character offset range.

Functions§

clear_mask_caches
Clear all cached masks.
convert_positions
Convert multiple character positions to token indices.
create_causal_mask
Create or retrieve a cached causal mask for the given sequence length.
create_generation_mask
Create a causal mask for generation with KV-cache.
pca_top_k
Compute the top k principal components of matrix via power iteration with deflation on the kernel matrix.