burn_core/record/
tensor.rs1use core::marker::PhantomData;
2
3use super::{PrecisionSettings, Record};
4use burn_tensor::{backend::Backend, Bool, DType, Element, Int, Tensor, TensorData};
5use serde::{Deserialize, Serialize};
6
7#[cfg(not(feature = "record-backward-compat"))]
8use alloc::format;
9#[cfg(feature = "record-backward-compat")]
10use burn_tensor::DataSerialize;
11
12#[cfg(feature = "record-backward-compat")]
14#[allow(deprecated)]
15#[derive(Serialize, Deserialize)]
16#[serde(untagged)]
17enum TensorDataSerde<E> {
18 V1(DataSerialize<E>),
19 V2(TensorData),
20}
21
22fn deserialize_data<'de, E, De>(deserializer: De) -> Result<TensorData, De::Error>
24where
25 E: Element + Deserialize<'de>,
26 De: serde::Deserializer<'de>,
27{
28 #[cfg(feature = "record-backward-compat")]
29 {
30 let data = match TensorDataSerde::<E>::deserialize(deserializer)? {
31 TensorDataSerde::V1(data) => data.into_tensor_data(),
32 TensorDataSerde::V2(data) => data.convert::<E>(),
34 };
35 Ok(data)
36 }
37
38 #[cfg(not(feature = "record-backward-compat"))]
39 {
40 let data = TensorData::deserialize(deserializer).map_err(|e| {
41 serde::de::Error::custom(format!(
42 "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag. Once you have saved the record in the new format, you can disable the feature flag.\n",
43 e
44 ))
45 })?;
46 let data = if let DType::QFloat(_) = data.dtype {
47 data } else {
49 data.convert::<E>()
50 };
51 Ok(data)
52 }
53}
54
55#[derive(new, Clone, Debug)]
58pub struct FloatTensorSerde<S: PrecisionSettings> {
59 data: TensorData,
60 _e: PhantomData<S::FloatElem>,
61}
62
63#[derive(new, Clone, Debug)]
66pub struct IntTensorSerde<S: PrecisionSettings> {
67 data: TensorData,
68 _e: PhantomData<S::IntElem>,
69}
70
71#[derive(new, Clone, Debug)]
73pub struct BoolTensorSerde {
74 data: TensorData,
75}
76
77impl<S: PrecisionSettings> Serialize for FloatTensorSerde<S> {
80 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
81 where
82 Se: serde::Serializer,
83 {
84 self.data.serialize(serializer)
85 }
86}
87
88impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
89 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
90 where
91 De: serde::Deserializer<'de>,
92 {
93 let data = deserialize_data::<S::FloatElem, De>(deserializer)?;
94
95 Ok(Self::new(data))
96 }
97}
98
99impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
100 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
101 where
102 Se: serde::Serializer,
103 {
104 self.data.serialize(serializer)
105 }
106}
107
108impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde<S> {
109 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
110 where
111 De: serde::Deserializer<'de>,
112 {
113 let data = deserialize_data::<S::IntElem, De>(deserializer)?;
114
115 Ok(Self::new(data))
116 }
117}
118
119impl Serialize for BoolTensorSerde {
120 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
121 where
122 Se: serde::Serializer,
123 {
124 self.data.serialize(serializer)
125 }
126}
127
128impl<'de> Deserialize<'de> for BoolTensorSerde {
129 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
130 where
131 De: serde::Deserializer<'de>,
132 {
133 let data = deserialize_data::<bool, De>(deserializer)?;
134
135 Ok(Self::new(data))
136 }
137}
138
139impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
142 type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
143
144 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
145 let data = self.into_data();
146 let data = if let DType::QFloat(_) = data.dtype {
147 data } else {
149 data.convert::<S::FloatElem>()
150 };
151 FloatTensorSerde::new(data)
152 }
153
154 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
155 let data = if let DType::QFloat(_) = item.data.dtype {
156 item.data } else {
158 item.data.convert::<B::FloatElem>()
159 };
160 Tensor::from_data(data, device)
161 }
162}
163
164impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
165 type Item<S: PrecisionSettings> = IntTensorSerde<S>;
166
167 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
168 IntTensorSerde::new(self.into_data().convert::<S::IntElem>())
169 }
170
171 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
172 Tensor::from_data(item.data.convert::<B::IntElem>(), device)
173 }
174}
175
176impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
177 type Item<S: PrecisionSettings> = BoolTensorSerde;
178
179 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
180 BoolTensorSerde::new(self.into_data())
181 }
182
183 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
184 Tensor::from_data(item.data, device)
185 }
186}