burn_core/record/
tensor.rs

1use core::marker::PhantomData;
2
3use super::{PrecisionSettings, Record};
4use burn_tensor::{Bool, DType, Element, Int, Tensor, TensorData, backend::Backend};
5use serde::{Deserialize, Serialize};
6
7use alloc::format;
8
9/// Deserialize the value into [`TensorData`].
10fn deserialize_data<'de, E, De>(deserializer: De) -> Result<TensorData, De::Error>
11where
12    E: Element + Deserialize<'de>,
13    De: serde::Deserializer<'de>,
14{
15    let data = TensorData::deserialize(deserializer).map_err(|e| {
16        serde::de::Error::custom(format!(
17            "{e:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\n"
18        ))
19    })?;
20    let data = if let DType::QFloat(_) = data.dtype {
21        data // do not convert quantized tensors
22    } else {
23        data.convert::<E>()
24    };
25    Ok(data)
26}
27
28/// This struct implements serde to lazily serialize and deserialize a float tensor
29/// using the given [record settings](RecordSettings).
30#[derive(new, Clone, Debug)]
31pub struct FloatTensorSerde<S: PrecisionSettings> {
32    data: TensorData,
33    _e: PhantomData<S::FloatElem>,
34}
35
36/// This struct implements serde to lazily serialize and deserialize an int tensor
37/// using the given [record settings](RecordSettings).
38#[derive(new, Clone, Debug)]
39pub struct IntTensorSerde<S: PrecisionSettings> {
40    data: TensorData,
41    _e: PhantomData<S::IntElem>,
42}
43
44/// This struct implements serde to lazily serialize and deserialize an bool tensor.
45#[derive(new, Clone, Debug)]
46pub struct BoolTensorSerde {
47    data: TensorData,
48}
49
50// --- SERDE IMPLEMENTATIONS --- //
51
52impl<S: PrecisionSettings> Serialize for FloatTensorSerde<S> {
53    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
54    where
55        Se: serde::Serializer,
56    {
57        self.data.serialize(serializer)
58    }
59}
60
61impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
62    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
63    where
64        De: serde::Deserializer<'de>,
65    {
66        let data = deserialize_data::<S::FloatElem, De>(deserializer)?;
67
68        Ok(Self::new(data))
69    }
70}
71
72impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
73    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
74    where
75        Se: serde::Serializer,
76    {
77        self.data.serialize(serializer)
78    }
79}
80
81impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde<S> {
82    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
83    where
84        De: serde::Deserializer<'de>,
85    {
86        let data = deserialize_data::<S::IntElem, De>(deserializer)?;
87
88        Ok(Self::new(data))
89    }
90}
91
92impl Serialize for BoolTensorSerde {
93    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
94    where
95        Se: serde::Serializer,
96    {
97        self.data.serialize(serializer)
98    }
99}
100
101impl<'de> Deserialize<'de> for BoolTensorSerde {
102    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
103    where
104        De: serde::Deserializer<'de>,
105    {
106        let data = deserialize_data::<bool, De>(deserializer)?;
107
108        Ok(Self::new(data))
109    }
110}
111
112// --- RECORD IMPLEMENTATIONS --- //
113
114impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
115    type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
116
117    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
118        let data = self.into_data();
119        let data = if let DType::QFloat(_) = data.dtype {
120            data // do not convert quantized tensors
121        } else {
122            data.convert::<S::FloatElem>()
123        };
124        FloatTensorSerde::new(data)
125    }
126
127    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
128        let data = if let DType::QFloat(_) = item.data.dtype {
129            item.data // do not convert quantized tensors
130        } else {
131            item.data.convert::<B::FloatElem>()
132        };
133        Tensor::from_data(data, device)
134    }
135}
136
137impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
138    type Item<S: PrecisionSettings> = IntTensorSerde<S>;
139
140    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
141        IntTensorSerde::new(self.into_data().convert::<S::IntElem>())
142    }
143
144    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
145        Tensor::from_data(item.data.convert::<B::IntElem>(), device)
146    }
147}
148
149impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
150    type Item<S: PrecisionSettings> = BoolTensorSerde;
151
152    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
153        BoolTensorSerde::new(self.into_data())
154    }
155
156    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
157        Tensor::from_data(item.data, device)
158    }
159}