Skip to main content

llama_cpp_4/context/
tensor_capture.rs

1//! Capture intermediate tensor outputs during decode via the `cb_eval` callback.
2//!
3//! During `llama_decode`, llama.cpp evaluates a computation graph where each
4//! tensor node has a name (e.g. `"l_out-13"` for layer 13's output,
5//! `"attn_norm-5"` for layer 5's attention norm, `"result_norm"` for the
6//! final norm output).
7//!
8//! The `cb_eval` callback is invoked for every tensor node:
9//! - **Ask phase** (`ask = true`):  return `true` to request this tensor's data.
10//! - **Data phase** (`ask = false`): the tensor data is computed and available
11//!   to copy out via `ggml_backend_tensor_get()`.
12//!
13//! [`TensorCapture`] provides a safe, reusable wrapper around this mechanism.
14//!
15//! # Example
16//!
17//! ```rust,ignore
18//! use llama_cpp_4::context::params::LlamaContextParams;
19//! use llama_cpp_4::context::tensor_capture::TensorCapture;
20//!
21//! // Capture layers 13, 20, 27
22//! let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
23//!
24//! let ctx_params = LlamaContextParams::default()
25//!     .with_n_ctx(Some(NonZeroU32::new(2048).unwrap()))
26//!     .with_embeddings(true)
27//!     .with_tensor_capture(&mut capture);
28//!
29//! let mut ctx = model.new_context(&backend, ctx_params)?;
30//! // ... add tokens to batch ...
31//! ctx.decode(&mut batch)?;
32//!
33//! // Read captured hidden states
34//! for &layer in &[13, 20, 27] {
35//!     if let Some(info) = capture.get(layer) {
36//!         println!("Layer {}: shape [{}, {}]", layer, info.n_embd, info.n_tokens);
37//!         // info.data contains [n_tokens * n_embd] f32 values
38//!         // Layout: data[token_idx * n_embd + dim_idx]
39//!     }
40//! }
41//! ```
42
43use std::collections::HashMap;
44
45/// Information about a single captured tensor.
46#[derive(Debug, Clone)]
47pub struct CapturedTensor {
48    /// The tensor name (e.g. `"l_out-13"`).
49    pub name: String,
50    /// The layer index extracted from the name, or `None` if the name
51    /// doesn't follow the `"prefix-N"` pattern.
52    pub layer: Option<usize>,
53    /// First dimension (typically `n_embd` / hidden dimension).
54    pub ne0: usize,
55    /// Second dimension (typically `n_tokens`).
56    pub ne1: usize,
57    /// Flattened f32 data with `ne0 * ne1` elements.
58    ///
59    /// Layout (row-major from ggml's perspective):
60    /// `data[token_idx * ne0 + dim_idx]`
61    ///
62    /// This matches the ggml tensor layout where `ne[0]` is the
63    /// innermost (contiguous) dimension.
64    pub data: Vec<f32>,
65}
66
67impl CapturedTensor {
68    /// Number of embedding dimensions (alias for `ne0`).
69    #[inline]
70    #[must_use]
71    pub fn n_embd(&self) -> usize {
72        self.ne0
73    }
74
75    /// Number of tokens (alias for `ne1`).
76    #[inline]
77    #[must_use]
78    pub fn n_tokens(&self) -> usize {
79        self.ne1
80    }
81
82    /// Get the hidden state for a specific token.
83    ///
84    /// Returns a slice of `n_embd` floats, or `None` if `token_idx` is
85    /// out of range.
86    #[must_use]
87    pub fn token_embedding(&self, token_idx: usize) -> Option<&[f32]> {
88        if token_idx >= self.ne1 {
89            return None;
90        }
91        let start = token_idx * self.ne0;
92        Some(&self.data[start..start + self.ne0])
93    }
94}
95
96/// Strategy for selecting which tensors to capture.
97#[derive(Debug, Clone)]
98enum CaptureFilter {
99    /// Capture tensors named `"l_out-{N}"` for specific layer indices.
100    Layers(Vec<usize>),
101    /// Capture tensors whose names exactly match the given strings.
102    Names(Vec<String>),
103    /// Capture tensors whose names start with the given prefix.
104    Prefix(String),
105    /// Capture all tensors (warning: can be very large).
106    All,
107}
108
109/// Captures intermediate tensor outputs during `llama_decode`.
110///
111/// Create a `TensorCapture`, attach it to `LlamaContextParams` via
112/// [`with_tensor_capture`](super::params::LlamaContextParams::with_tensor_capture),
113/// then call `decode()`. After decode completes, read captured data
114/// via [`get`], [`get_layer`], or [`iter`].
115///
116/// # Lifetime & Safety
117///
118/// The `TensorCapture` must outlive the `LlamaContext` it is attached to.
119/// The borrow is enforced by [`with_tensor_capture`](super::params::LlamaContextParams::with_tensor_capture)
120/// taking `&mut self`.
121pub struct TensorCapture {
122    filter: CaptureFilter,
123    /// Captured tensors keyed by name.
124    captured: HashMap<String, CapturedTensor>,
125}
126
127impl std::fmt::Debug for TensorCapture {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        f.debug_struct("TensorCapture")
130            .field("filter", &self.filter)
131            .field("captured_count", &self.captured.len())
132            .field("captured_keys", &self.captured.keys().collect::<Vec<_>>())
133            .finish()
134    }
135}
136
137impl TensorCapture {
138    /// Create a capture that intercepts layer outputs `"l_out-{N}"` for
139    /// the specified layer indices.
140    ///
141    /// This is the most common use case for extracting per-layer hidden
142    /// states from a language model.
143    ///
144    /// # Example
145    ///
146    /// ```rust,ignore
147    /// // Capture layers 13, 20, 27 (typical for LLaMA-3.2-3B with positions [0.5, 0.75, 1.0])
148    /// let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
149    /// ```
150    #[must_use]
151    pub fn for_layers(layer_indices: &[usize]) -> Self {
152        Self {
153            filter: CaptureFilter::Layers(layer_indices.to_vec()),
154            captured: HashMap::new(),
155        }
156    }
157
158    /// Create a capture that intercepts tensors with exact matching names.
159    ///
160    /// # Example
161    ///
162    /// ```rust,ignore
163    /// let mut capture = TensorCapture::for_names(&["result_norm", "l_out-27"]);
164    /// ```
165    #[must_use]
166    pub fn for_names(names: &[&str]) -> Self {
167        Self {
168            filter: CaptureFilter::Names(
169                names.iter().map(std::string::ToString::to_string).collect(),
170            ),
171            captured: HashMap::new(),
172        }
173    }
174
175    /// Create a capture that intercepts all tensors whose name starts with
176    /// the given prefix.
177    ///
178    /// # Example
179    ///
180    /// ```rust,ignore
181    /// // Capture all attention outputs
182    /// let mut capture = TensorCapture::for_prefix("attn_out");
183    /// ```
184    #[must_use]
185    pub fn for_prefix(prefix: &str) -> Self {
186        Self {
187            filter: CaptureFilter::Prefix(prefix.to_string()),
188            captured: HashMap::new(),
189        }
190    }
191
192    /// Create a capture that intercepts **all** tensors.
193    ///
194    /// ⚠️ Warning: this can produce very large amounts of data.
195    /// Use only for debugging or inspection.
196    #[must_use]
197    pub fn all() -> Self {
198        Self {
199            filter: CaptureFilter::All,
200            captured: HashMap::new(),
201        }
202    }
203
204    /// Clear all previously captured data, keeping the filter configuration.
205    ///
206    /// Call this before a new `decode()` if reusing the capture across
207    /// multiple batches.
208    pub fn clear(&mut self) {
209        self.captured.clear();
210    }
211
212    /// Get a captured tensor by its full name (e.g. `"l_out-13"`).
213    #[must_use]
214    pub fn get(&self, name: &str) -> Option<&CapturedTensor> {
215        self.captured.get(name)
216    }
217
218    /// Get a captured layer output by layer index.
219    ///
220    /// Looks up `"l_out-{layer_idx}"`.
221    #[must_use]
222    pub fn get_layer(&self, layer_idx: usize) -> Option<&CapturedTensor> {
223        self.captured.get(&format!("l_out-{layer_idx}"))
224    }
225
226    /// Returns `true` if the specified layer was captured.
227    #[must_use]
228    pub fn has_layer(&self, layer_idx: usize) -> bool {
229        self.captured.contains_key(&format!("l_out-{layer_idx}"))
230    }
231
232    /// Number of tensors captured so far.
233    #[must_use]
234    pub fn len(&self) -> usize {
235        self.captured.len()
236    }
237
238    /// Returns `true` if no tensors have been captured.
239    #[must_use]
240    pub fn is_empty(&self) -> bool {
241        self.captured.is_empty()
242    }
243
244    /// Iterate over all captured tensors.
245    pub fn iter(&self) -> impl Iterator<Item = (&str, &CapturedTensor)> {
246        self.captured.iter().map(|(k, v)| (k.as_str(), v))
247    }
248
249    /// Get all captured layer indices (sorted).
250    #[must_use]
251    pub fn captured_layers(&self) -> Vec<usize> {
252        let mut layers: Vec<usize> = self.captured.values().filter_map(|ct| ct.layer).collect();
253        layers.sort_unstable();
254        layers.dedup();
255        layers
256    }
257
258    // ── Internal: callback matching ──────────────────────────────────
259
260    /// Check if a tensor name matches the capture filter.
261    fn matches(&self, name: &str) -> bool {
262        match &self.filter {
263            CaptureFilter::Layers(indices) => {
264                if let Some(suffix) = name.strip_prefix("l_out-") {
265                    if let Ok(idx) = suffix.parse::<usize>() {
266                        return indices.contains(&idx);
267                    }
268                }
269                false
270            }
271            CaptureFilter::Names(names) => names.iter().any(|n| n == name),
272            CaptureFilter::Prefix(prefix) => name.starts_with(prefix.as_str()),
273            CaptureFilter::All => true,
274        }
275    }
276
277    /// Store a captured tensor.
278    fn store(&mut self, name: String, ne0: usize, ne1: usize, data: Vec<f32>) {
279        let layer = name
280            .strip_prefix("l_out-")
281            .and_then(|s| s.parse::<usize>().ok());
282
283        self.captured.insert(
284            name.clone(),
285            CapturedTensor {
286                name,
287                layer,
288                ne0,
289                ne1,
290                data,
291            },
292        );
293    }
294}
295
296// ── The extern "C" callback ──────────────────────────────────────────────
297
298/// The `cb_eval` callback function passed to llama.cpp.
299///
300/// # Safety
301///
302/// This function is called from C code during graph evaluation.
303/// `user_data` must point to a valid `TensorCapture` instance.
304pub(crate) unsafe extern "C" fn tensor_capture_callback(
305    t: *mut llama_cpp_sys_4::ggml_tensor,
306    ask: bool,
307    user_data: *mut std::ffi::c_void,
308) -> bool {
309    if t.is_null() || user_data.is_null() {
310        return false;
311    }
312
313    // Read tensor name from the fixed-size C array
314    let name_bytes = &(*t).name;
315    let len = name_bytes
316        .iter()
317        .position(|&b| b == 0)
318        .unwrap_or(name_bytes.len());
319    let name = std::str::from_utf8_unchecked(std::slice::from_raw_parts(
320        name_bytes.as_ptr().cast::<u8>(),
321        len,
322    ));
323
324    let state = &mut *user_data.cast::<TensorCapture>();
325
326    if !state.matches(name) {
327        return false;
328    }
329
330    if ask {
331        return true;
332    }
333
334    // Data phase: copy tensor data out
335    let ne0 = (*t).ne[0] as usize;
336    let ne1 = (*t).ne[1] as usize;
337    let n_elements = ne0 * ne1;
338
339    let mut buf = vec![0f32; n_elements];
340    llama_cpp_sys_4::ggml_backend_tensor_get(
341        t,
342        buf.as_mut_ptr().cast::<std::ffi::c_void>(),
343        0,
344        n_elements * std::mem::size_of::<f32>(),
345    );
346
347    state.store(name.to_string(), ne0, ne1, buf);
348
349    true
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_for_layers_matching() {
358        let capture = TensorCapture::for_layers(&[13, 20, 27]);
359        assert!(capture.matches("l_out-13"));
360        assert!(capture.matches("l_out-20"));
361        assert!(capture.matches("l_out-27"));
362        assert!(!capture.matches("l_out-0"));
363        assert!(!capture.matches("l_out-14"));
364        assert!(!capture.matches("attn_norm-13"));
365        assert!(!capture.matches("result_norm"));
366    }
367
368    #[test]
369    fn test_for_names_matching() {
370        let capture = TensorCapture::for_names(&["result_norm", "l_out-27"]);
371        assert!(capture.matches("result_norm"));
372        assert!(capture.matches("l_out-27"));
373        assert!(!capture.matches("l_out-13"));
374        assert!(!capture.matches("result_output"));
375    }
376
377    #[test]
378    fn test_for_prefix_matching() {
379        let capture = TensorCapture::for_prefix("attn_out");
380        assert!(capture.matches("attn_out-0"));
381        assert!(capture.matches("attn_out-27"));
382        assert!(!capture.matches("attn_norm-0"));
383        assert!(!capture.matches("l_out-0"));
384    }
385
386    #[test]
387    fn test_all_matching() {
388        let capture = TensorCapture::all();
389        assert!(capture.matches("l_out-13"));
390        assert!(capture.matches("result_norm"));
391        assert!(capture.matches("anything"));
392    }
393
394    #[test]
395    fn test_store_and_get() {
396        let mut capture = TensorCapture::for_layers(&[13]);
397        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
398        capture.store("l_out-13".to_string(), 3, 2, data.clone());
399
400        assert_eq!(capture.len(), 1);
401        assert!(!capture.is_empty());
402
403        let ct = capture.get("l_out-13").unwrap();
404        assert_eq!(ct.name, "l_out-13");
405        assert_eq!(ct.layer, Some(13));
406        assert_eq!(ct.n_embd(), 3);
407        assert_eq!(ct.n_tokens(), 2);
408        assert_eq!(ct.data, data);
409
410        // Also accessible by layer index
411        let ct2 = capture.get_layer(13).unwrap();
412        assert_eq!(ct2.name, ct.name);
413        assert!(capture.has_layer(13));
414        assert!(!capture.has_layer(14));
415    }
416
417    #[test]
418    fn test_token_embedding() {
419        let mut capture = TensorCapture::for_layers(&[5]);
420        // 2 tokens, 3 dims: token0=[1,2,3], token1=[4,5,6]
421        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
422        capture.store("l_out-5".to_string(), 3, 2, data);
423
424        let ct = capture.get_layer(5).unwrap();
425        assert_eq!(ct.token_embedding(0), Some(&[1.0, 2.0, 3.0][..]));
426        assert_eq!(ct.token_embedding(1), Some(&[4.0, 5.0, 6.0][..]));
427        assert_eq!(ct.token_embedding(2), None);
428    }
429
430    #[test]
431    fn test_captured_layers() {
432        let mut capture = TensorCapture::for_layers(&[5, 10, 20]);
433        capture.store("l_out-10".to_string(), 2, 1, vec![0.0, 0.0]);
434        capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
435        assert_eq!(capture.captured_layers(), vec![5, 10]);
436    }
437
438    #[test]
439    fn test_clear() {
440        let mut capture = TensorCapture::for_layers(&[5]);
441        capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
442        assert_eq!(capture.len(), 1);
443        capture.clear();
444        assert_eq!(capture.len(), 0);
445        assert!(capture.is_empty());
446    }
447
448    #[test]
449    fn test_non_layer_tensor() {
450        let mut capture = TensorCapture::for_names(&["result_norm"]);
451        capture.store("result_norm".to_string(), 4, 3, vec![0.0; 12]);
452        let ct = capture.get("result_norm").unwrap();
453        assert_eq!(ct.layer, None);
454        assert_eq!(ct.n_embd(), 4);
455        assert_eq!(ct.n_tokens(), 3);
456    }
457}