1use alloc::rc::Rc;
2use alloc::string::String;
3use alloc::string::ToString;
4use alloc::vec::Vec;
5use burn_core::module::ParamId;
6use burn_tensor::{Bool, Int, Tensor, TensorData, backend::Backend};
7
8#[derive(Debug, Clone)]
10pub enum TensorSnapshotError {
11 IoError(String),
13 DataError(String),
15 PanicError(String),
17}
18
19impl core::fmt::Display for TensorSnapshotError {
20 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
21 match self {
22 Self::IoError(e) => write!(f, "I/O error: {}", e),
23 Self::DataError(e) => write!(f, "Data error: {}", e),
24 Self::PanicError(e) => write!(f, "Panic error: {}", e),
25 }
26 }
27}
28
29impl core::error::Error for TensorSnapshotError {}
30
31pub struct TensorSnapshot {
40 data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
42 pub dtype: burn_tensor::DType,
44 pub shape: Vec<usize>,
46 pub path_stack: Option<Vec<String>>,
48 pub container_stack: Option<Vec<String>>,
50 pub tensor_id: Option<ParamId>,
52}
53
54impl TensorSnapshot {
55 pub fn from_float<B: Backend, const D: usize>(
57 tensor: &Tensor<B, D>,
58 path_stack: Vec<String>,
59 container_stack: Vec<String>,
60 tensor_id: ParamId,
61 ) -> Self {
62 let dtype = tensor.dtype();
63 let shape = tensor.shape().to_vec();
64 let tensor = tensor.clone(); Self {
66 data_fn: Rc::new(move || Ok(tensor.to_data())),
67 dtype,
68 shape,
69 path_stack: Some(path_stack),
70 container_stack: Some(container_stack),
71 tensor_id: Some(tensor_id),
72 }
73 }
74
75 pub fn from_int<B: Backend, const D: usize>(
77 tensor: &Tensor<B, D, Int>,
78 path_stack: Vec<String>,
79 container_stack: Vec<String>,
80 tensor_id: ParamId,
81 ) -> Self {
82 let dtype = tensor.dtype();
83 let shape = tensor.shape().to_vec();
84 let tensor = tensor.clone(); Self {
86 data_fn: Rc::new(move || Ok(tensor.to_data())),
87 dtype,
88 shape,
89 path_stack: Some(path_stack),
90 container_stack: Some(container_stack),
91 tensor_id: Some(tensor_id),
92 }
93 }
94
95 pub fn from_bool<B: Backend, const D: usize>(
97 tensor: &Tensor<B, D, Bool>,
98 path_stack: Vec<String>,
99 container_stack: Vec<String>,
100 tensor_id: ParamId,
101 ) -> Self {
102 let dtype = tensor.dtype();
103 let shape = tensor.shape().to_vec();
104 let tensor = tensor.clone(); Self {
106 data_fn: Rc::new(move || Ok(tensor.to_data())),
107 dtype,
108 shape,
109 path_stack: Some(path_stack),
110 container_stack: Some(container_stack),
111 tensor_id: Some(tensor_id),
112 }
113 }
114
115 #[cfg(feature = "std")]
117 pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
118 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (self.data_fn)())).unwrap_or_else(
120 |_| {
121 Err(TensorSnapshotError::PanicError(
122 "Panic occurred while loading tensor data".to_string(),
123 ))
124 },
125 )
126 }
127
128 #[cfg(not(feature = "std"))]
130 pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
131 (self.data_fn)() }
133
134 pub fn full_path(&self) -> String {
136 self.path_stack
137 .as_ref()
138 .map(|stack| stack.join("."))
139 .unwrap_or_default()
140 }
141
142 pub fn container_path(&self) -> String {
144 self.container_stack
145 .as_ref()
146 .map(|stack| stack.join("."))
147 .unwrap_or_default()
148 }
149
150 pub fn container_type(&self) -> String {
152 self.container_stack
153 .as_ref()
154 .and_then(|stack| stack.last())
155 .cloned()
156 .unwrap_or_else(|| "Unknown".to_string())
157 }
158
159 pub fn from_closure(
162 data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
163 dtype: burn_tensor::DType,
164 shape: Vec<usize>,
165 path_stack: Vec<String>,
166 container_stack: Vec<String>,
167 tensor_id: ParamId,
168 ) -> Self {
169 Self {
170 data_fn,
171 dtype,
172 shape,
173 path_stack: Some(path_stack),
174 container_stack: Some(container_stack),
175 tensor_id: Some(tensor_id),
176 }
177 }
178
179 pub fn from_data(
181 data: TensorData,
182 path_stack: Vec<String>,
183 container_stack: Vec<String>,
184 tensor_id: ParamId,
185 ) -> Self {
186 let dtype = data.dtype;
187 let shape = data.shape.clone();
188 Self {
189 data_fn: Rc::new(move || Ok(data.clone())),
190 dtype,
191 shape,
192 path_stack: Some(path_stack),
193 container_stack: Some(container_stack),
194 tensor_id: Some(tensor_id),
195 }
196 }
197
198 pub fn data_len(&self) -> usize {
200 self.shape.iter().product::<usize>() * self.dtype.size()
201 }
202
203 pub fn clone_data_fn(&self) -> Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> {
205 self.data_fn.clone()
206 }
207}
208
209impl Clone for TensorSnapshot {
210 fn clone(&self) -> Self {
211 Self {
213 data_fn: self.data_fn.clone(),
214 dtype: self.dtype,
215 shape: self.shape.clone(),
216 path_stack: self.path_stack.clone(),
217 container_stack: self.container_stack.clone(),
218 tensor_id: self.tensor_id,
219 }
220 }
221}
222
223impl core::fmt::Debug for TensorSnapshot {
224 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
225 f.debug_struct("TensorSnapshot")
226 .field("dtype", &self.dtype)
227 .field("shape", &self.shape)
228 .field("path_stack", &self.path_stack)
229 .field("container_stack", &self.container_stack)
230 .field("tensor_id", &self.tensor_id)
231 .finish()
232 }
233}
234
235#[cfg(all(test, feature = "std"))]
236mod tests {
237 use super::*;
238 type TestBackend = burn_ndarray::NdArray;
239 use alloc::string::ToString;
240 use burn_tensor::DType;
241
242 #[test]
243 fn tensor_view_float() {
244 let device = Default::default();
245 let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
246
247 let snapshot = TensorSnapshot::from_float(
248 &tensor,
249 vec!["test".to_string(), "weight".to_string()],
250 vec!["TestModule".to_string(), "Param".to_string()],
251 ParamId::new(),
252 );
253
254 assert_eq!(snapshot.dtype, DType::F32);
256 assert_eq!(snapshot.shape, vec![2, 2]);
257 assert_eq!(snapshot.full_path(), "test.weight");
258 assert_eq!(snapshot.container_path(), "TestModule.Param");
259
260 let data = snapshot.to_data().unwrap();
262 assert_eq!(data.shape, vec![2, 2]);
263 assert_eq!(data.dtype, DType::F32);
264 }
265
266 #[test]
267 fn tensor_view_int() {
268 let device = Default::default();
269 let tensor = Tensor::<TestBackend, 2, Int>::from_data([[1, 2], [3, 4]], &device);
270
271 let snapshot = TensorSnapshot::from_int(
272 &tensor,
273 vec!["test".to_string(), "int".to_string()],
274 vec!["TestModule".to_string(), "Param".to_string()],
275 ParamId::new(),
276 );
277
278 assert_eq!(snapshot.dtype, DType::I64);
281 assert_eq!(snapshot.shape, vec![2, 2]);
282
283 let data = snapshot.to_data().unwrap();
284 assert_eq!(data.shape, vec![2, 2]);
285 assert_eq!(data.dtype, DType::I64);
286 }
287
288 #[test]
289 fn tensor_view_bool() {
290 let device = Default::default();
291 let tensor =
292 Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
293
294 let snapshot = TensorSnapshot::from_bool(
295 &tensor,
296 vec!["test".to_string(), "bool".to_string()],
297 vec!["TestModule".to_string(), "Param".to_string()],
298 ParamId::new(),
299 );
300
301 assert_eq!(snapshot.dtype, DType::Bool);
303 assert_eq!(snapshot.shape, vec![2, 2]);
304
305 let data = snapshot.to_data().unwrap();
306 assert_eq!(data.shape, vec![2, 2]);
307 assert_eq!(data.dtype, DType::Bool);
308 }
309
310 #[test]
311 fn data_len() {
312 let device = Default::default();
313
314 let tensor_f32 = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
316 let view_f32 = TensorSnapshot::from_float(
317 &tensor_f32,
318 vec!["test".to_string()],
319 vec!["Module".to_string()],
320 ParamId::new(),
321 );
322 assert_eq!(view_f32.data_len(), 16); let tensor_i64 =
326 Tensor::<TestBackend, 3, Int>::from_data([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);
327 let view_i64 = TensorSnapshot::from_int(
328 &tensor_i64,
329 vec!["test".to_string()],
330 vec!["Module".to_string()],
331 ParamId::new(),
332 );
333 assert_eq!(view_i64.data_len(), 64); let tensor_bool =
337 Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
338 let view_bool = TensorSnapshot::from_bool(
339 &tensor_bool,
340 vec!["test".to_string()],
341 vec!["Module".to_string()],
342 ParamId::new(),
343 );
344 assert_eq!(view_bool.data_len(), 4); }
346
347 #[test]
348 fn from_closure() {
349 let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
350 let dtype = data.dtype;
351 let shape = data.shape.clone();
352
353 let snapshot = TensorSnapshot::from_closure(
354 Rc::new(move || Ok(data.clone())),
355 dtype,
356 shape.clone(),
357 vec!["model".to_string(), "layer".to_string()],
358 vec!["Model".to_string(), "Layer".to_string()],
359 ParamId::new(),
360 );
361
362 assert_eq!(snapshot.dtype, DType::F32);
364 assert_eq!(snapshot.shape, vec![4]);
365 assert_eq!(snapshot.full_path(), "model.layer");
366 assert_eq!(snapshot.data_len(), 16); let materialized = snapshot.to_data().unwrap();
370 assert_eq!(materialized.shape, vec![4]);
371 }
372
373 #[test]
374 fn from_data() {
375 let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
376 let original_dtype = data.dtype;
377 let original_shape = data.shape.clone();
378
379 let snapshot = TensorSnapshot::from_data(
380 data,
381 vec!["encoder".to_string(), "weight".to_string()],
382 vec!["Encoder".to_string(), "Dense".to_string()],
383 ParamId::new(),
384 );
385
386 assert_eq!(snapshot.dtype, original_dtype);
388 assert_eq!(snapshot.shape, original_shape);
389 assert_eq!(snapshot.full_path(), "encoder.weight");
390 assert_eq!(snapshot.container_type(), "Dense");
391 assert_eq!(snapshot.data_len(), 24); let materialized = snapshot.to_data().unwrap();
395 assert_eq!(materialized.shape, original_shape);
396 }
397
398 #[test]
399 #[cfg(feature = "std")]
400 fn panic_catching_in_to_data() {
401 use alloc::rc::Rc;
402
403 let snapshot = TensorSnapshot {
405 data_fn: Rc::new(|| panic!("Test panic in data_fn")),
406 dtype: DType::F32,
407 shape: vec![2, 2],
408 path_stack: Some(vec!["test".to_string()]),
409 container_stack: Some(vec!["Test".to_string()]),
410 tensor_id: Some(ParamId::new()),
411 };
412
413 let result = snapshot.to_data();
415 assert!(result.is_err());
416
417 match result {
418 Err(TensorSnapshotError::PanicError(msg)) => {
419 assert!(msg.contains("Panic occurred"));
420 }
421 _ => panic!("Expected PanicError with panic message"),
422 }
423 }
424
425 #[test]
426 fn error_propagation_in_closure() {
427 use alloc::rc::Rc;
428
429 let snapshot = TensorSnapshot::from_closure(
431 Rc::new(|| Err(TensorSnapshotError::IoError("Simulated IO error".into()))),
432 DType::F32,
433 vec![2, 2],
434 vec!["error_test".into()],
435 vec![],
436 ParamId::new(),
437 );
438
439 let result = snapshot.to_data();
441 assert!(result.is_err());
442 match result {
443 Err(TensorSnapshotError::IoError(msg)) => {
444 assert!(msg.contains("Simulated IO error"));
445 }
446 _ => panic!("Expected IoError"),
447 }
448 }
449
450 #[test]
451 fn container_type_extraction() {
452 let device = Default::default();
453 let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
454
455 let snapshot = TensorSnapshot::from_float(
456 &tensor,
457 vec![
458 "model".to_string(),
459 "layer1".to_string(),
460 "weight".to_string(),
461 ],
462 vec![
463 "Model".to_string(),
464 "Conv2d".to_string(),
465 "Param".to_string(),
466 ],
467 ParamId::new(),
468 );
469
470 assert_eq!(snapshot.container_type(), "Param");
471 assert_eq!(snapshot.container_path(), "Model.Conv2d.Param");
472 assert_eq!(snapshot.full_path(), "model.layer1.weight");
473 }
474}