# Validation Scripts and Reference Data
This directory contains Python reference scripts, comparison reports, and
JSON reference data used to validate candle-mi against Python/PyTorch
implementations.
## RWKV (Phase 2)
| File | Purpose |
|------|---------|
| `rwkv7_validation.py` | Python reference: RWKV-7 Goose 1.5B forward pass with HF Transformers |
| `rwkv7_reference.json` | Python reference output: top-5 token IDs + logits for RWKV-7 |
| `rwkv6_reference.json` | Python reference output: top-5 token IDs + logits for RWKV-6 Finch 1.6B |
| `rwkv7_validation_comparison.md` | Detailed Rust vs Python comparison (logit diffs, dtype analysis) |
**Regeneration:**
```bash
python scripts/rwkv7_validation.py
```
Requires `transformers`, `torch`, `safetensors`. Outputs `rwkv7_reference.json`.
## CLT Position Sweep (Phase 3)
| File | Purpose |
|------|---------|
| `clt_position_sweep_validation.py` | Python reference: Gemma 2 2B CLT position sweep |
| `clt_position_sweep_validation_llama.py` | Python reference: Llama 3.2 1B CLT position sweep |
| `clt_position_sweep_reference.json` | Python reference output: per-position top features + causal L2 distances (Gemma 2 2B) |
| `clt_position_sweep_comparison.md` | Detailed Rust vs Python comparison for both models |
**Regeneration:**
```bash
python scripts/clt_position_sweep_validation.py # Gemma 2 2B
python scripts/clt_position_sweep_validation_llama.py # Llama 3.2 1B
```
Requires `transformers`, `torch`, `safetensors`, `pyyaml`. Outputs
`clt_position_sweep_reference.json`.
## Anacrousis / Recurrent Feedback (Phase 3)
| File | Purpose |
|------|---------|
| `anacrousis_reference.json` | Reference output from plip-rs: 28 conditions x 15 couplets (420 measurements) |
The reference data was generated by
[plip-rs](https://github.com/PCfVW/plip-rs) `examples/recurrent_block_rhyme.rs`
on Llama 3.2 1B with greedy decoding.
**plip-rs results** (with KV cache):
- Baseline: 10/15 couplets rhyme
- sustained\_14-15\_s=1.0: 11/15 (strict superset of baseline)
**candle-mi results** (no KV cache — full sequence recompute each step):
- Baseline: 9/15
- unembed\_8-15\_s=2.0: 11/15 (best improvement, superset of baseline)
- Double-pass without feedback: 0/15 (degenerate — out-of-distribution)
**Regeneration** (from the plip-rs repository):
```bash
cargo run --release --example recurrent_block_rhyme
```
**Validation** (candle-mi):
```bash
cargo test --test validate_anacrousis --features transformer -- --ignored --test-threads=1
```
## SAE (Phase 4)
| File | Purpose |
|------|---------|
| `sae_validation.py` | Python reference: Gemma 2 2B + Gemma Scope SAE encoding (direct NPZ loading, no SAELens) |
| `sae_reference.json` | Python reference output: top features, reconstruction MSE, norms (generated by `sae_validation.py`) |
**Regeneration:**
```bash
python scripts/sae_validation.py
```
Requires `transformers`, `torch`, `huggingface_hub`, `numpy`. Outputs `sae_reference.json`.
SAE weights are loaded from `google/gemma-scope-2b-pt-res` (NPZ format),
the same repo that candle-mi uses via `hf-fetch-model`.
**Rust vs Python results** (Gemma 2 2B, layer 0, width 16k, JumpReLU, prompt "The capital of France is"):
| Metric | Rust | Python | Diff |
|--------|------|--------|------|
| Architecture | JumpReLU (auto-detected) | JumpReLU | — |
| Active features (last pos) | 14 / 16384 | 14 / 16384 | 0 |
| Top feature | #10492 (44.0694) | #10492 (44.0724) | 0.003 |
| Top-10 feature indices | 9/10 match | — | 1 diff (#10: Rust 6320 vs Python 688) |
| Reconstruction MSE | 6.912613 | 6.912656 | 0.000043 |
| Residual norm | 61.7991 | 61.8125 | 0.0134 |
| Decoded norm | 59.7376 | 59.7459 | 0.0083 |
**Validation** (candle-mi):
```bash
cargo test --test validate_sae --features sae,transformer -- --ignored --test-threads=1
```
## PLT — Llama 3.2 1B (v0.1.9)
| File | Purpose |
|------|---------|
| `plt_llama_validation.py` | From-first-principles encoder oracle: loads `layer_{L}.safetensors` bundles from `mntss/transcoder-Llama-3.2-1B`, applies `ReLU(W_enc @ residual + b_enc)` in torch on CPU. No circuit-tracer involvement. |
| `plt_llama_reference.json` | Python reference output: 9 test cases (3 seeds × layers {0, 7, 15}) with top-10 feature indices and activations. Frozen oracle for the Rust parity test. |
**Regeneration:**
```bash
python scripts/plt_llama_validation.py
```
Requires `torch`, `safetensors`, `huggingface_hub`. Outputs `plt_llama_reference.json` (~540 KB).
The methodology mirrors plip-rs's `scripts/clt_reference.py` (which achieved
90/90 top-10 parity with max relative error 1.2×10⁻⁶ for CLTs). The layer choice
`[0, 7, 15]` is ends + middle of Llama 3.2 1B's 16-layer stack (plip-rs used
`[0, 12, 25]` for Gemma 2 2B's 26 layers). PltBundle weight layout:
`W_enc [131072, 2048]`, `b_enc [131072]`, rank-2 `W_dec [131072, 2048]`,
`W_skip [2048, 2048]` (Llama PLT linear skip path, present but unused by the
encoder oracle), `b_dec [2048]`.
**Validation** (candle-mi, lands in V3 Step 1.5):
```bash
cargo test --test validate_plt --features clt,transformer -- --ignored --test-threads=1
```
## PLT — Gemma 2 2B (v0.1.10)
| File | Purpose |
|------|---------|
| `plt_gemma_validation.py` | From-first-principles encoder oracle for the GemmaScope PLT. Two-repo flow: parses `mntss/gemma-scope-transcoders/config.yaml` to discover per-layer NPZ paths, then loads each `params.npz` from `google/gemma-scope-2b-pt-transcoders` directly via `huggingface_hub` + `numpy`, applies `pre = W_enc.T @ residual + b_enc; acts = pre * (pre > threshold)` in torch on CPU. No circuit-tracer involvement. |
| `plt_gemma_reference.json` | Python reference output: 9 test cases (3 seeds × layers {0, 12, 25}) with top-10 feature indices and activations. Frozen oracle for the Rust parity test. |
**Regeneration:**
```bash
python scripts/plt_gemma_validation.py
```
Requires `torch`, `numpy`, `huggingface_hub`, `pyyaml`. Outputs
`plt_gemma_reference.json` (~610 KB). First run downloads ~864 MiB
(3 NPZs × ~288 MiB each FP32) into the HF cache.
The methodology mirrors `plt_llama_validation.py`, adapted for the
`GemmaScopeNpz` schema:
- **Two-repo flow** — curation YAML on `mntss/gemma-scope-transcoders`,
weights on `google/gemma-scope-2b-pt-transcoders`. The 26-entry
curation list points at the lowest-`L0` variant per layer.
- **`W_enc` transpose** — GemmaScope stores `W_enc` as
`[d_model, n_features] = [2304, 16384]` on disk; the oracle applies
`.T.contiguous()` to canonicalise the orientation, matching
`circuit-tracer`'s `load_gemma_scope_transcoder()` reference.
- **JumpReLU activation** — `acts = pre * (pre > threshold)` element-wise
with a per-feature `threshold [n_features]` tensor, instead of plain
`ReLU` (Llama PLT) or `ReLU` after CLT decoder skip.
- **Layer choice** `{0, 12, 25}` — ends + middle of Gemma 2 2B's
26-layer stack, mirroring plip-rs's `[0, 12, 25]`.
GemmaScope NPZ tensor layout (`width_16k`):
`W_enc [2304, 16384]`, `W_dec [16384, 2304]`, `b_enc [16384]`,
`b_dec [2304]`, `threshold [16384]`. No `W_skip` (GemmaScope is a
pure JumpReLU transcoder).
**Validation** (candle-mi, V3 Step 1.6 / Phase A.7 — 9/9 cases pass on
CPU with max abs-diff 4.20e-5):
```bash
cargo test --test validate_plt_gemma --features clt,sae,transformer -- --ignored
```
## Vocab scan and figure13 helpers (v0.1.11)
Helpers for the [Anthropic Appendix B vocabulary-scan protocol](https://transformer-circuits.pub/2025/attribution-graphs/biology.html)
+ figure13 preset construction. All four scripts are pure Python (no
`torch` dependency) and operate on the JSON output of the
`vocab_scan` Rust example.
| File | Purpose |
|------|---------|
| `vocab_scan_cmudict_filter.py` | Filter the raw `vocab_scan` JSON (gitignored, ~500 MB to 1.5 GB) through `nltk.corpus.cmudict`. For each feature, look up the CMUdict pronunciation of each top-K token, deduplicate by normalised English word, and flag features whose deduplicated tokens share a single ARPABET rime (cluster size ≥ 3, share ≥ 0.5 of CMU-resolvable words). Emits a phonologically-clean subset JSON (~2 MB, committable) plus a rhyme-group histogram. |
| `pick_features.py` | Read a clean-subset JSON and print the top-5 features per requested rime, sorted by `max_cosine`. Used to choose **suppress** features for `figure13_planning_poems` presets (we want broad rime-cluster coverage). |
| `pick_inject_feature.py` | Read a raw scan JSON and rank all features by the cosine they assign to a specific target word in their top-K. Used to choose **inject** features when the target word's identity matters more than its rime-cluster membership (e.g., `" myself"` for `-ation` poems where the prompt has no natural `-self` prior). |
| `inspect_grid.py` | Pretty-print the per-strength max-ratio profile of a `figure13_planning_poems --strength-grid` output JSON. Quick sanity-check for the 2D position × strength sweeps; shows planning-site probability per strength alongside the per-strength max. |
**Typical pipeline** (regenerate any cell in the
[cross-size sweep](../docs/experiments/figure13-qwen3-cross-size.md)):
```bash
# 1. Vocab scan (Rust; outputs ~500 MB–1.5 GB raw JSON, gitignored).
cargo run --release --features clt,transformer,mmap --example vocab_scan -- \
--model <model_id> --clt-repo <clt_repo> --output <raw_json>
# 2. Filter through CMUdict (Python; outputs ~1–2 MB clean JSON, committable).
python scripts/vocab_scan_cmudict_filter.py <raw_json> --clean-only-output \
--output <clean_json>
# 3a. Pick suppress features (top-5 per rime cluster).
python scripts/pick_features.py # edit `rimes` list inline
# 3b. Pick inject features (top-5 by cosine to a target word).
python scripts/pick_inject_feature.py <raw_json> " myself" " duration"
# 4. Run figure13 sweep (2D grid; outputs ~10–50 KB committable JSON).
cargo run --release --features clt,transformer,mmap --example figure13_planning_poems -- \
--preset <preset_name> --strength-grid 0.5,1,2.5,5,10,25,50,100 \
--output <grid_json>
# 5. Quick inspect.
python scripts/inspect_grid.py <grid_json>
```
CMUdict is loaded via `nltk.corpus.cmudict` and assumed already
downloaded (`python -c "import nltk; nltk.download('cmudict')"`). On
this machine it ships pre-cached at
`~/AppData/Roaming/nltk_data/corpora/cmudict/cmudict`.
## Maar replication helpers (v0.1.12)
Helpers for the Maar et al. (2026) contrastive activation steering
replication (`examples/maar_contrastive_steering`).
| File | Purpose |
|------|---------|
| `maar_supplementary_fetch.py` | Attempt to download Maar's supplementary `.zip` from OpenReview `Z10pxu0Q7X`. The OpenReview attachment endpoint typically requires a logged-in session, so this script will often print a manual-download URL the user opens in a browser. After download, the supplementary's `paper_experiments/` directory contains the per-model layer + position + strength specs not documented in the paper text. |
| `regenerate_rhyme_prompts.py` | Step B fallback: when Maar's supplementary is unavailable, generate candle-mi-authored prompt sets following Maar's described structure (85 positive + 85 negative + 20 held-out eval per rhyme family, template `"A rhyming couplet:\n{line}"`). Uses hand-curated couplet-stem templates and CMUdict-validated rhyme word lists. Output JSON matches the schema `examples/maar_contrastive_steering` expects. |
| `convert_maar_prompts.py` | Step A continuation: when Maar's supplementary IS available (downloaded by `maar_supplementary_fetch.py` and extracted to `examples/results/maar_contrastive_steering/maar_supp/`), parse `paper_experiments/data/{train,test}/rhyme_family_lines.json` plus the `rhyme_family_words` dict embedded in `shared_utils.py`, and emit the four `*_maar.json` files the `maar_contrastive_steering` example consumes (one per cell). The output JSONs carry `"source": "maar-supplementary"` and a `"source_url"` so reviewers can audit provenance without us committing Maar's 60 MB code drop. Round-trip-verified: rerunning the script reproduces the committed prompts JSONs byte-identically. |
**Typical pipeline**:
```bash
# 1a. Try to fetch Maar's supplementary (Step A).
python scripts/maar_supplementary_fetch.py \
--output examples/results/maar_contrastive_steering/maar_supplementary.zip
# 1b. If Step A succeeded, convert Maar's prompts to candle-mi schema.
python scripts/convert_maar_prompts.py
# 2. If Step A failed, generate prompts ourselves (Step B fallback).
python scripts/regenerate_rhyme_prompts.py \
--family -ee --contrast -out \
--output examples/results/maar_contrastive_steering/prompts/llama32_3b_rhyme_ee.json
# 3. Run the example with Maar-faithful protocol (Maar-supplementary prompts
# + raw direction + generated-couplet metric + 25 new tokens).
cargo run --release --features transformer,mmap --example maar_contrastive_steering -- \
--preset llama32-3b-rhyme-ee \
--prompt-file examples/results/maar_contrastive_steering/prompts/llama32_3b_rhyme_ee_maar.json \
--layer-grid 22 --strength-grid 1.5 \
--normalise=false --position-strategy last \
--metric generated-couplet --max-new-tokens 25 \
--output docs/experiments/maar-replication/llama32_3b_rhyme_ee_maar_prompts.json
```
CMUdict is loaded via `nltk.corpus.cmudict` (same dependency as
`vocab_scan_cmudict_filter.py`).
## Integration Test Commands
All integration tests require models cached in `~/.cache/huggingface/hub/`
and a CUDA GPU with at least 16 GiB VRAM. Run with `--test-threads=1` to
avoid OOM on 3B+ models.
```bash
# Model validation (6 families, CPU + GPU)
cargo test --test validate_models --features transformer -- --test-threads=1
# CLT validation (Gemma 2 2B + Llama 3.2 1B)
cargo test --test validate_clt --features clt,transformer -- --ignored --test-threads=1
# Anacrousis validation (Llama 3.2 1B, 28 conditions x 15 couplets)
cargo test --test validate_anacrousis --features transformer -- --ignored --test-threads=1
# SAE validation (Gemma 2 2B + Gemma Scope SAE)
cargo test --test validate_sae --features sae,transformer -- --ignored --test-threads=1
# PLT validation (Llama 3.2 1B, v0.1.9 — lands in Step 1.5)
cargo test --test validate_plt --features clt,transformer -- --ignored --test-threads=1
# PLT validation (Gemma 2 2B GemmaScope, v0.1.10 — Phase A.7, CPU)
cargo test --test validate_plt_gemma --features clt,sae,transformer -- --ignored
```