#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::indexing_slicing,
clippy::cast_possible_truncation,
clippy::as_conversions,
clippy::missing_docs_in_private_items,
clippy::missing_panics_doc,
missing_docs
)]
use std::collections::HashMap;
use std::path::PathBuf;
use candle_core::{Device, Tensor};
use candle_mi::clt::{CrossLayerTranscoder, TranscoderSchema};
fn reference_path() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("scripts")
.join("plt_llama_reference.json")
}
#[test]
#[ignore = "requires mntss/transcoder-Llama-3.2-1B cached (~16 GiB); run with --ignored"]
#[allow(clippy::too_many_lines)]
fn validate_plt_llama_encoder_against_python_oracle() {
let reference_str = std::fs::read_to_string(reference_path()).expect(
"failed to read plt_llama_reference.json — run scripts/plt_llama_validation.py first",
);
let reference: serde_json::Value = serde_json::from_str(&reference_str).unwrap();
let plt_repo = reference["plt_repo"].as_str().unwrap();
let ref_schema = reference["schema"].as_str().unwrap();
let d_model = reference["d_model"].as_u64().unwrap() as usize;
let n_features_per_layer = reference["n_features_per_layer"].as_u64().unwrap() as usize;
let test_cases = reference["test_cases"].as_array().unwrap();
assert_eq!(
ref_schema, "PltBundle",
"oracle JSON schema field must be PltBundle"
);
println!("Validating PLT encoder parity for {plt_repo}");
println!(" d_model = {d_model}, n_features_per_layer = {n_features_per_layer}");
println!(" {} test cases to check", test_cases.len());
let mut plt = CrossLayerTranscoder::open(plt_repo)
.expect("failed to open PLT — ensure the model is in the HF cache");
assert_eq!(
plt.config().schema,
TranscoderSchema::PltBundle,
"open() must detect PltBundle for mntss/transcoder-Llama-3.2-1B"
);
assert_eq!(
plt.config().d_model,
d_model,
"d_model mismatch with oracle"
);
assert_eq!(
plt.config().n_features_per_layer,
n_features_per_layer,
"n_features_per_layer mismatch with oracle"
);
let device = Device::Cpu;
let mut by_layer: HashMap<usize, Vec<&serde_json::Value>> = HashMap::new();
for tc in test_cases {
let layer = tc["layer"].as_u64().unwrap() as usize;
by_layer.entry(layer).or_default().push(tc);
}
let mut total_cases = 0_usize;
let mut max_abs_diff: f32 = 0.0;
let mut layers: Vec<usize> = by_layer.keys().copied().collect();
layers.sort_unstable();
for layer in layers {
plt.load_encoder(layer, &device).unwrap();
println!("Layer {layer}:");
for tc in &by_layer[&layer] {
let seed = tc["seed"].as_u64().unwrap();
let residual_vec: Vec<f32> = tc["residual"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_f64().unwrap() as f32)
.collect();
let ref_n_active = tc["n_active"].as_u64().unwrap() as usize;
let ref_top10 = tc["top_10"].as_array().unwrap();
assert_eq!(
residual_vec.len(),
d_model,
"layer {layer} seed {seed}: residual length {} != d_model {d_model}",
residual_vec.len()
);
let residual = Tensor::from_vec(residual_vec, (d_model,), &device).unwrap();
let sparse = plt.encode(&residual, layer).unwrap();
assert_eq!(
sparse.features.len(),
ref_n_active,
"layer {layer} seed {seed}: n_active mismatch (Rust {}, Python {})",
sparse.features.len(),
ref_n_active
);
for (rank, ref_item) in ref_top10.iter().enumerate() {
let ref_idx = ref_item["index"].as_u64().unwrap() as usize;
let ref_act = ref_item["activation"].as_f64().unwrap() as f32;
let (rust_fid, rust_act) = sparse.features.get(rank).unwrap_or_else(|| {
panic!(
"layer {layer} seed {seed}: Rust top-{} shorter than Python's",
rank + 1
)
});
assert_eq!(
rust_fid.index, ref_idx,
"layer {layer} seed {seed} rank {rank}: index mismatch \
(Rust {}, Python {ref_idx})",
rust_fid.index
);
assert_eq!(
rust_fid.layer, layer,
"layer {layer} seed {seed} rank {rank}: feature.layer {} != test layer",
rust_fid.layer
);
let diff = (*rust_act - ref_act).abs();
assert!(
diff < 1e-4,
"layer {layer} seed {seed} rank {rank}: activation abs-diff {diff:.2e} >= 1e-4 \
(Rust {rust_act}, Python {ref_act})"
);
if diff > max_abs_diff {
max_abs_diff = diff;
}
}
let top_feature = sparse.features[0].0;
println!(
" seed={seed:4}: {} active / {} features, top={top_feature}, top-10 matches \
(max abs-diff so far = {max_abs_diff:.2e})",
sparse.features.len(),
n_features_per_layer,
);
total_cases += 1;
}
}
println!(
"\n{total_cases} test cases passed; max abs-diff across all top-10 activations = \
{max_abs_diff:.2e} (bar: 1e-4)"
);
}