candle-mi 0.1.3

Mechanistic interpretability for language models in Rust, built on candle
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
# How to Add a New Model Architecture

> This guide explains how candle-mi supports new model families.  There
> are three paths — from easiest (auto-config) to most flexible (custom
> backend) — depending on how close the new model is to a standard
> decoder-only transformer.

---

## Table of Contents

- [Overview: Three Paths]#overview-three-paths
- [Path 1: Auto-Config (Zero Code)]#path-1-auto-config-zero-code
  - [What Auto-Config Detects]#what-auto-config-detects
  - [Compatibility Check]#compatibility-check
  - [Limitations]#limitations
- [Path 2: Config Parser (Recommended for Known Families)]#path-2-config-parser-recommended-for-known-families
  - [Step 1: Identify Configuration Axes]#step-1-identify-configuration-axes
  - [Step 2: Write the Parser]#step-2-write-the-parser
  - [Step 3: Register the Model Type]#step-3-register-the-model-type
  - [Step 4: Validate Against Python]#step-4-validate-against-python
- [Path 3: Custom MIBackend (Non-Transformer Architectures)]#path-3-custom-mibackend-non-transformer-architectures
  - [The MIBackend Trait]#the-mibackend-trait
  - [Implementing forward()]#implementing-forward
  - [Hook Integration Checklist]#hook-integration-checklist
  - [Registering with MIModel::from_pretrained()]#registering-with-mimodelfrom_pretrained
- [TransformerConfig Reference]#transformerconfig-reference
  - [Configuration Axes]#configuration-axes
  - [Existing Parsers as Templates]#existing-parsers-as-templates
- [Testing Checklist]#testing-checklist
- [Weight Naming Convention]#weight-naming-convention

---

## Overview: Three Paths

| Path | When to use | Effort | Hook support |
|------|-------------|--------|--------------|
| **Auto-config** | Model uses HuggingFace-standard weight naming and is a standard decoder-only transformer | None | Full (automatic) |
| **Config parser** | Model is a decoder-only transformer with quirks (special norm, bias pattern, etc.) | ~30 lines | Full (automatic) |
| **Custom `MIBackend`** | Model is not a decoder-only transformer (e.g., RWKV, Mamba, encoder-decoder) | ~500+ lines | Manual (you place hooks) |

Most HuggingFace transformer models work out of the box via auto-config.
If yours doesn't, a config parser is usually enough.  A custom backend is
only needed for fundamentally different architectures.

---

## Path 1: Auto-Config (Zero Code)

When `MIModel::from_pretrained()` encounters an unknown `model_type` in
`config.json`, it attempts to infer a `TransformerConfig` automatically:

```rust
// Works for any model with standard HF weight naming
let model = MIModel::from_pretrained("some-org/novel-transformer-3B")?;
```

### What Auto-Config Detects

Auto-config reads configuration from two sources:

**Tier 1–2: `config.json` scalars** (same fields as known families):
- `hidden_size`, `num_hidden_layers`, `num_attention_heads`, `intermediate_size`, `vocab_size`
- `num_key_value_heads`, `head_dim`, `rms_norm_eps`, `rope_theta`, `max_position_embeddings`
- `tie_word_embeddings`, `attention_bias`

**Tier 3: safetensors tensor names** (structural inference):
- **QKV layout**: `qkv_proj``Fused`; `q_proj` + `k_proj` + `v_proj``Separate`
- **MLP layout**: `gate_up_proj``GatedFused`; `gate_proj` + `up_proj``GatedSeparate`; `fc1` + `fc2``Plain`
- **Bias flags**: presence of `.bias` tensors in attention/MLP projections
- **Norm type**: `input_layernorm.bias``LayerNorm`; otherwise `RmsNorm`

**Tier 4: `model_type` fixups** (architecture-specific overrides):
- Gemma/Gemma2 → `GemmaRmsNorm`, `embedding_scale`, `alternating_sliding_window`
- Any model with `attn_logit_softcapping` → soft-capping enabled

### Compatibility Check

Before loading weights, `check_auto_compatibility()` runs a preflight
check that detects incompatible models:

