1use alloc::rc::Rc;
2use alloc::string::String;
3use alloc::string::ToString;
4use alloc::vec::Vec;
5use burn_core::module::ParamId;
6use burn_tensor::quantization::{QPARAM_ALIGN, QuantParam, params_shape};
7use burn_tensor::{Bool, DType, Int, Shape, Tensor, TensorData, backend::Backend};
8use half::f16;
9
10const fn quant_param_size(param: QuantParam) -> usize {
13 match param {
14 QuantParam::F32 => core::mem::size_of::<f32>(),
15 QuantParam::F16 | QuantParam::BF16 => core::mem::size_of::<f16>(),
16 QuantParam::UE8M0 | QuantParam::UE4M3 => core::mem::size_of::<u8>(),
17 }
18}
19
20#[derive(Debug, Clone)]
22pub enum TensorSnapshotError {
23 IoError(String),
25 DataError(String),
27 PanicError(String),
29}
30
31impl core::fmt::Display for TensorSnapshotError {
32 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
33 match self {
34 Self::IoError(e) => write!(f, "I/O error: {}", e),
35 Self::DataError(e) => write!(f, "Data error: {}", e),
36 Self::PanicError(e) => write!(f, "Panic error: {}", e),
37 }
38 }
39}
40
41impl core::error::Error for TensorSnapshotError {}
42
43pub struct TensorSnapshot {
52 data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
54 pub dtype: burn_tensor::DType,
56 pub shape: Vec<usize>,
58 pub path_stack: Option<Vec<String>>,
60 pub container_stack: Option<Vec<String>>,
62 pub tensor_id: Option<ParamId>,
64}
65
66impl TensorSnapshot {
67 pub fn from_float<B: Backend, const D: usize>(
69 tensor: &Tensor<B, D>,
70 path_stack: Vec<String>,
71 container_stack: Vec<String>,
72 tensor_id: ParamId,
73 ) -> Self {
74 let dtype = tensor.dtype();
75 let shape = tensor.shape().to_vec();
76 let tensor = tensor.clone(); Self {
78 data_fn: Rc::new(move || Ok(tensor.to_data())),
79 dtype,
80 shape,
81 path_stack: Some(path_stack),
82 container_stack: Some(container_stack),
83 tensor_id: Some(tensor_id),
84 }
85 }
86
87 pub fn from_int<B: Backend, const D: usize>(
89 tensor: &Tensor<B, D, Int>,
90 path_stack: Vec<String>,
91 container_stack: Vec<String>,
92 tensor_id: ParamId,
93 ) -> Self {
94 let dtype = tensor.dtype();
95 let shape = tensor.shape().to_vec();
96 let tensor = tensor.clone(); Self {
98 data_fn: Rc::new(move || Ok(tensor.to_data())),
99 dtype,
100 shape,
101 path_stack: Some(path_stack),
102 container_stack: Some(container_stack),
103 tensor_id: Some(tensor_id),
104 }
105 }
106
107 pub fn from_bool<B: Backend, const D: usize>(
109 tensor: &Tensor<B, D, Bool>,
110 path_stack: Vec<String>,
111 container_stack: Vec<String>,
112 tensor_id: ParamId,
113 ) -> Self {
114 let dtype = tensor.dtype();
115 let shape = tensor.shape().to_vec();
116 let tensor = tensor.clone(); Self {
118 data_fn: Rc::new(move || Ok(tensor.to_data())),
119 dtype,
120 shape,
121 path_stack: Some(path_stack),
122 container_stack: Some(container_stack),
123 tensor_id: Some(tensor_id),
124 }
125 }
126
127 #[cfg(feature = "std")]
129 pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
130 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (self.data_fn)())).unwrap_or_else(
132 |_| {
133 Err(TensorSnapshotError::PanicError(
134 "Panic occurred while loading tensor data".to_string(),
135 ))
136 },
137 )
138 }
139
140 #[cfg(not(feature = "std"))]
142 pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
143 (self.data_fn)() }
145
146 pub fn full_path(&self) -> String {
148 self.path_stack
149 .as_ref()
150 .map(|stack| stack.join("."))
151 .unwrap_or_default()
152 }
153
154 pub fn container_path(&self) -> String {
156 self.container_stack
157 .as_ref()
158 .map(|stack| stack.join("."))
159 .unwrap_or_default()
160 }
161
162 pub fn module_type(&self) -> Option<String> {
174 self.container_stack.as_ref().and_then(|stack| {
175 stack
177 .iter()
178 .rev()
179 .find(|ct| ct.starts_with("Struct:") || ct.starts_with("Enum:"))
180 .cloned()
181 })
182 }
183
184 pub fn container_type(&self) -> String {
195 self.container_stack
196 .as_ref()
197 .and_then(|stack| stack.last())
198 .cloned()
199 .unwrap_or_else(|| "Unknown".to_string())
200 }
201
202 pub fn from_closure(
205 data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
206 dtype: burn_tensor::DType,
207 shape: Vec<usize>,
208 path_stack: Vec<String>,
209 container_stack: Vec<String>,
210 tensor_id: ParamId,
211 ) -> Self {
212 Self {
213 data_fn,
214 dtype,
215 shape,
216 path_stack: Some(path_stack),
217 container_stack: Some(container_stack),
218 tensor_id: Some(tensor_id),
219 }
220 }
221
222 pub fn from_data(
224 data: TensorData,
225 path_stack: Vec<String>,
226 container_stack: Vec<String>,
227 tensor_id: ParamId,
228 ) -> Self {
229 let dtype = data.dtype;
230 let shape = data.shape.clone();
231 Self {
232 data_fn: Rc::new(move || Ok(data.clone())),
233 dtype,
234 shape,
235 path_stack: Some(path_stack),
236 container_stack: Some(container_stack),
237 tensor_id: Some(tensor_id),
238 }
239 }
240
241 pub fn data_len(&self) -> usize {
250 const BITS_PER_BYTE: usize = 8;
251
252 let num_elements: usize = self.shape.iter().product();
253
254 match self.dtype {
255 DType::QFloat(scheme) => {
256 let num_storage_elements = num_elements.div_ceil(scheme.num_quants());
258 let value_bytes =
259 num_storage_elements * (scheme.size_bits_stored() / BITS_PER_BYTE);
260
261 let num_params =
263 params_shape(&Shape::from(self.shape.clone()), scheme.level).num_elements();
264
265 let aligned_value_bytes = value_bytes.div_ceil(QPARAM_ALIGN) * QPARAM_ALIGN;
266 let scale_bytes = num_params * quant_param_size(scheme.param);
267
268 aligned_value_bytes + scale_bytes
269 }
270 _ => num_elements * self.dtype.size(),
271 }
272 }
273
274 pub fn clone_data_fn(&self) -> Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> {
276 self.data_fn.clone()
277 }
278}
279
280impl Clone for TensorSnapshot {
281 fn clone(&self) -> Self {
282 Self {
284 data_fn: self.data_fn.clone(),
285 dtype: self.dtype,
286 shape: self.shape.clone(),
287 path_stack: self.path_stack.clone(),
288 container_stack: self.container_stack.clone(),
289 tensor_id: self.tensor_id,
290 }
291 }
292}
293
294impl core::fmt::Debug for TensorSnapshot {
295 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
296 f.debug_struct("TensorSnapshot")
297 .field("dtype", &self.dtype)
298 .field("shape", &self.shape)
299 .field("path_stack", &self.path_stack)
300 .field("container_stack", &self.container_stack)
301 .field("tensor_id", &self.tensor_id)
302 .finish()
303 }
304}
305
306#[cfg(all(test, feature = "std"))]
307mod tests {
308 use super::*;
309 type TestBackend = burn_ndarray::NdArray;
310 use alloc::string::ToString;
311 use burn_tensor::DType;
312
313 #[test]
314 fn tensor_view_float() {
315 let device = Default::default();
316 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
317
318 let snapshot = TensorSnapshot::from_float(
319 &tensor,
320 vec!["test".to_string(), "weight".to_string()],
321 vec!["TestModule".to_string(), "Param".to_string()],
322 ParamId::new(),
323 );
324
325 assert_eq!(snapshot.dtype, DType::F32);
327 assert_eq!(snapshot.shape, vec![2, 2]);
328 assert_eq!(snapshot.full_path(), "test.weight");
329 assert_eq!(snapshot.container_path(), "TestModule.Param");
330
331 let data = snapshot.to_data().unwrap();
333 assert_eq!(data.shape, vec![2, 2]);
334 assert_eq!(data.dtype, DType::F32);
335 }
336
337 #[test]
338 fn tensor_view_int() {
339 let device = Default::default();
340 let tensor = Tensor::<TestBackend, 2, Int>::from_data([[1, 2], [3, 4]], &device);
341
342 let snapshot = TensorSnapshot::from_int(
343 &tensor,
344 vec!["test".to_string(), "int".to_string()],
345 vec!["TestModule".to_string(), "Param".to_string()],
346 ParamId::new(),
347 );
348
349 assert_eq!(snapshot.dtype, DType::I64);
352 assert_eq!(snapshot.shape, vec![2, 2]);
353
354 let data = snapshot.to_data().unwrap();
355 assert_eq!(data.shape, vec![2, 2]);
356 assert_eq!(data.dtype, DType::I64);
357 }
358
359 #[test]
360 fn tensor_view_bool() {
361 let device = Default::default();
362 let tensor =
363 Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
364
365 let snapshot = TensorSnapshot::from_bool(
366 &tensor,
367 vec!["test".to_string(), "bool".to_string()],
368 vec!["TestModule".to_string(), "Param".to_string()],
369 ParamId::new(),
370 );
371
372 assert_eq!(snapshot.dtype, DType::Bool);
374 assert_eq!(snapshot.shape, vec![2, 2]);
375
376 let data = snapshot.to_data().unwrap();
377 assert_eq!(data.shape, vec![2, 2]);
378 assert_eq!(data.dtype, DType::Bool);
379 }
380
381 #[test]
382 fn data_len() {
383 let device = Default::default();
384
385 let tensor_f32 = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
387 let view_f32 = TensorSnapshot::from_float(
388 &tensor_f32,
389 vec!["test".to_string()],
390 vec!["Module".to_string()],
391 ParamId::new(),
392 );
393 assert_eq!(view_f32.data_len(), 16); let tensor_i64 =
397 Tensor::<TestBackend, 3, Int>::from_data([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);
398 let view_i64 = TensorSnapshot::from_int(
399 &tensor_i64,
400 vec!["test".to_string()],
401 vec!["Module".to_string()],
402 ParamId::new(),
403 );
404 assert_eq!(view_i64.data_len(), 64); let tensor_bool =
408 Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
409 let view_bool = TensorSnapshot::from_bool(
410 &tensor_bool,
411 vec!["test".to_string()],
412 vec!["Module".to_string()],
413 ParamId::new(),
414 );
415 assert_eq!(view_bool.data_len(), 4); }
417
418 #[test]
419 fn from_closure() {
420 let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
421 let dtype = data.dtype;
422 let shape = data.shape.clone();
423
424 let snapshot = TensorSnapshot::from_closure(
425 Rc::new(move || Ok(data.clone())),
426 dtype,
427 shape.clone(),
428 vec!["model".to_string(), "layer".to_string()],
429 vec!["Model".to_string(), "Layer".to_string()],
430 ParamId::new(),
431 );
432
433 assert_eq!(snapshot.dtype, DType::F32);
435 assert_eq!(snapshot.shape, vec![4]);
436 assert_eq!(snapshot.full_path(), "model.layer");
437 assert_eq!(snapshot.data_len(), 16); let materialized = snapshot.to_data().unwrap();
441 assert_eq!(materialized.shape, vec![4]);
442 }
443
444 #[test]
445 fn from_data() {
446 let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
447 let original_dtype = data.dtype;
448 let original_shape = data.shape.clone();
449
450 let snapshot = TensorSnapshot::from_data(
451 data,
452 vec!["encoder".to_string(), "weight".to_string()],
453 vec!["Struct:Encoder".to_string(), "Struct:Dense".to_string()],
454 ParamId::new(),
455 );
456
457 assert_eq!(snapshot.dtype, original_dtype);
459 assert_eq!(snapshot.shape, original_shape);
460 assert_eq!(snapshot.full_path(), "encoder.weight");
461 assert_eq!(snapshot.container_type(), "Struct:Dense");
462 assert_eq!(snapshot.data_len(), 24); let materialized = snapshot.to_data().unwrap();
466 assert_eq!(materialized.shape, original_shape);
467 }
468
469 #[test]
470 #[cfg(feature = "std")]
471 fn panic_catching_in_to_data() {
472 use alloc::rc::Rc;
473
474 let snapshot = TensorSnapshot {
476 data_fn: Rc::new(|| panic!("Test panic in data_fn")),
477 dtype: DType::F32,
478 shape: vec![2, 2],
479 path_stack: Some(vec!["test".to_string()]),
480 container_stack: Some(vec!["Test".to_string()]),
481 tensor_id: Some(ParamId::new()),
482 };
483
484 let result = snapshot.to_data();
486 assert!(result.is_err());
487
488 match result {
489 Err(TensorSnapshotError::PanicError(msg)) => {
490 assert!(msg.contains("Panic occurred"));
491 }
492 _ => panic!("Expected PanicError with panic message"),
493 }
494 }
495
496 #[test]
497 fn error_propagation_in_closure() {
498 use alloc::rc::Rc;
499
500 let snapshot = TensorSnapshot::from_closure(
502 Rc::new(|| Err(TensorSnapshotError::IoError("Simulated IO error".into()))),
503 DType::F32,
504 vec![2, 2],
505 vec!["error_test".into()],
506 vec![],
507 ParamId::new(),
508 );
509
510 let result = snapshot.to_data();
512 assert!(result.is_err());
513 match result {
514 Err(TensorSnapshotError::IoError(msg)) => {
515 assert!(msg.contains("Simulated IO error"));
516 }
517 _ => panic!("Expected IoError"),
518 }
519 }
520
521 #[test]
522 fn container_type_extraction() {
523 let device = Default::default();
524 let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
525
526 let snapshot = TensorSnapshot::from_float(
527 &tensor,
528 vec![
529 "model".to_string(),
530 "layer1".to_string(),
531 "weight".to_string(),
532 ],
533 vec![
534 "Struct:Model".to_string(),
535 "Struct:Conv2d".to_string(),
536 "Struct:Param".to_string(),
537 ],
538 ParamId::new(),
539 );
540
541 assert_eq!(snapshot.container_type(), "Struct:Param");
542 assert_eq!(snapshot.module_type(), Some("Struct:Param".to_string()));
543 assert_eq!(
544 snapshot.container_path(),
545 "Struct:Model.Struct:Conv2d.Struct:Param"
546 );
547 assert_eq!(snapshot.full_path(), "model.layer1.weight");
548 }
549
550 #[test]
551 fn container_type_vs_module_type() {
552 let device = Default::default();
553 let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
554
555 let snapshot = TensorSnapshot::from_float(
558 &tensor,
559 vec![
560 "model".to_string(),
561 "layers".to_string(),
562 "0".to_string(),
563 "weight".to_string(),
564 ],
565 vec![
566 "Struct:Model".to_string(),
567 "Vec".to_string(),
568 "Struct:Linear".to_string(),
569 ],
570 ParamId::new(),
571 );
572
573 assert_eq!(snapshot.container_type(), "Struct:Linear");
575 assert_eq!(snapshot.module_type(), Some("Struct:Linear".to_string()));
577
578 let snapshot2 = TensorSnapshot::from_float(
581 &tensor,
582 vec!["data".to_string(), "0".to_string()],
583 vec!["Vec".to_string()],
584 ParamId::new(),
585 );
586
587 assert_eq!(snapshot2.container_type(), "Vec");
589 assert_eq!(snapshot2.module_type(), None);
591
592 let snapshot3 = TensorSnapshot::from_float(
595 &tensor,
596 vec![
597 "model".to_string(),
598 "layers".to_string(),
599 "0".to_string(),
600 "sublayers".to_string(),
601 "1".to_string(),
602 "weight".to_string(),
603 ],
604 vec![
605 "Struct:Model".to_string(),
606 "Vec".to_string(),
607 "Array".to_string(),
608 "Struct:Linear".to_string(),
609 ],
610 ParamId::new(),
611 );
612
613 assert_eq!(snapshot3.container_type(), "Struct:Linear");
615 assert_eq!(snapshot3.module_type(), Some("Struct:Linear".to_string()));
617 }
618}