arff/dynamic/
de.rs

1use serde::de::{self, Deserialize, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess,
2                Visitor};
3
4use error::{Error, Result};
5
6use super::DataSet;
7use super::FlatIter;
8use super::Value;
9
10pub fn from_dataset<'a, T>(dset: &'a DataSet) -> Result<T>
11where
12    T: Deserialize<'a>,
13{
14    let mut deserializer = Deserializer::from_dataset(dset);
15    T::deserialize(&mut deserializer)
16}
17
18/// Deserialize from a data set
19pub struct Deserializer<'de> {
20    input: FlatIter<'de>,
21    nested_sequence_depth: u8,
22}
23
24impl<'de> Deserializer<'de> {
25    pub fn from_dataset(input: &'de DataSet) -> Self {
26        Deserializer {
27            input: input.flat_iter(),
28            nested_sequence_depth: 0,
29        }
30    }
31
32    fn next(&mut self) -> Result<(&str, Value)> {
33        let n = self.input.next().ok_or(Error::Eof);
34        n
35    }
36}
37
38impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
39    type Error = Error;
40
41    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
42    where
43        V: Visitor<'de>,
44    {
45        unimplemented!()
46    }
47
48    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
49    where
50        V: Visitor<'de>,
51    {
52        visitor.visit_bool(self.next()?.1.as_bool()?)
53    }
54
55    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
56    where
57        V: Visitor<'de>,
58    {
59        visitor.visit_i8(self.next()?.1.as_i8()?)
60    }
61
62    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
63    where
64        V: Visitor<'de>,
65    {
66        visitor.visit_i16(self.next()?.1.as_i16()?)
67    }
68
69    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
70    where
71        V: Visitor<'de>,
72    {
73        visitor.visit_i32(self.next()?.1.as_i32()?)
74    }
75
76    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
77    where
78        V: Visitor<'de>,
79    {
80        visitor.visit_i64(self.next()?.1.as_i64()?)
81    }
82
83    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
84    where
85        V: Visitor<'de>,
86    {
87        visitor.visit_u8(self.next()?.1.as_u8()?)
88    }
89
90    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
91    where
92        V: Visitor<'de>,
93    {
94        visitor.visit_u16(self.next()?.1.as_u16()?)
95    }
96
97    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
98    where
99        V: Visitor<'de>,
100    {
101        visitor.visit_u32(self.next()?.1.as_u32()?)
102    }
103
104    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
105    where
106        V: Visitor<'de>,
107    {
108        visitor.visit_u64(self.next()?.1.as_u64()?)
109    }
110
111    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
112    where
113        V: Visitor<'de>,
114    {
115        visitor.visit_f32(self.next()?.1.as_f64()? as f32)
116    }
117
118    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
119    where
120        V: Visitor<'de>,
121    {
122        visitor.visit_f64(self.next()?.1.as_f64()?)
123    }
124
125    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value>
126    where
127        V: Visitor<'de>,
128    {
129        unimplemented!()
130    }
131
132    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
133    where
134        V: Visitor<'de>,
135    {
136        visitor.visit_string(self.next()?.1.as_string()?)
137    }
138
139    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
140    where
141        V: Visitor<'de>,
142    {
143        visitor.visit_string(self.next()?.1.as_string()?)
144    }
145
146    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value>
147    where
148        V: Visitor<'de>,
149    {
150        unimplemented!()
151    }
152
153    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value>
154    where
155        V: Visitor<'de>,
156    {
157        unimplemented!()
158    }
159
160    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
161    where
162        V: Visitor<'de>,
163    {
164        match self.input.peek() {
165            None => return Err(Error::Eof),
166            Some((_, Value::Missing)) => {
167                self.next()?;
168                visitor.visit_none()
169            }
170            Some(_) => visitor.visit_some(self),
171        }
172    }
173
174    fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value>
175    where
176        V: Visitor<'de>,
177    {
178        unimplemented!()
179    }
180
181    fn deserialize_unit_struct<V>(self, _name: &'static str, _visitor: V) -> Result<V::Value>
182    where
183        V: Visitor<'de>,
184    {
185        unimplemented!()
186    }
187
188    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
189    where
190        V: Visitor<'de>,
191    {
192        visitor.visit_newtype_struct(self)
193    }
194
195    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
196    where
197        V: Visitor<'de>,
198    {
199        visitor.visit_seq(SequenceAccessor::new(self))
200    }
201
202    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
203    where
204        V: Visitor<'de>,
205    {
206        self.deserialize_seq(visitor)
207    }
208
209    // Tuple structs look just like sequences in JSON.
210    fn deserialize_tuple_struct<V>(
211        self,
212        _name: &'static str,
213        _len: usize,
214        visitor: V,
215    ) -> Result<V::Value>
216    where
217        V: Visitor<'de>,
218    {
219        self.deserialize_seq(visitor)
220    }
221
222    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
223    where
224        V: Visitor<'de>,
225    {
226        unimplemented!()
227    }
228
229    fn deserialize_struct<V>(
230        mut self,
231        _name: &'static str,
232        fields: &'static [&'static str],
233        visitor: V,
234    ) -> Result<V::Value>
235    where
236        V: Visitor<'de>,
237    {
238        visitor.visit_map(StructAcess {
239            de: &mut self,
240            n_fields: fields.len(),
241        })
242    }
243
244    fn deserialize_enum<V>(
245        self,
246        _name: &'static str,
247        _variants: &'static [&'static str],
248        visitor: V,
249    ) -> Result<V::Value>
250    where
251        V: Visitor<'de>,
252    {
253        visitor.visit_enum(self.next()?.1.as_str()?.into_deserializer())
254    }
255
256    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
257    where
258        V: Visitor<'de>,
259    {
260        let (name, _) = self.input.peek().unwrap();
261        visitor.visit_str(name)
262    }
263
264    fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
265    where
266        V: Visitor<'de>,
267    {
268        unimplemented!()
269    }
270}
271
272struct SequenceAccessor<'a, 'de: 'a> {
273    de: &'a mut Deserializer<'de>,
274    my_row: usize,
275}
276
277impl<'a, 'de> SequenceAccessor<'a, 'de> {
278    fn new(de: &'a mut Deserializer<'de>) -> Self {
279        de.nested_sequence_depth += 1;
280        SequenceAccessor {
281            my_row: de.input.row(),
282            de,
283        }
284    }
285}
286
287impl<'a, 'de> Drop for SequenceAccessor<'a, 'de> {
288    fn drop(&mut self) {
289        self.de.nested_sequence_depth -= 1;
290    }
291}
292
293impl<'a, 'de> SeqAccess<'de> for SequenceAccessor<'a, 'de> {
294    type Error = Error;
295
296    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
297    where
298        T: DeserializeSeed<'de>,
299    {
300        if self.de.nested_sequence_depth > 1 && self.de.input.row() != self.my_row {
301            return Ok(None);
302        }
303
304        if self.de.input.peek().is_none() {
305            return Ok(None);
306        }
307
308        seed.deserialize(&mut *self.de).map(Some)
309    }
310}
311
312struct StructAcess<'a, 'de: 'a> {
313    de: &'a mut Deserializer<'de>,
314    n_fields: usize,
315}
316
317impl<'a, 'de> MapAccess<'de> for StructAcess<'a, 'de> {
318    type Error = Error;
319
320    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
321    where
322        K: DeserializeSeed<'de>,
323    {
324        if self.n_fields == 0 || self.de.input.peek().is_none() {
325            return Ok(None);
326        }
327        self.n_fields -= 1;
328        seed.deserialize(&mut *self.de).map(Some)
329    }
330
331    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
332    where
333        V: DeserializeSeed<'de>,
334    {
335        seed.deserialize(&mut *self.de)
336    }
337}
338
339#[cfg(test)]
340use super::column::{Column, ColumnData};
341
342#[test]
343fn simple() {
344    let dset = DataSet::new(
345        "Test data",
346        vec![
347            Column::new(
348                "int",
349                ColumnData::U8 {
350                    values: vec![Some(1), Some(4)],
351                },
352            ),
353            Column::new(
354                "float",
355                ColumnData::F64 {
356                    values: vec![Some(2.0), None],
357                },
358            ),
359            Column::new(
360                "text",
361                ColumnData::String {
362                    values: vec![Some("three".to_owned()), Some("7".to_owned())],
363                },
364            ),
365            Column::new(
366                "color",
367                ColumnData::Nominal {
368                    values: vec![Some(2), Some(0)],
369                    categories: vec!["red".to_owned(), "green".to_owned(), "blue".to_owned()],
370                },
371            ),
372        ],
373    );
374
375    let x: Vec<(u8, Option<f64>, String, String)> = from_dataset(&dset).unwrap();
376
377    assert_eq!(
378        x,
379        vec![
380            (1, Some(2.0), "three".to_owned(), "blue".to_owned()),
381            (4, None, "7".to_owned(), "red".to_owned()),
382        ]
383    );
384}
385
386#[test]
387fn named() {
388    let dset = DataSet::new(
389        "Test data",
390        vec![
391            Column::new(
392                "int",
393                ColumnData::U8 {
394                    values: vec![Some(1), Some(4)],
395                },
396            ),
397            Column::new(
398                "float",
399                ColumnData::F64 {
400                    values: vec![Some(2.0), None],
401                },
402            ),
403            Column::new(
404                "text",
405                ColumnData::String {
406                    values: vec![Some("three".to_owned()), Some("7".to_owned())],
407                },
408            ),
409            Column::new(
410                "color",
411                ColumnData::Nominal {
412                    values: vec![Some(2), Some(0)],
413                    categories: vec!["Red".to_owned(), "Green".to_owned(), "Blue".to_owned()],
414                },
415            ),
416        ],
417    );
418
419    #[derive(Debug, Deserialize, PartialEq)]
420    enum Color {
421        Red,
422        Green,
423        Blue,
424    }
425
426    #[derive(Debug, Deserialize, PartialEq)]
427    struct Row {
428        int: i16,
429        float: Option<f32>,
430        text: String,
431        color: Color,
432    }
433
434    let x: Vec<Row> = from_dataset(&dset).unwrap();
435
436    assert_eq!(
437        x,
438        vec![
439            Row {
440                int: 1,
441                float: Some(2.0),
442                text: "three".to_owned(),
443                color: Color::Blue,
444            },
445            Row {
446                int: 4,
447                float: None,
448                text: "7".to_owned(),
449                color: Color::Red,
450            },
451        ]
452    );
453}
454
455#[test]
456fn unknown_length() {
457    let dset = DataSet::new(
458        "Test data",
459        vec![
460            Column::new(
461                "int",
462                ColumnData::U8 {
463                    values: vec![Some(1), Some(4)],
464                },
465            ),
466            Column::new(
467                "float",
468                ColumnData::F64 {
469                    values: vec![Some(2.0), Some(5.0)],
470                },
471            ),
472        ],
473    );
474
475    let x: Vec<Vec<f64>> = from_dataset(&dset).unwrap();
476
477    assert_eq!(x, vec![vec![1.0, 2.0], vec![4.0, 5.0]]);
478}