use candle_core::{DType, Tensor};
use crate::error::{MIError, Result};
#[derive(Debug)]
pub struct ActivationCache {
activations: Vec<Tensor>,
}
impl ActivationCache {
pub const fn new(activations: Vec<Tensor>) -> Result<Self> {
Ok(Self { activations })
}
#[must_use]
pub fn with_capacity(n_layers: usize) -> Self {
Self {
activations: Vec::with_capacity(n_layers),
}
}
pub fn push(&mut self, tensor: Tensor) {
self.activations.push(tensor);
}
#[must_use]
pub fn get_layer(&self, layer: usize) -> Option<&Tensor> {
self.activations.get(layer)
}
#[must_use]
pub const fn n_layers(&self) -> usize {
self.activations.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.activations.is_empty()
}
#[must_use]
pub fn activations(&self) -> &[Tensor] {
&self.activations
}
pub fn to_f32_vecs(&self) -> Result<Vec<Vec<f32>>> {
self.activations
.iter()
.map(|t| {
let flat = t.flatten_all()?;
let data: Vec<f32> = flat.to_dtype(DType::F32)?.to_vec1()?;
Ok(data)
})
.collect()
}
}
#[derive(Debug)]
pub struct FullActivationCache {
activations: Vec<Tensor>,
}
impl FullActivationCache {
#[must_use]
pub fn with_capacity(n_layers: usize) -> Self {
Self {
activations: Vec::with_capacity(n_layers),
}
}
pub fn push(&mut self, tensor: Tensor) {
self.activations.push(tensor);
}
#[must_use]
pub fn get_layer(&self, layer: usize) -> Option<&Tensor> {
self.activations.get(layer)
}
pub fn get_position(&self, layer: usize, position: usize) -> Result<Tensor> {
let layer_tensor = self
.activations
.get(layer)
.ok_or_else(|| MIError::Hook(format!("layer {layer} not in cache")))?;
let seq_len = layer_tensor.dim(0)?;
if position >= seq_len {
return Err(MIError::Hook(format!(
"position {position} out of range (seq_len={seq_len})"
)));
}
Ok(layer_tensor.narrow(0, position, 1)?.squeeze(0)?)
}
#[must_use]
pub const fn n_layers(&self) -> usize {
self.activations.len()
}
pub fn seq_len(&self) -> Result<usize> {
let first = self
.activations
.first()
.ok_or_else(|| MIError::Hook("cache is empty".into()))?;
Ok(first.dim(0)?)
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.activations.is_empty()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn cache_basic() {
let device = Device::Cpu;
let t1 = Tensor::zeros((2048,), DType::F32, &device).unwrap();
let t2 = Tensor::zeros((2048,), DType::F32, &device).unwrap();
let cache = ActivationCache::new(vec![t1, t2]).unwrap();
assert_eq!(cache.n_layers(), 2);
assert!(cache.get_layer(0).is_some());
assert!(cache.get_layer(1).is_some());
assert!(cache.get_layer(2).is_none());
}
#[test]
fn cache_push() {
let device = Device::Cpu;
let mut cache = ActivationCache::with_capacity(2);
assert!(cache.is_empty());
let t = Tensor::zeros((2048,), DType::F32, &device).unwrap();
cache.push(t);
assert_eq!(cache.n_layers(), 1);
assert!(!cache.is_empty());
}
#[test]
fn full_cache_basic() {
let device = Device::Cpu;
let seq_len = 10;
let d_model = 2304;
let mut cache = FullActivationCache::with_capacity(2);
assert!(cache.is_empty());
let t1 = Tensor::zeros((seq_len, d_model), DType::F32, &device).unwrap();
let t2 = Tensor::zeros((seq_len, d_model), DType::F32, &device).unwrap();
cache.push(t1);
cache.push(t2);
assert_eq!(cache.n_layers(), 2);
assert_eq!(cache.seq_len().unwrap(), seq_len);
assert!(!cache.is_empty());
let layer0 = cache.get_layer(0).unwrap();
assert_eq!(layer0.dims(), &[seq_len, d_model]);
let pos = cache.get_position(0, 5).unwrap();
assert_eq!(pos.dims(), &[d_model]);
assert!(cache.get_position(0, seq_len).is_err());
assert!(cache.get_position(5, 0).is_err());
}
}