Skip to main content

messagepack_core/encode/
map.rs

1//! Map encoders.
2
3use core::{cell::RefCell, marker::PhantomData, ops::Deref};
4
5use super::{Encode, Error, Result};
6use crate::{formats::Format, io::IoWrite};
7
8/// A key-value encoder that writes a single `key, value` pair.
9pub trait KVEncode {
10    /// Encode this key‑value pair to the writer and return the number of bytes written.
11    fn encode_kv<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error>;
12}
13
14impl<KV: KVEncode> KVEncode for &KV {
15    fn encode_kv<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
16        KV::encode_kv(self, writer)
17    }
18}
19
20impl<K: Encode, V: Encode> KVEncode for (K, V) {
21    fn encode_kv<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
22        let (k, v) = self;
23        let k_len = k.encode(writer)?;
24        let v_len = v.encode(writer)?;
25        Ok(k_len + v_len)
26    }
27}
28
29/// Encode only the map header for a map of a given length.
30pub struct MapFormatEncoder(pub usize);
31impl MapFormatEncoder {
32    /// Construct from the number of pairs contained in the map.
33    pub fn new(size: usize) -> Self {
34        Self(size)
35    }
36}
37
38impl Encode for MapFormatEncoder {
39    fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
40        match self.0 {
41            0x00..=0xf => {
42                let cast = self.0 as u8;
43                writer.write(&[Format::FixMap(cast).as_byte()])?;
44
45                Ok(1)
46            }
47            0x10..=0xffff => {
48                let cast = (self.0 as u16).to_be_bytes();
49                writer.write(&[Format::Map16.as_byte(), cast[0], cast[1]])?;
50
51                Ok(3)
52            }
53            0x10000..=0xffffffff => {
54                let cast = (self.0 as u32).to_be_bytes();
55                writer.write(&[Format::Map32.as_byte(), cast[0], cast[1], cast[2], cast[3]])?;
56
57                Ok(5)
58            }
59            _ => Err(Error::InvalidFormat),
60        }
61    }
62}
63
64/// Encode a stream of key-value pairs from an iterator.
65pub struct MapDataEncoder<I, J, KV> {
66    data: RefCell<J>,
67    _phantom: PhantomData<(I, J, KV)>,
68}
69
70impl<I, KV> MapDataEncoder<I, I::IntoIter, KV>
71where
72    I: IntoIterator<Item = KV>,
73{
74    /// Construct from any iterable of key-value pairs.
75    pub fn new(data: I) -> Self {
76        Self {
77            data: RefCell::new(data.into_iter()),
78            _phantom: Default::default(),
79        }
80    }
81}
82
83impl<I, J, KV> Encode for MapDataEncoder<I, J, KV>
84where
85    J: Iterator<Item = KV>,
86    KV: KVEncode,
87{
88    fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
89        let map_len = self
90            .data
91            .borrow_mut()
92            .by_ref()
93            .map(|kv| kv.encode_kv(writer))
94            .try_fold(0, |acc, v| v.map(|n| acc + n))?;
95        Ok(map_len)
96    }
97}
98
99fn encode_iter<W, I>(writer: &mut W, len: usize, it: I) -> Result<usize, W::Error>
100where
101    W: IoWrite,
102    I: Iterator,
103    I::Item: KVEncode,
104{
105    let format_len = MapFormatEncoder::new(len).encode(writer)?;
106    let data_len = it
107        .map(|kv| kv.encode_kv(writer))
108        .try_fold(0, |acc, v| v.map(|n| acc + n))?;
109    Ok(format_len + data_len)
110}
111
112/// Encode a slice of key-value pairs.
113pub struct MapSliceEncoder<'data, KV> {
114    data: &'data [KV],
115    _phantom: PhantomData<KV>,
116}
117
118impl<'data, KV> MapSliceEncoder<'data, KV> {
119    /// Construct from a slice of key-value pairs.
120    pub fn new(data: &'data [KV]) -> Self {
121        Self {
122            data,
123            _phantom: Default::default(),
124        }
125    }
126}
127
128impl<'data, KV> Deref for MapSliceEncoder<'data, KV> {
129    type Target = &'data [KV];
130    fn deref(&self) -> &Self::Target {
131        &self.data
132    }
133}
134
135impl<KV: KVEncode> Encode for MapSliceEncoder<'_, KV> {
136    fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
137        encode_iter(writer, self.data.len(), self.data.iter())
138    }
139}
140
141#[cfg(feature = "alloc")]
142impl<K: Encode + Ord, V: Encode> Encode for alloc::collections::BTreeMap<K, V> {
143    fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
144        encode_iter(writer, self.len(), self.iter())
145    }
146}
147
148#[cfg(feature = "std")]
149impl<K, V, S> Encode for std::collections::HashMap<K, V, S>
150where
151    K: Encode + Eq + core::hash::Hash,
152    V: Encode,
153    S: std::hash::BuildHasher,
154{
155    fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
156        encode_iter(writer, self.len(), self.iter())
157    }
158}
159
160/// Encode a map from an owned iterator, writing items lazily.
161pub struct MapEncoder<I, J, KV> {
162    map: RefCell<J>,
163    _phantom: PhantomData<(I, J, KV)>,
164}
165
166impl<I, KV> MapEncoder<I, I::IntoIter, KV>
167where
168    I: IntoIterator<Item = KV>,
169    KV: KVEncode,
170{
171    /// Construct from any iterable of key-value pairs.
172    pub fn new(map: I) -> Self {
173        Self {
174            map: RefCell::new(map.into_iter()),
175            _phantom: Default::default(),
176        }
177    }
178}
179
180impl<I, J, KV> Encode for MapEncoder<I, J, KV>
181where
182    J: Iterator<Item = KV> + ExactSizeIterator,
183    KV: KVEncode,
184{
185    fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
186        let self_len = self.map.borrow().len();
187        let format_len = MapFormatEncoder::new(self_len).encode(writer)?;
188        let map_len = MapDataEncoder::new(self.map.borrow_mut().by_ref()).encode(writer)?;
189
190        Ok(format_len + map_len)
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use crate::encode::int::EncodeMinimizeInt;
198    use rstest::rstest;
199
200    #[rstest]
201    #[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
202    fn encode_slice_fix_array<K, V, Map, E>(#[case] value: Map, #[case] expected: E)
203    where
204        K: Encode,
205        V: Encode,
206        Map: AsRef<[(K, V)]>,
207        E: AsRef<[u8]> + Sized,
208    {
209        let expected = expected.as_ref();
210        let encoder = MapSliceEncoder::new(value.as_ref());
211
212        let mut buf = vec![];
213        let n = encoder.encode(&mut buf).unwrap();
214        assert_eq!(buf, expected);
215        assert_eq!(n, expected.len());
216    }
217
218    #[rstest]
219    #[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
220    fn encode_iter_fix_array<I, KV, E>(#[case] value: I, #[case] expected: E)
221    where
222        I: IntoIterator<Item = KV>,
223        I::IntoIter: ExactSizeIterator,
224        KV: KVEncode,
225        E: AsRef<[u8]> + Sized,
226    {
227        let expected = expected.as_ref();
228
229        let encoder = MapEncoder::new(value.into_iter());
230        let mut buf = vec![];
231        let n = encoder.encode(&mut buf).unwrap();
232        assert_eq!(buf, expected);
233        assert_eq!(n, expected.len());
234    }
235
236    #[cfg(feature = "alloc")]
237    #[test]
238    fn encode_btreemap_sorted() {
239        let mut m = alloc::collections::BTreeMap::new();
240        m.insert(2u8, 20u8);
241        m.insert(1u8, 10u8);
242
243        let mut buf = alloc::vec::Vec::new();
244        let n = m.encode(&mut buf).unwrap();
245
246        // Expect keys encoded in sorted order: 1, 2
247        assert_eq!(
248            &buf[..n],
249            &[0x82, 0x01, 0x0a, 0x02, 0x14] // fixmap(2) {1:10, 2:20}
250        );
251    }
252
253    #[cfg(feature = "std")]
254    #[test]
255    fn encode_hashmap_roundtrip() {
256        use crate::decode::Decode;
257
258        let mut m = std::collections::HashMap::<u8, bool>::new();
259        m.insert(1, true);
260        m.insert(3, false);
261
262        let mut buf = Vec::new();
263        let _ = m.encode(&mut buf).unwrap();
264
265        // Roundtrip decode to HashMap and check contents regardless of order
266        let mut r = crate::io::SliceReader::new(&buf);
267        let back = <std::collections::HashMap<u8, bool> as Decode>::decode(&mut r).unwrap();
268        assert_eq!(back.len(), 2);
269        assert_eq!(back.get(&1), Some(&true));
270        assert_eq!(back.get(&3), Some(&false));
271    }
272}