1use core::{cell::RefCell, marker::PhantomData, ops::Deref};
4
5use super::{Encode, Error, Result};
6use crate::{formats::Format, io::IoWrite};
7
8pub trait KVEncode<W>
10where
11 W: IoWrite,
12{
13 fn encode(&self, writer: &mut W) -> Result<usize, W::Error>;
15}
16
17impl<W: IoWrite, KV: KVEncode<W>> KVEncode<W> for &KV {
18 fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
19 KV::encode(self, writer)
20 }
21}
22
23impl<W: IoWrite, K: Encode<W>, V: Encode<W>> KVEncode<W> for (K, V) {
24 fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
25 let (k, v) = self;
26 let k_len = k.encode(writer)?;
27 let v_len = v.encode(writer)?;
28 Ok(k_len + v_len)
29 }
30}
31
32pub struct MapFormatEncoder(pub usize);
34impl MapFormatEncoder {
35 pub fn new(size: usize) -> Self {
37 Self(size)
38 }
39}
40
41impl<W: IoWrite> Encode<W> for MapFormatEncoder {
42 fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
43 match self.0 {
44 0x00..=0xf => {
45 let cast = self.0 as u8;
46 writer.write(&[Format::FixMap(cast).as_byte()])?;
47
48 Ok(1)
49 }
50 0x10..=0xffff => {
51 let cast = (self.0 as u16).to_be_bytes();
52 writer.write(&[Format::Map16.as_byte(), cast[0], cast[1]])?;
53
54 Ok(3)
55 }
56 0x10000..=0xffffffff => {
57 let cast = (self.0 as u32).to_be_bytes();
58 writer.write(&[Format::Map32.as_byte(), cast[0], cast[1], cast[2], cast[3]])?;
59
60 Ok(5)
61 }
62 _ => Err(Error::InvalidFormat),
63 }
64 }
65}
66
67pub struct MapDataEncoder<I, J, KV> {
69 data: RefCell<J>,
70 _phantom: PhantomData<(I, J, KV)>,
71}
72
73impl<I, KV> MapDataEncoder<I, I::IntoIter, KV>
74where
75 I: IntoIterator<Item = KV>,
76{
77 pub fn new(data: I) -> Self {
79 Self {
80 data: RefCell::new(data.into_iter()),
81 _phantom: Default::default(),
82 }
83 }
84}
85
86impl<W, I, J, KV> Encode<W> for MapDataEncoder<I, J, KV>
87where
88 W: IoWrite,
89 J: Iterator<Item = KV>,
90 KV: KVEncode<W>,
91{
92 fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
93 let map_len = self
94 .data
95 .borrow_mut()
96 .by_ref()
97 .map(|kv| kv.encode(writer))
98 .try_fold(0, |acc, v| v.map(|n| acc + n))?;
99 Ok(map_len)
100 }
101}
102
103fn encode_iter<W, I>(writer: &mut W, len: usize, it: I) -> Result<usize, W::Error>
104where
105 W: IoWrite,
106 I: Iterator,
107 I::Item: KVEncode<W>,
108{
109 let format_len = MapFormatEncoder::new(len).encode(writer)?;
110 let data_len = it
111 .map(|kv| kv.encode(writer))
112 .try_fold(0, |acc, v| v.map(|n| acc + n))?;
113 Ok(format_len + data_len)
114}
115
116pub struct MapSliceEncoder<'data, KV> {
118 data: &'data [KV],
119 _phantom: PhantomData<KV>,
120}
121
122impl<'data, KV> MapSliceEncoder<'data, KV> {
123 pub fn new(data: &'data [KV]) -> Self {
125 Self {
126 data,
127 _phantom: Default::default(),
128 }
129 }
130}
131
132impl<'data, KV> Deref for MapSliceEncoder<'data, KV> {
133 type Target = &'data [KV];
134 fn deref(&self) -> &Self::Target {
135 &self.data
136 }
137}
138
139impl<W, KV> Encode<W> for MapSliceEncoder<'_, KV>
140where
141 W: IoWrite,
142 KV: KVEncode<W>,
143{
144 fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
145 encode_iter(writer, self.data.len(), self.data.iter())
146 }
147}
148
149#[cfg(feature = "alloc")]
150impl<W, K, V> Encode<W> for alloc::collections::BTreeMap<K, V>
151where
152 W: IoWrite,
153 K: Encode<W> + Ord,
154 V: Encode<W>,
155{
156 fn encode(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
157 encode_iter(writer, self.len(), self.iter())
158 }
159}
160
161#[cfg(feature = "std")]
162impl<W, K, V, S> Encode<W> for std::collections::HashMap<K, V, S>
163where
164 W: IoWrite,
165 K: Encode<W> + Eq + core::hash::Hash,
166 V: Encode<W>,
167 S: std::hash::BuildHasher,
168{
169 fn encode(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
170 encode_iter(writer, self.len(), self.iter())
171 }
172}
173
174pub struct MapEncoder<W, I, J, KV> {
176 map: RefCell<J>,
177 _phantom: PhantomData<(W, I, J, KV)>,
178}
179
180impl<W, I, KV> MapEncoder<W, I, I::IntoIter, KV>
181where
182 W: IoWrite,
183 I: IntoIterator<Item = KV>,
184 KV: KVEncode<W>,
185{
186 pub fn new(map: I) -> Self {
188 Self {
189 map: RefCell::new(map.into_iter()),
190 _phantom: Default::default(),
191 }
192 }
193}
194
195impl<W, I, J, KV> Encode<W> for MapEncoder<W, I, J, KV>
196where
197 W: IoWrite,
198 J: Iterator<Item = KV> + ExactSizeIterator,
199 KV: KVEncode<W>,
200{
201 fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
202 let self_len = self.map.borrow().len();
203 let format_len = MapFormatEncoder::new(self_len).encode(writer)?;
204 let map_len = MapDataEncoder::new(self.map.borrow_mut().by_ref()).encode(writer)?;
205
206 Ok(format_len + map_len)
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::encode::int::EncodeMinimizeInt;
214 use rstest::rstest;
215
216 #[rstest]
217 #[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
218 fn encode_slice_fix_array<K, V, Map, E>(#[case] value: Map, #[case] expected: E)
219 where
220 K: Encode<Vec<u8>>,
221 V: Encode<Vec<u8>>,
222 Map: AsRef<[(K, V)]>,
223 E: AsRef<[u8]> + Sized,
224 {
225 let expected = expected.as_ref();
226 let encoder = MapSliceEncoder::new(value.as_ref());
227
228 let mut buf = vec![];
229 let n = encoder.encode(&mut buf).unwrap();
230 assert_eq!(buf, expected);
231 assert_eq!(n, expected.len());
232 }
233
234 #[rstest]
235 #[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
236 fn encode_iter_fix_array<I, KV, E>(#[case] value: I, #[case] expected: E)
237 where
238 I: IntoIterator<Item = KV>,
239 I::IntoIter: ExactSizeIterator,
240 KV: KVEncode<Vec<u8>>,
241 E: AsRef<[u8]> + Sized,
242 {
243 let expected = expected.as_ref();
244
245 let encoder = MapEncoder::new(value.into_iter());
246 let mut buf = vec![];
247 let n = encoder.encode(&mut buf).unwrap();
248 assert_eq!(buf, expected);
249 assert_eq!(n, expected.len());
250 }
251
252 #[cfg(feature = "alloc")]
253 #[test]
254 fn encode_btreemap_sorted() {
255 let mut m = alloc::collections::BTreeMap::new();
256 m.insert(2u8, 20u8);
257 m.insert(1u8, 10u8);
258
259 let mut buf = alloc::vec::Vec::new();
260 let n = m.encode(&mut buf).unwrap();
261
262 assert_eq!(
264 &buf[..n],
265 &[0x82, 0x01, 0x0a, 0x02, 0x14] );
267 }
268
269 #[cfg(feature = "std")]
270 #[test]
271 fn encode_hashmap_roundtrip() {
272 use crate::decode::Decode;
273
274 let mut m = std::collections::HashMap::<u8, bool>::new();
275 m.insert(1, true);
276 m.insert(3, false);
277
278 let mut buf = Vec::new();
279 let _ = m.encode(&mut buf).unwrap();
280
281 let mut r = crate::io::SliceReader::new(&buf);
283 let back = <std::collections::HashMap<u8, bool> as Decode>::decode(&mut r).unwrap();
284 assert_eq!(back.len(), 2);
285 assert_eq!(back.get(&1), Some(&true));
286 assert_eq!(back.get(&3), Some(&false));
287 }
288}