use candle_core::{DType, Device, Tensor};
use crate::backend::MIModel;
use crate::error::{MIError, Result};
use crate::hooks::{HookPoint, HookSpec, Intervention};
use crate::tokenizer::MITokenizer;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PositionStrategy {
Last,
FirstNewline,
Explicit(usize),
}
impl PositionStrategy {
pub fn resolve(self, tokens: &[u32], newline_token_id: u32) -> Result<usize> {
if tokens.is_empty() {
return Err(MIError::Config(
"PositionStrategy::resolve called on empty token sequence".into(),
));
}
match self {
Self::Last => Ok(tokens.len() - 1),
Self::FirstNewline => tokens
.iter()
.position(|&id| id == newline_token_id)
.ok_or_else(|| {
MIError::Config(format!(
"PositionStrategy::FirstNewline: no newline token (id={newline_token_id}) \
found in {len}-token prompt",
len = tokens.len()
))
}),
Self::Explicit(pos) => {
if pos < tokens.len() {
Ok(pos)
} else {
Err(MIError::Config(format!(
"PositionStrategy::Explicit({pos}) out of range for {len}-token prompt",
len = tokens.len()
)))
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct ContrastiveDirection {
pub layer: usize,
pub vector: Tensor,
pub is_normalised: bool,
pub n_positive: usize,
pub n_negative: usize,
pub position_strategy: PositionStrategy,
}
pub fn build_contrastive_direction(
model: &MIModel,
tokenizer: &MITokenizer,
positive: &[&str],
negative: &[&str],
layer: usize,
strategy: PositionStrategy,
normalise: bool,
) -> Result<ContrastiveDirection> {
if positive.is_empty() {
return Err(MIError::Config(
"build_contrastive_direction: positive prompt set is empty".into(),
));
}
if negative.is_empty() {
return Err(MIError::Config(
"build_contrastive_direction: negative prompt set is empty".into(),
));
}
let n_layers = model.num_layers();
if layer >= n_layers {
return Err(MIError::Config(format!(
"build_contrastive_direction: layer {layer} >= num_layers {n_layers}"
)));
}
let newline_token_id = resolve_newline_token_id(tokenizer)?;
let hook = HookPoint::ResidPost(layer);
let pos_residuals = capture_per_prompt(
model,
tokenizer,
positive,
&hook,
strategy,
newline_token_id,
)?;
let neg_residuals = capture_per_prompt(
model,
tokenizer,
negative,
&hook,
strategy,
newline_token_id,
)?;
let vector = compute_direction(&pos_residuals, &neg_residuals, normalise, model.device())?;
Ok(ContrastiveDirection {
layer,
vector,
is_normalised: normalise,
n_positive: positive.len(),
n_negative: negative.len(),
position_strategy: strategy,
})
}
#[must_use = "the intervention must be registered on a HookSpec to take effect"]
pub fn contrastive_intervention(
direction: &ContrastiveDirection,
strength: f32,
) -> Result<Intervention> {
let scaled = (&direction.vector * f64::from(strength))?;
Ok(Intervention::Add(scaled))
}
pub fn position_delta(direction: &Tensor, position: usize, seq_len: usize) -> Result<Tensor> {
if position >= seq_len {
return Err(MIError::Config(format!(
"position_delta: position {position} >= seq_len {seq_len}"
)));
}
let dims = direction.dims();
if dims.len() != 1 {
return Err(MIError::Config(format!(
"position_delta: direction must be 1-D [hidden]; got shape {dims:?}"
)));
}
let hidden = dims.first().copied().unwrap_or(0);
let zero_row = direction.zeros_like()?;
let rows: Vec<&Tensor> = (0..seq_len)
.map(|i| if i == position { direction } else { &zero_row })
.collect();
let stacked = Tensor::stack(&rows, 0)?;
let with_batch = stacked.unsqueeze(0)?;
let _ = hidden;
Ok(with_batch)
}
fn resolve_newline_token_id(tokenizer: &MITokenizer) -> Result<u32> {
let ids = tokenizer.encode_raw("\n")?;
match ids.len() {
1 => {
ids.first().copied().ok_or_else(|| {
MIError::Tokenizer("resolve_newline_token_id: unexpected empty Vec".into())
})
}
n => Err(MIError::Tokenizer(format!(
"resolve_newline_token_id: '\\n' encodes to {n} tokens, not 1; \
PositionStrategy::FirstNewline is not usable with this tokenizer"
))),
}
}
fn capture_per_prompt(
model: &MIModel,
tokenizer: &MITokenizer,
prompts: &[&str],
hook: &HookPoint,
strategy: PositionStrategy,
newline_token_id: u32,
) -> Result<Vec<Tensor>> {
let mut residuals: Vec<Tensor> = Vec::with_capacity(prompts.len());
for (i, prompt) in prompts.iter().enumerate() {
let tokens = tokenizer.encode(prompt).map_err(|e| {
MIError::Tokenizer(format!(
"capture_per_prompt: prompt #{i} encode failed: {e}"
))
})?;
if tokens.is_empty() {
return Err(MIError::Config(format!(
"capture_per_prompt: prompt #{i} encoded to zero tokens"
)));
}
let position = strategy
.resolve(&tokens, newline_token_id)
.map_err(|e| MIError::Config(format!("capture_per_prompt: prompt #{i}: {e}")))?;
let input = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
let mut hooks = HookSpec::new();
hooks.capture(hook.clone());
let cache = model.forward(&input, &hooks)?;
let residual_3d = cache.require(hook)?;
let residual_2d = residual_3d.squeeze(0)?;
let residual_1d = residual_2d.get(position)?;
let residual_f32 = residual_1d.to_dtype(DType::F32)?;
residuals.push(residual_f32);
}
Ok(residuals)
}
fn compute_direction(
positive: &[Tensor],
negative: &[Tensor],
normalise: bool,
_device: &Device,
) -> Result<Tensor> {
if positive.is_empty() || negative.is_empty() {
return Err(MIError::Config(
"compute_direction: positive and negative sets must both be non-empty".into(),
));
}
let pos_mean = stack_and_mean(positive)?;
let neg_mean = stack_and_mean(negative)?;
let diff = (&pos_mean - &neg_mean)?;
if normalise {
l2_normalise(&diff)
} else {
Ok(diff)
}
}
fn stack_and_mean(tensors: &[Tensor]) -> Result<Tensor> {
let stacked = Tensor::stack(tensors, 0)?;
let mean = stacked.mean(0)?;
Ok(mean)
}
fn l2_normalise(v: &Tensor) -> Result<Tensor> {
let norm_sq = (v * v)?.sum_all()?;
let norm_sq_scalar = norm_sq.to_dtype(DType::F64)?.to_scalar::<f64>()?;
let norm = norm_sq_scalar.sqrt();
if norm < 1e-12_f64 {
return Ok(v.clone());
}
let scaled = (v / norm)?;
Ok(scaled)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
fn cpu() -> Device {
Device::Cpu
}
#[test]
fn position_strategy_last_returns_seq_len_minus_one() {
let tokens = vec![1_u32, 2, 3, 4, 5];
let pos = PositionStrategy::Last.resolve(&tokens, 0).unwrap();
assert_eq!(pos, 4);
}
#[test]
fn position_strategy_first_newline_finds_correct_index() {
let tokens = vec![1_u32, 2, 198, 3, 198];
let pos = PositionStrategy::FirstNewline
.resolve(&tokens, 198)
.unwrap();
assert_eq!(pos, 2);
}
#[test]
fn position_strategy_first_newline_errors_when_absent() {
let tokens = vec![1_u32, 2, 3, 4, 5];
let err = PositionStrategy::FirstNewline.resolve(&tokens, 198);
assert!(err.is_err());
}
#[test]
fn position_strategy_explicit_errors_on_out_of_range() {
let tokens = vec![1_u32, 2, 3];
let err = PositionStrategy::Explicit(5).resolve(&tokens, 0);
assert!(err.is_err());
}
#[test]
fn position_strategy_resolve_errors_on_empty_tokens() {
let tokens: Vec<u32> = vec![];
let err = PositionStrategy::Last.resolve(&tokens, 0);
assert!(err.is_err());
}
#[test]
fn compute_direction_with_identical_sets_is_near_zero() {
let device = cpu();
let row = Tensor::new(&[1.0_f32, 2.0, 3.0, 4.0], &device).unwrap();
let positive = vec![row.clone(), row.clone()];
let negative = vec![row.clone(), row.clone()];
let direction = compute_direction(&positive, &negative, false, &device).unwrap();
let norm_sq = (&direction * &direction).unwrap().sum_all().unwrap();
let norm_sq_scalar = norm_sq.to_scalar::<f32>().unwrap();
assert!(norm_sq_scalar < 1e-10, "norm_sq = {norm_sq_scalar}");
}
#[test]
fn compute_direction_with_disjoint_means() {
let device = cpu();
let pos_row = Tensor::new(&[2.0_f32, 4.0, 6.0], &device).unwrap();
let neg_row = Tensor::new(&[1.0_f32, 2.0, 3.0], &device).unwrap();
let positive = vec![pos_row.clone(), pos_row.clone()];
let negative = vec![neg_row.clone(), neg_row.clone()];
let direction = compute_direction(&positive, &negative, false, &device).unwrap();
let values: Vec<f32> = direction.to_vec1().unwrap();
assert!((values[0] - 1.0).abs() < 1e-6);
assert!((values[1] - 2.0).abs() < 1e-6);
assert!((values[2] - 3.0).abs() < 1e-6);
}
#[test]
fn normalised_direction_has_unit_l2_norm() {
let device = cpu();
let pos_row = Tensor::new(&[3.0_f32, 0.0, 4.0], &device).unwrap();
let neg_row = Tensor::new(&[0.0_f32, 0.0, 0.0], &device).unwrap();
let positive = vec![pos_row];
let negative = vec![neg_row];
let direction = compute_direction(&positive, &negative, true, &device).unwrap();
let norm_sq = (&direction * &direction).unwrap().sum_all().unwrap();
let norm = norm_sq.to_scalar::<f32>().unwrap().sqrt();
assert!((norm - 1.0_f32).abs() < 1e-6, "norm = {norm}");
}
#[test]
fn contrastive_intervention_payload_scales_with_strength() {
let device = cpu();
let direction = ContrastiveDirection {
layer: 5,
vector: Tensor::new(&[1.0_f32, 0.0, 0.0], &device).unwrap(),
is_normalised: true,
n_positive: 1,
n_negative: 1,
position_strategy: PositionStrategy::Last,
};
let intervention = contrastive_intervention(&direction, 2.5_f32).unwrap();
match intervention {
Intervention::Add(payload) => {
let values: Vec<f32> = payload.to_vec1().unwrap();
assert!(
(values[0] - 2.5_f32).abs() < 1e-6,
"values[0] = {}",
values[0]
);
assert!(values[1].abs() < 1e-6);
assert!(values[2].abs() < 1e-6);
}
other => panic!("expected Intervention::Add, got {other:?}"),
}
}
#[test]
fn position_delta_places_vector_at_correct_index() {
let device = cpu();
let direction = Tensor::new(&[7.0_f32, 8.0, 9.0], &device).unwrap();
let delta = position_delta(&direction, 2, 4).unwrap();
assert_eq!(delta.dims(), &[1, 4, 3]);
let squeezed = delta.squeeze(0).unwrap();
let row0: Vec<f32> = squeezed.get(0).unwrap().to_vec1().unwrap();
let row2: Vec<f32> = squeezed.get(2).unwrap().to_vec1().unwrap();
let row3: Vec<f32> = squeezed.get(3).unwrap().to_vec1().unwrap();
assert!(row0.iter().all(|&x| x.abs() < 1e-6));
assert!((row2[0] - 7.0).abs() < 1e-6);
assert!((row2[1] - 8.0).abs() < 1e-6);
assert!((row2[2] - 9.0).abs() < 1e-6);
assert!(row3.iter().all(|&x| x.abs() < 1e-6));
}
#[test]
fn position_delta_errors_on_out_of_range() {
let device = cpu();
let direction = Tensor::new(&[1.0_f32, 2.0], &device).unwrap();
let err = position_delta(&direction, 5, 4);
assert!(err.is_err());
}
#[test]
fn position_delta_errors_on_non_1d_direction() {
let device = cpu();
let direction = Tensor::new(&[[1.0_f32, 2.0], [3.0, 4.0]], &device).unwrap();
let err = position_delta(&direction, 0, 4);
assert!(err.is_err());
}
}