1use alloc::{string::String, vec, vec::Vec};
2use core::{fmt, marker::PhantomData};
3
4use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde};
5use super::{PrecisionSettings, Record};
6use crate::module::{Param, ParamId};
7
8use burn_tensor::{Bool, Int, Tensor, backend::Backend};
9
10use hashbrown::HashMap;
11use serde::{
12 Deserialize, Serialize,
13 de::{Error, SeqAccess, Visitor},
14 ser::SerializeTuple,
15};
16
17impl<B> Record<B> for ()
18where
19 B: Backend,
20{
21 type Item<S: PrecisionSettings> = ();
22
23 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {}
24
25 fn from_item<S: PrecisionSettings>(_item: Self::Item<S>, _device: &B::Device) -> Self {}
26}
27
28impl<T, B> Record<B> for Vec<T>
29where
30 T: Record<B>,
31 B: Backend,
32{
33 type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
34
35 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
36 self.into_iter().map(Record::into_item).collect()
37 }
38
39 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
40 item.into_iter()
41 .map(|i| Record::from_item(i, device))
42 .collect()
43 }
44}
45
46impl<T, B> Record<B> for Option<T>
47where
48 T: Record<B>,
49 B: Backend,
50{
51 type Item<S: PrecisionSettings> = Option<T::Item<S>>;
52
53 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
54 self.map(Record::into_item)
55 }
56
57 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
58 item.map(|i| Record::from_item(i, device))
59 }
60}
61
62impl<const N: usize, T, B> Record<B> for [T; N]
63where
64 T: Record<B>,
65 B: Backend,
66{
67 type Item<S: PrecisionSettings> = Array<N, T::Item<S>>;
73
74 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
75 Array(self.map(Record::into_item))
76 }
77
78 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
79 item.0.map(|i| Record::from_item(i, device))
80 }
81}
82
83macro_rules! impl_record_tuple {
89 ([$($r:ident),*][$($i:tt),*]) => {
92 impl<B, $($r,)*> Record<B> for ($($r,)*)
93 where
94 B: Backend,
95 $($r: Record<B>),*
96 {
97 type Item<S: PrecisionSettings> = ($($r::Item<S>,)*);
98
99 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
100 ($(self.$i.into_item(),)*)
101 }
102
103 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
104 ($(Record::from_item(item.$i, device),)*)
105 }
106 }
107 };
108}
109
110impl_record_tuple!([R0][0]);
111impl_record_tuple!([R0, R1][0, 1]);
112impl_record_tuple!([R0, R1, R2][0, 1, 2]);
113impl_record_tuple!([R0, R1, R2, R3][0, 1, 2, 3]);
114impl_record_tuple!([R0, R1, R2, R3, R4][0, 1, 2, 3, 4]);
115impl_record_tuple!([R0, R1, R2, R3, R4, R5][0, 1, 2, 3, 4, 5]);
116impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6][0, 1, 2, 3, 4, 5, 6]);
117impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7][0, 1, 2, 3, 4, 5, 6, 7]);
118impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
119impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8, R9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
120
121impl<T, B> Record<B> for HashMap<ParamId, T>
122where
123 T: Record<B>,
124 B: Backend,
125{
126 type Item<S: PrecisionSettings> = HashMap<String, T::Item<S>>;
127
128 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
129 let mut items = HashMap::with_capacity(self.len());
130 self.into_iter().for_each(|(id, record)| {
131 items.insert(id.serialize(), record.into_item());
132 });
133 items
134 }
135
136 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
137 let mut record = HashMap::with_capacity(item.len());
138 item.into_iter().for_each(|(id, item)| {
139 record.insert(ParamId::deserialize(&id), T::from_item(item, device));
140 });
141 record
142 }
143}
144
145#[derive(new, Debug, Clone, Serialize, Deserialize)]
147pub struct ParamSerde<T> {
148 id: String,
149 param: T,
150}
151
152impl<B, const D: usize> Record<B> for Param<Tensor<B, D>>
153where
154 B: Backend,
155{
156 type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;
157
158 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
159 let (id, tensor, mapper) = self.consume();
160 let tensor = mapper.on_save(tensor);
161 ParamSerde::new(id.serialize(), tensor.into_item())
162 }
163
164 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
165 B::memory_persistent_allocations(device, item, |item| {
166 Param::initialized(
167 ParamId::deserialize(&item.id),
168 Tensor::from_item(item.param, device).require_grad(), )
171 })
172 }
173}
174
175impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Int>>
176where
177 B: Backend,
178{
179 type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;
180
181 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
182 let (id, tensor, mapper) = self.consume();
183 let tensor = mapper.on_save(tensor);
184 ParamSerde::new(id.serialize(), tensor.into_item())
185 }
186
187 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
188 B::memory_persistent_allocations(device, item, |item| {
189 Param::initialized(
190 ParamId::deserialize(&item.id),
191 Tensor::from_item(item.param, device),
192 )
193 })
194 }
195}
196
197impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Bool>>
198where
199 B: Backend,
200{
201 type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;
202
203 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
204 let (id, tensor, mapper) = self.consume();
205 let tensor = mapper.on_save(tensor);
206 ParamSerde::new(id.serialize(), tensor.into_item::<S>())
207 }
208
209 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
210 B::memory_persistent_allocations(device, item, |item| {
211 Param::initialized(
212 ParamId::deserialize(&item.id),
213 Tensor::from_item::<S>(item.param, device),
214 )
215 })
216 }
217}
218
219macro_rules! primitive {
221 ($type:ty) => {
222 impl<B: Backend> Record<B> for $type {
223 type Item<S: PrecisionSettings> = $type;
224
225 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
226 self
227 }
228
229 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
230 item
231 }
232 }
233 };
234}
235
236primitive!(alloc::string::String);
238primitive!(bool);
239
240primitive!(f64);
242primitive!(f32);
243
244primitive!(half::bf16);
245primitive!(half::f16);
246
247primitive!(usize);
249primitive!(u64);
250primitive!(u32);
251primitive!(u16);
252primitive!(u8);
253
254primitive!(isize);
256primitive!(i64);
257primitive!(i32);
258primitive!(i16);
259primitive!(i8);
260
261#[derive(Clone)]
269pub struct Array<const N: usize, T>([T; N]);
270
271impl<T: Serialize, const N: usize> Serialize for Array<N, T> {
272 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
273 where
274 S: serde::Serializer,
275 {
276 let mut seq = serializer.serialize_tuple(self.0.len())?;
277 for element in &self.0 {
278 seq.serialize_element(element)?;
279 }
280 seq.end()
281 }
282}
283
284impl<'de, T, const N: usize> Deserialize<'de> for Array<N, T>
285where
286 T: Deserialize<'de>,
287{
288 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
289 where
290 D: serde::Deserializer<'de>,
291 {
292 struct ArrayVisitor<T, const N: usize> {
293 marker: PhantomData<T>,
294 }
295
296 impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
297 where
298 T: Deserialize<'de>,
299 {
300 type Value = Array<N, T>;
301
302 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
303 formatter.write_str("a fixed size array")
304 }
305
306 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
307 where
308 A: SeqAccess<'de>,
309 {
310 let mut items = vec![];
311
312 for i in 0..N {
313 let item = seq
314 .next_element()?
315 .ok_or_else(|| Error::invalid_length(i, &self))?;
316 items.push(item);
317 }
318
319 let array: [T; N] = items
320 .into_iter()
321 .collect::<Vec<_>>()
322 .try_into()
323 .map_err(|_| "An array of size {N}")
324 .unwrap();
325
326 Ok(Array(array))
327 }
328 }
329
330 deserializer.deserialize_tuple(
331 N,
332 ArrayVisitor {
333 marker: PhantomData,
334 },
335 )
336 }
337}