1use core::{cell::RefCell, marker::PhantomData, ops::Deref};
4
5use super::{Encode, Error, Result};
6use crate::{formats::Format, io::IoWrite};
7
8pub trait KVEncode {
10 fn encode_kv<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error>;
12}
13
14impl<KV: KVEncode> KVEncode for &KV {
15 fn encode_kv<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
16 KV::encode_kv(self, writer)
17 }
18}
19
20impl<K: Encode, V: Encode> KVEncode for (K, V) {
21 fn encode_kv<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
22 let (k, v) = self;
23 let k_len = k.encode(writer)?;
24 let v_len = v.encode(writer)?;
25 Ok(k_len + v_len)
26 }
27}
28
29pub struct MapFormatEncoder(pub usize);
31impl MapFormatEncoder {
32 pub fn new(size: usize) -> Self {
34 Self(size)
35 }
36}
37
38impl Encode for MapFormatEncoder {
39 fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
40 match self.0 {
41 0x00..=0xf => {
42 let cast = self.0 as u8;
43 writer.write(&[Format::FixMap(cast).as_byte()])?;
44
45 Ok(1)
46 }
47 0x10..=0xffff => {
48 let cast = (self.0 as u16).to_be_bytes();
49 writer.write(&[Format::Map16.as_byte(), cast[0], cast[1]])?;
50
51 Ok(3)
52 }
53 0x10000..=0xffffffff => {
54 let cast = (self.0 as u32).to_be_bytes();
55 writer.write(&[Format::Map32.as_byte(), cast[0], cast[1], cast[2], cast[3]])?;
56
57 Ok(5)
58 }
59 _ => Err(Error::InvalidFormat),
60 }
61 }
62}
63
64pub struct MapDataEncoder<I, J, KV> {
66 data: RefCell<J>,
67 _phantom: PhantomData<(I, J, KV)>,
68}
69
70impl<I, KV> MapDataEncoder<I, I::IntoIter, KV>
71where
72 I: IntoIterator<Item = KV>,
73{
74 pub fn new(data: I) -> Self {
76 Self {
77 data: RefCell::new(data.into_iter()),
78 _phantom: Default::default(),
79 }
80 }
81}
82
83impl<I, J, KV> Encode for MapDataEncoder<I, J, KV>
84where
85 J: Iterator<Item = KV>,
86 KV: KVEncode,
87{
88 fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
89 let map_len = self
90 .data
91 .borrow_mut()
92 .by_ref()
93 .map(|kv| kv.encode_kv(writer))
94 .try_fold(0, |acc, v| v.map(|n| acc + n))?;
95 Ok(map_len)
96 }
97}
98
99fn encode_iter<W, I>(writer: &mut W, len: usize, it: I) -> Result<usize, W::Error>
100where
101 W: IoWrite,
102 I: Iterator,
103 I::Item: KVEncode,
104{
105 let format_len = MapFormatEncoder::new(len).encode(writer)?;
106 let data_len = it
107 .map(|kv| kv.encode_kv(writer))
108 .try_fold(0, |acc, v| v.map(|n| acc + n))?;
109 Ok(format_len + data_len)
110}
111
112pub struct MapSliceEncoder<'data, KV> {
114 data: &'data [KV],
115 _phantom: PhantomData<KV>,
116}
117
118impl<'data, KV> MapSliceEncoder<'data, KV> {
119 pub fn new(data: &'data [KV]) -> Self {
121 Self {
122 data,
123 _phantom: Default::default(),
124 }
125 }
126}
127
128impl<'data, KV> Deref for MapSliceEncoder<'data, KV> {
129 type Target = &'data [KV];
130 fn deref(&self) -> &Self::Target {
131 &self.data
132 }
133}
134
135impl<KV: KVEncode> Encode for MapSliceEncoder<'_, KV> {
136 fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
137 encode_iter(writer, self.data.len(), self.data.iter())
138 }
139}
140
141#[cfg(feature = "alloc")]
142impl<K: Encode + Ord, V: Encode> Encode for alloc::collections::BTreeMap<K, V> {
143 fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
144 encode_iter(writer, self.len(), self.iter())
145 }
146}
147
148#[cfg(feature = "std")]
149impl<K, V, S> Encode for std::collections::HashMap<K, V, S>
150where
151 K: Encode + Eq + core::hash::Hash,
152 V: Encode,
153 S: std::hash::BuildHasher,
154{
155 fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
156 encode_iter(writer, self.len(), self.iter())
157 }
158}
159
160pub struct MapEncoder<I, J, KV> {
162 map: RefCell<J>,
163 _phantom: PhantomData<(I, J, KV)>,
164}
165
166impl<I, KV> MapEncoder<I, I::IntoIter, KV>
167where
168 I: IntoIterator<Item = KV>,
169 KV: KVEncode,
170{
171 pub fn new(map: I) -> Self {
173 Self {
174 map: RefCell::new(map.into_iter()),
175 _phantom: Default::default(),
176 }
177 }
178}
179
180impl<I, J, KV> Encode for MapEncoder<I, J, KV>
181where
182 J: Iterator<Item = KV> + ExactSizeIterator,
183 KV: KVEncode,
184{
185 fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
186 let self_len = self.map.borrow().len();
187 let format_len = MapFormatEncoder::new(self_len).encode(writer)?;
188 let map_len = MapDataEncoder::new(self.map.borrow_mut().by_ref()).encode(writer)?;
189
190 Ok(format_len + map_len)
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use crate::encode::int::EncodeMinimizeInt;
198 use rstest::rstest;
199
200 #[rstest]
201 #[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
202 fn encode_slice_fix_array<K, V, Map, E>(#[case] value: Map, #[case] expected: E)
203 where
204 K: Encode,
205 V: Encode,
206 Map: AsRef<[(K, V)]>,
207 E: AsRef<[u8]> + Sized,
208 {
209 let expected = expected.as_ref();
210 let encoder = MapSliceEncoder::new(value.as_ref());
211
212 let mut buf = vec![];
213 let n = encoder.encode(&mut buf).unwrap();
214 assert_eq!(buf, expected);
215 assert_eq!(n, expected.len());
216 }
217
218 #[rstest]
219 #[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
220 fn encode_iter_fix_array<I, KV, E>(#[case] value: I, #[case] expected: E)
221 where
222 I: IntoIterator<Item = KV>,
223 I::IntoIter: ExactSizeIterator,
224 KV: KVEncode,
225 E: AsRef<[u8]> + Sized,
226 {
227 let expected = expected.as_ref();
228
229 let encoder = MapEncoder::new(value.into_iter());
230 let mut buf = vec![];
231 let n = encoder.encode(&mut buf).unwrap();
232 assert_eq!(buf, expected);
233 assert_eq!(n, expected.len());
234 }
235
236 #[cfg(feature = "alloc")]
237 #[test]
238 fn encode_btreemap_sorted() {
239 let mut m = alloc::collections::BTreeMap::new();
240 m.insert(2u8, 20u8);
241 m.insert(1u8, 10u8);
242
243 let mut buf = alloc::vec::Vec::new();
244 let n = m.encode(&mut buf).unwrap();
245
246 assert_eq!(
248 &buf[..n],
249 &[0x82, 0x01, 0x0a, 0x02, 0x14] );
251 }
252
253 #[cfg(feature = "std")]
254 #[test]
255 fn encode_hashmap_roundtrip() {
256 use crate::decode::Decode;
257
258 let mut m = std::collections::HashMap::<u8, bool>::new();
259 m.insert(1, true);
260 m.insert(3, false);
261
262 let mut buf = Vec::new();
263 let _ = m.encode(&mut buf).unwrap();
264
265 let mut r = crate::io::SliceReader::new(&buf);
267 let back = <std::collections::HashMap<u8, bool> as Decode>::decode(&mut r).unwrap();
268 assert_eq!(back.len(), 2);
269 assert_eq!(back.get(&1), Some(&true));
270 assert_eq!(back.get(&3), Some(&false));
271 }
272}