1use std::collections::HashMap;
44
45#[derive(Debug, Clone)]
47pub struct CapturedTensor {
48 pub name: String,
50 pub layer: Option<usize>,
53 pub ne0: usize,
55 pub ne1: usize,
57 pub data: Vec<f32>,
65}
66
67impl CapturedTensor {
68 #[inline]
70 pub fn n_embd(&self) -> usize {
71 self.ne0
72 }
73
74 #[inline]
76 pub fn n_tokens(&self) -> usize {
77 self.ne1
78 }
79
80 pub fn token_embedding(&self, token_idx: usize) -> Option<&[f32]> {
85 if token_idx >= self.ne1 {
86 return None;
87 }
88 let start = token_idx * self.ne0;
89 Some(&self.data[start..start + self.ne0])
90 }
91}
92
93#[derive(Debug, Clone)]
95enum CaptureFilter {
96 Layers(Vec<usize>),
98 Names(Vec<String>),
100 Prefix(String),
102 All,
104}
105
106pub struct TensorCapture {
119 filter: CaptureFilter,
120 captured: HashMap<String, CapturedTensor>,
122}
123
124impl std::fmt::Debug for TensorCapture {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("TensorCapture")
127 .field("filter", &self.filter)
128 .field("captured_count", &self.captured.len())
129 .field(
130 "captured_keys",
131 &self.captured.keys().collect::<Vec<_>>(),
132 )
133 .finish()
134 }
135}
136
137impl TensorCapture {
138 pub fn for_layers(layer_indices: &[usize]) -> Self {
151 Self {
152 filter: CaptureFilter::Layers(layer_indices.to_vec()),
153 captured: HashMap::new(),
154 }
155 }
156
157 pub fn for_names(names: &[&str]) -> Self {
165 Self {
166 filter: CaptureFilter::Names(names.iter().map(|s| s.to_string()).collect()),
167 captured: HashMap::new(),
168 }
169 }
170
171 pub fn for_prefix(prefix: &str) -> Self {
181 Self {
182 filter: CaptureFilter::Prefix(prefix.to_string()),
183 captured: HashMap::new(),
184 }
185 }
186
187 pub fn all() -> Self {
192 Self {
193 filter: CaptureFilter::All,
194 captured: HashMap::new(),
195 }
196 }
197
198 pub fn clear(&mut self) {
203 self.captured.clear();
204 }
205
206 pub fn get(&self, name: &str) -> Option<&CapturedTensor> {
208 self.captured.get(name)
209 }
210
211 pub fn get_layer(&self, layer_idx: usize) -> Option<&CapturedTensor> {
215 self.captured.get(&format!("l_out-{layer_idx}"))
216 }
217
218 pub fn has_layer(&self, layer_idx: usize) -> bool {
220 self.captured.contains_key(&format!("l_out-{layer_idx}"))
221 }
222
223 pub fn len(&self) -> usize {
225 self.captured.len()
226 }
227
228 pub fn is_empty(&self) -> bool {
230 self.captured.is_empty()
231 }
232
233 pub fn iter(&self) -> impl Iterator<Item = (&str, &CapturedTensor)> {
235 self.captured.iter().map(|(k, v)| (k.as_str(), v))
236 }
237
238 pub fn captured_layers(&self) -> Vec<usize> {
240 let mut layers: Vec<usize> = self
241 .captured
242 .values()
243 .filter_map(|ct| ct.layer)
244 .collect();
245 layers.sort_unstable();
246 layers.dedup();
247 layers
248 }
249
250 fn matches(&self, name: &str) -> bool {
254 match &self.filter {
255 CaptureFilter::Layers(indices) => {
256 if let Some(suffix) = name.strip_prefix("l_out-") {
257 if let Ok(idx) = suffix.parse::<usize>() {
258 return indices.contains(&idx);
259 }
260 }
261 false
262 }
263 CaptureFilter::Names(names) => names.iter().any(|n| n == name),
264 CaptureFilter::Prefix(prefix) => name.starts_with(prefix.as_str()),
265 CaptureFilter::All => true,
266 }
267 }
268
269 fn store(&mut self, name: String, ne0: usize, ne1: usize, data: Vec<f32>) {
271 let layer = name
272 .strip_prefix("l_out-")
273 .and_then(|s| s.parse::<usize>().ok());
274
275 self.captured.insert(
276 name.clone(),
277 CapturedTensor {
278 name,
279 layer,
280 ne0,
281 ne1,
282 data,
283 },
284 );
285 }
286}
287
288pub(crate) unsafe extern "C" fn tensor_capture_callback(
297 t: *mut llama_cpp_sys_4::ggml_tensor,
298 ask: bool,
299 user_data: *mut std::ffi::c_void,
300) -> bool {
301 if t.is_null() || user_data.is_null() {
302 return false;
303 }
304
305 let name_bytes = &(*t).name;
307 let len = name_bytes
308 .iter()
309 .position(|&b| b == 0)
310 .unwrap_or(name_bytes.len());
311 let name = std::str::from_utf8_unchecked(std::slice::from_raw_parts(
312 name_bytes.as_ptr() as *const u8,
313 len,
314 ));
315
316 let state = &mut *(user_data as *mut TensorCapture);
317
318 if !state.matches(name) {
319 return false;
320 }
321
322 if ask {
323 return true;
324 }
325
326 let ne0 = (*t).ne[0] as usize;
328 let ne1 = (*t).ne[1] as usize;
329 let n_elements = ne0 * ne1;
330
331 let mut buf = vec![0f32; n_elements];
332 llama_cpp_sys_4::ggml_backend_tensor_get(
333 t,
334 buf.as_mut_ptr() as *mut std::ffi::c_void,
335 0,
336 n_elements * std::mem::size_of::<f32>(),
337 );
338
339 state.store(name.to_string(), ne0, ne1, buf);
340
341 true
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_for_layers_matching() {
350 let capture = TensorCapture::for_layers(&[13, 20, 27]);
351 assert!(capture.matches("l_out-13"));
352 assert!(capture.matches("l_out-20"));
353 assert!(capture.matches("l_out-27"));
354 assert!(!capture.matches("l_out-0"));
355 assert!(!capture.matches("l_out-14"));
356 assert!(!capture.matches("attn_norm-13"));
357 assert!(!capture.matches("result_norm"));
358 }
359
360 #[test]
361 fn test_for_names_matching() {
362 let capture = TensorCapture::for_names(&["result_norm", "l_out-27"]);
363 assert!(capture.matches("result_norm"));
364 assert!(capture.matches("l_out-27"));
365 assert!(!capture.matches("l_out-13"));
366 assert!(!capture.matches("result_output"));
367 }
368
369 #[test]
370 fn test_for_prefix_matching() {
371 let capture = TensorCapture::for_prefix("attn_out");
372 assert!(capture.matches("attn_out-0"));
373 assert!(capture.matches("attn_out-27"));
374 assert!(!capture.matches("attn_norm-0"));
375 assert!(!capture.matches("l_out-0"));
376 }
377
378 #[test]
379 fn test_all_matching() {
380 let capture = TensorCapture::all();
381 assert!(capture.matches("l_out-13"));
382 assert!(capture.matches("result_norm"));
383 assert!(capture.matches("anything"));
384 }
385
386 #[test]
387 fn test_store_and_get() {
388 let mut capture = TensorCapture::for_layers(&[13]);
389 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
390 capture.store("l_out-13".to_string(), 3, 2, data.clone());
391
392 assert_eq!(capture.len(), 1);
393 assert!(!capture.is_empty());
394
395 let ct = capture.get("l_out-13").unwrap();
396 assert_eq!(ct.name, "l_out-13");
397 assert_eq!(ct.layer, Some(13));
398 assert_eq!(ct.n_embd(), 3);
399 assert_eq!(ct.n_tokens(), 2);
400 assert_eq!(ct.data, data);
401
402 let ct2 = capture.get_layer(13).unwrap();
404 assert_eq!(ct2.name, ct.name);
405 assert!(capture.has_layer(13));
406 assert!(!capture.has_layer(14));
407 }
408
409 #[test]
410 fn test_token_embedding() {
411 let mut capture = TensorCapture::for_layers(&[5]);
412 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
414 capture.store("l_out-5".to_string(), 3, 2, data);
415
416 let ct = capture.get_layer(5).unwrap();
417 assert_eq!(ct.token_embedding(0), Some(&[1.0, 2.0, 3.0][..]));
418 assert_eq!(ct.token_embedding(1), Some(&[4.0, 5.0, 6.0][..]));
419 assert_eq!(ct.token_embedding(2), None);
420 }
421
422 #[test]
423 fn test_captured_layers() {
424 let mut capture = TensorCapture::for_layers(&[5, 10, 20]);
425 capture.store("l_out-10".to_string(), 2, 1, vec![0.0, 0.0]);
426 capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
427 assert_eq!(capture.captured_layers(), vec![5, 10]);
428 }
429
430 #[test]
431 fn test_clear() {
432 let mut capture = TensorCapture::for_layers(&[5]);
433 capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
434 assert_eq!(capture.len(), 1);
435 capture.clear();
436 assert_eq!(capture.len(), 0);
437 assert!(capture.is_empty());
438 }
439
440 #[test]
441 fn test_non_layer_tensor() {
442 let mut capture = TensorCapture::for_names(&["result_norm"]);
443 capture.store("result_norm".to_string(), 4, 3, vec![0.0; 12]);
444 let ct = capture.get("result_norm").unwrap();
445 assert_eq!(ct.layer, None);
446 assert_eq!(ct.n_embd(), 4);
447 assert_eq!(ct.n_tokens(), 3);
448 }
449}