use crate::{
local_cache::{LocalAccess, LocalCache},
lru_cache::{EntryState, LruCache},
Compressed, Compression,
};
use std::collections::{hash_map::RandomState, HashMap};
use std::hash::{BuildHasher, Hash};
pub struct CompressibleMap<K, V, A, H = RandomState>
where
A: Compression<Data = V>,
{
cache: LruCache<K, V, H>,
compressed: HashMap<K, Compressed<A>, H>,
compression_params: A,
}
impl<K, V, H, A> CompressibleMap<K, V, A, H>
where
K: Clone + Eq + Hash,
H: BuildHasher + Default,
A: Compression<Data = V>,
{
pub fn new(compression_params: A) -> Self {
Self {
cache: LruCache::default(),
compressed: HashMap::default(),
compression_params,
}
}
pub fn compression_params(&self) -> &A {
&self.compression_params
}
pub fn from_all_compressed(
compression_params: A,
compressed: HashMap<K, Compressed<A>, H>,
) -> Self {
let mut cache = LruCache::<K, V, H>::default();
for key in compressed.keys() {
cache.evict(key.clone());
}
Self {
cache,
compressed,
compression_params,
}
}
pub fn insert(&mut self, key: K, value: V) -> Option<MaybeCompressed<V, Compressed<A>>> {
self.cache
.insert(key.clone(), value)
.map(|old_cache_entry| match old_cache_entry {
EntryState::Cached(v) => MaybeCompressed::Decompressed(v),
EntryState::Evicted => {
let compressed_value = self.compressed.remove(&key).unwrap();
MaybeCompressed::Compressed(compressed_value)
}
})
}
pub fn insert_compressed(
&mut self,
key: K,
value: Compressed<A>,
) -> Option<MaybeCompressed<V, Compressed<A>>> {
let old_cached_value = self
.cache
.evict(key.clone())
.map(|e| e.some_if_cached())
.flatten();
self.compressed
.insert(key, value)
.map(|v| MaybeCompressed::Compressed(v))
.or(old_cached_value.map(|v| MaybeCompressed::Decompressed(v)))
}
pub fn insert_maybe_compressed(
&mut self,
key: K,
value: MaybeCompressed<V, Compressed<A>>,
) -> Option<MaybeCompressed<V, Compressed<A>>> {
match value {
MaybeCompressed::Compressed(c) => self.insert_compressed(key, c),
MaybeCompressed::Decompressed(c) => self.insert(key, c),
}
}
pub fn compress_lru(&mut self) {
if let Some((lru_key, lru_value)) = self.cache.evict_lru() {
self.compressed
.insert(lru_key, self.compression_params.compress(&lru_value));
}
}
pub fn remove_lru(&mut self) -> Option<(K, V)> {
self.cache.remove_lru()
}
pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
let CompressibleMap {
cache, compressed, ..
} = self;
cache.get_or_repopulate_with(key.clone(), || {
compressed.remove(&key).map(|v| v.decompress()).unwrap()
})
}
pub fn get(&mut self, key: K) -> Option<&V> {
self.get_mut(key).map(|v| &*v)
}
pub fn get_or_insert_with(&mut self, key: K, on_missing: impl FnOnce() -> V) -> &mut V {
let CompressibleMap {
cache, compressed, ..
} = self;
let on_evicted = || compressed.remove(&key).unwrap().decompress();
cache.get_or_insert_with(key.clone(), on_evicted, on_missing)
}
pub fn insert_if_vacant(&mut self, key: K, value: V) -> &mut V {
self.get_or_insert_with(key, || value)
}
pub fn get_const<'a>(&'a self, key: K, local_cache: &'a LocalCache<K, V, H>) -> Option<&'a V> {
self.cache.get_const(&key).map(|entry| {
match entry {
EntryState::Cached(v) => {
local_cache.remember_cached_access(key.clone());
v
}
EntryState::Evicted => {
local_cache.get_or_insert_with(key.clone(), || {
self.compressed.get(&key).unwrap().decompress()
})
}
}
})
}
pub fn get_copy_without_caching(&self, key: &K) -> Option<MaybeCompressed<V, Compressed<A>>>
where
V: Clone,
Compressed<A>: Clone,
{
self.cache.get_const(key).map(|entry| match entry {
EntryState::Cached(v) => MaybeCompressed::Decompressed(v.clone()),
EntryState::Evicted => {
MaybeCompressed::Compressed(self.compressed.get(key).unwrap().clone())
}
})
}
pub fn flush_local_cache(&mut self, local_cache: LocalCache<K, V, H>) {
let CompressibleMap {
cache, compressed, ..
} = self;
for (key, access) in local_cache.into_iter() {
match access {
LocalAccess::Cached => {
cache.get(&key);
}
LocalAccess::Missed(value) => {
cache.get_or_repopulate_with(key.clone(), || {
compressed.remove(&key);
value
});
}
}
}
}
pub fn drop(&mut self, key: &K) {
self.cache.remove(key);
self.compressed.remove(key);
}
pub fn remove(&mut self, key: &K) -> Option<MaybeCompressed<V, Compressed<A>>> {
self.cache.remove(key).map(|entry| match entry {
EntryState::Cached(v) => MaybeCompressed::Decompressed(v),
EntryState::Evicted => {
MaybeCompressed::Compressed(self.compressed.remove(key).unwrap())
}
})
}
pub fn clear(&mut self) {
self.cache.clear();
self.compressed.clear();
}
pub fn len(&self) -> usize {
self.len_cached() + self.len_compressed()
}
pub fn len_cached(&self) -> usize {
self.cache.len_cached()
}
pub fn len_compressed(&self) -> usize {
self.compressed.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn keys<'a>(&'a self) -> impl Iterator<Item = &K>
where
Compressed<A>: 'a,
{
self.cache.keys()
}
pub fn iter<'a>(&'a self) -> impl Iterator<Item = (&K, MaybeCompressed<&V, &Compressed<A>>)>
where
Compressed<A>: 'a,
{
self.cache
.iter()
.map(|(k, v)| (k, MaybeCompressed::Decompressed(v)))
.chain(
self.compressed
.iter()
.map(|(k, v)| (k, MaybeCompressed::Compressed(v))),
)
}
pub fn into_iter(self) -> impl Iterator<Item = (K, MaybeCompressed<V, Compressed<A>>)> {
self.cache
.into_iter()
.map(|(k, v)| (k, MaybeCompressed::Decompressed(v)))
.chain(
self.compressed
.into_iter()
.map(|(k, v)| (k, MaybeCompressed::Compressed(v))),
)
}
}
pub enum MaybeCompressed<D, C> {
Decompressed(D),
Compressed(C),
}
impl<A: Compression> MaybeCompressed<A::Data, Compressed<A>> {
pub fn as_decompressed(self) -> A::Data {
match self {
MaybeCompressed::Compressed(c) => c.decompress(),
MaybeCompressed::Decompressed(d) => d,
}
}
pub fn unwrap_decompressed(self) -> A::Data {
match self {
MaybeCompressed::Compressed(_) => panic!("Must be decompressed"),
MaybeCompressed::Decompressed(d) => d,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct FakeFooCompression;
impl Compression for FakeFooCompression {
type Data = Foo;
type CompressedData = Foo;
fn compress(&self, data: &Self::Data) -> Compressed<Self> {
Compressed::new(Foo(data.0 + 1))
}
fn decompress(compressed: &Self::CompressedData) -> Self::Data {
Foo(compressed.0 + 1)
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
struct Foo(u32);
#[test]
fn get_after_compress() {
let mut map = CompressibleMap::<_, _, _>::new(FakeFooCompression);
map.insert(1, Foo(0));
map.compress_lru();
assert_eq!(map.len_cached(), 0);
assert_eq!(map.len_compressed(), 1);
assert_eq!(Some(&Foo(2)), map.get(1));
assert_eq!(map.len_cached(), 1);
assert_eq!(map.len_compressed(), 0);
}
#[test]
fn keys_iterator_has_both_cached_and_compressed() {
let mut map = CompressibleMap::<_, _, _>::new(FakeFooCompression);
map.insert(1, Foo(0));
map.insert(2, Foo(0));
map.compress_lru();
let mut keys: Vec<i32> = map.keys().cloned().collect();
keys.sort();
assert_eq!(keys, vec![1, 2]);
}
#[test]
fn flush_after_get_const_populates_cache() {
fn do_test_with_global_cache(map: &mut CompressibleMap<i32, Foo, FakeFooCompression>) {
map.insert(1, Foo(0));
map.insert(2, Foo(1));
map.compress_lru();
map.compress_lru();
let local_cache = LocalCache::default();
let mut values = Vec::new();
values.push(map.get_const(1, &local_cache));
values.push(map.get_const(2, &local_cache));
assert_eq!(Some(&Foo(2)), values[0]);
assert_eq!(Some(&Foo(3)), values[1]);
assert_eq!(map.len_cached(), 0);
assert_eq!(map.len_compressed(), 2);
map.flush_local_cache(local_cache);
assert_eq!(map.len_cached(), 2);
assert_eq!(map.len_compressed(), 0);
assert_eq!(Some(&Foo(2)), map.get(1));
assert_eq!(Some(&Foo(3)), map.get(2));
}
let mut map = CompressibleMap::new(FakeFooCompression);
do_test_with_global_cache(&mut map);
}
#[test]
fn multithreaded_borrows() {
use crossbeam::thread;
let mut map = CompressibleMap::<_, _, _>::new(FakeFooCompression);
for i in 0..100 {
map.insert(i, Foo(i));
}
for _ in 0..50 {
map.compress_lru();
}
let local_cache = LocalCache::new();
let mut batch = Vec::new();
for i in 0..100 {
batch.push(map.get_const(i, &local_cache));
}
thread::scope(|s| {
for (i, value) in batch.into_iter().enumerate() {
s.spawn(move |_| {
if i < 50 {
assert_eq!(value, Some(&Foo((i + 2) as u32)))
} else {
assert_eq!(value, Some(&Foo(i as u32)))
}
});
}
})
.unwrap();
map.flush_local_cache(local_cache);
assert_eq!(map.len_cached(), 100);
}
#[test]
fn multithreaded_decompression() {
use crossbeam::{channel, thread};
let mut map = CompressibleMap::<_, _, _>::new(FakeFooCompression);
for i in 0..100 {
map.insert(i, Foo(i));
}
for _ in 0..50 {
map.compress_lru();
}
let map_ref = ↦
let (tx, rx) = channel::unbounded();
{
let mut txs = Vec::new();
for _ in 0..99 {
txs.push(tx.clone());
}
txs.push(tx);
let txs_ref = &txs;
thread::scope(|s| {
for i in 0..100 {
s.spawn(move |_| {
let local_cache = LocalCache::new();
if i < 50 {
assert_eq!(
map_ref.get_const(i, &local_cache),
Some(&Foo((i + 2) as u32))
)
} else {
assert_eq!(map_ref.get_const(i, &local_cache), Some(&Foo(i as u32)))
}
txs_ref[i as usize].send(local_cache).unwrap();
});
}
})
.unwrap();
}
loop {
match rx.recv() {
Ok(cache) => map.flush_local_cache(cache),
Err(_) => {
break;
}
}
}
assert_eq!(map.len_cached(), 100);
}
}