1use core::marker::PhantomData;
4
5use super::{DecodeBorrowed, Error, NbyteReader};
6use crate::{formats::Format, io::IoRead};
7
8pub struct MapDecoder<Map, K, V>(PhantomData<(Map, K, V)>);
10
11#[allow(clippy::type_complexity)]
12fn decode_kv<'de, R, K, V>(reader: &mut R) -> Result<(K::Value, V::Value), Error<R::Error>>
13where
14 R: IoRead<'de>,
15 K: DecodeBorrowed<'de>,
16 V: DecodeBorrowed<'de>,
17{
18 let k = K::decode_borrowed(reader)?;
19 let v = V::decode_borrowed(reader)?;
20 Ok((k, v))
21}
22
23impl<'de, Map, K, V> DecodeBorrowed<'de> for MapDecoder<Map, K, V>
24where
25 K: DecodeBorrowed<'de>,
26 V: DecodeBorrowed<'de>,
27 Map: FromIterator<(K::Value, V::Value)>,
28{
29 type Value = Map;
30
31 fn decode_borrowed_with_format<R>(
32 format: Format,
33 reader: &mut R,
34 ) -> Result<Self::Value, Error<R::Error>>
35 where
36 R: IoRead<'de>,
37 {
38 let len = match format {
39 Format::FixMap(len) => len.into(),
40 Format::Map16 => NbyteReader::<2>::read(reader)?,
41 Format::Map32 => NbyteReader::<4>::read(reader)?,
42 _ => return Err(Error::UnexpectedFormat),
43 };
44
45 let mut err: Option<Error<R::Error>> = None;
46 let iter = (0..len).map_while(|_| match decode_kv::<R, K, V>(reader) {
47 Ok((k, v)) => Some((k, v)),
48 Err(e) => {
49 err = Some(e);
50 None
51 }
52 });
53 let res = Map::from_iter(iter);
54 match err {
55 Some(e) => Err(e),
56 None => Ok(res),
57 }
58 }
59}
60
61#[cfg(feature = "alloc")]
62impl<'de, K, V> DecodeBorrowed<'de> for alloc::collections::BTreeMap<K, V>
63where
64 K: DecodeBorrowed<'de>,
65 V: DecodeBorrowed<'de>,
66 K::Value: Ord,
67{
68 type Value = alloc::collections::BTreeMap<K::Value, V::Value>;
69
70 fn decode_borrowed_with_format<R>(
71 format: Format,
72 reader: &mut R,
73 ) -> Result<Self::Value, Error<R::Error>>
74 where
75 R: IoRead<'de>,
76 {
77 MapDecoder::<Self::Value, K, V>::decode_borrowed_with_format(format, reader)
78 }
79}
80
81#[cfg(feature = "std")]
82impl<'de, K, V> DecodeBorrowed<'de> for std::collections::HashMap<K, V>
83where
84 K: DecodeBorrowed<'de>,
85 V: DecodeBorrowed<'de>,
86 K::Value: Eq + core::hash::Hash,
87{
88 type Value = std::collections::HashMap<K::Value, V::Value>;
89
90 fn decode_borrowed_with_format<R>(
91 format: Format,
92 reader: &mut R,
93 ) -> Result<Self::Value, Error<R::Error>>
94 where
95 R: IoRead<'de>,
96 {
97 MapDecoder::<Self::Value, K, V>::decode_borrowed_with_format(format, reader)
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::decode::Decode;
105 use rstest::rstest;
106
107 #[rstest]
108 #[case(&[0x82, 0x01, 0x0a, 0x02, 0x14], vec![(1u8, 10u8), (2, 20)], &[])]
109 #[case(&[0xde, 0x00, 0x02, 0x01, 0x0a, 0x02, 0x14], vec![(1u8, 10u8), (2, 20)], &[])]
110 fn map_decode_success(
111 #[case] buf: &[u8],
112 #[case] expect: Vec<(u8, u8)>,
113 #[case] rest_expect: &[u8],
114 ) {
115 let mut r = crate::io::SliceReader::new(buf);
116 let decoded = MapDecoder::<Vec<(u8, u8)>, u8, u8>::decode(&mut r).unwrap();
117 assert_eq!(decoded, expect);
118 assert_eq!(r.rest(), rest_expect);
119 }
120
121 #[test]
122 fn map_decoder_unexpected_format() {
123 let buf = &[0x91, 0x00];
125 let mut r = crate::io::SliceReader::new(buf);
126 let err = MapDecoder::<Vec<(u8, u8)>, u8, u8>::decode(&mut r).unwrap_err();
127 assert!(matches!(err, Error::UnexpectedFormat));
128 }
129
130 #[test]
131 fn map_decode_eof_on_key() {
132 let buf = &[0x81];
134 let mut r = crate::io::SliceReader::new(buf);
135 let err = MapDecoder::<Vec<(u8, u8)>, u8, u8>::decode(&mut r).unwrap_err();
136 assert!(matches!(err, Error::Io(_)));
137 }
138
139 #[test]
140 fn map_decode_key_unexpected_format() {
141 let buf = &[0x81, 0xa1, b'a', 0x02];
143 let mut r = crate::io::SliceReader::new(buf);
144 let err = MapDecoder::<Vec<(u8, u8)>, u8, u8>::decode(&mut r).unwrap_err();
145 assert!(matches!(err, Error::UnexpectedFormat));
146 }
147
148 #[test]
149 fn map_decode_value_error_after_first_pair() {
150 let buf = &[0x82, 0x01, 0x01, 0x02];
152 let mut r = crate::io::SliceReader::new(buf);
153 let err = MapDecoder::<Vec<(u8, u8)>, u8, u8>::decode(&mut r).unwrap_err();
154 assert!(matches!(err, Error::Io(_)));
156 }
157
158 #[cfg(feature = "alloc")]
159 #[test]
160 fn btreemap_decode_success() {
161 let buf = &[0x82, 0x01, 0x0a, 0x02, 0x14];
163 let mut r = crate::io::SliceReader::new(buf);
164 let m = <alloc::collections::BTreeMap<u8, u8> as Decode>::decode(&mut r).unwrap();
165 assert_eq!(m.len(), 2);
166 assert_eq!(m.get(&1), Some(&10));
167 assert_eq!(m.get(&2), Some(&20));
168 assert!(r.rest().is_empty());
169 }
170
171 #[cfg(feature = "std")]
172 #[test]
173 fn hashmap_decode_success() {
174 let buf = &[0x82, 0x01, 0xc3, 0x03, 0xc2];
176 let mut r = crate::io::SliceReader::new(buf);
177 let m = <std::collections::HashMap<u8, bool> as Decode>::decode(&mut r).unwrap();
178 assert_eq!(m.len(), 2);
179 assert_eq!(m.get(&1), Some(&true));
180 assert_eq!(m.get(&3), Some(&false));
181 assert!(r.rest().is_empty());
182 }
183}