candle_mi/cache/activation.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Activation cache for storing intermediate transformer states.
4
5use candle_core::{DType, Tensor};
6
7use crate::error::{MIError, Result};
8
9/// Stores per-layer last-token activations from a forward pass.
10///
11/// Each tensor has shape `[d_model]` — the residual stream activation
12/// at the final sequence position for a given layer.
13///
14/// # Example
15///
16/// ```
17/// use candle_mi::ActivationCache;
18/// use candle_core::{Device, Tensor};
19///
20/// let mut cache = ActivationCache::with_capacity(32);
21/// cache.push(Tensor::zeros(128, candle_core::DType::F32, &Device::Cpu).unwrap());
22/// cache.push(Tensor::zeros(128, candle_core::DType::F32, &Device::Cpu).unwrap());
23/// assert_eq!(cache.n_layers(), 2);
24/// ```
25#[derive(Debug)]
26pub struct ActivationCache {
27 /// Residual stream activations per layer, each shape `[d_model]`.
28 activations: Vec<Tensor>,
29}
30
31impl ActivationCache {
32 /// Create a new cache from collected activations.
33 ///
34 /// # Errors
35 ///
36 /// Currently infallible but returns `Result` for forward compatibility.
37 pub const fn new(activations: Vec<Tensor>) -> Result<Self> {
38 Ok(Self { activations })
39 }
40
41 /// Create an empty cache with capacity for `n_layers` layers.
42 #[must_use]
43 pub fn with_capacity(n_layers: usize) -> Self {
44 Self {
45 activations: Vec::with_capacity(n_layers),
46 }
47 }
48
49 /// Add a layer's activation to the cache.
50 pub fn push(&mut self, tensor: Tensor) {
51 self.activations.push(tensor);
52 }
53
54 /// Get the activation for a specific layer.
55 #[must_use]
56 pub fn get_layer(&self, layer: usize) -> Option<&Tensor> {
57 self.activations.get(layer)
58 }
59
60 /// Number of cached layers.
61 #[must_use]
62 pub const fn n_layers(&self) -> usize {
63 self.activations.len()
64 }
65
66 /// Whether the cache is empty.
67 #[must_use]
68 pub const fn is_empty(&self) -> bool {
69 self.activations.is_empty()
70 }
71
72 /// All cached activations as a slice.
73 #[must_use]
74 pub fn activations(&self) -> &[Tensor] {
75 &self.activations
76 }
77
78 /// Extract activations as `f32` vectors.
79 ///
80 /// Returns one `Vec<f32>` of shape `[d_model]` per layer.
81 ///
82 /// # Errors
83 ///
84 /// Returns [`MIError::Model`] if dtype conversion or flattening fails.
85 pub fn to_f32_vecs(&self) -> Result<Vec<Vec<f32>>> {
86 self.activations
87 .iter()
88 .map(|t| {
89 let flat = t.flatten_all()?;
90 let data: Vec<f32> = flat.to_dtype(DType::F32)?.to_vec1()?;
91 Ok(data)
92 })
93 .collect()
94 }
95}
96
97/// Stores all-position activations from a forward pass.
98///
99/// Unlike [`ActivationCache`] which stores only the last-token activation
100/// per layer, this cache stores the full residual stream at every token
101/// position. Each tensor has shape `[seq_len, d_model]`.
102///
103/// # Example
104///
105/// ```
106/// use candle_mi::FullActivationCache;
107/// use candle_core::{Device, Tensor};
108///
109/// let mut cache = FullActivationCache::with_capacity(32);
110/// // shape [seq_len=10, d_model=128]
111/// cache.push(Tensor::zeros((10, 128), candle_core::DType::F32, &Device::Cpu).unwrap());
112///
113/// // Get a single position's activation for CLT encoding
114/// let act = cache.get_position(0, 5).unwrap(); // shape [d_model]
115/// ```
116#[derive(Debug)]
117pub struct FullActivationCache {
118 /// Residual stream activations per layer, each shape `[seq_len, d_model]`.
119 activations: Vec<Tensor>,
120}
121
122impl FullActivationCache {
123 /// Create an empty cache with capacity for `n_layers` layers.
124 #[must_use]
125 pub fn with_capacity(n_layers: usize) -> Self {
126 Self {
127 activations: Vec::with_capacity(n_layers),
128 }
129 }
130
131 /// Add a layer's all-position activation to the cache.
132 ///
133 /// The tensor should have shape `[seq_len, d_model]`.
134 pub fn push(&mut self, tensor: Tensor) {
135 self.activations.push(tensor);
136 }
137
138 /// Get the full activation tensor for a specific layer.
139 ///
140 /// Returns shape `[seq_len, d_model]`, or `None` if the layer
141 /// is not in the cache.
142 #[must_use]
143 pub fn get_layer(&self, layer: usize) -> Option<&Tensor> {
144 self.activations.get(layer)
145 }
146
147 /// Get the activation at a specific layer and token position.
148 ///
149 /// Returns shape `[d_model]` — compatible with CLT `encode()`.
150 ///
151 /// # Errors
152 ///
153 /// Returns [`MIError::Hook`] if the layer is not in the cache or
154 /// the position is out of range.
155 pub fn get_position(&self, layer: usize, position: usize) -> Result<Tensor> {
156 let layer_tensor = self
157 .activations
158 .get(layer)
159 .ok_or_else(|| MIError::Hook(format!("layer {layer} not in cache")))?;
160 let seq_len = layer_tensor.dim(0)?;
161 if position >= seq_len {
162 return Err(MIError::Hook(format!(
163 "position {position} out of range (seq_len={seq_len})"
164 )));
165 }
166 Ok(layer_tensor.narrow(0, position, 1)?.squeeze(0)?)
167 }
168
169 /// Number of cached layers.
170 #[must_use]
171 pub const fn n_layers(&self) -> usize {
172 self.activations.len()
173 }
174
175 /// Sequence length (from the first layer's tensor).
176 ///
177 /// # Errors
178 ///
179 /// Returns [`MIError::Hook`] if the cache is empty.
180 pub fn seq_len(&self) -> Result<usize> {
181 let first = self
182 .activations
183 .first()
184 .ok_or_else(|| MIError::Hook("cache is empty".into()))?;
185 Ok(first.dim(0)?)
186 }
187
188 /// Whether the cache is empty.
189 #[must_use]
190 pub const fn is_empty(&self) -> bool {
191 self.activations.is_empty()
192 }
193}
194
195// ---------------------------------------------------------------------------
196// Tests
197// ---------------------------------------------------------------------------
198
199#[cfg(test)]
200#[allow(clippy::unwrap_used, clippy::expect_used)]
201mod tests {
202 use super::*;
203 use candle_core::Device;
204
205 #[test]
206 fn cache_basic() {
207 let device = Device::Cpu;
208 let t1 = Tensor::zeros((2048,), DType::F32, &device).unwrap();
209 let t2 = Tensor::zeros((2048,), DType::F32, &device).unwrap();
210
211 let cache = ActivationCache::new(vec![t1, t2]).unwrap();
212
213 assert_eq!(cache.n_layers(), 2);
214 assert!(cache.get_layer(0).is_some());
215 assert!(cache.get_layer(1).is_some());
216 assert!(cache.get_layer(2).is_none());
217 }
218
219 #[test]
220 fn cache_push() {
221 let device = Device::Cpu;
222 let mut cache = ActivationCache::with_capacity(2);
223
224 assert!(cache.is_empty());
225
226 let t = Tensor::zeros((2048,), DType::F32, &device).unwrap();
227 cache.push(t);
228
229 assert_eq!(cache.n_layers(), 1);
230 assert!(!cache.is_empty());
231 }
232
233 #[test]
234 fn full_cache_basic() {
235 let device = Device::Cpu;
236 let seq_len = 10;
237 let d_model = 2304;
238
239 let mut cache = FullActivationCache::with_capacity(2);
240 assert!(cache.is_empty());
241
242 let t1 = Tensor::zeros((seq_len, d_model), DType::F32, &device).unwrap();
243 let t2 = Tensor::zeros((seq_len, d_model), DType::F32, &device).unwrap();
244 cache.push(t1);
245 cache.push(t2);
246
247 assert_eq!(cache.n_layers(), 2);
248 assert_eq!(cache.seq_len().unwrap(), seq_len);
249 assert!(!cache.is_empty());
250
251 // get_layer returns 2D tensor
252 let layer0 = cache.get_layer(0).unwrap();
253 assert_eq!(layer0.dims(), &[seq_len, d_model]);
254
255 // get_position returns 1D tensor
256 let pos = cache.get_position(0, 5).unwrap();
257 assert_eq!(pos.dims(), &[d_model]);
258
259 // out of range
260 assert!(cache.get_position(0, seq_len).is_err());
261 assert!(cache.get_position(5, 0).is_err());
262 }
263}