1use core::marker::PhantomData;
4
5use faer_entity::Entity;
6use serde::{
7 de::{DeserializeSeed, SeqAccess, Visitor},
8 ser::{SerializeSeq, SerializeStruct},
9 Deserialize, Serialize, Serializer,
10};
11
12use crate::Mat;
13
14impl<E: Entity> Serialize for Mat<E>
15where
16 E: Serialize,
17{
18 fn serialize<S>(&self, s: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
19 where
20 S: Serializer,
21 {
22 struct MatSequenceSerializer<'a, E: Entity>(&'a Mat<E>);
23
24 impl<'a, E: Entity> Serialize for MatSequenceSerializer<'a, E>
25 where
26 E: Serialize,
27 {
28 fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
29 where
30 S: Serializer,
31 {
32 let mut seq = s.serialize_seq(Some(self.0.nrows() * self.0.ncols()))?;
33 for i in 0..self.0.nrows() {
34 for j in 0..self.0.ncols() {
35 seq.serialize_element(&self.0.read(i, j))?;
36 }
37 }
38 seq.end()
39 }
40 }
41
42 let mut structure = s.serialize_struct("Mat", 3)?;
43 structure.serialize_field("nrows", &self.nrows())?;
44 structure.serialize_field("ncols", &self.ncols())?;
45 structure.serialize_field("data", &MatSequenceSerializer(self))?;
46 structure.end()
47 }
48}
49
50impl<'a, E: Entity> Deserialize<'a> for Mat<E>
51where
52 E: Deserialize<'a>,
53{
54 fn deserialize<D>(d: D) -> Result<Self, <D as serde::Deserializer<'a>>::Error>
55 where
56 D: serde::Deserializer<'a>,
57 {
58 #[derive(Deserialize)]
59 #[serde(field_identifier, rename_all = "lowercase")]
60 enum Field {
61 Nrows,
62 Ncols,
63 Data,
64 }
65 const FIELDS: &'static [&'static str] = &["nrows", "ncols", "data"];
66 struct MatVisitor<E: Entity>(PhantomData<E>);
67 impl<'a, E: Entity + Deserialize<'a>> Visitor<'a> for MatVisitor<E> {
68 type Value = Mat<E>;
69
70 fn expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result {
71 formatter.write_str("a faer matrix")
72 }
73
74 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
75 where
76 A: serde::de::MapAccess<'a>,
77 {
78 enum MatrixOrVec<E: Entity> {
79 Matrix(Mat<E>),
80 Vec(Vec<E>),
81 }
82 impl<E: Entity> MatrixOrVec<E> {
83 fn into_mat(self, nrows: usize, ncols: usize) -> Mat<E> {
84 match self {
85 MatrixOrVec::Matrix(m) => m,
86 MatrixOrVec::Vec(v) => {
87 Mat::from_fn(nrows, ncols, |i, j| v[i * ncols + j])
88 }
89 }
90 }
91 }
92 struct MatrixOrVecDeserializer<'a, E: Entity + Deserialize<'a>> {
93 marker: PhantomData<&'a E>,
94 nrows: Option<usize>,
95 ncols: Option<usize>,
96 }
97 impl<'a, E: Entity + Deserialize<'a>> MatrixOrVecDeserializer<'a, E> {
98 fn new(nrows: Option<usize>, ncols: Option<usize>) -> Self {
99 Self {
100 marker: PhantomData,
101 nrows,
102 ncols,
103 }
104 }
105 }
106 impl<'a, E: Entity> DeserializeSeed<'a> for MatrixOrVecDeserializer<'a, E>
107 where
108 E: Deserialize<'a>,
109 {
110 type Value = MatrixOrVec<E>;
111
112 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
113 where
114 D: serde::Deserializer<'a>,
115 {
116 deserializer.deserialize_seq(self)
117 }
118 }
119 impl<'a, E: Entity> Visitor<'a> for MatrixOrVecDeserializer<'a, E>
120 where
121 E: Deserialize<'a>,
122 {
123 type Value = MatrixOrVec<E>;
124
125 fn expecting(
126 &self,
127 formatter: &mut alloc::fmt::Formatter,
128 ) -> alloc::fmt::Result {
129 formatter.write_str("a sequence")
130 }
131
132 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
133 where
134 A: SeqAccess<'a>,
135 {
136 match (self.ncols, self.nrows) {
137 (Some(ncols), Some(nrows)) => {
138 let mut data = Mat::<E>::with_capacity(nrows, ncols);
139 unsafe {
140 data.set_dims(nrows, ncols);
141 }
142 let expected_length = nrows * ncols;
143 for i in 0..expected_length {
144 let el = seq.next_element::<E>()?.ok_or_else(|| {
145 serde::de::Error::invalid_length(
146 i,
147 &format!("{} elements", expected_length).as_str(),
148 )
149 })?;
150 data.write(i / ncols, i % ncols, el);
151 }
152 let mut additional = 0usize;
153 while let Some(_) = seq.next_element::<E>()? {
154 additional += 1;
155 }
156 if additional > 0 {
157 return Err(serde::de::Error::invalid_length(
158 additional + expected_length,
159 &format!("{} elements", expected_length).as_str(),
160 ));
161 }
162 Ok(MatrixOrVec::Matrix(data))
163 }
164 _ => {
165 let mut data = Vec::new();
166 while let Some(el) = seq.next_element::<E>()? {
167 data.push(el);
168 }
169 Ok(MatrixOrVec::Vec(data))
170 }
171 }
172 }
173 }
174 let mut nrows = None;
175 let mut ncols = None;
176 let mut data: Option<MatrixOrVec<E>> = None;
177 while let Some(key) = map.next_key()? {
178 match key {
179 Field::Nrows => {
180 if nrows.is_some() {
181 return Err(serde::de::Error::duplicate_field("nrows"));
182 }
183 let value = map.next_value()?;
184 nrows = Some(value);
185 }
186 Field::Ncols => {
187 if ncols.is_some() {
188 return Err(serde::de::Error::duplicate_field("ncols"));
189 }
190 let value = map.next_value()?;
191 ncols = Some(value);
192 }
193 Field::Data => {
194 if data.is_some() {
195 return Err(serde::de::Error::duplicate_field("data"));
196 }
197 data = Some(map.next_value_seed(MatrixOrVecDeserializer::<E>::new(
198 nrows.clone(),
199 ncols.clone(),
200 ))?);
201 }
202 }
203 }
204 let nrows = nrows.ok_or_else(|| serde::de::Error::missing_field("nrows"))?;
205 let ncols = ncols.ok_or_else(|| serde::de::Error::missing_field("ncols"))?;
206 let data = data
207 .ok_or_else(|| serde::de::Error::missing_field("data"))?
208 .into_mat(nrows, ncols);
209 Ok(data)
210 }
211 }
212 d.deserialize_struct("Mat", FIELDS, MatVisitor(PhantomData))
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use serde_test::{assert_de_tokens_error, assert_tokens, Token};
220 #[test]
221 fn matrix_serialization_normal() {
222 let value = Mat::from_fn(3, 4, |i, j| (i + (j * 10)) as f64);
223 assert_tokens(
224 &value,
225 &[
226 Token::Struct {
227 name: "Mat",
228 len: 3,
229 },
230 Token::Str("nrows"),
231 Token::U64(3),
232 Token::Str("ncols"),
233 Token::U64(4),
234 Token::Str("data"),
235 Token::Seq { len: Some(12) },
236 Token::F64(0.0),
237 Token::F64(10.0),
238 Token::F64(20.0),
239 Token::F64(30.0),
240 Token::F64(1.0),
241 Token::F64(11.0),
242 Token::F64(21.0),
243 Token::F64(31.0),
244 Token::F64(2.0),
245 Token::F64(12.0),
246 Token::F64(22.0),
247 Token::F64(32.0),
248 Token::SeqEnd,
249 Token::StructEnd,
250 ],
251 )
252 }
253
254 #[test]
255 fn matrix_serialization_wide() {
256 let value = Mat::from_fn(12, 1, |i, j| (i + (j * 10)) as f64);
257 assert_tokens(
258 &value,
259 &[
260 Token::Struct {
261 name: "Mat",
262 len: 3,
263 },
264 Token::Str("nrows"),
265 Token::U64(12),
266 Token::Str("ncols"),
267 Token::U64(1),
268 Token::Str("data"),
269 Token::Seq { len: Some(12) },
270 Token::F64(0.0),
271 Token::F64(1.0),
272 Token::F64(2.0),
273 Token::F64(3.0),
274 Token::F64(4.0),
275 Token::F64(5.0),
276 Token::F64(6.0),
277 Token::F64(7.0),
278 Token::F64(8.0),
279 Token::F64(9.0),
280 Token::F64(10.0),
281 Token::F64(11.0),
282 Token::SeqEnd,
283 Token::StructEnd,
284 ],
285 )
286 }
287
288 #[test]
289 fn matrix_serialization_tall() {
290 let value = Mat::from_fn(1, 12, |i, j| (i + (j * 10)) as f64);
291 assert_tokens(
292 &value,
293 &[
294 Token::Struct {
295 name: "Mat",
296 len: 3,
297 },
298 Token::Str("nrows"),
299 Token::U64(1),
300 Token::Str("ncols"),
301 Token::U64(12),
302 Token::Str("data"),
303 Token::Seq { len: Some(12) },
304 Token::F64(0.0),
305 Token::F64(10.0),
306 Token::F64(20.0),
307 Token::F64(30.0),
308 Token::F64(40.0),
309 Token::F64(50.0),
310 Token::F64(60.0),
311 Token::F64(70.0),
312 Token::F64(80.0),
313 Token::F64(90.0),
314 Token::F64(100.0),
315 Token::F64(110.0),
316 Token::SeqEnd,
317 Token::StructEnd,
318 ],
319 )
320 }
321
322 #[test]
323 fn matrix_serialization_zero() {
324 let value = Mat::from_fn(0, 0, |i, j| (i + (j * 10)) as f64);
325 assert_tokens(
326 &value,
327 &[
328 Token::Struct {
329 name: "Mat",
330 len: 3,
331 },
332 Token::Str("nrows"),
333 Token::U64(0),
334 Token::Str("ncols"),
335 Token::U64(0),
336 Token::Str("data"),
337 Token::Seq { len: Some(0) },
338 Token::SeqEnd,
339 Token::StructEnd,
340 ],
341 )
342 }
343
344 #[test]
345 fn matrix_serialization_errors_too_small() {
346 assert_de_tokens_error::<Mat<f64>>(
347 &[
348 Token::Struct {
349 name: "Mat",
350 len: 3,
351 },
352 Token::Str("nrows"),
353 Token::U64(3),
354 Token::Str("ncols"),
355 Token::U64(4),
356 Token::Str("data"),
357 Token::Seq { len: Some(12) },
358 Token::F64(0.0),
359 Token::F64(10.0),
360 Token::F64(20.0),
361 Token::F64(30.0),
362 Token::F64(1.0),
363 Token::F64(11.0),
364 Token::F64(21.0),
365 Token::F64(31.0),
366 Token::F64(2.0),
367 Token::SeqEnd,
368 ],
369 "invalid length 9, expected 12 elements",
370 )
371 }
372
373 #[test]
374 fn matrix_serialization_errors_too_large() {
375 assert_de_tokens_error::<Mat<f64>>(
376 &[
377 Token::Struct {
378 name: "Mat",
379 len: 3,
380 },
381 Token::Str("nrows"),
382 Token::U64(3),
383 Token::Str("ncols"),
384 Token::U64(4),
385 Token::Str("data"),
386 Token::Seq { len: Some(12) },
387 Token::F64(0.0),
388 Token::F64(10.0),
389 Token::F64(20.0),
390 Token::F64(30.0),
391 Token::F64(1.0),
392 Token::F64(11.0),
393 Token::F64(21.0),
394 Token::F64(31.0),
395 Token::F64(2.0),
396 Token::F64(12.0),
397 Token::F64(22.0),
398 Token::F64(32.0),
399 Token::F64(32.0),
400 Token::F64(32.0),
401 Token::SeqEnd,
402 ],
403 "invalid length 14, expected 12 elements",
404 )
405 }
406}