burn_core/record/
tensor.rs1use core::marker::PhantomData;
2
3use super::{PrecisionSettings, Record};
4use burn_tensor::{Bool, DType, Element, Int, Tensor, TensorData, backend::Backend};
5use serde::{Deserialize, Serialize};
6
7use alloc::format;
8
9fn deserialize_data<'de, E, De>(deserializer: De) -> Result<TensorData, De::Error>
11where
12 E: Element + Deserialize<'de>,
13 De: serde::Deserializer<'de>,
14{
15 let data = TensorData::deserialize(deserializer).map_err(|e| {
16 serde::de::Error::custom(format!(
17 "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\n",
18 e
19 ))
20 })?;
21 let data = if let DType::QFloat(_) = data.dtype {
22 data } else {
24 data.convert::<E>()
25 };
26 Ok(data)
27}
28
29#[derive(new, Clone, Debug)]
32pub struct FloatTensorSerde<S: PrecisionSettings> {
33 data: TensorData,
34 _e: PhantomData<S::FloatElem>,
35}
36
37#[derive(new, Clone, Debug)]
40pub struct IntTensorSerde<S: PrecisionSettings> {
41 data: TensorData,
42 _e: PhantomData<S::IntElem>,
43}
44
45#[derive(new, Clone, Debug)]
47pub struct BoolTensorSerde {
48 data: TensorData,
49}
50
51impl<S: PrecisionSettings> Serialize for FloatTensorSerde<S> {
54 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
55 where
56 Se: serde::Serializer,
57 {
58 self.data.serialize(serializer)
59 }
60}
61
62impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
63 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
64 where
65 De: serde::Deserializer<'de>,
66 {
67 let data = deserialize_data::<S::FloatElem, De>(deserializer)?;
68
69 Ok(Self::new(data))
70 }
71}
72
73impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
74 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
75 where
76 Se: serde::Serializer,
77 {
78 self.data.serialize(serializer)
79 }
80}
81
82impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde<S> {
83 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
84 where
85 De: serde::Deserializer<'de>,
86 {
87 let data = deserialize_data::<S::IntElem, De>(deserializer)?;
88
89 Ok(Self::new(data))
90 }
91}
92
93impl Serialize for BoolTensorSerde {
94 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
95 where
96 Se: serde::Serializer,
97 {
98 self.data.serialize(serializer)
99 }
100}
101
102impl<'de> Deserialize<'de> for BoolTensorSerde {
103 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
104 where
105 De: serde::Deserializer<'de>,
106 {
107 let data = deserialize_data::<bool, De>(deserializer)?;
108
109 Ok(Self::new(data))
110 }
111}
112
113impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
116 type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
117
118 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
119 let data = self.into_data();
120 let data = if let DType::QFloat(_) = data.dtype {
121 data } else {
123 data.convert::<S::FloatElem>()
124 };
125 FloatTensorSerde::new(data)
126 }
127
128 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
129 let data = if let DType::QFloat(_) = item.data.dtype {
130 item.data } else {
132 item.data.convert::<B::FloatElem>()
133 };
134 Tensor::from_data(data, device)
135 }
136}
137
138impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
139 type Item<S: PrecisionSettings> = IntTensorSerde<S>;
140
141 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
142 IntTensorSerde::new(self.into_data().convert::<S::IntElem>())
143 }
144
145 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
146 Tensor::from_data(item.data.convert::<B::IntElem>(), device)
147 }
148}
149
150impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
151 type Item<S: PrecisionSettings> = BoolTensorSerde;
152
153 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
154 BoolTensorSerde::new(self.into_data())
155 }
156
157 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
158 Tensor::from_data(item.data, device)
159 }
160}