use candle_core::{DType, Tensor};
use crate::error::{MIError, Result};
#[derive(Debug)]
pub struct AttentionCache {
patterns: Vec<Tensor>,
}
impl AttentionCache {
#[must_use]
pub fn with_capacity(n_layers: usize) -> Self {
Self {
patterns: Vec::with_capacity(n_layers),
}
}
pub fn push(&mut self, pattern: Tensor) {
self.patterns.push(pattern);
}
#[must_use]
pub const fn n_layers(&self) -> usize {
self.patterns.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
#[must_use]
pub fn get_layer(&self, layer: usize) -> Option<&Tensor> {
self.patterns.get(layer)
}
pub fn attention_from_position(&self, layer: usize, position: usize) -> Result<Vec<f32>> {
let pattern = self
.patterns
.get(layer)
.ok_or_else(|| MIError::Hook(format!("layer {layer} not in attention cache")))?;
let seq_q = pattern.dim(2)?;
if position >= seq_q {
return Err(MIError::Hook(format!(
"position {position} out of range (seq_q={seq_q})"
)));
}
let attn_f32 = pattern.to_dtype(DType::F32)?;
let batch0 = attn_f32.narrow(0, 0, 1)?;
let row = batch0.narrow(2, position, 1)?;
let row = row.squeeze(0)?.squeeze(1)?;
let avg = row.mean(0)?;
let result: Vec<f32> = avg.to_vec1()?;
Ok(result)
}
pub fn attention_to_position(&self, layer: usize, position: usize) -> Result<Vec<f32>> {
let pattern = self
.patterns
.get(layer)
.ok_or_else(|| MIError::Hook(format!("layer {layer} not in attention cache")))?;
let seq_k = pattern.dim(3)?;
if position >= seq_k {
return Err(MIError::Hook(format!(
"position {position} out of range (seq_k={seq_k})"
)));
}
let attn_f32 = pattern.to_dtype(DType::F32)?;
let batch0 = attn_f32.narrow(0, 0, 1)?;
let col = batch0.narrow(3, position, 1)?;
let col = col.squeeze(0)?.squeeze(2)?;
let avg = col.mean(0)?;
let result: Vec<f32> = avg.to_vec1()?;
Ok(result)
}
pub fn top_attended_positions(
&self,
layer: usize,
from_position: usize,
k: usize,
) -> Result<Vec<(usize, f32)>> {
let attn = self.attention_from_position(layer, from_position)?;
let mut indexed: Vec<(usize, f32)> = attn.into_iter().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(k);
Ok(indexed)
}
#[must_use]
pub fn patterns(&self) -> &[Tensor] {
&self.patterns
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::float_cmp)]
mod tests {
use super::*;
use candle_core::Device;
fn sample_cache() -> AttentionCache {
let device = Device::Cpu;
#[rustfmt::skip]
let data: Vec<f32> = vec![
0.25, 0.25, 0.25, 0.25,
0.25, 0.25, 0.25, 0.25,
0.25, 0.25, 0.25, 0.25,
0.25, 0.25, 0.25, 0.25,
0.70, 0.10, 0.10, 0.10,
0.10, 0.70, 0.10, 0.10,
0.10, 0.10, 0.70, 0.10,
0.10, 0.10, 0.10, 0.70,
];
let tensor = Tensor::from_vec(data, (1, 2, 4, 4), &device).unwrap();
let mut cache = AttentionCache::with_capacity(1);
cache.push(tensor);
cache
}
#[test]
fn empty_cache() {
let cache = AttentionCache::with_capacity(32);
assert_eq!(cache.n_layers(), 0);
assert!(cache.is_empty());
assert!(cache.get_layer(0).is_none());
}
#[test]
fn push_and_get_layer() {
let cache = sample_cache();
assert_eq!(cache.n_layers(), 1);
assert!(!cache.is_empty());
let layer0 = cache.get_layer(0).unwrap();
assert_eq!(layer0.dims(), &[1, 2, 4, 4]);
assert!(cache.get_layer(1).is_none());
}
#[test]
fn attention_from_position_values() {
let cache = sample_cache();
let attn = cache.attention_from_position(0, 0).unwrap();
assert_eq!(attn.len(), 4);
assert!((attn[0] - 0.475).abs() < 1e-5);
assert!((attn[1] - 0.175).abs() < 1e-5);
assert!((attn[2] - 0.175).abs() < 1e-5);
assert!((attn[3] - 0.175).abs() < 1e-5);
}
#[test]
fn attention_from_position_out_of_range() {
let cache = sample_cache();
assert!(cache.attention_from_position(0, 10).is_err());
assert!(cache.attention_from_position(5, 0).is_err());
}
#[test]
fn attention_to_position_values() {
let cache = sample_cache();
let attn = cache.attention_to_position(0, 0).unwrap();
assert_eq!(attn.len(), 4);
assert!((attn[0] - 0.475).abs() < 1e-5);
assert!((attn[1] - 0.175).abs() < 1e-5);
}
#[test]
fn attention_to_position_out_of_range() {
let cache = sample_cache();
assert!(cache.attention_to_position(0, 10).is_err());
assert!(cache.attention_to_position(5, 0).is_err());
}
#[test]
fn top_attended_positions_sorted() {
let cache = sample_cache();
let top = cache.top_attended_positions(0, 0, 2).unwrap();
assert_eq!(top.len(), 2);
assert_eq!(top[0].0, 0); assert!((top[0].1 - 0.475).abs() < 1e-5);
}
#[test]
fn top_attended_positions_k_larger_than_seq() {
let cache = sample_cache();
let top = cache.top_attended_positions(0, 0, 100).unwrap();
assert_eq!(top.len(), 4);
}
#[test]
fn patterns_slice() {
let cache = sample_cache();
assert_eq!(cache.patterns().len(), 1);
}
}