llama_cpp_4/context/tensor_capture.rs
1//! Capture intermediate tensor outputs during [`crate::LlamaContext::decode`].
2//!
3//! llama.cpp builds a computation graph for each forward pass. Every node has a
4//! string name — for transformer blocks the layer output is typically
5//! `"l_out-{N}"` (e.g. `"l_out-13"`), attention norms are `"attn_norm-{N}"`, and
6//! the final norm is `"result_norm"`.
7//!
8//! The graph evaluation callback (`cb_eval`) runs in two phases for each node:
9//!
10//! | Phase | `ask` | Behaviour |
11//! |---|---|---|
12//! | Ask | `true` | Return `true` to request a copy of this tensor's data. |
13//! | Data | `false` | Tensor is computed; data is copied via `ggml_backend_tensor_get`. |
14//!
15//! [`TensorCapture`] implements that callback and stores matching tensors in a
16//! [`HashMap`] you can read after `decode()` finishes.
17//!
18//! # Typical use cases
19//!
20//! - **Layer probing** — inspect hidden states at specific depths.
21//! - **EAGLE / distillation** — read draft-model anchor layers (see `examples/eagle`).
22//! - **Debugging** — dump norms or attention outputs with [`TensorCapture::for_prefix`].
23//!
24//! # Setup
25//!
26//! 1. Build a [`TensorCapture`] with the filter you need ([`TensorCapture::for_layers`]
27//! is the common case).
28//! 2. Pass it to [`LlamaContextParams::with_tensor_capture`](crate::LlamaContextParams::with_tensor_capture). The capture must
29//! **outlive** the [`LlamaContext`](crate::LlamaContext).
30//! 3. Run [`LlamaContext::decode`](crate::LlamaContext::decode) as usual.
31//! 4. Read [`CapturedTensor`] values via [`TensorCapture::get_layer`],
32//! [`TensorCapture::get`], or [`TensorCapture::iter`].
33//!
34//! Call [`TensorCapture::clear`](crate::TensorCapture::clear) before reusing the same capture on another batch.
35//!
36//! # Example
37//!
38//! ```no_run
39//! use llama_cpp_4::prelude::*;
40//! use std::num::NonZeroU32;
41//!
42//! fn main() {
43//! let backend = LlamaBackend::init().unwrap();
44//! let model = LlamaModel::load_from_file(
45//! &backend,
46//! "model.gguf",
47//! &LlamaModelParams::default(),
48//! )
49//! .unwrap();
50//!
51//! let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
52//! let ctx_params = LlamaContextParams::default()
53//! .with_n_ctx(NonZeroU32::new(512))
54//! .with_tensor_capture(&mut capture);
55//! let mut ctx = model.new_context(&backend, ctx_params).unwrap();
56//!
57//! let tokens = model.str_to_token("Hello", AddBos::Always).unwrap();
58//! let mut batch = LlamaBatch::new(512, 1);
59//! for (i, &tok) in tokens.iter().enumerate() {
60//! batch
61//! .add(tok, i as i32, &[0], i == tokens.len() - 1)
62//! .unwrap();
63//! }
64//! ctx.decode(&mut batch).unwrap();
65//!
66//! for &layer in &[13, 20, 27] {
67//! if let Some(t) = capture.get_layer(layer) {
68//! println!(
69//! "l_out-{layer}: {} tokens × {} dims",
70//! t.n_tokens(),
71//! t.n_embd()
72//! );
73//! if let Some(vec) = t.token_embedding(0) {
74//! println!(" first token, first 3 dims: {:?}", &vec[..3.min(vec.len())]);
75//! }
76//! }
77//! }
78//! }
79//! ```
80//!
81//! # Tensor layout
82//!
83//! Each [`CapturedTensor`] stores a flat `f32` buffer with
84//! `data[token_idx * n_embd + dim_idx]` (ggml row-major: `ne0` = embedding dim,
85//! `ne1` = token count). Use [`CapturedTensor::token_embedding`] to slice one row.
86
87use std::collections::HashMap;
88
89/// A single tensor copied out of the decode graph.
90///
91/// Produced by [`TensorCapture`] after a successful [`crate::LlamaContext::decode`].
92/// For layer outputs (`"l_out-N"`), [`Self::layer`] is set to `N`.
93#[derive(Debug, Clone)]
94pub struct CapturedTensor {
95 /// Graph node name (e.g. `"l_out-13"`, `"result_norm"`).
96 pub name: String,
97 /// Layer index when `name` is `"l_out-{N}"`, otherwise `None`.
98 pub layer: Option<usize>,
99 /// First dimension (typically `n_embd` / hidden size).
100 pub ne0: usize,
101 /// Second dimension (typically number of tokens in the batch position).
102 pub ne1: usize,
103 /// Flattened `ne0 * ne1` values in ggml row-major order.
104 ///
105 /// Index as `data[token_idx * ne0 + dim_idx]`.
106 pub data: Vec<f32>,
107}
108
109impl CapturedTensor {
110 /// Number of embedding dimensions (alias for [`Self::ne0`]).
111 #[inline]
112 #[must_use]
113 pub fn n_embd(&self) -> usize {
114 self.ne0
115 }
116
117 /// Number of token positions (alias for [`Self::ne1`]).
118 #[inline]
119 #[must_use]
120 pub fn n_tokens(&self) -> usize {
121 self.ne1
122 }
123
124 /// Hidden-state vector for one token index.
125 ///
126 /// Returns `None` when `token_idx >= n_tokens()`.
127 #[must_use]
128 pub fn token_embedding(&self, token_idx: usize) -> Option<&[f32]> {
129 if token_idx >= self.ne1 {
130 return None;
131 }
132 let start = token_idx * self.ne0;
133 Some(&self.data[start..start + self.ne0])
134 }
135}
136
137/// Strategy for selecting which tensors to capture.
138#[derive(Debug, Clone)]
139enum CaptureFilter {
140 /// `"l_out-{N}"` for each listed layer index `N`.
141 Layers(Vec<usize>),
142 /// Exact graph node names.
143 Names(Vec<String>),
144 /// Names starting with a prefix (e.g. `"attn_out"`).
145 Prefix(String),
146 /// Every node (can be very large — debug only).
147 All,
148}
149
150/// Captures intermediate tensors during [`crate::LlamaContext::decode`].
151///
152/// Attach with [`LlamaContextParams::with_tensor_capture`](crate::LlamaContextParams::with_tensor_capture) before creating the
153/// context. The same instance can be reused across decodes if you call
154/// [`Self::clear`] between passes.
155///
156/// # Lifetime
157///
158/// The capture must outlive the [`crate::LlamaContext`] it is wired into;
159/// [`LlamaContextParams::with_tensor_capture`](crate::LlamaContextParams::with_tensor_capture) takes `&mut TensorCapture` to
160/// enforce this at compile time.
161pub struct TensorCapture {
162 filter: CaptureFilter,
163 captured: HashMap<String, CapturedTensor>,
164}
165
166impl std::fmt::Debug for TensorCapture {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 f.debug_struct("TensorCapture")
169 .field("filter", &self.filter)
170 .field("captured_count", &self.captured.len())
171 .field("captured_keys", &self.captured.keys().collect::<Vec<_>>())
172 .finish()
173 }
174}
175
176impl TensorCapture {
177 /// Capture transformer layer outputs `"l_out-{N}"` for the given indices.
178 ///
179 /// This is the usual choice for hidden-state extraction. EAGLE-3 draft models
180 /// often use three layers at ~50%, 75%, and 100% depth — e.g. `[13, 20, 27]`
181 /// on a 28-layer model.
182 #[must_use]
183 pub fn for_layers(layer_indices: &[usize]) -> Self {
184 Self {
185 filter: CaptureFilter::Layers(layer_indices.to_vec()),
186 captured: HashMap::new(),
187 }
188 }
189
190 /// Capture tensors whose graph names match exactly.
191 ///
192 /// Example names: `"result_norm"`, `"l_out-27"`.
193 #[must_use]
194 pub fn for_names(names: &[&str]) -> Self {
195 Self {
196 filter: CaptureFilter::Names(
197 names.iter().map(std::string::ToString::to_string).collect(),
198 ),
199 captured: HashMap::new(),
200 }
201 }
202
203 /// Capture every tensor whose name starts with `prefix`.
204 ///
205 /// Useful for families like `"attn_out-*"` or `"attn_norm-*"`.
206 #[must_use]
207 pub fn for_prefix(prefix: &str) -> Self {
208 Self {
209 filter: CaptureFilter::Prefix(prefix.to_string()),
210 captured: HashMap::new(),
211 }
212 }
213
214 /// Capture **all** graph nodes.
215 ///
216 /// Warning: memory use scales with model size and sequence length. Prefer
217 /// [`Self::for_layers`] or [`Self::for_names`] in production code.
218 #[must_use]
219 pub fn all() -> Self {
220 Self {
221 filter: CaptureFilter::All,
222 captured: HashMap::new(),
223 }
224 }
225
226 /// Drop captured tensors but keep the filter (safe to call before another decode).
227 pub fn clear(&mut self) {
228 self.captured.clear();
229 }
230
231 /// Lookup by full graph name (e.g. `"l_out-13"`).
232 #[must_use]
233 pub fn get(&self, name: &str) -> Option<&CapturedTensor> {
234 self.captured.get(name)
235 }
236
237 /// Lookup a layer output (`"l_out-{layer_idx}"`).
238 #[must_use]
239 pub fn get_layer(&self, layer_idx: usize) -> Option<&CapturedTensor> {
240 self.captured.get(&format!("l_out-{layer_idx}"))
241 }
242
243 /// Whether `"l_out-{layer_idx}"` was captured in the last decode.
244 #[must_use]
245 pub fn has_layer(&self, layer_idx: usize) -> bool {
246 self.captured.contains_key(&format!("l_out-{layer_idx}"))
247 }
248
249 /// Number of tensors stored from the most recent decode.
250 #[must_use]
251 pub fn len(&self) -> usize {
252 self.captured.len()
253 }
254
255 /// `true` when [`Self::len`] is zero.
256 #[must_use]
257 pub fn is_empty(&self) -> bool {
258 self.captured.is_empty()
259 }
260
261 /// Iterate `(name, tensor)` pairs from the last decode.
262 pub fn iter(&self) -> impl Iterator<Item = (&str, &CapturedTensor)> {
263 self.captured.iter().map(|(k, v)| (k.as_str(), v))
264 }
265
266 /// Sorted layer indices present among captured `"l_out-*"` tensors.
267 #[must_use]
268 pub fn captured_layers(&self) -> Vec<usize> {
269 let mut layers: Vec<usize> = self.captured.values().filter_map(|ct| ct.layer).collect();
270 layers.sort_unstable();
271 layers.dedup();
272 layers
273 }
274
275 fn matches(&self, name: &str) -> bool {
276 match &self.filter {
277 CaptureFilter::Layers(indices) => {
278 if let Some(suffix) = name.strip_prefix("l_out-") {
279 if let Ok(idx) = suffix.parse::<usize>() {
280 return indices.contains(&idx);
281 }
282 }
283 false
284 }
285 CaptureFilter::Names(names) => names.iter().any(|n| n == name),
286 CaptureFilter::Prefix(prefix) => name.starts_with(prefix.as_str()),
287 CaptureFilter::All => true,
288 }
289 }
290
291 fn store(&mut self, name: String, ne0: usize, ne1: usize, data: Vec<f32>) {
292 let layer = name
293 .strip_prefix("l_out-")
294 .and_then(|s| s.parse::<usize>().ok());
295
296 self.captured.insert(
297 name.clone(),
298 CapturedTensor {
299 name,
300 layer,
301 ne0,
302 ne1,
303 data,
304 },
305 );
306 }
307}
308
309/// `cb_eval` callback installed by [`LlamaContextParams::with_tensor_capture`](crate::LlamaContextParams::with_tensor_capture).
310///
311/// # Safety
312///
313/// `user_data` must point to a live [`TensorCapture`] for the context lifetime.
314pub(crate) unsafe extern "C" fn tensor_capture_callback(
315 t: *mut llama_cpp_sys_4::ggml_tensor,
316 ask: bool,
317 user_data: *mut std::ffi::c_void,
318) -> bool {
319 if t.is_null() || user_data.is_null() {
320 return false;
321 }
322
323 let name_bytes = &(*t).name;
324 let len = name_bytes
325 .iter()
326 .position(|&b| b == 0)
327 .unwrap_or(name_bytes.len());
328 let name = std::str::from_utf8_unchecked(std::slice::from_raw_parts(
329 name_bytes.as_ptr().cast::<u8>(),
330 len,
331 ));
332
333 let state = &mut *user_data.cast::<TensorCapture>();
334
335 if !state.matches(name) {
336 return false;
337 }
338
339 if ask {
340 return true;
341 }
342
343 let ne0 = usize::try_from((*t).ne[0]).expect("tensor ne[0] must be non-negative");
344 let ne1 = usize::try_from((*t).ne[1]).expect("tensor ne[1] must be non-negative");
345 let n_elements = ne0 * ne1;
346
347 let mut buf = vec![0f32; n_elements];
348 llama_cpp_sys_4::ggml_backend_tensor_get(
349 t,
350 buf.as_mut_ptr().cast::<std::ffi::c_void>(),
351 0,
352 n_elements * std::mem::size_of::<f32>(),
353 );
354
355 state.store(name.to_string(), ne0, ne1, buf);
356
357 true
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_for_layers_matching() {
366 let capture = TensorCapture::for_layers(&[13, 20, 27]);
367 assert!(capture.matches("l_out-13"));
368 assert!(capture.matches("l_out-20"));
369 assert!(capture.matches("l_out-27"));
370 assert!(!capture.matches("l_out-0"));
371 assert!(!capture.matches("l_out-14"));
372 assert!(!capture.matches("attn_norm-13"));
373 assert!(!capture.matches("result_norm"));
374 }
375
376 #[test]
377 fn test_for_names_matching() {
378 let capture = TensorCapture::for_names(&["result_norm", "l_out-27"]);
379 assert!(capture.matches("result_norm"));
380 assert!(capture.matches("l_out-27"));
381 assert!(!capture.matches("l_out-13"));
382 assert!(!capture.matches("result_output"));
383 }
384
385 #[test]
386 fn test_for_prefix_matching() {
387 let capture = TensorCapture::for_prefix("attn_out");
388 assert!(capture.matches("attn_out-0"));
389 assert!(capture.matches("attn_out-27"));
390 assert!(!capture.matches("attn_norm-0"));
391 assert!(!capture.matches("l_out-0"));
392 }
393
394 #[test]
395 fn test_all_matching() {
396 let capture = TensorCapture::all();
397 assert!(capture.matches("l_out-13"));
398 assert!(capture.matches("result_norm"));
399 assert!(capture.matches("anything"));
400 }
401
402 #[test]
403 fn test_store_and_get() {
404 let mut capture = TensorCapture::for_layers(&[13]);
405 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
406 capture.store("l_out-13".to_string(), 3, 2, data.clone());
407
408 assert_eq!(capture.len(), 1);
409 assert!(!capture.is_empty());
410
411 let ct = capture.get("l_out-13").unwrap();
412 assert_eq!(ct.name, "l_out-13");
413 assert_eq!(ct.layer, Some(13));
414 assert_eq!(ct.n_embd(), 3);
415 assert_eq!(ct.n_tokens(), 2);
416 assert_eq!(ct.data, data);
417
418 let ct2 = capture.get_layer(13).unwrap();
419 assert_eq!(ct2.name, ct.name);
420 assert!(capture.has_layer(13));
421 assert!(!capture.has_layer(14));
422 }
423
424 #[test]
425 fn test_token_embedding() {
426 let mut capture = TensorCapture::for_layers(&[5]);
427 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
428 capture.store("l_out-5".to_string(), 3, 2, data);
429
430 let ct = capture.get_layer(5).unwrap();
431 assert_eq!(ct.token_embedding(0), Some(&[1.0, 2.0, 3.0][..]));
432 assert_eq!(ct.token_embedding(1), Some(&[4.0, 5.0, 6.0][..]));
433 assert_eq!(ct.token_embedding(2), None);
434 }
435
436 #[test]
437 fn test_captured_layers() {
438 let mut capture = TensorCapture::for_layers(&[5, 10, 20]);
439 capture.store("l_out-10".to_string(), 2, 1, vec![0.0, 0.0]);
440 capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
441 assert_eq!(capture.captured_layers(), vec![5, 10]);
442 }
443
444 #[test]
445 fn test_clear() {
446 let mut capture = TensorCapture::for_layers(&[5]);
447 capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
448 assert_eq!(capture.len(), 1);
449 capture.clear();
450 assert_eq!(capture.len(), 0);
451 assert!(capture.is_empty());
452 }
453
454 #[test]
455 fn test_non_layer_tensor() {
456 let mut capture = TensorCapture::for_names(&["result_norm"]);
457 capture.store("result_norm".to_string(), 4, 3, vec![0.0; 12]);
458 let ct = capture.get("result_norm").unwrap();
459 assert_eq!(ct.layer, None);
460 assert_eq!(ct.n_embd(), 4);
461 assert_eq!(ct.n_tokens(), 3);
462 }
463}