use crate::{
codec::{EncodeSize, Read, Write},
error::Error,
RangeCfg,
};
use bytes::{Buf, BufMut};
use std::{
cmp::Ordering,
collections::{BTreeMap, HashMap},
hash::Hash,
};
const BTREEMAP_TYPE: &str = "BTreeMap";
const HASHMAP_TYPE: &str = "HashMap";
fn read_ordered_map<K, V, F>(
buf: &mut impl Buf,
len: usize,
k_cfg: &K::Cfg,
v_cfg: &V::Cfg,
mut insert: F,
map_type: &'static str,
) -> Result<(), Error>
where
K: Read + Ord,
V: Read,
F: FnMut(K, V) -> Option<V>,
{
let mut last: Option<(K, V)> = None;
for _ in 0..len {
let key = K::read_cfg(buf, k_cfg)?;
if let Some((ref last_key, _)) = last {
match key.cmp(last_key) {
Ordering::Equal => return Err(Error::Invalid(map_type, "Duplicate key")),
Ordering::Less => return Err(Error::Invalid(map_type, "Keys must ascend")),
_ => {}
}
}
let value = V::read_cfg(buf, v_cfg)?;
if let Some((last_key, last_value)) = last.take() {
insert(last_key, last_value);
}
last = Some((key, value));
}
if let Some((last_key, last_value)) = last {
insert(last_key, last_value);
}
Ok(())
}
impl<K: Ord + Hash + Eq + Write, V: Write> Write for BTreeMap<K, V> {
fn write(&self, buf: &mut impl BufMut) {
self.len().write(buf);
for (k, v) in self {
k.write(buf);
v.write(buf);
}
}
}
impl<K: Ord + Hash + Eq + EncodeSize, V: EncodeSize> EncodeSize for BTreeMap<K, V> {
fn encode_size(&self) -> usize {
let mut size = self.len().encode_size();
for (k, v) in self {
size += k.encode_size();
size += v.encode_size();
}
size
}
}
impl<K: Read + Clone + Ord + Hash + Eq, V: Read + Clone> Read for BTreeMap<K, V> {
type Cfg = (RangeCfg, (K::Cfg, V::Cfg));
fn read_cfg(buf: &mut impl Buf, (range, (k_cfg, v_cfg)): &Self::Cfg) -> Result<Self, Error> {
let len = usize::read_cfg(buf, range)?;
let mut map = BTreeMap::new();
read_ordered_map(
buf,
len,
k_cfg,
v_cfg,
|k, v| map.insert(k, v),
BTREEMAP_TYPE,
)?;
Ok(map)
}
}
impl<K: Ord + Hash + Eq + Write, V: Write> Write for HashMap<K, V> {
fn write(&self, buf: &mut impl BufMut) {
self.len().write(buf);
let mut entries: Vec<_> = self.iter().collect();
entries.sort_by(|a, b| a.0.cmp(b.0));
for (k, v) in entries {
k.write(buf);
v.write(buf);
}
}
}
impl<K: Ord + Hash + Eq + EncodeSize, V: EncodeSize> EncodeSize for HashMap<K, V> {
fn encode_size(&self) -> usize {
let mut size = self.len().encode_size();
for (k, v) in self {
size += k.encode_size();
size += v.encode_size();
}
size
}
}
impl<K: Read + Clone + Ord + Hash + Eq, V: Read + Clone> Read for HashMap<K, V> {
type Cfg = (RangeCfg, (K::Cfg, V::Cfg));
fn read_cfg(buf: &mut impl Buf, (range, (k_cfg, v_cfg)): &Self::Cfg) -> Result<Self, Error> {
let len = usize::read_cfg(buf, range)?;
let mut map = HashMap::with_capacity(len);
read_ordered_map(
buf,
len,
k_cfg,
v_cfg,
|k, v| map.insert(k, v),
HASHMAP_TYPE,
)?;
Ok(map)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Decode, Encode, FixedSize};
use bytes::{Bytes, BytesMut};
use std::fmt::Debug;
fn round_trip_btree<K, V, KCfg, VCfg>(
map: &BTreeMap<K, V>,
range_cfg: RangeCfg,
k_cfg: KCfg,
v_cfg: VCfg,
) where
K: Write + EncodeSize + Read<Cfg = KCfg> + Clone + Ord + Hash + Eq + PartialEq + Debug,
V: Write + EncodeSize + Read<Cfg = VCfg> + Clone + PartialEq + Debug,
BTreeMap<K, V>: Read<Cfg = (RangeCfg, (K::Cfg, V::Cfg))>
+ Decode<Cfg = (RangeCfg, (K::Cfg, V::Cfg))>
+ PartialEq
+ Write
+ EncodeSize,
{
let encoded = map.encode();
assert_eq!(encoded.len(), map.encode_size());
let config_tuple = (range_cfg, (k_cfg, v_cfg));
let decoded = BTreeMap::<K, V>::decode_cfg(encoded, &config_tuple)
.expect("decode_cfg failed for BTreeMap");
assert_eq!(map, &decoded);
}
fn round_trip_hash<K, V, KCfg, VCfg>(
map: &HashMap<K, V>,
range_cfg: RangeCfg,
k_cfg: KCfg,
v_cfg: VCfg,
) where
K: Write + EncodeSize + Read<Cfg = KCfg> + Clone + Ord + Hash + Eq + PartialEq + Debug,
V: Write + EncodeSize + Read<Cfg = VCfg> + Clone + PartialEq + Debug,
HashMap<K, V>: Read<Cfg = (RangeCfg, (K::Cfg, V::Cfg))>
+ Decode<Cfg = (RangeCfg, (K::Cfg, V::Cfg))>
+ PartialEq
+ Write
+ EncodeSize,
{
let encoded = map.encode();
assert_eq!(encoded.len(), map.encode_size());
let config_tuple = (range_cfg, (k_cfg, v_cfg));
let decoded = HashMap::<K, V>::decode_cfg(encoded, &config_tuple)
.expect("decode_cfg failed for HashMap");
assert_eq!(map, &decoded);
}
fn round_trip<K, V>(map: &HashMap<K, V>, range_cfg: RangeCfg, k_cfg: K::Cfg, v_cfg: V::Cfg)
where
K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + PartialEq + Debug,
V: Write + EncodeSize + Read + Clone + PartialEq + Debug,
HashMap<K, V>: Read<Cfg = (RangeCfg, (K::Cfg, V::Cfg))> + PartialEq + Write + EncodeSize,
{
let encoded = map.encode();
let config_tuple = (range_cfg, (k_cfg, v_cfg));
let decoded = HashMap::<K, V>::decode_cfg(encoded, &config_tuple)
.expect("decode_cfg failed for HashMap");
assert_eq!(map, &decoded);
}
#[test]
fn test_empty_btreemap() {
let map = BTreeMap::<u32, u64>::new();
round_trip_btree(&map, (..).into(), (), ());
assert_eq!(map.encode_size(), 1); let encoded = map.encode();
assert_eq!(encoded, Bytes::from_static(&[0]));
}
#[test]
fn test_simple_btreemap_u32_u64() {
let mut map = BTreeMap::new();
map.insert(1u32, 100u64);
map.insert(5u32, 500u64);
map.insert(2u32, 200u64);
round_trip_btree(&map, (..).into(), (), ());
assert_eq!(map.encode_size(), 1 + 3 * (u32::SIZE + u64::SIZE));
let mut expected = BytesMut::new();
3usize.write(&mut expected); 1u32.write(&mut expected);
100u64.write(&mut expected);
2u32.write(&mut expected);
200u64.write(&mut expected);
5u32.write(&mut expected);
500u64.write(&mut expected);
assert_eq!(map.encode(), expected.freeze());
}
#[test]
fn test_large_btreemap() {
let mut map = BTreeMap::new();
for i in 0..1000 {
map.insert(i as u16, i as u64 * 2);
}
round_trip_btree(&map, (0..=1000).into(), (), ());
let mut map = BTreeMap::new();
for i in 0..1000usize {
map.insert(i, 1000usize + i);
}
round_trip_btree(
&map,
(0..=1000).into(),
(..=1000).into(),
(1000..=2000).into(),
);
}
#[test]
fn test_btreemap_with_variable_values() {
let mut map = BTreeMap::new();
map.insert(Bytes::from_static(b"apple"), vec![1, 2]);
map.insert(Bytes::from_static(b"banana"), vec![3, 4, 5]);
map.insert(Bytes::from_static(b"cherry"), vec![]);
let map_range = 0..=10;
let key_range = ..=10;
let val_range = 0..=100;
round_trip_btree(
&map,
map_range.into(),
key_range.into(),
(val_range.into(), ()),
);
}
#[test]
fn test_btreemap_decode_length_limit_exceeded() {
let mut map = BTreeMap::new();
map.insert(1u32, 100u64);
map.insert(5u32, 500u64);
let encoded = map.encode();
let config_tuple = ((0..=1).into(), ((), ()));
let result = BTreeMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
assert!(matches!(result, Err(Error::InvalidLength(2))));
}
#[test]
fn test_btreemap_decode_value_length_limit_exceeded() {
let mut map = BTreeMap::new();
map.insert(Bytes::from_static(b"key1"), vec![1, 2, 3, 4, 5]);
let key_range = ..=10;
let map_range = 0..=10;
let restrictive_val_range = 0..=3;
let encoded = map.encode();
let config_tuple = (
map_range.into(),
(key_range.into(), (restrictive_val_range.into(), ())),
);
let result = BTreeMap::<Bytes, Vec<u8>>::decode_cfg(encoded, &config_tuple);
assert!(matches!(result, Err(Error::InvalidLength(5))));
}
#[test]
fn test_btreemap_decode_invalid_key_order() {
let mut encoded = BytesMut::new();
2usize.write(&mut encoded); 5u32.write(&mut encoded); 500u64.write(&mut encoded); 2u32.write(&mut encoded); 200u64.write(&mut encoded);
let range = (..).into();
let config_tuple = (range, ((), ()));
let result = BTreeMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
assert!(matches!(
result,
Err(Error::Invalid("BTreeMap", "Keys must ascend"))
));
}
#[test]
fn test_btreemap_decode_duplicate_key() {
let mut encoded = BytesMut::new();
2usize.write(&mut encoded); 1u32.write(&mut encoded); 100u64.write(&mut encoded); 1u32.write(&mut encoded); 200u64.write(&mut encoded);
let range = (..).into();
let config_tuple = (range, ((), ()));
let result = BTreeMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
assert!(matches!(
result,
Err(Error::Invalid("BTreeMap", "Duplicate key"))
));
}
#[test]
fn test_btreemap_decode_end_of_buffer_key() {
let mut map = BTreeMap::new();
map.insert(1u32, 100u64);
map.insert(5u32, 500u64);
let mut encoded = map.encode();
encoded.truncate(map.encode_size() - 10);
let range = (..).into();
let config_tuple = (range, ((), ()));
let result = BTreeMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
assert!(matches!(result, Err(Error::EndOfBuffer)));
}
#[test]
fn test_btreemap_decode_end_of_buffer_value() {
let mut map = BTreeMap::new();
map.insert(1u32, 100u64);
map.insert(5u32, 500u64);
let mut encoded = map.encode();
encoded.truncate(map.encode_size() - 4);
let range = (..).into();
let config_tuple = (range, ((), ()));
let result = BTreeMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
assert!(matches!(result, Err(Error::EndOfBuffer)));
}
#[test]
fn test_btreemap_decode_extra_data() {
let mut map = BTreeMap::new();
map.insert(1u32, 100u64);
let mut encoded = map.encode();
encoded.put_u8(0xFF);
let config_tuple = ((..).into(), ((), ()));
let result = BTreeMap::<u32, u64>::decode_cfg(encoded.clone(), &config_tuple);
assert!(matches!(result, Err(Error::ExtraData(1))));
let read_result = BTreeMap::<u32, u64>::read_cfg(&mut encoded.clone(), &config_tuple);
assert!(read_result.is_ok());
let decoded_map = read_result.unwrap();
assert_eq!(decoded_map.len(), 1);
assert_eq!(decoded_map.get(&1u32), Some(&100u64));
}
#[test]
fn test_btreemap_deterministic_encoding() {
let mut map2 = BTreeMap::new();
(0..=1000u32).for_each(|i| {
map2.insert(i, i * 2);
});
let mut map1 = BTreeMap::new();
(0..=1000u32).rev().for_each(|i| {
map1.insert(i, i * 2);
});
assert_eq!(map1.encode(), map2.encode());
}
#[test]
fn test_btreemap_conformity() {
let map1 = BTreeMap::<u8, u16>::new();
let mut expected1 = BytesMut::new();
0usize.write(&mut expected1); assert_eq!(map1.encode(), expected1.freeze());
assert_eq!(map1.encode_size(), 1);
let mut map2 = BTreeMap::<u8, u16>::new();
map2.insert(2u8, 0xBBBBu16); map2.insert(1u8, 0xAAAAu16);
let mut expected2 = BytesMut::new();
2usize.write(&mut expected2); 1u8.write(&mut expected2); 0xAAAAu16.write(&mut expected2); 2u8.write(&mut expected2); 0xBBBBu16.write(&mut expected2); assert_eq!(map2.encode(), expected2.freeze());
assert_eq!(map2.encode_size(), 1 + (u8::SIZE + u16::SIZE) * 2);
let mut map3 = BTreeMap::<u16, bool>::new();
map3.insert(0x0303u16, true);
map3.insert(0x0101u16, false);
map3.insert(0x0202u16, true);
let mut expected3 = BytesMut::new();
3usize.write(&mut expected3); 0x0101u16.write(&mut expected3); false.write(&mut expected3); 0x0202u16.write(&mut expected3); true.write(&mut expected3); 0x0303u16.write(&mut expected3); true.write(&mut expected3); assert_eq!(map3.encode(), expected3.freeze());
assert_eq!(map3.encode_size(), 1 + (u16::SIZE + bool::SIZE) * 3);
let mut map4 = BTreeMap::<Bytes, Vec<u8>>::new();
map4.insert(Bytes::from_static(b"b"), vec![20u8, 21u8]);
map4.insert(Bytes::from_static(b"a"), vec![10u8]);
let mut expected4 = BytesMut::new();
2usize.write(&mut expected4);
Bytes::from_static(b"a").write(&mut expected4);
vec![10u8].write(&mut expected4);
Bytes::from_static(b"b").write(&mut expected4);
vec![20u8, 21u8].write(&mut expected4);
assert_eq!(map4.encode(), expected4.freeze());
let expected_size = 1usize.encode_size()
+ Bytes::from_static(b"a").encode_size()
+ vec![10u8].encode_size()
+ Bytes::from_static(b"b").encode_size()
+ vec![20u8, 21u8].encode_size();
assert_eq!(map4.encode_size(), expected_size);
}
#[test]
fn test_empty_hashmap() {
let map = HashMap::<u32, u64>::new();
round_trip_hash(&map, (..).into(), (), ());
assert_eq!(map.encode_size(), 1);
let encoded = map.encode();
assert_eq!(encoded, 0usize.encode());
}
#[test]
fn test_simple_hashmap_u32_u64() {
let mut map = HashMap::new();
map.insert(1u32, 100u64);
map.insert(5u32, 500u64);
map.insert(2u32, 200u64);
round_trip(&map, (..).into(), (), ());
assert_eq!(map.encode_size(), 1 + 3 * (u32::SIZE + u64::SIZE));
}
#[test]
fn test_large_hashmap() {
let mut map = HashMap::new();
for i in 0..1000 {
map.insert(i as u16, i as u64 * 2);
}
round_trip_hash(&map, (0..=1000).into(), (), ());
let mut map = HashMap::new();
for i in 0..1000usize {
map.insert(i, 1000usize + i);
}
round_trip_hash(
&map,
(0..=1000).into(),
(..=1000).into(),
(1000..=2000).into(),
);
}
#[test]
fn test_hashmap_with_variable_values() {
let mut map = HashMap::new();
map.insert(Bytes::from_static(b"apple"), vec![1, 2]);
map.insert(Bytes::from_static(b"banana"), vec![3, 4, 5]);
map.insert(Bytes::from_static(b"cherry"), vec![]);
let map_range = RangeCfg::from(0..=10);
let key_range = RangeCfg::from(..=10);
let val_range = RangeCfg::from(0..=100);
round_trip_hash(&map, map_range, key_range, (val_range, ()));
}
#[test]
fn test_hashmap_decode_length_limit_exceeded() {
let mut map = HashMap::new();
map.insert(1u32, 100u64);
map.insert(5u32, 500u64);
let encoded = map.encode();
let config_tuple = ((0..=1).into(), ((), ()));
let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
assert!(matches!(result, Err(Error::InvalidLength(2))));
}
#[test]
fn test_hashmap_decode_value_length_limit_exceeded() {
let mut map = HashMap::new();
map.insert(Bytes::from_static(b"key1"), vec![1u8, 2u8, 3u8, 4u8, 5u8]);
let key_range = RangeCfg::from(..=10);
let map_range = RangeCfg::from(0..=10);
let restrictive_val_range = RangeCfg::from(0..=3);
let encoded = map.encode();
let config_tuple = (map_range, (key_range, (restrictive_val_range, ())));
let result = HashMap::<Bytes, Vec<u8>>::decode_cfg(encoded, &config_tuple);
assert!(matches!(result, Err(Error::InvalidLength(5))));
}
#[test]
fn test_hashmap_decode_invalid_key_order() {
let mut encoded = BytesMut::new();
2usize.write(&mut encoded); 5u32.write(&mut encoded); 500u64.write(&mut encoded); 2u32.write(&mut encoded); 200u64.write(&mut encoded);
let range = (..).into();
let config_tuple = (range, ((), ()));
let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
assert!(matches!(
result,
Err(Error::Invalid("HashMap", "Keys must ascend"))
));
}
#[test]
fn test_hashmap_decode_duplicate_key() {
let mut encoded = BytesMut::new();
2usize.write(&mut encoded); 1u32.write(&mut encoded); 100u64.write(&mut encoded); 1u32.write(&mut encoded); 200u64.write(&mut encoded);
let range = (..).into();
let config_tuple = (range, ((), ()));
let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
assert!(matches!(
result,
Err(Error::Invalid("HashMap", "Duplicate key"))
));
}
#[test]
fn test_hashmap_decode_end_of_buffer_key() {
let mut map = HashMap::new();
map.insert(1u32, 100u64);
map.insert(5u32, 500u64);
let mut encoded = map.encode();
encoded.truncate(map.encode_size() - 10);
let range = (..).into();
let config_tuple = (range, ((), ()));
let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
assert!(matches!(result, Err(Error::EndOfBuffer)));
}
#[test]
fn test_hashmap_decode_end_of_buffer_value() {
let mut map = HashMap::new();
map.insert(1u32, 100u64);
map.insert(5u32, 500u64);
let mut encoded = map.encode();
encoded.truncate(map.encode_size() - 4);
let range = RangeCfg::from(..);
let config_tuple = (range, ((), ()));
let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
assert!(matches!(result, Err(Error::EndOfBuffer)));
}
#[test]
fn test_hashmap_decode_extra_data() {
let mut map = HashMap::new();
map.insert(1u32, 100u64);
let mut encoded = map.encode();
encoded.put_u8(0xFF);
let config_tuple = ((..).into(), ((), ()));
let result = HashMap::<u32, u64>::decode_cfg(encoded.clone(), &config_tuple);
assert!(matches!(result, Err(Error::ExtraData(1))));
let read_result = HashMap::<u32, u64>::read_cfg(&mut encoded, &config_tuple);
assert!(read_result.is_ok());
let decoded_map = read_result.unwrap();
assert_eq!(decoded_map.len(), 1);
assert_eq!(decoded_map.get(&1u32), Some(&100u64));
}
#[test]
fn test_hashmap_deterministic_encoding() {
let mut map2 = HashMap::new();
(0..=1000u32).for_each(|i| {
map2.insert(i, i * 2);
});
let mut map1 = HashMap::new();
(0..=1000u32).rev().for_each(|i| {
map1.insert(i, i * 2);
});
assert_eq!(map1.encode(), map2.encode());
}
#[test]
fn test_hashmap_conformity() {
let map1 = HashMap::<u8, u16>::new();
let mut expected1 = BytesMut::new();
0usize.write(&mut expected1); assert_eq!(map1.encode(), expected1.freeze());
let mut map2 = HashMap::<u8, u16>::new();
map2.insert(2u8, 0xBBBBu16); map2.insert(1u8, 0xAAAAu16);
let mut expected2 = BytesMut::new();
2usize.write(&mut expected2); 1u8.write(&mut expected2); 0xAAAAu16.write(&mut expected2); 2u8.write(&mut expected2); 0xBBBBu16.write(&mut expected2); assert_eq!(map2.encode(), expected2.freeze());
let mut map3 = HashMap::<u16, bool>::new();
map3.insert(0x0303u16, true);
map3.insert(0x0101u16, false);
map3.insert(0x0202u16, true);
let mut expected3 = BytesMut::new();
3usize.write(&mut expected3); 0x0101u16.write(&mut expected3); false.write(&mut expected3); 0x0202u16.write(&mut expected3); true.write(&mut expected3); 0x0303u16.write(&mut expected3); true.write(&mut expected3); assert_eq!(map3.encode(), expected3.freeze());
let mut map4 = HashMap::<Bytes, Vec<u8>>::new();
map4.insert(Bytes::from_static(b"b"), vec![20u8, 21u8]);
map4.insert(Bytes::from_static(b"a"), vec![10u8]);
let mut expected4 = BytesMut::new();
2usize.write(&mut expected4);
Bytes::from_static(b"a").write(&mut expected4);
vec![10u8].write(&mut expected4);
Bytes::from_static(b"b").write(&mut expected4);
vec![20u8, 21u8].write(&mut expected4);
assert_eq!(map4.encode(), expected4.freeze());
}
}