1use std::collections::HashMap;
44
45#[derive(Debug, Clone)]
47pub struct CapturedTensor {
48 pub name: String,
50 pub layer: Option<usize>,
53 pub ne0: usize,
55 pub ne1: usize,
57 pub data: Vec<f32>,
65}
66
67impl CapturedTensor {
68 #[inline]
70 #[must_use]
71 pub fn n_embd(&self) -> usize {
72 self.ne0
73 }
74
75 #[inline]
77 #[must_use]
78 pub fn n_tokens(&self) -> usize {
79 self.ne1
80 }
81
82 #[must_use]
87 pub fn token_embedding(&self, token_idx: usize) -> Option<&[f32]> {
88 if token_idx >= self.ne1 {
89 return None;
90 }
91 let start = token_idx * self.ne0;
92 Some(&self.data[start..start + self.ne0])
93 }
94}
95
96#[derive(Debug, Clone)]
98enum CaptureFilter {
99 Layers(Vec<usize>),
101 Names(Vec<String>),
103 Prefix(String),
105 All,
107}
108
109pub struct TensorCapture {
122 filter: CaptureFilter,
123 captured: HashMap<String, CapturedTensor>,
125}
126
127impl std::fmt::Debug for TensorCapture {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 f.debug_struct("TensorCapture")
130 .field("filter", &self.filter)
131 .field("captured_count", &self.captured.len())
132 .field("captured_keys", &self.captured.keys().collect::<Vec<_>>())
133 .finish()
134 }
135}
136
137impl TensorCapture {
138 #[must_use]
151 pub fn for_layers(layer_indices: &[usize]) -> Self {
152 Self {
153 filter: CaptureFilter::Layers(layer_indices.to_vec()),
154 captured: HashMap::new(),
155 }
156 }
157
158 #[must_use]
166 pub fn for_names(names: &[&str]) -> Self {
167 Self {
168 filter: CaptureFilter::Names(
169 names.iter().map(std::string::ToString::to_string).collect(),
170 ),
171 captured: HashMap::new(),
172 }
173 }
174
175 #[must_use]
185 pub fn for_prefix(prefix: &str) -> Self {
186 Self {
187 filter: CaptureFilter::Prefix(prefix.to_string()),
188 captured: HashMap::new(),
189 }
190 }
191
192 #[must_use]
197 pub fn all() -> Self {
198 Self {
199 filter: CaptureFilter::All,
200 captured: HashMap::new(),
201 }
202 }
203
204 pub fn clear(&mut self) {
209 self.captured.clear();
210 }
211
212 #[must_use]
214 pub fn get(&self, name: &str) -> Option<&CapturedTensor> {
215 self.captured.get(name)
216 }
217
218 #[must_use]
222 pub fn get_layer(&self, layer_idx: usize) -> Option<&CapturedTensor> {
223 self.captured.get(&format!("l_out-{layer_idx}"))
224 }
225
226 #[must_use]
228 pub fn has_layer(&self, layer_idx: usize) -> bool {
229 self.captured.contains_key(&format!("l_out-{layer_idx}"))
230 }
231
232 #[must_use]
234 pub fn len(&self) -> usize {
235 self.captured.len()
236 }
237
238 #[must_use]
240 pub fn is_empty(&self) -> bool {
241 self.captured.is_empty()
242 }
243
244 pub fn iter(&self) -> impl Iterator<Item = (&str, &CapturedTensor)> {
246 self.captured.iter().map(|(k, v)| (k.as_str(), v))
247 }
248
249 #[must_use]
251 pub fn captured_layers(&self) -> Vec<usize> {
252 let mut layers: Vec<usize> = self.captured.values().filter_map(|ct| ct.layer).collect();
253 layers.sort_unstable();
254 layers.dedup();
255 layers
256 }
257
258 fn matches(&self, name: &str) -> bool {
262 match &self.filter {
263 CaptureFilter::Layers(indices) => {
264 if let Some(suffix) = name.strip_prefix("l_out-") {
265 if let Ok(idx) = suffix.parse::<usize>() {
266 return indices.contains(&idx);
267 }
268 }
269 false
270 }
271 CaptureFilter::Names(names) => names.iter().any(|n| n == name),
272 CaptureFilter::Prefix(prefix) => name.starts_with(prefix.as_str()),
273 CaptureFilter::All => true,
274 }
275 }
276
277 fn store(&mut self, name: String, ne0: usize, ne1: usize, data: Vec<f32>) {
279 let layer = name
280 .strip_prefix("l_out-")
281 .and_then(|s| s.parse::<usize>().ok());
282
283 self.captured.insert(
284 name.clone(),
285 CapturedTensor {
286 name,
287 layer,
288 ne0,
289 ne1,
290 data,
291 },
292 );
293 }
294}
295
296pub(crate) unsafe extern "C" fn tensor_capture_callback(
305 t: *mut llama_cpp_sys_4::ggml_tensor,
306 ask: bool,
307 user_data: *mut std::ffi::c_void,
308) -> bool {
309 if t.is_null() || user_data.is_null() {
310 return false;
311 }
312
313 let name_bytes = &(*t).name;
315 let len = name_bytes
316 .iter()
317 .position(|&b| b == 0)
318 .unwrap_or(name_bytes.len());
319 let name = std::str::from_utf8_unchecked(std::slice::from_raw_parts(
320 name_bytes.as_ptr().cast::<u8>(),
321 len,
322 ));
323
324 let state = &mut *user_data.cast::<TensorCapture>();
325
326 if !state.matches(name) {
327 return false;
328 }
329
330 if ask {
331 return true;
332 }
333
334 let ne0 = (*t).ne[0] as usize;
336 let ne1 = (*t).ne[1] as usize;
337 let n_elements = ne0 * ne1;
338
339 let mut buf = vec![0f32; n_elements];
340 llama_cpp_sys_4::ggml_backend_tensor_get(
341 t,
342 buf.as_mut_ptr().cast::<std::ffi::c_void>(),
343 0,
344 n_elements * std::mem::size_of::<f32>(),
345 );
346
347 state.store(name.to_string(), ne0, ne1, buf);
348
349 true
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_for_layers_matching() {
358 let capture = TensorCapture::for_layers(&[13, 20, 27]);
359 assert!(capture.matches("l_out-13"));
360 assert!(capture.matches("l_out-20"));
361 assert!(capture.matches("l_out-27"));
362 assert!(!capture.matches("l_out-0"));
363 assert!(!capture.matches("l_out-14"));
364 assert!(!capture.matches("attn_norm-13"));
365 assert!(!capture.matches("result_norm"));
366 }
367
368 #[test]
369 fn test_for_names_matching() {
370 let capture = TensorCapture::for_names(&["result_norm", "l_out-27"]);
371 assert!(capture.matches("result_norm"));
372 assert!(capture.matches("l_out-27"));
373 assert!(!capture.matches("l_out-13"));
374 assert!(!capture.matches("result_output"));
375 }
376
377 #[test]
378 fn test_for_prefix_matching() {
379 let capture = TensorCapture::for_prefix("attn_out");
380 assert!(capture.matches("attn_out-0"));
381 assert!(capture.matches("attn_out-27"));
382 assert!(!capture.matches("attn_norm-0"));
383 assert!(!capture.matches("l_out-0"));
384 }
385
386 #[test]
387 fn test_all_matching() {
388 let capture = TensorCapture::all();
389 assert!(capture.matches("l_out-13"));
390 assert!(capture.matches("result_norm"));
391 assert!(capture.matches("anything"));
392 }
393
394 #[test]
395 fn test_store_and_get() {
396 let mut capture = TensorCapture::for_layers(&[13]);
397 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
398 capture.store("l_out-13".to_string(), 3, 2, data.clone());
399
400 assert_eq!(capture.len(), 1);
401 assert!(!capture.is_empty());
402
403 let ct = capture.get("l_out-13").unwrap();
404 assert_eq!(ct.name, "l_out-13");
405 assert_eq!(ct.layer, Some(13));
406 assert_eq!(ct.n_embd(), 3);
407 assert_eq!(ct.n_tokens(), 2);
408 assert_eq!(ct.data, data);
409
410 let ct2 = capture.get_layer(13).unwrap();
412 assert_eq!(ct2.name, ct.name);
413 assert!(capture.has_layer(13));
414 assert!(!capture.has_layer(14));
415 }
416
417 #[test]
418 fn test_token_embedding() {
419 let mut capture = TensorCapture::for_layers(&[5]);
420 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
422 capture.store("l_out-5".to_string(), 3, 2, data);
423
424 let ct = capture.get_layer(5).unwrap();
425 assert_eq!(ct.token_embedding(0), Some(&[1.0, 2.0, 3.0][..]));
426 assert_eq!(ct.token_embedding(1), Some(&[4.0, 5.0, 6.0][..]));
427 assert_eq!(ct.token_embedding(2), None);
428 }
429
430 #[test]
431 fn test_captured_layers() {
432 let mut capture = TensorCapture::for_layers(&[5, 10, 20]);
433 capture.store("l_out-10".to_string(), 2, 1, vec![0.0, 0.0]);
434 capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
435 assert_eq!(capture.captured_layers(), vec![5, 10]);
436 }
437
438 #[test]
439 fn test_clear() {
440 let mut capture = TensorCapture::for_layers(&[5]);
441 capture.store("l_out-5".to_string(), 2, 1, vec![0.0, 0.0]);
442 assert_eq!(capture.len(), 1);
443 capture.clear();
444 assert_eq!(capture.len(), 0);
445 assert!(capture.is_empty());
446 }
447
448 #[test]
449 fn test_non_layer_tensor() {
450 let mut capture = TensorCapture::for_names(&["result_norm"]);
451 capture.store("result_norm".to_string(), 4, 3, vec![0.0; 12]);
452 let ct = capture.get("result_norm").unwrap();
453 assert_eq!(ct.layer, None);
454 assert_eq!(ct.n_embd(), 4);
455 assert_eq!(ct.n_tokens(), 3);
456 }
457}