faer_core/
serde_impl.rs

1//! Serde implementations for Mat
2
3use core::marker::PhantomData;
4
5use faer_entity::Entity;
6use serde::{
7    de::{DeserializeSeed, SeqAccess, Visitor},
8    ser::{SerializeSeq, SerializeStruct},
9    Deserialize, Serialize, Serializer,
10};
11
12use crate::Mat;
13
14impl<E: Entity> Serialize for Mat<E>
15where
16    E: Serialize,
17{
18    fn serialize<S>(&self, s: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
19    where
20        S: Serializer,
21    {
22        struct MatSequenceSerializer<'a, E: Entity>(&'a Mat<E>);
23
24        impl<'a, E: Entity> Serialize for MatSequenceSerializer<'a, E>
25        where
26            E: Serialize,
27        {
28            fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
29            where
30                S: Serializer,
31            {
32                let mut seq = s.serialize_seq(Some(self.0.nrows() * self.0.ncols()))?;
33                for i in 0..self.0.nrows() {
34                    for j in 0..self.0.ncols() {
35                        seq.serialize_element(&self.0.read(i, j))?;
36                    }
37                }
38                seq.end()
39            }
40        }
41
42        let mut structure = s.serialize_struct("Mat", 3)?;
43        structure.serialize_field("nrows", &self.nrows())?;
44        structure.serialize_field("ncols", &self.ncols())?;
45        structure.serialize_field("data", &MatSequenceSerializer(self))?;
46        structure.end()
47    }
48}
49
50impl<'a, E: Entity> Deserialize<'a> for Mat<E>
51where
52    E: Deserialize<'a>,
53{
54    fn deserialize<D>(d: D) -> Result<Self, <D as serde::Deserializer<'a>>::Error>
55    where
56        D: serde::Deserializer<'a>,
57    {
58        #[derive(Deserialize)]
59        #[serde(field_identifier, rename_all = "lowercase")]
60        enum Field {
61            Nrows,
62            Ncols,
63            Data,
64        }
65        const FIELDS: &'static [&'static str] = &["nrows", "ncols", "data"];
66        struct MatVisitor<E: Entity>(PhantomData<E>);
67        impl<'a, E: Entity + Deserialize<'a>> Visitor<'a> for MatVisitor<E> {
68            type Value = Mat<E>;
69
70            fn expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result {
71                formatter.write_str("a faer matrix")
72            }
73
74            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
75            where
76                A: serde::de::MapAccess<'a>,
77            {
78                enum MatrixOrVec<E: Entity> {
79                    Matrix(Mat<E>),
80                    Vec(Vec<E>),
81                }
82                impl<E: Entity> MatrixOrVec<E> {
83                    fn into_mat(self, nrows: usize, ncols: usize) -> Mat<E> {
84                        match self {
85                            MatrixOrVec::Matrix(m) => m,
86                            MatrixOrVec::Vec(v) => {
87                                Mat::from_fn(nrows, ncols, |i, j| v[i * ncols + j])
88                            }
89                        }
90                    }
91                }
92                struct MatrixOrVecDeserializer<'a, E: Entity + Deserialize<'a>> {
93                    marker: PhantomData<&'a E>,
94                    nrows: Option<usize>,
95                    ncols: Option<usize>,
96                }
97                impl<'a, E: Entity + Deserialize<'a>> MatrixOrVecDeserializer<'a, E> {
98                    fn new(nrows: Option<usize>, ncols: Option<usize>) -> Self {
99                        Self {
100                            marker: PhantomData,
101                            nrows,
102                            ncols,
103                        }
104                    }
105                }
106                impl<'a, E: Entity> DeserializeSeed<'a> for MatrixOrVecDeserializer<'a, E>
107                where
108                    E: Deserialize<'a>,
109                {
110                    type Value = MatrixOrVec<E>;
111
112                    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
113                    where
114                        D: serde::Deserializer<'a>,
115                    {
116                        deserializer.deserialize_seq(self)
117                    }
118                }
119                impl<'a, E: Entity> Visitor<'a> for MatrixOrVecDeserializer<'a, E>
120                where
121                    E: Deserialize<'a>,
122                {
123                    type Value = MatrixOrVec<E>;
124
125                    fn expecting(
126                        &self,
127                        formatter: &mut alloc::fmt::Formatter,
128                    ) -> alloc::fmt::Result {
129                        formatter.write_str("a sequence")
130                    }
131
132                    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
133                    where
134                        A: SeqAccess<'a>,
135                    {
136                        match (self.ncols, self.nrows) {
137                            (Some(ncols), Some(nrows)) => {
138                                let mut data = Mat::<E>::with_capacity(nrows, ncols);
139                                unsafe {
140                                    data.set_dims(nrows, ncols);
141                                }
142                                let expected_length = nrows * ncols;
143                                for i in 0..expected_length {
144                                    let el = seq.next_element::<E>()?.ok_or_else(|| {
145                                        serde::de::Error::invalid_length(
146                                            i,
147                                            &format!("{} elements", expected_length).as_str(),
148                                        )
149                                    })?;
150                                    data.write(i / ncols, i % ncols, el);
151                                }
152                                let mut additional = 0usize;
153                                while let Some(_) = seq.next_element::<E>()? {
154                                    additional += 1;
155                                }
156                                if additional > 0 {
157                                    return Err(serde::de::Error::invalid_length(
158                                        additional + expected_length,
159                                        &format!("{} elements", expected_length).as_str(),
160                                    ));
161                                }
162                                Ok(MatrixOrVec::Matrix(data))
163                            }
164                            _ => {
165                                let mut data = Vec::new();
166                                while let Some(el) = seq.next_element::<E>()? {
167                                    data.push(el);
168                                }
169                                Ok(MatrixOrVec::Vec(data))
170                            }
171                        }
172                    }
173                }
174                let mut nrows = None;
175                let mut ncols = None;
176                let mut data: Option<MatrixOrVec<E>> = None;
177                while let Some(key) = map.next_key()? {
178                    match key {
179                        Field::Nrows => {
180                            if nrows.is_some() {
181                                return Err(serde::de::Error::duplicate_field("nrows"));
182                            }
183                            let value = map.next_value()?;
184                            nrows = Some(value);
185                        }
186                        Field::Ncols => {
187                            if ncols.is_some() {
188                                return Err(serde::de::Error::duplicate_field("ncols"));
189                            }
190                            let value = map.next_value()?;
191                            ncols = Some(value);
192                        }
193                        Field::Data => {
194                            if data.is_some() {
195                                return Err(serde::de::Error::duplicate_field("data"));
196                            }
197                            data = Some(map.next_value_seed(MatrixOrVecDeserializer::<E>::new(
198                                nrows.clone(),
199                                ncols.clone(),
200                            ))?);
201                        }
202                    }
203                }
204                let nrows = nrows.ok_or_else(|| serde::de::Error::missing_field("nrows"))?;
205                let ncols = ncols.ok_or_else(|| serde::de::Error::missing_field("ncols"))?;
206                let data = data
207                    .ok_or_else(|| serde::de::Error::missing_field("data"))?
208                    .into_mat(nrows, ncols);
209                Ok(data)
210            }
211        }
212        d.deserialize_struct("Mat", FIELDS, MatVisitor(PhantomData))
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use serde_test::{assert_de_tokens_error, assert_tokens, Token};
220    #[test]
221    fn matrix_serialization_normal() {
222        let value = Mat::from_fn(3, 4, |i, j| (i + (j * 10)) as f64);
223        assert_tokens(
224            &value,
225            &[
226                Token::Struct {
227                    name: "Mat",
228                    len: 3,
229                },
230                Token::Str("nrows"),
231                Token::U64(3),
232                Token::Str("ncols"),
233                Token::U64(4),
234                Token::Str("data"),
235                Token::Seq { len: Some(12) },
236                Token::F64(0.0),
237                Token::F64(10.0),
238                Token::F64(20.0),
239                Token::F64(30.0),
240                Token::F64(1.0),
241                Token::F64(11.0),
242                Token::F64(21.0),
243                Token::F64(31.0),
244                Token::F64(2.0),
245                Token::F64(12.0),
246                Token::F64(22.0),
247                Token::F64(32.0),
248                Token::SeqEnd,
249                Token::StructEnd,
250            ],
251        )
252    }
253
254    #[test]
255    fn matrix_serialization_wide() {
256        let value = Mat::from_fn(12, 1, |i, j| (i + (j * 10)) as f64);
257        assert_tokens(
258            &value,
259            &[
260                Token::Struct {
261                    name: "Mat",
262                    len: 3,
263                },
264                Token::Str("nrows"),
265                Token::U64(12),
266                Token::Str("ncols"),
267                Token::U64(1),
268                Token::Str("data"),
269                Token::Seq { len: Some(12) },
270                Token::F64(0.0),
271                Token::F64(1.0),
272                Token::F64(2.0),
273                Token::F64(3.0),
274                Token::F64(4.0),
275                Token::F64(5.0),
276                Token::F64(6.0),
277                Token::F64(7.0),
278                Token::F64(8.0),
279                Token::F64(9.0),
280                Token::F64(10.0),
281                Token::F64(11.0),
282                Token::SeqEnd,
283                Token::StructEnd,
284            ],
285        )
286    }
287
288    #[test]
289    fn matrix_serialization_tall() {
290        let value = Mat::from_fn(1, 12, |i, j| (i + (j * 10)) as f64);
291        assert_tokens(
292            &value,
293            &[
294                Token::Struct {
295                    name: "Mat",
296                    len: 3,
297                },
298                Token::Str("nrows"),
299                Token::U64(1),
300                Token::Str("ncols"),
301                Token::U64(12),
302                Token::Str("data"),
303                Token::Seq { len: Some(12) },
304                Token::F64(0.0),
305                Token::F64(10.0),
306                Token::F64(20.0),
307                Token::F64(30.0),
308                Token::F64(40.0),
309                Token::F64(50.0),
310                Token::F64(60.0),
311                Token::F64(70.0),
312                Token::F64(80.0),
313                Token::F64(90.0),
314                Token::F64(100.0),
315                Token::F64(110.0),
316                Token::SeqEnd,
317                Token::StructEnd,
318            ],
319        )
320    }
321
322    #[test]
323    fn matrix_serialization_zero() {
324        let value = Mat::from_fn(0, 0, |i, j| (i + (j * 10)) as f64);
325        assert_tokens(
326            &value,
327            &[
328                Token::Struct {
329                    name: "Mat",
330                    len: 3,
331                },
332                Token::Str("nrows"),
333                Token::U64(0),
334                Token::Str("ncols"),
335                Token::U64(0),
336                Token::Str("data"),
337                Token::Seq { len: Some(0) },
338                Token::SeqEnd,
339                Token::StructEnd,
340            ],
341        )
342    }
343
344    #[test]
345    fn matrix_serialization_errors_too_small() {
346        assert_de_tokens_error::<Mat<f64>>(
347            &[
348                Token::Struct {
349                    name: "Mat",
350                    len: 3,
351                },
352                Token::Str("nrows"),
353                Token::U64(3),
354                Token::Str("ncols"),
355                Token::U64(4),
356                Token::Str("data"),
357                Token::Seq { len: Some(12) },
358                Token::F64(0.0),
359                Token::F64(10.0),
360                Token::F64(20.0),
361                Token::F64(30.0),
362                Token::F64(1.0),
363                Token::F64(11.0),
364                Token::F64(21.0),
365                Token::F64(31.0),
366                Token::F64(2.0),
367                Token::SeqEnd,
368            ],
369            "invalid length 9, expected 12 elements",
370        )
371    }
372
373    #[test]
374    fn matrix_serialization_errors_too_large() {
375        assert_de_tokens_error::<Mat<f64>>(
376            &[
377                Token::Struct {
378                    name: "Mat",
379                    len: 3,
380                },
381                Token::Str("nrows"),
382                Token::U64(3),
383                Token::Str("ncols"),
384                Token::U64(4),
385                Token::Str("data"),
386                Token::Seq { len: Some(12) },
387                Token::F64(0.0),
388                Token::F64(10.0),
389                Token::F64(20.0),
390                Token::F64(30.0),
391                Token::F64(1.0),
392                Token::F64(11.0),
393                Token::F64(21.0),
394                Token::F64(31.0),
395                Token::F64(2.0),
396                Token::F64(12.0),
397                Token::F64(22.0),
398                Token::F64(32.0),
399                Token::F64(32.0),
400                Token::F64(32.0),
401                Token::SeqEnd,
402            ],
403            "invalid length 14, expected 12 elements",
404        )
405    }
406}