Skip to main content

burn_store/
tensor_snapshot.rs

1use alloc::rc::Rc;
2use alloc::string::String;
3use alloc::string::ToString;
4use alloc::vec::Vec;
5use burn_core::module::ParamId;
6use burn_tensor::quantization::{QPARAM_ALIGN, QuantParam, params_shape};
7use burn_tensor::{Bool, DType, Int, Shape, Tensor, TensorData, backend::Backend};
8use half::f16;
9
10/// Returns the byte size of a quantization parameter type.
11// TODO: Add `size_bytes()` method to `QuantParam` in cubecl and use it here.
12const fn quant_param_size(param: QuantParam) -> usize {
13    match param {
14        QuantParam::F32 => core::mem::size_of::<f32>(),
15        QuantParam::F16 | QuantParam::BF16 => core::mem::size_of::<f16>(),
16        QuantParam::UE8M0 | QuantParam::UE4M3 => core::mem::size_of::<u8>(),
17    }
18}
19
20/// Error type for TensorSnapshot operations
21#[derive(Debug, Clone)]
22pub enum TensorSnapshotError {
23    /// I/O error occurred while loading tensor data
24    IoError(String),
25    /// Data corruption or invalid format
26    DataError(String),
27    /// Panic occurred while loading tensor data
28    PanicError(String),
29}
30
31impl core::fmt::Display for TensorSnapshotError {
32    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
33        match self {
34            Self::IoError(e) => write!(f, "I/O error: {}", e),
35            Self::DataError(e) => write!(f, "Data error: {}", e),
36            Self::PanicError(e) => write!(f, "Panic error: {}", e),
37        }
38    }
39}
40
41impl core::error::Error for TensorSnapshotError {}
42
43/// A lightweight snapshot of a tensor that can lazily produce TensorData.
44///
45/// TensorSnapshot stores a cloned tensor internally (which is cheap due to reference counting)
46/// and only materializes the actual data when `to_data()` is called. This allows
47/// efficient inspection of module structure without the overhead of copying all tensor data.
48///
49/// The dtype and shape are cached for efficient access without requiring data materialization,
50/// which is particularly useful for serialization formats that need metadata upfront.
51pub struct TensorSnapshot {
52    /// Function to get tensor data when needed (Rc allows cloning)
53    data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
54    /// Data type of the tensor (cached for efficient access)
55    pub dtype: burn_tensor::DType,
56    /// Shape of the tensor (cached for efficient access)
57    pub shape: Shape,
58    /// Path stack representing the module hierarchy
59    pub path_stack: Option<Vec<String>>,
60    /// Container stack representing the container types at each level
61    pub container_stack: Option<Vec<String>>,
62    /// Unique identifier for the tensor parameter
63    pub tensor_id: Option<ParamId>,
64}
65
66impl TensorSnapshot {
67    /// Create a new tensor snapshot from a float tensor
68    pub fn from_float<B: Backend, const D: usize>(
69        tensor: &Tensor<B, D>,
70        path_stack: Vec<String>,
71        container_stack: Vec<String>,
72        tensor_id: ParamId,
73    ) -> Self {
74        let dtype = tensor.dtype();
75        let shape = tensor.shape();
76        let tensor = tensor.clone(); // Clone is cheap (reference counted)
77        Self {
78            data_fn: Rc::new(move || Ok(tensor.to_data())),
79            dtype,
80            shape,
81            path_stack: Some(path_stack),
82            container_stack: Some(container_stack),
83            tensor_id: Some(tensor_id),
84        }
85    }
86
87    /// Create a new tensor snapshot from an int tensor
88    pub fn from_int<B: Backend, const D: usize>(
89        tensor: &Tensor<B, D, Int>,
90        path_stack: Vec<String>,
91        container_stack: Vec<String>,
92        tensor_id: ParamId,
93    ) -> Self {
94        let dtype = tensor.dtype();
95        let shape = tensor.shape();
96        let tensor = tensor.clone(); // Clone is cheap (reference counted)
97        Self {
98            data_fn: Rc::new(move || Ok(tensor.to_data())),
99            dtype,
100            shape,
101            path_stack: Some(path_stack),
102            container_stack: Some(container_stack),
103            tensor_id: Some(tensor_id),
104        }
105    }
106
107    /// Create a new tensor snapshot from a bool tensor
108    pub fn from_bool<B: Backend, const D: usize>(
109        tensor: &Tensor<B, D, Bool>,
110        path_stack: Vec<String>,
111        container_stack: Vec<String>,
112        tensor_id: ParamId,
113    ) -> Self {
114        let dtype = tensor.dtype();
115        let shape = tensor.shape();
116        let tensor = tensor.clone(); // Clone is cheap (reference counted)
117        Self {
118            data_fn: Rc::new(move || Ok(tensor.to_data())),
119            dtype,
120            shape,
121            path_stack: Some(path_stack),
122            container_stack: Some(container_stack),
123            tensor_id: Some(tensor_id),
124        }
125    }
126
127    /// Convert to TensorData (this is where actual data copy happens)
128    #[cfg(feature = "std")]
129    pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
130        // Use AssertUnwindSafe since we're working with Rc which is not UnwindSafe
131        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (self.data_fn)())).unwrap_or_else(
132            |_| {
133                Err(TensorSnapshotError::PanicError(
134                    "Panic occurred while loading tensor data".to_string(),
135                ))
136            },
137        )
138    }
139
140    /// Convert to TensorData (this is where actual data copy happens)
141    #[cfg(not(feature = "std"))]
142    pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
143        (self.data_fn)() // Can't catch panics in no-std, do it when core::panic::AssertUnwindSafe is available
144    }
145
146    /// Get the full path by joining the path stack
147    pub fn full_path(&self) -> String {
148        self.path_stack
149            .as_ref()
150            .map(|stack| stack.join("."))
151            .unwrap_or_default()
152    }
153
154    /// Get the full container path by joining the container stack
155    pub fn container_path(&self) -> String {
156        self.container_stack
157            .as_ref()
158            .map(|stack| stack.join("."))
159            .unwrap_or_default()
160    }
161
162    /// Get the module type (last Struct/Enum in the hierarchy)
163    ///
164    /// Returns the last user-defined module type, skipping primitive containers
165    /// like "Vec", "Array". This is useful for determining which user-defined
166    /// module a tensor belongs to.
167    ///
168    /// # Examples
169    /// - `Linear.weight` → `Some("Struct:Linear")`
170    /// - `Vec<Linear>[0].weight` → `Some("Struct:Linear")`
171    /// - `Linear.bias` (Optional) → `Some("Struct:Linear")`
172    /// - `Vec<Param>[0]` (no module) → `None`
173    pub fn module_type(&self) -> Option<String> {
174        self.container_stack.as_ref().and_then(|stack| {
175            // Find the last user-defined type (Struct: or Enum:)
176            stack
177                .iter()
178                .rev()
179                .find(|ct| ct.starts_with("Struct:") || ct.starts_with("Enum:"))
180                .cloned()
181        })
182    }
183
184    /// Get the immediate container type (last in the container stack)
185    ///
186    /// Returns the last element in the container stack, which could be a
187    /// user-defined type ("Struct:", "Enum:") or a collection type ("Vec", "Array").
188    /// This is useful for understanding the full container hierarchy.
189    ///
190    /// # Examples
191    /// - `Linear.weight` → `"Struct:Linear"`
192    /// - `Vec<Linear>[0].weight` → `"Struct:Linear"` (the Linear, not the Vec)
193    /// - `Vec<Param>[0]` → `"Vec"`
194    pub fn container_type(&self) -> String {
195        self.container_stack
196            .as_ref()
197            .and_then(|stack| stack.last())
198            .cloned()
199            .unwrap_or_else(|| "Unknown".to_string())
200    }
201
202    /// Create a TensorSnapshot from a closure that produces TensorData
203    /// This is used internally for lazy loading
204    pub fn from_closure(
205        data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
206        dtype: burn_tensor::DType,
207        shape: Shape,
208        path_stack: Vec<String>,
209        container_stack: Vec<String>,
210        tensor_id: ParamId,
211    ) -> Self {
212        Self {
213            data_fn,
214            dtype,
215            shape,
216            path_stack: Some(path_stack),
217            container_stack: Some(container_stack),
218            tensor_id: Some(tensor_id),
219        }
220    }
221
222    /// Create a TensorSnapshot from TensorData directly
223    pub fn from_data(
224        data: TensorData,
225        path_stack: Vec<String>,
226        container_stack: Vec<String>,
227        tensor_id: ParamId,
228    ) -> Self {
229        let dtype = data.dtype;
230        let shape = data.shape.clone();
231        Self {
232            data_fn: Rc::new(move || Ok(data.clone())),
233            dtype,
234            shape,
235            path_stack: Some(path_stack),
236            container_stack: Some(container_stack),
237            tensor_id: Some(tensor_id),
238        }
239    }
240
241    /// Get the size of the tensor data in bytes without materializing it.
242    ///
243    /// For regular (non-quantized) types, this is simply `shape.product() * dtype.size()`.
244    ///
245    /// For quantized types (`QFloat`), this accounts for:
246    /// - The quantized values (packed according to the quantization scheme)
247    /// - Alignment padding (values are aligned to 4-byte boundary)
248    /// - Quantization parameters (scale values appended to the data)
249    pub fn data_len(&self) -> usize {
250        const BITS_PER_BYTE: usize = 8;
251
252        let num_elements: usize = self.shape.iter().product();
253
254        match self.dtype {
255            DType::QFloat(scheme) => {
256                // Calculate value bytes using scheme's packing information
257                let num_storage_elements = num_elements.div_ceil(scheme.num_quants());
258                let value_bytes =
259                    num_storage_elements * (scheme.size_bits_stored() / BITS_PER_BYTE);
260
261                // Calculate number of quantization parameters (scales)
262                let num_params = params_shape(&self.shape, scheme.level).num_elements();
263
264                let aligned_value_bytes = value_bytes.div_ceil(QPARAM_ALIGN) * QPARAM_ALIGN;
265                let scale_bytes = num_params * quant_param_size(scheme.param);
266
267                aligned_value_bytes + scale_bytes
268            }
269            _ => num_elements * self.dtype.size(),
270        }
271    }
272
273    /// Clone the data function for lazy composition
274    pub fn clone_data_fn(&self) -> Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> {
275        self.data_fn.clone()
276    }
277}
278
279impl Clone for TensorSnapshot {
280    fn clone(&self) -> Self {
281        // Clone lazily - keep the same data function
282        Self {
283            data_fn: self.data_fn.clone(),
284            dtype: self.dtype,
285            shape: self.shape.clone(),
286            path_stack: self.path_stack.clone(),
287            container_stack: self.container_stack.clone(),
288            tensor_id: self.tensor_id,
289        }
290    }
291}
292
293impl core::fmt::Debug for TensorSnapshot {
294    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
295        f.debug_struct("TensorSnapshot")
296            .field("dtype", &self.dtype)
297            .field("shape", &self.shape)
298            .field("path_stack", &self.path_stack)
299            .field("container_stack", &self.container_stack)
300            .field("tensor_id", &self.tensor_id)
301            .finish()
302    }
303}
304
305#[cfg(all(test, feature = "std"))]
306mod tests {
307    use super::*;
308    type TestBackend = burn_flex::Flex;
309    use alloc::string::ToString;
310    use burn_tensor::{BoolStore, DType, shape};
311
312    #[test]
313    fn tensor_view_float() {
314        let device = Default::default();
315        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
316
317        let snapshot = TensorSnapshot::from_float(
318            &tensor,
319            vec!["test".to_string(), "weight".to_string()],
320            vec!["TestModule".to_string(), "Param".to_string()],
321            ParamId::new(),
322        );
323
324        // Test metadata access without materialization
325        assert_eq!(snapshot.dtype, DType::F32);
326        assert_eq!(snapshot.shape, shape![2, 2]);
327        assert_eq!(snapshot.full_path(), "test.weight");
328        assert_eq!(snapshot.container_path(), "TestModule.Param");
329
330        // Test data materialization
331        let data = snapshot.to_data().unwrap();
332        assert_eq!(data.shape, shape![2, 2]);
333        assert_eq!(data.dtype, DType::F32);
334    }
335
336    #[test]
337    fn tensor_view_int() {
338        let device = Default::default();
339        let tensor = Tensor::<TestBackend, 2, Int>::from_data([[1, 2], [3, 4]], &device);
340
341        let snapshot = TensorSnapshot::from_int(
342            &tensor,
343            vec!["test".to_string(), "int".to_string()],
344            vec!["TestModule".to_string(), "Param".to_string()],
345            ParamId::new(),
346        );
347
348        // Test metadata access without materialization
349        // TestBackend (Flex) uses I32 for integers
350        assert_eq!(snapshot.dtype, DType::I32);
351        assert_eq!(snapshot.shape, shape![2, 2]);
352
353        let data = snapshot.to_data().unwrap();
354        assert_eq!(data.shape, shape![2, 2]);
355        assert_eq!(data.dtype, DType::I32);
356    }
357
358    #[test]
359    fn tensor_view_bool() {
360        let device = Default::default();
361        let tensor =
362            Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
363
364        let snapshot = TensorSnapshot::from_bool(
365            &tensor,
366            vec!["test".to_string(), "bool".to_string()],
367            vec!["TestModule".to_string(), "Param".to_string()],
368            ParamId::new(),
369        );
370
371        // Test metadata access without materialization
372        assert_eq!(snapshot.dtype, DType::Bool(BoolStore::Native));
373        assert_eq!(snapshot.shape, shape![2, 2]);
374
375        let data = snapshot.to_data().unwrap();
376        assert_eq!(data.shape, shape![2, 2]);
377        assert_eq!(data.dtype, DType::Bool(BoolStore::Native));
378    }
379
380    #[test]
381    fn data_len() {
382        let device = Default::default();
383
384        // Test F32 tensor (4 bytes per element)
385        let tensor_f32 = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
386        let view_f32 = TensorSnapshot::from_float(
387            &tensor_f32,
388            vec!["test".to_string()],
389            vec!["Module".to_string()],
390            ParamId::new(),
391        );
392        assert_eq!(view_f32.data_len(), 16); // 4 elements * 4 bytes
393
394        // Test I32 tensor (4 bytes per element) - TestBackend (Flex) uses I32 for Int
395        let tensor_int =
396            Tensor::<TestBackend, 3, Int>::from_data([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);
397        let view_int = TensorSnapshot::from_int(
398            &tensor_int,
399            vec!["test".to_string()],
400            vec!["Module".to_string()],
401            ParamId::new(),
402        );
403        assert_eq!(view_int.data_len(), 32); // 8 elements * 4 bytes (I32)
404
405        // Test Bool tensor (1 byte per element)
406        let tensor_bool =
407            Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
408        let view_bool = TensorSnapshot::from_bool(
409            &tensor_bool,
410            vec!["test".to_string()],
411            vec!["Module".to_string()],
412            ParamId::new(),
413        );
414        assert_eq!(view_bool.data_len(), 4); // 4 elements * 1 byte
415    }
416
417    #[test]
418    fn from_closure() {
419        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
420        let dtype = data.dtype;
421        let shape = data.shape.clone();
422
423        let snapshot = TensorSnapshot::from_closure(
424            Rc::new(move || Ok(data.clone())),
425            dtype,
426            shape.clone(),
427            vec!["model".to_string(), "layer".to_string()],
428            vec!["Model".to_string(), "Layer".to_string()],
429            ParamId::new(),
430        );
431
432        // Test metadata access
433        assert_eq!(snapshot.dtype, DType::F32);
434        assert_eq!(snapshot.shape, shape![4]);
435        assert_eq!(snapshot.full_path(), "model.layer");
436        assert_eq!(snapshot.data_len(), 16); // 4 * 4 bytes
437
438        // Test data materialization
439        let materialized = snapshot.to_data().unwrap();
440        assert_eq!(materialized.shape, shape![4]);
441    }
442
443    #[test]
444    fn from_data() {
445        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
446        let original_dtype = data.dtype;
447        let original_shape = data.shape.clone();
448
449        let snapshot = TensorSnapshot::from_data(
450            data,
451            vec!["encoder".to_string(), "weight".to_string()],
452            vec!["Struct:Encoder".to_string(), "Struct:Dense".to_string()],
453            ParamId::new(),
454        );
455
456        // Test metadata
457        assert_eq!(snapshot.dtype, original_dtype);
458        assert_eq!(snapshot.shape, original_shape);
459        assert_eq!(snapshot.full_path(), "encoder.weight");
460        assert_eq!(snapshot.container_type(), "Struct:Dense");
461        assert_eq!(snapshot.data_len(), 24); // 6 * 4 bytes
462
463        // Test data materialization
464        let materialized = snapshot.to_data().unwrap();
465        assert_eq!(materialized.shape, original_shape);
466    }
467
468    #[test]
469    #[cfg(feature = "std")]
470    fn panic_catching_in_to_data() {
471        use alloc::rc::Rc;
472
473        // Create a TensorSnapshot with a closure that panics
474        let snapshot = TensorSnapshot {
475            data_fn: Rc::new(|| panic!("Test panic in data_fn")),
476            dtype: DType::F32,
477            shape: shape![2, 2],
478            path_stack: Some(vec!["test".to_string()]),
479            container_stack: Some(vec!["Test".to_string()]),
480            tensor_id: Some(ParamId::new()),
481        };
482
483        // When std is available, to_data should catch the panic and return an error
484        let result = snapshot.to_data();
485        assert!(result.is_err());
486
487        match result {
488            Err(TensorSnapshotError::PanicError(msg)) => {
489                assert!(msg.contains("Panic occurred"));
490            }
491            _ => panic!("Expected PanicError with panic message"),
492        }
493    }
494
495    #[test]
496    fn error_propagation_in_closure() {
497        use alloc::rc::Rc;
498
499        // Create a snapshot with a closure that returns an error
500        let snapshot = TensorSnapshot::from_closure(
501            Rc::new(|| Err(TensorSnapshotError::IoError("Simulated IO error".into()))),
502            DType::F32,
503            shape![2, 2],
504            vec!["error_test".into()],
505            vec![],
506            ParamId::new(),
507        );
508
509        // Should return an error when trying to get data
510        let result = snapshot.to_data();
511        assert!(result.is_err());
512        match result {
513            Err(TensorSnapshotError::IoError(msg)) => {
514                assert!(msg.contains("Simulated IO error"));
515            }
516            _ => panic!("Expected IoError"),
517        }
518    }
519
520    #[test]
521    fn container_type_extraction() {
522        let device = Default::default();
523        let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
524
525        let snapshot = TensorSnapshot::from_float(
526            &tensor,
527            vec![
528                "model".to_string(),
529                "layer1".to_string(),
530                "weight".to_string(),
531            ],
532            vec![
533                "Struct:Model".to_string(),
534                "Struct:Conv2d".to_string(),
535                "Struct:Param".to_string(),
536            ],
537            ParamId::new(),
538        );
539
540        assert_eq!(snapshot.container_type(), "Struct:Param");
541        assert_eq!(snapshot.module_type(), Some("Struct:Param".to_string()));
542        assert_eq!(
543            snapshot.container_path(),
544            "Struct:Model.Struct:Conv2d.Struct:Param"
545        );
546        assert_eq!(snapshot.full_path(), "model.layer1.weight");
547    }
548
549    #[test]
550    fn container_type_vs_module_type() {
551        let device = Default::default();
552        let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
553
554        // Test case 1: Tensor inside a Vec<Linear>
555        // container_stack: ["Struct:Model", "Vec", "Struct:Linear"]
556        let snapshot = TensorSnapshot::from_float(
557            &tensor,
558            vec![
559                "model".to_string(),
560                "layers".to_string(),
561                "0".to_string(),
562                "weight".to_string(),
563            ],
564            vec![
565                "Struct:Model".to_string(),
566                "Vec".to_string(),
567                "Struct:Linear".to_string(),
568            ],
569            ParamId::new(),
570        );
571
572        // container_type() returns the last element (Struct:Linear in this case)
573        assert_eq!(snapshot.container_type(), "Struct:Linear");
574        // module_type() also returns Some(Struct:Linear) (skipping Vec)
575        assert_eq!(snapshot.module_type(), Some("Struct:Linear".to_string()));
576
577        // Test case 2: Tensor that's just in a Vec
578        // container_stack: ["Vec"]
579        let snapshot2 = TensorSnapshot::from_float(
580            &tensor,
581            vec!["data".to_string(), "0".to_string()],
582            vec!["Vec".to_string()],
583            ParamId::new(),
584        );
585
586        // container_type() returns Vec
587        assert_eq!(snapshot2.container_type(), "Vec");
588        // module_type() returns None (no Struct/Enum found)
589        assert_eq!(snapshot2.module_type(), None);
590
591        // Test case 3: Nested collections
592        // container_stack: ["Struct:Model", "Vec", "Array", "Struct:Linear"]
593        let snapshot3 = TensorSnapshot::from_float(
594            &tensor,
595            vec![
596                "model".to_string(),
597                "layers".to_string(),
598                "0".to_string(),
599                "sublayers".to_string(),
600                "1".to_string(),
601                "weight".to_string(),
602            ],
603            vec![
604                "Struct:Model".to_string(),
605                "Vec".to_string(),
606                "Array".to_string(),
607                "Struct:Linear".to_string(),
608            ],
609            ParamId::new(),
610        );
611
612        // container_type() returns the immediate container
613        assert_eq!(snapshot3.container_type(), "Struct:Linear");
614        // module_type() returns the last Struct/Enum
615        assert_eq!(snapshot3.module_type(), Some("Struct:Linear".to_string()));
616    }
617}