1use std::collections::{HashMap, HashSet};
24
25use candle_core::{DType, Device, Tensor};
26
27use crate::error::{MIError, Result};
28
29fn tensor_to_vec4(tensor: &Tensor) -> Result<Vec<Vec<Vec<Vec<f32>>>>> {
45 let shape = tensor.dims();
46 if shape.len() != 4 {
47 return Err(MIError::Intervention(format!(
48 "expected 4D tensor, got {}D",
49 shape.len()
50 )));
51 }
52 let s0 = shape.first().copied().unwrap_or(0);
53 let s1 = shape.get(1).copied().unwrap_or(0);
54 let s2 = shape.get(2).copied().unwrap_or(0);
55 let s3 = shape.get(3).copied().unwrap_or(0);
56
57 let flat: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
58
59 let mut result = Vec::with_capacity(s0);
60 let mut iter = flat.into_iter();
61 for _ in 0..s0 {
62 let mut axis1 = Vec::with_capacity(s1);
63 for _ in 0..s1 {
64 let mut axis2 = Vec::with_capacity(s2);
65 for _ in 0..s2 {
66 let row: Vec<f32> = iter.by_ref().take(s3).collect();
67 axis2.push(row);
68 }
69 axis1.push(axis2);
70 }
71 result.push(axis1);
72 }
73
74 Ok(result)
75}
76
77fn softmax_to_vec(logits: &Tensor) -> Result<Vec<f32>> {
79 let logits_f32 = logits.to_dtype(DType::F32)?;
81 let probs = candle_nn::ops::softmax_last_dim(&logits_f32)?;
82 Ok(probs.flatten_all()?.to_vec1()?)
83}
84
85fn expand_edges(edges: &[AttentionEdge], seq_len: usize) -> Vec<AttentionEdge> {
90 let mut expanded = Vec::new();
91
92 for edge in edges {
93 match (edge.from_pos, edge.to_pos) {
94 (from, usize::MAX) if from != usize::MAX => {
95 for to in 0..seq_len {
96 expanded.push(AttentionEdge::new(from, to));
97 }
98 }
99 (usize::MAX, to) if to != usize::MAX => {
100 for from in 0..seq_len {
101 expanded.push(AttentionEdge::new(from, to));
102 }
103 }
104 (from, to) if from != usize::MAX && to != usize::MAX => {
105 expanded.push(*edge);
106 }
107 _ => {} }
109 }
110
111 expanded
112}
113
114#[non_exhaustive]
120#[derive(Debug, Clone)]
121pub enum LayerSpec {
122 All,
124 Specific(Vec<usize>),
126 Range {
128 start: usize,
130 end: usize,
132 },
133}
134
135#[non_exhaustive]
137#[derive(Debug, Clone)]
138pub enum HeadSpec {
139 All,
141 Specific(Vec<usize>),
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
150pub struct AttentionEdge {
151 pub from_pos: usize,
153 pub to_pos: usize,
155}
156
157impl AttentionEdge {
158 #[must_use]
160 pub const fn new(from_pos: usize, to_pos: usize) -> Self {
161 Self { from_pos, to_pos }
162 }
163}
164
165#[derive(Debug, Clone)]
182#[must_use]
183pub struct KnockoutSpec {
184 pub layers: LayerSpec,
186 pub heads: HeadSpec,
188 pub edges: Vec<AttentionEdge>,
190}
191
192impl KnockoutSpec {
193 pub const fn new() -> Self {
195 Self {
196 layers: LayerSpec::All,
197 heads: HeadSpec::All,
198 edges: Vec::new(),
199 }
200 }
201
202 pub fn layer(mut self, layer: usize) -> Self {
204 self.layers = LayerSpec::Specific(vec![layer]);
205 self
206 }
207
208 pub fn layers(mut self, layers: &[usize]) -> Self {
210 self.layers = LayerSpec::Specific(layers.to_vec());
211 self
212 }
213
214 pub fn layer_range(mut self, start: usize, end: usize) -> Self {
216 self.layers = LayerSpec::Range { start, end };
217 self
218 }
219
220 pub fn head(mut self, head: usize) -> Self {
222 self.heads = HeadSpec::Specific(vec![head]);
223 self
224 }
225
226 pub fn heads(mut self, heads: &[usize]) -> Self {
228 self.heads = HeadSpec::Specific(heads.to_vec());
229 self
230 }
231
232 pub fn edge(mut self, from_pos: usize, to_pos: usize) -> Self {
234 self.edges.push(AttentionEdge::new(from_pos, to_pos));
235 self
236 }
237
238 pub fn from_position(mut self, from_pos: usize) -> Self {
240 self.edges.push(AttentionEdge::new(from_pos, usize::MAX));
241 self
242 }
243
244 pub fn to_position(mut self, to_pos: usize) -> Self {
246 self.edges.push(AttentionEdge::new(usize::MAX, to_pos));
247 self
248 }
249
250 pub fn from_to_positions(mut self, from_pos: usize, to_positions: &[usize]) -> Self {
252 for &to_pos in to_positions {
253 self.edges.push(AttentionEdge::new(from_pos, to_pos));
254 }
255 self
256 }
257
258 #[must_use]
260 pub fn applies_to_layer(&self, layer: usize) -> bool {
261 match &self.layers {
262 LayerSpec::All => true,
263 LayerSpec::Specific(layers) => layers.contains(&layer),
264 LayerSpec::Range { start, end } => layer >= *start && layer <= *end,
265 }
266 }
267
268 #[must_use]
270 pub fn applies_to_head(&self, head: usize) -> bool {
271 match &self.heads {
272 HeadSpec::All => true,
273 HeadSpec::Specific(heads) => heads.contains(&head),
274 }
275 }
276
277 pub fn validate(&self, n_layers: usize, n_heads: usize, seq_len: usize) -> Result<()> {
284 validate_layers(&self.layers, n_layers)?;
285 validate_heads(&self.heads, n_heads)?;
286 validate_edges(&self.edges, seq_len)?;
287 Ok(())
288 }
289}
290
291impl Default for KnockoutSpec {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297#[non_exhaustive]
303#[derive(Debug, Clone, Copy, PartialEq, Default)]
304pub enum InterventionType {
305 #[default]
307 Knockout,
308 Scale(f32),
310 SetValue(f32),
312}
313
314#[derive(Debug, Clone)]
331#[must_use]
332pub struct SteeringSpec {
333 pub layers: LayerSpec,
335 pub heads: HeadSpec,
337 pub edges: Vec<AttentionEdge>,
339 pub intervention_type: InterventionType,
341}
342
343impl SteeringSpec {
344 pub const fn new(intervention_type: InterventionType) -> Self {
346 Self {
347 layers: LayerSpec::All,
348 heads: HeadSpec::All,
349 edges: Vec::new(),
350 intervention_type,
351 }
352 }
353
354 pub const fn scale(factor: f32) -> Self {
356 Self::new(InterventionType::Scale(factor))
357 }
358
359 pub const fn set_value(target: f32) -> Self {
361 Self::new(InterventionType::SetValue(target))
362 }
363
364 pub fn layer(mut self, layer: usize) -> Self {
366 self.layers = LayerSpec::Specific(vec![layer]);
367 self
368 }
369
370 pub fn layers(mut self, layers: &[usize]) -> Self {
372 self.layers = LayerSpec::Specific(layers.to_vec());
373 self
374 }
375
376 pub fn layer_range(mut self, start: usize, end: usize) -> Self {
378 self.layers = LayerSpec::Range { start, end };
379 self
380 }
381
382 pub fn head(mut self, head: usize) -> Self {
384 self.heads = HeadSpec::Specific(vec![head]);
385 self
386 }
387
388 pub fn heads(mut self, heads: &[usize]) -> Self {
390 self.heads = HeadSpec::Specific(heads.to_vec());
391 self
392 }
393
394 pub fn edge(mut self, from_pos: usize, to_pos: usize) -> Self {
396 self.edges.push(AttentionEdge::new(from_pos, to_pos));
397 self
398 }
399
400 pub fn from_position(mut self, from_pos: usize) -> Self {
402 self.edges.push(AttentionEdge::new(from_pos, usize::MAX));
403 self
404 }
405
406 pub fn to_position(mut self, to_pos: usize) -> Self {
408 self.edges.push(AttentionEdge::new(usize::MAX, to_pos));
409 self
410 }
411
412 pub fn from_to_positions(mut self, from_pos: usize, to_positions: &[usize]) -> Self {
414 for &to_pos in to_positions {
415 self.edges.push(AttentionEdge::new(from_pos, to_pos));
416 }
417 self
418 }
419
420 #[must_use]
422 pub fn applies_to_layer(&self, layer: usize) -> bool {
423 match &self.layers {
424 LayerSpec::All => true,
425 LayerSpec::Specific(layers) => layers.contains(&layer),
426 LayerSpec::Range { start, end } => layer >= *start && layer <= *end,
427 }
428 }
429
430 #[must_use]
432 pub fn applies_to_head(&self, head: usize) -> bool {
433 match &self.heads {
434 HeadSpec::All => true,
435 HeadSpec::Specific(heads) => heads.contains(&head),
436 }
437 }
438
439 pub fn validate(&self, n_layers: usize, n_heads: usize, seq_len: usize) -> Result<()> {
446 validate_layers(&self.layers, n_layers)?;
447 validate_heads(&self.heads, n_heads)?;
448 validate_edges(&self.edges, seq_len)?;
449
450 match self.intervention_type {
451 InterventionType::Scale(factor) => {
452 if factor < 0.0 {
453 return Err(MIError::Intervention(format!(
454 "scale factor must be non-negative, got {factor}"
455 )));
456 }
457 }
458 InterventionType::SetValue(value) => {
459 if !(0.0..=1.0).contains(&value) {
460 return Err(MIError::Intervention(format!(
461 "set value must be in [0, 1], got {value}"
462 )));
463 }
464 }
465 InterventionType::Knockout => {}
466 }
467
468 Ok(())
469 }
470
471 #[must_use]
473 pub const fn intervention_type(&self) -> InterventionType {
474 self.intervention_type
475 }
476
477 #[must_use]
479 pub const fn is_knockout(&self) -> bool {
480 matches!(self.intervention_type, InterventionType::Knockout)
481 }
482
483 #[must_use]
485 pub const fn is_steering(&self) -> bool {
486 matches!(
487 self.intervention_type,
488 InterventionType::Scale(_) | InterventionType::SetValue(_)
489 )
490 }
491
492 #[must_use]
498 pub fn is_prompt_only(&self, prompt_len: usize) -> bool {
499 for edge in &self.edges {
500 if edge.from_pos == usize::MAX {
501 return false;
502 }
503 if edge.from_pos >= prompt_len {
504 return false;
505 }
506 }
507 true
508 }
509
510 #[must_use]
512 pub fn max_from_pos(&self) -> Option<usize> {
513 self.edges
514 .iter()
515 .filter(|e| e.from_pos != usize::MAX)
516 .map(|e| e.from_pos)
517 .max()
518 }
519
520 #[must_use]
522 pub fn max_to_pos(&self) -> Option<usize> {
523 self.edges
524 .iter()
525 .filter(|e| e.to_pos != usize::MAX)
526 .map(|e| e.to_pos)
527 .max()
528 }
529}
530
531impl From<KnockoutSpec> for SteeringSpec {
533 fn from(spec: KnockoutSpec) -> Self {
534 Self {
535 layers: spec.layers,
536 heads: spec.heads,
537 edges: spec.edges,
538 intervention_type: InterventionType::Knockout,
539 }
540 }
541}
542
543#[derive(Debug)]
552pub struct AblationResult {
553 pub baseline_logits: Tensor,
555 pub ablated_logits: Tensor,
557 pub spec: KnockoutSpec,
559}
560
561impl AblationResult {
562 #[must_use]
564 pub const fn new(baseline_logits: Tensor, ablated_logits: Tensor, spec: KnockoutSpec) -> Self {
565 Self {
566 baseline_logits,
567 ablated_logits,
568 spec,
569 }
570 }
571
572 pub fn kl_divergence(&self) -> Result<f32> {
578 kl_divergence(&self.baseline_logits, &self.ablated_logits)
579 }
580
581 pub fn logit_diff(&self, token_id: u32) -> Result<f32> {
587 logit_diff_impl(&self.baseline_logits, &self.ablated_logits, token_id)
588 }
589
590 pub fn top_changed_tokens(&self, k: usize) -> Result<Vec<(u32, f32, f32, f32)>> {
598 top_changed_impl(&self.baseline_logits, &self.ablated_logits, k)
599 }
600}
601
602#[derive(Debug)]
604#[must_use]
605pub struct SteeringResult {
606 pub baseline_logits: Tensor,
608 pub steered_logits: Tensor,
610 pub spec: SteeringSpec,
612 pub baseline_attention_mean: Option<f32>,
614 pub steered_attention_mean: Option<f32>,
616}
617
618impl SteeringResult {
619 pub const fn new(baseline_logits: Tensor, steered_logits: Tensor, spec: SteeringSpec) -> Self {
621 Self {
622 baseline_logits,
623 steered_logits,
624 spec,
625 baseline_attention_mean: None,
626 steered_attention_mean: None,
627 }
628 }
629
630 pub const fn with_attention_measurements(
632 mut self,
633 baseline_mean: f32,
634 steered_mean: f32,
635 ) -> Self {
636 self.baseline_attention_mean = Some(baseline_mean);
637 self.steered_attention_mean = Some(steered_mean);
638 self
639 }
640
641 pub fn kl_divergence(&self) -> Result<f32> {
647 kl_divergence(&self.baseline_logits, &self.steered_logits)
648 }
649
650 pub fn logit_diff(&self, token_id: u32) -> Result<f32> {
656 logit_diff_impl(&self.baseline_logits, &self.steered_logits, token_id)
657 }
658
659 pub fn top_changed_tokens(&self, k: usize) -> Result<Vec<(u32, f32, f32, f32)>> {
665 top_changed_impl(&self.baseline_logits, &self.steered_logits, k)
666 }
667
668 #[must_use]
670 pub fn attention_ratio(&self) -> Option<f32> {
671 match (self.baseline_attention_mean, self.steered_attention_mean) {
672 (Some(base), Some(steered)) if base > 1e-10 => Some(steered / base),
673 _ => None,
674 }
675 }
676}
677
678fn logit_diff_impl(baseline: &Tensor, other: &Tensor, token_id: u32) -> Result<f32> {
684 let baseline_f32 = baseline.to_dtype(DType::F32)?;
685 let other_f32 = other.to_dtype(DType::F32)?;
686 let baseline_vec: Vec<f32> = baseline_f32.flatten_all()?.to_vec1()?;
687 let other_vec: Vec<f32> = other_f32.flatten_all()?.to_vec1()?;
688
689 #[allow(clippy::as_conversions)]
691 let idx = token_id as usize;
692 let b = baseline_vec
693 .get(idx)
694 .ok_or_else(|| MIError::Intervention(format!("token ID {token_id} out of range")))?;
695 let o = other_vec
696 .get(idx)
697 .ok_or_else(|| MIError::Intervention(format!("token ID {token_id} out of range")))?;
698 Ok(b - o)
699}
700
701#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
704fn top_changed_impl(
705 baseline: &Tensor,
706 other: &Tensor,
707 k: usize,
708) -> Result<Vec<(u32, f32, f32, f32)>> {
709 let baseline_probs = softmax_to_vec(baseline)?;
710 let other_probs = softmax_to_vec(other)?;
711
712 let mut changes: Vec<(u32, f32, f32, f32)> = baseline_probs
713 .iter()
714 .zip(other_probs.iter())
715 .enumerate()
716 .map(|(idx, (&base, &oth))| (idx as u32, base, oth, (base - oth).abs()))
717 .collect();
718
719 changes.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
720 Ok(changes.into_iter().take(k).collect())
721}
722
723fn validate_layers(layers: &LayerSpec, n_layers: usize) -> Result<()> {
729 match layers {
730 LayerSpec::Specific(ls) => {
731 for &l in ls {
732 if l >= n_layers {
733 return Err(MIError::Intervention(format!(
734 "layer {l} out of range (model has {n_layers} layers)"
735 )));
736 }
737 }
738 }
739 LayerSpec::Range { start, end } => {
740 if *end >= n_layers {
741 return Err(MIError::Intervention(format!(
742 "layer range end {end} out of range (model has {n_layers} layers)"
743 )));
744 }
745 if start > end {
746 return Err(MIError::Intervention(format!(
747 "invalid layer range: start {start} > end {end}"
748 )));
749 }
750 }
751 LayerSpec::All => {}
752 }
753 Ok(())
754}
755
756fn validate_heads(heads: &HeadSpec, n_heads: usize) -> Result<()> {
758 if let HeadSpec::Specific(hs) = heads {
759 for &h in hs {
760 if h >= n_heads {
761 return Err(MIError::Intervention(format!(
762 "head {h} out of range (model has {n_heads} heads)"
763 )));
764 }
765 }
766 }
767 Ok(())
768}
769
770fn validate_edges(edges: &[AttentionEdge], seq_len: usize) -> Result<()> {
772 for edge in edges {
773 if edge.from_pos != usize::MAX && edge.from_pos >= seq_len {
774 return Err(MIError::Intervention(format!(
775 "edge from_pos {} out of range (seq_len is {seq_len})",
776 edge.from_pos,
777 )));
778 }
779 if edge.to_pos != usize::MAX && edge.to_pos >= seq_len {
780 return Err(MIError::Intervention(format!(
781 "edge to_pos {} out of range (seq_len is {seq_len})",
782 edge.to_pos,
783 )));
784 }
785 }
786 Ok(())
787}
788
789#[allow(clippy::indexing_slicing)] pub fn create_knockout_mask(
810 spec: &KnockoutSpec,
811 n_heads: usize,
812 seq_len: usize,
813 device: &Device,
814 dtype: DType,
815) -> Result<Tensor> {
816 let mut mask_data = vec![0.0f32; n_heads * seq_len * seq_len];
817 let expanded_edges = expand_edges(&spec.edges, seq_len);
818
819 for head in 0..n_heads {
820 if !spec.applies_to_head(head) {
821 continue;
822 }
823
824 for edge in &expanded_edges {
825 if edge.from_pos < seq_len && edge.to_pos < seq_len {
826 let idx = head * seq_len * seq_len + edge.from_pos * seq_len + edge.to_pos;
827 mask_data[idx] = f32::NEG_INFINITY;
828 }
829 }
830 }
831
832 let mask = Tensor::from_vec(mask_data, (1, n_heads, seq_len, seq_len), device)?;
833 Ok(mask.to_dtype(dtype)?)
834}
835
836pub fn kl_divergence(baseline_logits: &Tensor, other_logits: &Tensor) -> Result<f32> {
844 let p = softmax_to_vec(baseline_logits)?;
845 let q = softmax_to_vec(other_logits)?;
846
847 let kl: f32 = p
848 .iter()
849 .zip(q.iter())
850 .filter(|&(&pi, &qi)| pi > 1e-10 && qi > 1e-10)
851 .map(|(&pi, &qi)| pi * (pi / qi).ln())
852 .sum();
853
854 Ok(kl)
855}
856
857pub fn apply_steering(
872 attn_weights: &Tensor,
873 spec: &SteeringSpec,
874 n_heads: usize,
875 seq_len: usize,
876) -> Result<Tensor> {
877 match spec.intervention_type {
878 InterventionType::Scale(factor) => {
879 apply_scale_steering(attn_weights, spec, n_heads, seq_len, factor)
880 }
881 InterventionType::SetValue(target) => {
882 apply_set_value_steering(attn_weights, spec, n_heads, seq_len, target)
883 }
884 InterventionType::Knockout => Err(MIError::Intervention(
885 "knockout should use create_knockout_mask, not apply_steering".into(),
886 )),
887 }
888}
889
890#[allow(clippy::indexing_slicing)] pub fn apply_scale_steering(
902 attn_weights: &Tensor,
903 spec: &SteeringSpec,
904 _n_heads: usize,
905 seq_len: usize,
906 scale_factor: f32,
907) -> Result<Tensor> {
908 let attn_f32 = attn_weights.to_dtype(DType::F32)?;
910 let original_dtype = attn_weights.dtype();
911 let device = attn_weights.device();
912
913 let mut data = tensor_to_vec4(&attn_f32)?;
914 let expanded_edges = expand_edges(&spec.edges, seq_len);
915
916 for batch_data in &mut data {
917 for (h, head_data) in batch_data.iter_mut().enumerate() {
918 if !spec.applies_to_head(h) {
919 continue;
920 }
921
922 let mut rows_modified: HashSet<usize> = HashSet::new();
923
924 for edge in &expanded_edges {
925 if edge.from_pos < seq_len && edge.to_pos < seq_len {
926 head_data[edge.from_pos][edge.to_pos] *= scale_factor;
927 rows_modified.insert(edge.from_pos);
928 }
929 }
930
931 for row in rows_modified {
932 let row_sum: f32 = head_data[row].iter().sum();
933 if row_sum > 1e-10 {
934 for val in &mut head_data[row] {
935 *val /= row_sum;
936 }
937 }
938 }
939 }
940 }
941
942 let result = Tensor::new(data, device)?.to_dtype(original_dtype)?;
943 Ok(result)
944}
945
946#[allow(
958 clippy::indexing_slicing, clippy::cast_precision_loss,
960 clippy::as_conversions,
961)]
962pub fn apply_set_value_steering(
963 attn_weights: &Tensor,
964 spec: &SteeringSpec,
965 _n_heads: usize,
966 seq_len: usize,
967 target_value: f32,
968) -> Result<Tensor> {
969 let attn_f32 = attn_weights.to_dtype(DType::F32)?;
971 let original_dtype = attn_weights.dtype();
972 let device = attn_weights.device();
973
974 let mut data = tensor_to_vec4(&attn_f32)?;
975 let expanded_edges = expand_edges(&spec.edges, seq_len);
976
977 let mut edges_by_row: HashMap<usize, Vec<usize>> = HashMap::new();
979 for edge in &expanded_edges {
980 if edge.from_pos < seq_len && edge.to_pos < seq_len {
981 edges_by_row
982 .entry(edge.from_pos)
983 .or_default()
984 .push(edge.to_pos);
985 }
986 }
987
988 for batch_data in &mut data {
989 for (h, head_data) in batch_data.iter_mut().enumerate() {
990 if !spec.applies_to_head(h) {
991 continue;
992 }
993
994 for (&row, target_cols) in &edges_by_row {
995 let current_target_sum: f32 =
996 target_cols.iter().map(|&col| head_data[row][col]).sum();
997 let new_target_sum = target_value * target_cols.len() as f32;
998 let delta = new_target_sum - current_target_sum;
999
1000 let non_target_cols: Vec<usize> =
1001 (0..seq_len).filter(|i| !target_cols.contains(i)).collect();
1002
1003 for &col in target_cols {
1004 head_data[row][col] = target_value;
1005 }
1006
1007 if !non_target_cols.is_empty() {
1008 let adjustment = delta / non_target_cols.len() as f32;
1009 for col in non_target_cols {
1010 head_data[row][col] = (head_data[row][col] - adjustment).max(0.0);
1011 }
1012 }
1013
1014 let row_sum: f32 = head_data[row].iter().sum();
1015 if row_sum > 1e-10 {
1016 for val in &mut head_data[row] {
1017 *val /= row_sum;
1018 }
1019 }
1020 }
1021 }
1022 }
1023
1024 let result = Tensor::new(data, device)?.to_dtype(original_dtype)?;
1025 Ok(result)
1026}
1027
1028#[allow(clippy::indexing_slicing)] pub fn measure_attention_to_targets(
1039 attn_weights: &Tensor,
1040 from_pos: usize,
1041 to_positions: &[usize],
1042) -> Result<f32> {
1043 let attn_f32 = attn_weights.to_dtype(DType::F32)?;
1044 let data = tensor_to_vec4(&attn_f32)?;
1045
1046 let seq_len = data.first().and_then(|b| b.first()).map_or(0, Vec::len);
1047
1048 if from_pos >= seq_len {
1049 return Err(MIError::Intervention(format!(
1050 "from_pos {from_pos} out of range (seq_len is {seq_len})"
1051 )));
1052 }
1053
1054 let mut total = 0.0_f32;
1055 let mut count = 0_usize;
1056
1057 for batch_data in &data {
1058 for head_data in batch_data {
1059 for &to_pos in to_positions {
1060 if to_pos < seq_len {
1061 total += head_data[from_pos][to_pos];
1062 count += 1;
1063 }
1064 }
1065 }
1066 }
1067
1068 if count == 0 {
1069 Ok(0.0)
1070 } else {
1071 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
1073 Ok(total / count as f32)
1074 }
1075}
1076
1077#[derive(Debug, Clone)]
1087#[must_use]
1088pub struct StateKnockoutSpec {
1089 pub positions: Vec<usize>,
1091 pub layers: LayerSpec,
1093}
1094
1095impl StateKnockoutSpec {
1096 pub const fn new() -> Self {
1098 Self {
1099 positions: Vec::new(),
1100 layers: LayerSpec::All,
1101 }
1102 }
1103
1104 pub fn position(mut self, pos: usize) -> Self {
1106 self.positions.push(pos);
1107 self
1108 }
1109
1110 pub fn positions(mut self, positions: &[usize]) -> Self {
1112 self.positions.extend_from_slice(positions);
1113 self
1114 }
1115
1116 pub fn layer(mut self, layer: usize) -> Self {
1118 self.layers = LayerSpec::Specific(vec![layer]);
1119 self
1120 }
1121
1122 pub fn layers(mut self, layers: &[usize]) -> Self {
1124 self.layers = LayerSpec::Specific(layers.to_vec());
1125 self
1126 }
1127
1128 pub fn layer_range(mut self, start: usize, end: usize) -> Self {
1130 self.layers = LayerSpec::Range { start, end };
1131 self
1132 }
1133
1134 #[must_use]
1136 pub fn applies_to_layer(&self, layer: usize) -> bool {
1137 match &self.layers {
1138 LayerSpec::All => true,
1139 LayerSpec::Specific(layers) => layers.contains(&layer),
1140 LayerSpec::Range { start, end } => layer >= *start && layer <= *end,
1141 }
1142 }
1143
1144 #[must_use]
1146 pub fn position_set(&self) -> HashSet<usize> {
1147 self.positions.iter().copied().collect()
1148 }
1149
1150 pub fn validate(&self, n_layers: usize, seq_len: usize) -> Result<()> {
1157 validate_layers(&self.layers, n_layers)?;
1158
1159 for &pos in &self.positions {
1160 if pos >= seq_len {
1161 return Err(MIError::Intervention(format!(
1162 "position {pos} out of range (seq_len is {seq_len})"
1163 )));
1164 }
1165 }
1166
1167 if self.positions.is_empty() {
1168 return Err(MIError::Intervention(
1169 "StateKnockoutSpec has no positions specified".into(),
1170 ));
1171 }
1172
1173 Ok(())
1174 }
1175}
1176
1177impl Default for StateKnockoutSpec {
1178 fn default() -> Self {
1179 Self::new()
1180 }
1181}
1182
1183#[derive(Debug)]
1185pub struct StateAblationResult {
1186 pub baseline_logits: Tensor,
1188 pub ablated_logits: Tensor,
1190 pub spec: StateKnockoutSpec,
1192}
1193
1194impl StateAblationResult {
1195 #[must_use]
1197 pub const fn new(
1198 baseline_logits: Tensor,
1199 ablated_logits: Tensor,
1200 spec: StateKnockoutSpec,
1201 ) -> Self {
1202 Self {
1203 baseline_logits,
1204 ablated_logits,
1205 spec,
1206 }
1207 }
1208
1209 pub fn kl_divergence(&self) -> Result<f32> {
1215 kl_divergence(&self.baseline_logits, &self.ablated_logits)
1216 }
1217
1218 pub fn logit_diff(&self, token_id: u32) -> Result<f32> {
1224 logit_diff_impl(&self.baseline_logits, &self.ablated_logits, token_id)
1225 }
1226
1227 pub fn top_changed_tokens(&self, k: usize) -> Result<Vec<(u32, f32, f32, f32)>> {
1233 top_changed_impl(&self.baseline_logits, &self.ablated_logits, k)
1234 }
1235}
1236
1237#[derive(Debug, Clone)]
1251#[must_use]
1252pub struct StateSteeringSpec {
1253 pub positions: Vec<usize>,
1255 pub layers: LayerSpec,
1257 pub scale: f32,
1259}
1260
1261impl StateSteeringSpec {
1262 pub const fn new(scale: f32) -> Self {
1264 Self {
1265 positions: Vec::new(),
1266 layers: LayerSpec::All,
1267 scale,
1268 }
1269 }
1270
1271 pub fn position(mut self, pos: usize) -> Self {
1273 self.positions.push(pos);
1274 self
1275 }
1276
1277 pub fn positions(mut self, positions: &[usize]) -> Self {
1279 self.positions.extend_from_slice(positions);
1280 self
1281 }
1282
1283 pub fn layer(mut self, layer: usize) -> Self {
1285 self.layers = LayerSpec::Specific(vec![layer]);
1286 self
1287 }
1288
1289 pub fn layers(mut self, layers: &[usize]) -> Self {
1291 self.layers = LayerSpec::Specific(layers.to_vec());
1292 self
1293 }
1294
1295 pub fn layer_range(mut self, start: usize, end: usize) -> Self {
1297 self.layers = LayerSpec::Range { start, end };
1298 self
1299 }
1300
1301 #[must_use]
1303 pub fn applies_to_layer(&self, layer: usize) -> bool {
1304 match &self.layers {
1305 LayerSpec::All => true,
1306 LayerSpec::Specific(layers) => layers.contains(&layer),
1307 LayerSpec::Range { start, end } => layer >= *start && layer <= *end,
1308 }
1309 }
1310
1311 #[must_use]
1313 pub fn position_set(&self) -> HashSet<usize> {
1314 self.positions.iter().copied().collect()
1315 }
1316
1317 pub fn validate(&self, n_layers: usize, seq_len: usize) -> Result<()> {
1324 validate_layers(&self.layers, n_layers)?;
1325
1326 for &pos in &self.positions {
1327 if pos >= seq_len {
1328 return Err(MIError::Intervention(format!(
1329 "position {pos} out of range (seq_len is {seq_len})"
1330 )));
1331 }
1332 }
1333
1334 if self.positions.is_empty() {
1335 return Err(MIError::Intervention(
1336 "StateSteeringSpec has no positions specified".into(),
1337 ));
1338 }
1339
1340 Ok(())
1341 }
1342}
1343
1344#[derive(Debug)]
1346pub struct StateSteeringResult {
1347 pub baseline_logits: Tensor,
1349 pub steered_logits: Tensor,
1351 pub spec: StateSteeringSpec,
1353}
1354
1355impl StateSteeringResult {
1356 #[must_use]
1358 pub const fn new(
1359 baseline_logits: Tensor,
1360 steered_logits: Tensor,
1361 spec: StateSteeringSpec,
1362 ) -> Self {
1363 Self {
1364 baseline_logits,
1365 steered_logits,
1366 spec,
1367 }
1368 }
1369
1370 pub fn kl_divergence(&self) -> Result<f32> {
1376 kl_divergence(&self.baseline_logits, &self.steered_logits)
1377 }
1378
1379 pub fn top_changed_tokens(&self, k: usize) -> Result<Vec<(u32, f32, f32, f32)>> {
1385 top_changed_impl(&self.baseline_logits, &self.steered_logits, k)
1386 }
1387}
1388
1389#[cfg(feature = "clt")]
1399#[derive(Debug, Clone)]
1400pub struct CltInjectionSpec {
1401 pub injections: Vec<CltLayerInjection>,
1403}
1404
1405#[cfg(feature = "clt")]
1407#[derive(Debug, Clone)]
1408pub struct CltLayerInjection {
1409 pub target_layer: usize,
1411 pub position: usize,
1413 pub vector: Tensor,
1415}
1416
1417#[cfg(feature = "clt")]
1418impl CltInjectionSpec {
1419 #[must_use]
1421 pub const fn new() -> Self {
1422 Self {
1423 injections: Vec::new(),
1424 }
1425 }
1426
1427 pub fn add(&mut self, target_layer: usize, position: usize, vector: Tensor) {
1429 self.injections.push(CltLayerInjection {
1430 target_layer,
1431 position,
1432 vector,
1433 });
1434 }
1435
1436 #[must_use]
1438 pub fn applies_to_layer(&self, layer: usize) -> bool {
1439 self.injections.iter().any(|inj| inj.target_layer == layer)
1440 }
1441
1442 #[must_use]
1444 pub fn injections_for_layer(&self, layer: usize) -> Vec<&CltLayerInjection> {
1445 self.injections
1446 .iter()
1447 .filter(|inj| inj.target_layer == layer)
1448 .collect()
1449 }
1450
1451 pub fn validate(&self, n_layers: usize, seq_len: usize, d_model: usize) -> Result<()> {
1458 for inj in &self.injections {
1459 let target = inj.target_layer;
1460 if target >= n_layers {
1461 return Err(MIError::Intervention(format!(
1462 "CLT injection target layer {target} out of range (model has {n_layers} layers)"
1463 )));
1464 }
1465 let pos = inj.position;
1466 if pos >= seq_len {
1467 return Err(MIError::Intervention(format!(
1468 "CLT injection position {pos} out of range (seq_len={seq_len})"
1469 )));
1470 }
1471 let vec_dim = inj.vector.dim(0)?;
1472 if vec_dim != d_model {
1473 return Err(MIError::Intervention(format!(
1474 "CLT injection vector dim {vec_dim} doesn't match model d_model={d_model}"
1475 )));
1476 }
1477 }
1478 Ok(())
1479 }
1480}
1481
1482#[cfg(feature = "clt")]
1483impl Default for CltInjectionSpec {
1484 fn default() -> Self {
1485 Self::new()
1486 }
1487}
1488
1489#[cfg(feature = "clt")]
1491#[derive(Debug)]
1492pub struct CltLogitShiftResult {
1493 pub baseline_logits: Tensor,
1495 pub injected_logits: Tensor,
1497}
1498
1499#[cfg(feature = "clt")]
1500impl CltLogitShiftResult {
1501 #[must_use]
1503 pub const fn new(baseline_logits: Tensor, injected_logits: Tensor) -> Self {
1504 Self {
1505 baseline_logits,
1506 injected_logits,
1507 }
1508 }
1509
1510 pub fn kl_divergence(&self) -> Result<f32> {
1516 kl_divergence(&self.baseline_logits, &self.injected_logits)
1517 }
1518
1519 pub fn top_changed_tokens(&self, k: usize) -> Result<Vec<(u32, f32, f32, f32)>> {
1525 top_changed_impl(&self.baseline_logits, &self.injected_logits, k)
1526 }
1527}
1528
1529#[cfg(test)]
1534#[allow(
1535 clippy::unwrap_used,
1536 clippy::expect_used,
1537 clippy::float_cmp,
1538 clippy::indexing_slicing
1539)]
1540mod tests {
1541 use super::*;
1542
1543 #[test]
1544 fn knockout_spec_builder() {
1545 let spec = KnockoutSpec::new()
1546 .layer(5)
1547 .head(2)
1548 .edge(3, 1)
1549 .from_to_positions(4, &[0, 1, 2]);
1550
1551 assert!(matches!(spec.layers, LayerSpec::Specific(_)));
1552 assert!(matches!(spec.heads, HeadSpec::Specific(_)));
1553 assert_eq!(spec.edges.len(), 4); }
1555
1556 #[test]
1557 fn layer_spec_applies() {
1558 let spec = KnockoutSpec::new().layer_range(5, 10);
1559
1560 assert!(!spec.applies_to_layer(4));
1561 assert!(spec.applies_to_layer(5));
1562 assert!(spec.applies_to_layer(7));
1563 assert!(spec.applies_to_layer(10));
1564 assert!(!spec.applies_to_layer(11));
1565 }
1566
1567 #[test]
1568 fn expand_edges_sentinels() {
1569 let edges = vec![AttentionEdge::new(2, usize::MAX), AttentionEdge::new(1, 0)];
1570
1571 let expanded = expand_edges(&edges, 4);
1572 assert_eq!(expanded.len(), 5); }
1574
1575 #[test]
1576 fn create_knockout_mask_correctness() {
1577 let spec = KnockoutSpec::new().head(0).edge(2, 1);
1578
1579 let mask = create_knockout_mask(&spec, 2, 4, &Device::Cpu, DType::F32).unwrap();
1580 assert_eq!(mask.dims(), &[1, 2, 4, 4]);
1581
1582 let mask_vec: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap();
1583
1584 assert!(mask_vec[9].is_infinite() && mask_vec[9].is_sign_negative());
1586
1587 assert_eq!(mask_vec[25], 0.0);
1589 }
1590
1591 #[test]
1592 fn validation_catches_errors() {
1593 let spec = KnockoutSpec::new().layer(100).edge(50, 25);
1594 assert!(spec.validate(30, 16, 20).is_err());
1595 }
1596
1597 #[test]
1598 fn validation_passes_valid() {
1599 let spec = KnockoutSpec::new().layer(10).edge(5, 3);
1600 assert!(spec.validate(30, 16, 20).is_ok());
1601 }
1602
1603 #[test]
1604 fn steering_spec_builder() {
1605 let spec = SteeringSpec::scale(2.0)
1606 .layer(5)
1607 .head(2)
1608 .edge(3, 1)
1609 .from_to_positions(4, &[0, 1, 2]);
1610
1611 assert!(matches!(spec.layers, LayerSpec::Specific(_)));
1612 assert!(matches!(spec.heads, HeadSpec::Specific(_)));
1613 assert_eq!(spec.edges.len(), 4);
1614 assert!(
1615 matches!(spec.intervention_type, InterventionType::Scale(f) if (f - 2.0).abs() < 1e-6)
1616 );
1617 }
1618
1619 #[test]
1620 fn steering_validation() {
1621 let spec = SteeringSpec::scale(2.0).layer(10).edge(5, 3);
1622 assert!(spec.validate(30, 16, 20).is_ok());
1623
1624 let spec = SteeringSpec::scale(-1.0).layer(10).edge(5, 3);
1625 assert!(spec.validate(30, 16, 20).is_err());
1626
1627 let spec = SteeringSpec::set_value(0.09).layer(10).edge(5, 3);
1628 assert!(spec.validate(30, 16, 20).is_ok());
1629
1630 let spec = SteeringSpec::set_value(1.5).layer(10).edge(5, 3);
1631 assert!(spec.validate(30, 16, 20).is_err());
1632 }
1633
1634 #[test]
1635 fn steering_is_methods() {
1636 let knockout = SteeringSpec::new(InterventionType::Knockout);
1637 assert!(knockout.is_knockout());
1638 assert!(!knockout.is_steering());
1639
1640 let scale = SteeringSpec::scale(2.0);
1641 assert!(!scale.is_knockout());
1642 assert!(scale.is_steering());
1643
1644 let set_value = SteeringSpec::set_value(0.1);
1645 assert!(!set_value.is_knockout());
1646 assert!(set_value.is_steering());
1647 }
1648
1649 #[test]
1650 fn apply_scale_steering_correctness() {
1651 let data: Vec<f32> = vec![
1652 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
1654 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
1656 0.25, 0.25,
1657 ];
1658 let tensor = Tensor::from_vec(data, (1, 2, 4, 4), &Device::Cpu).unwrap();
1659
1660 let spec = SteeringSpec::scale(2.0).edge(2, 1);
1661 let result = apply_scale_steering(&tensor, &spec, 2, 4, 2.0).unwrap();
1662 let result_data = tensor_to_vec4(&result).unwrap();
1663
1664 let row2 = &result_data[0][0][2];
1669 assert!((row2[0] - 0.20).abs() < 1e-5);
1670 assert!((row2[1] - 0.40).abs() < 1e-5);
1671 assert!((row2[2] - 0.20).abs() < 1e-5);
1672 assert!((row2[3] - 0.20).abs() < 1e-5);
1673
1674 let row_sum: f32 = row2.iter().sum();
1675 assert!((row_sum - 1.0).abs() < 1e-5);
1676 }
1677
1678 #[test]
1679 fn apply_set_value_steering_correctness() {
1680 let data: Vec<f32> = vec![
1681 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
1682 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
1683 0.25, 0.25, 0.25, 0.25,
1684 ];
1685 let tensor = Tensor::from_vec(data, (1, 2, 4, 4), &Device::Cpu).unwrap();
1686
1687 let spec = SteeringSpec::set_value(0.5).edge(2, 1);
1688 let result = apply_set_value_steering(&tensor, &spec, 2, 4, 0.5).unwrap();
1689 let result_data = tensor_to_vec4(&result).unwrap();
1690
1691 let row2 = &result_data[0][0][2];
1692 let row_sum: f32 = row2.iter().sum();
1693 assert!(
1694 (row_sum - 1.0).abs() < 1e-5,
1695 "row sum should be 1.0, got {row_sum}"
1696 );
1697
1698 assert!(row2[1] > row2[0]);
1700 assert!(row2[1] > row2[2]);
1701 assert!(row2[1] > row2[3]);
1702 }
1703
1704 #[test]
1705 fn knockout_to_steering_conversion() {
1706 let knockout = KnockoutSpec::new().layer(5).head(2).edge(3, 1);
1707 let steering: SteeringSpec = knockout.into();
1708
1709 assert!(matches!(steering.layers, LayerSpec::Specific(ref v) if v == &[5]));
1710 assert!(matches!(steering.heads, HeadSpec::Specific(ref v) if v == &[2]));
1711 assert_eq!(steering.edges.len(), 1);
1712 assert!(steering.is_knockout());
1713 }
1714
1715 #[test]
1716 fn is_prompt_only() {
1717 let spec = SteeringSpec::scale(2.0).edge(5, 2).edge(8, 3);
1718 assert!(spec.is_prompt_only(10));
1719 assert!(!spec.is_prompt_only(6));
1720 }
1721
1722 #[test]
1723 fn is_prompt_only_with_sentinel() {
1724 let spec = SteeringSpec::scale(2.0).to_position(5);
1725 assert!(!spec.is_prompt_only(10));
1726
1727 let spec2 = SteeringSpec::scale(2.0).from_position(5);
1728 assert!(spec2.is_prompt_only(10));
1729 }
1730
1731 #[test]
1732 fn max_positions() {
1733 let spec = SteeringSpec::scale(2.0).edge(5, 2).edge(8, 3).edge(3, 7);
1734 assert_eq!(spec.max_from_pos(), Some(8));
1735 assert_eq!(spec.max_to_pos(), Some(7));
1736 }
1737
1738 #[test]
1739 fn max_positions_empty() {
1740 let spec = SteeringSpec::scale(2.0);
1741 assert_eq!(spec.max_from_pos(), None);
1742 assert_eq!(spec.max_to_pos(), None);
1743 }
1744
1745 #[test]
1748 fn state_knockout_spec_builder() {
1749 let spec = StateKnockoutSpec::new().position(3).position(5).layer(10);
1750 assert_eq!(spec.positions, vec![3, 5]);
1751 assert!(matches!(spec.layers, LayerSpec::Specific(ref v) if v == &[10]));
1752 }
1753
1754 #[test]
1755 fn state_knockout_validation() {
1756 assert!(
1757 StateKnockoutSpec::new()
1758 .position(5)
1759 .layer(10)
1760 .validate(24, 20)
1761 .is_ok()
1762 );
1763 assert!(
1764 StateKnockoutSpec::new()
1765 .position(25)
1766 .validate(24, 20)
1767 .is_err()
1768 );
1769 assert!(
1770 StateKnockoutSpec::new()
1771 .position(5)
1772 .layer(30)
1773 .validate(24, 20)
1774 .is_err()
1775 );
1776 assert!(StateKnockoutSpec::new().validate(24, 20).is_err()); }
1778
1779 #[test]
1780 fn state_knockout_position_set() {
1781 let spec = StateKnockoutSpec::new().position(3).position(5).position(3);
1782 let set = spec.position_set();
1783 assert_eq!(set.len(), 2); assert!(set.contains(&3));
1785 assert!(set.contains(&5));
1786 }
1787
1788 #[test]
1789 fn state_knockout_layer_range() {
1790 let spec = StateKnockoutSpec::new().position(0).layer_range(5, 10);
1791 assert!(!spec.applies_to_layer(4));
1792 assert!(spec.applies_to_layer(5));
1793 assert!(spec.applies_to_layer(10));
1794 assert!(!spec.applies_to_layer(11));
1795 }
1796}