burn_store/
tensor_snapshot.rs

1use alloc::rc::Rc;
2use alloc::string::String;
3use alloc::string::ToString;
4use alloc::vec::Vec;
5use burn_core::module::ParamId;
6use burn_tensor::quantization::{QPARAM_ALIGN, QuantParam, params_shape};
7use burn_tensor::{Bool, DType, Int, Shape, Tensor, TensorData, backend::Backend};
8use half::f16;
9
10/// Returns the byte size of a quantization parameter type.
11// TODO: Add `size_bytes()` method to `QuantParam` in cubecl and use it here.
12const fn quant_param_size(param: QuantParam) -> usize {
13    match param {
14        QuantParam::F32 => core::mem::size_of::<f32>(),
15        QuantParam::F16 | QuantParam::BF16 => core::mem::size_of::<f16>(),
16        QuantParam::UE8M0 | QuantParam::UE4M3 => core::mem::size_of::<u8>(),
17    }
18}
19
20/// Error type for TensorSnapshot operations
21#[derive(Debug, Clone)]
22pub enum TensorSnapshotError {
23    /// I/O error occurred while loading tensor data
24    IoError(String),
25    /// Data corruption or invalid format
26    DataError(String),
27    /// Panic occurred while loading tensor data
28    PanicError(String),
29}
30
31impl core::fmt::Display for TensorSnapshotError {
32    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
33        match self {
34            Self::IoError(e) => write!(f, "I/O error: {}", e),
35            Self::DataError(e) => write!(f, "Data error: {}", e),
36            Self::PanicError(e) => write!(f, "Panic error: {}", e),
37        }
38    }
39}
40
41impl core::error::Error for TensorSnapshotError {}
42
43/// A lightweight snapshot of a tensor that can lazily produce TensorData.
44///
45/// TensorSnapshot stores a cloned tensor internally (which is cheap due to reference counting)
46/// and only materializes the actual data when `to_data()` is called. This allows
47/// efficient inspection of module structure without the overhead of copying all tensor data.
48///
49/// The dtype and shape are cached for efficient access without requiring data materialization,
50/// which is particularly useful for serialization formats that need metadata upfront.
51pub struct TensorSnapshot {
52    /// Function to get tensor data when needed (Rc allows cloning)
53    data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
54    /// Data type of the tensor (cached for efficient access)
55    pub dtype: burn_tensor::DType,
56    /// Shape of the tensor (cached for efficient access)
57    pub shape: Vec<usize>,
58    /// Path stack representing the module hierarchy
59    pub path_stack: Option<Vec<String>>,
60    /// Container stack representing the container types at each level
61    pub container_stack: Option<Vec<String>>,
62    /// Unique identifier for the tensor parameter
63    pub tensor_id: Option<ParamId>,
64}
65
66impl TensorSnapshot {
67    /// Create a new tensor snapshot from a float tensor
68    pub fn from_float<B: Backend, const D: usize>(
69        tensor: &Tensor<B, D>,
70        path_stack: Vec<String>,
71        container_stack: Vec<String>,
72        tensor_id: ParamId,
73    ) -> Self {
74        let dtype = tensor.dtype();
75        let shape = tensor.shape().to_vec();
76        let tensor = tensor.clone(); // Clone is cheap (reference counted)
77        Self {
78            data_fn: Rc::new(move || Ok(tensor.to_data())),
79            dtype,
80            shape,
81            path_stack: Some(path_stack),
82            container_stack: Some(container_stack),
83            tensor_id: Some(tensor_id),
84        }
85    }
86
87    /// Create a new tensor snapshot from an int tensor
88    pub fn from_int<B: Backend, const D: usize>(
89        tensor: &Tensor<B, D, Int>,
90        path_stack: Vec<String>,
91        container_stack: Vec<String>,
92        tensor_id: ParamId,
93    ) -> Self {
94        let dtype = tensor.dtype();
95        let shape = tensor.shape().to_vec();
96        let tensor = tensor.clone(); // Clone is cheap (reference counted)
97        Self {
98            data_fn: Rc::new(move || Ok(tensor.to_data())),
99            dtype,
100            shape,
101            path_stack: Some(path_stack),
102            container_stack: Some(container_stack),
103            tensor_id: Some(tensor_id),
104        }
105    }
106
107    /// Create a new tensor snapshot from a bool tensor
108    pub fn from_bool<B: Backend, const D: usize>(
109        tensor: &Tensor<B, D, Bool>,
110        path_stack: Vec<String>,
111        container_stack: Vec<String>,
112        tensor_id: ParamId,
113    ) -> Self {
114        let dtype = tensor.dtype();
115        let shape = tensor.shape().to_vec();
116        let tensor = tensor.clone(); // Clone is cheap (reference counted)
117        Self {
118            data_fn: Rc::new(move || Ok(tensor.to_data())),
119            dtype,
120            shape,
121            path_stack: Some(path_stack),
122            container_stack: Some(container_stack),
123            tensor_id: Some(tensor_id),
124        }
125    }
126
127    /// Convert to TensorData (this is where actual data copy happens)
128    #[cfg(feature = "std")]
129    pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
130        // Use AssertUnwindSafe since we're working with Rc which is not UnwindSafe
131        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (self.data_fn)())).unwrap_or_else(
132            |_| {
133                Err(TensorSnapshotError::PanicError(
134                    "Panic occurred while loading tensor data".to_string(),
135                ))
136            },
137        )
138    }
139
140    /// Convert to TensorData (this is where actual data copy happens)
141    #[cfg(not(feature = "std"))]
142    pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
143        (self.data_fn)() // Can't catch panics in no-std, do it when core::panic::AssertUnwindSafe is available
144    }
145
146    /// Get the full path by joining the path stack
147    pub fn full_path(&self) -> String {
148        self.path_stack
149            .as_ref()
150            .map(|stack| stack.join("."))
151            .unwrap_or_default()
152    }
153
154    /// Get the full container path by joining the container stack
155    pub fn container_path(&self) -> String {
156        self.container_stack
157            .as_ref()
158            .map(|stack| stack.join("."))
159            .unwrap_or_default()
160    }
161
162    /// Get the module type (last Struct/Enum in the hierarchy)
163    ///
164    /// Returns the last user-defined module type, skipping primitive containers
165    /// like "Vec", "Array". This is useful for determining which user-defined
166    /// module a tensor belongs to.
167    ///
168    /// # Examples
169    /// - `Linear.weight` → `Some("Struct:Linear")`
170    /// - `Vec<Linear>[0].weight` → `Some("Struct:Linear")`
171    /// - `Linear.bias` (Optional) → `Some("Struct:Linear")`
172    /// - `Vec<Param>[0]` (no module) → `None`
173    pub fn module_type(&self) -> Option<String> {
174        self.container_stack.as_ref().and_then(|stack| {
175            // Find the last user-defined type (Struct: or Enum:)
176            stack
177                .iter()
178                .rev()
179                .find(|ct| ct.starts_with("Struct:") || ct.starts_with("Enum:"))
180                .cloned()
181        })
182    }
183
184    /// Get the immediate container type (last in the container stack)
185    ///
186    /// Returns the last element in the container stack, which could be a
187    /// user-defined type ("Struct:", "Enum:") or a collection type ("Vec", "Array").
188    /// This is useful for understanding the full container hierarchy.
189    ///
190    /// # Examples
191    /// - `Linear.weight` → `"Struct:Linear"`
192    /// - `Vec<Linear>[0].weight` → `"Struct:Linear"` (the Linear, not the Vec)
193    /// - `Vec<Param>[0]` → `"Vec"`
194    pub fn container_type(&self) -> String {
195        self.container_stack
196            .as_ref()
197            .and_then(|stack| stack.last())
198            .cloned()
199            .unwrap_or_else(|| "Unknown".to_string())
200    }
201
202    /// Create a TensorSnapshot from a closure that produces TensorData
203    /// This is used internally for lazy loading
204    pub fn from_closure(
205        data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
206        dtype: burn_tensor::DType,
207        shape: Vec<usize>,
208        path_stack: Vec<String>,
209        container_stack: Vec<String>,
210        tensor_id: ParamId,
211    ) -> Self {
212        Self {
213            data_fn,
214            dtype,
215            shape,
216            path_stack: Some(path_stack),
217            container_stack: Some(container_stack),
218            tensor_id: Some(tensor_id),
219        }
220    }
221
222    /// Create a TensorSnapshot from TensorData directly
223    pub fn from_data(
224        data: TensorData,
225        path_stack: Vec<String>,
226        container_stack: Vec<String>,
227        tensor_id: ParamId,
228    ) -> Self {
229        let dtype = data.dtype;
230        let shape = data.shape.clone();
231        Self {
232            data_fn: Rc::new(move || Ok(data.clone())),
233            dtype,
234            shape,
235            path_stack: Some(path_stack),
236            container_stack: Some(container_stack),
237            tensor_id: Some(tensor_id),
238        }
239    }
240
241    /// Get the size of the tensor data in bytes without materializing it.
242    ///
243    /// For regular (non-quantized) types, this is simply `shape.product() * dtype.size()`.
244    ///
245    /// For quantized types (`QFloat`), this accounts for:
246    /// - The quantized values (packed according to the quantization scheme)
247    /// - Alignment padding (values are aligned to 4-byte boundary)
248    /// - Quantization parameters (scale values appended to the data)
249    pub fn data_len(&self) -> usize {
250        const BITS_PER_BYTE: usize = 8;
251
252        let num_elements: usize = self.shape.iter().product();
253
254        match self.dtype {
255            DType::QFloat(scheme) => {
256                // Calculate value bytes using scheme's packing information
257                let num_storage_elements = num_elements.div_ceil(scheme.num_quants());
258                let value_bytes =
259                    num_storage_elements * (scheme.size_bits_stored() / BITS_PER_BYTE);
260
261                // Calculate number of quantization parameters (scales)
262                let num_params =
263                    params_shape(&Shape::from(self.shape.clone()), scheme.level).num_elements();
264
265                let aligned_value_bytes = value_bytes.div_ceil(QPARAM_ALIGN) * QPARAM_ALIGN;
266                let scale_bytes = num_params * quant_param_size(scheme.param);
267
268                aligned_value_bytes + scale_bytes
269            }
270            _ => num_elements * self.dtype.size(),
271        }
272    }
273
274    /// Clone the data function for lazy composition
275    pub fn clone_data_fn(&self) -> Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> {
276        self.data_fn.clone()
277    }
278}
279
280impl Clone for TensorSnapshot {
281    fn clone(&self) -> Self {
282        // Clone lazily - keep the same data function
283        Self {
284            data_fn: self.data_fn.clone(),
285            dtype: self.dtype,
286            shape: self.shape.clone(),
287            path_stack: self.path_stack.clone(),
288            container_stack: self.container_stack.clone(),
289            tensor_id: self.tensor_id,
290        }
291    }
292}
293
294impl core::fmt::Debug for TensorSnapshot {
295    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
296        f.debug_struct("TensorSnapshot")
297            .field("dtype", &self.dtype)
298            .field("shape", &self.shape)
299            .field("path_stack", &self.path_stack)
300            .field("container_stack", &self.container_stack)
301            .field("tensor_id", &self.tensor_id)
302            .finish()
303    }
304}
305
306#[cfg(all(test, feature = "std"))]
307mod tests {
308    use super::*;
309    type TestBackend = burn_ndarray::NdArray;
310    use alloc::string::ToString;
311    use burn_tensor::DType;
312
313    #[test]
314    fn tensor_view_float() {
315        let device = Default::default();
316        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
317
318        let snapshot = TensorSnapshot::from_float(
319            &tensor,
320            vec!["test".to_string(), "weight".to_string()],
321            vec!["TestModule".to_string(), "Param".to_string()],
322            ParamId::new(),
323        );
324
325        // Test metadata access without materialization
326        assert_eq!(snapshot.dtype, DType::F32);
327        assert_eq!(snapshot.shape, vec![2, 2]);
328        assert_eq!(snapshot.full_path(), "test.weight");
329        assert_eq!(snapshot.container_path(), "TestModule.Param");
330
331        // Test data materialization
332        let data = snapshot.to_data().unwrap();
333        assert_eq!(data.shape, vec![2, 2]);
334        assert_eq!(data.dtype, DType::F32);
335    }
336
337    #[test]
338    fn tensor_view_int() {
339        let device = Default::default();
340        let tensor = Tensor::<TestBackend, 2, Int>::from_data([[1, 2], [3, 4]], &device);
341
342        let snapshot = TensorSnapshot::from_int(
343            &tensor,
344            vec!["test".to_string(), "int".to_string()],
345            vec!["TestModule".to_string(), "Param".to_string()],
346            ParamId::new(),
347        );
348
349        // Test metadata access without materialization
350        // TestBackend uses I64 for integers
351        assert_eq!(snapshot.dtype, DType::I64);
352        assert_eq!(snapshot.shape, vec![2, 2]);
353
354        let data = snapshot.to_data().unwrap();
355        assert_eq!(data.shape, vec![2, 2]);
356        assert_eq!(data.dtype, DType::I64);
357    }
358
359    #[test]
360    fn tensor_view_bool() {
361        let device = Default::default();
362        let tensor =
363            Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
364
365        let snapshot = TensorSnapshot::from_bool(
366            &tensor,
367            vec!["test".to_string(), "bool".to_string()],
368            vec!["TestModule".to_string(), "Param".to_string()],
369            ParamId::new(),
370        );
371
372        // Test metadata access without materialization
373        assert_eq!(snapshot.dtype, DType::Bool);
374        assert_eq!(snapshot.shape, vec![2, 2]);
375
376        let data = snapshot.to_data().unwrap();
377        assert_eq!(data.shape, vec![2, 2]);
378        assert_eq!(data.dtype, DType::Bool);
379    }
380
381    #[test]
382    fn data_len() {
383        let device = Default::default();
384
385        // Test F32 tensor (4 bytes per element)
386        let tensor_f32 = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
387        let view_f32 = TensorSnapshot::from_float(
388            &tensor_f32,
389            vec!["test".to_string()],
390            vec!["Module".to_string()],
391            ParamId::new(),
392        );
393        assert_eq!(view_f32.data_len(), 16); // 4 elements * 4 bytes
394
395        // Test I64 tensor (8 bytes per element) - TestBackend uses I64 for Int
396        let tensor_i64 =
397            Tensor::<TestBackend, 3, Int>::from_data([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);
398        let view_i64 = TensorSnapshot::from_int(
399            &tensor_i64,
400            vec!["test".to_string()],
401            vec!["Module".to_string()],
402            ParamId::new(),
403        );
404        assert_eq!(view_i64.data_len(), 64); // 8 elements * 8 bytes (I64)
405
406        // Test Bool tensor (1 byte per element)
407        let tensor_bool =
408            Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
409        let view_bool = TensorSnapshot::from_bool(
410            &tensor_bool,
411            vec!["test".to_string()],
412            vec!["Module".to_string()],
413            ParamId::new(),
414        );
415        assert_eq!(view_bool.data_len(), 4); // 4 elements * 1 byte
416    }
417
418    #[test]
419    fn from_closure() {
420        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
421        let dtype = data.dtype;
422        let shape = data.shape.clone();
423
424        let snapshot = TensorSnapshot::from_closure(
425            Rc::new(move || Ok(data.clone())),
426            dtype,
427            shape.clone(),
428            vec!["model".to_string(), "layer".to_string()],
429            vec!["Model".to_string(), "Layer".to_string()],
430            ParamId::new(),
431        );
432
433        // Test metadata access
434        assert_eq!(snapshot.dtype, DType::F32);
435        assert_eq!(snapshot.shape, vec![4]);
436        assert_eq!(snapshot.full_path(), "model.layer");
437        assert_eq!(snapshot.data_len(), 16); // 4 * 4 bytes
438
439        // Test data materialization
440        let materialized = snapshot.to_data().unwrap();
441        assert_eq!(materialized.shape, vec![4]);
442    }
443
444    #[test]
445    fn from_data() {
446        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
447        let original_dtype = data.dtype;
448        let original_shape = data.shape.clone();
449
450        let snapshot = TensorSnapshot::from_data(
451            data,
452            vec!["encoder".to_string(), "weight".to_string()],
453            vec!["Struct:Encoder".to_string(), "Struct:Dense".to_string()],
454            ParamId::new(),
455        );
456
457        // Test metadata
458        assert_eq!(snapshot.dtype, original_dtype);
459        assert_eq!(snapshot.shape, original_shape);
460        assert_eq!(snapshot.full_path(), "encoder.weight");
461        assert_eq!(snapshot.container_type(), "Struct:Dense");
462        assert_eq!(snapshot.data_len(), 24); // 6 * 4 bytes
463
464        // Test data materialization
465        let materialized = snapshot.to_data().unwrap();
466        assert_eq!(materialized.shape, original_shape);
467    }
468
469    #[test]
470    #[cfg(feature = "std")]
471    fn panic_catching_in_to_data() {
472        use alloc::rc::Rc;
473
474        // Create a TensorSnapshot with a closure that panics
475        let snapshot = TensorSnapshot {
476            data_fn: Rc::new(|| panic!("Test panic in data_fn")),
477            dtype: DType::F32,
478            shape: vec![2, 2],
479            path_stack: Some(vec!["test".to_string()]),
480            container_stack: Some(vec!["Test".to_string()]),
481            tensor_id: Some(ParamId::new()),
482        };
483
484        // When std is available, to_data should catch the panic and return an error
485        let result = snapshot.to_data();
486        assert!(result.is_err());
487
488        match result {
489            Err(TensorSnapshotError::PanicError(msg)) => {
490                assert!(msg.contains("Panic occurred"));
491            }
492            _ => panic!("Expected PanicError with panic message"),
493        }
494    }
495
496    #[test]
497    fn error_propagation_in_closure() {
498        use alloc::rc::Rc;
499
500        // Create a snapshot with a closure that returns an error
501        let snapshot = TensorSnapshot::from_closure(
502            Rc::new(|| Err(TensorSnapshotError::IoError("Simulated IO error".into()))),
503            DType::F32,
504            vec![2, 2],
505            vec!["error_test".into()],
506            vec![],
507            ParamId::new(),
508        );
509
510        // Should return an error when trying to get data
511        let result = snapshot.to_data();
512        assert!(result.is_err());
513        match result {
514            Err(TensorSnapshotError::IoError(msg)) => {
515                assert!(msg.contains("Simulated IO error"));
516            }
517            _ => panic!("Expected IoError"),
518        }
519    }
520
521    #[test]
522    fn container_type_extraction() {
523        let device = Default::default();
524        let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
525
526        let snapshot = TensorSnapshot::from_float(
527            &tensor,
528            vec![
529                "model".to_string(),
530                "layer1".to_string(),
531                "weight".to_string(),
532            ],
533            vec![
534                "Struct:Model".to_string(),
535                "Struct:Conv2d".to_string(),
536                "Struct:Param".to_string(),
537            ],
538            ParamId::new(),
539        );
540
541        assert_eq!(snapshot.container_type(), "Struct:Param");
542        assert_eq!(snapshot.module_type(), Some("Struct:Param".to_string()));
543        assert_eq!(
544            snapshot.container_path(),
545            "Struct:Model.Struct:Conv2d.Struct:Param"
546        );
547        assert_eq!(snapshot.full_path(), "model.layer1.weight");
548    }
549
550    #[test]
551    fn container_type_vs_module_type() {
552        let device = Default::default();
553        let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
554
555        // Test case 1: Tensor inside a Vec<Linear>
556        // container_stack: ["Struct:Model", "Vec", "Struct:Linear"]
557        let snapshot = TensorSnapshot::from_float(
558            &tensor,
559            vec![
560                "model".to_string(),
561                "layers".to_string(),
562                "0".to_string(),
563                "weight".to_string(),
564            ],
565            vec![
566                "Struct:Model".to_string(),
567                "Vec".to_string(),
568                "Struct:Linear".to_string(),
569            ],
570            ParamId::new(),
571        );
572
573        // container_type() returns the last element (Struct:Linear in this case)
574        assert_eq!(snapshot.container_type(), "Struct:Linear");
575        // module_type() also returns Some(Struct:Linear) (skipping Vec)
576        assert_eq!(snapshot.module_type(), Some("Struct:Linear".to_string()));
577
578        // Test case 2: Tensor that's just in a Vec
579        // container_stack: ["Vec"]
580        let snapshot2 = TensorSnapshot::from_float(
581            &tensor,
582            vec!["data".to_string(), "0".to_string()],
583            vec!["Vec".to_string()],
584            ParamId::new(),
585        );
586
587        // container_type() returns Vec
588        assert_eq!(snapshot2.container_type(), "Vec");
589        // module_type() returns None (no Struct/Enum found)
590        assert_eq!(snapshot2.module_type(), None);
591
592        // Test case 3: Nested collections
593        // container_stack: ["Struct:Model", "Vec", "Array", "Struct:Linear"]
594        let snapshot3 = TensorSnapshot::from_float(
595            &tensor,
596            vec![
597                "model".to_string(),
598                "layers".to_string(),
599                "0".to_string(),
600                "sublayers".to_string(),
601                "1".to_string(),
602                "weight".to_string(),
603            ],
604            vec![
605                "Struct:Model".to_string(),
606                "Vec".to_string(),
607                "Array".to_string(),
608                "Struct:Linear".to_string(),
609            ],
610            ParamId::new(),
611        );
612
613        // container_type() returns the immediate container
614        assert_eq!(snapshot3.container_type(), "Struct:Linear");
615        // module_type() returns the last Struct/Enum
616        assert_eq!(snapshot3.module_type(), Some("Struct:Linear".to_string()));
617    }
618}