Skip to main content

burn_core/record/
primitive.rs

1use alloc::{string::String, vec, vec::Vec};
2use core::{fmt, marker::PhantomData};
3
4use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde};
5use super::{PrecisionSettings, Record};
6use crate::module::{Param, ParamId};
7
8use burn_tensor::{Bool, Int, Tensor, backend::Backend};
9
10use hashbrown::HashMap;
11use serde::{
12    Deserialize, Serialize,
13    de::{Error, SeqAccess, Visitor},
14    ser::SerializeTuple,
15};
16
17impl<B> Record<B> for ()
18where
19    B: Backend,
20{
21    type Item<S: PrecisionSettings> = ();
22
23    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {}
24
25    fn from_item<S: PrecisionSettings>(_item: Self::Item<S>, _device: &B::Device) -> Self {}
26}
27
28impl<T, B> Record<B> for Vec<T>
29where
30    T: Record<B>,
31    B: Backend,
32{
33    type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
34
35    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
36        self.into_iter().map(Record::into_item).collect()
37    }
38
39    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
40        item.into_iter()
41            .map(|i| Record::from_item(i, device))
42            .collect()
43    }
44}
45
46impl<T, B> Record<B> for Option<T>
47where
48    T: Record<B>,
49    B: Backend,
50{
51    type Item<S: PrecisionSettings> = Option<T::Item<S>>;
52
53    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
54        self.map(Record::into_item)
55    }
56
57    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
58        item.map(|i| Record::from_item(i, device))
59    }
60}
61
62impl<const N: usize, T, B> Record<B> for [T; N]
63where
64    T: Record<B>,
65    B: Backend,
66{
67    /// The record item is an array of the record item of the elements.
68    /// The reason why we wrap the array in a struct is because serde does not support
69    /// deserializing arrays of variable size,
70    /// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937).
71    /// for backward compatibility reasons. Serde APIs were created before const generics.
72    type Item<S: PrecisionSettings> = Array<N, T::Item<S>>;
73
74    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
75        Array(self.map(Record::into_item))
76    }
77
78    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
79        item.0.map(|i| Record::from_item(i, device))
80    }
81}
82
83/// A macro for generating implementations for tuple records of different sizes.
84/// For example: `impl_record_tuple!([R0, R1][0, 1])`.
85/// Would generate an implementation for a tuple of size 2.
86/// For this macro to work properly, please adhere to the convention:
87/// `impl_record_tuple!([R0, R1, ..., Rn][0, 1, ..., n])`.
88macro_rules! impl_record_tuple {
89    // `$r` represents the generic records.
90    // `$i` represents the indices of the records in the tuple.
91    ([$($r:ident),*][$($i:tt),*]) => {
92        impl<B, $($r,)*> Record<B> for ($($r,)*)
93        where
94            B: Backend,
95            $($r: Record<B>),*
96        {
97            type Item<S: PrecisionSettings> = ($($r::Item<S>,)*);
98
99            fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
100                ($(self.$i.into_item(),)*)
101            }
102
103            fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
104                ($(Record::from_item(item.$i, device),)*)
105            }
106        }
107    };
108}
109
110impl_record_tuple!([R0][0]);
111impl_record_tuple!([R0, R1][0, 1]);
112impl_record_tuple!([R0, R1, R2][0, 1, 2]);
113impl_record_tuple!([R0, R1, R2, R3][0, 1, 2, 3]);
114impl_record_tuple!([R0, R1, R2, R3, R4][0, 1, 2, 3, 4]);
115impl_record_tuple!([R0, R1, R2, R3, R4, R5][0, 1, 2, 3, 4, 5]);
116impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6][0, 1, 2, 3, 4, 5, 6]);
117impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7][0, 1, 2, 3, 4, 5, 6, 7]);
118impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
119impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8, R9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
120
121impl<T, B> Record<B> for HashMap<ParamId, T>
122where
123    T: Record<B>,
124    B: Backend,
125{
126    type Item<S: PrecisionSettings> = HashMap<String, T::Item<S>>;
127
128    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
129        let mut items = HashMap::with_capacity(self.len());
130        self.into_iter().for_each(|(id, record)| {
131            items.insert(id.serialize(), record.into_item());
132        });
133        items
134    }
135
136    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
137        let mut record = HashMap::with_capacity(item.len());
138        item.into_iter().for_each(|(id, item)| {
139            record.insert(ParamId::deserialize(&id), T::from_item(item, device));
140        });
141        record
142    }
143}
144
145/// (De)serialize parameters into a clean format.
146#[derive(new, Debug, Clone, Serialize, Deserialize)]
147pub struct ParamSerde<T> {
148    id: String,
149    param: T,
150}
151
152impl<B, const D: usize> Record<B> for Param<Tensor<B, D>>
153where
154    B: Backend,
155{
156    type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;
157
158    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
159        let (id, tensor, mapper) = self.consume();
160        let tensor = mapper.on_save(tensor);
161        ParamSerde::new(id.serialize(), tensor.into_item())
162    }
163
164    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
165        B::memory_persistent_allocations(device, item, |item| {
166            Param::initialized(
167                ParamId::deserialize(&item.id),
168                Tensor::from_item(item.param, device).require_grad(), // Same behavior as when we create a new
169                                                                      // Param from a tensor.
170            )
171        })
172    }
173}
174
175impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Int>>
176where
177    B: Backend,
178{
179    type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;
180
181    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
182        let (id, tensor, mapper) = self.consume();
183        let tensor = mapper.on_save(tensor);
184        ParamSerde::new(id.serialize(), tensor.into_item())
185    }
186
187    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
188        B::memory_persistent_allocations(device, item, |item| {
189            Param::initialized(
190                ParamId::deserialize(&item.id),
191                Tensor::from_item(item.param, device),
192            )
193        })
194    }
195}
196
197impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Bool>>
198where
199    B: Backend,
200{
201    type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;
202
203    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
204        let (id, tensor, mapper) = self.consume();
205        let tensor = mapper.on_save(tensor);
206        ParamSerde::new(id.serialize(), tensor.into_item::<S>())
207    }
208
209    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
210        B::memory_persistent_allocations(device, item, |item| {
211            Param::initialized(
212                ParamId::deserialize(&item.id),
213                Tensor::from_item::<S>(item.param, device),
214            )
215        })
216    }
217}
218
219// Type that can be serialized as is without any conversion.
220macro_rules! primitive {
221    ($type:ty) => {
222        impl<B: Backend> Record<B> for $type {
223            type Item<S: PrecisionSettings> = $type;
224
225            fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
226                self
227            }
228
229            fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
230                item
231            }
232        }
233    };
234}
235
236// General Types
237primitive!(alloc::string::String);
238primitive!(bool);
239
240// Float Types
241primitive!(f64);
242primitive!(f32);
243
244primitive!(half::bf16);
245primitive!(half::f16);
246
247// Unsigned Integer Types
248primitive!(usize);
249primitive!(u64);
250primitive!(u32);
251primitive!(u16);
252primitive!(u8);
253
254// Signed Integer Types
255primitive!(isize);
256primitive!(i64);
257primitive!(i32);
258primitive!(i16);
259primitive!(i8);
260
261/// A wrapper around an array of size N, so that it can be serialized and deserialized
262/// using serde.
263///
264/// The reason why we wrap the array in a struct is because serde does not support
265/// deserializing arrays of variable size,
266/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937)
267/// for backward compatibility reasons. Serde APIs were created before const generics.
268#[derive(Clone)]
269pub struct Array<const N: usize, T>([T; N]);
270
271impl<T: Serialize, const N: usize> Serialize for Array<N, T> {
272    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
273    where
274        S: serde::Serializer,
275    {
276        let mut seq = serializer.serialize_tuple(self.0.len())?;
277        for element in &self.0 {
278            seq.serialize_element(element)?;
279        }
280        seq.end()
281    }
282}
283
284impl<'de, T, const N: usize> Deserialize<'de> for Array<N, T>
285where
286    T: Deserialize<'de>,
287{
288    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
289    where
290        D: serde::Deserializer<'de>,
291    {
292        struct ArrayVisitor<T, const N: usize> {
293            marker: PhantomData<T>,
294        }
295
296        impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
297        where
298            T: Deserialize<'de>,
299        {
300            type Value = Array<N, T>;
301
302            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
303                formatter.write_str("a fixed size array")
304            }
305
306            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
307            where
308                A: SeqAccess<'de>,
309            {
310                let mut items = vec![];
311
312                for i in 0..N {
313                    let item = seq
314                        .next_element()?
315                        .ok_or_else(|| Error::invalid_length(i, &self))?;
316                    items.push(item);
317                }
318
319                let array: [T; N] = items
320                    .into_iter()
321                    .collect::<Vec<_>>()
322                    .try_into()
323                    .map_err(|_| "An array of size {N}")
324                    .unwrap();
325
326                Ok(Array(array))
327            }
328        }
329
330        deserializer.deserialize_tuple(
331            N,
332            ArrayVisitor {
333                marker: PhantomData,
334            },
335        )
336    }
337}