Skip to main content

Module tensor_capture

Module tensor_capture 

Source
Expand description

Capture intermediate tensor outputs during crate::LlamaContext::decode.

llama.cpp builds a computation graph for each forward pass. Every node has a string name — for transformer blocks the layer output is typically "l_out-{N}" (e.g. "l_out-13"), attention norms are "attn_norm-{N}", and the final norm is "result_norm".

The graph evaluation callback (cb_eval) runs in two phases for each node:

PhaseaskBehaviour
AsktrueReturn true to request a copy of this tensor’s data.
DatafalseTensor is computed; data is copied via ggml_backend_tensor_get.

TensorCapture implements that callback and stores matching tensors in a HashMap you can read after decode() finishes.

§Typical use cases

  • Layer probing — inspect hidden states at specific depths.
  • EAGLE / distillation — read draft-model anchor layers (see examples/eagle).
  • Debugging — dump norms or attention outputs with TensorCapture::for_prefix.

§Setup

  1. Build a TensorCapture with the filter you need (TensorCapture::for_layers is the common case).
  2. Pass it to LlamaContextParams::with_tensor_capture. The capture must outlive the LlamaContext.
  3. Run LlamaContext::decode as usual.
  4. Read CapturedTensor values via TensorCapture::get_layer, TensorCapture::get, or TensorCapture::iter.

Call TensorCapture::clear before reusing the same capture on another batch.

§Example

use llama_cpp_4::prelude::*;
use std::num::NonZeroU32;

fn main() {
    let backend = LlamaBackend::init().unwrap();
    let model = LlamaModel::load_from_file(
        &backend,
        "model.gguf",
        &LlamaModelParams::default(),
    )
    .unwrap();

    let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
    let ctx_params = LlamaContextParams::default()
        .with_n_ctx(NonZeroU32::new(512))
        .with_tensor_capture(&mut capture);
    let mut ctx = model.new_context(&backend, ctx_params).unwrap();

    let tokens = model.str_to_token("Hello", AddBos::Always).unwrap();
    let mut batch = LlamaBatch::new(512, 1);
    for (i, &tok) in tokens.iter().enumerate() {
        batch
            .add(tok, i as i32, &[0], i == tokens.len() - 1)
            .unwrap();
    }
    ctx.decode(&mut batch).unwrap();

    for &layer in &[13, 20, 27] {
        if let Some(t) = capture.get_layer(layer) {
            println!(
                "l_out-{layer}: {} tokens × {} dims",
                t.n_tokens(),
                t.n_embd()
            );
            if let Some(vec) = t.token_embedding(0) {
                println!("  first token, first 3 dims: {:?}", &vec[..3.min(vec.len())]);
            }
        }
    }
}

§Tensor layout

Each CapturedTensor stores a flat f32 buffer with data[token_idx * n_embd + dim_idx] (ggml row-major: ne0 = embedding dim, ne1 = token count). Use CapturedTensor::token_embedding to slice one row.

Structs§

CapturedTensor
A single tensor copied out of the decode graph.
TensorCapture
Captures intermediate tensors during crate::LlamaContext::decode.