Skip to main content

candle_mi/cache/
activation.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Activation cache for storing intermediate transformer states.
4
5use candle_core::{DType, Tensor};
6
7use crate::error::{MIError, Result};
8
9/// Stores per-layer last-token activations from a forward pass.
10///
11/// Each tensor has shape `[d_model]` — the residual stream activation
12/// at the final sequence position for a given layer.
13///
14/// # Example
15///
16/// ```
17/// use candle_mi::ActivationCache;
18/// use candle_core::{Device, Tensor};
19///
20/// let mut cache = ActivationCache::with_capacity(32);
21/// cache.push(Tensor::zeros(128, candle_core::DType::F32, &Device::Cpu).unwrap());
22/// cache.push(Tensor::zeros(128, candle_core::DType::F32, &Device::Cpu).unwrap());
23/// assert_eq!(cache.n_layers(), 2);
24/// ```
25#[derive(Debug)]
26pub struct ActivationCache {
27    /// Residual stream activations per layer, each shape `[d_model]`.
28    activations: Vec<Tensor>,
29}
30
31impl ActivationCache {
32    /// Create a new cache from collected activations.
33    ///
34    /// # Errors
35    ///
36    /// Currently infallible but returns `Result` for forward compatibility.
37    pub const fn new(activations: Vec<Tensor>) -> Result<Self> {
38        Ok(Self { activations })
39    }
40
41    /// Create an empty cache with capacity for `n_layers` layers.
42    #[must_use]
43    pub fn with_capacity(n_layers: usize) -> Self {
44        Self {
45            activations: Vec::with_capacity(n_layers),
46        }
47    }
48
49    /// Add a layer's activation to the cache.
50    pub fn push(&mut self, tensor: Tensor) {
51        self.activations.push(tensor);
52    }
53
54    /// Get the activation for a specific layer.
55    #[must_use]
56    pub fn get_layer(&self, layer: usize) -> Option<&Tensor> {
57        self.activations.get(layer)
58    }
59
60    /// Number of cached layers.
61    #[must_use]
62    pub const fn n_layers(&self) -> usize {
63        self.activations.len()
64    }
65
66    /// Whether the cache is empty.
67    #[must_use]
68    pub const fn is_empty(&self) -> bool {
69        self.activations.is_empty()
70    }
71
72    /// All cached activations as a slice.
73    #[must_use]
74    pub fn activations(&self) -> &[Tensor] {
75        &self.activations
76    }
77
78    /// Extract activations as `f32` vectors.
79    ///
80    /// Returns one `Vec<f32>` of shape `[d_model]` per layer.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`MIError::Model`] if dtype conversion or flattening fails.
85    pub fn to_f32_vecs(&self) -> Result<Vec<Vec<f32>>> {
86        self.activations
87            .iter()
88            .map(|t| {
89                let flat = t.flatten_all()?;
90                let data: Vec<f32> = flat.to_dtype(DType::F32)?.to_vec1()?;
91                Ok(data)
92            })
93            .collect()
94    }
95}
96
97/// Stores all-position activations from a forward pass.
98///
99/// Unlike [`ActivationCache`] which stores only the last-token activation
100/// per layer, this cache stores the full residual stream at every token
101/// position. Each tensor has shape `[seq_len, d_model]`.
102///
103/// # Example
104///
105/// ```
106/// use candle_mi::FullActivationCache;
107/// use candle_core::{Device, Tensor};
108///
109/// let mut cache = FullActivationCache::with_capacity(32);
110/// // shape [seq_len=10, d_model=128]
111/// cache.push(Tensor::zeros((10, 128), candle_core::DType::F32, &Device::Cpu).unwrap());
112///
113/// // Get a single position's activation for CLT encoding
114/// let act = cache.get_position(0, 5).unwrap(); // shape [d_model]
115/// ```
116#[derive(Debug)]
117pub struct FullActivationCache {
118    /// Residual stream activations per layer, each shape `[seq_len, d_model]`.
119    activations: Vec<Tensor>,
120}
121
122impl FullActivationCache {
123    /// Create an empty cache with capacity for `n_layers` layers.
124    #[must_use]
125    pub fn with_capacity(n_layers: usize) -> Self {
126        Self {
127            activations: Vec::with_capacity(n_layers),
128        }
129    }
130
131    /// Add a layer's all-position activation to the cache.
132    ///
133    /// The tensor should have shape `[seq_len, d_model]`.
134    pub fn push(&mut self, tensor: Tensor) {
135        self.activations.push(tensor);
136    }
137
138    /// Get the full activation tensor for a specific layer.
139    ///
140    /// Returns shape `[seq_len, d_model]`, or `None` if the layer
141    /// is not in the cache.
142    #[must_use]
143    pub fn get_layer(&self, layer: usize) -> Option<&Tensor> {
144        self.activations.get(layer)
145    }
146
147    /// Get the activation at a specific layer and token position.
148    ///
149    /// Returns shape `[d_model]` — compatible with CLT `encode()`.
150    ///
151    /// # Errors
152    ///
153    /// Returns [`MIError::Hook`] if the layer is not in the cache or
154    /// the position is out of range.
155    pub fn get_position(&self, layer: usize, position: usize) -> Result<Tensor> {
156        let layer_tensor = self
157            .activations
158            .get(layer)
159            .ok_or_else(|| MIError::Hook(format!("layer {layer} not in cache")))?;
160        let seq_len = layer_tensor.dim(0)?;
161        if position >= seq_len {
162            return Err(MIError::Hook(format!(
163                "position {position} out of range (seq_len={seq_len})"
164            )));
165        }
166        Ok(layer_tensor.narrow(0, position, 1)?.squeeze(0)?)
167    }
168
169    /// Number of cached layers.
170    #[must_use]
171    pub const fn n_layers(&self) -> usize {
172        self.activations.len()
173    }
174
175    /// Sequence length (from the first layer's tensor).
176    ///
177    /// # Errors
178    ///
179    /// Returns [`MIError::Hook`] if the cache is empty.
180    pub fn seq_len(&self) -> Result<usize> {
181        let first = self
182            .activations
183            .first()
184            .ok_or_else(|| MIError::Hook("cache is empty".into()))?;
185        Ok(first.dim(0)?)
186    }
187
188    /// Whether the cache is empty.
189    #[must_use]
190    pub const fn is_empty(&self) -> bool {
191        self.activations.is_empty()
192    }
193}
194
195// ---------------------------------------------------------------------------
196// Tests
197// ---------------------------------------------------------------------------
198
199#[cfg(test)]
200#[allow(clippy::unwrap_used, clippy::expect_used)]
201mod tests {
202    use super::*;
203    use candle_core::Device;
204
205    #[test]
206    fn cache_basic() {
207        let device = Device::Cpu;
208        let t1 = Tensor::zeros((2048,), DType::F32, &device).unwrap();
209        let t2 = Tensor::zeros((2048,), DType::F32, &device).unwrap();
210
211        let cache = ActivationCache::new(vec![t1, t2]).unwrap();
212
213        assert_eq!(cache.n_layers(), 2);
214        assert!(cache.get_layer(0).is_some());
215        assert!(cache.get_layer(1).is_some());
216        assert!(cache.get_layer(2).is_none());
217    }
218
219    #[test]
220    fn cache_push() {
221        let device = Device::Cpu;
222        let mut cache = ActivationCache::with_capacity(2);
223
224        assert!(cache.is_empty());
225
226        let t = Tensor::zeros((2048,), DType::F32, &device).unwrap();
227        cache.push(t);
228
229        assert_eq!(cache.n_layers(), 1);
230        assert!(!cache.is_empty());
231    }
232
233    #[test]
234    fn full_cache_basic() {
235        let device = Device::Cpu;
236        let seq_len = 10;
237        let d_model = 2304;
238
239        let mut cache = FullActivationCache::with_capacity(2);
240        assert!(cache.is_empty());
241
242        let t1 = Tensor::zeros((seq_len, d_model), DType::F32, &device).unwrap();
243        let t2 = Tensor::zeros((seq_len, d_model), DType::F32, &device).unwrap();
244        cache.push(t1);
245        cache.push(t2);
246
247        assert_eq!(cache.n_layers(), 2);
248        assert_eq!(cache.seq_len().unwrap(), seq_len);
249        assert!(!cache.is_empty());
250
251        // get_layer returns 2D tensor
252        let layer0 = cache.get_layer(0).unwrap();
253        assert_eq!(layer0.dims(), &[seq_len, d_model]);
254
255        // get_position returns 1D tensor
256        let pos = cache.get_position(0, 5).unwrap();
257        assert_eq!(pos.dims(), &[d_model]);
258
259        // out of range
260        assert!(cache.get_position(0, seq_len).is_err());
261        assert!(cache.get_position(5, 0).is_err());
262    }
263}