```rust
use candle_mi::{CompatibilityReport, TransformerConfig};

let report = TransformerConfig::check_auto_compatibility(&config_json, &tensor_names);
if !report.compatible {
    for issue in &report.issues {
        eprintln!("  - {issue}");
    }
}
```

The check verifies:
- Required `config.json` fields are present
- Expected weight tensors exist (`embed_tokens`, `input_layernorm`, `lm_head`, etc.)
- When tensors are missing, it suggests the closest match ("did you mean?")
- Detects non-HF naming conventions (e.g., GGUF, custom prefixes)

### Limitations

Auto-config does **not** handle:
- Non-standard weight naming (e.g., `transformer.h.0.attn` instead of `model.layers.0.self_attn`)
- Architectures that aren't decoder-only transformers (RWKV, Mamba, encoder-decoder)
- Novel attention mechanisms (linear attention, local-global hybrid)
- Custom activations not in the standard set (`SiLU`, `GELU`, `GELU tanh`)

If auto-config fails, write a config parser (Path 2) or a custom backend
(Path 3).

### What failure looks like

The `auto_config_dogfood` example demonstrates both success and failure modes:

```bash
# Success — known family, uses manual parser
cargo run --release --features transformer --example auto_config_dogfood -- "meta-llama/Llama-3.2-1B"

# Failure — unsupported architecture (weight name mismatch)
cargo run --release --features transformer --example auto_config_dogfood -- "allenai/OLMo-1B-hf"

# Failure — actionable diagnostics (non-standard naming convention)
cargo run --release --features transformer --example auto_config_dogfood -- "EleutherAI/pythia-1.4b"
```

**OLMo-1B** fails the compatibility check because its weight names
(`model.layers.*.input_layernorm.weight`, `model.final_norm.weight`) do not
match the normalisation tensor patterns that `GenericTransformer` expects.

**Pythia 1.4B** uses the `gpt_neox.layers.{i}` weight prefix instead of the
HF-standard `model.layers.{i}`. The error message shows which tensors
*were* found for each expected category (embedding, norm, attention, MLP),
detects the GPT-NeoX / Pythia naming convention, and points to Phase 9
(tensor name remapping) for planned support. This is the diagnostic output
that tells contributors exactly where to look when adding a new model family.

---

## Path 2: Config Parser (Recommended for Known Families)

If the model is a decoder-only transformer but auto-config fails (or you
want guaranteed correctness), write a dedicated config parser.  This is
typically ~30 lines of code.

### Step 1: Identify Configuration Axes

Compare the new model to the closest existing family.  The axes to check:

| Axis | Question | Where to look |
|------|----------|---------------|
| Norm type | RmsNorm, LayerNorm, or GemmaRmsNorm? | `config.json` or paper |
| Activation | SiLU, GELU, or GELU tanh? | `hidden_act` field |
| QKV layout | Separate or fused? | Weight tensor names |
| MLP layout | GatedSeparate, GatedFused, or Plain? | Weight tensor names |
| QKV bias | Do Q, K, V have bias terms? | Weight tensor names |
| Use bias | Does every projection have bias? | Weight tensor names |
| Post-norms | 2 or 4 norms per layer? | Weight tensor names |
| Embedding scale | Multiply embeddings by `sqrt(hidden_size)`? | Paper or code |
| Soft-capping | Pre-softmax logit capping? | `attn_logit_softcapping` |
| Tied embeddings | `lm_head.weight` shared with `embed_tokens`? | `tie_word_embeddings` |
| Sliding window | Alternating full/sliding attention? | `sliding_window` |

### Step 2: Write the Parser

Add a `parse_*` function in `src/config.rs`.  Use an existing parser as a
template — `parse_llama` is the simplest baseline:

```rust
/// Parse a Llama-family `config.json`.
fn parse_llama(config: &Value) -> Result<TransformerConfig> {
    let base = parse_common_fields(config)?;
    Ok(TransformerConfig {
        norm_type: NormType::RmsNorm,
        activation: Activation::Silu,
        qkv_layout: QkvLayout::Separate,
        mlp_layout: MlpLayout::GatedSeparate,
        qkv_bias: false,
        use_bias: false,
        use_post_norms: false,
        embedding_scale: None,
        attn_logit_softcapping: None,
        query_pre_attn_scalar: None,
        alternating_sliding_window: false,
        ..base
    })
}
```

