use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct CapturedTensor {
pub name: String,
pub layer: Option<usize>,
pub ne0: usize,
pub ne1: usize,
pub data: Vec<f32>,
}
impl CapturedTensor {
#[inline]
pub fn n_embd(&self) -> usize {
self.ne0
}
#[inline]
pub fn n_tokens(&self) -> usize {
self.ne1
}
pub fn token_embedding(&self, token_idx: usize) -> Option<&[f32]> {
if token_idx >= self.ne1 {
return None;
}
let start = token_idx * self.ne0;
Some(&self.data[start..start + self.ne0])
}
}
#[derive(Debug, Clone)]
enum CaptureFilter {
Layers(Vec<usize>),
Names(Vec<String>),
Prefix(String),
All,
}
pub struct TensorCapture {
filter: CaptureFilter,
captured: HashMap<String, CapturedTensor>,
}
impl std::fmt::Debug for TensorCapture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TensorCapture")
.field("filter", &self.filter)
.field("captured_count", &self.captured.len())
.field(
"captured_keys",
&self.captured.keys().collect::<Vec<_>>(),
)
.finish()
}
}
impl TensorCapture {
pub fn for_layers(layer_indices: &[usize]) -> Self {
Self {
filter: CaptureFilter::Layers(layer_indices.to_vec()),
captured: HashMap::new(),
}
}
pub fn for_names(names: &[&str]) -> Self {
Self {
filter: CaptureFilter::Names(names.iter().map(|s| s.to_string()).collect()),
captured: HashMap::new(),
}
}
pub fn for_prefix(prefix: &str) -> Self {
Self {
filter: CaptureFilter::Prefix(prefix.to_string()),
captured: HashMap::new(),
}
}
pub fn all() -> Self {
Self {
filter: CaptureFilter::All,
captured: HashMap::new(),
}
}
pub fn clear(&mut self) {
self.captured.clear();
}
pub fn get(&self, name: &str) -> Option<&CapturedTensor> {
self.captured.get(name)
}
pub fn get_layer(&self, layer_idx: usize) -> Option<&CapturedTensor> {
self.captured.get(&format!("l_out-{layer_idx}"))
}
pub fn has_layer(&self, layer_idx: usize) -> bool {
self.captured.contains_key(&format!("l_out-{layer_idx}"))
}
pub fn len(&self) -> usize {
self.captured.len()
}
pub fn is_empty(&self) -> bool {
self.captured.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &CapturedTensor)> {
self.captured.iter().map(|(k, v)| (k.as_str(), v))
}
pub fn captured_layers(&self) -> Vec<usize> {
let mut layers: Vec<usize> = self
.captured
.values()
.filter_map(|ct| ct.layer)
.collect();
layers.sort_unstable();
layers.dedup();
layers
}
fn matches(&self, name: &str) -> bool {
match &self.filter {
CaptureFilter::Layers(indices) => {
if let Some(suffix) = name.strip_prefix("l_out-") {
if let Ok(idx) = suffix.parse::<usize>() {
return indices.contains(&idx);
}
}
false
}
CaptureFilter::Names(names) => names.iter().any(|n| n == name),
CaptureFilter::Prefix(prefix) => name.starts_with(prefix.as_str()),
CaptureFilter::All => true,
}
}
fn store(&mut self, name: String, ne0: usize, ne1: usize, data: Vec<f32>) {
let layer = name
.strip_prefix("l_out-")
.and_then(|s| s.parse::<usize>().ok());
self.captured.insert(
name.clone(),
CapturedTensor {
name,
layer,
ne0,
ne1,
data,
},
);
}
}
pub(crate) unsafe extern "C" fn tensor_capture_callback(
t: *mut llama_cpp_sys_4::ggml_tensor,
ask: bool,
user_data: *mut std::ffi::c_void,
) -> bool {
if t.is_null() || user_data.is_null() {
return false;
}
let name_bytes = &(*t).name;
let len = name_bytes
.iter()
.position(|&b| b == 0)
.unwrap_or(name_bytes.len());
let name = std::str::from_utf8_unchecked(std::slice::from_raw_parts(
name_bytes.as_ptr() as *const u8,
len,
));
let state = &mut *(user_data as *mut TensorCapture);
if !state.matches(name) {
return false;
}
if ask {
return true;
}
let ne0 = (*t).ne[0] as usize;
let ne1 = (*t).ne[1] as usize;
let n_elements = ne0 * ne1;
let mut buf = vec![0f32; n_elements];
llama_cpp_sys_4::ggml_backend_tensor_get(
t,
buf.as_mut_ptr() as *mut std::ffi::c_void,
0,
n_elements * std::mem::size_of::<f32>(),
);
state.store(name.to_string(), ne0, ne1, buf);
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_for_layers_matching() {
let capture = TensorCapture::for_layers(&[13, 20, 27]);
assert!(capture.matches("l_out-13"));
assert!(capture.matches("l_out-20"));
assert!(capture.matches("l_out-27"));
assert!(!capture.matches("l_out-0"));
assert!(!capture.matches("l_out-14"));
assert!(!capture.matches("attn_norm-13"));
assert!(!capture.matches("result_norm"));
}
#[test]
fn test_for_names_matching() {
let capture = TensorCapture::for_names(&["result_norm", "l_out-27"]);
assert!(capture.matches("result_norm"));
assert!(capture.matches("l_out-27"));
assert!(!capture.matches("l_out-13"));
assert!(!capture.matches("result_output"));
}
#[test]
fn test_for_prefix_matching() {
let capture = TensorCapture::for_prefix("attn_out");
assert!(capture.matches("attn_out-0"));
assert!(capture.matches("attn_out-27"));
assert!(!capture.matches("attn_norm-0"));
assert!(!capture.matches("l_out-0"));
}
#[test]
fn test_all_matching() {
let capture = TensorCapture::all();
assert!(capture.matches("l_out-13"));
assert!(capture.matches("result_norm"));
assert!(capture.matches("anything"));
}
#[test]
fn test_store_and_get() {
let mut capture = TensorCapture::for_layers(&[13]);
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
capture.store("l_out-13".to_string(), 3, 2, data.clone());
assert_eq!(capture.len(), 1);
assert!(!capture.is_empty());
let ct = capture.get("l_out-13").unwrap();
assert_eq!(ct.name, "l_out-13");
assert_eq!(ct.layer, Some(13));
assert_eq!(ct.n_embd(), 3);
assert_eq!(ct.n_tokens(), 2);
assert_eq!(ct.data, data);
let ct2 = capture.get_layer(13).unwrap();
assert_eq!(ct2.name, ct.name);
assert!(capture.has_layer(13));
assert!(!capture.has_layer(14));
}
#[test]
fn test_token_embedding() {
let mut capture = TensorCapture::for_layers(&[5]);
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
capture.store("l_out-5".to_string(), 3, 2, data);
let ct = capture.get_layer(5).unwrap();
assert_eq!(ct.token_embedding(0), Some(&[1.0, 2.0, 3.0][..]));
assert_eq!(ct.token_embedding(1), Some(&[4.0, 5.0, 6.0][..]));
assert_eq!(ct.token_embedding(2), None);
}
#[test]
fn test_captured_layers() {
let mut capture = TensorCapture::for_layers(&[5, 10, 20]);
capture.store("l_out-10".to_string(), 2, 1, vec![0.0, 0.0]);
capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
assert_eq!(capture.captured_layers(), vec![5, 10]);
}
#[test]
fn test_clear() {
let mut capture = TensorCapture::for_layers(&[5]);
capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
assert_eq!(capture.len(), 1);
capture.clear();
assert_eq!(capture.len(), 0);
assert!(capture.is_empty());
}
#[test]
fn test_non_layer_tensor() {
let mut capture = TensorCapture::for_names(&["result_norm"]);
capture.store("result_norm".to_string(), 4, 3, vec![0.0; 12]);
let ct = capture.get("result_norm").unwrap();
assert_eq!(ct.layer, None);
assert_eq!(ct.n_embd(), 4);
assert_eq!(ct.n_tokens(), 3);
}
}