Skip to main content

llama_cpp_4/context/
tensor_capture.rs

1//! Capture intermediate tensor outputs during [`crate::LlamaContext::decode`].
2//!
3//! llama.cpp builds a computation graph for each forward pass. Every node has a
4//! string name — for transformer blocks the layer output is typically
5//! `"l_out-{N}"` (e.g. `"l_out-13"`), attention norms are `"attn_norm-{N}"`, and
6//! the final norm is `"result_norm"`.
7//!
8//! The graph evaluation callback (`cb_eval`) runs in two phases for each node:
9//!
10//! | Phase | `ask` | Behaviour |
11//! |---|---|---|
12//! | Ask | `true` | Return `true` to request a copy of this tensor's data. |
13//! | Data | `false` | Tensor is computed; data is copied via `ggml_backend_tensor_get`. |
14//!
15//! [`TensorCapture`] implements that callback and stores matching tensors in a
16//! [`HashMap`] you can read after `decode()` finishes.
17//!
18//! # Typical use cases
19//!
20//! - **Layer probing** — inspect hidden states at specific depths.
21//! - **EAGLE / distillation** — read draft-model anchor layers (see `examples/eagle`).
22//! - **Debugging** — dump norms or attention outputs with [`TensorCapture::for_prefix`].
23//!
24//! # Setup
25//!
26//! 1. Build a [`TensorCapture`] with the filter you need ([`TensorCapture::for_layers`]
27//!    is the common case).
28//! 2. Pass it to [`LlamaContextParams::with_tensor_capture`](crate::LlamaContextParams::with_tensor_capture). The capture must
29//!    **outlive** the [`LlamaContext`](crate::LlamaContext).
30//! 3. Run [`LlamaContext::decode`](crate::LlamaContext::decode) as usual.
31//! 4. Read [`CapturedTensor`] values via [`TensorCapture::get_layer`],
32//!    [`TensorCapture::get`], or [`TensorCapture::iter`].
33//!
34//! Call [`TensorCapture::clear`](crate::TensorCapture::clear) before reusing the same capture on another batch.
35//!
36//! # Example
37//!
38//! ```no_run
39//! use llama_cpp_4::prelude::*;
40//! use std::num::NonZeroU32;
41//!
42//! fn main() {
43//!     let backend = LlamaBackend::init().unwrap();
44//!     let model = LlamaModel::load_from_file(
45//!         &backend,
46//!         "model.gguf",
47//!         &LlamaModelParams::default(),
48//!     )
49//!     .unwrap();
50//!
51//!     let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
52//!     let ctx_params = LlamaContextParams::default()
53//!         .with_n_ctx(NonZeroU32::new(512))
54//!         .with_tensor_capture(&mut capture);
55//!     let mut ctx = model.new_context(&backend, ctx_params).unwrap();
56//!
57//!     let tokens = model.str_to_token("Hello", AddBos::Always).unwrap();
58//!     let mut batch = LlamaBatch::new(512, 1);
59//!     for (i, &tok) in tokens.iter().enumerate() {
60//!         batch
61//!             .add(tok, i as i32, &[0], i == tokens.len() - 1)
62//!             .unwrap();
63//!     }
64//!     ctx.decode(&mut batch).unwrap();
65//!
66//!     for &layer in &[13, 20, 27] {
67//!         if let Some(t) = capture.get_layer(layer) {
68//!             println!(
69//!                 "l_out-{layer}: {} tokens × {} dims",
70//!                 t.n_tokens(),
71//!                 t.n_embd()
72//!             );
73//!             if let Some(vec) = t.token_embedding(0) {
74//!                 println!("  first token, first 3 dims: {:?}", &vec[..3.min(vec.len())]);
75//!             }
76//!         }
77//!     }
78//! }
79//! ```
80//!
81//! # Tensor layout
82//!
83//! Each [`CapturedTensor`] stores a flat `f32` buffer with
84//! `data[token_idx * n_embd + dim_idx]` (ggml row-major: `ne0` = embedding dim,
85//! `ne1` = token count). Use [`CapturedTensor::token_embedding`] to slice one row.
86
87use std::collections::HashMap;
88
89/// A single tensor copied out of the decode graph.
90///
91/// Produced by [`TensorCapture`] after a successful [`crate::LlamaContext::decode`].
92/// For layer outputs (`"l_out-N"`), [`Self::layer`] is set to `N`.
93#[derive(Debug, Clone)]
94pub struct CapturedTensor {
95    /// Graph node name (e.g. `"l_out-13"`, `"result_norm"`).
96    pub name: String,
97    /// Layer index when `name` is `"l_out-{N}"`, otherwise `None`.
98    pub layer: Option<usize>,
99    /// First dimension (typically `n_embd` / hidden size).
100    pub ne0: usize,
101    /// Second dimension (typically number of tokens in the batch position).
102    pub ne1: usize,
103    /// Flattened `ne0 * ne1` values in ggml row-major order.
104    ///
105    /// Index as `data[token_idx * ne0 + dim_idx]`.
106    pub data: Vec<f32>,
107}
108
109impl CapturedTensor {
110    /// Number of embedding dimensions (alias for [`Self::ne0`]).
111    #[inline]
112    #[must_use]
113    pub fn n_embd(&self) -> usize {
114        self.ne0
115    }
116
117    /// Number of token positions (alias for [`Self::ne1`]).
118    #[inline]
119    #[must_use]
120    pub fn n_tokens(&self) -> usize {
121        self.ne1
122    }
123
124    /// Hidden-state vector for one token index.
125    ///
126    /// Returns `None` when `token_idx >= n_tokens()`.
127    #[must_use]
128    pub fn token_embedding(&self, token_idx: usize) -> Option<&[f32]> {
129        if token_idx >= self.ne1 {
130            return None;
131        }
132        let start = token_idx * self.ne0;
133        Some(&self.data[start..start + self.ne0])
134    }
135}
136
137/// Strategy for selecting which tensors to capture.
138#[derive(Debug, Clone)]
139enum CaptureFilter {
140    /// `"l_out-{N}"` for each listed layer index `N`.
141    Layers(Vec<usize>),
142    /// Exact graph node names.
143    Names(Vec<String>),
144    /// Names starting with a prefix (e.g. `"attn_out"`).
145    Prefix(String),
146    /// Every node (can be very large — debug only).
147    All,
148}
149
150/// Captures intermediate tensors during [`crate::LlamaContext::decode`].
151///
152/// Attach with [`LlamaContextParams::with_tensor_capture`](crate::LlamaContextParams::with_tensor_capture) before creating the
153/// context. The same instance can be reused across decodes if you call
154/// [`Self::clear`] between passes.
155///
156/// # Lifetime
157///
158/// The capture must outlive the [`crate::LlamaContext`] it is wired into;
159/// [`LlamaContextParams::with_tensor_capture`](crate::LlamaContextParams::with_tensor_capture) takes `&mut TensorCapture` to
160/// enforce this at compile time.
161pub struct TensorCapture {
162    filter: CaptureFilter,
163    captured: HashMap<String, CapturedTensor>,
164}
165
166impl std::fmt::Debug for TensorCapture {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        f.debug_struct("TensorCapture")
169            .field("filter", &self.filter)
170            .field("captured_count", &self.captured.len())
171            .field("captured_keys", &self.captured.keys().collect::<Vec<_>>())
172            .finish()
173    }
174}
175
176impl TensorCapture {
177    /// Capture transformer layer outputs `"l_out-{N}"` for the given indices.
178    ///
179    /// This is the usual choice for hidden-state extraction. EAGLE-3 draft models
180    /// often use three layers at ~50%, 75%, and 100% depth — e.g. `[13, 20, 27]`
181    /// on a 28-layer model.
182    #[must_use]
183    pub fn for_layers(layer_indices: &[usize]) -> Self {
184        Self {
185            filter: CaptureFilter::Layers(layer_indices.to_vec()),
186            captured: HashMap::new(),
187        }
188    }
189
190    /// Capture tensors whose graph names match exactly.
191    ///
192    /// Example names: `"result_norm"`, `"l_out-27"`.
193    #[must_use]
194    pub fn for_names(names: &[&str]) -> Self {
195        Self {
196            filter: CaptureFilter::Names(
197                names.iter().map(std::string::ToString::to_string).collect(),
198            ),
199            captured: HashMap::new(),
200        }
201    }
202
203    /// Capture every tensor whose name starts with `prefix`.
204    ///
205    /// Useful for families like `"attn_out-*"` or `"attn_norm-*"`.
206    #[must_use]
207    pub fn for_prefix(prefix: &str) -> Self {
208        Self {
209            filter: CaptureFilter::Prefix(prefix.to_string()),
210            captured: HashMap::new(),
211        }
212    }
213
214    /// Capture **all** graph nodes.
215    ///
216    /// Warning: memory use scales with model size and sequence length. Prefer
217    /// [`Self::for_layers`] or [`Self::for_names`] in production code.
218    #[must_use]
219    pub fn all() -> Self {
220        Self {
221            filter: CaptureFilter::All,
222            captured: HashMap::new(),
223        }
224    }
225
226    /// Drop captured tensors but keep the filter (safe to call before another decode).
227    pub fn clear(&mut self) {
228        self.captured.clear();
229    }
230
231    /// Lookup by full graph name (e.g. `"l_out-13"`).
232    #[must_use]
233    pub fn get(&self, name: &str) -> Option<&CapturedTensor> {
234        self.captured.get(name)
235    }
236
237    /// Lookup a layer output (`"l_out-{layer_idx}"`).
238    #[must_use]
239    pub fn get_layer(&self, layer_idx: usize) -> Option<&CapturedTensor> {
240        self.captured.get(&format!("l_out-{layer_idx}"))
241    }
242
243    /// Whether `"l_out-{layer_idx}"` was captured in the last decode.
244    #[must_use]
245    pub fn has_layer(&self, layer_idx: usize) -> bool {
246        self.captured.contains_key(&format!("l_out-{layer_idx}"))
247    }
248
249    /// Number of tensors stored from the most recent decode.
250    #[must_use]
251    pub fn len(&self) -> usize {
252        self.captured.len()
253    }
254
255    /// `true` when [`Self::len`] is zero.
256    #[must_use]
257    pub fn is_empty(&self) -> bool {
258        self.captured.is_empty()
259    }
260
261    /// Iterate `(name, tensor)` pairs from the last decode.
262    pub fn iter(&self) -> impl Iterator<Item = (&str, &CapturedTensor)> {
263        self.captured.iter().map(|(k, v)| (k.as_str(), v))
264    }
265
266    /// Sorted layer indices present among captured `"l_out-*"` tensors.
267    #[must_use]
268    pub fn captured_layers(&self) -> Vec<usize> {
269        let mut layers: Vec<usize> = self.captured.values().filter_map(|ct| ct.layer).collect();
270        layers.sort_unstable();
271        layers.dedup();
272        layers
273    }
274
275    fn matches(&self, name: &str) -> bool {
276        match &self.filter {
277            CaptureFilter::Layers(indices) => {
278                if let Some(suffix) = name.strip_prefix("l_out-") {
279                    if let Ok(idx) = suffix.parse::<usize>() {
280                        return indices.contains(&idx);
281                    }
282                }
283                false
284            }
285            CaptureFilter::Names(names) => names.iter().any(|n| n == name),
286            CaptureFilter::Prefix(prefix) => name.starts_with(prefix.as_str()),
287            CaptureFilter::All => true,
288        }
289    }
290
291    fn store(&mut self, name: String, ne0: usize, ne1: usize, data: Vec<f32>) {
292        let layer = name
293            .strip_prefix("l_out-")
294            .and_then(|s| s.parse::<usize>().ok());
295
296        self.captured.insert(
297            name.clone(),
298            CapturedTensor {
299                name,
300                layer,
301                ne0,
302                ne1,
303                data,
304            },
305        );
306    }
307}
308
309/// `cb_eval` callback installed by [`LlamaContextParams::with_tensor_capture`](crate::LlamaContextParams::with_tensor_capture).
310///
311/// # Safety
312///
313/// `user_data` must point to a live [`TensorCapture`] for the context lifetime.
314pub(crate) unsafe extern "C" fn tensor_capture_callback(
315    t: *mut llama_cpp_sys_4::ggml_tensor,
316    ask: bool,
317    user_data: *mut std::ffi::c_void,
318) -> bool {
319    if t.is_null() || user_data.is_null() {
320        return false;
321    }
322
323    let name_bytes = &(*t).name;
324    let len = name_bytes
325        .iter()
326        .position(|&b| b == 0)
327        .unwrap_or(name_bytes.len());
328    let name = std::str::from_utf8_unchecked(std::slice::from_raw_parts(
329        name_bytes.as_ptr().cast::<u8>(),
330        len,
331    ));
332
333    let state = &mut *user_data.cast::<TensorCapture>();
334
335    if !state.matches(name) {
336        return false;
337    }
338
339    if ask {
340        return true;
341    }
342
343    let ne0 = usize::try_from((*t).ne[0]).expect("tensor ne[0] must be non-negative");
344    let ne1 = usize::try_from((*t).ne[1]).expect("tensor ne[1] must be non-negative");
345    let n_elements = ne0 * ne1;
346
347    let mut buf = vec![0f32; n_elements];
348    llama_cpp_sys_4::ggml_backend_tensor_get(
349        t,
350        buf.as_mut_ptr().cast::<std::ffi::c_void>(),
351        0,
352        n_elements * std::mem::size_of::<f32>(),
353    );
354
355    state.store(name.to_string(), ne0, ne1, buf);
356
357    true
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn test_for_layers_matching() {
366        let capture = TensorCapture::for_layers(&[13, 20, 27]);
367        assert!(capture.matches("l_out-13"));
368        assert!(capture.matches("l_out-20"));
369        assert!(capture.matches("l_out-27"));
370        assert!(!capture.matches("l_out-0"));
371        assert!(!capture.matches("l_out-14"));
372        assert!(!capture.matches("attn_norm-13"));
373        assert!(!capture.matches("result_norm"));
374    }
375
376    #[test]
377    fn test_for_names_matching() {
378        let capture = TensorCapture::for_names(&["result_norm", "l_out-27"]);
379        assert!(capture.matches("result_norm"));
380        assert!(capture.matches("l_out-27"));
381        assert!(!capture.matches("l_out-13"));
382        assert!(!capture.matches("result_output"));
383    }
384
385    #[test]
386    fn test_for_prefix_matching() {
387        let capture = TensorCapture::for_prefix("attn_out");
388        assert!(capture.matches("attn_out-0"));
389        assert!(capture.matches("attn_out-27"));
390        assert!(!capture.matches("attn_norm-0"));
391        assert!(!capture.matches("l_out-0"));
392    }
393
394    #[test]
395    fn test_all_matching() {
396        let capture = TensorCapture::all();
397        assert!(capture.matches("l_out-13"));
398        assert!(capture.matches("result_norm"));
399        assert!(capture.matches("anything"));
400    }
401
402    #[test]
403    fn test_store_and_get() {
404        let mut capture = TensorCapture::for_layers(&[13]);
405        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
406        capture.store("l_out-13".to_string(), 3, 2, data.clone());
407
408        assert_eq!(capture.len(), 1);
409        assert!(!capture.is_empty());
410
411        let ct = capture.get("l_out-13").unwrap();
412        assert_eq!(ct.name, "l_out-13");
413        assert_eq!(ct.layer, Some(13));
414        assert_eq!(ct.n_embd(), 3);
415        assert_eq!(ct.n_tokens(), 2);
416        assert_eq!(ct.data, data);
417
418        let ct2 = capture.get_layer(13).unwrap();
419        assert_eq!(ct2.name, ct.name);
420        assert!(capture.has_layer(13));
421        assert!(!capture.has_layer(14));
422    }
423
424    #[test]
425    fn test_token_embedding() {
426        let mut capture = TensorCapture::for_layers(&[5]);
427        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
428        capture.store("l_out-5".to_string(), 3, 2, data);
429
430        let ct = capture.get_layer(5).unwrap();
431        assert_eq!(ct.token_embedding(0), Some(&[1.0, 2.0, 3.0][..]));
432        assert_eq!(ct.token_embedding(1), Some(&[4.0, 5.0, 6.0][..]));
433        assert_eq!(ct.token_embedding(2), None);
434    }
435
436    #[test]
437    fn test_captured_layers() {
438        let mut capture = TensorCapture::for_layers(&[5, 10, 20]);
439        capture.store("l_out-10".to_string(), 2, 1, vec![0.0, 0.0]);
440        capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
441        assert_eq!(capture.captured_layers(), vec![5, 10]);
442    }
443
444    #[test]
445    fn test_clear() {
446        let mut capture = TensorCapture::for_layers(&[5]);
447        capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
448        assert_eq!(capture.len(), 1);
449        capture.clear();
450        assert_eq!(capture.len(), 0);
451        assert!(capture.is_empty());
452    }
453
454    #[test]
455    fn test_non_layer_tensor() {
456        let mut capture = TensorCapture::for_names(&["result_norm"]);
457        capture.store("result_norm".to_string(), 4, 3, vec![0.0; 12]);
458        let ct = capture.get("result_norm").unwrap();
459        assert_eq!(ct.layer, None);
460        assert_eq!(ct.n_embd(), 4);
461        assert_eq!(ct.n_tokens(), 3);
462    }
463}