For a model with quirks, override only what differs:

```rust
fn parse_qwen2(config: &Value) -> Result<TransformerConfig> {
    let base = parse_common_fields(config)?;
    // Qwen2: attention_bias applies to Q, K, V (NOT o_proj)
    let qkv_bias = config
        .get("attention_bias")
        .and_then(Value::as_bool)
        .unwrap_or(true);
    Ok(TransformerConfig {
        norm_type: NormType::RmsNorm,
        activation: Activation::Silu,
        qkv_layout: QkvLayout::Separate,
        mlp_layout: MlpLayout::GatedSeparate,
        qkv_bias,
        use_bias: false,
        use_post_norms: false,
        embedding_scale: None,
        attn_logit_softcapping: None,
        query_pre_attn_scalar: None,
        alternating_sliding_window: false,
        ..base
    })
}
```

### Step 3: Register the Model Type

1. Add the `model_type` string to `SUPPORTED_MODEL_TYPES` in `src/config.rs`:

```rust
pub const SUPPORTED_MODEL_TYPES: &[&str] = &[
    "gemma", "gemma2", "llama", "mistral",
    "my_new_model",  // ← add here
    "phi3", "qwen2", "starcoder2",
];
```

2. Add a match arm in `from_hf_config()`:

```rust
"my_new_model" => parse_my_new_model(config),
```

### Step 4: Validate Against Python

Run the model in both Python (HuggingFace Transformers) and Rust, and
compare logits:

```python
# Python reference
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("org/model", torch_dtype=torch.float32)
tokens = tokenizer("The capital of France is", return_tensors="pt")
logits = model(**tokens).logits[0, -1, :]  # last position
top5 = torch.topk(logits, 5)
```

```rust
// Rust — should match Python's top-5 exactly
let model = MIModel::from_pretrained("org/model")?;
let tokens = tokenizer.encode("The capital of France is")?;
// ... forward pass, extract last-position logits, compare top-5
```

With F32 on both sides, logits should match to ~6 decimal places.

---

## Path 3: Custom MIBackend (Non-Transformer Architectures)

For architectures that aren't decoder-only transformers (RWKV, Mamba,
encoder-decoder, etc.), implement the `MIBackend` trait directly.

### The MIBackend Trait

```rust
pub trait MIBackend: Send + Sync {
    // --- Required: metadata ---
    fn num_layers(&self) -> usize;
    fn hidden_size(&self) -> usize;
    fn vocab_size(&self) -> usize;
    fn num_heads(&self) -> usize;

    // --- Required: forward pass ---
    fn forward(&self, input_ids: &Tensor, hooks: &HookSpec) -> Result<HookCache>;

    // --- Required: logit projection ---
    fn project_to_vocab(&self, hidden: &Tensor) -> Result<Tensor>;

    // --- Optional ---
    fn chat_template(&self, _prompt: &str, _system_prompt: Option<&str>) -> Option<String> {
        None
    }
    fn embedding_vector(&self, _token_id: u32) -> Result<Tensor> {
        Err(MIError::Hook("embedding_vector not supported".into()))
    }
}
```

**Contract:**
- `forward()` must return a `HookCache` with logits at shape `[batch, seq, vocab_size]`
- When `hooks` is empty, the forward pass must produce **zero extra allocations**
- `project_to_vocab()` applies the final norm + unembedding projection (for logit lens)
- The trait is `Send + Sync` because `MIModel` may be shared across threads

### Implementing forward()

The forward pass must integrate with the hook system.  Here is the
pattern:

