use core::{cell::RefCell, marker::PhantomData, ops::Deref};
use super::{Encode, Error, Result};
use crate::{formats::Format, io::IoWrite};
pub trait KVEncode {
fn encode_kv<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error>;
}
impl<KV: KVEncode> KVEncode for &KV {
fn encode_kv<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
KV::encode_kv(self, writer)
}
}
impl<K: Encode, V: Encode> KVEncode for (K, V) {
fn encode_kv<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
let (k, v) = self;
let k_len = k.encode(writer)?;
let v_len = v.encode(writer)?;
Ok(k_len + v_len)
}
}
pub struct MapFormatEncoder(pub usize);
impl MapFormatEncoder {
pub fn new(size: usize) -> Self {
Self(size)
}
}
impl Encode for MapFormatEncoder {
fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
match self.0 {
0x00..=0xf => {
let cast = self.0 as u8;
writer.write(&[Format::FixMap(cast).as_byte()])?;
Ok(1)
}
0x10..=0xffff => {
let cast = (self.0 as u16).to_be_bytes();
writer.write(&[Format::Map16.as_byte(), cast[0], cast[1]])?;
Ok(3)
}
0x10000..=0xffffffff => {
let cast = (self.0 as u32).to_be_bytes();
writer.write(&[Format::Map32.as_byte(), cast[0], cast[1], cast[2], cast[3]])?;
Ok(5)
}
_ => Err(Error::InvalidFormat),
}
}
}
pub struct MapDataEncoder<I, J, KV> {
data: RefCell<J>,
_phantom: PhantomData<(I, J, KV)>,
}
impl<I, KV> MapDataEncoder<I, I::IntoIter, KV>
where
I: IntoIterator<Item = KV>,
{
pub fn new(data: I) -> Self {
Self {
data: RefCell::new(data.into_iter()),
_phantom: Default::default(),
}
}
}
impl<I, J, KV> Encode for MapDataEncoder<I, J, KV>
where
J: Iterator<Item = KV>,
KV: KVEncode,
{
fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
let map_len = self
.data
.borrow_mut()
.by_ref()
.map(|kv| kv.encode_kv(writer))
.try_fold(0, |acc, v| v.map(|n| acc + n))?;
Ok(map_len)
}
}
fn encode_iter<W, I>(writer: &mut W, len: usize, it: I) -> Result<usize, W::Error>
where
W: IoWrite,
I: Iterator,
I::Item: KVEncode,
{
let format_len = MapFormatEncoder::new(len).encode(writer)?;
let data_len = it
.map(|kv| kv.encode_kv(writer))
.try_fold(0, |acc, v| v.map(|n| acc + n))?;
Ok(format_len + data_len)
}
pub struct MapSliceEncoder<'data, KV> {
data: &'data [KV],
_phantom: PhantomData<KV>,
}
impl<'data, KV> MapSliceEncoder<'data, KV> {
pub fn new(data: &'data [KV]) -> Self {
Self {
data,
_phantom: Default::default(),
}
}
}
impl<'data, KV> Deref for MapSliceEncoder<'data, KV> {
type Target = &'data [KV];
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<KV: KVEncode> Encode for MapSliceEncoder<'_, KV> {
fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
encode_iter(writer, self.data.len(), self.data.iter())
}
}
#[cfg(feature = "alloc")]
impl<K: Encode + Ord, V: Encode> Encode for alloc::collections::BTreeMap<K, V> {
fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
encode_iter(writer, self.len(), self.iter())
}
}
#[cfg(feature = "std")]
impl<K, V, S> Encode for std::collections::HashMap<K, V, S>
where
K: Encode + Eq + core::hash::Hash,
V: Encode,
S: std::hash::BuildHasher,
{
fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, <W as IoWrite>::Error> {
encode_iter(writer, self.len(), self.iter())
}
}
pub struct MapEncoder<I, J, KV> {
map: RefCell<J>,
_phantom: PhantomData<(I, J, KV)>,
}
impl<I, KV> MapEncoder<I, I::IntoIter, KV>
where
I: IntoIterator<Item = KV>,
KV: KVEncode,
{
pub fn new(map: I) -> Self {
Self {
map: RefCell::new(map.into_iter()),
_phantom: Default::default(),
}
}
}
impl<I, J, KV> Encode for MapEncoder<I, J, KV>
where
J: Iterator<Item = KV> + ExactSizeIterator,
KV: KVEncode,
{
fn encode<W: IoWrite>(&self, writer: &mut W) -> Result<usize, W::Error> {
let self_len = self.map.borrow().len();
let format_len = MapFormatEncoder::new(self_len).encode(writer)?;
let map_len = MapDataEncoder::new(self.map.borrow_mut().by_ref()).encode(writer)?;
Ok(format_len + map_len)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encode::int::EncodeMinimizeInt;
use rstest::rstest;
#[rstest]
#[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
fn encode_slice_fix_array<K, V, Map, E>(#[case] value: Map, #[case] expected: E)
where
K: Encode,
V: Encode,
Map: AsRef<[(K, V)]>,
E: AsRef<[u8]> + Sized,
{
let expected = expected.as_ref();
let encoder = MapSliceEncoder::new(value.as_ref());
let mut buf = vec![];
let n = encoder.encode(&mut buf).unwrap();
assert_eq!(buf, expected);
assert_eq!(n, expected.len());
}
#[rstest]
#[case([("123", EncodeMinimizeInt(123)), ("456", EncodeMinimizeInt(456))], [0x82, 0xa3, 0x31, 0x32, 0x33, 0x7b, 0xa3, 0x34, 0x35, 0x36, 0xcd, 0x01, 0xc8])]
fn encode_iter_fix_array<I, KV, E>(#[case] value: I, #[case] expected: E)
where
I: IntoIterator<Item = KV>,
I::IntoIter: ExactSizeIterator,
KV: KVEncode,
E: AsRef<[u8]> + Sized,
{
let expected = expected.as_ref();
let encoder = MapEncoder::new(value.into_iter());
let mut buf = vec![];
let n = encoder.encode(&mut buf).unwrap();
assert_eq!(buf, expected);
assert_eq!(n, expected.len());
}
#[cfg(feature = "alloc")]
#[test]
fn encode_btreemap_sorted() {
let mut m = alloc::collections::BTreeMap::new();
m.insert(2u8, 20u8);
m.insert(1u8, 10u8);
let mut buf = alloc::vec::Vec::new();
let n = m.encode(&mut buf).unwrap();
assert_eq!(
&buf[..n],
&[0x82, 0x01, 0x0a, 0x02, 0x14] );
}
#[cfg(feature = "std")]
#[test]
fn encode_hashmap_roundtrip() {
use crate::decode::Decode;
let mut m = std::collections::HashMap::<u8, bool>::new();
m.insert(1, true);
m.insert(3, false);
let mut buf = Vec::new();
let _ = m.encode(&mut buf).unwrap();
let mut r = crate::io::SliceReader::new(&buf);
let back = <std::collections::HashMap<u8, bool> as Decode>::decode(&mut r).unwrap();
assert_eq!(back.len(), 2);
assert_eq!(back.get(&1), Some(&true));
assert_eq!(back.get(&3), Some(&false));
}
}