burn_core/record/
tensor.rs

1use core::marker::PhantomData;
2
3use super::{PrecisionSettings, Record};
4use burn_tensor::{backend::Backend, Bool, DType, Element, Int, Tensor, TensorData};
5use serde::{Deserialize, Serialize};
6
7#[cfg(not(feature = "record-backward-compat"))]
8use alloc::format;
9#[cfg(feature = "record-backward-compat")]
10use burn_tensor::DataSerialize;
11
12/// Versioned serde data deserialization to maintain backward compatibility between formats.
13#[cfg(feature = "record-backward-compat")]
14#[allow(deprecated)]
15#[derive(Serialize, Deserialize)]
16#[serde(untagged)]
17enum TensorDataSerde<E> {
18    V1(DataSerialize<E>),
19    V2(TensorData),
20}
21
22/// Deserialize the value into [`TensorData`].
23fn deserialize_data<'de, E, De>(deserializer: De) -> Result<TensorData, De::Error>
24where
25    E: Element + Deserialize<'de>,
26    De: serde::Deserializer<'de>,
27{
28    #[cfg(feature = "record-backward-compat")]
29    {
30        let data = match TensorDataSerde::<E>::deserialize(deserializer)? {
31            TensorDataSerde::V1(data) => data.into_tensor_data(),
32            // NOTE: loading f32 weights with f16 precision will deserialize the f32 weights (bytes) first and then convert to f16
33            TensorDataSerde::V2(data) => data.convert::<E>(),
34        };
35        Ok(data)
36    }
37
38    #[cfg(not(feature = "record-backward-compat"))]
39    {
40        let data = TensorData::deserialize(deserializer).map_err(|e| {
41            serde::de::Error::custom(format!(
42                "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag. Once you have saved the record in the new format, you can disable the feature flag.\n",
43                e
44            ))
45        })?;
46        let data = if let DType::QFloat(_) = data.dtype {
47            data // do not convert quantized tensors
48        } else {
49            data.convert::<E>()
50        };
51        Ok(data)
52    }
53}
54
55/// This struct implements serde to lazily serialize and deserialize a float tensor
56/// using the given [record settings](RecordSettings).
57#[derive(new, Clone, Debug)]
58pub struct FloatTensorSerde<S: PrecisionSettings> {
59    data: TensorData,
60    _e: PhantomData<S::FloatElem>,
61}
62
63/// This struct implements serde to lazily serialize and deserialize an int tensor
64/// using the given [record settings](RecordSettings).
65#[derive(new, Clone, Debug)]
66pub struct IntTensorSerde<S: PrecisionSettings> {
67    data: TensorData,
68    _e: PhantomData<S::IntElem>,
69}
70
71/// This struct implements serde to lazily serialize and deserialize an bool tensor.
72#[derive(new, Clone, Debug)]
73pub struct BoolTensorSerde {
74    data: TensorData,
75}
76
77// --- SERDE IMPLEMENTATIONS --- //
78
79impl<S: PrecisionSettings> Serialize for FloatTensorSerde<S> {
80    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
81    where
82        Se: serde::Serializer,
83    {
84        self.data.serialize(serializer)
85    }
86}
87
88impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
89    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
90    where
91        De: serde::Deserializer<'de>,
92    {
93        let data = deserialize_data::<S::FloatElem, De>(deserializer)?;
94
95        Ok(Self::new(data))
96    }
97}
98
99impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
100    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
101    where
102        Se: serde::Serializer,
103    {
104        self.data.serialize(serializer)
105    }
106}
107
108impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde<S> {
109    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
110    where
111        De: serde::Deserializer<'de>,
112    {
113        let data = deserialize_data::<S::IntElem, De>(deserializer)?;
114
115        Ok(Self::new(data))
116    }
117}
118
119impl Serialize for BoolTensorSerde {
120    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
121    where
122        Se: serde::Serializer,
123    {
124        self.data.serialize(serializer)
125    }
126}
127
128impl<'de> Deserialize<'de> for BoolTensorSerde {
129    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
130    where
131        De: serde::Deserializer<'de>,
132    {
133        let data = deserialize_data::<bool, De>(deserializer)?;
134
135        Ok(Self::new(data))
136    }
137}
138
139// --- RECORD IMPLEMENTATIONS --- //
140
141impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
142    type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
143
144    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
145        let data = self.into_data();
146        let data = if let DType::QFloat(_) = data.dtype {
147            data // do not convert quantized tensors
148        } else {
149            data.convert::<S::FloatElem>()
150        };
151        FloatTensorSerde::new(data)
152    }
153
154    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
155        let data = if let DType::QFloat(_) = item.data.dtype {
156            item.data // do not convert quantized tensors
157        } else {
158            item.data.convert::<B::FloatElem>()
159        };
160        Tensor::from_data(data, device)
161    }
162}
163
164impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
165    type Item<S: PrecisionSettings> = IntTensorSerde<S>;
166
167    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
168        IntTensorSerde::new(self.into_data().convert::<S::IntElem>())
169    }
170
171    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
172        Tensor::from_data(item.data.convert::<B::IntElem>(), device)
173    }
174}
175
176impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
177    type Item<S: PrecisionSettings> = BoolTensorSerde;
178
179    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
180        BoolTensorSerde::new(self.into_data())
181    }
182
183    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
184        Tensor::from_data(item.data, device)
185    }
186}