```rust
fn forward(&self, input_ids: &Tensor, hooks: &HookSpec) -> Result<HookCache> {
    // 1. Create a placeholder HookCache
    let placeholder = Tensor::zeros((1,), DType::F32, input_ids.device())?;
    let mut cache = HookCache::new(placeholder);

    // 2. Embedding
    let mut hidden = self.embed(input_ids)?;

    // 3. Hook: Embed (capture + intervene)
    if hooks.is_captured(&HookPoint::Embed) {
        cache.store(HookPoint::Embed, hidden.clone());
    }
    for intervention in hooks.interventions_at(&HookPoint::Embed) {
        hidden = apply_intervention(&hidden, intervention)?;
    }

    // 4. Layer loop
    for i in 0..self.num_layers {
        // ResidPre capture + intervene
        if hooks.is_captured(&HookPoint::ResidPre(i)) {
            cache.store(HookPoint::ResidPre(i), hidden.clone());
        }
        for intervention in hooks.interventions_at(&HookPoint::ResidPre(i)) {
            hidden = apply_intervention(&hidden, intervention)?;
        }

        // ... your layer logic with hook points at each stage ...

        // ResidPost capture + intervene
        if hooks.is_captured(&HookPoint::ResidPost(i)) {
            cache.store(HookPoint::ResidPost(i), hidden.clone());
        }
        for intervention in hooks.interventions_at(&HookPoint::ResidPost(i)) {
            hidden = apply_intervention(&hidden, intervention)?;
        }
    }

    // 5. Final logits
    let logits = self.compute_logits(&hidden)?;
    cache.set_output(logits);
    Ok(cache)
}
```

**Key points:**
- `apply_intervention()` is a crate-internal helper (`pub(crate)` in
  `src/hooks.rs`) — if you're implementing a backend inside the crate, use
  it directly; if outside, implement the intervention logic yourself.
- Capture checks (`is_captured`) are cheap `HashSet` lookups — when false,
  the `.clone()` is skipped entirely.
- Intervention iteration (`interventions_at`) returns an empty iterator
  when no interventions target that hook point.

### Hook Integration Checklist

For a new backend, decide which hook points to support.  At minimum:

| Hook Point | Priority | Purpose |
|------------|----------|---------|
| `Embed` | Required | Token embedding output |
| `ResidPre(i)` | Required | Residual stream before each layer |
| `ResidPost(i)` | Required | Residual stream after each layer (logit lens) |
| `FinalNorm` | Required | After final norm (logit lens) |
| `AttnPattern(i)` | Recommended | Attention visualization |
| `AttnScores(i)` | Recommended | Knockout interventions |
| `AttnQ/K/V(i)` | Optional | Q, K, V inspection |
| `MlpPre/Post/Out(i)` | Optional | MLP analysis |

Backend-specific hook points should use `HookPoint::Custom("name")` or
propose a new enum variant.

### Registering with MIModel::from_pretrained()

To make your backend discoverable via `from_pretrained()`, add a match arm
in `src/backend.rs` inside `MIModel::from_pretrained()`:

```rust
#[cfg(feature = "my_backend")]
"my_model_type" => {
    let config = MyConfig::from_hf_config(&json)?;
    let model = MyBackend::load(config, &device, dtype, vb)?;
    Ok(Self::with_tokenizer(Box::new(model), device, tokenizer))
}
```

The backend is gated behind a feature flag, so users only compile what they
need.

---

## TransformerConfig Reference

### Configuration Axes

The `GenericTransformer` is parameterized by these fields:

| Field | Type | Description |
|-------|------|-------------|
| `hidden_size` | `usize` | Model dimension (`d_model`) |
| `num_layers` | `usize` | Number of decoder blocks |
| `num_attention_heads` | `usize` | Number of query heads |
| `num_kv_heads` | `usize` | Number of key/value heads (GQA) |
| `head_dim` | `usize` | Per-head dimension |
| `intermediate_size` | `usize` | MLP hidden dimension |
| `vocab_size` | `usize` | Vocabulary size |
| `norm_type` | `NormType` | `RmsNorm`, `LayerNorm`, or `GemmaRmsNorm` |
| `norm_eps` | `f64` | Normalization epsilon |
| `activation` | `Activation` | `Silu`, `Gelu`, or `GeluApprox` |
| `qkv_layout` | `QkvLayout` | `Separate` or `Fused` |
| `mlp_layout` | `MlpLayout` | `GatedSeparate`, `GatedFused`, or `Plain` |
| `qkv_bias` | `bool` | Bias on Q, K, V projections |
| `use_bias` | `bool` | Bias on all projections |
| `use_post_norms` | `bool` | 4-norm layers (Gemma 2) |
| `embedding_scale` | `Option<f64>` | Multiply embeddings by this value |
| `attn_logit_softcapping` | `Option<f64>` | Pre-softmax logit soft-capping threshold |
| `query_pre_attn_scalar` | `Option<f64>` | Custom Q scaling (overrides `1/sqrt(head_dim)`) |
| `rope_theta` | `f64` | RoPE base frequency |
| `max_position_embeddings` | `usize` | Maximum sequence length |
| `tie_word_embeddings` | `bool` | Share `embed_tokens` and `lm_head` weights |
| `alternating_sliding_window` | `bool` | Alternating full/sliding attention (Gemma 2) |
| `sliding_window` | `Option<usize>` | Sliding window size |

