Skip to main content

candle_mi/interp/
intervention.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Attention intervention for causal experiments.
4//!
5//! Enables causal intervention experiments by surgically modifying
6//! attention edges and measuring impact on model outputs.
7//!
8//! ## Intervention Types
9//!
10//! - **Knockout**: Remove attention edges (pre-softmax, add `-inf`)
11//! - **Scale**: Multiply attention by a factor (post-softmax, then renormalize)
12//! - **`SetValue`**: Set attention to a specific value (post-softmax, then renormalize)
13//!
14//! ## Intervention Mechanism
15//!
16//! Knockout is implemented by adding negative infinity to specified attention
17//! scores BEFORE softmax. After softmax, these edges become exactly 0,
18//! completely removing their contribution to the output.
19//!
20//! Steering (Scale/SetValue) is applied AFTER softmax, modifying attention
21//! weights and renormalizing rows to maintain valid probability distributions.
22
23use std::collections::{HashMap, HashSet};
24
25use candle_core::{DType, Device, Tensor};
26
27use crate::error::{MIError, Result};
28
29// ---------------------------------------------------------------------------
30// Internal helpers
31// ---------------------------------------------------------------------------
32
33/// Extract a 4D tensor to nested `Vec`s.
34///
35/// Candle doesn't provide `to_vec4()`, so we flatten and reshape manually.
36///
37/// # Shapes
38///
39/// - `tensor`: `[d0, d1, d2, d3]`
40///
41/// # Errors
42///
43/// Returns [`MIError::Intervention`] if the tensor is not 4D.
44fn tensor_to_vec4(tensor: &Tensor) -> Result<Vec<Vec<Vec<Vec<f32>>>>> {
45    let shape = tensor.dims();
46    if shape.len() != 4 {
47        return Err(MIError::Intervention(format!(
48            "expected 4D tensor, got {}D",
49            shape.len()
50        )));
51    }
52    let s0 = shape.first().copied().unwrap_or(0);
53    let s1 = shape.get(1).copied().unwrap_or(0);
54    let s2 = shape.get(2).copied().unwrap_or(0);
55    let s3 = shape.get(3).copied().unwrap_or(0);
56
57    let flat: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
58
59    let mut result = Vec::with_capacity(s0);
60    let mut iter = flat.into_iter();
61    for _ in 0..s0 {
62        let mut axis1 = Vec::with_capacity(s1);
63        for _ in 0..s1 {
64            let mut axis2 = Vec::with_capacity(s2);
65            for _ in 0..s2 {
66                let row: Vec<f32> = iter.by_ref().take(s3).collect();
67                axis2.push(row);
68            }
69            axis1.push(axis2);
70        }
71        result.push(axis1);
72    }
73
74    Ok(result)
75}
76
77/// Convert logits to a probability distribution (softmax).
78fn softmax_to_vec(logits: &Tensor) -> Result<Vec<f32>> {
79    // PROMOTE: softmax needs f32 for numerical stability
80    let logits_f32 = logits.to_dtype(DType::F32)?;
81    let probs = candle_nn::ops::softmax_last_dim(&logits_f32)?;
82    Ok(probs.flatten_all()?.to_vec1()?)
83}
84
85/// Expand edge specifications, resolving sentinel values (`usize::MAX`).
86///
87/// - `(from, usize::MAX)` → all edges FROM `from` to every position
88/// - `(usize::MAX, to)` → all edges TO `to` from every position
89fn expand_edges(edges: &[AttentionEdge], seq_len: usize) -> Vec<AttentionEdge> {
90    let mut expanded = Vec::new();
91
92    for edge in edges {
93        match (edge.from_pos, edge.to_pos) {
94            (from, usize::MAX) if from != usize::MAX => {
95                for to in 0..seq_len {
96                    expanded.push(AttentionEdge::new(from, to));
97                }
98            }
99            (usize::MAX, to) if to != usize::MAX => {
100                for from in 0..seq_len {
101                    expanded.push(AttentionEdge::new(from, to));
102                }
103            }
104            (from, to) if from != usize::MAX && to != usize::MAX => {
105                expanded.push(*edge);
106            }
107            _ => {} // Invalid sentinel combination, skip
108        }
109    }
110
111    expanded
112}
113
114// ===========================================================================
115// Part 1: Knockout Specification
116// ===========================================================================
117
118/// Specification for which layers to target.
119#[non_exhaustive]
120#[derive(Debug, Clone)]
121pub enum LayerSpec {
122    /// Apply to all layers.
123    All,
124    /// Apply to specific layers.
125    Specific(Vec<usize>),
126    /// Apply to a range of layers (inclusive).
127    Range {
128        /// First layer (inclusive).
129        start: usize,
130        /// Last layer (inclusive).
131        end: usize,
132    },
133}
134
135/// Specification for which heads to target.
136#[non_exhaustive]
137#[derive(Debug, Clone)]
138pub enum HeadSpec {
139    /// Apply to all heads.
140    All,
141    /// Apply to specific heads.
142    Specific(Vec<usize>),
143}
144
145/// A single attention edge from one position to another.
146///
147/// Uses `usize::MAX` as a sentinel for "all positions" (expanded at
148/// mask creation time based on actual sequence length).
149#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
150pub struct AttentionEdge {
151    /// Token position that is attending (row in attention matrix).
152    pub from_pos: usize,
153    /// Token position being attended to (column in attention matrix).
154    pub to_pos: usize,
155}
156
157impl AttentionEdge {
158    /// Create a new edge.
159    #[must_use]
160    pub const fn new(from_pos: usize, to_pos: usize) -> Self {
161        Self { from_pos, to_pos }
162    }
163}
164
165/// Specification for which attention edges to knock out.
166///
167/// An "edge" is attention from one token position to another.
168/// Knockout removes the edge completely by setting its attention weight
169/// to 0 (via pre-softmax `-inf` masking).
170///
171/// # Example
172///
173/// ```
174/// use candle_mi::KnockoutSpec;
175///
176/// let spec = KnockoutSpec::new()
177///     .layer(10)
178///     .from_to_positions(5, &[0, 1, 2, 3]);
179/// assert_eq!(spec.edges.len(), 4);
180/// ```
181#[derive(Debug, Clone)]
182#[must_use]
183pub struct KnockoutSpec {
184    /// Layer indices to apply intervention.
185    pub layers: LayerSpec,
186    /// Head indices to apply intervention.
187    pub heads: HeadSpec,
188    /// Attention edges to knock out: `(from_position, to_position)`.
189    pub edges: Vec<AttentionEdge>,
190}
191
192impl KnockoutSpec {
193    /// Create a new empty knockout specification (all layers, all heads).
194    pub const fn new() -> Self {
195        Self {
196            layers: LayerSpec::All,
197            heads: HeadSpec::All,
198            edges: Vec::new(),
199        }
200    }
201
202    /// Target a single layer.
203    pub fn layer(mut self, layer: usize) -> Self {
204        self.layers = LayerSpec::Specific(vec![layer]);
205        self
206    }
207
208    /// Target multiple specific layers.
209    pub fn layers(mut self, layers: &[usize]) -> Self {
210        self.layers = LayerSpec::Specific(layers.to_vec());
211        self
212    }
213
214    /// Target a range of layers (inclusive).
215    pub fn layer_range(mut self, start: usize, end: usize) -> Self {
216        self.layers = LayerSpec::Range { start, end };
217        self
218    }
219
220    /// Target a single head.
221    pub fn head(mut self, head: usize) -> Self {
222        self.heads = HeadSpec::Specific(vec![head]);
223        self
224    }
225
226    /// Target multiple specific heads.
227    pub fn heads(mut self, heads: &[usize]) -> Self {
228        self.heads = HeadSpec::Specific(heads.to_vec());
229        self
230    }
231
232    /// Add a single edge to knock out.
233    pub fn edge(mut self, from_pos: usize, to_pos: usize) -> Self {
234        self.edges.push(AttentionEdge::new(from_pos, to_pos));
235        self
236    }
237
238    /// Knock out all attention FROM a specific position.
239    pub fn from_position(mut self, from_pos: usize) -> Self {
240        self.edges.push(AttentionEdge::new(from_pos, usize::MAX));
241        self
242    }
243
244    /// Knock out all attention TO a specific position.
245    pub fn to_position(mut self, to_pos: usize) -> Self {
246        self.edges.push(AttentionEdge::new(usize::MAX, to_pos));
247        self
248    }
249
250    /// Add edges from one position to several positions.
251    pub fn from_to_positions(mut self, from_pos: usize, to_positions: &[usize]) -> Self {
252        for &to_pos in to_positions {
253            self.edges.push(AttentionEdge::new(from_pos, to_pos));
254        }
255        self
256    }
257
258    /// Check if this layer should have intervention applied.
259    #[must_use]
260    pub fn applies_to_layer(&self, layer: usize) -> bool {
261        match &self.layers {
262            LayerSpec::All => true,
263            LayerSpec::Specific(layers) => layers.contains(&layer),
264            LayerSpec::Range { start, end } => layer >= *start && layer <= *end,
265        }
266    }
267
268    /// Check if this head should have intervention applied.
269    #[must_use]
270    pub fn applies_to_head(&self, head: usize) -> bool {
271        match &self.heads {
272            HeadSpec::All => true,
273            HeadSpec::Specific(heads) => heads.contains(&head),
274        }
275    }
276
277    /// Validate the spec against model dimensions.
278    ///
279    /// # Errors
280    ///
281    /// Returns [`MIError::Intervention`] if any layer, head, or edge
282    /// position is out of range.
283    pub fn validate(&self, n_layers: usize, n_heads: usize, seq_len: usize) -> Result<()> {
284        validate_layers(&self.layers, n_layers)?;
285        validate_heads(&self.heads, n_heads)?;
286        validate_edges(&self.edges, seq_len)?;
287        Ok(())
288    }
289}
290
291impl Default for KnockoutSpec {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297// ===========================================================================
298// Part 2: Steering Specification
299// ===========================================================================
300
301/// Type of intervention to apply to attention weights.
302#[non_exhaustive]
303#[derive(Debug, Clone, Copy, PartialEq, Default)]
304pub enum InterventionType {
305    /// Set attention to zero (pre-softmax: add `-inf`).
306    #[default]
307    Knockout,
308    /// Multiply attention by factor (post-softmax, then renormalize).
309    Scale(f32),
310    /// Set attention to specific target value (post-softmax, then renormalize).
311    SetValue(f32),
312}
313
314/// Specification for attention steering interventions.
315///
316/// Unlike knockout which removes edges, steering modifies attention weights
317/// by scaling or setting values, then renormalizing to maintain valid
318/// probability distributions.
319///
320/// # Example
321///
322/// ```
323/// use candle_mi::SteeringSpec;
324///
325/// let spec = SteeringSpec::scale(3.0)
326///     .layer(16)
327///     .from_to_positions(5, &[0, 1, 2]);
328/// assert_eq!(spec.edges.len(), 3);
329/// ```
330#[derive(Debug, Clone)]
331#[must_use]
332pub struct SteeringSpec {
333    /// Layer indices to apply intervention.
334    pub layers: LayerSpec,
335    /// Head indices to apply intervention.
336    pub heads: HeadSpec,
337    /// Attention edges to modify.
338    pub edges: Vec<AttentionEdge>,
339    /// Type of intervention to apply.
340    pub intervention_type: InterventionType,
341}
342
343impl SteeringSpec {
344    /// Create a new steering specification with the given intervention type.
345    pub const fn new(intervention_type: InterventionType) -> Self {
346        Self {
347            layers: LayerSpec::All,
348            heads: HeadSpec::All,
349            edges: Vec::new(),
350            intervention_type,
351        }
352    }
353
354    /// Create a scaling intervention.
355    pub const fn scale(factor: f32) -> Self {
356        Self::new(InterventionType::Scale(factor))
357    }
358
359    /// Create a set-value intervention.
360    pub const fn set_value(target: f32) -> Self {
361        Self::new(InterventionType::SetValue(target))
362    }
363
364    /// Target a single layer.
365    pub fn layer(mut self, layer: usize) -> Self {
366        self.layers = LayerSpec::Specific(vec![layer]);
367        self
368    }
369
370    /// Target multiple specific layers.
371    pub fn layers(mut self, layers: &[usize]) -> Self {
372        self.layers = LayerSpec::Specific(layers.to_vec());
373        self
374    }
375
376    /// Target a range of layers (inclusive).
377    pub fn layer_range(mut self, start: usize, end: usize) -> Self {
378        self.layers = LayerSpec::Range { start, end };
379        self
380    }
381
382    /// Target a single head.
383    pub fn head(mut self, head: usize) -> Self {
384        self.heads = HeadSpec::Specific(vec![head]);
385        self
386    }
387
388    /// Target multiple specific heads.
389    pub fn heads(mut self, heads: &[usize]) -> Self {
390        self.heads = HeadSpec::Specific(heads.to_vec());
391        self
392    }
393
394    /// Add a single edge to modify.
395    pub fn edge(mut self, from_pos: usize, to_pos: usize) -> Self {
396        self.edges.push(AttentionEdge::new(from_pos, to_pos));
397        self
398    }
399
400    /// Steer all attention FROM a specific position.
401    pub fn from_position(mut self, from_pos: usize) -> Self {
402        self.edges.push(AttentionEdge::new(from_pos, usize::MAX));
403        self
404    }
405
406    /// Steer all attention TO a specific position.
407    pub fn to_position(mut self, to_pos: usize) -> Self {
408        self.edges.push(AttentionEdge::new(usize::MAX, to_pos));
409        self
410    }
411
412    /// Add edges from one position to several positions.
413    pub fn from_to_positions(mut self, from_pos: usize, to_positions: &[usize]) -> Self {
414        for &to_pos in to_positions {
415            self.edges.push(AttentionEdge::new(from_pos, to_pos));
416        }
417        self
418    }
419
420    /// Check if this layer should have intervention applied.
421    #[must_use]
422    pub fn applies_to_layer(&self, layer: usize) -> bool {
423        match &self.layers {
424            LayerSpec::All => true,
425            LayerSpec::Specific(layers) => layers.contains(&layer),
426            LayerSpec::Range { start, end } => layer >= *start && layer <= *end,
427        }
428    }
429
430    /// Check if this head should have intervention applied.
431    #[must_use]
432    pub fn applies_to_head(&self, head: usize) -> bool {
433        match &self.heads {
434            HeadSpec::All => true,
435            HeadSpec::Specific(heads) => heads.contains(&head),
436        }
437    }
438
439    /// Validate the spec against model dimensions.
440    ///
441    /// # Errors
442    ///
443    /// Returns [`MIError::Intervention`] if any layer, head, edge, or
444    /// intervention parameter is out of range.
445    pub fn validate(&self, n_layers: usize, n_heads: usize, seq_len: usize) -> Result<()> {
446        validate_layers(&self.layers, n_layers)?;
447        validate_heads(&self.heads, n_heads)?;
448        validate_edges(&self.edges, seq_len)?;
449
450        match self.intervention_type {
451            InterventionType::Scale(factor) => {
452                if factor < 0.0 {
453                    return Err(MIError::Intervention(format!(
454                        "scale factor must be non-negative, got {factor}"
455                    )));
456                }
457            }
458            InterventionType::SetValue(value) => {
459                if !(0.0..=1.0).contains(&value) {
460                    return Err(MIError::Intervention(format!(
461                        "set value must be in [0, 1], got {value}"
462                    )));
463                }
464            }
465            InterventionType::Knockout => {}
466        }
467
468        Ok(())
469    }
470
471    /// Get the intervention type.
472    #[must_use]
473    pub const fn intervention_type(&self) -> InterventionType {
474        self.intervention_type
475    }
476
477    /// Check if this is a knockout intervention.
478    #[must_use]
479    pub const fn is_knockout(&self) -> bool {
480        matches!(self.intervention_type, InterventionType::Knockout)
481    }
482
483    /// Check if this is a post-softmax steering intervention.
484    #[must_use]
485    pub const fn is_steering(&self) -> bool {
486        matches!(
487            self.intervention_type,
488            InterventionType::Scale(_) | InterventionType::SetValue(_)
489        )
490    }
491
492    /// Check if steering only affects positions within the prompt.
493    ///
494    /// If all edges have `from_pos < prompt_len`, the steering can be
495    /// applied once during prompt processing, cached, and reused for
496    /// generation (no steering needed for generated tokens).
497    #[must_use]
498    pub fn is_prompt_only(&self, prompt_len: usize) -> bool {
499        for edge in &self.edges {
500            if edge.from_pos == usize::MAX {
501                return false;
502            }
503            if edge.from_pos >= prompt_len {
504                return false;
505            }
506        }
507        true
508    }
509
510    /// Maximum `from_pos` among all edges (excluding sentinels).
511    #[must_use]
512    pub fn max_from_pos(&self) -> Option<usize> {
513        self.edges
514            .iter()
515            .filter(|e| e.from_pos != usize::MAX)
516            .map(|e| e.from_pos)
517            .max()
518    }
519
520    /// Maximum `to_pos` among all edges (excluding sentinels).
521    #[must_use]
522    pub fn max_to_pos(&self) -> Option<usize> {
523        self.edges
524            .iter()
525            .filter(|e| e.to_pos != usize::MAX)
526            .map(|e| e.to_pos)
527            .max()
528    }
529}
530
531/// Convert a [`KnockoutSpec`] to a [`SteeringSpec`] for unified handling.
532impl From<KnockoutSpec> for SteeringSpec {
533    fn from(spec: KnockoutSpec) -> Self {
534        Self {
535            layers: spec.layers,
536            heads: spec.heads,
537            edges: spec.edges,
538            intervention_type: InterventionType::Knockout,
539        }
540    }
541}
542
543// ===========================================================================
544// Result types
545// ===========================================================================
546
547/// Result of an ablation (knockout) experiment.
548///
549/// Carries baseline and ablated logits so the caller can compute
550/// KL divergence, logit diffs, and top-changed-token analyses.
551#[derive(Debug)]
552pub struct AblationResult {
553    /// Logits from baseline forward pass (no intervention).
554    pub baseline_logits: Tensor,
555    /// Logits from the knocked-out forward pass.
556    pub ablated_logits: Tensor,
557    /// The knockout specification used.
558    pub spec: KnockoutSpec,
559}
560
561impl AblationResult {
562    /// Create a new ablation result.
563    #[must_use]
564    pub const fn new(baseline_logits: Tensor, ablated_logits: Tensor, spec: KnockoutSpec) -> Self {
565        Self {
566            baseline_logits,
567            ablated_logits,
568            spec,
569        }
570    }
571
572    /// KL divergence between baseline and ablated distributions.
573    ///
574    /// # Errors
575    ///
576    /// Returns [`MIError::Model`] if tensor operations fail.
577    pub fn kl_divergence(&self) -> Result<f32> {
578        kl_divergence(&self.baseline_logits, &self.ablated_logits)
579    }
580
581    /// Logit difference for a specific token (`baseline - ablated`).
582    ///
583    /// # Errors
584    ///
585    /// Returns [`MIError::Intervention`] if `token_id` is out of range.
586    pub fn logit_diff(&self, token_id: u32) -> Result<f32> {
587        logit_diff_impl(&self.baseline_logits, &self.ablated_logits, token_id)
588    }
589
590    /// Top-k tokens that changed most due to ablation.
591    ///
592    /// Returns `(token_id, baseline_prob, ablated_prob, abs_diff)`.
593    ///
594    /// # Errors
595    ///
596    /// Returns [`MIError::Model`] if tensor operations fail.
597    pub fn top_changed_tokens(&self, k: usize) -> Result<Vec<(u32, f32, f32, f32)>> {
598        top_changed_impl(&self.baseline_logits, &self.ablated_logits, k)
599    }
600}
601
602/// Result of a steering experiment.
603#[derive(Debug)]
604#[must_use]
605pub struct SteeringResult {
606    /// Logits from baseline forward pass (no intervention).
607    pub baseline_logits: Tensor,
608    /// Logits from the steered forward pass.
609    pub steered_logits: Tensor,
610    /// The steering specification used.
611    pub spec: SteeringSpec,
612    /// Mean attention to target edges before steering.
613    pub baseline_attention_mean: Option<f32>,
614    /// Mean attention to target edges after steering.
615    pub steered_attention_mean: Option<f32>,
616}
617
618impl SteeringResult {
619    /// Create a new steering result.
620    pub const fn new(baseline_logits: Tensor, steered_logits: Tensor, spec: SteeringSpec) -> Self {
621        Self {
622            baseline_logits,
623            steered_logits,
624            spec,
625            baseline_attention_mean: None,
626            steered_attention_mean: None,
627        }
628    }
629
630    /// Add attention measurements.
631    pub const fn with_attention_measurements(
632        mut self,
633        baseline_mean: f32,
634        steered_mean: f32,
635    ) -> Self {
636        self.baseline_attention_mean = Some(baseline_mean);
637        self.steered_attention_mean = Some(steered_mean);
638        self
639    }
640
641    /// KL divergence between baseline and steered distributions.
642    ///
643    /// # Errors
644    ///
645    /// Returns [`MIError::Model`] if tensor operations fail.
646    pub fn kl_divergence(&self) -> Result<f32> {
647        kl_divergence(&self.baseline_logits, &self.steered_logits)
648    }
649
650    /// Logit difference for a specific token (`baseline - steered`).
651    ///
652    /// # Errors
653    ///
654    /// Returns [`MIError::Intervention`] if `token_id` is out of range.
655    pub fn logit_diff(&self, token_id: u32) -> Result<f32> {
656        logit_diff_impl(&self.baseline_logits, &self.steered_logits, token_id)
657    }
658
659    /// Top-k tokens that changed most due to steering.
660    ///
661    /// # Errors
662    ///
663    /// Returns [`MIError::Model`] if tensor operations fail.
664    pub fn top_changed_tokens(&self, k: usize) -> Result<Vec<(u32, f32, f32, f32)>> {
665        top_changed_impl(&self.baseline_logits, &self.steered_logits, k)
666    }
667
668    /// Attention change ratio (`steered_mean / baseline_mean`).
669    #[must_use]
670    pub fn attention_ratio(&self) -> Option<f32> {
671        match (self.baseline_attention_mean, self.steered_attention_mean) {
672            (Some(base), Some(steered)) if base > 1e-10 => Some(steered / base),
673            _ => None,
674        }
675    }
676}
677
678// ===========================================================================
679// Shared result helpers
680// ===========================================================================
681
682/// Compute logit difference for a specific token.
683fn logit_diff_impl(baseline: &Tensor, other: &Tensor, token_id: u32) -> Result<f32> {
684    let baseline_f32 = baseline.to_dtype(DType::F32)?;
685    let other_f32 = other.to_dtype(DType::F32)?;
686    let baseline_vec: Vec<f32> = baseline_f32.flatten_all()?.to_vec1()?;
687    let other_vec: Vec<f32> = other_f32.flatten_all()?.to_vec1()?;
688
689    // CAST: u32 → usize, token ID used as Vec index
690    #[allow(clippy::as_conversions)]
691    let idx = token_id as usize;
692    let b = baseline_vec
693        .get(idx)
694        .ok_or_else(|| MIError::Intervention(format!("token ID {token_id} out of range")))?;
695    let o = other_vec
696        .get(idx)
697        .ok_or_else(|| MIError::Intervention(format!("token ID {token_id} out of range")))?;
698    Ok(b - o)
699}
700
701/// Compute top-k changed tokens between two logit tensors.
702// CAST: usize → u32, vocab index fits in u32
703#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
704fn top_changed_impl(
705    baseline: &Tensor,
706    other: &Tensor,
707    k: usize,
708) -> Result<Vec<(u32, f32, f32, f32)>> {
709    let baseline_probs = softmax_to_vec(baseline)?;
710    let other_probs = softmax_to_vec(other)?;
711
712    let mut changes: Vec<(u32, f32, f32, f32)> = baseline_probs
713        .iter()
714        .zip(other_probs.iter())
715        .enumerate()
716        .map(|(idx, (&base, &oth))| (idx as u32, base, oth, (base - oth).abs()))
717        .collect();
718
719    changes.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
720    Ok(changes.into_iter().take(k).collect())
721}
722
723// ===========================================================================
724// Shared validation helpers
725// ===========================================================================
726
727/// Validate layer specification against model dimensions.
728fn validate_layers(layers: &LayerSpec, n_layers: usize) -> Result<()> {
729    match layers {
730        LayerSpec::Specific(ls) => {
731            for &l in ls {
732                if l >= n_layers {
733                    return Err(MIError::Intervention(format!(
734                        "layer {l} out of range (model has {n_layers} layers)"
735                    )));
736                }
737            }
738        }
739        LayerSpec::Range { start, end } => {
740            if *end >= n_layers {
741                return Err(MIError::Intervention(format!(
742                    "layer range end {end} out of range (model has {n_layers} layers)"
743                )));
744            }
745            if start > end {
746                return Err(MIError::Intervention(format!(
747                    "invalid layer range: start {start} > end {end}"
748                )));
749            }
750        }
751        LayerSpec::All => {}
752    }
753    Ok(())
754}
755
756/// Validate head specification against model dimensions.
757fn validate_heads(heads: &HeadSpec, n_heads: usize) -> Result<()> {
758    if let HeadSpec::Specific(hs) = heads {
759        for &h in hs {
760            if h >= n_heads {
761                return Err(MIError::Intervention(format!(
762                    "head {h} out of range (model has {n_heads} heads)"
763                )));
764            }
765        }
766    }
767    Ok(())
768}
769
770/// Validate edge positions against sequence length.
771fn validate_edges(edges: &[AttentionEdge], seq_len: usize) -> Result<()> {
772    for edge in edges {
773        if edge.from_pos != usize::MAX && edge.from_pos >= seq_len {
774            return Err(MIError::Intervention(format!(
775                "edge from_pos {} out of range (seq_len is {seq_len})",
776                edge.from_pos,
777            )));
778        }
779        if edge.to_pos != usize::MAX && edge.to_pos >= seq_len {
780            return Err(MIError::Intervention(format!(
781                "edge to_pos {} out of range (seq_len is {seq_len})",
782                edge.to_pos,
783            )));
784        }
785    }
786    Ok(())
787}
788
789// ===========================================================================
790// Mask creation and steering application
791// ===========================================================================
792
793/// Create a knockout mask tensor for the given specification.
794///
795/// Returns a tensor of shape `[1, n_heads, seq_len, seq_len]` where:
796/// - `0.0` = no knockout (attention allowed)
797/// - `-inf` = knockout (attention blocked)
798///
799/// This mask is ADDED to the attention scores before softmax.
800///
801/// # Shapes
802///
803/// - returns: `[1, n_heads, seq_len, seq_len]`
804///
805/// # Errors
806///
807/// Returns [`MIError::Model`] if tensor creation fails.
808#[allow(clippy::indexing_slicing)] // Bounds checked via edge.from_pos < seq_len
809pub fn create_knockout_mask(
810    spec: &KnockoutSpec,
811    n_heads: usize,
812    seq_len: usize,
813    device: &Device,
814    dtype: DType,
815) -> Result<Tensor> {
816    let mut mask_data = vec![0.0f32; n_heads * seq_len * seq_len];
817    let expanded_edges = expand_edges(&spec.edges, seq_len);
818
819    for head in 0..n_heads {
820        if !spec.applies_to_head(head) {
821            continue;
822        }
823
824        for edge in &expanded_edges {
825            if edge.from_pos < seq_len && edge.to_pos < seq_len {
826                let idx = head * seq_len * seq_len + edge.from_pos * seq_len + edge.to_pos;
827                mask_data[idx] = f32::NEG_INFINITY;
828            }
829        }
830    }
831
832    let mask = Tensor::from_vec(mask_data, (1, n_heads, seq_len, seq_len), device)?;
833    Ok(mask.to_dtype(dtype)?)
834}
835
836/// Compute KL divergence between two logit tensors.
837///
838/// Returns `KL(P || Q)` where `P = softmax(baseline)`, `Q = softmax(other)`.
839///
840/// # Errors
841///
842/// Returns [`MIError::Model`] if tensor operations fail.
843pub fn kl_divergence(baseline_logits: &Tensor, other_logits: &Tensor) -> Result<f32> {
844    let p = softmax_to_vec(baseline_logits)?;
845    let q = softmax_to_vec(other_logits)?;
846
847    let kl: f32 = p
848        .iter()
849        .zip(q.iter())
850        .filter(|&(&pi, &qi)| pi > 1e-10 && qi > 1e-10)
851        .map(|(&pi, &qi)| pi * (pi / qi).ln())
852        .sum();
853
854    Ok(kl)
855}
856
857/// Apply steering intervention to attention weights (post-softmax).
858///
859/// Modifies attention weights according to the steering spec and
860/// renormalizes rows to maintain valid probability distributions.
861///
862/// # Shapes
863///
864/// - `attn_weights`: `[batch, heads, seq, seq]`
865/// - returns: `[batch, heads, seq, seq]`
866///
867/// # Errors
868///
869/// Returns [`MIError::Intervention`] if the spec uses knockout
870/// (which should use [`create_knockout_mask`] instead).
871pub fn apply_steering(
872    attn_weights: &Tensor,
873    spec: &SteeringSpec,
874    n_heads: usize,
875    seq_len: usize,
876) -> Result<Tensor> {
877    match spec.intervention_type {
878        InterventionType::Scale(factor) => {
879            apply_scale_steering(attn_weights, spec, n_heads, seq_len, factor)
880        }
881        InterventionType::SetValue(target) => {
882            apply_set_value_steering(attn_weights, spec, n_heads, seq_len, target)
883        }
884        InterventionType::Knockout => Err(MIError::Intervention(
885            "knockout should use create_knockout_mask, not apply_steering".into(),
886        )),
887    }
888}
889
890/// Apply scaling to specified edges, then renormalize rows.
891///
892/// # Shapes
893///
894/// - `attn_weights`: `[batch, heads, seq, seq]`
895/// - returns: `[batch, heads, seq, seq]`
896///
897/// # Errors
898///
899/// Returns [`MIError::Intervention`] on tensor extraction failures.
900#[allow(clippy::indexing_slicing)] // Operating on extracted Vecs with validated bounds
901pub fn apply_scale_steering(
902    attn_weights: &Tensor,
903    spec: &SteeringSpec,
904    _n_heads: usize,
905    seq_len: usize,
906    scale_factor: f32,
907) -> Result<Tensor> {
908    // PROMOTE: needs f32 for numerical manipulation
909    let attn_f32 = attn_weights.to_dtype(DType::F32)?;
910    let original_dtype = attn_weights.dtype();
911    let device = attn_weights.device();
912
913    let mut data = tensor_to_vec4(&attn_f32)?;
914    let expanded_edges = expand_edges(&spec.edges, seq_len);
915
916    for batch_data in &mut data {
917        for (h, head_data) in batch_data.iter_mut().enumerate() {
918            if !spec.applies_to_head(h) {
919                continue;
920            }
921
922            let mut rows_modified: HashSet<usize> = HashSet::new();
923
924            for edge in &expanded_edges {
925                if edge.from_pos < seq_len && edge.to_pos < seq_len {
926                    head_data[edge.from_pos][edge.to_pos] *= scale_factor;
927                    rows_modified.insert(edge.from_pos);
928                }
929            }
930
931            for row in rows_modified {
932                let row_sum: f32 = head_data[row].iter().sum();
933                if row_sum > 1e-10 {
934                    for val in &mut head_data[row] {
935                        *val /= row_sum;
936                    }
937                }
938            }
939        }
940    }
941
942    let result = Tensor::new(data, device)?.to_dtype(original_dtype)?;
943    Ok(result)
944}
945
946/// Set specified edges to a target value, redistributing mass.
947///
948/// # Shapes
949///
950/// - `attn_weights`: `[batch, heads, seq, seq]`
951/// - returns: `[batch, heads, seq, seq]`
952///
953/// # Errors
954///
955/// Returns [`MIError::Intervention`] on tensor extraction failures.
956// CAST: usize → f32, small counts (seq_len, edge count) fit in f32
957#[allow(
958    clippy::indexing_slicing, // Operating on extracted Vecs with validated bounds
959    clippy::cast_precision_loss,
960    clippy::as_conversions,
961)]
962pub fn apply_set_value_steering(
963    attn_weights: &Tensor,
964    spec: &SteeringSpec,
965    _n_heads: usize,
966    seq_len: usize,
967    target_value: f32,
968) -> Result<Tensor> {
969    // PROMOTE: needs f32 for numerical manipulation
970    let attn_f32 = attn_weights.to_dtype(DType::F32)?;
971    let original_dtype = attn_weights.dtype();
972    let device = attn_weights.device();
973
974    let mut data = tensor_to_vec4(&attn_f32)?;
975    let expanded_edges = expand_edges(&spec.edges, seq_len);
976
977    // Group edges by row for efficient row-wise operations.
978    let mut edges_by_row: HashMap<usize, Vec<usize>> = HashMap::new();
979    for edge in &expanded_edges {
980        if edge.from_pos < seq_len && edge.to_pos < seq_len {
981            edges_by_row
982                .entry(edge.from_pos)
983                .or_default()
984                .push(edge.to_pos);
985        }
986    }
987
988    for batch_data in &mut data {
989        for (h, head_data) in batch_data.iter_mut().enumerate() {
990            if !spec.applies_to_head(h) {
991                continue;
992            }
993
994            for (&row, target_cols) in &edges_by_row {
995                let current_target_sum: f32 =
996                    target_cols.iter().map(|&col| head_data[row][col]).sum();
997                let new_target_sum = target_value * target_cols.len() as f32;
998                let delta = new_target_sum - current_target_sum;
999
1000                let non_target_cols: Vec<usize> =
1001                    (0..seq_len).filter(|i| !target_cols.contains(i)).collect();
1002
1003                for &col in target_cols {
1004                    head_data[row][col] = target_value;
1005                }
1006
1007                if !non_target_cols.is_empty() {
1008                    let adjustment = delta / non_target_cols.len() as f32;
1009                    for col in non_target_cols {
1010                        head_data[row][col] = (head_data[row][col] - adjustment).max(0.0);
1011                    }
1012                }
1013
1014                let row_sum: f32 = head_data[row].iter().sum();
1015                if row_sum > 1e-10 {
1016                    for val in &mut head_data[row] {
1017                        *val /= row_sum;
1018                    }
1019                }
1020            }
1021        }
1022    }
1023
1024    let result = Tensor::new(data, device)?.to_dtype(original_dtype)?;
1025    Ok(result)
1026}
1027
1028/// Measure mean attention for specified edges in an attention tensor.
1029///
1030/// # Shapes
1031///
1032/// - `attn_weights`: `[batch, heads, seq, seq]`
1033///
1034/// # Errors
1035///
1036/// Returns [`MIError::Intervention`] if `from_pos` is out of range.
1037#[allow(clippy::indexing_slicing)] // Operating on extracted Vecs with validated bounds
1038pub fn measure_attention_to_targets(
1039    attn_weights: &Tensor,
1040    from_pos: usize,
1041    to_positions: &[usize],
1042) -> Result<f32> {
1043    let attn_f32 = attn_weights.to_dtype(DType::F32)?;
1044    let data = tensor_to_vec4(&attn_f32)?;
1045
1046    let seq_len = data.first().and_then(|b| b.first()).map_or(0, Vec::len);
1047
1048    if from_pos >= seq_len {
1049        return Err(MIError::Intervention(format!(
1050            "from_pos {from_pos} out of range (seq_len is {seq_len})"
1051        )));
1052    }
1053
1054    let mut total = 0.0_f32;
1055    let mut count = 0_usize;
1056
1057    for batch_data in &data {
1058        for head_data in batch_data {
1059            for &to_pos in to_positions {
1060                if to_pos < seq_len {
1061                    total += head_data[from_pos][to_pos];
1062                    count += 1;
1063                }
1064            }
1065        }
1066    }
1067
1068    if count == 0 {
1069        Ok(0.0)
1070    } else {
1071        // CAST: usize → f32, count is number of matching heads — fits in f32
1072        #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
1073        Ok(total / count as f32)
1074    }
1075}
1076
1077// ===========================================================================
1078// Part 3: State Knockout (RWKV-6)
1079// ===========================================================================
1080
1081/// Specification for RWKV-6 state knockout intervention.
1082///
1083/// State knockout makes specific token positions invisible to all future
1084/// tokens by skipping the recurrent state update at those positions.
1085/// This is the RNN analogue of all-edge attention knockout in transformers.
1086#[derive(Debug, Clone)]
1087#[must_use]
1088pub struct StateKnockoutSpec {
1089    /// Token positions where state update is skipped.
1090    pub positions: Vec<usize>,
1091    /// Which layers to apply knockout.
1092    pub layers: LayerSpec,
1093}
1094
1095impl StateKnockoutSpec {
1096    /// Create a new empty spec (all layers, no positions yet).
1097    pub const fn new() -> Self {
1098        Self {
1099            positions: Vec::new(),
1100            layers: LayerSpec::All,
1101        }
1102    }
1103
1104    /// Add a single position to knock out.
1105    pub fn position(mut self, pos: usize) -> Self {
1106        self.positions.push(pos);
1107        self
1108    }
1109
1110    /// Add multiple positions to knock out.
1111    pub fn positions(mut self, positions: &[usize]) -> Self {
1112        self.positions.extend_from_slice(positions);
1113        self
1114    }
1115
1116    /// Target a single layer.
1117    pub fn layer(mut self, layer: usize) -> Self {
1118        self.layers = LayerSpec::Specific(vec![layer]);
1119        self
1120    }
1121
1122    /// Target multiple specific layers.
1123    pub fn layers(mut self, layers: &[usize]) -> Self {
1124        self.layers = LayerSpec::Specific(layers.to_vec());
1125        self
1126    }
1127
1128    /// Target a range of layers (inclusive).
1129    pub fn layer_range(mut self, start: usize, end: usize) -> Self {
1130        self.layers = LayerSpec::Range { start, end };
1131        self
1132    }
1133
1134    /// Check if knockout applies to this layer.
1135    #[must_use]
1136    pub fn applies_to_layer(&self, layer: usize) -> bool {
1137        match &self.layers {
1138            LayerSpec::All => true,
1139            LayerSpec::Specific(layers) => layers.contains(&layer),
1140            LayerSpec::Range { start, end } => layer >= *start && layer <= *end,
1141        }
1142    }
1143
1144    /// Get knockout positions as a `HashSet` for O(1) lookup in the WKV loop.
1145    #[must_use]
1146    pub fn position_set(&self) -> HashSet<usize> {
1147        self.positions.iter().copied().collect()
1148    }
1149
1150    /// Validate the spec against model dimensions.
1151    ///
1152    /// # Errors
1153    ///
1154    /// Returns [`MIError::Intervention`] if positions, layers are out of
1155    /// range, or no positions are specified.
1156    pub fn validate(&self, n_layers: usize, seq_len: usize) -> Result<()> {
1157        validate_layers(&self.layers, n_layers)?;
1158
1159        for &pos in &self.positions {
1160            if pos >= seq_len {
1161                return Err(MIError::Intervention(format!(
1162                    "position {pos} out of range (seq_len is {seq_len})"
1163                )));
1164            }
1165        }
1166
1167        if self.positions.is_empty() {
1168            return Err(MIError::Intervention(
1169                "StateKnockoutSpec has no positions specified".into(),
1170            ));
1171        }
1172
1173        Ok(())
1174    }
1175}
1176
1177impl Default for StateKnockoutSpec {
1178    fn default() -> Self {
1179        Self::new()
1180    }
1181}
1182
1183/// Result of a state knockout ablation experiment (RWKV-6).
1184#[derive(Debug)]
1185pub struct StateAblationResult {
1186    /// Logits from baseline forward pass (no intervention).
1187    pub baseline_logits: Tensor,
1188    /// Logits from state-knocked-out forward pass.
1189    pub ablated_logits: Tensor,
1190    /// The state knockout specification used.
1191    pub spec: StateKnockoutSpec,
1192}
1193
1194impl StateAblationResult {
1195    /// Create a new state ablation result.
1196    #[must_use]
1197    pub const fn new(
1198        baseline_logits: Tensor,
1199        ablated_logits: Tensor,
1200        spec: StateKnockoutSpec,
1201    ) -> Self {
1202        Self {
1203            baseline_logits,
1204            ablated_logits,
1205            spec,
1206        }
1207    }
1208
1209    /// KL divergence between baseline and ablated distributions.
1210    ///
1211    /// # Errors
1212    ///
1213    /// Returns [`MIError::Model`] if tensor operations fail.
1214    pub fn kl_divergence(&self) -> Result<f32> {
1215        kl_divergence(&self.baseline_logits, &self.ablated_logits)
1216    }
1217
1218    /// Logit difference for a specific token.
1219    ///
1220    /// # Errors
1221    ///
1222    /// Returns [`MIError::Intervention`] if `token_id` is out of range.
1223    pub fn logit_diff(&self, token_id: u32) -> Result<f32> {
1224        logit_diff_impl(&self.baseline_logits, &self.ablated_logits, token_id)
1225    }
1226
1227    /// Top-k tokens that changed most due to state knockout.
1228    ///
1229    /// # Errors
1230    ///
1231    /// Returns [`MIError::Model`] if tensor operations fail.
1232    pub fn top_changed_tokens(&self, k: usize) -> Result<Vec<(u32, f32, f32, f32)>> {
1233        top_changed_impl(&self.baseline_logits, &self.ablated_logits, k)
1234    }
1235}
1236
1237// ===========================================================================
1238// Part 4: State Steering (RWKV-6)
1239// ===========================================================================
1240
1241/// Specification for RWKV-6 state steering intervention.
1242///
1243/// State steering scales the kv write at specified positions, amplifying
1244/// or dampening the token's contribution to recurrent state.
1245///
1246/// - `scale = 0.0` → knockout (equivalent to [`StateKnockoutSpec`])
1247/// - `scale = 1.0` → no-op (normal forward pass)
1248/// - `scale > 1.0` → amplify the token's state write
1249/// - `scale < 1.0` → dampen the token's state write
1250#[derive(Debug, Clone)]
1251#[must_use]
1252pub struct StateSteeringSpec {
1253    /// Token positions where state write is scaled.
1254    pub positions: Vec<usize>,
1255    /// Which layers to apply steering.
1256    pub layers: LayerSpec,
1257    /// Scale factor for kv write.
1258    pub scale: f32,
1259}
1260
1261impl StateSteeringSpec {
1262    /// Create a new spec with the given scale factor (all layers, no positions).
1263    pub const fn new(scale: f32) -> Self {
1264        Self {
1265            positions: Vec::new(),
1266            layers: LayerSpec::All,
1267            scale,
1268        }
1269    }
1270
1271    /// Add a single position to steer.
1272    pub fn position(mut self, pos: usize) -> Self {
1273        self.positions.push(pos);
1274        self
1275    }
1276
1277    /// Add multiple positions to steer.
1278    pub fn positions(mut self, positions: &[usize]) -> Self {
1279        self.positions.extend_from_slice(positions);
1280        self
1281    }
1282
1283    /// Target a single layer.
1284    pub fn layer(mut self, layer: usize) -> Self {
1285        self.layers = LayerSpec::Specific(vec![layer]);
1286        self
1287    }
1288
1289    /// Target multiple specific layers.
1290    pub fn layers(mut self, layers: &[usize]) -> Self {
1291        self.layers = LayerSpec::Specific(layers.to_vec());
1292        self
1293    }
1294
1295    /// Target a range of layers (inclusive).
1296    pub fn layer_range(mut self, start: usize, end: usize) -> Self {
1297        self.layers = LayerSpec::Range { start, end };
1298        self
1299    }
1300
1301    /// Check if steering applies to this layer.
1302    #[must_use]
1303    pub fn applies_to_layer(&self, layer: usize) -> bool {
1304        match &self.layers {
1305            LayerSpec::All => true,
1306            LayerSpec::Specific(layers) => layers.contains(&layer),
1307            LayerSpec::Range { start, end } => layer >= *start && layer <= *end,
1308        }
1309    }
1310
1311    /// Get steering positions as a `HashSet` for O(1) lookup in the WKV loop.
1312    #[must_use]
1313    pub fn position_set(&self) -> HashSet<usize> {
1314        self.positions.iter().copied().collect()
1315    }
1316
1317    /// Validate the spec against model dimensions.
1318    ///
1319    /// # Errors
1320    ///
1321    /// Returns [`MIError::Intervention`] if positions or layers are out of
1322    /// range, or no positions are specified.
1323    pub fn validate(&self, n_layers: usize, seq_len: usize) -> Result<()> {
1324        validate_layers(&self.layers, n_layers)?;
1325
1326        for &pos in &self.positions {
1327            if pos >= seq_len {
1328                return Err(MIError::Intervention(format!(
1329                    "position {pos} out of range (seq_len is {seq_len})"
1330                )));
1331            }
1332        }
1333
1334        if self.positions.is_empty() {
1335            return Err(MIError::Intervention(
1336                "StateSteeringSpec has no positions specified".into(),
1337            ));
1338        }
1339
1340        Ok(())
1341    }
1342}
1343
1344/// Result of a state steering experiment (RWKV-6).
1345#[derive(Debug)]
1346pub struct StateSteeringResult {
1347    /// Logits from baseline forward pass (no intervention).
1348    pub baseline_logits: Tensor,
1349    /// Logits from the steered forward pass.
1350    pub steered_logits: Tensor,
1351    /// The state steering specification used.
1352    pub spec: StateSteeringSpec,
1353}
1354
1355impl StateSteeringResult {
1356    /// Create a new state steering result.
1357    #[must_use]
1358    pub const fn new(
1359        baseline_logits: Tensor,
1360        steered_logits: Tensor,
1361        spec: StateSteeringSpec,
1362    ) -> Self {
1363        Self {
1364            baseline_logits,
1365            steered_logits,
1366            spec,
1367        }
1368    }
1369
1370    /// KL divergence between baseline and steered distributions.
1371    ///
1372    /// # Errors
1373    ///
1374    /// Returns [`MIError::Model`] if tensor operations fail.
1375    pub fn kl_divergence(&self) -> Result<f32> {
1376        kl_divergence(&self.baseline_logits, &self.steered_logits)
1377    }
1378
1379    /// Top-k tokens that changed most due to state steering.
1380    ///
1381    /// # Errors
1382    ///
1383    /// Returns [`MIError::Model`] if tensor operations fail.
1384    pub fn top_changed_tokens(&self, k: usize) -> Result<Vec<(u32, f32, f32, f32)>> {
1385        top_changed_impl(&self.baseline_logits, &self.steered_logits, k)
1386    }
1387}
1388
1389// ===========================================================================
1390// Part 5: CLT Injection (feature-gated)
1391// ===========================================================================
1392
1393/// Pre-accumulated CLT injection vectors for per-layer residual stream injection.
1394///
1395/// Created by the CLT encoder's `prepare_injection()` method. The forward
1396/// pass adds each vector to the residual at the specified position after the
1397/// target layer completes.
1398#[cfg(feature = "clt")]
1399#[derive(Debug, Clone)]
1400pub struct CltInjectionSpec {
1401    /// Per-layer injection entries.
1402    pub injections: Vec<CltLayerInjection>,
1403}
1404
1405/// A single CLT injection at one layer and position.
1406#[cfg(feature = "clt")]
1407#[derive(Debug, Clone)]
1408pub struct CltLayerInjection {
1409    /// Target layer index (injection happens after this layer completes).
1410    pub target_layer: usize,
1411    /// Token position in the sequence to inject at.
1412    pub position: usize,
1413    /// Pre-accumulated and strength-scaled decoder vector, shape `[d_model]`.
1414    pub vector: Tensor,
1415}
1416
1417#[cfg(feature = "clt")]
1418impl CltInjectionSpec {
1419    /// Create an empty injection spec.
1420    #[must_use]
1421    pub const fn new() -> Self {
1422        Self {
1423            injections: Vec::new(),
1424        }
1425    }
1426
1427    /// Add a single injection entry.
1428    pub fn add(&mut self, target_layer: usize, position: usize, vector: Tensor) {
1429        self.injections.push(CltLayerInjection {
1430            target_layer,
1431            position,
1432            vector,
1433        });
1434    }
1435
1436    /// Check if any injection targets this layer.
1437    #[must_use]
1438    pub fn applies_to_layer(&self, layer: usize) -> bool {
1439        self.injections.iter().any(|inj| inj.target_layer == layer)
1440    }
1441
1442    /// Get all injections targeting a specific layer.
1443    #[must_use]
1444    pub fn injections_for_layer(&self, layer: usize) -> Vec<&CltLayerInjection> {
1445        self.injections
1446            .iter()
1447            .filter(|inj| inj.target_layer == layer)
1448            .collect()
1449    }
1450
1451    /// Validate the spec against model dimensions.
1452    ///
1453    /// # Errors
1454    ///
1455    /// Returns [`MIError::Intervention`] if any target layer, position,
1456    /// or vector dimension is out of range.
1457    pub fn validate(&self, n_layers: usize, seq_len: usize, d_model: usize) -> Result<()> {
1458        for inj in &self.injections {
1459            let target = inj.target_layer;
1460            if target >= n_layers {
1461                return Err(MIError::Intervention(format!(
1462                    "CLT injection target layer {target} out of range (model has {n_layers} layers)"
1463                )));
1464            }
1465            let pos = inj.position;
1466            if pos >= seq_len {
1467                return Err(MIError::Intervention(format!(
1468                    "CLT injection position {pos} out of range (seq_len={seq_len})"
1469                )));
1470            }
1471            let vec_dim = inj.vector.dim(0)?;
1472            if vec_dim != d_model {
1473                return Err(MIError::Intervention(format!(
1474                    "CLT injection vector dim {vec_dim} doesn't match model d_model={d_model}"
1475                )));
1476            }
1477        }
1478        Ok(())
1479    }
1480}
1481
1482#[cfg(feature = "clt")]
1483impl Default for CltInjectionSpec {
1484    fn default() -> Self {
1485        Self::new()
1486    }
1487}
1488
1489/// Result of a CLT logit shift test (baseline vs. injected comparison).
1490#[cfg(feature = "clt")]
1491#[derive(Debug)]
1492pub struct CltLogitShiftResult {
1493    /// Logits from baseline forward pass (no injection).
1494    pub baseline_logits: Tensor,
1495    /// Logits from CLT-injected forward pass.
1496    pub injected_logits: Tensor,
1497}
1498
1499#[cfg(feature = "clt")]
1500impl CltLogitShiftResult {
1501    /// Create a new CLT logit shift result.
1502    #[must_use]
1503    pub const fn new(baseline_logits: Tensor, injected_logits: Tensor) -> Self {
1504        Self {
1505            baseline_logits,
1506            injected_logits,
1507        }
1508    }
1509
1510    /// KL divergence between baseline and injected distributions.
1511    ///
1512    /// # Errors
1513    ///
1514    /// Returns [`MIError::Model`] if tensor operations fail.
1515    pub fn kl_divergence(&self) -> Result<f32> {
1516        kl_divergence(&self.baseline_logits, &self.injected_logits)
1517    }
1518
1519    /// Top-k tokens that changed most due to CLT injection.
1520    ///
1521    /// # Errors
1522    ///
1523    /// Returns [`MIError::Model`] if tensor operations fail.
1524    pub fn top_changed_tokens(&self, k: usize) -> Result<Vec<(u32, f32, f32, f32)>> {
1525        top_changed_impl(&self.baseline_logits, &self.injected_logits, k)
1526    }
1527}
1528
1529// ===========================================================================
1530// Tests
1531// ===========================================================================
1532
1533#[cfg(test)]
1534#[allow(
1535    clippy::unwrap_used,
1536    clippy::expect_used,
1537    clippy::float_cmp,
1538    clippy::indexing_slicing
1539)]
1540mod tests {
1541    use super::*;
1542
1543    #[test]
1544    fn knockout_spec_builder() {
1545        let spec = KnockoutSpec::new()
1546            .layer(5)
1547            .head(2)
1548            .edge(3, 1)
1549            .from_to_positions(4, &[0, 1, 2]);
1550
1551        assert!(matches!(spec.layers, LayerSpec::Specific(_)));
1552        assert!(matches!(spec.heads, HeadSpec::Specific(_)));
1553        assert_eq!(spec.edges.len(), 4); // 1 + 3
1554    }
1555
1556    #[test]
1557    fn layer_spec_applies() {
1558        let spec = KnockoutSpec::new().layer_range(5, 10);
1559
1560        assert!(!spec.applies_to_layer(4));
1561        assert!(spec.applies_to_layer(5));
1562        assert!(spec.applies_to_layer(7));
1563        assert!(spec.applies_to_layer(10));
1564        assert!(!spec.applies_to_layer(11));
1565    }
1566
1567    #[test]
1568    fn expand_edges_sentinels() {
1569        let edges = vec![AttentionEdge::new(2, usize::MAX), AttentionEdge::new(1, 0)];
1570
1571        let expanded = expand_edges(&edges, 4);
1572        assert_eq!(expanded.len(), 5); // 4 from sentinel + 1 specific
1573    }
1574
1575    #[test]
1576    fn create_knockout_mask_correctness() {
1577        let spec = KnockoutSpec::new().head(0).edge(2, 1);
1578
1579        let mask = create_knockout_mask(&spec, 2, 4, &Device::Cpu, DType::F32).unwrap();
1580        assert_eq!(mask.dims(), &[1, 2, 4, 4]);
1581
1582        let mask_vec: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap();
1583
1584        // Head 0, row 2, col 1 = index 0*16 + 2*4 + 1 = 9
1585        assert!(mask_vec[9].is_infinite() && mask_vec[9].is_sign_negative());
1586
1587        // Head 1 should not be affected (index 1*16 + 2*4 + 1 = 25)
1588        assert_eq!(mask_vec[25], 0.0);
1589    }
1590
1591    #[test]
1592    fn validation_catches_errors() {
1593        let spec = KnockoutSpec::new().layer(100).edge(50, 25);
1594        assert!(spec.validate(30, 16, 20).is_err());
1595    }
1596
1597    #[test]
1598    fn validation_passes_valid() {
1599        let spec = KnockoutSpec::new().layer(10).edge(5, 3);
1600        assert!(spec.validate(30, 16, 20).is_ok());
1601    }
1602
1603    #[test]
1604    fn steering_spec_builder() {
1605        let spec = SteeringSpec::scale(2.0)
1606            .layer(5)
1607            .head(2)
1608            .edge(3, 1)
1609            .from_to_positions(4, &[0, 1, 2]);
1610
1611        assert!(matches!(spec.layers, LayerSpec::Specific(_)));
1612        assert!(matches!(spec.heads, HeadSpec::Specific(_)));
1613        assert_eq!(spec.edges.len(), 4);
1614        assert!(
1615            matches!(spec.intervention_type, InterventionType::Scale(f) if (f - 2.0).abs() < 1e-6)
1616        );
1617    }
1618
1619    #[test]
1620    fn steering_validation() {
1621        let spec = SteeringSpec::scale(2.0).layer(10).edge(5, 3);
1622        assert!(spec.validate(30, 16, 20).is_ok());
1623
1624        let spec = SteeringSpec::scale(-1.0).layer(10).edge(5, 3);
1625        assert!(spec.validate(30, 16, 20).is_err());
1626
1627        let spec = SteeringSpec::set_value(0.09).layer(10).edge(5, 3);
1628        assert!(spec.validate(30, 16, 20).is_ok());
1629
1630        let spec = SteeringSpec::set_value(1.5).layer(10).edge(5, 3);
1631        assert!(spec.validate(30, 16, 20).is_err());
1632    }
1633
1634    #[test]
1635    fn steering_is_methods() {
1636        let knockout = SteeringSpec::new(InterventionType::Knockout);
1637        assert!(knockout.is_knockout());
1638        assert!(!knockout.is_steering());
1639
1640        let scale = SteeringSpec::scale(2.0);
1641        assert!(!scale.is_knockout());
1642        assert!(scale.is_steering());
1643
1644        let set_value = SteeringSpec::set_value(0.1);
1645        assert!(!set_value.is_knockout());
1646        assert!(set_value.is_steering());
1647    }
1648
1649    #[test]
1650    fn apply_scale_steering_correctness() {
1651        let data: Vec<f32> = vec![
1652            // Head 0: uniform attention (each row sums to 1.0)
1653            0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
1654            0.25, 0.25, // Head 1: same
1655            0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
1656            0.25, 0.25,
1657        ];
1658        let tensor = Tensor::from_vec(data, (1, 2, 4, 4), &Device::Cpu).unwrap();
1659
1660        let spec = SteeringSpec::scale(2.0).edge(2, 1);
1661        let result = apply_scale_steering(&tensor, &spec, 2, 4, 2.0).unwrap();
1662        let result_data = tensor_to_vec4(&result).unwrap();
1663
1664        // Row 2: edge (2,1) scaled by 2, then renormalized
1665        // Before: [0.25, 0.25, 0.25, 0.25]
1666        // After scaling: [0.25, 0.50, 0.25, 0.25], sum = 1.25
1667        // After renorm: [0.20, 0.40, 0.20, 0.20]
1668        let row2 = &result_data[0][0][2];
1669        assert!((row2[0] - 0.20).abs() < 1e-5);
1670        assert!((row2[1] - 0.40).abs() < 1e-5);
1671        assert!((row2[2] - 0.20).abs() < 1e-5);
1672        assert!((row2[3] - 0.20).abs() < 1e-5);
1673
1674        let row_sum: f32 = row2.iter().sum();
1675        assert!((row_sum - 1.0).abs() < 1e-5);
1676    }
1677
1678    #[test]
1679    fn apply_set_value_steering_correctness() {
1680        let data: Vec<f32> = vec![
1681            0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
1682            0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
1683            0.25, 0.25, 0.25, 0.25,
1684        ];
1685        let tensor = Tensor::from_vec(data, (1, 2, 4, 4), &Device::Cpu).unwrap();
1686
1687        let spec = SteeringSpec::set_value(0.5).edge(2, 1);
1688        let result = apply_set_value_steering(&tensor, &spec, 2, 4, 0.5).unwrap();
1689        let result_data = tensor_to_vec4(&result).unwrap();
1690
1691        let row2 = &result_data[0][0][2];
1692        let row_sum: f32 = row2.iter().sum();
1693        assert!(
1694            (row_sum - 1.0).abs() < 1e-5,
1695            "row sum should be 1.0, got {row_sum}"
1696        );
1697
1698        // Edge (2,1) should be the largest value
1699        assert!(row2[1] > row2[0]);
1700        assert!(row2[1] > row2[2]);
1701        assert!(row2[1] > row2[3]);
1702    }
1703
1704    #[test]
1705    fn knockout_to_steering_conversion() {
1706        let knockout = KnockoutSpec::new().layer(5).head(2).edge(3, 1);
1707        let steering: SteeringSpec = knockout.into();
1708
1709        assert!(matches!(steering.layers, LayerSpec::Specific(ref v) if v == &[5]));
1710        assert!(matches!(steering.heads, HeadSpec::Specific(ref v) if v == &[2]));
1711        assert_eq!(steering.edges.len(), 1);
1712        assert!(steering.is_knockout());
1713    }
1714
1715    #[test]
1716    fn is_prompt_only() {
1717        let spec = SteeringSpec::scale(2.0).edge(5, 2).edge(8, 3);
1718        assert!(spec.is_prompt_only(10));
1719        assert!(!spec.is_prompt_only(6));
1720    }
1721
1722    #[test]
1723    fn is_prompt_only_with_sentinel() {
1724        let spec = SteeringSpec::scale(2.0).to_position(5);
1725        assert!(!spec.is_prompt_only(10));
1726
1727        let spec2 = SteeringSpec::scale(2.0).from_position(5);
1728        assert!(spec2.is_prompt_only(10));
1729    }
1730
1731    #[test]
1732    fn max_positions() {
1733        let spec = SteeringSpec::scale(2.0).edge(5, 2).edge(8, 3).edge(3, 7);
1734        assert_eq!(spec.max_from_pos(), Some(8));
1735        assert_eq!(spec.max_to_pos(), Some(7));
1736    }
1737
1738    #[test]
1739    fn max_positions_empty() {
1740        let spec = SteeringSpec::scale(2.0);
1741        assert_eq!(spec.max_from_pos(), None);
1742        assert_eq!(spec.max_to_pos(), None);
1743    }
1744
1745    // --- State knockout tests ---
1746
1747    #[test]
1748    fn state_knockout_spec_builder() {
1749        let spec = StateKnockoutSpec::new().position(3).position(5).layer(10);
1750        assert_eq!(spec.positions, vec![3, 5]);
1751        assert!(matches!(spec.layers, LayerSpec::Specific(ref v) if v == &[10]));
1752    }
1753
1754    #[test]
1755    fn state_knockout_validation() {
1756        assert!(
1757            StateKnockoutSpec::new()
1758                .position(5)
1759                .layer(10)
1760                .validate(24, 20)
1761                .is_ok()
1762        );
1763        assert!(
1764            StateKnockoutSpec::new()
1765                .position(25)
1766                .validate(24, 20)
1767                .is_err()
1768        );
1769        assert!(
1770            StateKnockoutSpec::new()
1771                .position(5)
1772                .layer(30)
1773                .validate(24, 20)
1774                .is_err()
1775        );
1776        assert!(StateKnockoutSpec::new().validate(24, 20).is_err()); // empty
1777    }
1778
1779    #[test]
1780    fn state_knockout_position_set() {
1781        let spec = StateKnockoutSpec::new().position(3).position(5).position(3);
1782        let set = spec.position_set();
1783        assert_eq!(set.len(), 2); // deduplicated
1784        assert!(set.contains(&3));
1785        assert!(set.contains(&5));
1786    }
1787
1788    #[test]
1789    fn state_knockout_layer_range() {
1790        let spec = StateKnockoutSpec::new().position(0).layer_range(5, 10);
1791        assert!(!spec.applies_to_layer(4));
1792        assert!(spec.applies_to_layer(5));
1793        assert!(spec.applies_to_layer(10));
1794        assert!(!spec.applies_to_layer(11));
1795    }
1796}