messagepack_core/encode/
array.rs

1use core::{cell::RefCell, marker::PhantomData};
2
3use super::{Encode, Error, Result};
4use crate::{formats::Format, io::IoWrite};
5
6pub struct ArrayFormatEncoder(pub usize);
7impl ArrayFormatEncoder {
8    pub fn new(size: usize) -> Self {
9        Self(size)
10    }
11}
12impl<W: IoWrite> Encode<W> for ArrayFormatEncoder {
13    fn encode(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
14        match self.0 {
15            0x00..=0b1111 => {
16                let cast = self.0 as u8;
17                writer.write_bytes(&[Format::FixArray(cast).as_byte()])?;
18                Ok(1)
19            }
20            0x10..=0xffff => {
21                let cast = (self.0 as u16).to_be_bytes();
22                writer.write_bytes(&[Format::Array16.as_byte(), cast[0], cast[1]])?;
23
24                Ok(3)
25            }
26            0x10000..=0xffffffff => {
27                let cast = (self.0 as u32).to_be_bytes();
28                writer.write_bytes(&[
29                    Format::Array32.as_byte(),
30                    cast[0],
31                    cast[1],
32                    cast[2],
33                    cast[3],
34                ])?;
35
36                Ok(5)
37            }
38            _ => Err(Error::InvalidFormat),
39        }
40    }
41}
42
43pub struct ArrayDataEncoder<I, V> {
44    data: RefCell<I>,
45    _phantom: PhantomData<(I, V)>,
46}
47
48impl<I, V> ArrayDataEncoder<I, V> {
49    pub fn new(data: I) -> Self {
50        ArrayDataEncoder {
51            data: RefCell::new(data),
52            _phantom: Default::default(),
53        }
54    }
55}
56
57impl<W, I, V> Encode<W> for ArrayDataEncoder<I, V>
58where
59    W: IoWrite,
60    I: Iterator<Item = V>,
61    V: Encode<W>,
62{
63    fn encode(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
64        let array_len = self
65            .data
66            .borrow_mut()
67            .by_ref()
68            .map(|v| v.encode(writer))
69            .try_fold(0, |acc, v| v.map(|n| acc + n))?;
70        Ok(array_len)
71    }
72}
73
74#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
75pub struct ArrayEncoder<'array, V>(&'array [V]);
76
77impl<'array, V> core::ops::Deref for ArrayEncoder<'array, V> {
78    type Target = &'array [V];
79    fn deref(&self) -> &Self::Target {
80        &self.0
81    }
82}
83
84impl<W, V> Encode<W> for ArrayEncoder<'_, V>
85where
86    W: IoWrite,
87    V: Encode<W>,
88{
89    fn encode(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
90        let self_len = self.len();
91        let format_len = ArrayFormatEncoder(self_len).encode(writer)?;
92
93        let array_len = ArrayDataEncoder::new(self.iter()).encode(writer)?;
94        Ok(format_len + array_len)
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use rstest::rstest;
102
103    #[rstest]
104    #[case([1u8, 2u8, 3u8],[0x93, 0x01, 0x02, 0x03])]
105    fn encode_fix_array<V: Encode<Vec<u8>>, Array: AsRef<[V]>, E: AsRef<[u8]> + Sized>(
106        #[case] value: Array,
107        #[case] expected: E,
108    ) {
109        let expected = expected.as_ref();
110        let encoder = ArrayEncoder(value.as_ref());
111
112        let mut buf = vec![];
113        let n = encoder.encode(&mut buf).unwrap();
114        assert_eq!(buf, expected);
115        assert_eq!(n, expected.len());
116    }
117
118    #[rstest]
119    #[case(0xdc, 65535_u16.to_be_bytes(),[0x34;65535])]
120    #[case(0xdd, 65536_u32.to_be_bytes(),[0x56;65536])]
121    fn encode_array_sized<S: AsRef<[u8]>, D: AsRef<[u8]>>(
122        #[case] marker: u8,
123        #[case] size: S,
124        #[case] data: D,
125    ) {
126        let expected = marker
127            .to_be_bytes()
128            .iter()
129            .chain(size.as_ref())
130            .chain(data.as_ref())
131            .cloned()
132            .collect::<Vec<u8>>();
133
134        let encoder = ArrayEncoder(data.as_ref());
135
136        let mut buf = vec![];
137        let n = encoder.encode(&mut buf).unwrap();
138
139        assert_eq!(&buf, &expected);
140        assert_eq!(n, expected.len());
141    }
142}