Skip to main content

candle_mi/
lib.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! # candle-mi
4//!
5//! Mechanistic interpretability for language models in Rust, built on
6//! [candle](https://github.com/huggingface/candle).
7//!
8//! candle-mi re-implements model forward passes with built-in hook points
9//! (following the [`TransformerLens`](https://github.com/TransformerLensOrg/TransformerLens)
10//! design), enabling activation capture, attention knockout, steering, logit
11//! lens, and sparse-feature analysis (CLTs and SAEs) — all in pure Rust with
12//! GPU acceleration.
13//!
14//! ## Supported backends
15//!
16//! | Backend | Models | Feature flag |
17//! |---------|--------|-------------|
18//! | [`GenericTransformer`] | `LLaMA`, `Qwen2`, Gemma, Gemma 2, `Phi-3`, `StarCoder2`, Mistral (+ auto-config for unknown families) | `transformer` |
19//! | `GenericRwkv` | RWKV-6 (Finch), RWKV-7 (Goose) | `rwkv` |
20//!
21//! See [`BACKENDS.md`](https://github.com/PCfVW/candle-mi/blob/main/BACKENDS.md)
22//! for how to add a new model architecture.
23//!
24//! ## Feature flags
25//!
26//! | Feature | Default | Description |
27//! |---------|---------|-------------|
28//! | `transformer` | yes | Generic transformer backend (decoder-only) |
29//! | `cuda` | yes | CUDA GPU acceleration |
30//! | `rwkv` | no | RWKV-6/7 linear RNN backend |
31//! | `rwkv-tokenizer` | no | RWKV world tokenizer (required for RWKV inference) |
32//! | `clt` | no | Cross-Layer Transcoder support |
33//! | `sae` | no | Sparse Autoencoder support |
34//! | `mmap` | no | Memory-mapped weight loading (required for sharded models) |
35//! | `memory` | no | RAM/VRAM memory reporting |
36//! | `probing` | no | Linear probing via linfa (experimental) |
37//! | `metal` | no | Apple Metal GPU acceleration |
38//!
39//! ## Quick start
40//!
41//! Load a model, run a forward pass, and inspect the output:
42//!
43//! ```no_run
44//! use candle_mi::{HookSpec, MIModel};
45//!
46//! # fn main() -> candle_mi::Result<()> {
47//! let model = MIModel::from_pretrained("meta-llama/Llama-3.2-1B")?;
48//! let tokenizer = model.tokenizer().unwrap();
49//!
50//! let tokens = tokenizer.encode("The capital of France is")?;
51//! let input = candle_core::Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
52//!
53//! let cache = model.forward(&input, &HookSpec::new())?;
54//! let logits = cache.output();  // [1, seq, vocab]
55//!
56//! let last_logits = logits.get(0)?.get(tokens.len() - 1)?;
57//! let token_id = candle_mi::sample_token(&last_logits, 0.0)?;  // greedy
58//! println!("{}", tokenizer.decode(&[token_id])?);  // " Paris"
59//! # Ok(())
60//! # }
61//! ```
62//!
63//! ## Activation capture
64//!
65//! Use [`HookSpec::capture`] to snapshot tensors at any
66//! [`HookPoint`] during the forward pass:
67//!
68//! ```no_run
69//! use candle_mi::{HookPoint, HookSpec, MIModel};
70//!
71//! # fn main() -> candle_mi::Result<()> {
72//! # let model = MIModel::from_pretrained("meta-llama/Llama-3.2-1B")?;
73//! # let tokenizer = model.tokenizer().unwrap();
74//! # let tokens = tokenizer.encode("The capital of France is")?;
75//! # let input = candle_core::Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
76//! let mut hooks = HookSpec::new();
77//! hooks.capture(HookPoint::AttnPattern(5))       // post-softmax attention
78//!      .capture(HookPoint::ResidPost(10));        // residual stream at layer 10
79//!
80//! let cache = model.forward(&input, &hooks)?;
81//!
82//! let attn = cache.require(&HookPoint::AttnPattern(5))?;   // [1, heads, seq, seq]
83//! let resid = cache.require(&HookPoint::ResidPost(10))?;    // [1, seq, hidden]
84//! # Ok(())
85//! # }
86//! ```
87//!
88//! ## Interventions
89//!
90//! Use [`HookSpec::intervene`] to modify activations mid-forward-pass.
91//! Five intervention types are available: [`Intervention::Replace`],
92//! [`Intervention::Add`], [`Intervention::Knockout`],
93//! [`Intervention::Scale`], and [`Intervention::Zero`].
94//!
95//! ```no_run
96//! use candle_mi::{HookPoint, HookSpec, Intervention, KnockoutSpec, create_knockout_mask};
97//!
98//! # fn main() -> candle_mi::Result<()> {
99//! # let model = candle_mi::MIModel::from_pretrained("meta-llama/Llama-3.2-1B")?;
100//! # let tokenizer = model.tokenizer().unwrap();
101//! # let tokens = tokenizer.encode("The capital of France is")?;
102//! # let seq_len = tokens.len();
103//! # let input = candle_core::Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
104//! // Knock out the attention edge: last token cannot attend to position 0
105//! let spec = KnockoutSpec::new().layer(8).edge(seq_len - 1, 0);
106//! let mask = create_knockout_mask(
107//!     &spec, model.num_heads(), seq_len, model.device(), candle_core::DType::F32,
108//! )?;
109//!
110//! let mut hooks = HookSpec::new();
111//! hooks.intervene(HookPoint::AttnScores(8), Intervention::Knockout(mask));
112//!
113//! let ablated = model.forward(&input, &hooks)?;
114//! # Ok(())
115//! # }
116//! ```
117//!
118//! ## Logit lens
119//!
120//! Project intermediate residual streams to vocabulary space using
121//! [`MIModel::project_to_vocab`]:
122//!
123//! ```no_run
124//! use candle_mi::{HookPoint, HookSpec, MIModel};
125//!
126//! # fn main() -> candle_mi::Result<()> {
127//! # let model = MIModel::from_pretrained("meta-llama/Llama-3.2-1B")?;
128//! # let tokenizer = model.tokenizer().unwrap();
129//! # let tokens = tokenizer.encode("The capital of France is")?;
130//! # let seq_len = tokens.len();
131//! # let input = candle_core::Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
132//! let mut hooks = HookSpec::new();
133//! for layer in 0..model.num_layers() {
134//!     hooks.capture(HookPoint::ResidPost(layer));
135//! }
136//! let cache = model.forward(&input, &hooks)?;
137//!
138//! for layer in 0..model.num_layers() {
139//!     let resid = cache.require(&HookPoint::ResidPost(layer))?;
140//!     let last = resid.get(0)?.get(seq_len - 1)?.unsqueeze(0)?;
141//!     let logits = model.project_to_vocab(&last)?;
142//!     let token_id = candle_mi::sample_token(&logits.flatten_all()?, 0.0)?;
143//!     println!("Layer {layer:>2}: {}", tokenizer.decode(&[token_id])?);
144//! }
145//! # Ok(())
146//! # }
147//! ```
148//!
149//! ## Fast downloads
150//!
151//! candle-mi uses [`hf-fetch-model`](https://github.com/PCfVW/hf-fetch-model)
152//! for high-throughput parallel downloads from the `HuggingFace` Hub:
153//!
154//! ```rust,no_run
155//! # async fn example() -> candle_mi::Result<()> {
156//! // Async: parallel chunked download with progress bars
157//! let _path = candle_mi::download_model("meta-llama/Llama-3.2-1B".to_owned()).await?;
158//! # Ok(())
159//! # }
160//! ```
161//!
162//! ```no_run
163//! # fn main() -> candle_mi::Result<()> {
164//! // Sync: blocking variant (uses local HF cache if already downloaded)
165//! candle_mi::download_model_blocking("meta-llama/Llama-3.2-1B".to_owned())?;
166//! let model = candle_mi::MIModel::from_pretrained("meta-llama/Llama-3.2-1B")?;
167//! # Ok(())
168//! # }
169//! ```
170//!
171//! ## Further reading
172//!
173//! - [`HOOKS.md`](https://github.com/PCfVW/candle-mi/blob/main/HOOKS.md) —
174//!   complete hook point reference with shapes, intervention walkthrough, and
175//!   worked examples.
176//! - [`BACKENDS.md`](https://github.com/PCfVW/candle-mi/blob/main/BACKENDS.md) —
177//!   how to add a new model architecture (auto-config, config parser, or
178//!   custom `MIBackend`).
179//! - [`examples/README.md`](https://github.com/PCfVW/candle-mi/blob/main/examples/README.md) —
180//!   15 runnable examples covering inference, logit lens, attention patterns,
181//!   knockout, steering, activation patching, CLT circuits, SAE encoding,
182//!   RWKV inference, and more.
183
184#![deny(warnings)] // All warns → errors in CI
185#![cfg_attr(not(any(feature = "mmap", feature = "memory")), forbid(unsafe_code))] // Rule 5: safe by default
186#![cfg_attr(any(feature = "mmap", feature = "memory"), deny(unsafe_code))] // mmap/memory: deny for scoped FFI
187
188pub mod backend;
189pub mod cache;
190#[cfg(feature = "clt")]
191pub mod clt;
192pub mod config;
193pub mod download;
194pub mod error;
195pub mod hooks;
196pub mod interp;
197#[cfg(feature = "memory")]
198pub mod memory;
199#[cfg(feature = "rwkv")]
200pub mod rwkv;
201#[cfg(feature = "sae")]
202pub mod sae;
203pub mod sparse;
204pub mod tokenizer;
205#[cfg(feature = "transformer")]
206pub mod transformer;
207mod util;
208
209// --- Public re-exports ---------------------------------------------------
210
211// Backend
212pub use backend::{GenerationResult, MIBackend, MIModel, extract_token_prob, sample_token};
213
214// Config
215pub use config::{
216    Activation, CompatibilityReport, MlpLayout, NormType, QkvLayout, SUPPORTED_MODEL_TYPES,
217    TransformerConfig,
218};
219
220// Transformer backend
221#[cfg(feature = "transformer")]
222pub use transformer::GenericTransformer;
223
224// Recurrent feedback (anacrousis)
225#[cfg(feature = "transformer")]
226pub use transformer::recurrent::{RecurrentFeedbackEntry, RecurrentPassSpec};
227
228// RWKV backend
229#[cfg(feature = "rwkv")]
230pub use rwkv::{GenericRwkv, RwkvConfig, RwkvLoraDims, RwkvVersion};
231
232// Sparse feature types (shared by CLT and SAE)
233pub use sparse::{FeatureId, SparseActivations};
234
235// CLT (Cross-Layer Transcoder)
236#[cfg(feature = "clt")]
237pub use clt::{AttributionEdge, AttributionGraph, CltConfig, CltFeatureId, CrossLayerTranscoder};
238
239// SAE (Sparse Autoencoder)
240#[cfg(feature = "sae")]
241pub use sae::{
242    NormalizeActivations, SaeArchitecture, SaeConfig, SaeFeatureId, SparseAutoencoder, TopKStrategy,
243};
244
245// Cache
246pub use cache::{ActivationCache, AttentionCache, FullActivationCache, KVCache};
247
248// Error
249pub use error::{MIError, Result};
250
251// Hooks
252pub use hooks::{HookCache, HookPoint, HookSpec, Intervention};
253
254// Interpretability — intervention specs and results
255pub use interp::intervention::{
256    AblationResult, AttentionEdge, HeadSpec, InterventionType, KnockoutSpec, LayerSpec,
257    StateAblationResult, StateKnockoutSpec, StateSteeringResult, StateSteeringSpec, SteeringResult,
258    SteeringSpec, apply_steering, create_knockout_mask, kl_divergence,
259    measure_attention_to_targets,
260};
261
262// Interpretability — logit lens
263pub use interp::logit_lens::{LogitLensAnalysis, LogitLensResult, TokenPrediction};
264
265// Interpretability — steering calibration
266pub use interp::steering::{DoseResponseCurve, DoseResponsePoint, SteeringCalibration};
267
268// Utility — masks
269pub use util::masks::{clear_mask_caches, create_causal_mask, create_generation_mask};
270
271// Utility — PCA
272pub use util::pca::{PcaResult, pca_top_k};
273
274// Utility — positioning
275pub use util::positioning::{
276    EncodingWithOffsets, PositionConversion, TokenWithOffset, convert_positions,
277};
278
279// Tokenizer
280pub use tokenizer::MITokenizer;
281
282// Memory reporting
283#[cfg(feature = "memory")]
284pub use memory::{MemoryReport, MemorySnapshot, sync_and_trim_gpu};
285
286// Download
287pub use download::{download_model, download_model_blocking};