#![allow(clippy::doc_markdown)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::missing_docs_in_private_items)]
#![allow(clippy::too_many_lines)]
use candle_core::{DType, Device, Tensor};
use candle_mi::clt::{CltFeatureId, CrossLayerTranscoder};
use candle_mi::{HookPoint, HookSpec, MIModel, extract_token_prob};
#[cfg(feature = "memory")]
use candle_mi::{MemoryReport, MemorySnapshot};
use clap::Parser;
use serde::Serialize;
use std::path::PathBuf;
use std::time::Instant;
#[derive(Parser)]
#[command(name = "correction_test")]
#[command(about = "Test whether downstream layers can reverse a prolepsis commitment")]
struct Args {
#[arg(long, default_value = "google/gemma-2-2b")]
model: String,
#[arg(long, default_value = "mntss/clt-gemma-2-2b-426k")]
clt_repo: String,
#[arg(long)]
prompt: Option<String>,
#[arg(long, default_value = "L22:10243")]
commit_feature: String,
#[arg(long, default_value = "around")]
commit_word: String,
#[arg(long)]
suppress: Vec<String>,
#[arg(long, default_value_t = 10.0)]
commit_strength: f32,
#[arg(long, default_value = "L20:12386")]
correct_feature: String,
#[arg(long, default_value = "back")]
correct_word: String,
#[arg(long, default_value_t = 23)]
correct_from_layer: usize,
#[arg(long, default_value_t = 20.0)]
max_correct_strength: f32,
#[arg(long, default_value_t = 12)]
correct_steps: usize,
#[arg(long)]
planning_site: Option<usize>,
#[arg(long)]
output: Option<PathBuf>,
#[arg(long)]
no_runtime: bool,
}
#[derive(Serialize)]
struct JsonOutput {
model_id: String,
clt_repo: String,
prompt: String,
planning_site: usize,
output_position: usize,
n_layers: usize,
n_heads: usize,
commit_feature: String,
commit_word: String,
suppress_features: Vec<String>,
commit_strength: f32,
correct_feature: String,
correct_word: String,
correct_from_layer: usize,
baseline: CorrectionPoint,
commitment_only: CorrectionPoint,
correction_sweep: Vec<CorrectionPoint>,
total_time_secs: f64,
}
#[derive(Serialize)]
struct CorrectionPoint {
correct_strength: f32,
p_commit_word: f32,
p_correct_word: f32,
top1_token: String,
top1_prob: f32,
routing_head_delta: f32,
total_routing_shift: f32,
time_secs: f64,
}
const DEFAULT_PROMPT: &str = "The stars were twinkling in the night,\n\
The lanterns cast a golden light.\n\
She wandered in the dark about,\n\
And found a hidden passage";
fn main() {
if let Err(e) = run() {
eprintln!("Error: {e}");
std::process::exit(1);
}
}
fn run() -> candle_mi::Result<()> {
tracing_subscriber::fmt::init();
let args = Args::parse();
let t_total = Instant::now();
let commit_feature = parse_clt_feature(&args.commit_feature)?;
let correct_feature = parse_clt_feature(&args.correct_feature)?;
let suppress_features: Vec<CltFeatureId> = args
.suppress
.iter()
.map(|s| parse_clt_feature(s))
.collect::<candle_mi::Result<Vec<_>>>()?;
if args.correct_from_layer < correct_feature.layer {
return Err(candle_mi::MIError::Config(format!(
"correct_from_layer ({}) must be >= correct_feature source layer ({})",
args.correct_from_layer, correct_feature.layer
)));
}
eprintln!("=== Correction Test: Can Downstream Layers Reverse a Commitment? ===\n");
eprintln!("Model: {}", args.model);
eprintln!("CLT: {}", args.clt_repo);
eprintln!(
"Commit: {} (\"{}\")",
commit_feature, args.commit_word
);
eprintln!(
"Correct: {} (\"{}\")",
correct_feature, args.correct_word
);
eprintln!("Correct from: L{}", args.correct_from_layer);
eprintln!("Commit str: {}", args.commit_strength);
eprintln!("Max correct str:{}", args.max_correct_strength);
eprintln!("Steps: {}", args.correct_steps);
if !suppress_features.is_empty() {
eprintln!("Suppress: {:?}", suppress_features);
}
eprintln!("\nLoading model...");
#[cfg(feature = "memory")]
let mem_before =
MemorySnapshot::now(&candle_core::Device::cuda_if_available(0).unwrap_or(Device::Cpu))?;
let t_load = Instant::now();
let model = MIModel::from_pretrained(&args.model)?;
let load_time = t_load.elapsed();
#[cfg(feature = "memory")]
{
let mem_after = MemorySnapshot::now(model.device())?;
MemoryReport::new(mem_before, mem_after).print_before_after("Model load");
}
let n_layers = model.num_layers();
let n_heads = model.num_heads();
let device = model.device().clone();
let tokenizer = model
.tokenizer()
.ok_or_else(|| candle_mi::MIError::Tokenizer("model has no bundled tokenizer".into()))?;
if !args.no_runtime {
eprintln!(" Load time: {load_time:.2?}");
}
eprintln!(" {n_layers} layers, {n_heads} heads/layer, device={device:?}");
eprintln!("Opening CLT: {}...", args.clt_repo);
let mut clt = CrossLayerTranscoder::open(&args.clt_repo)?;
let mut all_features: Vec<CltFeatureId> = suppress_features.clone();
all_features.push(commit_feature);
all_features.push(correct_feature);
eprintln!("Caching decoder vectors for all downstream layers...");
clt.cache_steering_vectors_all_downstream(&all_features, &device)?;
let prompt = args.prompt.as_deref().unwrap_or(DEFAULT_PROMPT);
let prompt_with_space = format!("{prompt} ");
let token_ids = tokenizer.encode(&prompt_with_space)?;
let seq_len = token_ids.len();
let output_pos = seq_len - 1;
let planning_site = match args.planning_site {
Some(pos) => pos,
None => find_planning_site(tokenizer, &token_ids, "about")?,
};
let commit_token_id = tokenizer.find_token_id(&args.commit_word)?;
let correct_token_id = tokenizer.find_token_id(&args.correct_word)?;
eprintln!(
" Commit token: \"{}\" (id={commit_token_id})",
args.commit_word
);
eprintln!(
" Correct token: \"{}\" (id={correct_token_id})",
args.correct_word
);
eprintln!(" Planning site: position {planning_site}");
eprintln!(" Output position: {output_pos}");
let suppress_entries: Vec<(CltFeatureId, usize)> = suppress_features
.iter()
.flat_map(|feat| (feat.layer..n_layers).map(move |l| (*feat, l)))
.collect();
let commit_entries: Vec<(CltFeatureId, usize)> = (commit_feature.layer..n_layers)
.map(|l| (commit_feature, l))
.collect();
let correct_entries: Vec<(CltFeatureId, usize)> = (args.correct_from_layer..n_layers)
.map(|l| (correct_feature, l))
.collect();
eprintln!(
" Suppress: {} entries, Commit: {} entries, Correct: {} entries (L{}--L{})",
suppress_entries.len(),
commit_entries.len(),
correct_entries.len(),
args.correct_from_layer,
n_layers - 1
);
let input = Tensor::new(&token_ids[..], &device)?.unsqueeze(0)?;
eprintln!("\n--- Baseline (no intervention) ---");
let t_step = Instant::now();
let mut baseline_hooks = HookSpec::new();
for layer in 0..n_layers {
baseline_hooks.capture(HookPoint::AttnPattern(layer));
}
let baseline_cache = model.forward(&input, &baseline_hooks)?;
let baseline_attn = extract_attention_weights(
&baseline_cache,
n_layers,
n_heads,
output_pos,
planning_site,
)?;
let baseline_p_commit = extract_token_prob(baseline_cache.output(), commit_token_id)?;
let baseline_p_correct = extract_token_prob(baseline_cache.output(), correct_token_id)?;
let baseline_top1 = extract_top1(&baseline_cache, output_pos, tokenizer)?;
let baseline_time = t_step.elapsed();
eprintln!(
" P(\"{}\") = {baseline_p_commit:.6e}, P(\"{}\") = {baseline_p_correct:.6e}, top1 = \"{}\" ({:.4})",
args.commit_word, args.correct_word, baseline_top1.0, baseline_top1.1
);
let baseline_point = CorrectionPoint {
correct_strength: 0.0,
p_commit_word: baseline_p_commit,
p_correct_word: baseline_p_correct,
top1_token: baseline_top1.0.clone(),
top1_prob: baseline_top1.1,
routing_head_delta: 0.0,
total_routing_shift: 0.0,
time_secs: baseline_time.as_secs_f64(),
};
eprintln!("\n--- Commitment only (suppress + inject, no correction) ---");
let commitment_point = run_correction_pass(
&model,
&input,
&clt,
&suppress_entries,
&commit_entries,
&correct_entries,
n_layers,
n_heads,
seq_len,
planning_site,
output_pos,
args.commit_strength,
0.0, commit_token_id,
correct_token_id,
&baseline_attn,
tokenizer,
&args.commit_word,
&args.correct_word,
&device,
!args.no_runtime,
)?;
eprintln!(
"\n--- Correction sweep (commit_str={}, correct L{}+) ---",
args.commit_strength, args.correct_from_layer
);
eprintln!(
" {:>6} {:>12} {:>12} {:>10} {:>10} {:>10}",
"str", "P(commit)", "P(correct)", "top1", "H5 delta", "total"
);
let mut sweep: Vec<CorrectionPoint> = Vec::with_capacity(args.correct_steps + 1);
let step_size = args.max_correct_strength / args.correct_steps as f32;
for step in 0..=args.correct_steps {
let correct_strength = step as f32 * step_size;
let point = run_correction_pass(
&model,
&input,
&clt,
&suppress_entries,
&commit_entries,
&correct_entries,
n_layers,
n_heads,
seq_len,
planning_site,
output_pos,
args.commit_strength,
correct_strength,
commit_token_id,
correct_token_id,
&baseline_attn,
tokenizer,
&args.commit_word,
&args.correct_word,
&device,
false, )?;
eprintln!(
" {:>6.1} {:>12.6e} {:>12.6e} {:>10} {:>+10.4} {:>10.4}",
correct_strength,
point.p_commit_word,
point.p_correct_word,
point.top1_token,
point.routing_head_delta,
point.total_routing_shift
);
sweep.push(point);
}
let total_time = t_total.elapsed();
eprintln!("\n=== Summary ===");
eprintln!(
" Baseline P(\"{}\") = {baseline_p_commit:.6e}",
args.commit_word
);
eprintln!(
" Committed P(\"{}\") = {:.6e}",
args.commit_word, commitment_point.p_commit_word
);
eprintln!(
" At max correction (str={}): P(\"{}\") = {:.6e}, P(\"{}\") = {:.6e}",
args.max_correct_strength,
args.commit_word,
sweep[sweep.len() - 1].p_commit_word,
args.correct_word,
sweep[sweep.len() - 1].p_correct_word
);
if !args.no_runtime {
eprintln!(" Total time: {total_time:.2?}");
}
let output = JsonOutput {
model_id: args.model,
clt_repo: args.clt_repo,
prompt: prompt.to_owned(),
planning_site,
output_position: output_pos,
n_layers,
n_heads,
commit_feature: args.commit_feature,
commit_word: args.commit_word.to_owned(),
suppress_features: args.suppress,
commit_strength: args.commit_strength,
correct_feature: args.correct_feature,
correct_word: args.correct_word.to_owned(),
correct_from_layer: args.correct_from_layer,
baseline: baseline_point,
commitment_only: commitment_point,
correction_sweep: sweep,
total_time_secs: total_time.as_secs_f64(),
};
if let Some(ref path) = args.output {
write_json(path, &output)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn run_correction_pass(
model: &MIModel,
input: &Tensor,
clt: &CrossLayerTranscoder,
suppress_entries: &[(CltFeatureId, usize)],
commit_entries: &[(CltFeatureId, usize)],
correct_entries: &[(CltFeatureId, usize)],
n_layers: usize,
n_heads: usize,
seq_len: usize,
planning_site: usize,
output_pos: usize,
commit_strength: f32,
correct_strength: f32,
commit_token_id: u32,
correct_token_id: u32,
baseline_attn: &[Vec<f32>],
tokenizer: &candle_mi::MITokenizer,
commit_word: &str,
correct_word: &str,
device: &Device,
verbose: bool,
) -> candle_mi::Result<CorrectionPoint> {
let t_step = Instant::now();
let mut hooks = if !suppress_entries.is_empty() {
clt.prepare_hook_injection(
suppress_entries,
planning_site,
seq_len,
-commit_strength,
device,
)?
} else {
HookSpec::new()
};
let commit_hooks = clt.prepare_hook_injection(
commit_entries,
planning_site,
seq_len,
commit_strength,
device,
)?;
hooks.extend(&commit_hooks);
if correct_strength > 0.0 && !correct_entries.is_empty() {
let correct_hooks = clt.prepare_hook_injection(
correct_entries,
planning_site,
seq_len,
correct_strength,
device,
)?;
hooks.extend(&correct_hooks);
}
for layer in 0..n_layers {
hooks.capture(HookPoint::AttnPattern(layer));
}
let cache = model.forward(input, &hooks)?;
let p_commit = extract_token_prob(cache.output(), commit_token_id)?;
let p_correct = extract_token_prob(cache.output(), correct_token_id)?;
let (top1_token, top1_prob) = extract_top1(&cache, output_pos, tokenizer)?;
let attn = extract_attention_weights(&cache, n_layers, n_heads, output_pos, planning_site)?;
let routing_head_delta = if n_layers > 21 && n_heads > 5 {
attn[21][5] - baseline_attn[21][5]
} else {
0.0
};
let total_routing_shift: f32 = (0..n_layers)
.flat_map(|l| (0..n_heads).map(move |h| (l, h)))
.map(|(l, h)| (attn[l][h] - baseline_attn[l][h]).abs())
.sum();
let step_time = t_step.elapsed();
if verbose {
eprintln!(
" P(\"{commit_word}\") = {p_commit:.6e}, P(\"{correct_word}\") = {p_correct:.6e}, \
top1 = \"{top1_token}\" ({top1_prob:.4})"
);
}
Ok(CorrectionPoint {
correct_strength,
p_commit_word: p_commit,
p_correct_word: p_correct,
top1_token,
top1_prob,
routing_head_delta,
total_routing_shift,
time_secs: step_time.as_secs_f64(),
})
}
fn extract_attention_weights(
cache: &candle_mi::HookCache,
n_layers: usize,
n_heads: usize,
query_pos: usize,
key_pos: usize,
) -> candle_mi::Result<Vec<Vec<f32>>> {
let mut result: Vec<Vec<f32>> = Vec::with_capacity(n_layers);
for layer in 0..n_layers {
let pattern = cache.require(&HookPoint::AttnPattern(layer))?;
let slice = pattern
.get(0)? .narrow(1, query_pos, 1)? .squeeze(1)? .narrow(1, key_pos, 1)? .squeeze(1)? .to_dtype(DType::F32)?;
let weights: Vec<f32> = slice.to_vec1()?;
assert!(
weights.len() == n_heads,
"expected {n_heads} heads, got {}",
weights.len()
);
result.push(weights);
}
Ok(result)
}
fn extract_top1(
cache: &candle_mi::HookCache,
output_pos: usize,
tokenizer: &candle_mi::MITokenizer,
) -> candle_mi::Result<(String, f32)> {
let logits = cache
.output()
.get(0)? .narrow(0, output_pos, 1)? .squeeze(0)? .to_dtype(DType::F32)?;
let probs = candle_nn::ops::softmax_last_dim(&logits.unsqueeze(0)?)?.squeeze(0)?;
let probs_vec: Vec<f32> = probs.to_vec1()?;
let (top_idx, &top_prob) = probs_vec
.iter()
.enumerate()
.max_by(|(_, a): &(usize, &f32), (_, b): &(usize, &f32)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.ok_or_else(|| candle_mi::MIError::Config("empty probability vector".into()))?;
let token_str = tokenizer.decode(&[top_idx as u32])?;
Ok((token_str, top_prob))
}
fn find_planning_site(
tokenizer: &candle_mi::MITokenizer,
tokens: &[u32],
word: &str,
) -> candle_mi::Result<usize> {
for (i, &tid) in tokens.iter().enumerate() {
let decoded = tokenizer.decode(&[tid])?;
let trimmed = decoded.trim();
if trimmed == word {
eprintln!(" Auto-detected planning site: position {i} (token \"{trimmed}\")");
return Ok(i);
}
}
Err(candle_mi::MIError::Tokenizer(format!(
"could not find \"{word}\" in tokenized prompt"
)))
}
fn parse_clt_feature(s: &str) -> candle_mi::Result<CltFeatureId> {
let s = s.trim();
if !s.starts_with('L') {
return Err(candle_mi::MIError::Config(format!(
"CLT feature must start with 'L', got \"{s}\""
)));
}
let rest = &s[1..];
let parts: Vec<&str> = rest.splitn(2, ':').collect();
if parts.len() != 2 {
return Err(candle_mi::MIError::Config(format!(
"CLT feature must be \"L<layer>:<index>\", got \"{s}\""
)));
}
let layer: usize = parts[0]
.parse()
.map_err(|_| candle_mi::MIError::Config(format!("invalid layer number in \"{s}\"")))?;
let index: usize = parts[1]
.parse()
.map_err(|_| candle_mi::MIError::Config(format!("invalid feature index in \"{s}\"")))?;
Ok(CltFeatureId { layer, index })
}
fn write_json(path: &std::path::Path, output: &JsonOutput) -> candle_mi::Result<()> {
let json = serde_json::to_string_pretty(output)
.map_err(|e| candle_mi::MIError::Config(format!("JSON serialization failed: {e}")))?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
candle_mi::MIError::Config(format!("failed to create {}: {e}", parent.display()))
})?;
}
std::fs::write(path, &json).map_err(|e| {
candle_mi::MIError::Config(format!("failed to write {}: {e}", path.display()))
})?;
eprintln!("\nOutput written to {}", path.display());
Ok(())
}