messagepack_core/encode/
map.rs

1use core::{cell::RefCell, marker::PhantomData, ops::Deref};
2
3use super::{Encode, Error, Result};
4use crate::{formats::Format, io::IoWrite};
5
6pub trait KVEncode<W>
7where
8    W: IoWrite,
9{
10    fn encode(&self, writer: &mut W) -> Result<usize, W::Error>;
11}
12
13impl<W: IoWrite, KV: KVEncode<W>> KVEncode<W> for &KV {
14    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
15        KV::encode(self, writer)
16    }
17}
18
19impl<W: IoWrite, K: Encode<W>, V: Encode<W>> KVEncode<W> for (K, V) {
20    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
21        let (k, v) = self;
22        let k_len = k.encode(writer)?;
23        let v_len = v.encode(writer)?;
24        Ok(k_len + v_len)
25    }
26}
27
28pub struct MapFormatEncoder(pub usize);
29impl MapFormatEncoder {
30    pub fn new(size: usize) -> Self {
31        Self(size)
32    }
33}
34
35impl<W: IoWrite> Encode<W> for MapFormatEncoder {
36    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
37        match self.0 {
38            0x00..=0xf => {
39                let cast = self.0 as u8;
40                writer.write_bytes(&[Format::FixMap(cast).as_byte()])?;
41
42                Ok(1)
43            }
44            0x10..=0xffff => {
45                let cast = (self.0 as u16).to_be_bytes();
46                writer.write_bytes(&[Format::Map16.as_byte(), cast[0], cast[1]])?;
47
48                Ok(3)
49            }
50            0x10000..=0xffffffff => {
51                let cast = (self.0 as u32).to_be_bytes();
52                writer.write_bytes(&[
53                    Format::Map32.as_byte(),
54                    cast[0],
55                    cast[1],
56                    cast[2],
57                    cast[3],
58                ])?;
59
60                Ok(5)
61            }
62            _ => Err(Error::InvalidFormat),
63        }
64    }
65}
66
67pub struct MapDataEncoder<I, J, KV> {
68    data: RefCell<J>,
69    _phantom: PhantomData<(I, J, KV)>,
70}
71
72impl<I, KV> MapDataEncoder<I, I::IntoIter, KV>
73where
74    I: IntoIterator<Item = KV>,
75{
76    pub fn new(data: I) -> Self {
77        Self {
78            data: RefCell::new(data.into_iter()),
79            _phantom: Default::default(),
80        }
81    }
82}
83
84impl<W, I, J, KV> Encode<W> for MapDataEncoder<I, J, KV>
85where
86    W: IoWrite,
87    J: Iterator<Item = KV>,
88    KV: KVEncode<W>,
89{
90    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
91        let map_len = self
92            .data
93            .borrow_mut()
94            .by_ref()
95            .map(|kv| kv.encode(writer))
96            .try_fold(0, |acc, v| v.map(|n| acc + n))?;
97        Ok(map_len)
98    }
99}
100
101pub struct MapSliceEncoder<'data, KV> {
102    data: &'data [KV],
103    _phantom: PhantomData<KV>,
104}
105
106impl<'data, KV> MapSliceEncoder<'data, KV> {
107    pub fn new(data: &'data [KV]) -> Self {
108        Self {
109            data,
110            _phantom: Default::default(),
111        }
112    }
113}
114
115impl<'data, KV> Deref for MapSliceEncoder<'data, KV> {
116    type Target = &'data [KV];
117    fn deref(&self) -> &Self::Target {
118        &self.data
119    }
120}
121
122impl<W, KV> Encode<W> for MapSliceEncoder<'_, KV>
123where
124    W: IoWrite,
125    KV: KVEncode<W>,
126{
127    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
128        let self_len = self.data.len();
129        let format_len = MapFormatEncoder::new(self_len).encode(writer)?;
130        let map_len = MapDataEncoder::new(self.data.iter()).encode(writer)?;
131
132        Ok(format_len + map_len)
133    }
134}
135
136pub struct MapEncoder<W, I, J, KV> {
137    map: RefCell<J>,
138    _phantom: PhantomData<(W, I, J, KV)>,
139}
140
141impl<W, I, KV> MapEncoder<W, I, I::IntoIter, KV>
142where
143    W: IoWrite,
144    I: IntoIterator<Item = KV>,
145    KV: KVEncode<W>,
146{
147    pub fn new(map: I) -> Self {
148        Self {
149            map: RefCell::new(map.into_iter()),
150            _phantom: Default::default(),
151        }
152    }
153}
154
155impl<W, I, J, KV> Encode<W> for MapEncoder<W, I, J, KV>
156where
157    W: IoWrite,
158    J: Iterator<Item = KV> + ExactSizeIterator,
159    KV: KVEncode<W>,
160{
161    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
162        let self_len = self.map.borrow().len();
163        let format_len = MapFormatEncoder::new(self_len).encode(writer)?;
164        let map_len = MapDataEncoder::new(self.map.borrow_mut().by_ref()).encode(writer)?;
165
166        Ok(format_len + map_len)
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::encode::int::EncodeMinimizeInt;
174    use rstest::rstest;
175
176    #[rstest]
177    #[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
178    fn encode_slice_fix_array<K, V, Map, E>(#[case] value: Map, #[case] expected: E)
179    where
180        K: Encode<Vec<u8>>,
181        V: Encode<Vec<u8>>,
182        Map: AsRef<[(K, V)]>,
183        E: AsRef<[u8]> + Sized,
184    {
185        let expected = expected.as_ref();
186        let encoder = MapSliceEncoder::new(value.as_ref());
187
188        let mut buf = vec![];
189        let n = encoder.encode(&mut buf).unwrap();
190        assert_eq!(buf, expected);
191        assert_eq!(n, expected.len());
192    }
193
194    #[rstest]
195    #[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
196    fn encode_iter_fix_array<I, KV, E>(#[case] value: I, #[case] expected: E)
197    where
198        I: IntoIterator<Item = KV>,
199        I::IntoIter: ExactSizeIterator,
200        KV: KVEncode<Vec<u8>>,
201        E: AsRef<[u8]> + Sized,
202    {
203        let expected = expected.as_ref();
204
205        let encoder = MapEncoder::new(value.into_iter());
206        let mut buf = vec![];
207        let n = encoder.encode(&mut buf).unwrap();
208        assert_eq!(buf, expected);
209        assert_eq!(n, expected.len());
210    }
211}