burn_core/record/
tensor.rs1use core::marker::PhantomData;
2
3use super::{PrecisionSettings, Record};
4use burn_tensor::{Bool, DType, Element, Int, Tensor, TensorData, backend::Backend};
5use serde::{Deserialize, Serialize};
6
7use alloc::format;
8
9fn deserialize_data<'de, E, De>(deserializer: De) -> Result<TensorData, De::Error>
11where
12 E: Element + Deserialize<'de>,
13 De: serde::Deserializer<'de>,
14{
15 let data = TensorData::deserialize(deserializer).map_err(|e| {
16 serde::de::Error::custom(format!(
17 "{e:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\n"
18 ))
19 })?;
20 let data = if let DType::QFloat(_) = data.dtype {
21 data } else {
23 data.convert::<E>()
24 };
25 Ok(data)
26}
27
28#[derive(new, Clone, Debug)]
31pub struct FloatTensorSerde<S: PrecisionSettings> {
32 data: TensorData,
33 _e: PhantomData<S::FloatElem>,
34}
35
36#[derive(new, Clone, Debug)]
39pub struct IntTensorSerde<S: PrecisionSettings> {
40 data: TensorData,
41 _e: PhantomData<S::IntElem>,
42}
43
44#[derive(new, Clone, Debug)]
46pub struct BoolTensorSerde {
47 data: TensorData,
48}
49
50impl<S: PrecisionSettings> Serialize for FloatTensorSerde<S> {
53 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
54 where
55 Se: serde::Serializer,
56 {
57 self.data.serialize(serializer)
58 }
59}
60
61impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
62 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
63 where
64 De: serde::Deserializer<'de>,
65 {
66 let data = deserialize_data::<S::FloatElem, De>(deserializer)?;
67
68 Ok(Self::new(data))
69 }
70}
71
72impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
73 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
74 where
75 Se: serde::Serializer,
76 {
77 self.data.serialize(serializer)
78 }
79}
80
81impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde<S> {
82 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
83 where
84 De: serde::Deserializer<'de>,
85 {
86 let data = deserialize_data::<S::IntElem, De>(deserializer)?;
87
88 Ok(Self::new(data))
89 }
90}
91
92impl Serialize for BoolTensorSerde {
93 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
94 where
95 Se: serde::Serializer,
96 {
97 self.data.serialize(serializer)
98 }
99}
100
101impl<'de> Deserialize<'de> for BoolTensorSerde {
102 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
103 where
104 De: serde::Deserializer<'de>,
105 {
106 let data = deserialize_data::<bool, De>(deserializer)?;
107
108 Ok(Self::new(data))
109 }
110}
111
112impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
115 type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
116
117 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
118 let data = self.into_data();
119 let data = if let DType::QFloat(_) = data.dtype {
120 data } else {
122 data.convert::<S::FloatElem>()
123 };
124 FloatTensorSerde::new(data)
125 }
126
127 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
128 let data = if let DType::QFloat(_) = item.data.dtype {
129 item.data } else {
131 item.data.convert::<B::FloatElem>()
132 };
133 Tensor::from_data(data, device)
134 }
135}
136
137impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
138 type Item<S: PrecisionSettings> = IntTensorSerde<S>;
139
140 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
141 IntTensorSerde::new(self.into_data().convert::<S::IntElem>())
142 }
143
144 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
145 Tensor::from_data(item.data.convert::<B::IntElem>(), device)
146 }
147}
148
149impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
150 type Item<S: PrecisionSettings> = BoolTensorSerde;
151
152 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
153 BoolTensorSerde::new(self.into_data())
154 }
155
156 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
157 Tensor::from_data(item.data, device)
158 }
159}