#![allow(clippy::doc_markdown)]
#![allow(clippy::missing_docs_in_private_items)]
#![allow(clippy::unnecessary_wraps)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::too_many_lines)]
use candle_mi::interp::intervention::kl_divergence;
use candle_mi::{FullActivationCache, HookPoint, HookSpec, Intervention, MIModel, sample_token};
#[cfg(feature = "memory")]
use candle_mi::{MemoryReport, MemorySnapshot};
use clap::Parser;
use serde::Serialize;
use std::path::{Path, PathBuf};
use std::time::Instant;
#[derive(Parser)]
#[command(name = "counterfact_patching")]
#[command(about = "CounterFact activation patching (Li et al., 2025 / Transluce protocol)")]
struct Args {
#[arg(default_value = "meta-llama/Llama-3.2-1B")]
model: String,
#[arg(long)]
output: Option<PathBuf>,
#[arg(long, default_value = "4")]
block_size: usize,
#[arg(long)]
limit: Option<usize>,
#[arg(long, default_value = "data/counterfact_transluce_test.json")]
data: PathBuf,
#[arg(long)]
no_runtime: bool,
}
#[derive(serde::Deserialize)]
struct CounterFactPair {
original_prompt: String,
counterfactual_prompt: String,
gt_original_target: String,
gt_counterfactual_target: String,
}
impl CounterFactPair {
fn fact_span_len(&self) -> usize {
self.original_prompt
.find("\n\n")
.unwrap_or(self.original_prompt.len())
}
fn label(&self) -> String {
let end = self.fact_span_len().min(30);
self.original_prompt.chars().take(end).collect()
}
}
fn load_pairs(path: &Path) -> candle_mi::Result<Vec<CounterFactPair>> {
let data = std::fs::read_to_string(path).map_err(|e| {
candle_mi::MIError::Config(format!("failed to read {}: {e}", path.display()))
})?;
let pairs: Vec<CounterFactPair> = serde_json::from_str(&data).map_err(|e| {
candle_mi::MIError::Config(format!("failed to parse {}: {e}", path.display()))
})?;
Ok(pairs)
}
#[derive(Serialize)]
struct JsonOutput {
model_id: String,
block_size: usize,
n_layers: usize,
n_blocks: usize,
total_time_secs: f64,
results: Vec<PromptResult>,
summary: Summary,
}
#[derive(Serialize)]
struct PromptResult {
label: String,
original_prompt: String,
counterfactual_prompt: String,
gt_original_target: String,
gt_counterfactual_target: String,
input_tokens: Vec<String>,
original_continuation: String,
cf_continuation: String,
time_secs: f64,
patches: Vec<PatchResult>,
}
#[derive(Serialize)]
struct PatchResult {
layer: Vec<usize>,
orig_pos: usize,
cf_pos: usize,
ablated_continuation: String,
is_different: bool,
kl_from_original: f32,
token_type: String,
}
#[derive(Serialize)]
struct Summary {
pct_different_by_block: Vec<BlockSummary>,
pct_different_by_token_type: Vec<TypeSummary>,
total_patches: usize,
total_different: usize,
}
#[derive(Serialize)]
struct BlockSummary {
block: usize,
layers: Vec<usize>,
pct_different: f32,
count_different: usize,
count_total: usize,
}
#[derive(Serialize)]
struct TypeSummary {
token_type: String,
pct_different: f32,
count_different: usize,
count_total: usize,
}
fn replace_position(
base: &candle_core::Tensor,
patch_vector: &candle_core::Tensor,
patch_pos: usize,
seq_len: usize,
hidden: usize,
device: &candle_core::Device,
) -> candle_mi::Result<candle_core::Tensor> {
let mut mask_data = vec![0.0_f32; seq_len * hidden];
for i in 0..hidden {
#[allow(clippy::indexing_slicing)]
{
mask_data[patch_pos * hidden + i] = 1.0;
}
}
let mask = candle_core::Tensor::from_vec(mask_data, (seq_len, hidden), device)?;
let patch_broadcast = patch_vector.unsqueeze(0)?.broadcast_as((seq_len, hidden))?;
let one_minus_mask = (1.0 - &mask)?;
let result = (base * &one_minus_mask)? + (patch_broadcast * &mask)?;
Ok(result?)
}
fn build_block_patch(
orig_acts: &FullActivationCache,
cf_acts: &FullActivationCache,
block_layers: &[usize],
orig_pos: usize,
cf_pos: usize,
orig_seq_len: usize,
hidden: usize,
device: &candle_core::Device,
) -> candle_mi::Result<HookSpec> {
let mut hooks = HookSpec::new();
for &layer in block_layers {
let orig_resid = orig_acts
.get_layer(layer)
.ok_or_else(|| candle_mi::MIError::Hook(format!("layer {layer} not in orig cache")))?;
let cf_vector = cf_acts.get_position(layer, cf_pos)?;
let patched = replace_position(
orig_resid,
&cf_vector,
orig_pos,
orig_seq_len,
hidden,
device,
)?
.unsqueeze(0)?;
hooks.intervene(HookPoint::ResidPost(layer), Intervention::Replace(patched));
}
Ok(hooks)
}
fn compute_blocks(n_layers: usize, block_size: usize) -> Vec<Vec<usize>> {
let mut blocks = Vec::new();
let mut start = 0;
while start < n_layers {
let end = (start + block_size).min(n_layers);
blocks.push((start..end).collect());
start = end;
}
blocks
}
fn main() {
if let Err(e) = run() {
eprintln!("Error: {e}");
std::process::exit(1);
}
}
fn run() -> candle_mi::Result<()> {
let args = Args::parse();
#[cfg(feature = "memory")]
let mem_before = MemorySnapshot::now(
&candle_core::Device::cuda_if_available(0).unwrap_or(candle_core::Device::Cpu),
)?;
println!("Loading {}...", args.model);
let t0 = Instant::now();
let model = MIModel::from_pretrained(&args.model)?;
let load_time = t0.elapsed();
let tokenizer = model
.tokenizer()
.ok_or_else(|| candle_mi::MIError::Config("model has no tokenizer".into()))?;
let n_layers = model.num_layers();
let hidden = model.hidden_size();
let blocks = compute_blocks(n_layers, args.block_size);
let n_blocks = blocks.len();
println!(
" Layers: {n_layers}, hidden: {hidden}, device: {:?}",
model.device()
);
if !args.no_runtime {
println!(" Load time: {load_time:.2?}");
}
println!(" Block size: {}, blocks: {n_blocks}", args.block_size);
#[cfg(feature = "memory")]
{
let mem_after = MemorySnapshot::now(model.device())?;
MemoryReport::new(mem_before, mem_after).print_before_after("Model load");
}
let pairs = load_pairs(&args.data)?;
println!(" Data: {} pairs from {}", pairs.len(), args.data.display());
let n_pairs = args.limit.unwrap_or(pairs.len()).min(pairs.len());
let mut block_different: Vec<usize> = vec![0; n_blocks];
let mut block_total: Vec<usize> = vec![0; n_blocks];
let mut type_different: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
let mut type_total: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
let mut total_patches: usize = 0;
let mut total_different: usize = 0;
let mut prompt_results: Vec<PromptResult> = Vec::with_capacity(n_pairs);
let total_start = Instant::now();
for (pair_idx, pair) in pairs.iter().take(n_pairs).enumerate() {
let pair_start = Instant::now();
let label = pair.label();
println!("\n --- [{}/{}] {} ---", pair_idx + 1, n_pairs, label);
let capture_start = Instant::now();
let mut capture_hooks = HookSpec::new();
for layer in 0..n_layers {
capture_hooks.capture(HookPoint::ResidPost(layer));
}
let orig_result = model.forward_text(&pair.original_prompt, &capture_hooks)?;
let cf_result = model.forward_text(&pair.counterfactual_prompt, &capture_hooks)?;
let capture_time = capture_start.elapsed();
let orig_seq_len = orig_result.seq_len();
let cf_seq_len = cf_result.seq_len();
let fact_len = pair.fact_span_len();
let token_labels = orig_result.encoding().label_spans(&[("fact", 0..fact_len)]);
println!(
" GT: {} -> {}",
pair.gt_original_target, pair.gt_counterfactual_target
);
println!(" Tokens: {orig_seq_len} (original), {cf_seq_len} (counterfactual)");
let orig_logits = orig_result.output().get(0)?.get(orig_seq_len - 1)?; let orig_token_id = sample_token(&orig_logits, 0.0)?;
let orig_continuation = tokenizer.decode(&[orig_token_id])?;
let cf_logits = cf_result.output().get(0)?.get(cf_seq_len - 1)?;
let cf_token_id = sample_token(&cf_logits, 0.0)?;
let cf_continuation = tokenizer.decode(&[cf_token_id])?;
let mut orig_acts = FullActivationCache::with_capacity(n_layers);
for layer in 0..n_layers {
let resid = orig_result.require(&HookPoint::ResidPost(layer))?; orig_acts.push(resid.get(0)?); }
let mut cf_acts = FullActivationCache::with_capacity(n_layers);
for layer in 0..n_layers {
let resid = cf_result.require(&HookPoint::ResidPost(layer))?;
cf_acts.push(resid.get(0)?);
}
let orig_matches_gt = orig_continuation.trim() == pair.gt_original_target;
let cf_matches_gt = cf_continuation.trim() == pair.gt_counterfactual_target;
println!(
" Original prediction: {orig_continuation} {} | CF prediction: {cf_continuation} {}",
if orig_matches_gt { "correct" } else { "WRONG" },
if cf_matches_gt { "correct" } else { "WRONG" },
);
let sweep_start = Instant::now();
#[allow(clippy::as_conversions)]
let offset = cf_seq_len as isize - orig_seq_len as isize;
let orig_input = candle_core::Tensor::new(&orig_result.encoding().ids[..], model.device())?
.unsqueeze(0)?;
print!("\n {:>3} {:>14}", "Pos", "Token");
for block in &blocks {
#[allow(clippy::indexing_slicing)]
let first = block[0];
#[allow(clippy::indexing_slicing)]
let last = block[block.len() - 1];
print!(" {:>12}", format!("[{first}-{last}]"));
}
println!(" Type");
print!(" {:->3} {:->14}", "", "");
for _ in &blocks {
print!(" {:->12}", "");
}
println!(" {:->14}", "");
let mut patches: Vec<PatchResult> = Vec::new();
for orig_pos in 0..orig_seq_len {
#[allow(clippy::indexing_slicing)]
let ttype = &token_labels[orig_pos];
#[allow(clippy::as_conversions)]
let cf_pos = if ttype.starts_with("fact") {
orig_pos.min(cf_seq_len.saturating_sub(1))
} else {
let mapped = (orig_pos as isize + offset).max(0);
(mapped as usize).min(cf_seq_len.saturating_sub(1))
};
#[allow(clippy::indexing_slicing)]
let token_str = &orig_result.tokens()[orig_pos];
let short_label: String = token_str.chars().take(12).collect();
print!(" {orig_pos:>3} {short_label:>14}");
for (bi, block) in blocks.iter().enumerate() {
let patch_hooks = build_block_patch(
&orig_acts,
&cf_acts,
block,
orig_pos,
cf_pos,
orig_seq_len,
hidden,
model.device(),
)?;
let patched_cache = model.forward(&orig_input, &patch_hooks)?;
let patched_logits = patched_cache.output().get(0)?.get(orig_seq_len - 1)?;
let patched_token_id = sample_token(&patched_logits, 0.0)?;
let ablated_continuation = tokenizer.decode(&[patched_token_id])?;
let is_different = ablated_continuation != orig_continuation;
let kl = kl_divergence(&orig_logits, &patched_logits)?;
let marker = if is_different { "CHANGED" } else { "same" };
print!(" {marker:>12}");
#[allow(clippy::indexing_slicing)]
{
block_total[bi] += 1;
if is_different {
block_different[bi] += 1;
}
}
*type_total.entry(ttype.clone()).or_insert(0) += 1;
if is_different {
*type_different.entry(ttype.clone()).or_insert(0) += 1;
total_different += 1;
}
total_patches += 1;
patches.push(PatchResult {
layer: block.clone(),
orig_pos,
cf_pos,
ablated_continuation: ablated_continuation.clone(),
is_different,
kl_from_original: kl,
token_type: ttype.clone(),
});
}
println!(" {ttype}");
}
let sweep_time = sweep_start.elapsed();
let pair_time = pair_start.elapsed();
let pair_secs = pair_time.as_secs_f64();
let n_patches = patches.len();
let n_changed: usize = patches.iter().filter(|p| p.is_different).count();
if !args.no_runtime {
println!(
"\n Pair time: {pair_time:.2?} (capture: {capture_time:.2?}, sweep: {sweep_time:.2?}, {n_patches} patches, {n_changed} changed)"
);
}
prompt_results.push(PromptResult {
label,
original_prompt: pair.original_prompt.clone(),
counterfactual_prompt: pair.counterfactual_prompt.clone(),
gt_original_target: pair.gt_original_target.clone(),
gt_counterfactual_target: pair.gt_counterfactual_target.clone(),
input_tokens: orig_result.tokens().to_vec(),
original_continuation: orig_continuation,
cf_continuation,
time_secs: pair_secs,
patches,
});
}
let total_time = total_start.elapsed();
if args.no_runtime {
println!("\n === Summary ({n_pairs} prompts, {total_patches} patches) ===\n");
} else {
println!(
"\n === Summary ({n_pairs} prompts, {total_patches} patches, {total_time:.1?}) ===\n"
);
}
let mut block_summaries = Vec::with_capacity(n_blocks);
for (bi, block) in blocks.iter().enumerate() {
#[allow(clippy::indexing_slicing)]
let diff = block_different[bi];
#[allow(clippy::indexing_slicing)]
let total = block_total[bi];
#[allow(clippy::as_conversions)]
let pct = if total > 0 {
diff as f32 / total as f32 * 100.0
} else {
0.0
};
#[allow(clippy::indexing_slicing)]
let first = block[0];
#[allow(clippy::indexing_slicing)]
let last = block[block.len() - 1];
println!(" Block [{first:>2}-{last:>2}]: {diff:>4}/{total:<4} changed ({pct:.1}%)");
block_summaries.push(BlockSummary {
block: bi,
layers: block.clone(),
pct_different: pct,
count_different: diff,
count_total: total,
});
}
println!();
let mut type_summaries = Vec::new();
let mut type_keys: Vec<String> = type_total.keys().cloned().collect();
type_keys.sort();
for key in &type_keys {
let total = type_total.get(key).copied().unwrap_or(0);
let diff = type_different.get(key).copied().unwrap_or(0);
#[allow(clippy::as_conversions)]
let pct = if total > 0 {
diff as f32 / total as f32 * 100.0
} else {
0.0
};
println!(" {key:<16} {diff:>4}/{total:<4} changed ({pct:.1}%)");
type_summaries.push(TypeSummary {
token_type: key.clone(),
pct_different: pct,
count_different: diff,
count_total: total,
});
}
if let Some(path) = &args.output {
let output = JsonOutput {
model_id: args.model.clone(),
block_size: args.block_size,
n_layers,
n_blocks,
total_time_secs: total_time.as_secs_f64(),
results: prompt_results,
summary: Summary {
pct_different_by_block: block_summaries,
pct_different_by_token_type: type_summaries,
total_patches,
total_different,
},
};
write_json(path, &output)?;
println!("\n JSON written to {}", path.display());
}
println!();
Ok(())
}
fn write_json(path: &Path, output: &JsonOutput) -> candle_mi::Result<()> {
let json = serde_json::to_string_pretty(output)
.map_err(|e| candle_mi::MIError::Config(format!("JSON serialization failed: {e}")))?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
candle_mi::MIError::Config(format!("failed to create {}: {e}", parent.display()))
})?;
}
std::fs::write(path, &json).map_err(|e| {
candle_mi::MIError::Config(format!("failed to write {}: {e}", path.display()))
})?;
Ok(())
}