burn_core/record/
tensor.rs

1use core::marker::PhantomData;
2
3use super::{PrecisionSettings, Record};
4use burn_tensor::{Bool, DType, Element, Int, Tensor, TensorData, backend::Backend};
5use serde::{Deserialize, Serialize};
6
7use alloc::format;
8
9/// Deserialize the value into [`TensorData`].
10fn deserialize_data<'de, E, De>(deserializer: De) -> Result<TensorData, De::Error>
11where
12    E: Element + Deserialize<'de>,
13    De: serde::Deserializer<'de>,
14{
15    let data = TensorData::deserialize(deserializer).map_err(|e| {
16        serde::de::Error::custom(format!(
17            "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\n",
18            e
19        ))
20    })?;
21    let data = if let DType::QFloat(_) = data.dtype {
22        data // do not convert quantized tensors
23    } else {
24        data.convert::<E>()
25    };
26    Ok(data)
27}
28
29/// This struct implements serde to lazily serialize and deserialize a float tensor
30/// using the given [record settings](RecordSettings).
31#[derive(new, Clone, Debug)]
32pub struct FloatTensorSerde<S: PrecisionSettings> {
33    data: TensorData,
34    _e: PhantomData<S::FloatElem>,
35}
36
37/// This struct implements serde to lazily serialize and deserialize an int tensor
38/// using the given [record settings](RecordSettings).
39#[derive(new, Clone, Debug)]
40pub struct IntTensorSerde<S: PrecisionSettings> {
41    data: TensorData,
42    _e: PhantomData<S::IntElem>,
43}
44
45/// This struct implements serde to lazily serialize and deserialize an bool tensor.
46#[derive(new, Clone, Debug)]
47pub struct BoolTensorSerde {
48    data: TensorData,
49}
50
51// --- SERDE IMPLEMENTATIONS --- //
52
53impl<S: PrecisionSettings> Serialize for FloatTensorSerde<S> {
54    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
55    where
56        Se: serde::Serializer,
57    {
58        self.data.serialize(serializer)
59    }
60}
61
62impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
63    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
64    where
65        De: serde::Deserializer<'de>,
66    {
67        let data = deserialize_data::<S::FloatElem, De>(deserializer)?;
68
69        Ok(Self::new(data))
70    }
71}
72
73impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
74    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
75    where
76        Se: serde::Serializer,
77    {
78        self.data.serialize(serializer)
79    }
80}
81
82impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde<S> {
83    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
84    where
85        De: serde::Deserializer<'de>,
86    {
87        let data = deserialize_data::<S::IntElem, De>(deserializer)?;
88
89        Ok(Self::new(data))
90    }
91}
92
93impl Serialize for BoolTensorSerde {
94    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
95    where
96        Se: serde::Serializer,
97    {
98        self.data.serialize(serializer)
99    }
100}
101
102impl<'de> Deserialize<'de> for BoolTensorSerde {
103    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
104    where
105        De: serde::Deserializer<'de>,
106    {
107        let data = deserialize_data::<bool, De>(deserializer)?;
108
109        Ok(Self::new(data))
110    }
111}
112
113// --- RECORD IMPLEMENTATIONS --- //
114
115impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
116    type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
117
118    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
119        let data = self.into_data();
120        let data = if let DType::QFloat(_) = data.dtype {
121            data // do not convert quantized tensors
122        } else {
123            data.convert::<S::FloatElem>()
124        };
125        FloatTensorSerde::new(data)
126    }
127
128    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
129        let data = if let DType::QFloat(_) = item.data.dtype {
130            item.data // do not convert quantized tensors
131        } else {
132            item.data.convert::<B::FloatElem>()
133        };
134        Tensor::from_data(data, device)
135    }
136}
137
138impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
139    type Item<S: PrecisionSettings> = IntTensorSerde<S>;
140
141    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
142        IntTensorSerde::new(self.into_data().convert::<S::IntElem>())
143    }
144
145    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
146        Tensor::from_data(item.data.convert::<B::IntElem>(), device)
147    }
148}
149
150impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
151    type Item<S: PrecisionSettings> = BoolTensorSerde;
152
153    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
154        BoolTensorSerde::new(self.into_data())
155    }
156
157    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
158        Tensor::from_data(item.data, device)
159    }
160}