#![allow(clippy::doc_markdown)]
#![allow(clippy::missing_docs_in_private_items)]
#![allow(clippy::unnecessary_wraps)]
#![allow(clippy::cast_precision_loss)]
use candle_mi::interp::intervention::create_knockout_mask;
use candle_mi::interp::logit_lens::format_probability;
use candle_mi::{
AblationResult, HookPoint, HookSpec, Intervention, KnockoutSpec, MIModel, MITokenizer,
SUPPORTED_MODEL_TYPES,
};
#[cfg(feature = "memory")]
use candle_mi::{MemoryReport, MemorySnapshot};
use clap::Parser;
use serde::Serialize;
use std::path::{Path, PathBuf};
use std::time::Instant;
#[derive(Parser)]
#[command(name = "attention_knockout")]
#[command(about = "Attention knockout: ablate an attention edge and measure prediction impact")]
struct Args {
model: Option<String>,
#[arg(long)]
output: Option<PathBuf>,
}
#[derive(Serialize)]
struct JsonOutput {
model_id: String,
prompt: String,
n_layers: usize,
n_heads: usize,
hidden_size: usize,
target_layer: usize,
knockout_edge: [usize; 2],
kl_divergence: f32,
paris_logit_diff: Option<f32>,
top_changed: Vec<JsonChangedToken>,
}
#[derive(Serialize)]
struct JsonChangedToken {
rank: usize,
token_id: u32,
token: String,
baseline_prob: f32,
ablated_prob: f32,
abs_diff: f32,
}
fn main() {
if let Err(e) = run() {
eprintln!("Error: {e}");
std::process::exit(1);
}
}
fn run() -> candle_mi::Result<()> {
let args = Args::parse();
let prompt = "The capital of France is";
if let Some(ref model_id) = args.model {
return run_single_model(model_id, prompt, args.output.as_deref());
}
if args.output.is_some() {
eprintln!("Warning: --output is only supported with a specific model ID; ignoring.");
}
let cached = discover_cached_models();
if cached.is_empty() {
println!("No cached transformer models found in the HuggingFace Hub cache.");
println!("Download one first, e.g.:");
println!(" cargo run --example fast_download -- meta-llama/Llama-3.2-1B");
return Ok(());
}
println!(
"Found {} supported transformer(s) in HF cache:\n",
cached.len()
);
for (model_id, model_type, snapshot) in &cached {
println!("=== {model_id} (model_type: {model_type}) ===");
if let Err(e) = run_model(model_id, snapshot, prompt) {
println!(" Skipped: {e}\n");
}
}
Ok(())
}
fn run_single_model(
model_id: &str,
prompt: &str,
json_path: Option<&Path>,
) -> candle_mi::Result<()> {
println!("=== {model_id} ===");
#[cfg(feature = "memory")]
let mem_before = MemorySnapshot::now(
&candle_core::Device::cuda_if_available(0).unwrap_or(candle_core::Device::Cpu),
)?;
let t0 = Instant::now();
let model = MIModel::from_pretrained(model_id)?;
let load_time = t0.elapsed();
let n_layers = model.num_layers();
let n_heads = model.num_heads();
let hidden = model.hidden_size();
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let weight_mb = estimate_weight_mb(n_layers, hidden);
println!(
" Layers: {n_layers}, heads: {n_heads}, hidden: {hidden}, device: {:?}",
model.device()
);
println!(" Estimated F32 weight size: {weight_mb:.0} MB");
println!(" Load time: {load_time:.2?}");
#[cfg(feature = "memory")]
{
let mem_after = MemorySnapshot::now(model.device())?;
MemoryReport::new(mem_before, mem_after).print_before_after("Model load");
}
let tokenizer = model.tokenizer().ok_or(candle_mi::MIError::Tokenizer(
"model has no embedded tokenizer".into(),
))?;
run_knockout(
&model, tokenizer, prompt, n_layers, n_heads, hidden, model_id, json_path,
)
}
fn hf_cache_dir() -> Option<PathBuf> {
if let Ok(cache) = std::env::var("HF_HOME") {
return Some(PathBuf::from(cache).join("hub"));
}
if let Ok(home) = std::env::var("USERPROFILE") {
let p = PathBuf::from(home)
.join(".cache")
.join("huggingface")
.join("hub");
if p.is_dir() {
return Some(p);
}
}
if let Ok(home) = std::env::var("HOME") {
let p = PathBuf::from(home)
.join(".cache")
.join("huggingface")
.join("hub");
if p.is_dir() {
return Some(p);
}
}
None
}
fn find_snapshot(cache_dir: &Path, model_id: &str) -> Option<PathBuf> {
let dir_name = format!("models--{}", model_id.replace('/', "--"));
let snapshots = cache_dir.join(dir_name).join("snapshots");
let entry = std::fs::read_dir(snapshots).ok()?.next()?.ok()?;
Some(entry.path())
}
fn read_model_type(snapshot: &Path) -> Option<String> {
let config_path = snapshot.join("config.json");
let text = std::fs::read_to_string(config_path).ok()?;
let json: serde_json::Value = serde_json::from_str(&text).ok()?;
json.get("model_type")?.as_str().map(String::from)
}
fn discover_cached_models() -> Vec<(String, String, PathBuf)> {
let Some(cache_dir) = hf_cache_dir() else {
return Vec::new();
};
let Ok(entries) = std::fs::read_dir(&cache_dir) else {
return Vec::new();
};
let mut models = Vec::new();
for entry in entries.flatten() {
let name = entry.file_name();
let Some(dir_name) = name.to_str() else {
continue;
};
let Some(rest) = dir_name.strip_prefix("models--") else {
continue;
};
let model_id = rest.replacen("--", "/", 1);
let Some(snapshot) = find_snapshot(&cache_dir, &model_id) else {
continue;
};
let Some(model_type) = read_model_type(&snapshot) else {
continue;
};
if SUPPORTED_MODEL_TYPES.contains(&model_type.as_str()) {
models.push((model_id, model_type, snapshot));
}
}
models.sort_by(|a, b| a.0.cmp(&b.0));
models
}
fn run_model(model_id: &str, snapshot: &Path, prompt: &str) -> candle_mi::Result<()> {
#[cfg(feature = "memory")]
let mem_before = MemorySnapshot::now(
&candle_core::Device::cuda_if_available(0).unwrap_or(candle_core::Device::Cpu),
)?;
let t0 = Instant::now();
let model = MIModel::from_pretrained(model_id)?;
let load_time = t0.elapsed();
let n_layers = model.num_layers();
let n_heads = model.num_heads();
let hidden = model.hidden_size();
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let weight_mb = estimate_weight_mb(n_layers, hidden);
println!(
" {} layers, {} heads, {} hidden, device: {:?}",
n_layers,
n_heads,
hidden,
model.device()
);
println!(" Estimated F32 weight size: {weight_mb:.0} MB | Load: {load_time:.2?}");
#[cfg(feature = "memory")]
{
let mem_after = MemorySnapshot::now(model.device())?;
MemoryReport::new(mem_before, mem_after).print_before_after("Memory");
}
let tokenizer_path = snapshot.join("tokenizer.json");
if !tokenizer_path.exists() {
return Err(candle_mi::MIError::Tokenizer(
"tokenizer.json not found in snapshot".into(),
));
}
let tokenizer = MITokenizer::from_hf_path(tokenizer_path)?;
run_knockout(
&model, &tokenizer, prompt, n_layers, n_heads, hidden, model_id, None,
)
}
#[allow(clippy::too_many_arguments)]
fn run_knockout(
model: &MIModel,
tokenizer: &MITokenizer,
prompt: &str,
n_layers: usize,
n_heads: usize,
hidden: usize,
model_id: &str,
json_path: Option<&Path>,
) -> candle_mi::Result<()> {
let token_ids = tokenizer.encode(prompt)?;
let seq_len = token_ids.len();
let input = candle_core::Tensor::new(&token_ids[..], model.device())?.unsqueeze(0)?; println!(" Prompt: \"{prompt}\" ({seq_len} tokens)");
let target_layer = n_layers / 2;
println!(
" Knockout: all {n_heads} heads at layer {target_layer}, \
edge (last → position 0)\n"
);
let t1 = Instant::now();
let baseline_cache = model.forward(&input, &HookSpec::new())?;
let baseline_time = t1.elapsed();
let baseline_logits = baseline_cache.output().get(0)?.get(seq_len - 1)?;
let spec = KnockoutSpec::new().layer(target_layer).edge(seq_len - 1, 0);
let mask = create_knockout_mask(
&spec,
n_heads,
seq_len,
model.device(),
candle_core::DType::F32,
)?;
let mut ablated_hooks = HookSpec::new();
ablated_hooks.intervene(
HookPoint::AttnScores(target_layer),
Intervention::Knockout(mask),
);
let t2 = Instant::now();
let ablated_cache = model.forward(&input, &ablated_hooks)?;
let ablated_time = t2.elapsed();
let ablated_logits = ablated_cache.output().get(0)?.get(seq_len - 1)?;
println!(" Baseline forward : {baseline_time:.2?}");
println!(" Ablated forward : {ablated_time:.2?}");
let result = AblationResult::new(baseline_logits, ablated_logits, spec);
let kl = result.kl_divergence()?;
println!(" KL divergence (baseline || ablated): {kl:.6}");
let paris_tokens = tokenizer.encode(" Paris")?;
let paris_logit_diff = if let Some(&paris_id) = paris_tokens.last() {
let diff = result.logit_diff(paris_id)?;
println!(" Logit diff for \" Paris\" (token {paris_id}): {diff:+.4}");
Some(diff)
} else {
None
};
let changed = result.top_changed_tokens(10)?;
println!("\n Top-10 most changed tokens:");
println!(
" {:>4} {:>15} {:>10} {:>10} {:>10}",
"Rank", "Token", "Baseline", "Ablated", "|Diff|"
);
for (rank, &(token_id, baseline_p, ablated_p, abs_diff)) in changed.iter().enumerate() {
let token_text = tokenizer.decode(&[token_id])?;
println!(
" {:>4} {:>15} {:>10} {:>10} {:>10}",
rank + 1,
format!("\"{}\"", token_text.trim()),
format_probability(baseline_p),
format_probability(ablated_p),
format_probability(abs_diff),
);
}
println!();
if let Some(path) = json_path {
write_knockout_json(
path,
model_id,
prompt,
n_layers,
n_heads,
hidden,
target_layer,
seq_len,
kl,
paris_logit_diff,
&changed,
tokenizer,
)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn write_knockout_json(
path: &Path,
model_id: &str,
prompt: &str,
n_layers: usize,
n_heads: usize,
hidden: usize,
target_layer: usize,
seq_len: usize,
kl: f32,
paris_logit_diff: Option<f32>,
changed: &[(u32, f32, f32, f32)],
tokenizer: &MITokenizer,
) -> candle_mi::Result<()> {
let top_changed: Vec<JsonChangedToken> = changed
.iter()
.enumerate()
.map(
|(rank, &(token_id, baseline_prob, ablated_prob, abs_diff))| {
let token = tokenizer
.decode(&[token_id])
.unwrap_or_else(|_| format!("[{token_id}]"));
JsonChangedToken {
rank: rank + 1,
token_id,
token,
baseline_prob,
ablated_prob,
abs_diff,
}
},
)
.collect();
let output = JsonOutput {
model_id: model_id.into(),
prompt: prompt.into(),
n_layers,
n_heads,
hidden_size: hidden,
target_layer,
knockout_edge: [seq_len - 1, 0],
kl_divergence: kl,
paris_logit_diff,
top_changed,
};
let json = serde_json::to_string_pretty(&output)
.map_err(|e| candle_mi::MIError::Config(format!("JSON serialization failed: {e}")))?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
candle_mi::MIError::Config(format!("failed to create {}: {e}", parent.display()))
})?;
}
std::fs::write(path, &json).map_err(|e| {
candle_mi::MIError::Config(format!("failed to write {}: {e}", path.display()))
})?;
println!(" JSON written to {}", path.display());
Ok(())
}
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
fn estimate_weight_mb(n_layers: usize, hidden: usize) -> f64 {
let params_per_layer = 12.0 * (hidden as f64) * (hidden as f64);
let total_params = (n_layers as f64) * params_per_layer;
total_params * 4.0 / 1_000_000.0
}