candle-mi 0.1.12

Mechanistic interpretability for language models in Rust, built on candle
Documentation
// SPDX-License-Identifier: MIT OR Apache-2.0

//! Integration test: PLT (Per-Layer Transcoder) encoder parity against the
//! from-first-principles Python oracle in `scripts/plt_llama_validation.py`.
//!
//! Consumes the frozen reference JSON (`scripts/plt_llama_reference.json`)
//! generated by V3 Step 1.4 and verifies that candle-mi's
//! [`CrossLayerTranscoder`] encoder produces matching output when fed the same
//! residual vectors. Acceptance bar per V3 Step 1.5:
//!
//! - Detected schema is [`TranscoderSchema::PltBundle`].
//! - `(d_model, n_features_per_layer)` match the Python run.
//! - Per test case: active-feature count matches, top-10 feature indices match
//!   exactly, activation magnitudes within `abs diff < 1e-4` at F32.
//!
//! Runs on CPU: the Python oracle is CPU-only, so same-device comparison gives
//! the closest possible numerical match (no CUDA-vs-CPU rounding noise).
//!
//! Requires `mntss/transcoder-Llama-3.2-1B` cached in `~/.cache/huggingface/hub/`
//! (~16 GiB of safetensors). `#[ignore]`-gated so it does not run by default.
//!
//! Run:
//!   `cargo test --test validate_plt --features clt,transformer -- --ignored`

#![allow(
    clippy::unwrap_used,
    clippy::expect_used,
    clippy::panic,
    clippy::indexing_slicing,
    clippy::cast_possible_truncation,
    clippy::as_conversions,
    clippy::missing_docs_in_private_items,
    clippy::missing_panics_doc,
    missing_docs
)]

use std::collections::HashMap;
use std::path::PathBuf;

use candle_core::{Device, Tensor};
use candle_mi::clt::{CrossLayerTranscoder, TranscoderSchema};

fn reference_path() -> PathBuf {
    PathBuf::from(env!("CARGO_MANIFEST_DIR"))
        .join("scripts")
        .join("plt_llama_reference.json")
}

