serde_aco/
de.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use serde::Deserialize;
18use serde::de::{self, DeserializeSeed, EnumAccess, MapAccess, SeqAccess, VariantAccess, Visitor};
19
20use crate::error::{Error, Result};
21
22#[derive(Debug)]
23pub struct Deserializer<'s, 'o> {
24    input: &'s str,
25    objects: Option<&'o HashMap<&'s str, &'s str>>,
26    top_level: bool,
27    key: &'s str,
28}
29
30impl<'s> de::Deserializer<'s> for &mut Deserializer<'s, '_> {
31    type Error = Error;
32
33    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
34    where
35        V: Visitor<'s>,
36    {
37        Err(Error::UnknownType)
38    }
39
40    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
41    where
42        V: Visitor<'s>,
43    {
44        let s = self.consume_input();
45        match s {
46            "on" | "true" => visitor.visit_bool(true),
47            "off" | "false" => visitor.visit_bool(false),
48            _ => Err(Error::ExpectedBool),
49        }
50    }
51
52    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
53    where
54        V: Visitor<'s>,
55    {
56        visitor.visit_i8(self.parse_signed()?)
57    }
58
59    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
60    where
61        V: Visitor<'s>,
62    {
63        visitor.visit_i16(self.parse_signed()?)
64    }
65
66    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
67    where
68        V: Visitor<'s>,
69    {
70        visitor.visit_i32(self.parse_signed()?)
71    }
72
73    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
74    where
75        V: Visitor<'s>,
76    {
77        visitor.visit_i64(self.parse_signed()?)
78    }
79
80    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
81    where
82        V: Visitor<'s>,
83    {
84        visitor.visit_u8(self.parse_unsigned()?)
85    }
86
87    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
88    where
89        V: Visitor<'s>,
90    {
91        visitor.visit_u16(self.parse_unsigned()?)
92    }
93
94    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
95    where
96        V: Visitor<'s>,
97    {
98        visitor.visit_u32(self.parse_unsigned()?)
99    }
100
101    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
102    where
103        V: Visitor<'s>,
104    {
105        visitor.visit_u64(self.parse_unsigned()?)
106    }
107
108    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
109    where
110        V: Visitor<'s>,
111    {
112        let s = self.consume_input();
113        visitor.visit_f32(s.parse().map_err(|_| Error::ExpectedFloat)?)
114    }
115
116    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
117    where
118        V: Visitor<'s>,
119    {
120        let s = self.consume_input();
121        visitor.visit_f64(s.parse().map_err(|_| Error::ExpectedFloat)?)
122    }
123
124    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
125    where
126        V: Visitor<'s>,
127    {
128        self.deserialize_str(visitor)
129    }
130
131    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
132    where
133        V: Visitor<'s>,
134    {
135        if self.top_level {
136            visitor.visit_borrowed_str(self.consume_all())
137        } else {
138            let id = self.consume_input();
139            visitor.visit_borrowed_str(self.deref_id(id)?)
140        }
141    }
142
143    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
144    where
145        V: Visitor<'s>,
146    {
147        self.deserialize_str(visitor)
148    }
149
150    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
151    where
152        V: Visitor<'s>,
153    {
154        self.deserialize_seq(visitor)
155    }
156
157    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
158    where
159        V: Visitor<'s>,
160    {
161        self.deserialize_bytes(visitor)
162    }
163
164    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
165    where
166        V: Visitor<'s>,
167    {
168        let id = self.consume_input();
169        let s = self.deref_id(id)?;
170        if id.starts_with("id_") && s.is_empty() {
171            visitor.visit_none()
172        } else {
173            let mut sub_de = Deserializer { input: s, ..*self };
174            visitor.visit_some(&mut sub_de)
175        }
176    }
177
178    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
179    where
180        V: Visitor<'s>,
181    {
182        let s = self.consume_input();
183        if s.is_empty() {
184            visitor.visit_unit()
185        } else {
186            Err(Error::ExpectedUnit)
187        }
188    }
189
190    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
191    where
192        V: Visitor<'s>,
193    {
194        self.deserialize_unit(visitor)
195    }
196
197    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
198    where
199        V: Visitor<'s>,
200    {
201        visitor.visit_newtype_struct(self)
202    }
203
204    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
205    where
206        V: Visitor<'s>,
207    {
208        self.deserialize_nested(|de| visitor.visit_seq(CommaSeparated::new(de)))
209    }
210
211    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
212    where
213        V: Visitor<'s>,
214    {
215        self.deserialize_seq(visitor)
216    }
217
218    fn deserialize_tuple_struct<V>(
219        self,
220        _name: &'static str,
221        _len: usize,
222        visitor: V,
223    ) -> Result<V::Value>
224    where
225        V: Visitor<'s>,
226    {
227        self.deserialize_seq(visitor)
228    }
229
230    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
231    where
232        V: Visitor<'s>,
233    {
234        self.deserialize_nested(|de| visitor.visit_map(CommaSeparated::new(de)))
235    }
236
237    fn deserialize_struct<V>(
238        self,
239        _name: &'static str,
240        _fields: &'static [&'static str],
241        visitor: V,
242    ) -> Result<V::Value>
243    where
244        V: Visitor<'s>,
245    {
246        self.deserialize_map(visitor)
247    }
248
249    fn deserialize_enum<V>(
250        self,
251        _name: &'static str,
252        _variants: &'static [&'static str],
253        visitor: V,
254    ) -> Result<V::Value>
255    where
256        V: Visitor<'s>,
257    {
258        self.deserialize_nested(|de| visitor.visit_enum(Enum::new(de)))
259    }
260
261    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
262    where
263        V: Visitor<'s>,
264    {
265        visitor.visit_borrowed_str(self.consume_input())
266    }
267
268    fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
269    where
270        V: Visitor<'s>,
271    {
272        Err(Error::Ignored(self.key.to_owned()))
273    }
274}
275
276impl<'s, 'o> Deserializer<'s, 'o> {
277    pub fn from_args(input: &'s str, objects: &'o HashMap<&'s str, &'s str>) -> Self {
278        Deserializer {
279            input,
280            objects: Some(objects),
281            top_level: true,
282            key: "",
283        }
284    }
285
286    pub fn from_arg(input: &'s str) -> Self {
287        Deserializer {
288            input,
289            objects: None,
290            top_level: true,
291            key: "",
292        }
293    }
294
295    fn end(&self) -> Result<()> {
296        if self.input.is_empty() {
297            Ok(())
298        } else {
299            Err(Error::Trailing(self.input.to_owned()))
300        }
301    }
302
303    fn deserialize_nested<F, V>(&mut self, f: F) -> Result<V>
304    where
305        F: FnOnce(&mut Self) -> Result<V>,
306    {
307        let mut sub_de;
308        let de = if !self.top_level {
309            let id = self.consume_input();
310            let sub_input = self.deref_id(id)?;
311            sub_de = Deserializer {
312                input: sub_input,
313                ..*self
314            };
315            &mut sub_de
316        } else {
317            self.top_level = false;
318            self
319        };
320        let val = f(de)?;
321        de.end()?;
322        Ok(val)
323    }
324
325    fn consume_input_until(&mut self, end: char) -> Option<&'s str> {
326        let len = self.input.find(end)?;
327        let s = &self.input[..len];
328        self.input = &self.input[len + end.len_utf8()..];
329        Some(s)
330    }
331
332    fn consume_all(&mut self) -> &'s str {
333        let s = self.input;
334        self.input = "";
335        s
336    }
337
338    fn consume_input(&mut self) -> &'s str {
339        match self.consume_input_until(',') {
340            Some(s) => s,
341            None => self.consume_all(),
342        }
343    }
344
345    fn deref_id(&self, id: &'s str) -> Result<&'s str> {
346        if id.starts_with("id_") {
347            if let Some(s) = self.objects.and_then(|objects| objects.get(id)) {
348                Ok(s)
349            } else {
350                Err(Error::IdNotFound(id.to_owned()))
351            }
352        } else {
353            Ok(id)
354        }
355    }
356
357    fn parse_unsigned<T>(&mut self) -> Result<T>
358    where
359        T: TryFrom<u64>,
360    {
361        let s = self.consume_input();
362        let (num, shift) = if let Some((num, "")) = s.split_once(['k', 'K']) {
363            (num, 10)
364        } else if let Some((num, "")) = s.split_once(['m', 'M']) {
365            (num, 20)
366        } else if let Some((num, "")) = s.split_once(['g', 'G']) {
367            (num, 30)
368        } else if let Some((num, "")) = s.split_once(['t', 'T']) {
369            (num, 40)
370        } else {
371            (s, 0)
372        };
373        let n = if let Some(num_h) = num.strip_prefix("0x") {
374            u64::from_str_radix(num_h, 16)
375        } else if let Some(num_o) = num.strip_prefix("0o") {
376            u64::from_str_radix(num_o, 8)
377        } else if let Some(num_b) = num.strip_prefix("0b") {
378            u64::from_str_radix(num_b, 2)
379        } else {
380            num.parse::<u64>()
381        }
382        .map_err(|_| Error::ExpectedInteger)?;
383
384        let shifted_n = n.checked_shl(shift).ok_or(Error::Overflow)?;
385
386        T::try_from(shifted_n).map_err(|_| Error::Overflow)
387    }
388
389    fn parse_signed<T>(&mut self) -> Result<T>
390    where
391        T: TryFrom<i64>,
392    {
393        let i = if self.input.starts_with('-') {
394            let s = self.consume_input();
395            s.parse().map_err(|_| Error::ExpectedInteger)
396        } else {
397            let n = self.parse_unsigned::<u64>()?;
398            i64::try_from(n).map_err(|_| Error::Overflow)
399        }?;
400        T::try_from(i).map_err(|_| Error::Overflow)
401    }
402}
403
404pub fn from_args<'s, 'o, T>(s: &'s str, objects: &'o HashMap<&'s str, &'s str>) -> Result<T>
405where
406    T: Deserialize<'s>,
407{
408    let mut deserializer = Deserializer::from_args(s, objects);
409    let value = T::deserialize(&mut deserializer)?;
410    deserializer.end()?;
411    Ok(value)
412}
413
414pub fn from_arg<'s, T>(s: &'s str) -> Result<T>
415where
416    T: Deserialize<'s>,
417{
418    let mut deserializer = Deserializer::from_arg(s);
419    let value = T::deserialize(&mut deserializer)?;
420    deserializer.end()?;
421    Ok(value)
422}
423
424struct CommaSeparated<'a, 's: 'a, 'o: 'a> {
425    de: &'a mut Deserializer<'s, 'o>,
426}
427
428impl<'a, 's, 'o> CommaSeparated<'a, 's, 'o> {
429    fn new(de: &'a mut Deserializer<'s, 'o>) -> Self {
430        CommaSeparated { de }
431    }
432}
433
434impl<'s> SeqAccess<'s> for CommaSeparated<'_, 's, '_> {
435    type Error = Error;
436
437    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
438    where
439        T: DeserializeSeed<'s>,
440    {
441        if self.de.input.is_empty() {
442            return Ok(None);
443        }
444        seed.deserialize(&mut *self.de).map(Some)
445    }
446}
447
448impl<'s> MapAccess<'s> for CommaSeparated<'_, 's, '_> {
449    type Error = Error;
450
451    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
452    where
453        K: DeserializeSeed<'s>,
454    {
455        if self.de.input.is_empty() {
456            return Ok(None);
457        }
458        let Some(key) = self.de.consume_input_until('=') else {
459            return Err(Error::ExpectedMapEq);
460        };
461        if key.contains(',') {
462            return Err(Error::ExpectedMapEq);
463        }
464        self.de.key = key;
465        let mut sub_de = Deserializer {
466            input: key,
467            key: "",
468            ..*self.de
469        };
470        seed.deserialize(&mut sub_de).map(Some)
471    }
472
473    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
474    where
475        V: DeserializeSeed<'s>,
476    {
477        seed.deserialize(&mut *self.de)
478    }
479}
480
481struct Enum<'a, 's: 'a, 'o: 'a> {
482    de: &'a mut Deserializer<'s, 'o>,
483}
484
485impl<'a, 's, 'o> Enum<'a, 's, 'o> {
486    fn new(de: &'a mut Deserializer<'s, 'o>) -> Self {
487        Enum { de }
488    }
489}
490
491impl<'s> EnumAccess<'s> for Enum<'_, 's, '_> {
492    type Error = Error;
493    type Variant = Self;
494
495    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
496    where
497        V: DeserializeSeed<'s>,
498    {
499        let val = seed.deserialize(&mut *self.de)?;
500        Ok((val, self))
501    }
502}
503
504impl<'s> VariantAccess<'s> for Enum<'_, 's, '_> {
505    type Error = Error;
506
507    fn unit_variant(self) -> Result<()> {
508        Ok(())
509    }
510
511    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
512    where
513        T: DeserializeSeed<'s>,
514    {
515        self.de.top_level = true;
516        seed.deserialize(self.de)
517    }
518
519    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
520    where
521        V: Visitor<'s>,
522    {
523        visitor.visit_seq(CommaSeparated::new(self.de))
524    }
525
526    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
527    where
528        V: Visitor<'s>,
529    {
530        visitor.visit_map(CommaSeparated::new(self.de))
531    }
532}
533
534#[cfg(test)]
535mod test {
536    use std::collections::HashMap;
537    use std::marker::PhantomData;
538
539    use assert_matches::assert_matches;
540    use serde::Deserialize;
541    use serde_bytes::{ByteArray, ByteBuf};
542
543    use crate::{Error, from_arg, from_args};
544
545    #[test]
546    fn test_option() {
547        assert_matches!(from_arg::<Option<u32>>(""), Err(Error::ExpectedInteger));
548        assert_eq!(from_arg::<Option<u32>>("12").unwrap(), Some(12));
549
550        assert_eq!(from_arg::<Option<&'static str>>("").unwrap(), Some(""));
551        assert_eq!(
552            from_args::<Option<&'static str>>("id_1", &HashMap::from([("id_1", "")])).unwrap(),
553            None
554        );
555        assert_eq!(from_arg::<Option<&'static str>>("12").unwrap(), Some("12"));
556        assert_matches!(
557            from_arg::<Option<&'static str>>("id_1"),
558            Err(Error::IdNotFound(id)) if id == "id_1"
559        );
560        assert_eq!(
561            from_args::<Option<&'static str>>("id_1", &HashMap::from([("id_1", "id_2")])).unwrap(),
562            Some("id_2")
563        );
564
565        let map_none = HashMap::from([("id_none", "")]);
566        assert_eq!(from_arg::<Vec<Option<u32>>>("").unwrap(), vec![]);
567        assert_eq!(
568            from_args::<Vec<Option<u32>>>("id_none,", &map_none).unwrap(),
569            vec![None]
570        );
571        assert_eq!(from_arg::<Vec<Option<u32>>>("1,").unwrap(), vec![Some(1)]);
572        assert_eq!(
573            from_arg::<Vec<Option<u32>>>("1,2,").unwrap(),
574            vec![Some(1), Some(2)]
575        );
576        assert_eq!(
577            from_args::<Vec<Option<u32>>>("1,2,id_none,", &map_none).unwrap(),
578            vec![Some(1), Some(2), None]
579        );
580        assert_eq!(
581            from_args::<Vec<Option<u32>>>("id_none,2", &map_none).unwrap(),
582            vec![None, Some(2)]
583        );
584    }
585
586    #[test]
587    fn test_unit() {
588        assert!(from_arg::<()>("").is_ok());
589        assert_matches!(from_arg::<()>("unit"), Err(Error::ExpectedUnit));
590
591        assert!(from_arg::<PhantomData<u8>>("").is_ok());
592        assert_matches!(from_arg::<PhantomData<u8>>("12"), Err(Error::ExpectedUnit));
593
594        #[derive(Debug, Deserialize, PartialEq, Eq)]
595        struct Param {
596            p: PhantomData<u8>,
597        }
598        assert_eq!(from_arg::<Param>("p=").unwrap(), Param { p: PhantomData });
599        assert_matches!(from_arg::<Param>("p=1,"), Err(Error::ExpectedUnit));
600    }
601
602    #[test]
603    fn test_numbers() {
604        assert_eq!(from_arg::<i8>("0").unwrap(), 0);
605        assert_eq!(from_arg::<i8>("1").unwrap(), 1);
606        assert_eq!(from_arg::<i8>("127").unwrap(), 127);
607        assert_matches!(from_arg::<i8>("128"), Err(Error::Overflow));
608        assert_eq!(from_arg::<i8>("-1").unwrap(), -1);
609        assert_eq!(from_arg::<i8>("-128").unwrap(), -128);
610        assert_matches!(from_arg::<i8>("-129"), Err(Error::Overflow));
611
612        assert_eq!(from_arg::<i16>("1k").unwrap(), 1 << 10);
613
614        assert_eq!(from_arg::<i32>("1g").unwrap(), 1 << 30);
615        assert_matches!(from_arg::<i32>("2g"), Err(Error::Overflow));
616        assert_matches!(from_arg::<i32>("0xffffffff"), Err(Error::Overflow));
617
618        assert_eq!(from_arg::<i64>("0xffffffff").unwrap(), 0xffffffff);
619
620        assert_matches!(from_arg::<i64>("gg"), Err(Error::ExpectedInteger));
621
622        assert_matches!(from_arg::<f32>("0.125").unwrap(), 0.125);
623
624        assert_matches!(from_arg::<f64>("-0.5").unwrap(), -0.5);
625    }
626
627    #[test]
628    fn test_char() {
629        assert_eq!(from_arg::<char>("=").unwrap(), '=');
630        assert_eq!(from_arg::<char>("a").unwrap(), 'a');
631        assert_matches!(from_arg::<char>("an"), Err(Error::Message(_)));
632
633        assert_eq!(
634            from_args::<HashMap<char, char>>(
635                "id_1=a,b=id_2,id_2=id_1",
636                &HashMap::from([("id_1", ","), ("id_2", "="),])
637            )
638            .unwrap(),
639            HashMap::from([(',', 'a'), ('b', '='), ('=', ',')])
640        );
641    }
642
643    #[test]
644    fn test_bytes() {
645        assert!(from_arg::<ByteArray<6>>("0xea,0xd7,0xa8,0xe8,0xc6,0x2f").is_ok());
646        assert_matches!(
647            from_arg::<ByteArray<5>>("0xea,0xd7,0xa8,0xe8,0xc6,0x2f"),
648            Err(Error::Trailing(t)) if t == "0x2f"
649        );
650        assert_eq!(
651            from_arg::<ByteBuf>("0xea,0xd7,0xa8,0xe8,0xc6,0x2f").unwrap(),
652            vec![0xea, 0xd7, 0xa8, 0xe8, 0xc6, 0x2f]
653        );
654
655        #[derive(Debug, Deserialize, Eq, PartialEq)]
656        struct MacAddr {
657            addr: ByteArray<6>,
658        }
659        assert_eq!(
660            from_args::<MacAddr>(
661                "addr=id_addr",
662                &HashMap::from([("id_addr", "0xea,0xd7,0xa8,0xe8,0xc6,0x2f")])
663            )
664            .unwrap(),
665            MacAddr {
666                addr: ByteArray::new([0xea, 0xd7, 0xa8, 0xe8, 0xc6, 0x2f])
667            }
668        )
669    }
670
671    #[test]
672    fn test_string() {
673        assert_eq!(
674            from_arg::<String>("test,s=1,c").unwrap(),
675            "test,s=1,c".to_owned()
676        );
677        assert_eq!(
678            from_args::<HashMap<String, String>>(
679                "cmd=id_1",
680                &HashMap::from([("id_1", "console=ttyS0")])
681            )
682            .unwrap(),
683            HashMap::from([("cmd".to_owned(), "console=ttyS0".to_owned())])
684        )
685    }
686
687    #[test]
688    fn test_seq() {
689        assert_eq!(from_arg::<Vec<u32>>("").unwrap(), vec![]);
690
691        assert_eq!(from_arg::<Vec<u32>>("1").unwrap(), vec![1]);
692
693        assert_eq!(from_arg::<Vec<u32>>("1,2,3,4").unwrap(), vec![1, 2, 3, 4]);
694
695        assert_eq!(from_arg::<(u16, bool)>("12,true").unwrap(), (12, true));
696        assert_matches!(
697            from_arg::<(u16, bool)>("12,true,false"),
698            Err(Error::Trailing(t)) if t == "false"
699        );
700
701        #[derive(Debug, Deserialize, PartialEq, Eq)]
702        struct TestStruct {
703            a: (u16, bool),
704        }
705        assert_eq!(
706            from_args::<TestStruct>("a=id_a", &HashMap::from([("id_a", "12,true")])).unwrap(),
707            TestStruct { a: (12, true) }
708        );
709        assert_matches!(
710            from_args::<TestStruct>("a=id_a", &HashMap::from([("id_a", "12,true,true")])),
711            Err(Error::Trailing(t)) if t == "true"
712        );
713
714        #[derive(Debug, Deserialize, PartialEq, Eq)]
715        struct Node {
716            #[serde(default)]
717            name: String,
718            #[serde(default)]
719            start: u64,
720            size: u64,
721        }
722        #[derive(Debug, Deserialize, PartialEq, Eq)]
723        struct Numa {
724            nodes: Vec<Node>,
725        }
726
727        assert_eq!(
728            from_args::<Numa>(
729                "nodes=id_nodes",
730                &HashMap::from([
731                    ("id_nodes", "id_node1,id_node2"),
732                    ("id_node1", "name=a,start=0,size=2g"),
733                    ("id_node2", "name=b,start=4g,size=2g"),
734                ])
735            )
736            .unwrap(),
737            Numa {
738                nodes: vec![
739                    Node {
740                        name: "a".to_owned(),
741                        start: 0,
742                        size: 2 << 30
743                    },
744                    Node {
745                        name: "b".to_owned(),
746                        start: 4 << 30,
747                        size: 2 << 30
748                    }
749                ]
750            }
751        );
752
753        assert_eq!(
754            from_arg::<Numa>("nodes=size=2g,").unwrap(),
755            Numa {
756                nodes: vec![Node {
757                    name: "".to_owned(),
758                    start: 0,
759                    size: 2 << 30
760                }]
761            }
762        );
763
764        #[derive(Debug, Deserialize, PartialEq, Eq)]
765        struct Info(bool, u32);
766
767        assert_eq!(from_arg::<Info>("true,32").unwrap(), Info(true, 32));
768    }
769
770    #[test]
771    fn test_map() {
772        #[derive(Debug, Deserialize, PartialEq, Eq, Hash)]
773        struct MapKey {
774            name: String,
775            id: u32,
776        }
777        #[derive(Debug, Deserialize, PartialEq, Eq)]
778        struct MapVal {
779            addr: String,
780            info: HashMap<String, String>,
781        }
782
783        assert_matches!(
784            from_arg::<MapKey>("name=a,id=1,addr=b"),
785            Err(Error::Ignored(k)) if k == "addr"
786        );
787        assert_matches!(
788            from_arg::<MapKey>("name=a,addr=b,id=1"),
789            Err(Error::Ignored(k)) if k == "addr"
790        );
791        assert_matches!(from_arg::<MapKey>("name=a,ids=b"), Err(Error::Ignored(k)) if k == "ids");
792        assert_matches!(from_arg::<MapKey>("name=a,ids=b,id=1"), Err(Error::Ignored(k)) if k == "ids");
793
794        assert_eq!(
795            from_args::<HashMap<MapKey, MapVal>>(
796                "id_key1=id_val1,id_key2=id_val2",
797                &HashMap::from([
798                    ("id_key1", "name=gic,id=1"),
799                    ("id_key2", "name=pci,id=2"),
800                    ("id_val1", "addr=0xff,info=id_info1"),
801                    ("id_info1", "compatible=id_gic,msi-controller=,#msi-cells=1"),
802                    ("id_gic", "arm,gic-v3-its"),
803                    ("id_val2", "addr=0xcc,info=compatible=pci-host-ecam-generic"),
804                ])
805            )
806            .unwrap(),
807            HashMap::from([
808                (
809                    MapKey {
810                        name: "gic".to_owned(),
811                        id: 1
812                    },
813                    MapVal {
814                        addr: "0xff".to_owned(),
815                        info: HashMap::from([
816                            ("compatible".to_owned(), "arm,gic-v3-its".to_owned()),
817                            ("msi-controller".to_owned(), "".to_owned()),
818                            ("#msi-cells".to_owned(), "1".to_owned())
819                        ])
820                    }
821                ),
822                (
823                    MapKey {
824                        name: "pci".to_owned(),
825                        id: 2
826                    },
827                    MapVal {
828                        addr: "0xcc".to_owned(),
829                        info: HashMap::from([(
830                            "compatible".to_owned(),
831                            "pci-host-ecam-generic".to_owned()
832                        )])
833                    }
834                )
835            ])
836        );
837    }
838
839    #[test]
840    fn test_nested_struct() {
841        #[derive(Debug, Deserialize, PartialEq, Eq)]
842        struct Param {
843            byte: u8,
844            word: u16,
845            dw: u32,
846            long: u64,
847            enable_1: bool,
848            enable_2: bool,
849            enable_3: Option<bool>,
850            sub: SubParam,
851            addr: Addr,
852        }
853
854        #[derive(Debug, Deserialize, PartialEq, Eq)]
855        struct SubParam {
856            b: u8,
857            w: u16,
858            enable: Option<bool>,
859            s: String,
860        }
861
862        #[derive(Debug, Deserialize, PartialEq, Eq)]
863        struct Addr(u32);
864
865        assert_eq!(
866            from_args::<Param>(
867                "byte=0b10,word=0o7k,dw=0x8m,long=10t,enable_1=on,enable_2=off,sub=id_1,addr=1g",
868                &[("id_1", "b=1,w=2,s=s1,enable=on")].into()
869            )
870            .unwrap(),
871            Param {
872                byte: 0b10,
873                word: 0o7 << 10,
874                dw: 0x8 << 20,
875                long: 10 << 40,
876                enable_1: true,
877                enable_2: false,
878                enable_3: None,
879                sub: SubParam {
880                    b: 1,
881                    w: 2,
882                    enable: Some(true),
883                    s: "s1".to_owned(),
884                },
885                addr: Addr(1 << 30)
886            }
887        );
888        assert_matches!(
889            from_arg::<SubParam>("b=1,w=2,enable,s=s1"),
890            Err(Error::ExpectedMapEq)
891        );
892        assert_matches!(
893            from_arg::<SubParam>("b=1,w=2,s=s1,enable"),
894            Err(Error::ExpectedMapEq)
895        );
896    }
897
898    #[test]
899    fn test_bool() {
900        assert_matches!(from_arg::<bool>("on"), Ok(true));
901        assert_matches!(from_arg::<bool>("off"), Ok(false));
902        assert_matches!(from_arg::<bool>("true"), Ok(true));
903        assert_matches!(from_arg::<bool>("false"), Ok(false));
904        assert_matches!(from_arg::<bool>("on,off"), Err(Error::Trailing(t)) if t == "off");
905
906        #[derive(Debug, Deserialize, PartialEq, Eq)]
907        struct BoolStruct {
908            val: bool,
909        }
910        assert_eq!(
911            from_arg::<BoolStruct>("val=on").unwrap(),
912            BoolStruct { val: true }
913        );
914        assert_eq!(
915            from_arg::<BoolStruct>("val=off").unwrap(),
916            BoolStruct { val: false }
917        );
918        assert_eq!(
919            from_arg::<BoolStruct>("val=true").unwrap(),
920            BoolStruct { val: true }
921        );
922        assert_eq!(
923            from_arg::<BoolStruct>("val=false").unwrap(),
924            BoolStruct { val: false }
925        );
926        assert_matches!(from_arg::<BoolStruct>("val=a"), Err(Error::ExpectedBool));
927
928        assert_matches!(
929            from_arg::<BoolStruct>("val=on,key=off"),
930            Err(Error::Ignored(k)) if k == "key"
931        );
932    }
933
934    #[test]
935    fn test_enum() {
936        #[derive(Debug, Deserialize, PartialEq, Eq)]
937        struct SubStruct {
938            a: u32,
939            b: bool,
940        }
941
942        #[derive(Debug, Deserialize, PartialEq, Eq)]
943        enum TestEnum {
944            A {
945                #[serde(default)]
946                val: u32,
947            },
948            B(u64),
949            C(u8, u8),
950            D,
951            #[serde(alias = "e")]
952            E,
953            F(SubStruct),
954            G(u16, String, bool),
955        }
956
957        #[derive(Debug, Deserialize, PartialEq, Eq)]
958        struct TestStruct {
959            num: u32,
960            e: TestEnum,
961        }
962
963        assert_eq!(
964            from_args::<TestStruct>("num=3,e=id_a", &[("id_a", "A,val=1")].into()).unwrap(),
965            TestStruct {
966                num: 3,
967                e: TestEnum::A { val: 1 }
968            }
969        );
970        assert_eq!(
971            from_arg::<TestStruct>("num=4,e=A").unwrap(),
972            TestStruct {
973                num: 4,
974                e: TestEnum::A { val: 0 },
975            }
976        );
977        assert_eq!(
978            from_args::<TestStruct>("num=4,e=id_a", &[("id_a", "A")].into()).unwrap(),
979            TestStruct {
980                num: 4,
981                e: TestEnum::A { val: 0 },
982            }
983        );
984        assert_eq!(
985            from_arg::<TestStruct>("num=4,e=D").unwrap(),
986            TestStruct {
987                num: 4,
988                e: TestEnum::D,
989            }
990        );
991        assert_eq!(
992            from_args::<TestStruct>("num=4,e=id_d", &[("id_d", "D")].into()).unwrap(),
993            TestStruct {
994                num: 4,
995                e: TestEnum::D,
996            }
997        );
998        assert_eq!(
999            from_arg::<TestStruct>("num=3,e=e").unwrap(),
1000            TestStruct {
1001                num: 3,
1002                e: TestEnum::E
1003            }
1004        );
1005        assert_matches!(
1006            from_arg::<TestStruct>("num=4,e=id_d"),
1007            Err(Error::IdNotFound(id)) if id == "id_d"
1008        );
1009        assert_matches!(
1010            from_args::<TestStruct>("num=4,e=id_d", &[].into()),
1011            Err(Error::IdNotFound(id)) if id == "id_d"
1012        );
1013        assert_eq!(from_arg::<TestEnum>("B,1").unwrap(), TestEnum::B(1));
1014        assert_eq!(from_arg::<TestEnum>("D").unwrap(), TestEnum::D);
1015        assert_eq!(
1016            from_arg::<TestEnum>("F,a=1,b=on").unwrap(),
1017            TestEnum::F(SubStruct { a: 1, b: true })
1018        );
1019        assert_eq!(
1020            from_arg::<TestEnum>("G,1,a,true").unwrap(),
1021            TestEnum::G(1, "a".to_owned(), true)
1022        );
1023        assert_matches!(
1024            from_arg::<TestEnum>("G,1,a,true,false"),
1025            Err(Error::Trailing(t)) if t == "false"
1026        );
1027        assert_matches!(
1028            from_args::<TestStruct>(
1029                "num=4,e=id_e",
1030                &HashMap::from([("id_e", "G,1,a,true,false")])
1031            ),
1032            Err(Error::Trailing(t)) if t == "false"
1033        );
1034    }
1035}