candle-mi 0.1.0

Mechanistic interpretability for language models in Rust, built on candle
docs.rs failed to build candle-mi-0.1.0
Please check the build logs for more information.
See Builds for ideas on how to fix a failed build, or Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault, open an issue.
Visit the last successful build: candle-mi-0.1.3

candle-mi

CI Crates.io docs.rs Rust 1.87+ Edition 2024 License: MIT OR Apache-2.0 GitHub last commit

Mechanistic Interpretability for the Rust of us.

Note: v0.1.0 — the API may change between minor versions. See the CHANGELOG.

Table of Contents

What is this?

Mechanistic interpretability (MI) is the study of how a language model arrives at its predictions — not just what it outputs, but what happens inside. By inspecting and manipulating the model's internal activations (attention patterns, residual streams, MLP outputs), researchers can understand which components drive specific behaviors.

candle-mi is a Rust library that makes this possible. It re-implements model forward passes with built-in hook points — named locations where you can:

  • Capture activations (e.g., "what does the attention pattern look like at layer 5?")
  • Intervene on activations mid-forward-pass (e.g., "what happens if I knock out this attention edge?" or "what if I steer the residual stream toward a concept?")

This is the Rust equivalent of Python's TransformerLens, built on candle for GPU acceleration. The hook system is type-safe (typos are caught at compile time, not silently ignored at runtime) and zero-overhead (an empty hook spec adds no allocations or clones to the forward pass).

Why Rust? Running published MI experiments — such as Anthropic's Scaling Monosemanticity or Planning in poems — quickly hits the limits of CPU-only Python. Cloud GPUs are always an option, but not a frugal one. With a consumer-grade GPU, memory and runtime become the real bottleneck. candle solves both: Rust's zero-cost abstractions minimize memory overhead, compiled code runs faster, and candle provides direct CUDA/Metal access without Python's runtime tax. That's how candle-mi started: let's bring MI to local hardware. (See the Figure 13 replication for a concrete example — Planning in poems reproduced on a consumer GPU.)

What can you do with it?

Technique What it does Example
Logit lens See what the model "thinks" at each layer by projecting intermediate residual streams to vocabulary space logit_lens
Attention knockout Block specific attention edges (e.g., "token 5 cannot attend to token 0") and measure how predictions change attention_knockout
Activation steering Add a direction vector to the residual stream to shift model behavior (e.g., make it more positive or more formal) steering_dose_response
Activation patching Swap activations between a clean and corrupted run to identify which components causally drive a prediction activation_patching
Attention patterns Visualize where each attention head attends across the sequence attention_patterns
RWKV state analysis Inspect and intervene on recurrent state — not just transformers rwkv_inference

Quick start

use candle_mi::{HookSpec, MIModel};

fn main() -> candle_mi::Result<()> {
    // 1. Load a model (auto-detects architecture from HuggingFace config)
    let model = MIModel::from_pretrained("meta-llama/Llama-3.2-1B")?;
    let tokenizer = model.tokenizer().unwrap();

    // 2. Tokenize a prompt
    let tokens = tokenizer.encode("The capital of France is")?;
    let input = candle_core::Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;

    // 3. Run a forward pass (HookSpec::new() = no hooks, zero overhead)
    let cache = model.forward(&input, &HookSpec::new())?;
    let logits = cache.output();  // [1, seq_len, vocab_size]

    // 4. Decode the top prediction
    let last_logits = logits.get(0)?.get(tokens.len() - 1)?;
    let token_id = candle_mi::sample_token(&last_logits, 0.0)?;  // greedy
    println!("{}", tokenizer.decode(&[token_id])?);  // " Paris"
    Ok(())
}

Here is what an end-to-end run looks like (auto-config loading LLaMA 3.2 1B — config detection, forward pass, and top-5 predictions):

Supported models

Backend Model families Validated models Feature flag
GenericTransformer LLaMA 1/2/3, Qwen 2/2.5, Gemma, Gemma 2, Phi-3/4, StarCoder2, Mistral LLaMA 3.2 1B, Qwen2.5-Coder-3B, Gemma 2 2B, Phi-3 Mini, StarCoder2 3B, Mistral 7B v0.1 transformer (default)
GenericRwkv RWKV-6 (Finch), RWKV-7 (Goose) RWKV-7 1.6B rwkv

Model families are what the config parser accepts (any model reporting that model_type in its HuggingFace config.json). Validated models are those verified to match Python/PyTorch reference output. Most other HuggingFace transformer models work out of the box via auto-config — no code changes needed. See BACKENDS.md for details.

Feature flags

Feature Default Description
transformer yes Generic transformer backend (decoder-only)
cuda yes CUDA GPU acceleration
rwkv no RWKV-6/7 linear RNN backend
rwkv-tokenizer no RWKV world tokenizer (required for RWKV inference)
clt no Cross-Layer Transcoder support
sae no Sparse Autoencoder support
mmap no Memory-mapped weight loading (required for sharded models)
memory no RAM/VRAM memory reporting
probing no Linear probing via linfa (experimental)
metal no Apple Metal GPU acceleration

Documentation

Document Description
API docs (docs.rs) Crate-level documentation with quick start and examples
HOOKS.md Hook point reference, intervention API walkthrough, and worked examples
BACKENDS.md How to add a new model architecture (auto-config, config parser, custom backend)
examples/README.md 15 runnable examples covering inference, logit lens, knockout, steering, and more
CHANGELOG.md Release history
ROADMAP.md Development roadmap and architecture decisions

License

MIT OR Apache-2.0

Development

Exclusively developed with Claude Code (dev) and Augment Code (review). Git workflow managed with Fork.