use core::any::{Any, TypeId};
use core::hash::Hash;
use hashbrown::HashMap;
use smallbox::SmallBox;
use smallbox::space::S8;
#[cfg(not(feature = "std"))]
use alloc::boxed::Box;
#[cfg(feature = "std")]
use std::boxed::Box;
use crate::prelude::*;
const DEFAULT_INITIAL_CAPACITY: usize = 128;
const DEFAULT_NUM_TYPES: usize = 4;
pub trait DedupeEncodeable: Hash + Eq + Pack + Clone + Send + Sync + 'static {}
impl<T: DedupeEncodeable> Encode for T {
#[inline(always)]
fn encode_ext(
&self,
writer: &mut impl Write,
ctx: Option<&mut crate::context::EncoderContext>,
) -> Result<usize> {
if let Some(ctx) = ctx
&& let Some(encoder) = ctx.dedupe.as_mut()
{
return encoder.encode(self, writer);
}
self.pack(writer)
}
#[inline(always)]
fn encode_slice(items: &[Self], writer: &mut impl Write) -> Result<usize> {
T::pack_slice(items, writer)
}
}
pub trait DedupeDecodeable: Pack + Clone + Hash + Eq + Send + Sync + 'static {}
impl<T: DedupeDecodeable> Decode for T {
#[inline(always)]
fn decode_ext(
reader: &mut impl Read,
ctx: Option<&mut crate::context::DecoderContext>,
) -> Result<Self> {
if let Some(ctx) = ctx
&& let Some(decoder) = ctx.dedupe.as_mut()
{
return decoder.decode(reader);
}
T::unpack(reader)
}
#[inline(always)]
fn decode_vec(reader: &mut impl Read, count: usize) -> Result<Vec<Self>> {
T::unpack_vec(reader, count)
}
}
pub struct DedupeEncoder {
type_stores: HashMap<TypeId, SmallBox<dyn Any + Send + Sync, S8>>,
next_id: usize,
initial_capacity: usize,
}
impl Default for DedupeEncoder {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
impl DedupeEncoder {
#[inline(always)]
pub fn new() -> Self {
Self {
type_stores: HashMap::with_capacity(DEFAULT_NUM_TYPES),
next_id: 1, initial_capacity: DEFAULT_INITIAL_CAPACITY,
}
}
#[inline(always)]
pub fn with_capacity(initial_capacity: usize, num_types: usize) -> Self {
Self {
type_stores: HashMap::with_capacity(num_types),
next_id: 1,
initial_capacity,
}
}
#[inline(always)]
pub fn clear(&mut self) {
self.type_stores.clear();
self.next_id = 1;
}
#[inline(always)]
pub const fn len(&self) -> usize {
self.next_id - 1
}
#[inline(always)]
pub const fn is_empty(&self) -> bool {
self.next_id == 1
}
#[inline(always)]
pub fn num_types(&self) -> usize {
self.type_stores.len()
}
#[inline(always)]
pub fn type_ids(&self) -> impl Iterator<Item = TypeId> + '_ {
self.type_stores.keys().copied()
}
#[inline]
pub fn contains_type<T: 'static>(&self) -> bool {
self.type_stores.contains_key(&TypeId::of::<T>())
}
#[inline]
pub fn len_for_type<T: Hash + Eq + Send + Sync + 'static>(&self) -> usize {
let type_id = TypeId::of::<T>();
match self.type_stores.get(&type_id) {
Some(store) => store
.downcast_ref::<HashMap<T, usize>>()
.map_or(0, |m| m.len()),
None => 0,
}
}
#[inline]
pub fn values_for_type<T: Hash + Eq + Send + Sync + 'static>(
&self,
) -> impl Iterator<Item = &T> {
let type_id = TypeId::of::<T>();
self.type_stores
.get(&type_id)
.and_then(|store| store.downcast_ref::<HashMap<T, usize>>())
.into_iter()
.flat_map(|m| m.keys())
}
#[inline]
pub fn clear_type<T: Hash + Eq + Send + Sync + 'static>(&mut self) {
let type_id = TypeId::of::<T>();
self.type_stores.remove(&type_id);
}
#[inline]
pub fn memory_usage(&self) -> usize {
use core::mem::size_of;
let mut total = self.type_stores.capacity()
* (size_of::<TypeId>() + size_of::<SmallBox<dyn Any + Send + Sync, S8>>());
let entry_count = self.len();
total += entry_count * size_of::<usize>() * 3;
total
}
#[inline]
pub fn encode<T: Hash + Eq + Pack + Clone + Send + Sync + 'static>(
&mut self,
val: &T,
writer: &mut impl Write,
) -> Result<usize> {
let type_id = TypeId::of::<T>();
let store = self.type_stores.entry(type_id).or_insert_with(|| {
smallbox::smallbox!(HashMap::<T, usize>::with_capacity(self.initial_capacity))
});
let typed_store = store
.downcast_mut::<HashMap<T, usize>>()
.expect("Type mismatch in type store");
if let Some(&existing_id) = typed_store.get(val) {
return Lencode::encode_varint(existing_id, writer);
}
let new_id = self.next_id;
self.next_id += 1;
typed_store.insert(val.clone(), new_id);
let mut total_bytes = 0;
total_bytes += Lencode::encode_varint(0usize, writer)?; total_bytes += val.pack(writer)?;
Ok(total_bytes)
}
}
#[derive(Default)]
pub struct DedupeDecoder {
values: Vec<Box<dyn Any + Send + Sync>>,
}
impl DedupeDecoder {
#[inline(always)]
pub fn new() -> Self {
Self {
values: Vec::with_capacity(DEFAULT_INITIAL_CAPACITY),
}
}
#[inline(always)]
pub fn with_capacity(capacity: usize) -> Self {
Self {
values: Vec::with_capacity(capacity),
}
}
#[inline(always)]
pub fn clear(&mut self) {
self.values.clear();
}
#[inline(always)]
pub fn len(&self) -> usize {
self.values.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
#[inline]
pub fn memory_usage(&self) -> usize {
use core::mem::size_of;
self.values.capacity() * size_of::<Box<dyn Any + Send + Sync>>()
}
#[inline]
pub fn decode<T: Pack + Clone + Hash + Eq + Send + Sync + 'static>(
&mut self,
reader: &mut impl Read,
) -> Result<T> {
let id = Lencode::decode_varint::<usize>(reader)?;
if id == 0 {
let value = T::unpack(reader)?;
self.values.push(Box::new(value.clone()));
Ok(value)
} else {
let index = id - 1; if let Some(boxed_value) = self.values.get(index)
&& let Some(typed_value) = boxed_value.downcast_ref::<T>()
{
return Ok(typed_value.clone());
}
Err(crate::io::Error::InvalidData)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::Cursor;
#[test]
fn test_dedupe_encode_decode_roundtrip() {
let mut encoder = DedupeEncoder::new();
let mut decoder = DedupeDecoder::new();
let mut buffer = Vec::new();
let values = [42u32, 123u32, 42u32, 456u32, 123u32, 789u32, 42u32];
for &value in &values {
encoder.encode(&value, &mut buffer).unwrap();
}
let mut cursor = Cursor::new(&buffer);
let mut decoded_values = Vec::new();
for _ in &values {
let decoded: u32 = decoder.decode(&mut cursor).unwrap();
decoded_values.push(decoded);
}
assert_eq!(values.to_vec(), decoded_values);
}
#[test]
fn test_dedupe_clear() {
let mut encoder = DedupeEncoder::new();
let mut decoder = DedupeDecoder::new();
let mut buffer = Vec::new();
encoder.encode(&42u32, &mut buffer).unwrap();
encoder.encode(&123u32, &mut buffer).unwrap();
encoder.clear();
decoder.clear();
buffer.clear();
encoder.encode(&42u32, &mut buffer).unwrap(); encoder.encode(&42u32, &mut buffer).unwrap();
let mut cursor = Cursor::new(&buffer);
let decoded1: u32 = decoder.decode(&mut cursor).unwrap();
let decoded2: u32 = decoder.decode(&mut cursor).unwrap();
assert_eq!(decoded1, 42u32);
assert_eq!(decoded2, 42u32);
}
#[test]
fn test_dedupe_len_for_type() {
let mut encoder = DedupeEncoder::new();
let mut buffer = Vec::new();
assert_eq!(encoder.len_for_type::<u32>(), 0);
assert_eq!(encoder.num_types(), 0);
encoder.encode(&42u32, &mut buffer).unwrap();
encoder.encode(&42u32, &mut buffer).unwrap(); encoder.encode(&99u32, &mut buffer).unwrap();
encoder.encode(&7u64, &mut buffer).unwrap();
assert_eq!(encoder.len_for_type::<u32>(), 2);
assert_eq!(encoder.len_for_type::<u64>(), 1);
assert_eq!(encoder.len_for_type::<u16>(), 0);
assert_eq!(encoder.num_types(), 2);
assert_eq!(encoder.len(), 3);
}
#[test]
fn test_dedupe_clear_type() {
let mut encoder = DedupeEncoder::new();
let mut buffer = Vec::new();
encoder.encode(&42u32, &mut buffer).unwrap();
encoder.encode(&7u64, &mut buffer).unwrap();
assert_eq!(encoder.num_types(), 2);
encoder.clear_type::<u32>();
assert_eq!(encoder.len_for_type::<u32>(), 0);
assert_eq!(encoder.len_for_type::<u64>(), 1);
assert_eq!(encoder.num_types(), 1);
}
#[test]
fn test_dedupe_memory_usage() {
let mut encoder = DedupeEncoder::new();
let mut buffer = Vec::new();
let initial = encoder.memory_usage();
encoder.encode(&42u32, &mut buffer).unwrap();
encoder.encode(&99u32, &mut buffer).unwrap();
let after = encoder.memory_usage();
assert!(
after > initial,
"memory usage should increase after storing entries"
);
}
#[test]
fn test_dedupe_decoder_memory_usage() {
let decoder = DedupeDecoder::new();
let _usage = decoder.memory_usage();
}
#[test]
fn test_dedupe_invalid_id() {
let mut decoder = DedupeDecoder::new();
let mut buffer = Vec::new();
Lencode::encode_varint(5usize, &mut buffer).unwrap();
let mut cursor = Cursor::new(&buffer);
let result: Result<u32> = decoder.decode(&mut cursor);
assert!(result.is_err());
matches!(result, Err(crate::io::Error::InvalidData));
}
}