1use std::collections::HashMap;
16
17use serde::Deserialize;
18use serde::de::{self, DeserializeSeed, EnumAccess, MapAccess, SeqAccess, VariantAccess, Visitor};
19
20use crate::error::{Error, Result};
21
22#[derive(Debug)]
23pub struct Deserializer<'s, 'o> {
24 input: &'s str,
25 objects: Option<&'o HashMap<&'s str, &'s str>>,
26 top_level: bool,
27 key: &'s str,
28}
29
30impl<'s> de::Deserializer<'s> for &mut Deserializer<'s, '_> {
31 type Error = Error;
32
33 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
34 where
35 V: Visitor<'s>,
36 {
37 Err(Error::UnknownType)
38 }
39
40 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
41 where
42 V: Visitor<'s>,
43 {
44 let s = self.consume_input();
45 match s {
46 "on" | "true" => visitor.visit_bool(true),
47 "off" | "false" => visitor.visit_bool(false),
48 _ => Err(Error::ExpectedBool),
49 }
50 }
51
52 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
53 where
54 V: Visitor<'s>,
55 {
56 visitor.visit_i8(self.parse_signed()?)
57 }
58
59 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
60 where
61 V: Visitor<'s>,
62 {
63 visitor.visit_i16(self.parse_signed()?)
64 }
65
66 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
67 where
68 V: Visitor<'s>,
69 {
70 visitor.visit_i32(self.parse_signed()?)
71 }
72
73 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
74 where
75 V: Visitor<'s>,
76 {
77 visitor.visit_i64(self.parse_signed()?)
78 }
79
80 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
81 where
82 V: Visitor<'s>,
83 {
84 visitor.visit_u8(self.parse_unsigned()?)
85 }
86
87 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
88 where
89 V: Visitor<'s>,
90 {
91 visitor.visit_u16(self.parse_unsigned()?)
92 }
93
94 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
95 where
96 V: Visitor<'s>,
97 {
98 visitor.visit_u32(self.parse_unsigned()?)
99 }
100
101 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
102 where
103 V: Visitor<'s>,
104 {
105 visitor.visit_u64(self.parse_unsigned()?)
106 }
107
108 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
109 where
110 V: Visitor<'s>,
111 {
112 let s = self.consume_input();
113 visitor.visit_f32(s.parse().map_err(|_| Error::ExpectedFloat)?)
114 }
115
116 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
117 where
118 V: Visitor<'s>,
119 {
120 let s = self.consume_input();
121 visitor.visit_f64(s.parse().map_err(|_| Error::ExpectedFloat)?)
122 }
123
124 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
125 where
126 V: Visitor<'s>,
127 {
128 self.deserialize_str(visitor)
129 }
130
131 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
132 where
133 V: Visitor<'s>,
134 {
135 if self.top_level {
136 visitor.visit_borrowed_str(self.consume_all())
137 } else {
138 let id = self.consume_input();
139 visitor.visit_borrowed_str(self.deref_id(id)?)
140 }
141 }
142
143 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
144 where
145 V: Visitor<'s>,
146 {
147 self.deserialize_str(visitor)
148 }
149
150 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
151 where
152 V: Visitor<'s>,
153 {
154 self.deserialize_seq(visitor)
155 }
156
157 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
158 where
159 V: Visitor<'s>,
160 {
161 self.deserialize_bytes(visitor)
162 }
163
164 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
165 where
166 V: Visitor<'s>,
167 {
168 let id = self.consume_input();
169 let s = self.deref_id(id)?;
170 if id.starts_with("id_") && s.is_empty() {
171 visitor.visit_none()
172 } else {
173 let mut sub_de = Deserializer { input: s, ..*self };
174 visitor.visit_some(&mut sub_de)
175 }
176 }
177
178 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
179 where
180 V: Visitor<'s>,
181 {
182 let s = self.consume_input();
183 if s.is_empty() {
184 visitor.visit_unit()
185 } else {
186 Err(Error::ExpectedUnit)
187 }
188 }
189
190 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
191 where
192 V: Visitor<'s>,
193 {
194 self.deserialize_unit(visitor)
195 }
196
197 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
198 where
199 V: Visitor<'s>,
200 {
201 visitor.visit_newtype_struct(self)
202 }
203
204 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
205 where
206 V: Visitor<'s>,
207 {
208 self.deserialize_nested(|de| visitor.visit_seq(CommaSeparated::new(de)))
209 }
210
211 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
212 where
213 V: Visitor<'s>,
214 {
215 self.deserialize_seq(visitor)
216 }
217
218 fn deserialize_tuple_struct<V>(
219 self,
220 _name: &'static str,
221 _len: usize,
222 visitor: V,
223 ) -> Result<V::Value>
224 where
225 V: Visitor<'s>,
226 {
227 self.deserialize_seq(visitor)
228 }
229
230 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
231 where
232 V: Visitor<'s>,
233 {
234 self.deserialize_nested(|de| visitor.visit_map(CommaSeparated::new(de)))
235 }
236
237 fn deserialize_struct<V>(
238 self,
239 _name: &'static str,
240 _fields: &'static [&'static str],
241 visitor: V,
242 ) -> Result<V::Value>
243 where
244 V: Visitor<'s>,
245 {
246 self.deserialize_map(visitor)
247 }
248
249 fn deserialize_enum<V>(
250 self,
251 _name: &'static str,
252 _variants: &'static [&'static str],
253 visitor: V,
254 ) -> Result<V::Value>
255 where
256 V: Visitor<'s>,
257 {
258 self.deserialize_nested(|de| visitor.visit_enum(Enum::new(de)))
259 }
260
261 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
262 where
263 V: Visitor<'s>,
264 {
265 visitor.visit_borrowed_str(self.consume_input())
266 }
267
268 fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
269 where
270 V: Visitor<'s>,
271 {
272 Err(Error::Ignored(self.key.to_owned()))
273 }
274}
275
276impl<'s, 'o> Deserializer<'s, 'o> {
277 pub fn from_args(input: &'s str, objects: &'o HashMap<&'s str, &'s str>) -> Self {
278 Deserializer {
279 input,
280 objects: Some(objects),
281 top_level: true,
282 key: "",
283 }
284 }
285
286 pub fn from_arg(input: &'s str) -> Self {
287 Deserializer {
288 input,
289 objects: None,
290 top_level: true,
291 key: "",
292 }
293 }
294
295 fn end(&self) -> Result<()> {
296 if self.input.is_empty() {
297 Ok(())
298 } else {
299 Err(Error::Trailing(self.input.to_owned()))
300 }
301 }
302
303 fn deserialize_nested<F, V>(&mut self, f: F) -> Result<V>
304 where
305 F: FnOnce(&mut Self) -> Result<V>,
306 {
307 let mut sub_de;
308 let de = if !self.top_level {
309 let id = self.consume_input();
310 let sub_input = self.deref_id(id)?;
311 sub_de = Deserializer {
312 input: sub_input,
313 ..*self
314 };
315 &mut sub_de
316 } else {
317 self.top_level = false;
318 self
319 };
320 let val = f(de)?;
321 de.end()?;
322 Ok(val)
323 }
324
325 fn consume_input_until(&mut self, end: char) -> Option<&'s str> {
326 let len = self.input.find(end)?;
327 let s = &self.input[..len];
328 self.input = &self.input[len + end.len_utf8()..];
329 Some(s)
330 }
331
332 fn consume_all(&mut self) -> &'s str {
333 let s = self.input;
334 self.input = "";
335 s
336 }
337
338 fn consume_input(&mut self) -> &'s str {
339 match self.consume_input_until(',') {
340 Some(s) => s,
341 None => self.consume_all(),
342 }
343 }
344
345 fn deref_id(&self, id: &'s str) -> Result<&'s str> {
346 if id.starts_with("id_") {
347 if let Some(s) = self.objects.and_then(|objects| objects.get(id)) {
348 Ok(s)
349 } else {
350 Err(Error::IdNotFound(id.to_owned()))
351 }
352 } else {
353 Ok(id)
354 }
355 }
356
357 fn parse_unsigned<T>(&mut self) -> Result<T>
358 where
359 T: TryFrom<u64>,
360 {
361 let s = self.consume_input();
362 let (num, shift) = if let Some((num, "")) = s.split_once(['k', 'K']) {
363 (num, 10)
364 } else if let Some((num, "")) = s.split_once(['m', 'M']) {
365 (num, 20)
366 } else if let Some((num, "")) = s.split_once(['g', 'G']) {
367 (num, 30)
368 } else if let Some((num, "")) = s.split_once(['t', 'T']) {
369 (num, 40)
370 } else {
371 (s, 0)
372 };
373 let n = if let Some(num_h) = num.strip_prefix("0x") {
374 u64::from_str_radix(num_h, 16)
375 } else if let Some(num_o) = num.strip_prefix("0o") {
376 u64::from_str_radix(num_o, 8)
377 } else if let Some(num_b) = num.strip_prefix("0b") {
378 u64::from_str_radix(num_b, 2)
379 } else {
380 num.parse::<u64>()
381 }
382 .map_err(|_| Error::ExpectedInteger)?;
383
384 let shifted_n = n.checked_shl(shift).ok_or(Error::Overflow)?;
385
386 T::try_from(shifted_n).map_err(|_| Error::Overflow)
387 }
388
389 fn parse_signed<T>(&mut self) -> Result<T>
390 where
391 T: TryFrom<i64>,
392 {
393 let i = if self.input.starts_with('-') {
394 let s = self.consume_input();
395 s.parse().map_err(|_| Error::ExpectedInteger)
396 } else {
397 let n = self.parse_unsigned::<u64>()?;
398 i64::try_from(n).map_err(|_| Error::Overflow)
399 }?;
400 T::try_from(i).map_err(|_| Error::Overflow)
401 }
402}
403
404pub fn from_args<'s, 'o, T>(s: &'s str, objects: &'o HashMap<&'s str, &'s str>) -> Result<T>
405where
406 T: Deserialize<'s>,
407{
408 let mut deserializer = Deserializer::from_args(s, objects);
409 let value = T::deserialize(&mut deserializer)?;
410 deserializer.end()?;
411 Ok(value)
412}
413
414pub fn from_arg<'s, T>(s: &'s str) -> Result<T>
415where
416 T: Deserialize<'s>,
417{
418 let mut deserializer = Deserializer::from_arg(s);
419 let value = T::deserialize(&mut deserializer)?;
420 deserializer.end()?;
421 Ok(value)
422}
423
424struct CommaSeparated<'a, 's: 'a, 'o: 'a> {
425 de: &'a mut Deserializer<'s, 'o>,
426}
427
428impl<'a, 's, 'o> CommaSeparated<'a, 's, 'o> {
429 fn new(de: &'a mut Deserializer<'s, 'o>) -> Self {
430 CommaSeparated { de }
431 }
432}
433
434impl<'s> SeqAccess<'s> for CommaSeparated<'_, 's, '_> {
435 type Error = Error;
436
437 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
438 where
439 T: DeserializeSeed<'s>,
440 {
441 if self.de.input.is_empty() {
442 return Ok(None);
443 }
444 seed.deserialize(&mut *self.de).map(Some)
445 }
446}
447
448impl<'s> MapAccess<'s> for CommaSeparated<'_, 's, '_> {
449 type Error = Error;
450
451 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
452 where
453 K: DeserializeSeed<'s>,
454 {
455 if self.de.input.is_empty() {
456 return Ok(None);
457 }
458 let Some(key) = self.de.consume_input_until('=') else {
459 return Err(Error::ExpectedMapEq);
460 };
461 if key.contains(',') {
462 return Err(Error::ExpectedMapEq);
463 }
464 self.de.key = key;
465 let mut sub_de = Deserializer {
466 input: key,
467 key: "",
468 ..*self.de
469 };
470 seed.deserialize(&mut sub_de).map(Some)
471 }
472
473 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
474 where
475 V: DeserializeSeed<'s>,
476 {
477 seed.deserialize(&mut *self.de)
478 }
479}
480
481struct Enum<'a, 's: 'a, 'o: 'a> {
482 de: &'a mut Deserializer<'s, 'o>,
483}
484
485impl<'a, 's, 'o> Enum<'a, 's, 'o> {
486 fn new(de: &'a mut Deserializer<'s, 'o>) -> Self {
487 Enum { de }
488 }
489}
490
491impl<'s> EnumAccess<'s> for Enum<'_, 's, '_> {
492 type Error = Error;
493 type Variant = Self;
494
495 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
496 where
497 V: DeserializeSeed<'s>,
498 {
499 let val = seed.deserialize(&mut *self.de)?;
500 Ok((val, self))
501 }
502}
503
504impl<'s> VariantAccess<'s> for Enum<'_, 's, '_> {
505 type Error = Error;
506
507 fn unit_variant(self) -> Result<()> {
508 Ok(())
509 }
510
511 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
512 where
513 T: DeserializeSeed<'s>,
514 {
515 self.de.top_level = true;
516 seed.deserialize(self.de)
517 }
518
519 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
520 where
521 V: Visitor<'s>,
522 {
523 visitor.visit_seq(CommaSeparated::new(self.de))
524 }
525
526 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
527 where
528 V: Visitor<'s>,
529 {
530 visitor.visit_map(CommaSeparated::new(self.de))
531 }
532}
533
534#[cfg(test)]
535mod test {
536 use std::collections::HashMap;
537 use std::marker::PhantomData;
538
539 use assert_matches::assert_matches;
540 use serde::Deserialize;
541 use serde_bytes::{ByteArray, ByteBuf};
542
543 use crate::{Error, from_arg, from_args};
544
545 #[test]
546 fn test_option() {
547 assert_matches!(from_arg::<Option<u32>>(""), Err(Error::ExpectedInteger));
548 assert_eq!(from_arg::<Option<u32>>("12").unwrap(), Some(12));
549
550 assert_eq!(from_arg::<Option<&'static str>>("").unwrap(), Some(""));
551 assert_eq!(
552 from_args::<Option<&'static str>>("id_1", &HashMap::from([("id_1", "")])).unwrap(),
553 None
554 );
555 assert_eq!(from_arg::<Option<&'static str>>("12").unwrap(), Some("12"));
556 assert_matches!(
557 from_arg::<Option<&'static str>>("id_1"),
558 Err(Error::IdNotFound(id)) if id == "id_1"
559 );
560 assert_eq!(
561 from_args::<Option<&'static str>>("id_1", &HashMap::from([("id_1", "id_2")])).unwrap(),
562 Some("id_2")
563 );
564
565 let map_none = HashMap::from([("id_none", "")]);
566 assert_eq!(from_arg::<Vec<Option<u32>>>("").unwrap(), vec![]);
567 assert_eq!(
568 from_args::<Vec<Option<u32>>>("id_none,", &map_none).unwrap(),
569 vec![None]
570 );
571 assert_eq!(from_arg::<Vec<Option<u32>>>("1,").unwrap(), vec![Some(1)]);
572 assert_eq!(
573 from_arg::<Vec<Option<u32>>>("1,2,").unwrap(),
574 vec![Some(1), Some(2)]
575 );
576 assert_eq!(
577 from_args::<Vec<Option<u32>>>("1,2,id_none,", &map_none).unwrap(),
578 vec![Some(1), Some(2), None]
579 );
580 assert_eq!(
581 from_args::<Vec<Option<u32>>>("id_none,2", &map_none).unwrap(),
582 vec![None, Some(2)]
583 );
584 }
585
586 #[test]
587 fn test_unit() {
588 assert!(from_arg::<()>("").is_ok());
589 assert_matches!(from_arg::<()>("unit"), Err(Error::ExpectedUnit));
590
591 assert!(from_arg::<PhantomData<u8>>("").is_ok());
592 assert_matches!(from_arg::<PhantomData<u8>>("12"), Err(Error::ExpectedUnit));
593
594 #[derive(Debug, Deserialize, PartialEq, Eq)]
595 struct Param {
596 p: PhantomData<u8>,
597 }
598 assert_eq!(from_arg::<Param>("p=").unwrap(), Param { p: PhantomData });
599 assert_matches!(from_arg::<Param>("p=1,"), Err(Error::ExpectedUnit));
600 }
601
602 #[test]
603 fn test_numbers() {
604 assert_eq!(from_arg::<i8>("0").unwrap(), 0);
605 assert_eq!(from_arg::<i8>("1").unwrap(), 1);
606 assert_eq!(from_arg::<i8>("127").unwrap(), 127);
607 assert_matches!(from_arg::<i8>("128"), Err(Error::Overflow));
608 assert_eq!(from_arg::<i8>("-1").unwrap(), -1);
609 assert_eq!(from_arg::<i8>("-128").unwrap(), -128);
610 assert_matches!(from_arg::<i8>("-129"), Err(Error::Overflow));
611
612 assert_eq!(from_arg::<i16>("1k").unwrap(), 1 << 10);
613
614 assert_eq!(from_arg::<i32>("1g").unwrap(), 1 << 30);
615 assert_matches!(from_arg::<i32>("2g"), Err(Error::Overflow));
616 assert_matches!(from_arg::<i32>("0xffffffff"), Err(Error::Overflow));
617
618 assert_eq!(from_arg::<i64>("0xffffffff").unwrap(), 0xffffffff);
619
620 assert_matches!(from_arg::<i64>("gg"), Err(Error::ExpectedInteger));
621
622 assert_matches!(from_arg::<f32>("0.125").unwrap(), 0.125);
623
624 assert_matches!(from_arg::<f64>("-0.5").unwrap(), -0.5);
625 }
626
627 #[test]
628 fn test_char() {
629 assert_eq!(from_arg::<char>("=").unwrap(), '=');
630 assert_eq!(from_arg::<char>("a").unwrap(), 'a');
631 assert_matches!(from_arg::<char>("an"), Err(Error::Message(_)));
632
633 assert_eq!(
634 from_args::<HashMap<char, char>>(
635 "id_1=a,b=id_2,id_2=id_1",
636 &HashMap::from([("id_1", ","), ("id_2", "="),])
637 )
638 .unwrap(),
639 HashMap::from([(',', 'a'), ('b', '='), ('=', ',')])
640 );
641 }
642
643 #[test]
644 fn test_bytes() {
645 assert!(from_arg::<ByteArray<6>>("0xea,0xd7,0xa8,0xe8,0xc6,0x2f").is_ok());
646 assert_matches!(
647 from_arg::<ByteArray<5>>("0xea,0xd7,0xa8,0xe8,0xc6,0x2f"),
648 Err(Error::Trailing(t)) if t == "0x2f"
649 );
650 assert_eq!(
651 from_arg::<ByteBuf>("0xea,0xd7,0xa8,0xe8,0xc6,0x2f").unwrap(),
652 vec![0xea, 0xd7, 0xa8, 0xe8, 0xc6, 0x2f]
653 );
654
655 #[derive(Debug, Deserialize, Eq, PartialEq)]
656 struct MacAddr {
657 addr: ByteArray<6>,
658 }
659 assert_eq!(
660 from_args::<MacAddr>(
661 "addr=id_addr",
662 &HashMap::from([("id_addr", "0xea,0xd7,0xa8,0xe8,0xc6,0x2f")])
663 )
664 .unwrap(),
665 MacAddr {
666 addr: ByteArray::new([0xea, 0xd7, 0xa8, 0xe8, 0xc6, 0x2f])
667 }
668 )
669 }
670
671 #[test]
672 fn test_string() {
673 assert_eq!(
674 from_arg::<String>("test,s=1,c").unwrap(),
675 "test,s=1,c".to_owned()
676 );
677 assert_eq!(
678 from_args::<HashMap<String, String>>(
679 "cmd=id_1",
680 &HashMap::from([("id_1", "console=ttyS0")])
681 )
682 .unwrap(),
683 HashMap::from([("cmd".to_owned(), "console=ttyS0".to_owned())])
684 )
685 }
686
687 #[test]
688 fn test_seq() {
689 assert_eq!(from_arg::<Vec<u32>>("").unwrap(), vec![]);
690
691 assert_eq!(from_arg::<Vec<u32>>("1").unwrap(), vec![1]);
692
693 assert_eq!(from_arg::<Vec<u32>>("1,2,3,4").unwrap(), vec![1, 2, 3, 4]);
694
695 assert_eq!(from_arg::<(u16, bool)>("12,true").unwrap(), (12, true));
696 assert_matches!(
697 from_arg::<(u16, bool)>("12,true,false"),
698 Err(Error::Trailing(t)) if t == "false"
699 );
700
701 #[derive(Debug, Deserialize, PartialEq, Eq)]
702 struct TestStruct {
703 a: (u16, bool),
704 }
705 assert_eq!(
706 from_args::<TestStruct>("a=id_a", &HashMap::from([("id_a", "12,true")])).unwrap(),
707 TestStruct { a: (12, true) }
708 );
709 assert_matches!(
710 from_args::<TestStruct>("a=id_a", &HashMap::from([("id_a", "12,true,true")])),
711 Err(Error::Trailing(t)) if t == "true"
712 );
713
714 #[derive(Debug, Deserialize, PartialEq, Eq)]
715 struct Node {
716 #[serde(default)]
717 name: String,
718 #[serde(default)]
719 start: u64,
720 size: u64,
721 }
722 #[derive(Debug, Deserialize, PartialEq, Eq)]
723 struct Numa {
724 nodes: Vec<Node>,
725 }
726
727 assert_eq!(
728 from_args::<Numa>(
729 "nodes=id_nodes",
730 &HashMap::from([
731 ("id_nodes", "id_node1,id_node2"),
732 ("id_node1", "name=a,start=0,size=2g"),
733 ("id_node2", "name=b,start=4g,size=2g"),
734 ])
735 )
736 .unwrap(),
737 Numa {
738 nodes: vec![
739 Node {
740 name: "a".to_owned(),
741 start: 0,
742 size: 2 << 30
743 },
744 Node {
745 name: "b".to_owned(),
746 start: 4 << 30,
747 size: 2 << 30
748 }
749 ]
750 }
751 );
752
753 assert_eq!(
754 from_arg::<Numa>("nodes=size=2g,").unwrap(),
755 Numa {
756 nodes: vec![Node {
757 name: "".to_owned(),
758 start: 0,
759 size: 2 << 30
760 }]
761 }
762 );
763
764 #[derive(Debug, Deserialize, PartialEq, Eq)]
765 struct Info(bool, u32);
766
767 assert_eq!(from_arg::<Info>("true,32").unwrap(), Info(true, 32));
768 }
769
770 #[test]
771 fn test_map() {
772 #[derive(Debug, Deserialize, PartialEq, Eq, Hash)]
773 struct MapKey {
774 name: String,
775 id: u32,
776 }
777 #[derive(Debug, Deserialize, PartialEq, Eq)]
778 struct MapVal {
779 addr: String,
780 info: HashMap<String, String>,
781 }
782
783 assert_matches!(
784 from_arg::<MapKey>("name=a,id=1,addr=b"),
785 Err(Error::Ignored(k)) if k == "addr"
786 );
787 assert_matches!(
788 from_arg::<MapKey>("name=a,addr=b,id=1"),
789 Err(Error::Ignored(k)) if k == "addr"
790 );
791 assert_matches!(from_arg::<MapKey>("name=a,ids=b"), Err(Error::Ignored(k)) if k == "ids");
792 assert_matches!(from_arg::<MapKey>("name=a,ids=b,id=1"), Err(Error::Ignored(k)) if k == "ids");
793
794 assert_eq!(
795 from_args::<HashMap<MapKey, MapVal>>(
796 "id_key1=id_val1,id_key2=id_val2",
797 &HashMap::from([
798 ("id_key1", "name=gic,id=1"),
799 ("id_key2", "name=pci,id=2"),
800 ("id_val1", "addr=0xff,info=id_info1"),
801 ("id_info1", "compatible=id_gic,msi-controller=,#msi-cells=1"),
802 ("id_gic", "arm,gic-v3-its"),
803 ("id_val2", "addr=0xcc,info=compatible=pci-host-ecam-generic"),
804 ])
805 )
806 .unwrap(),
807 HashMap::from([
808 (
809 MapKey {
810 name: "gic".to_owned(),
811 id: 1
812 },
813 MapVal {
814 addr: "0xff".to_owned(),
815 info: HashMap::from([
816 ("compatible".to_owned(), "arm,gic-v3-its".to_owned()),
817 ("msi-controller".to_owned(), "".to_owned()),
818 ("#msi-cells".to_owned(), "1".to_owned())
819 ])
820 }
821 ),
822 (
823 MapKey {
824 name: "pci".to_owned(),
825 id: 2
826 },
827 MapVal {
828 addr: "0xcc".to_owned(),
829 info: HashMap::from([(
830 "compatible".to_owned(),
831 "pci-host-ecam-generic".to_owned()
832 )])
833 }
834 )
835 ])
836 );
837 }
838
839 #[test]
840 fn test_nested_struct() {
841 #[derive(Debug, Deserialize, PartialEq, Eq)]
842 struct Param {
843 byte: u8,
844 word: u16,
845 dw: u32,
846 long: u64,
847 enable_1: bool,
848 enable_2: bool,
849 enable_3: Option<bool>,
850 sub: SubParam,
851 addr: Addr,
852 }
853
854 #[derive(Debug, Deserialize, PartialEq, Eq)]
855 struct SubParam {
856 b: u8,
857 w: u16,
858 enable: Option<bool>,
859 s: String,
860 }
861
862 #[derive(Debug, Deserialize, PartialEq, Eq)]
863 struct Addr(u32);
864
865 assert_eq!(
866 from_args::<Param>(
867 "byte=0b10,word=0o7k,dw=0x8m,long=10t,enable_1=on,enable_2=off,sub=id_1,addr=1g",
868 &[("id_1", "b=1,w=2,s=s1,enable=on")].into()
869 )
870 .unwrap(),
871 Param {
872 byte: 0b10,
873 word: 0o7 << 10,
874 dw: 0x8 << 20,
875 long: 10 << 40,
876 enable_1: true,
877 enable_2: false,
878 enable_3: None,
879 sub: SubParam {
880 b: 1,
881 w: 2,
882 enable: Some(true),
883 s: "s1".to_owned(),
884 },
885 addr: Addr(1 << 30)
886 }
887 );
888 assert_matches!(
889 from_arg::<SubParam>("b=1,w=2,enable,s=s1"),
890 Err(Error::ExpectedMapEq)
891 );
892 assert_matches!(
893 from_arg::<SubParam>("b=1,w=2,s=s1,enable"),
894 Err(Error::ExpectedMapEq)
895 );
896 }
897
898 #[test]
899 fn test_bool() {
900 assert_matches!(from_arg::<bool>("on"), Ok(true));
901 assert_matches!(from_arg::<bool>("off"), Ok(false));
902 assert_matches!(from_arg::<bool>("true"), Ok(true));
903 assert_matches!(from_arg::<bool>("false"), Ok(false));
904 assert_matches!(from_arg::<bool>("on,off"), Err(Error::Trailing(t)) if t == "off");
905
906 #[derive(Debug, Deserialize, PartialEq, Eq)]
907 struct BoolStruct {
908 val: bool,
909 }
910 assert_eq!(
911 from_arg::<BoolStruct>("val=on").unwrap(),
912 BoolStruct { val: true }
913 );
914 assert_eq!(
915 from_arg::<BoolStruct>("val=off").unwrap(),
916 BoolStruct { val: false }
917 );
918 assert_eq!(
919 from_arg::<BoolStruct>("val=true").unwrap(),
920 BoolStruct { val: true }
921 );
922 assert_eq!(
923 from_arg::<BoolStruct>("val=false").unwrap(),
924 BoolStruct { val: false }
925 );
926 assert_matches!(from_arg::<BoolStruct>("val=a"), Err(Error::ExpectedBool));
927
928 assert_matches!(
929 from_arg::<BoolStruct>("val=on,key=off"),
930 Err(Error::Ignored(k)) if k == "key"
931 );
932 }
933
934 #[test]
935 fn test_enum() {
936 #[derive(Debug, Deserialize, PartialEq, Eq)]
937 struct SubStruct {
938 a: u32,
939 b: bool,
940 }
941
942 #[derive(Debug, Deserialize, PartialEq, Eq)]
943 enum TestEnum {
944 A {
945 #[serde(default)]
946 val: u32,
947 },
948 B(u64),
949 C(u8, u8),
950 D,
951 #[serde(alias = "e")]
952 E,
953 F(SubStruct),
954 G(u16, String, bool),
955 }
956
957 #[derive(Debug, Deserialize, PartialEq, Eq)]
958 struct TestStruct {
959 num: u32,
960 e: TestEnum,
961 }
962
963 assert_eq!(
964 from_args::<TestStruct>("num=3,e=id_a", &[("id_a", "A,val=1")].into()).unwrap(),
965 TestStruct {
966 num: 3,
967 e: TestEnum::A { val: 1 }
968 }
969 );
970 assert_eq!(
971 from_arg::<TestStruct>("num=4,e=A").unwrap(),
972 TestStruct {
973 num: 4,
974 e: TestEnum::A { val: 0 },
975 }
976 );
977 assert_eq!(
978 from_args::<TestStruct>("num=4,e=id_a", &[("id_a", "A")].into()).unwrap(),
979 TestStruct {
980 num: 4,
981 e: TestEnum::A { val: 0 },
982 }
983 );
984 assert_eq!(
985 from_arg::<TestStruct>("num=4,e=D").unwrap(),
986 TestStruct {
987 num: 4,
988 e: TestEnum::D,
989 }
990 );
991 assert_eq!(
992 from_args::<TestStruct>("num=4,e=id_d", &[("id_d", "D")].into()).unwrap(),
993 TestStruct {
994 num: 4,
995 e: TestEnum::D,
996 }
997 );
998 assert_eq!(
999 from_arg::<TestStruct>("num=3,e=e").unwrap(),
1000 TestStruct {
1001 num: 3,
1002 e: TestEnum::E
1003 }
1004 );
1005 assert_matches!(
1006 from_arg::<TestStruct>("num=4,e=id_d"),
1007 Err(Error::IdNotFound(id)) if id == "id_d"
1008 );
1009 assert_matches!(
1010 from_args::<TestStruct>("num=4,e=id_d", &[].into()),
1011 Err(Error::IdNotFound(id)) if id == "id_d"
1012 );
1013 assert_eq!(from_arg::<TestEnum>("B,1").unwrap(), TestEnum::B(1));
1014 assert_eq!(from_arg::<TestEnum>("D").unwrap(), TestEnum::D);
1015 assert_eq!(
1016 from_arg::<TestEnum>("F,a=1,b=on").unwrap(),
1017 TestEnum::F(SubStruct { a: 1, b: true })
1018 );
1019 assert_eq!(
1020 from_arg::<TestEnum>("G,1,a,true").unwrap(),
1021 TestEnum::G(1, "a".to_owned(), true)
1022 );
1023 assert_matches!(
1024 from_arg::<TestEnum>("G,1,a,true,false"),
1025 Err(Error::Trailing(t)) if t == "false"
1026 );
1027 assert_matches!(
1028 from_args::<TestStruct>(
1029 "num=4,e=id_e",
1030 &HashMap::from([("id_e", "G,1,a,true,false")])
1031 ),
1032 Err(Error::Trailing(t)) if t == "false"
1033 );
1034 }
1035}