### Existing Parsers as Templates

| Family | Parser | Key differences from LLaMA baseline |
|--------|--------|-------------------------------------|
| LLaMA | `parse_llama` | Baseline: GQA, SiLU, RmsNorm, separate QKV, gated MLP |
| Qwen2 | `parse_qwen2` | + `qkv_bias: true` (Q, K, V only, not `o_proj`) |
| Gemma | `parse_gemma` | + `GemmaRmsNorm`, `embedding_scale`, `GeluApprox` |
| Gemma 2 | `parse_gemma2` | + post-norms, soft-capping, `query_pre_attn_scalar`, alternating sliding window |
| Phi-3 | `parse_phi3` | + fused QKV, fused gate+up MLP |
| StarCoder2 | `parse_starcoder2` | + `LayerNorm`, `Plain` MLP, `use_bias: true`, `GeluApprox` |
| Mistral | `parse_mistral` | + sliding window attention |

---

## Testing Checklist

When adding a new model family, verify:

- [ ] **Config parsing**: `TransformerConfig::from_hf_config(&json)` produces the correct config for a known model
- [ ] **Forward pass**: top-5 predictions match Python HuggingFace Transformers (F32 on both sides)
- [ ] **Hook capture**: all hook points produce tensors with the expected shapes
- [ ] **Intervention**: `Intervention::Zero` at `ResidPost(0)` changes the output (proves hooks are wired)
- [ ] **Logit lens**: `project_to_vocab()` produces meaningful predictions at intermediate layers
- [ ] **No regression**: existing model tests still pass (`cargo test --features transformer`)

For GPU testing, run with `--test-threads=1` to avoid OOM on 3B+ models.

---

## Weight Naming Convention

The `GenericTransformer` expects HuggingFace-standard weight names under
the `model.` prefix:

```
model.embed_tokens.weight                         # [vocab, hidden]
model.layers.{i}.input_layernorm.weight           # [hidden]
model.layers.{i}.self_attn.q_proj.weight          # [n_heads * head_dim, hidden]
model.layers.{i}.self_attn.k_proj.weight          # [n_kv_heads * head_dim, hidden]
model.layers.{i}.self_attn.v_proj.weight          # [n_kv_heads * head_dim, hidden]
model.layers.{i}.self_attn.o_proj.weight          # [hidden, n_heads * head_dim]
model.layers.{i}.post_attention_layernorm.weight  # [hidden]
model.layers.{i}.mlp.gate_proj.weight             # [intermediate, hidden]
model.layers.{i}.mlp.up_proj.weight               # [intermediate, hidden]
model.layers.{i}.mlp.down_proj.weight             # [hidden, intermediate]
model.norm.weight                                 # [hidden]
lm_head.weight                                    # [vocab, hidden]
```

Variants by family:

| Family | Difference |
|--------|-----------|
| Phi-3 | `qkv_proj` instead of `q/k/v_proj`; `gate_up_proj` instead of `gate_proj` + `up_proj` |
| StarCoder2 | `fc1`/`fc2` instead of `gate_proj`/`up_proj`/`down_proj`; `.bias` on all projections |
| Gemma 2 | + `pre_feedforward_layernorm`, `post_feedforward_layernorm`, `post_attention_layernorm` (4 norms) |
| Tied embeddings | No `lm_head.weight` — reuses `embed_tokens.weight` |

Models with non-standard naming (e.g., `transformer.h.{i}.attn` instead
of `model.layers.{i}.self_attn`) require a custom `MIBackend` (Path 3).