Skip to main content

llama_cpp_4/context/
tensor_capture.rs

1//! Capture intermediate tensor outputs during decode via the `cb_eval` callback.
2//!
3//! During `llama_decode`, llama.cpp evaluates a computation graph where each
4//! tensor node has a name (e.g. `"l_out-13"` for layer 13's output,
5//! `"attn_norm-5"` for layer 5's attention norm, `"result_norm"` for the
6//! final norm output).
7//!
8//! The `cb_eval` callback is invoked for every tensor node:
9//! - **Ask phase** (`ask = true`):  return `true` to request this tensor's data.
10//! - **Data phase** (`ask = false`): the tensor data is computed and available
11//!   to copy out via `ggml_backend_tensor_get()`.
12//!
13//! [`TensorCapture`] provides a safe, reusable wrapper around this mechanism.
14//!
15//! # Example
16//!
17//! ```rust,ignore
18//! use llama_cpp_4::context::params::LlamaContextParams;
19//! use llama_cpp_4::context::tensor_capture::TensorCapture;
20//!
21//! // Capture layers 13, 20, 27
22//! let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
23//!
24//! let ctx_params = LlamaContextParams::default()
25//!     .with_n_ctx(Some(NonZeroU32::new(2048).unwrap()))
26//!     .with_embeddings(true)
27//!     .with_tensor_capture(&mut capture);
28//!
29//! let mut ctx = model.new_context(&backend, ctx_params)?;
30//! // ... add tokens to batch ...
31//! ctx.decode(&mut batch)?;
32//!
33//! // Read captured hidden states
34//! for &layer in &[13, 20, 27] {
35//!     if let Some(info) = capture.get(layer) {
36//!         println!("Layer {}: shape [{}, {}]", layer, info.n_embd, info.n_tokens);
37//!         // info.data contains [n_tokens * n_embd] f32 values
38//!         // Layout: data[token_idx * n_embd + dim_idx]
39//!     }
40//! }
41//! ```
42
43use std::collections::HashMap;
44
45/// Information about a single captured tensor.
46#[derive(Debug, Clone)]
47pub struct CapturedTensor {
48    /// The tensor name (e.g. `"l_out-13"`).
49    pub name: String,
50    /// The layer index extracted from the name, or `None` if the name
51    /// doesn't follow the `"prefix-N"` pattern.
52    pub layer: Option<usize>,
53    /// First dimension (typically `n_embd` / hidden dimension).
54    pub ne0: usize,
55    /// Second dimension (typically `n_tokens`).
56    pub ne1: usize,
57    /// Flattened f32 data with `ne0 * ne1` elements.
58    ///
59    /// Layout (row-major from ggml's perspective):
60    /// `data[token_idx * ne0 + dim_idx]`
61    ///
62    /// This matches the ggml tensor layout where `ne[0]` is the
63    /// innermost (contiguous) dimension.
64    pub data: Vec<f32>,
65}
66
67impl CapturedTensor {
68    /// Number of embedding dimensions (alias for `ne0`).
69    #[inline]
70    pub fn n_embd(&self) -> usize {
71        self.ne0
72    }
73
74    /// Number of tokens (alias for `ne1`).
75    #[inline]
76    pub fn n_tokens(&self) -> usize {
77        self.ne1
78    }
79
80    /// Get the hidden state for a specific token.
81    ///
82    /// Returns a slice of `n_embd` floats, or `None` if `token_idx` is
83    /// out of range.
84    pub fn token_embedding(&self, token_idx: usize) -> Option<&[f32]> {
85        if token_idx >= self.ne1 {
86            return None;
87        }
88        let start = token_idx * self.ne0;
89        Some(&self.data[start..start + self.ne0])
90    }
91}
92
93/// Strategy for selecting which tensors to capture.
94#[derive(Debug, Clone)]
95enum CaptureFilter {
96    /// Capture tensors named `"l_out-{N}"` for specific layer indices.
97    Layers(Vec<usize>),
98    /// Capture tensors whose names exactly match the given strings.
99    Names(Vec<String>),
100    /// Capture tensors whose names start with the given prefix.
101    Prefix(String),
102    /// Capture all tensors (warning: can be very large).
103    All,
104}
105
106/// Captures intermediate tensor outputs during `llama_decode`.
107///
108/// Create a `TensorCapture`, attach it to `LlamaContextParams` via
109/// [`with_tensor_capture`](super::params::LlamaContextParams::with_tensor_capture),
110/// then call `decode()`. After decode completes, read captured data
111/// via [`get`], [`get_layer`], or [`iter`].
112///
113/// # Lifetime & Safety
114///
115/// The `TensorCapture` must outlive the `LlamaContext` it is attached to.
116/// The borrow is enforced by [`with_tensor_capture`](super::params::LlamaContextParams::with_tensor_capture)
117/// taking `&mut self`.
118pub struct TensorCapture {
119    filter: CaptureFilter,
120    /// Captured tensors keyed by name.
121    captured: HashMap<String, CapturedTensor>,
122}
123
124impl std::fmt::Debug for TensorCapture {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        f.debug_struct("TensorCapture")
127            .field("filter", &self.filter)
128            .field("captured_count", &self.captured.len())
129            .field(
130                "captured_keys",
131                &self.captured.keys().collect::<Vec<_>>(),
132            )
133            .finish()
134    }
135}
136
137impl TensorCapture {
138    /// Create a capture that intercepts layer outputs `"l_out-{N}"` for
139    /// the specified layer indices.
140    ///
141    /// This is the most common use case for extracting per-layer hidden
142    /// states from a language model.
143    ///
144    /// # Example
145    ///
146    /// ```rust,ignore
147    /// // Capture layers 13, 20, 27 (typical for LLaMA-3.2-3B with positions [0.5, 0.75, 1.0])
148    /// let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
149    /// ```
150    pub fn for_layers(layer_indices: &[usize]) -> Self {
151        Self {
152            filter: CaptureFilter::Layers(layer_indices.to_vec()),
153            captured: HashMap::new(),
154        }
155    }
156
157    /// Create a capture that intercepts tensors with exact matching names.
158    ///
159    /// # Example
160    ///
161    /// ```rust,ignore
162    /// let mut capture = TensorCapture::for_names(&["result_norm", "l_out-27"]);
163    /// ```
164    pub fn for_names(names: &[&str]) -> Self {
165        Self {
166            filter: CaptureFilter::Names(names.iter().map(|s| s.to_string()).collect()),
167            captured: HashMap::new(),
168        }
169    }
170
171    /// Create a capture that intercepts all tensors whose name starts with
172    /// the given prefix.
173    ///
174    /// # Example
175    ///
176    /// ```rust,ignore
177    /// // Capture all attention outputs
178    /// let mut capture = TensorCapture::for_prefix("attn_out");
179    /// ```
180    pub fn for_prefix(prefix: &str) -> Self {
181        Self {
182            filter: CaptureFilter::Prefix(prefix.to_string()),
183            captured: HashMap::new(),
184        }
185    }
186
187    /// Create a capture that intercepts **all** tensors.
188    ///
189    /// ⚠️ Warning: this can produce very large amounts of data.
190    /// Use only for debugging or inspection.
191    pub fn all() -> Self {
192        Self {
193            filter: CaptureFilter::All,
194            captured: HashMap::new(),
195        }
196    }
197
198    /// Clear all previously captured data, keeping the filter configuration.
199    ///
200    /// Call this before a new `decode()` if reusing the capture across
201    /// multiple batches.
202    pub fn clear(&mut self) {
203        self.captured.clear();
204    }
205
206    /// Get a captured tensor by its full name (e.g. `"l_out-13"`).
207    pub fn get(&self, name: &str) -> Option<&CapturedTensor> {
208        self.captured.get(name)
209    }
210
211    /// Get a captured layer output by layer index.
212    ///
213    /// Looks up `"l_out-{layer_idx}"`.
214    pub fn get_layer(&self, layer_idx: usize) -> Option<&CapturedTensor> {
215        self.captured.get(&format!("l_out-{layer_idx}"))
216    }
217
218    /// Returns `true` if the specified layer was captured.
219    pub fn has_layer(&self, layer_idx: usize) -> bool {
220        self.captured.contains_key(&format!("l_out-{layer_idx}"))
221    }
222
223    /// Number of tensors captured so far.
224    pub fn len(&self) -> usize {
225        self.captured.len()
226    }
227
228    /// Returns `true` if no tensors have been captured.
229    pub fn is_empty(&self) -> bool {
230        self.captured.is_empty()
231    }
232
233    /// Iterate over all captured tensors.
234    pub fn iter(&self) -> impl Iterator<Item = (&str, &CapturedTensor)> {
235        self.captured.iter().map(|(k, v)| (k.as_str(), v))
236    }
237
238    /// Get all captured layer indices (sorted).
239    pub fn captured_layers(&self) -> Vec<usize> {
240        let mut layers: Vec<usize> = self
241            .captured
242            .values()
243            .filter_map(|ct| ct.layer)
244            .collect();
245        layers.sort_unstable();
246        layers.dedup();
247        layers
248    }
249
250    // ── Internal: callback matching ──────────────────────────────────
251
252    /// Check if a tensor name matches the capture filter.
253    fn matches(&self, name: &str) -> bool {
254        match &self.filter {
255            CaptureFilter::Layers(indices) => {
256                if let Some(suffix) = name.strip_prefix("l_out-") {
257                    if let Ok(idx) = suffix.parse::<usize>() {
258                        return indices.contains(&idx);
259                    }
260                }
261                false
262            }
263            CaptureFilter::Names(names) => names.iter().any(|n| n == name),
264            CaptureFilter::Prefix(prefix) => name.starts_with(prefix.as_str()),
265            CaptureFilter::All => true,
266        }
267    }
268
269    /// Store a captured tensor.
270    fn store(&mut self, name: String, ne0: usize, ne1: usize, data: Vec<f32>) {
271        let layer = name
272            .strip_prefix("l_out-")
273            .and_then(|s| s.parse::<usize>().ok());
274
275        self.captured.insert(
276            name.clone(),
277            CapturedTensor {
278                name,
279                layer,
280                ne0,
281                ne1,
282                data,
283            },
284        );
285    }
286}
287
288// ── The extern "C" callback ──────────────────────────────────────────────
289
290/// The `cb_eval` callback function passed to llama.cpp.
291///
292/// # Safety
293///
294/// This function is called from C code during graph evaluation.
295/// `user_data` must point to a valid `TensorCapture` instance.
296pub(crate) unsafe extern "C" fn tensor_capture_callback(
297    t: *mut llama_cpp_sys_4::ggml_tensor,
298    ask: bool,
299    user_data: *mut std::ffi::c_void,
300) -> bool {
301    if t.is_null() || user_data.is_null() {
302        return false;
303    }
304
305    // Read tensor name from the fixed-size C array
306    let name_bytes = &(*t).name;
307    let len = name_bytes
308        .iter()
309        .position(|&b| b == 0)
310        .unwrap_or(name_bytes.len());
311    let name = std::str::from_utf8_unchecked(std::slice::from_raw_parts(
312        name_bytes.as_ptr() as *const u8,
313        len,
314    ));
315
316    let state = &mut *(user_data as *mut TensorCapture);
317
318    if !state.matches(name) {
319        return false;
320    }
321
322    if ask {
323        return true;
324    }
325
326    // Data phase: copy tensor data out
327    let ne0 = (*t).ne[0] as usize;
328    let ne1 = (*t).ne[1] as usize;
329    let n_elements = ne0 * ne1;
330
331    let mut buf = vec![0f32; n_elements];
332    llama_cpp_sys_4::ggml_backend_tensor_get(
333        t,
334        buf.as_mut_ptr() as *mut std::ffi::c_void,
335        0,
336        n_elements * std::mem::size_of::<f32>(),
337    );
338
339    state.store(name.to_string(), ne0, ne1, buf);
340
341    true
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_for_layers_matching() {
350        let capture = TensorCapture::for_layers(&[13, 20, 27]);
351        assert!(capture.matches("l_out-13"));
352        assert!(capture.matches("l_out-20"));
353        assert!(capture.matches("l_out-27"));
354        assert!(!capture.matches("l_out-0"));
355        assert!(!capture.matches("l_out-14"));
356        assert!(!capture.matches("attn_norm-13"));
357        assert!(!capture.matches("result_norm"));
358    }
359
360    #[test]
361    fn test_for_names_matching() {
362        let capture = TensorCapture::for_names(&["result_norm", "l_out-27"]);
363        assert!(capture.matches("result_norm"));
364        assert!(capture.matches("l_out-27"));
365        assert!(!capture.matches("l_out-13"));
366        assert!(!capture.matches("result_output"));
367    }
368
369    #[test]
370    fn test_for_prefix_matching() {
371        let capture = TensorCapture::for_prefix("attn_out");
372        assert!(capture.matches("attn_out-0"));
373        assert!(capture.matches("attn_out-27"));
374        assert!(!capture.matches("attn_norm-0"));
375        assert!(!capture.matches("l_out-0"));
376    }
377
378    #[test]
379    fn test_all_matching() {
380        let capture = TensorCapture::all();
381        assert!(capture.matches("l_out-13"));
382        assert!(capture.matches("result_norm"));
383        assert!(capture.matches("anything"));
384    }
385
386    #[test]
387    fn test_store_and_get() {
388        let mut capture = TensorCapture::for_layers(&[13]);
389        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
390        capture.store("l_out-13".to_string(), 3, 2, data.clone());
391
392        assert_eq!(capture.len(), 1);
393        assert!(!capture.is_empty());
394
395        let ct = capture.get("l_out-13").unwrap();
396        assert_eq!(ct.name, "l_out-13");
397        assert_eq!(ct.layer, Some(13));
398        assert_eq!(ct.n_embd(), 3);
399        assert_eq!(ct.n_tokens(), 2);
400        assert_eq!(ct.data, data);
401
402        // Also accessible by layer index
403        let ct2 = capture.get_layer(13).unwrap();
404        assert_eq!(ct2.name, ct.name);
405        assert!(capture.has_layer(13));
406        assert!(!capture.has_layer(14));
407    }
408
409    #[test]
410    fn test_token_embedding() {
411        let mut capture = TensorCapture::for_layers(&[5]);
412        // 2 tokens, 3 dims: token0=[1,2,3], token1=[4,5,6]
413        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
414        capture.store("l_out-5".to_string(), 3, 2, data);
415
416        let ct = capture.get_layer(5).unwrap();
417        assert_eq!(ct.token_embedding(0), Some(&[1.0, 2.0, 3.0][..]));
418        assert_eq!(ct.token_embedding(1), Some(&[4.0, 5.0, 6.0][..]));
419        assert_eq!(ct.token_embedding(2), None);
420    }
421
422    #[test]
423    fn test_captured_layers() {
424        let mut capture = TensorCapture::for_layers(&[5, 10, 20]);
425        capture.store("l_out-10".to_string(), 2, 1, vec![0.0, 0.0]);
426        capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
427        assert_eq!(capture.captured_layers(), vec![5, 10]);
428    }
429
430    #[test]
431    fn test_clear() {
432        let mut capture = TensorCapture::for_layers(&[5]);
433        capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
434        assert_eq!(capture.len(), 1);
435        capture.clear();
436        assert_eq!(capture.len(), 0);
437        assert!(capture.is_empty());
438    }
439
440    #[test]
441    fn test_non_layer_tensor() {
442        let mut capture = TensorCapture::for_names(&["result_norm"]);
443        capture.store("result_norm".to_string(), 4, 3, vec![0.0; 12]);
444        let ct = capture.get("result_norm").unwrap();
445        assert_eq!(ct.layer, None);
446        assert_eq!(ct.n_embd(), 4);
447        assert_eq!(ct.n_tokens(), 3);
448    }
449}