messagepack_core/encode/
map.rs

1//! Map encoders.
2
3use core::{cell::RefCell, marker::PhantomData, ops::Deref};
4
5use super::{Encode, Error, Result};
6use crate::{formats::Format, io::IoWrite};
7
8/// A key-value encoder that writes a single `key, value` pair.
9pub trait KVEncode<W>
10where
11    W: IoWrite,
12{
13    /// Encode this key‑value pair to the writer and return the number of bytes written.
14    fn encode(&self, writer: &mut W) -> Result<usize, W::Error>;
15}
16
17impl<W: IoWrite, KV: KVEncode<W>> KVEncode<W> for &KV {
18    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
19        KV::encode(self, writer)
20    }
21}
22
23impl<W: IoWrite, K: Encode<W>, V: Encode<W>> KVEncode<W> for (K, V) {
24    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
25        let (k, v) = self;
26        let k_len = k.encode(writer)?;
27        let v_len = v.encode(writer)?;
28        Ok(k_len + v_len)
29    }
30}
31
32/// Encode only the map header for a map of a given length.
33pub struct MapFormatEncoder(pub usize);
34impl MapFormatEncoder {
35    /// Construct from the number of pairs contained in the map.
36    pub fn new(size: usize) -> Self {
37        Self(size)
38    }
39}
40
41impl<W: IoWrite> Encode<W> for MapFormatEncoder {
42    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
43        match self.0 {
44            0x00..=0xf => {
45                let cast = self.0 as u8;
46                writer.write(&[Format::FixMap(cast).as_byte()])?;
47
48                Ok(1)
49            }
50            0x10..=0xffff => {
51                let cast = (self.0 as u16).to_be_bytes();
52                writer.write(&[Format::Map16.as_byte(), cast[0], cast[1]])?;
53
54                Ok(3)
55            }
56            0x10000..=0xffffffff => {
57                let cast = (self.0 as u32).to_be_bytes();
58                writer.write(&[Format::Map32.as_byte(), cast[0], cast[1], cast[2], cast[3]])?;
59
60                Ok(5)
61            }
62            _ => Err(Error::InvalidFormat),
63        }
64    }
65}
66
67/// Encode a stream of key-value pairs from an iterator.
68pub struct MapDataEncoder<I, J, KV> {
69    data: RefCell<J>,
70    _phantom: PhantomData<(I, J, KV)>,
71}
72
73impl<I, KV> MapDataEncoder<I, I::IntoIter, KV>
74where
75    I: IntoIterator<Item = KV>,
76{
77    /// Construct from any iterable of key-value pairs.
78    pub fn new(data: I) -> Self {
79        Self {
80            data: RefCell::new(data.into_iter()),
81            _phantom: Default::default(),
82        }
83    }
84}
85
86impl<W, I, J, KV> Encode<W> for MapDataEncoder<I, J, KV>
87where
88    W: IoWrite,
89    J: Iterator<Item = KV>,
90    KV: KVEncode<W>,
91{
92    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
93        let map_len = self
94            .data
95            .borrow_mut()
96            .by_ref()
97            .map(|kv| kv.encode(writer))
98            .try_fold(0, |acc, v| v.map(|n| acc + n))?;
99        Ok(map_len)
100    }
101}
102
103fn encode_iter<W, I>(writer: &mut W, len: usize, it: I) -> Result<usize, W::Error>
104where
105    W: IoWrite,
106    I: Iterator,
107    I::Item: KVEncode<W>,
108{
109    let format_len = MapFormatEncoder::new(len).encode(writer)?;
110    let data_len = it
111        .map(|kv| kv.encode(writer))
112        .try_fold(0, |acc, v| v.map(|n| acc + n))?;
113    Ok(format_len + data_len)
114}
115
116/// Encode a slice of key-value pairs.
117pub struct MapSliceEncoder<'data, KV> {
118    data: &'data [KV],
119    _phantom: PhantomData<KV>,
120}
121
122impl<'data, KV> MapSliceEncoder<'data, KV> {
123    /// Construct from a slice of key-value pairs.
124    pub fn new(data: &'data [KV]) -> Self {
125        Self {
126            data,
127            _phantom: Default::default(),
128        }
129    }
130}
131
132impl<'data, KV> Deref for MapSliceEncoder<'data, KV> {
133    type Target = &'data [KV];
134    fn deref(&self) -> &Self::Target {
135        &self.data
136    }
137}
138
139impl<W, KV> Encode<W> for MapSliceEncoder<'_, KV>
140where
141    W: IoWrite,
142    KV: KVEncode<W>,
143{
144    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
145        encode_iter(writer, self.data.len(), self.data.iter())
146    }
147}
148
149#[cfg(feature = "alloc")]
150impl<W, K, V> Encode<W> for alloc::collections::BTreeMap<K, V>
151where
152    W: IoWrite,
153    K: Encode<W> + Ord,
154    V: Encode<W>,
155{
156    fn encode(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
157        encode_iter(writer, self.len(), self.iter())
158    }
159}
160
161#[cfg(feature = "std")]
162impl<W, K, V, S> Encode<W> for std::collections::HashMap<K, V, S>
163where
164    W: IoWrite,
165    K: Encode<W> + Eq + core::hash::Hash,
166    V: Encode<W>,
167    S: std::hash::BuildHasher,
168{
169    fn encode(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
170        encode_iter(writer, self.len(), self.iter())
171    }
172}
173
174/// Encode a map from an owned iterator, writing items lazily.
175pub struct MapEncoder<W, I, J, KV> {
176    map: RefCell<J>,
177    _phantom: PhantomData<(W, I, J, KV)>,
178}
179
180impl<W, I, KV> MapEncoder<W, I, I::IntoIter, KV>
181where
182    W: IoWrite,
183    I: IntoIterator<Item = KV>,
184    KV: KVEncode<W>,
185{
186    /// Construct from any iterable of key-value pairs.
187    pub fn new(map: I) -> Self {
188        Self {
189            map: RefCell::new(map.into_iter()),
190            _phantom: Default::default(),
191        }
192    }
193}
194
195impl<W, I, J, KV> Encode<W> for MapEncoder<W, I, J, KV>
196where
197    W: IoWrite,
198    J: Iterator<Item = KV> + ExactSizeIterator,
199    KV: KVEncode<W>,
200{
201    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
202        let self_len = self.map.borrow().len();
203        let format_len = MapFormatEncoder::new(self_len).encode(writer)?;
204        let map_len = MapDataEncoder::new(self.map.borrow_mut().by_ref()).encode(writer)?;
205
206        Ok(format_len + map_len)
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::encode::int::EncodeMinimizeInt;
214    use rstest::rstest;
215
216    #[rstest]
217    #[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
218    fn encode_slice_fix_array<K, V, Map, E>(#[case] value: Map, #[case] expected: E)
219    where
220        K: Encode<Vec<u8>>,
221        V: Encode<Vec<u8>>,
222        Map: AsRef<[(K, V)]>,
223        E: AsRef<[u8]> + Sized,
224    {
225        let expected = expected.as_ref();
226        let encoder = MapSliceEncoder::new(value.as_ref());
227
228        let mut buf = vec![];
229        let n = encoder.encode(&mut buf).unwrap();
230        assert_eq!(buf, expected);
231        assert_eq!(n, expected.len());
232    }
233
234    #[rstest]
235    #[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
236    fn encode_iter_fix_array<I, KV, E>(#[case] value: I, #[case] expected: E)
237    where
238        I: IntoIterator<Item = KV>,
239        I::IntoIter: ExactSizeIterator,
240        KV: KVEncode<Vec<u8>>,
241        E: AsRef<[u8]> + Sized,
242    {
243        let expected = expected.as_ref();
244
245        let encoder = MapEncoder::new(value.into_iter());
246        let mut buf = vec![];
247        let n = encoder.encode(&mut buf).unwrap();
248        assert_eq!(buf, expected);
249        assert_eq!(n, expected.len());
250    }
251
252    #[cfg(feature = "alloc")]
253    #[test]
254    fn encode_btreemap_sorted() {
255        let mut m = alloc::collections::BTreeMap::new();
256        m.insert(2u8, 20u8);
257        m.insert(1u8, 10u8);
258
259        let mut buf = alloc::vec::Vec::new();
260        let n = m.encode(&mut buf).unwrap();
261
262        // Expect keys encoded in sorted order: 1, 2
263        assert_eq!(
264            &buf[..n],
265            &[0x82, 0x01, 0x0a, 0x02, 0x14] // fixmap(2) {1:10, 2:20}
266        );
267    }
268
269    #[cfg(feature = "std")]
270    #[test]
271    fn encode_hashmap_roundtrip() {
272        use crate::decode::Decode;
273
274        let mut m = std::collections::HashMap::<u8, bool>::new();
275        m.insert(1, true);
276        m.insert(3, false);
277
278        let mut buf = Vec::new();
279        let _ = m.encode(&mut buf).unwrap();
280
281        // Roundtrip decode to HashMap and check contents regardless of order
282        let mut r = crate::io::SliceReader::new(&buf);
283        let back = <std::collections::HashMap<u8, bool> as Decode>::decode(&mut r).unwrap();
284        assert_eq!(back.len(), 2);
285        assert_eq!(back.get(&1), Some(&true));
286        assert_eq!(back.get(&3), Some(&false));
287    }
288}