burn_core/record/
primitive.rs

1use alloc::{string::String, vec, vec::Vec};
2use core::{fmt, marker::PhantomData};
3
4use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde};
5use super::{PrecisionSettings, Record};
6use crate::module::{Param, ParamId};
7
8#[allow(deprecated)]
9use burn_tensor::DataSerialize;
10use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor};
11
12use hashbrown::HashMap;
13use serde::{
14    de::{Error, SeqAccess, Visitor},
15    ser::SerializeTuple,
16    Deserialize, Serialize,
17};
18
19impl<B> Record<B> for ()
20where
21    B: Backend,
22{
23    type Item<S: PrecisionSettings> = ();
24
25    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {}
26
27    fn from_item<S: PrecisionSettings>(_item: Self::Item<S>, _device: &B::Device) -> Self {}
28}
29
30impl<T, B> Record<B> for Vec<T>
31where
32    T: Record<B>,
33    B: Backend,
34{
35    type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
36
37    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
38        self.into_iter().map(Record::into_item).collect()
39    }
40
41    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
42        item.into_iter()
43            .map(|i| Record::from_item(i, device))
44            .collect()
45    }
46}
47
48impl<T, B> Record<B> for Option<T>
49where
50    T: Record<B>,
51    B: Backend,
52{
53    type Item<S: PrecisionSettings> = Option<T::Item<S>>;
54
55    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
56        self.map(Record::into_item)
57    }
58
59    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
60        item.map(|i| Record::from_item(i, device))
61    }
62}
63
64impl<const N: usize, T, B> Record<B> for [T; N]
65where
66    T: Record<B>,
67    B: Backend,
68{
69    /// The record item is an array of the record item of the elements.
70    /// The reason why we wrap the array in a struct is because serde does not support
71    /// deserializing arrays of variable size,
72    /// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937).
73    /// for backward compatibility reasons. Serde APIs were created before const generics.
74    type Item<S: PrecisionSettings> = Array<N, T::Item<S>>;
75
76    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
77        Array(self.map(Record::into_item))
78    }
79
80    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
81        item.0.map(|i| Record::from_item(i, device))
82    }
83}
84
85/// A macro for generating implementations for tuple records of different sizes.
86/// For example: `impl_record_tuple!([R0, R1][0, 1])`.
87/// Would generate an implementation for a tuple of size 2.
88/// For this macro to work properly, please adhear to the convention:
89/// `impl_record_tuple!([R0, R1, ..., Rn][0, 1, ..., n])`.
90macro_rules! impl_record_tuple {
91    // `$r` represents the generic records.
92    // `$i` represents the indices of the records in the tuple.
93    ([$($r:ident),*][$($i:tt),*]) => {
94        impl<B, $($r,)*> Record<B> for ($($r,)*)
95        where
96            B: Backend,
97            $($r: Record<B>),*
98        {
99            type Item<S: PrecisionSettings> = ($($r::Item<S>,)*);
100
101            fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
102                ($(self.$i.into_item(),)*)
103            }
104
105            fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
106                ($(Record::from_item(item.$i, device),)*)
107            }
108        }
109    };
110}
111
112impl_record_tuple!([R0, R1][0, 1]);
113impl_record_tuple!([R0, R1, R2][0, 1, 2]);
114impl_record_tuple!([R0, R1, R2, R3][0, 1, 2, 3]);
115impl_record_tuple!([R0, R1, R2, R3, R4][0, 1, 2, 3, 4]);
116impl_record_tuple!([R0, R1, R2, R3, R4, R5][0, 1, 2, 3, 4, 5]);
117impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6][0, 1, 2, 3, 4, 5, 6]);
118impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7][0, 1, 2, 3, 4, 5, 6, 7]);
119impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
120impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8, R9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
121
122impl<T, B> Record<B> for HashMap<ParamId, T>
123where
124    T: Record<B>,
125    B: Backend,
126{
127    type Item<S: PrecisionSettings> = HashMap<String, T::Item<S>>;
128
129    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
130        let mut items = HashMap::with_capacity(self.len());
131        self.into_iter().for_each(|(id, record)| {
132            items.insert(id.serialize(), record.into_item());
133        });
134        items
135    }
136
137    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
138        let mut record = HashMap::with_capacity(item.len());
139        item.into_iter().for_each(|(id, item)| {
140            record.insert(ParamId::deserialize(&id), T::from_item(item, device));
141        });
142        record
143    }
144}
145
146#[allow(deprecated)]
147impl<E, B> Record<B> for DataSerialize<E>
148where
149    E: Element,
150    B: Backend,
151{
152    type Item<S: PrecisionSettings> = DataSerialize<S::FloatElem>;
153
154    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
155        self.convert()
156    }
157
158    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
159        item.convert()
160    }
161}
162
163/// (De)serialize parameters into a clean format.
164#[derive(new, Debug, Clone, Serialize, Deserialize)]
165pub struct ParamSerde<T> {
166    id: String,
167    param: T,
168}
169
170impl<B, const D: usize> Record<B> for Param<Tensor<B, D>>
171where
172    B: Backend,
173{
174    type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;
175
176    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
177        let (id, tensor) = self.consume();
178        ParamSerde::new(id.serialize(), tensor.into_item())
179    }
180
181    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
182        Param::initialized(
183            ParamId::deserialize(&item.id),
184            Tensor::from_item(item.param, device).require_grad(), // Same behavior as when we create a new
185                                                                  // Param from a tensor.
186        )
187    }
188}
189
190impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Int>>
191where
192    B: Backend,
193{
194    type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;
195
196    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
197        let (id, tensor) = self.consume();
198        ParamSerde::new(id.serialize(), tensor.into_item())
199    }
200
201    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
202        Param::initialized(
203            ParamId::deserialize(&item.id),
204            Tensor::from_item(item.param, device),
205        )
206    }
207}
208
209impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Bool>>
210where
211    B: Backend,
212{
213    type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;
214
215    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
216        let (id, tensor) = self.consume();
217        ParamSerde::new(id.serialize(), tensor.into_item::<S>())
218    }
219
220    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
221        Param::initialized(
222            ParamId::deserialize(&item.id),
223            Tensor::from_item::<S>(item.param, device),
224        )
225    }
226}
227
228// Type that can be serialized as is without any conversion.
229macro_rules! primitive {
230    ($type:ty) => {
231        impl<B: Backend> Record<B> for $type {
232            type Item<S: PrecisionSettings> = $type;
233
234            fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
235                self
236            }
237
238            fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
239                item
240            }
241        }
242    };
243}
244
245// General Types
246primitive!(alloc::string::String);
247primitive!(bool);
248
249// Float Types
250primitive!(f64);
251primitive!(f32);
252
253primitive!(half::bf16);
254primitive!(half::f16);
255
256// Unsigned Integer Types
257primitive!(usize);
258primitive!(u64);
259primitive!(u32);
260primitive!(u16);
261primitive!(u8);
262
263// Signed Integer Types
264primitive!(i64);
265primitive!(i32);
266primitive!(i16);
267primitive!(i8);
268
269/// A wrapper around an array of size N, so that it can be serialized and deserialized
270/// using serde.
271///
272/// The reason why we wrap the array in a struct is because serde does not support
273/// deserializing arrays of variable size,
274/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937)
275/// for backward compatibility reasons. Serde APIs were created before const generics.
276pub struct Array<const N: usize, T>([T; N]);
277
278impl<T: Serialize, const N: usize> Serialize for Array<N, T> {
279    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
280    where
281        S: serde::Serializer,
282    {
283        let mut seq = serializer.serialize_tuple(self.0.len())?;
284        for element in &self.0 {
285            seq.serialize_element(element)?;
286        }
287        seq.end()
288    }
289}
290
291impl<'de, T, const N: usize> Deserialize<'de> for Array<N, T>
292where
293    T: Deserialize<'de>,
294{
295    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
296    where
297        D: serde::Deserializer<'de>,
298    {
299        struct ArrayVisitor<T, const N: usize> {
300            marker: PhantomData<T>,
301        }
302
303        impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
304        where
305            T: Deserialize<'de>,
306        {
307            type Value = Array<N, T>;
308
309            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
310                formatter.write_str("a fixed size array")
311            }
312
313            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
314            where
315                A: SeqAccess<'de>,
316            {
317                let mut items = vec![];
318
319                for i in 0..N {
320                    let item = seq
321                        .next_element()?
322                        .ok_or_else(|| Error::invalid_length(i, &self))?;
323                    items.push(item);
324                }
325
326                let array: [T; N] = items
327                    .into_iter()
328                    .collect::<Vec<_>>()
329                    .try_into()
330                    .map_err(|_| "An array of size {N}")
331                    .unwrap();
332
333                Ok(Array(array))
334            }
335        }
336
337        deserializer.deserialize_tuple(
338            N,
339            ArrayVisitor {
340                marker: PhantomData,
341            },
342        )
343    }
344}