burn_core/record/
primitive.rs

1use alloc::{string::String, vec, vec::Vec};
2use core::{fmt, marker::PhantomData};
3
4use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde};
5use super::{PrecisionSettings, Record};
6use crate::module::{Param, ParamId};
7
8use burn_tensor::{Bool, Int, Tensor, backend::Backend};
9
10use hashbrown::HashMap;
11use serde::{
12    Deserialize, Serialize,
13    de::{Error, SeqAccess, Visitor},
14    ser::SerializeTuple,
15};
16
17impl<B> Record<B> for ()
18where
19    B: Backend,
20{
21    type Item<S: PrecisionSettings> = ();
22
23    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {}
24
25    fn from_item<S: PrecisionSettings>(_item: Self::Item<S>, _device: &B::Device) -> Self {}
26}
27
28impl<T, B> Record<B> for Vec<T>
29where
30    T: Record<B>,
31    B: Backend,
32{
33    type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
34
35    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
36        self.into_iter().map(Record::into_item).collect()
37    }
38
39    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
40        item.into_iter()
41            .map(|i| Record::from_item(i, device))
42            .collect()
43    }
44}
45
46impl<T, B> Record<B> for Option<T>
47where
48    T: Record<B>,
49    B: Backend,
50{
51    type Item<S: PrecisionSettings> = Option<T::Item<S>>;
52
53    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
54        self.map(Record::into_item)
55    }
56
57    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
58        item.map(|i| Record::from_item(i, device))
59    }
60}
61
62impl<const N: usize, T, B> Record<B> for [T; N]
63where
64    T: Record<B>,
65    B: Backend,
66{
67    /// The record item is an array of the record item of the elements.
68    /// The reason why we wrap the array in a struct is because serde does not support
69    /// deserializing arrays of variable size,
70    /// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937).
71    /// for backward compatibility reasons. Serde APIs were created before const generics.
72    type Item<S: PrecisionSettings> = Array<N, T::Item<S>>;
73
74    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
75        Array(self.map(Record::into_item))
76    }
77
78    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
79        item.0.map(|i| Record::from_item(i, device))
80    }
81}
82
83/// A macro for generating implementations for tuple records of different sizes.
84/// For example: `impl_record_tuple!([R0, R1][0, 1])`.
85/// Would generate an implementation for a tuple of size 2.
86/// For this macro to work properly, please adhere to the convention:
87/// `impl_record_tuple!([R0, R1, ..., Rn][0, 1, ..., n])`.
88macro_rules! impl_record_tuple {
89    // `$r` represents the generic records.
90    // `$i` represents the indices of the records in the tuple.
91    ([$($r:ident),*][$($i:tt),*]) => {
92        impl<B, $($r,)*> Record<B> for ($($r,)*)
93        where
94            B: Backend,
95            $($r: Record<B>),*
96        {
97            type Item<S: PrecisionSettings> = ($($r::Item<S>,)*);
98
99            fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
100                ($(self.$i.into_item(),)*)
101            }
102
103            fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
104                ($(Record::from_item(item.$i, device),)*)
105            }
106        }
107    };
108}
109
110impl_record_tuple!([R0, R1][0, 1]);
111impl_record_tuple!([R0, R1, R2][0, 1, 2]);
112impl_record_tuple!([R0, R1, R2, R3][0, 1, 2, 3]);
113impl_record_tuple!([R0, R1, R2, R3, R4][0, 1, 2, 3, 4]);
114impl_record_tuple!([R0, R1, R2, R3, R4, R5][0, 1, 2, 3, 4, 5]);
115impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6][0, 1, 2, 3, 4, 5, 6]);
116impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7][0, 1, 2, 3, 4, 5, 6, 7]);
117impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
118impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8, R9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
119
120impl<T, B> Record<B> for HashMap<ParamId, T>
121where
122    T: Record<B>,
123    B: Backend,
124{
125    type Item<S: PrecisionSettings> = HashMap<String, T::Item<S>>;
126
127    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
128        let mut items = HashMap::with_capacity(self.len());
129        self.into_iter().for_each(|(id, record)| {
130            items.insert(id.serialize(), record.into_item());
131        });
132        items
133    }
134
135    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
136        let mut record = HashMap::with_capacity(item.len());
137        item.into_iter().for_each(|(id, item)| {
138            record.insert(ParamId::deserialize(&id), T::from_item(item, device));
139        });
140        record
141    }
142}
143
144/// (De)serialize parameters into a clean format.
145#[derive(new, Debug, Clone, Serialize, Deserialize)]
146pub struct ParamSerde<T> {
147    id: String,
148    param: T,
149}
150
151impl<B, const D: usize> Record<B> for Param<Tensor<B, D>>
152where
153    B: Backend,
154{
155    type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;
156
157    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
158        let (id, tensor, mapper) = self.consume();
159        let tensor = mapper.on_save(tensor);
160        ParamSerde::new(id.serialize(), tensor.into_item())
161    }
162
163    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
164        B::memory_persistent_allocations(device, item, |item| {
165            Param::initialized(
166                ParamId::deserialize(&item.id),
167                Tensor::from_item(item.param, device).require_grad(), // Same behavior as when we create a new
168                                                                      // Param from a tensor.
169            )
170        })
171    }
172}
173
174impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Int>>
175where
176    B: Backend,
177{
178    type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;
179
180    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
181        let (id, tensor, mapper) = self.consume();
182        let tensor = mapper.on_save(tensor);
183        ParamSerde::new(id.serialize(), tensor.into_item())
184    }
185
186    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
187        B::memory_persistent_allocations(device, item, |item| {
188            Param::initialized(
189                ParamId::deserialize(&item.id),
190                Tensor::from_item(item.param, device),
191            )
192        })
193    }
194}
195
196impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Bool>>
197where
198    B: Backend,
199{
200    type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;
201
202    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
203        let (id, tensor, mapper) = self.consume();
204        let tensor = mapper.on_save(tensor);
205        ParamSerde::new(id.serialize(), tensor.into_item::<S>())
206    }
207
208    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
209        B::memory_persistent_allocations(device, item, |item| {
210            Param::initialized(
211                ParamId::deserialize(&item.id),
212                Tensor::from_item::<S>(item.param, device),
213            )
214        })
215    }
216}
217
218// Type that can be serialized as is without any conversion.
219macro_rules! primitive {
220    ($type:ty) => {
221        impl<B: Backend> Record<B> for $type {
222            type Item<S: PrecisionSettings> = $type;
223
224            fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
225                self
226            }
227
228            fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
229                item
230            }
231        }
232    };
233}
234
235// General Types
236primitive!(alloc::string::String);
237primitive!(bool);
238
239// Float Types
240primitive!(f64);
241primitive!(f32);
242
243primitive!(half::bf16);
244primitive!(half::f16);
245
246// Unsigned Integer Types
247primitive!(usize);
248primitive!(u64);
249primitive!(u32);
250primitive!(u16);
251primitive!(u8);
252
253// Signed Integer Types
254primitive!(isize);
255primitive!(i64);
256primitive!(i32);
257primitive!(i16);
258primitive!(i8);
259
260/// A wrapper around an array of size N, so that it can be serialized and deserialized
261/// using serde.
262///
263/// The reason why we wrap the array in a struct is because serde does not support
264/// deserializing arrays of variable size,
265/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937)
266/// for backward compatibility reasons. Serde APIs were created before const generics.
267#[derive(Clone)]
268pub struct Array<const N: usize, T>([T; N]);
269
270impl<T: Serialize, const N: usize> Serialize for Array<N, T> {
271    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
272    where
273        S: serde::Serializer,
274    {
275        let mut seq = serializer.serialize_tuple(self.0.len())?;
276        for element in &self.0 {
277            seq.serialize_element(element)?;
278        }
279        seq.end()
280    }
281}
282
283impl<'de, T, const N: usize> Deserialize<'de> for Array<N, T>
284where
285    T: Deserialize<'de>,
286{
287    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
288    where
289        D: serde::Deserializer<'de>,
290    {
291        struct ArrayVisitor<T, const N: usize> {
292            marker: PhantomData<T>,
293        }
294
295        impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
296        where
297            T: Deserialize<'de>,
298        {
299            type Value = Array<N, T>;
300
301            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
302                formatter.write_str("a fixed size array")
303            }
304
305            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
306            where
307                A: SeqAccess<'de>,
308            {
309                let mut items = vec![];
310
311                for i in 0..N {
312                    let item = seq
313                        .next_element()?
314                        .ok_or_else(|| Error::invalid_length(i, &self))?;
315                    items.push(item);
316                }
317
318                let array: [T; N] = items
319                    .into_iter()
320                    .collect::<Vec<_>>()
321                    .try_into()
322                    .map_err(|_| "An array of size {N}")
323                    .unwrap();
324
325                Ok(Array(array))
326            }
327        }
328
329        deserializer.deserialize_tuple(
330            N,
331            ArrayVisitor {
332                marker: PhantomData,
333            },
334        )
335    }
336}