1use alloc::{string::String, vec, vec::Vec};
2use core::{fmt, marker::PhantomData};
3
4use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde};
5use super::{PrecisionSettings, Record};
6use crate::module::{Param, ParamId};
7
8#[allow(deprecated)]
9use burn_tensor::DataSerialize;
10use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor};
11
12use hashbrown::HashMap;
13use serde::{
14 de::{Error, SeqAccess, Visitor},
15 ser::SerializeTuple,
16 Deserialize, Serialize,
17};
18
19impl<B> Record<B> for ()
20where
21 B: Backend,
22{
23 type Item<S: PrecisionSettings> = ();
24
25 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {}
26
27 fn from_item<S: PrecisionSettings>(_item: Self::Item<S>, _device: &B::Device) -> Self {}
28}
29
30impl<T, B> Record<B> for Vec<T>
31where
32 T: Record<B>,
33 B: Backend,
34{
35 type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
36
37 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
38 self.into_iter().map(Record::into_item).collect()
39 }
40
41 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
42 item.into_iter()
43 .map(|i| Record::from_item(i, device))
44 .collect()
45 }
46}
47
48impl<T, B> Record<B> for Option<T>
49where
50 T: Record<B>,
51 B: Backend,
52{
53 type Item<S: PrecisionSettings> = Option<T::Item<S>>;
54
55 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
56 self.map(Record::into_item)
57 }
58
59 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
60 item.map(|i| Record::from_item(i, device))
61 }
62}
63
64impl<const N: usize, T, B> Record<B> for [T; N]
65where
66 T: Record<B>,
67 B: Backend,
68{
69 type Item<S: PrecisionSettings> = Array<N, T::Item<S>>;
75
76 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
77 Array(self.map(Record::into_item))
78 }
79
80 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
81 item.0.map(|i| Record::from_item(i, device))
82 }
83}
84
85macro_rules! impl_record_tuple {
91 ([$($r:ident),*][$($i:tt),*]) => {
94 impl<B, $($r,)*> Record<B> for ($($r,)*)
95 where
96 B: Backend,
97 $($r: Record<B>),*
98 {
99 type Item<S: PrecisionSettings> = ($($r::Item<S>,)*);
100
101 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
102 ($(self.$i.into_item(),)*)
103 }
104
105 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
106 ($(Record::from_item(item.$i, device),)*)
107 }
108 }
109 };
110}
111
112impl_record_tuple!([R0, R1][0, 1]);
113impl_record_tuple!([R0, R1, R2][0, 1, 2]);
114impl_record_tuple!([R0, R1, R2, R3][0, 1, 2, 3]);
115impl_record_tuple!([R0, R1, R2, R3, R4][0, 1, 2, 3, 4]);
116impl_record_tuple!([R0, R1, R2, R3, R4, R5][0, 1, 2, 3, 4, 5]);
117impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6][0, 1, 2, 3, 4, 5, 6]);
118impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7][0, 1, 2, 3, 4, 5, 6, 7]);
119impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
120impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8, R9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
121
122impl<T, B> Record<B> for HashMap<ParamId, T>
123where
124 T: Record<B>,
125 B: Backend,
126{
127 type Item<S: PrecisionSettings> = HashMap<String, T::Item<S>>;
128
129 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
130 let mut items = HashMap::with_capacity(self.len());
131 self.into_iter().for_each(|(id, record)| {
132 items.insert(id.serialize(), record.into_item());
133 });
134 items
135 }
136
137 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
138 let mut record = HashMap::with_capacity(item.len());
139 item.into_iter().for_each(|(id, item)| {
140 record.insert(ParamId::deserialize(&id), T::from_item(item, device));
141 });
142 record
143 }
144}
145
146#[allow(deprecated)]
147impl<E, B> Record<B> for DataSerialize<E>
148where
149 E: Element,
150 B: Backend,
151{
152 type Item<S: PrecisionSettings> = DataSerialize<S::FloatElem>;
153
154 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
155 self.convert()
156 }
157
158 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
159 item.convert()
160 }
161}
162
163#[derive(new, Debug, Clone, Serialize, Deserialize)]
165pub struct ParamSerde<T> {
166 id: String,
167 param: T,
168}
169
170impl<B, const D: usize> Record<B> for Param<Tensor<B, D>>
171where
172 B: Backend,
173{
174 type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;
175
176 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
177 let (id, tensor) = self.consume();
178 ParamSerde::new(id.serialize(), tensor.into_item())
179 }
180
181 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
182 Param::initialized(
183 ParamId::deserialize(&item.id),
184 Tensor::from_item(item.param, device).require_grad(), )
187 }
188}
189
190impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Int>>
191where
192 B: Backend,
193{
194 type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;
195
196 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
197 let (id, tensor) = self.consume();
198 ParamSerde::new(id.serialize(), tensor.into_item())
199 }
200
201 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
202 Param::initialized(
203 ParamId::deserialize(&item.id),
204 Tensor::from_item(item.param, device),
205 )
206 }
207}
208
209impl<B, const D: usize> Record<B> for Param<Tensor<B, D, Bool>>
210where
211 B: Backend,
212{
213 type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;
214
215 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
216 let (id, tensor) = self.consume();
217 ParamSerde::new(id.serialize(), tensor.into_item::<S>())
218 }
219
220 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
221 Param::initialized(
222 ParamId::deserialize(&item.id),
223 Tensor::from_item::<S>(item.param, device),
224 )
225 }
226}
227
228macro_rules! primitive {
230 ($type:ty) => {
231 impl<B: Backend> Record<B> for $type {
232 type Item<S: PrecisionSettings> = $type;
233
234 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
235 self
236 }
237
238 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
239 item
240 }
241 }
242 };
243}
244
245primitive!(alloc::string::String);
247primitive!(bool);
248
249primitive!(f64);
251primitive!(f32);
252
253primitive!(half::bf16);
254primitive!(half::f16);
255
256primitive!(usize);
258primitive!(u64);
259primitive!(u32);
260primitive!(u16);
261primitive!(u8);
262
263primitive!(i64);
265primitive!(i32);
266primitive!(i16);
267primitive!(i8);
268
269pub struct Array<const N: usize, T>([T; N]);
277
278impl<T: Serialize, const N: usize> Serialize for Array<N, T> {
279 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
280 where
281 S: serde::Serializer,
282 {
283 let mut seq = serializer.serialize_tuple(self.0.len())?;
284 for element in &self.0 {
285 seq.serialize_element(element)?;
286 }
287 seq.end()
288 }
289}
290
291impl<'de, T, const N: usize> Deserialize<'de> for Array<N, T>
292where
293 T: Deserialize<'de>,
294{
295 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
296 where
297 D: serde::Deserializer<'de>,
298 {
299 struct ArrayVisitor<T, const N: usize> {
300 marker: PhantomData<T>,
301 }
302
303 impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
304 where
305 T: Deserialize<'de>,
306 {
307 type Value = Array<N, T>;
308
309 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
310 formatter.write_str("a fixed size array")
311 }
312
313 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
314 where
315 A: SeqAccess<'de>,
316 {
317 let mut items = vec![];
318
319 for i in 0..N {
320 let item = seq
321 .next_element()?
322 .ok_or_else(|| Error::invalid_length(i, &self))?;
323 items.push(item);
324 }
325
326 let array: [T; N] = items
327 .into_iter()
328 .collect::<Vec<_>>()
329 .try_into()
330 .map_err(|_| "An array of size {N}")
331 .unwrap();
332
333 Ok(Array(array))
334 }
335 }
336
337 deserializer.deserialize_tuple(
338 N,
339 ArrayVisitor {
340 marker: PhantomData,
341 },
342 )
343 }
344}