1use alloc::rc::Rc;
2use alloc::string::String;
3use alloc::string::ToString;
4use alloc::vec::Vec;
5use burn_core::module::ParamId;
6use burn_tensor::quantization::{QPARAM_ALIGN, QuantParam, params_shape};
7use burn_tensor::{Bool, DType, Int, Shape, Tensor, TensorData, backend::Backend};
8use half::f16;
9
10const fn quant_param_size(param: QuantParam) -> usize {
13 match param {
14 QuantParam::F32 => core::mem::size_of::<f32>(),
15 QuantParam::F16 | QuantParam::BF16 => core::mem::size_of::<f16>(),
16 QuantParam::UE8M0 | QuantParam::UE4M3 => core::mem::size_of::<u8>(),
17 }
18}
19
20#[derive(Debug, Clone)]
22pub enum TensorSnapshotError {
23 IoError(String),
25 DataError(String),
27 PanicError(String),
29}
30
31impl core::fmt::Display for TensorSnapshotError {
32 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
33 match self {
34 Self::IoError(e) => write!(f, "I/O error: {}", e),
35 Self::DataError(e) => write!(f, "Data error: {}", e),
36 Self::PanicError(e) => write!(f, "Panic error: {}", e),
37 }
38 }
39}
40
41impl core::error::Error for TensorSnapshotError {}
42
43pub struct TensorSnapshot {
52 data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
54 pub dtype: burn_tensor::DType,
56 pub shape: Shape,
58 pub path_stack: Option<Vec<String>>,
60 pub container_stack: Option<Vec<String>>,
62 pub tensor_id: Option<ParamId>,
64}
65
66impl TensorSnapshot {
67 pub fn from_float<B: Backend, const D: usize>(
69 tensor: &Tensor<B, D>,
70 path_stack: Vec<String>,
71 container_stack: Vec<String>,
72 tensor_id: ParamId,
73 ) -> Self {
74 let dtype = tensor.dtype();
75 let shape = tensor.shape();
76 let tensor = tensor.clone(); Self {
78 data_fn: Rc::new(move || Ok(tensor.to_data())),
79 dtype,
80 shape,
81 path_stack: Some(path_stack),
82 container_stack: Some(container_stack),
83 tensor_id: Some(tensor_id),
84 }
85 }
86
87 pub fn from_int<B: Backend, const D: usize>(
89 tensor: &Tensor<B, D, Int>,
90 path_stack: Vec<String>,
91 container_stack: Vec<String>,
92 tensor_id: ParamId,
93 ) -> Self {
94 let dtype = tensor.dtype();
95 let shape = tensor.shape();
96 let tensor = tensor.clone(); Self {
98 data_fn: Rc::new(move || Ok(tensor.to_data())),
99 dtype,
100 shape,
101 path_stack: Some(path_stack),
102 container_stack: Some(container_stack),
103 tensor_id: Some(tensor_id),
104 }
105 }
106
107 pub fn from_bool<B: Backend, const D: usize>(
109 tensor: &Tensor<B, D, Bool>,
110 path_stack: Vec<String>,
111 container_stack: Vec<String>,
112 tensor_id: ParamId,
113 ) -> Self {
114 let dtype = tensor.dtype();
115 let shape = tensor.shape();
116 let tensor = tensor.clone(); Self {
118 data_fn: Rc::new(move || Ok(tensor.to_data())),
119 dtype,
120 shape,
121 path_stack: Some(path_stack),
122 container_stack: Some(container_stack),
123 tensor_id: Some(tensor_id),
124 }
125 }
126
127 #[cfg(feature = "std")]
129 pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
130 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (self.data_fn)())).unwrap_or_else(
132 |_| {
133 Err(TensorSnapshotError::PanicError(
134 "Panic occurred while loading tensor data".to_string(),
135 ))
136 },
137 )
138 }
139
140 #[cfg(not(feature = "std"))]
142 pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
143 (self.data_fn)() }
145
146 pub fn full_path(&self) -> String {
148 self.path_stack
149 .as_ref()
150 .map(|stack| stack.join("."))
151 .unwrap_or_default()
152 }
153
154 pub fn container_path(&self) -> String {
156 self.container_stack
157 .as_ref()
158 .map(|stack| stack.join("."))
159 .unwrap_or_default()
160 }
161
162 pub fn module_type(&self) -> Option<String> {
174 self.container_stack.as_ref().and_then(|stack| {
175 stack
177 .iter()
178 .rev()
179 .find(|ct| ct.starts_with("Struct:") || ct.starts_with("Enum:"))
180 .cloned()
181 })
182 }
183
184 pub fn container_type(&self) -> String {
195 self.container_stack
196 .as_ref()
197 .and_then(|stack| stack.last())
198 .cloned()
199 .unwrap_or_else(|| "Unknown".to_string())
200 }
201
202 pub fn from_closure(
205 data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
206 dtype: burn_tensor::DType,
207 shape: Shape,
208 path_stack: Vec<String>,
209 container_stack: Vec<String>,
210 tensor_id: ParamId,
211 ) -> Self {
212 Self {
213 data_fn,
214 dtype,
215 shape,
216 path_stack: Some(path_stack),
217 container_stack: Some(container_stack),
218 tensor_id: Some(tensor_id),
219 }
220 }
221
222 pub fn from_data(
224 data: TensorData,
225 path_stack: Vec<String>,
226 container_stack: Vec<String>,
227 tensor_id: ParamId,
228 ) -> Self {
229 let dtype = data.dtype;
230 let shape = data.shape.clone();
231 Self {
232 data_fn: Rc::new(move || Ok(data.clone())),
233 dtype,
234 shape,
235 path_stack: Some(path_stack),
236 container_stack: Some(container_stack),
237 tensor_id: Some(tensor_id),
238 }
239 }
240
241 pub fn data_len(&self) -> usize {
250 const BITS_PER_BYTE: usize = 8;
251
252 let num_elements: usize = self.shape.iter().product();
253
254 match self.dtype {
255 DType::QFloat(scheme) => {
256 let num_storage_elements = num_elements.div_ceil(scheme.num_quants());
258 let value_bytes =
259 num_storage_elements * (scheme.size_bits_stored() / BITS_PER_BYTE);
260
261 let num_params = params_shape(&self.shape, scheme.level).num_elements();
263
264 let aligned_value_bytes = value_bytes.div_ceil(QPARAM_ALIGN) * QPARAM_ALIGN;
265 let scale_bytes = num_params * quant_param_size(scheme.param);
266
267 aligned_value_bytes + scale_bytes
268 }
269 _ => num_elements * self.dtype.size(),
270 }
271 }
272
273 pub fn clone_data_fn(&self) -> Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> {
275 self.data_fn.clone()
276 }
277}
278
279impl Clone for TensorSnapshot {
280 fn clone(&self) -> Self {
281 Self {
283 data_fn: self.data_fn.clone(),
284 dtype: self.dtype,
285 shape: self.shape.clone(),
286 path_stack: self.path_stack.clone(),
287 container_stack: self.container_stack.clone(),
288 tensor_id: self.tensor_id,
289 }
290 }
291}
292
293impl core::fmt::Debug for TensorSnapshot {
294 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
295 f.debug_struct("TensorSnapshot")
296 .field("dtype", &self.dtype)
297 .field("shape", &self.shape)
298 .field("path_stack", &self.path_stack)
299 .field("container_stack", &self.container_stack)
300 .field("tensor_id", &self.tensor_id)
301 .finish()
302 }
303}
304
305#[cfg(all(test, feature = "std"))]
306mod tests {
307 use super::*;
308 type TestBackend = burn_flex::Flex;
309 use alloc::string::ToString;
310 use burn_tensor::{BoolStore, DType, shape};
311
312 #[test]
313 fn tensor_view_float() {
314 let device = Default::default();
315 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
316
317 let snapshot = TensorSnapshot::from_float(
318 &tensor,
319 vec!["test".to_string(), "weight".to_string()],
320 vec!["TestModule".to_string(), "Param".to_string()],
321 ParamId::new(),
322 );
323
324 assert_eq!(snapshot.dtype, DType::F32);
326 assert_eq!(snapshot.shape, shape![2, 2]);
327 assert_eq!(snapshot.full_path(), "test.weight");
328 assert_eq!(snapshot.container_path(), "TestModule.Param");
329
330 let data = snapshot.to_data().unwrap();
332 assert_eq!(data.shape, shape![2, 2]);
333 assert_eq!(data.dtype, DType::F32);
334 }
335
336 #[test]
337 fn tensor_view_int() {
338 let device = Default::default();
339 let tensor = Tensor::<TestBackend, 2, Int>::from_data([[1, 2], [3, 4]], &device);
340
341 let snapshot = TensorSnapshot::from_int(
342 &tensor,
343 vec!["test".to_string(), "int".to_string()],
344 vec!["TestModule".to_string(), "Param".to_string()],
345 ParamId::new(),
346 );
347
348 assert_eq!(snapshot.dtype, DType::I32);
351 assert_eq!(snapshot.shape, shape![2, 2]);
352
353 let data = snapshot.to_data().unwrap();
354 assert_eq!(data.shape, shape![2, 2]);
355 assert_eq!(data.dtype, DType::I32);
356 }
357
358 #[test]
359 fn tensor_view_bool() {
360 let device = Default::default();
361 let tensor =
362 Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
363
364 let snapshot = TensorSnapshot::from_bool(
365 &tensor,
366 vec!["test".to_string(), "bool".to_string()],
367 vec!["TestModule".to_string(), "Param".to_string()],
368 ParamId::new(),
369 );
370
371 assert_eq!(snapshot.dtype, DType::Bool(BoolStore::Native));
373 assert_eq!(snapshot.shape, shape![2, 2]);
374
375 let data = snapshot.to_data().unwrap();
376 assert_eq!(data.shape, shape![2, 2]);
377 assert_eq!(data.dtype, DType::Bool(BoolStore::Native));
378 }
379
380 #[test]
381 fn data_len() {
382 let device = Default::default();
383
384 let tensor_f32 = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
386 let view_f32 = TensorSnapshot::from_float(
387 &tensor_f32,
388 vec!["test".to_string()],
389 vec!["Module".to_string()],
390 ParamId::new(),
391 );
392 assert_eq!(view_f32.data_len(), 16); let tensor_int =
396 Tensor::<TestBackend, 3, Int>::from_data([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);
397 let view_int = TensorSnapshot::from_int(
398 &tensor_int,
399 vec!["test".to_string()],
400 vec!["Module".to_string()],
401 ParamId::new(),
402 );
403 assert_eq!(view_int.data_len(), 32); let tensor_bool =
407 Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
408 let view_bool = TensorSnapshot::from_bool(
409 &tensor_bool,
410 vec!["test".to_string()],
411 vec!["Module".to_string()],
412 ParamId::new(),
413 );
414 assert_eq!(view_bool.data_len(), 4); }
416
417 #[test]
418 fn from_closure() {
419 let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
420 let dtype = data.dtype;
421 let shape = data.shape.clone();
422
423 let snapshot = TensorSnapshot::from_closure(
424 Rc::new(move || Ok(data.clone())),
425 dtype,
426 shape.clone(),
427 vec!["model".to_string(), "layer".to_string()],
428 vec!["Model".to_string(), "Layer".to_string()],
429 ParamId::new(),
430 );
431
432 assert_eq!(snapshot.dtype, DType::F32);
434 assert_eq!(snapshot.shape, shape![4]);
435 assert_eq!(snapshot.full_path(), "model.layer");
436 assert_eq!(snapshot.data_len(), 16); let materialized = snapshot.to_data().unwrap();
440 assert_eq!(materialized.shape, shape![4]);
441 }
442
443 #[test]
444 fn from_data() {
445 let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
446 let original_dtype = data.dtype;
447 let original_shape = data.shape.clone();
448
449 let snapshot = TensorSnapshot::from_data(
450 data,
451 vec!["encoder".to_string(), "weight".to_string()],
452 vec!["Struct:Encoder".to_string(), "Struct:Dense".to_string()],
453 ParamId::new(),
454 );
455
456 assert_eq!(snapshot.dtype, original_dtype);
458 assert_eq!(snapshot.shape, original_shape);
459 assert_eq!(snapshot.full_path(), "encoder.weight");
460 assert_eq!(snapshot.container_type(), "Struct:Dense");
461 assert_eq!(snapshot.data_len(), 24); let materialized = snapshot.to_data().unwrap();
465 assert_eq!(materialized.shape, original_shape);
466 }
467
468 #[test]
469 #[cfg(feature = "std")]
470 fn panic_catching_in_to_data() {
471 use alloc::rc::Rc;
472
473 let snapshot = TensorSnapshot {
475 data_fn: Rc::new(|| panic!("Test panic in data_fn")),
476 dtype: DType::F32,
477 shape: shape![2, 2],
478 path_stack: Some(vec!["test".to_string()]),
479 container_stack: Some(vec!["Test".to_string()]),
480 tensor_id: Some(ParamId::new()),
481 };
482
483 let result = snapshot.to_data();
485 assert!(result.is_err());
486
487 match result {
488 Err(TensorSnapshotError::PanicError(msg)) => {
489 assert!(msg.contains("Panic occurred"));
490 }
491 _ => panic!("Expected PanicError with panic message"),
492 }
493 }
494
495 #[test]
496 fn error_propagation_in_closure() {
497 use alloc::rc::Rc;
498
499 let snapshot = TensorSnapshot::from_closure(
501 Rc::new(|| Err(TensorSnapshotError::IoError("Simulated IO error".into()))),
502 DType::F32,
503 shape![2, 2],
504 vec!["error_test".into()],
505 vec![],
506 ParamId::new(),
507 );
508
509 let result = snapshot.to_data();
511 assert!(result.is_err());
512 match result {
513 Err(TensorSnapshotError::IoError(msg)) => {
514 assert!(msg.contains("Simulated IO error"));
515 }
516 _ => panic!("Expected IoError"),
517 }
518 }
519
520 #[test]
521 fn container_type_extraction() {
522 let device = Default::default();
523 let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
524
525 let snapshot = TensorSnapshot::from_float(
526 &tensor,
527 vec![
528 "model".to_string(),
529 "layer1".to_string(),
530 "weight".to_string(),
531 ],
532 vec![
533 "Struct:Model".to_string(),
534 "Struct:Conv2d".to_string(),
535 "Struct:Param".to_string(),
536 ],
537 ParamId::new(),
538 );
539
540 assert_eq!(snapshot.container_type(), "Struct:Param");
541 assert_eq!(snapshot.module_type(), Some("Struct:Param".to_string()));
542 assert_eq!(
543 snapshot.container_path(),
544 "Struct:Model.Struct:Conv2d.Struct:Param"
545 );
546 assert_eq!(snapshot.full_path(), "model.layer1.weight");
547 }
548
549 #[test]
550 fn container_type_vs_module_type() {
551 let device = Default::default();
552 let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
553
554 let snapshot = TensorSnapshot::from_float(
557 &tensor,
558 vec![
559 "model".to_string(),
560 "layers".to_string(),
561 "0".to_string(),
562 "weight".to_string(),
563 ],
564 vec![
565 "Struct:Model".to_string(),
566 "Vec".to_string(),
567 "Struct:Linear".to_string(),
568 ],
569 ParamId::new(),
570 );
571
572 assert_eq!(snapshot.container_type(), "Struct:Linear");
574 assert_eq!(snapshot.module_type(), Some("Struct:Linear".to_string()));
576
577 let snapshot2 = TensorSnapshot::from_float(
580 &tensor,
581 vec!["data".to_string(), "0".to_string()],
582 vec!["Vec".to_string()],
583 ParamId::new(),
584 );
585
586 assert_eq!(snapshot2.container_type(), "Vec");
588 assert_eq!(snapshot2.module_type(), None);
590
591 let snapshot3 = TensorSnapshot::from_float(
594 &tensor,
595 vec![
596 "model".to_string(),
597 "layers".to_string(),
598 "0".to_string(),
599 "sublayers".to_string(),
600 "1".to_string(),
601 "weight".to_string(),
602 ],
603 vec![
604 "Struct:Model".to_string(),
605 "Vec".to_string(),
606 "Array".to_string(),
607 "Struct:Linear".to_string(),
608 ],
609 ParamId::new(),
610 );
611
612 assert_eq!(snapshot3.container_type(), "Struct:Linear");
614 assert_eq!(snapshot3.module_type(), Some("Struct:Linear".to_string()));
616 }
617}