burn_store/
tensor_snapshot.rs

1use alloc::rc::Rc;
2use alloc::string::String;
3use alloc::string::ToString;
4use alloc::vec::Vec;
5use burn_core::module::ParamId;
6use burn_tensor::{Bool, Int, Tensor, TensorData, backend::Backend};
7
8/// Error type for TensorSnapshot operations
9#[derive(Debug, Clone)]
10pub enum TensorSnapshotError {
11    /// I/O error occurred while loading tensor data
12    IoError(String),
13    /// Data corruption or invalid format
14    DataError(String),
15    /// Panic occurred while loading tensor data
16    PanicError(String),
17}
18
19impl core::fmt::Display for TensorSnapshotError {
20    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
21        match self {
22            Self::IoError(e) => write!(f, "I/O error: {}", e),
23            Self::DataError(e) => write!(f, "Data error: {}", e),
24            Self::PanicError(e) => write!(f, "Panic error: {}", e),
25        }
26    }
27}
28
29impl core::error::Error for TensorSnapshotError {}
30
31/// A lightweight snapshot of a tensor that can lazily produce TensorData.
32///
33/// TensorSnapshot stores a cloned tensor internally (which is cheap due to reference counting)
34/// and only materializes the actual data when `to_data()` is called. This allows
35/// efficient inspection of module structure without the overhead of copying all tensor data.
36///
37/// The dtype and shape are cached for efficient access without requiring data materialization,
38/// which is particularly useful for serialization formats that need metadata upfront.
39pub struct TensorSnapshot {
40    /// Function to get tensor data when needed (Rc allows cloning)
41    data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
42    /// Data type of the tensor (cached for efficient access)
43    pub dtype: burn_tensor::DType,
44    /// Shape of the tensor (cached for efficient access)
45    pub shape: Vec<usize>,
46    /// Path stack representing the module hierarchy
47    pub path_stack: Option<Vec<String>>,
48    /// Container stack representing the container types at each level
49    pub container_stack: Option<Vec<String>>,
50    /// Unique identifier for the tensor parameter
51    pub tensor_id: Option<ParamId>,
52}
53
54impl TensorSnapshot {
55    /// Create a new tensor snapshot from a float tensor
56    pub fn from_float<B: Backend, const D: usize>(
57        tensor: &Tensor<B, D>,
58        path_stack: Vec<String>,
59        container_stack: Vec<String>,
60        tensor_id: ParamId,
61    ) -> Self {
62        let dtype = tensor.dtype();
63        let shape = tensor.shape().to_vec();
64        let tensor = tensor.clone(); // Clone is cheap (reference counted)
65        Self {
66            data_fn: Rc::new(move || Ok(tensor.to_data())),
67            dtype,
68            shape,
69            path_stack: Some(path_stack),
70            container_stack: Some(container_stack),
71            tensor_id: Some(tensor_id),
72        }
73    }
74
75    /// Create a new tensor snapshot from an int tensor
76    pub fn from_int<B: Backend, const D: usize>(
77        tensor: &Tensor<B, D, Int>,
78        path_stack: Vec<String>,
79        container_stack: Vec<String>,
80        tensor_id: ParamId,
81    ) -> Self {
82        let dtype = tensor.dtype();
83        let shape = tensor.shape().to_vec();
84        let tensor = tensor.clone(); // Clone is cheap (reference counted)
85        Self {
86            data_fn: Rc::new(move || Ok(tensor.to_data())),
87            dtype,
88            shape,
89            path_stack: Some(path_stack),
90            container_stack: Some(container_stack),
91            tensor_id: Some(tensor_id),
92        }
93    }
94
95    /// Create a new tensor snapshot from a bool tensor
96    pub fn from_bool<B: Backend, const D: usize>(
97        tensor: &Tensor<B, D, Bool>,
98        path_stack: Vec<String>,
99        container_stack: Vec<String>,
100        tensor_id: ParamId,
101    ) -> Self {
102        let dtype = tensor.dtype();
103        let shape = tensor.shape().to_vec();
104        let tensor = tensor.clone(); // Clone is cheap (reference counted)
105        Self {
106            data_fn: Rc::new(move || Ok(tensor.to_data())),
107            dtype,
108            shape,
109            path_stack: Some(path_stack),
110            container_stack: Some(container_stack),
111            tensor_id: Some(tensor_id),
112        }
113    }
114
115    /// Convert to TensorData (this is where actual data copy happens)
116    #[cfg(feature = "std")]
117    pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
118        // Use AssertUnwindSafe since we're working with Rc which is not UnwindSafe
119        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (self.data_fn)())).unwrap_or_else(
120            |_| {
121                Err(TensorSnapshotError::PanicError(
122                    "Panic occurred while loading tensor data".to_string(),
123                ))
124            },
125        )
126    }
127
128    /// Convert to TensorData (this is where actual data copy happens)
129    #[cfg(not(feature = "std"))]
130    pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
131        (self.data_fn)() // Can't catch panics in no-std, do it when core::panic::AssertUnwindSafe is available
132    }
133
134    /// Get the full path by joining the path stack
135    pub fn full_path(&self) -> String {
136        self.path_stack
137            .as_ref()
138            .map(|stack| stack.join("."))
139            .unwrap_or_default()
140    }
141
142    /// Get the full container path by joining the container stack
143    pub fn container_path(&self) -> String {
144        self.container_stack
145            .as_ref()
146            .map(|stack| stack.join("."))
147            .unwrap_or_default()
148    }
149
150    /// Get the immediate container type (last in the container stack)
151    pub fn container_type(&self) -> String {
152        self.container_stack
153            .as_ref()
154            .and_then(|stack| stack.last())
155            .cloned()
156            .unwrap_or_else(|| "Unknown".to_string())
157    }
158
159    /// Create a TensorSnapshot from a closure that produces TensorData
160    /// This is used internally for lazy loading
161    pub fn from_closure(
162        data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
163        dtype: burn_tensor::DType,
164        shape: Vec<usize>,
165        path_stack: Vec<String>,
166        container_stack: Vec<String>,
167        tensor_id: ParamId,
168    ) -> Self {
169        Self {
170            data_fn,
171            dtype,
172            shape,
173            path_stack: Some(path_stack),
174            container_stack: Some(container_stack),
175            tensor_id: Some(tensor_id),
176        }
177    }
178
179    /// Create a TensorSnapshot from TensorData directly
180    pub fn from_data(
181        data: TensorData,
182        path_stack: Vec<String>,
183        container_stack: Vec<String>,
184        tensor_id: ParamId,
185    ) -> Self {
186        let dtype = data.dtype;
187        let shape = data.shape.clone();
188        Self {
189            data_fn: Rc::new(move || Ok(data.clone())),
190            dtype,
191            shape,
192            path_stack: Some(path_stack),
193            container_stack: Some(container_stack),
194            tensor_id: Some(tensor_id),
195        }
196    }
197
198    /// Get the size of the tensor data in bytes without materializing it
199    pub fn data_len(&self) -> usize {
200        self.shape.iter().product::<usize>() * self.dtype.size()
201    }
202
203    /// Clone the data function for lazy composition
204    pub fn clone_data_fn(&self) -> Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> {
205        self.data_fn.clone()
206    }
207}
208
209impl Clone for TensorSnapshot {
210    fn clone(&self) -> Self {
211        // Clone lazily - keep the same data function
212        Self {
213            data_fn: self.data_fn.clone(),
214            dtype: self.dtype,
215            shape: self.shape.clone(),
216            path_stack: self.path_stack.clone(),
217            container_stack: self.container_stack.clone(),
218            tensor_id: self.tensor_id,
219        }
220    }
221}
222
223impl core::fmt::Debug for TensorSnapshot {
224    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
225        f.debug_struct("TensorSnapshot")
226            .field("dtype", &self.dtype)
227            .field("shape", &self.shape)
228            .field("path_stack", &self.path_stack)
229            .field("container_stack", &self.container_stack)
230            .field("tensor_id", &self.tensor_id)
231            .finish()
232    }
233}
234
235#[cfg(all(test, feature = "std"))]
236mod tests {
237    use super::*;
238    type TestBackend = burn_ndarray::NdArray;
239    use alloc::string::ToString;
240    use burn_tensor::DType;
241
242    #[test]
243    fn tensor_view_float() {
244        let device = Default::default();
245        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
246
247        let snapshot = TensorSnapshot::from_float(
248            &tensor,
249            vec!["test".to_string(), "weight".to_string()],
250            vec!["TestModule".to_string(), "Param".to_string()],
251            ParamId::new(),
252        );
253
254        // Test metadata access without materialization
255        assert_eq!(snapshot.dtype, DType::F32);
256        assert_eq!(snapshot.shape, vec![2, 2]);
257        assert_eq!(snapshot.full_path(), "test.weight");
258        assert_eq!(snapshot.container_path(), "TestModule.Param");
259
260        // Test data materialization
261        let data = snapshot.to_data().unwrap();
262        assert_eq!(data.shape, vec![2, 2]);
263        assert_eq!(data.dtype, DType::F32);
264    }
265
266    #[test]
267    fn tensor_view_int() {
268        let device = Default::default();
269        let tensor = Tensor::<TestBackend, 2, Int>::from_data([[1, 2], [3, 4]], &device);
270
271        let snapshot = TensorSnapshot::from_int(
272            &tensor,
273            vec!["test".to_string(), "int".to_string()],
274            vec!["TestModule".to_string(), "Param".to_string()],
275            ParamId::new(),
276        );
277
278        // Test metadata access without materialization
279        // TestBackend uses I64 for integers
280        assert_eq!(snapshot.dtype, DType::I64);
281        assert_eq!(snapshot.shape, vec![2, 2]);
282
283        let data = snapshot.to_data().unwrap();
284        assert_eq!(data.shape, vec![2, 2]);
285        assert_eq!(data.dtype, DType::I64);
286    }
287
288    #[test]
289    fn tensor_view_bool() {
290        let device = Default::default();
291        let tensor =
292            Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
293
294        let snapshot = TensorSnapshot::from_bool(
295            &tensor,
296            vec!["test".to_string(), "bool".to_string()],
297            vec!["TestModule".to_string(), "Param".to_string()],
298            ParamId::new(),
299        );
300
301        // Test metadata access without materialization
302        assert_eq!(snapshot.dtype, DType::Bool);
303        assert_eq!(snapshot.shape, vec![2, 2]);
304
305        let data = snapshot.to_data().unwrap();
306        assert_eq!(data.shape, vec![2, 2]);
307        assert_eq!(data.dtype, DType::Bool);
308    }
309
310    #[test]
311    fn data_len() {
312        let device = Default::default();
313
314        // Test F32 tensor (4 bytes per element)
315        let tensor_f32 = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
316        let view_f32 = TensorSnapshot::from_float(
317            &tensor_f32,
318            vec!["test".to_string()],
319            vec!["Module".to_string()],
320            ParamId::new(),
321        );
322        assert_eq!(view_f32.data_len(), 16); // 4 elements * 4 bytes
323
324        // Test I64 tensor (8 bytes per element) - TestBackend uses I64 for Int
325        let tensor_i64 =
326            Tensor::<TestBackend, 3, Int>::from_data([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);
327        let view_i64 = TensorSnapshot::from_int(
328            &tensor_i64,
329            vec!["test".to_string()],
330            vec!["Module".to_string()],
331            ParamId::new(),
332        );
333        assert_eq!(view_i64.data_len(), 64); // 8 elements * 8 bytes (I64)
334
335        // Test Bool tensor (1 byte per element)
336        let tensor_bool =
337            Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
338        let view_bool = TensorSnapshot::from_bool(
339            &tensor_bool,
340            vec!["test".to_string()],
341            vec!["Module".to_string()],
342            ParamId::new(),
343        );
344        assert_eq!(view_bool.data_len(), 4); // 4 elements * 1 byte
345    }
346
347    #[test]
348    fn from_closure() {
349        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
350        let dtype = data.dtype;
351        let shape = data.shape.clone();
352
353        let snapshot = TensorSnapshot::from_closure(
354            Rc::new(move || Ok(data.clone())),
355            dtype,
356            shape.clone(),
357            vec!["model".to_string(), "layer".to_string()],
358            vec!["Model".to_string(), "Layer".to_string()],
359            ParamId::new(),
360        );
361
362        // Test metadata access
363        assert_eq!(snapshot.dtype, DType::F32);
364        assert_eq!(snapshot.shape, vec![4]);
365        assert_eq!(snapshot.full_path(), "model.layer");
366        assert_eq!(snapshot.data_len(), 16); // 4 * 4 bytes
367
368        // Test data materialization
369        let materialized = snapshot.to_data().unwrap();
370        assert_eq!(materialized.shape, vec![4]);
371    }
372
373    #[test]
374    fn from_data() {
375        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
376        let original_dtype = data.dtype;
377        let original_shape = data.shape.clone();
378
379        let snapshot = TensorSnapshot::from_data(
380            data,
381            vec!["encoder".to_string(), "weight".to_string()],
382            vec!["Encoder".to_string(), "Dense".to_string()],
383            ParamId::new(),
384        );
385
386        // Test metadata
387        assert_eq!(snapshot.dtype, original_dtype);
388        assert_eq!(snapshot.shape, original_shape);
389        assert_eq!(snapshot.full_path(), "encoder.weight");
390        assert_eq!(snapshot.container_type(), "Dense");
391        assert_eq!(snapshot.data_len(), 24); // 6 * 4 bytes
392
393        // Test data materialization
394        let materialized = snapshot.to_data().unwrap();
395        assert_eq!(materialized.shape, original_shape);
396    }
397
398    #[test]
399    #[cfg(feature = "std")]
400    fn panic_catching_in_to_data() {
401        use alloc::rc::Rc;
402
403        // Create a TensorSnapshot with a closure that panics
404        let snapshot = TensorSnapshot {
405            data_fn: Rc::new(|| panic!("Test panic in data_fn")),
406            dtype: DType::F32,
407            shape: vec![2, 2],
408            path_stack: Some(vec!["test".to_string()]),
409            container_stack: Some(vec!["Test".to_string()]),
410            tensor_id: Some(ParamId::new()),
411        };
412
413        // When std is available, to_data should catch the panic and return an error
414        let result = snapshot.to_data();
415        assert!(result.is_err());
416
417        match result {
418            Err(TensorSnapshotError::PanicError(msg)) => {
419                assert!(msg.contains("Panic occurred"));
420            }
421            _ => panic!("Expected PanicError with panic message"),
422        }
423    }
424
425    #[test]
426    fn error_propagation_in_closure() {
427        use alloc::rc::Rc;
428
429        // Create a snapshot with a closure that returns an error
430        let snapshot = TensorSnapshot::from_closure(
431            Rc::new(|| Err(TensorSnapshotError::IoError("Simulated IO error".into()))),
432            DType::F32,
433            vec![2, 2],
434            vec!["error_test".into()],
435            vec![],
436            ParamId::new(),
437        );
438
439        // Should return an error when trying to get data
440        let result = snapshot.to_data();
441        assert!(result.is_err());
442        match result {
443            Err(TensorSnapshotError::IoError(msg)) => {
444                assert!(msg.contains("Simulated IO error"));
445            }
446            _ => panic!("Expected IoError"),
447        }
448    }
449
450    #[test]
451    fn container_type_extraction() {
452        let device = Default::default();
453        let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
454
455        let snapshot = TensorSnapshot::from_float(
456            &tensor,
457            vec![
458                "model".to_string(),
459                "layer1".to_string(),
460                "weight".to_string(),
461            ],
462            vec![
463                "Model".to_string(),
464                "Conv2d".to_string(),
465                "Param".to_string(),
466            ],
467            ParamId::new(),
468        );
469
470        assert_eq!(snapshot.container_type(), "Param");
471        assert_eq!(snapshot.container_path(), "Model.Conv2d.Param");
472        assert_eq!(snapshot.full_path(), "model.layer1.weight");
473    }
474}