#[test]
#[ignore = "requires mntss/transcoder-Llama-3.2-1B cached (~16 GiB); run with --ignored"]
// The body is a flat sequence (load reference → open PLT → assert schema →
// group cases by layer → iterate → compare) with a lot of inline validation
// messages. Extracting helpers would fragment a naturally linear test story,
// and the CONVENTIONS annotations + assertion messages account for ~25 of the
// 115 lines. Pedantic length lint suppressed intentionally.
#[allow(clippy::too_many_lines)]
fn validate_plt_llama_encoder_against_python_oracle() {
    // --- Load the frozen reference JSON ---
    let reference_str = std::fs::read_to_string(reference_path()).expect(
        "failed to read plt_llama_reference.json — run scripts/plt_llama_validation.py first",
    );
    let reference: serde_json::Value = serde_json::from_str(&reference_str).unwrap();

    let plt_repo = reference["plt_repo"].as_str().unwrap();
    let ref_schema = reference["schema"].as_str().unwrap();
    // CAST: u64 → usize, JSON integer known to fit (PLT d_model = 2048)
    let d_model = reference["d_model"].as_u64().unwrap() as usize;
    // CAST: u64 → usize, JSON integer known to fit (PLT n_features = 131072)
    let n_features_per_layer = reference["n_features_per_layer"].as_u64().unwrap() as usize;
    let test_cases = reference["test_cases"].as_array().unwrap();

    assert_eq!(
        ref_schema, "PltBundle",
        "oracle JSON schema field must be PltBundle"
    );

    println!("Validating PLT encoder parity for {plt_repo}");
    println!("  d_model = {d_model}, n_features_per_layer = {n_features_per_layer}");
    println!("  {} test cases to check", test_cases.len());

    // --- Open the PLT; verify schema and dimensions match ---
    let mut plt = CrossLayerTranscoder::open(plt_repo)
        .expect("failed to open PLT — ensure the model is in the HF cache");

    assert_eq!(
        plt.config().schema,
        TranscoderSchema::PltBundle,
        "open() must detect PltBundle for mntss/transcoder-Llama-3.2-1B"
    );
    assert_eq!(
        plt.config().d_model,
        d_model,
        "d_model mismatch with oracle"
    );
    assert_eq!(
        plt.config().n_features_per_layer,
        n_features_per_layer,
        "n_features_per_layer mismatch with oracle"
    );

    let device = Device::Cpu;

    // --- Group test cases by layer so each encoder is loaded exactly once ---
    let mut by_layer: HashMap<usize, Vec<&serde_json::Value>> = HashMap::new();
    for tc in test_cases {
        // CAST: u64 → usize, JSON layer index known to be 0..16 for Llama 3.2 1B
        let layer = tc["layer"].as_u64().unwrap() as usize;
        by_layer.entry(layer).or_default().push(tc);
    }

    let mut total_cases = 0_usize;
    let mut max_abs_diff: f32 = 0.0;

    // Iterate layers in sorted order for reproducible output.
    let mut layers: Vec<usize> = by_layer.keys().copied().collect();
    layers.sort_unstable();

    for layer in layers {
        plt.load_encoder(layer, &device).unwrap();
        println!("Layer {layer}:");

        // INDEX: by_layer was populated from the same `layers` keys we are
        // iterating — every `layer` is guaranteed to be a key.
        for tc in &by_layer[&layer] {
            let seed = tc["seed"].as_u64().unwrap();
            let residual_vec: Vec<f32> = tc["residual"]
                .as_array()
                .unwrap()
                .iter()
                // CAST: f64 → f32, JSON residual stored as Python-float (f64)
                // but candle-mi's encoder works in F32; matches oracle's input dtype.
                .map(|v| v.as_f64().unwrap() as f32)
                .collect();
            // CAST: u64 → usize, JSON count bounded by n_features_per_layer
            let ref_n_active = tc["n_active"].as_u64().unwrap() as usize;
            let ref_top10 = tc["top_10"].as_array().unwrap();

            assert_eq!(
                residual_vec.len(),
                d_model,
                "layer {layer} seed {seed}: residual length {} != d_model {d_model}",
                residual_vec.len()
            );

            // Build the residual tensor on CPU and run the Rust encoder.
            let residual = Tensor::from_vec(residual_vec, (d_model,), &device).unwrap();
            let sparse = plt.encode(&residual, layer).unwrap();

            // --- Active-feature count ---
            assert_eq!(
                sparse.features.len(),
                ref_n_active,
                "layer {layer} seed {seed}: n_active mismatch (Rust {}, Python {})",
                sparse.features.len(),
                ref_n_active
            );

            // --- Top-10 indices + activations ---
            for (rank, ref_item) in ref_top10.iter().enumerate() {
                // CAST: u64 → usize, JSON feature index bounded by n_features_per_layer
                let ref_idx = ref_item["index"].as_u64().unwrap() as usize;
                // CAST: f64 → f32, activation magnitude down-cast to match candle-mi's F32 encoder
                let ref_act = ref_item["activation"].as_f64().unwrap() as f32;

                let (rust_fid, rust_act) = sparse.features.get(rank).unwrap_or_else(|| {
                    panic!(
                        "layer {layer} seed {seed}: Rust top-{} shorter than Python's",
                        rank + 1
                    )
                });

                assert_eq!(
                    rust_fid.index, ref_idx,
                    "layer {layer} seed {seed} rank {rank}: index mismatch \
                     (Rust {}, Python {ref_idx})",
                    rust_fid.index
                );
                assert_eq!(
                    rust_fid.layer, layer,
                    "layer {layer} seed {seed} rank {rank}: feature.layer {} != test layer",
                    rust_fid.layer
                );

                let diff = (*rust_act - ref_act).abs();
                assert!(
                    diff < 1e-4,
                    "layer {layer} seed {seed} rank {rank}: activation abs-diff {diff:.2e} >= 1e-4 \
                     (Rust {rust_act}, Python {ref_act})"
                );
                if diff > max_abs_diff {
                    max_abs_diff = diff;
                }
            }

            // INDEX: sparse.features[0] — safe because we just validated that
            // top-10 is populated (ref_top10.len() ≥ 1 when we enter this block,
            // and sparse.features.len() == ref_n_active == ref_top10.len() or more).
            let top_feature = sparse.features[0].0;
            println!(
                "  seed={seed:4}: {} active / {} features, top={top_feature}, top-10 matches \
                 (max abs-diff so far = {max_abs_diff:.2e})",
                sparse.features.len(),
                n_features_per_layer,
            );
            total_cases += 1;
        }
    }

    println!(
        "\n{total_cases} test cases passed; max abs-diff across all top-10 activations = \
         {max_abs_diff:.2e} (bar: 1e-4)"
    );
}