1use crate::{
2 local_cache::{LocalAccess, LocalCache},
3 lru_cache::{EntryState, LruCache},
4 Compressed, Compression,
5};
6
7use std::collections::{hash_map::RandomState, HashMap};
8use std::hash::{BuildHasher, Hash};
9
10pub struct CompressibleMap<K, V, A, H = RandomState>
25where
26 A: Compression<Data = V>,
27{
28 cache: LruCache<K, V, H>,
29 compressed: HashMap<K, Compressed<A>, H>,
30 compression_params: A,
31}
32
33impl<K, V, H, A> CompressibleMap<K, V, A, H>
34where
35 K: Clone + Eq + Hash,
36 H: BuildHasher + Default,
37 A: Compression<Data = V>,
38{
39 pub fn new(compression_params: A) -> Self {
40 Self {
41 cache: LruCache::default(),
42 compressed: HashMap::default(),
43 compression_params,
44 }
45 }
46
47 pub fn compression_params(&self) -> &A {
48 &self.compression_params
49 }
50
51 pub fn from_all_compressed(
52 compression_params: A,
53 compressed: HashMap<K, Compressed<A>, H>,
54 ) -> Self {
55 let mut cache = LruCache::<K, V, H>::default();
56 for key in compressed.keys() {
57 cache.evict(key.clone());
58 }
59
60 Self {
61 cache,
62 compressed,
63 compression_params,
64 }
65 }
66
67 pub fn insert(&mut self, key: K, value: V) -> Option<MaybeCompressed<V, Compressed<A>>> {
69 self.cache
70 .insert(key.clone(), value)
71 .map(|old_cache_entry| match old_cache_entry {
72 EntryState::Cached(v) => MaybeCompressed::Decompressed(v),
73 EntryState::Evicted => {
74 let compressed_value = self.compressed.remove(&key).unwrap();
75
76 MaybeCompressed::Compressed(compressed_value)
77 }
78 })
79 }
80
81 pub fn insert_compressed(
83 &mut self,
84 key: K,
85 value: Compressed<A>,
86 ) -> Option<MaybeCompressed<V, Compressed<A>>> {
87 let old_cached_value = self
88 .cache
89 .evict(key.clone())
90 .map(|e| e.some_if_cached())
91 .flatten();
92
93 self.compressed
94 .insert(key, value)
95 .map(|v| MaybeCompressed::Compressed(v))
96 .or(old_cached_value.map(|v| MaybeCompressed::Decompressed(v)))
97 }
98
99 pub fn insert_maybe_compressed(
100 &mut self,
101 key: K,
102 value: MaybeCompressed<V, Compressed<A>>,
103 ) -> Option<MaybeCompressed<V, Compressed<A>>> {
104 match value {
105 MaybeCompressed::Compressed(c) => self.insert_compressed(key, c),
106 MaybeCompressed::Decompressed(c) => self.insert(key, c),
107 }
108 }
109
110 pub fn compress_lru(&mut self) {
111 if let Some((lru_key, lru_value)) = self.cache.evict_lru() {
112 self.compressed
113 .insert(lru_key, self.compression_params.compress(&lru_value));
114 }
115 }
116
117 pub fn remove_lru(&mut self) -> Option<(K, V)> {
118 self.cache.remove_lru()
119 }
120
121 pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
122 let CompressibleMap {
123 cache, compressed, ..
124 } = self;
125
126 cache.get_or_repopulate_with(key.clone(), || {
127 compressed.remove(&key).map(|v| v.decompress()).unwrap()
128 })
129 }
130
131 pub fn get(&mut self, key: K) -> Option<&V> {
132 self.get_mut(key).map(|v| &*v)
134 }
135
136 pub fn get_or_insert_with(&mut self, key: K, on_missing: impl FnOnce() -> V) -> &mut V {
137 let CompressibleMap {
138 cache, compressed, ..
139 } = self;
140
141 let on_evicted = || compressed.remove(&key).unwrap().decompress();
142
143 cache.get_or_insert_with(key.clone(), on_evicted, on_missing)
144 }
145
146 pub fn insert_if_vacant(&mut self, key: K, value: V) -> &mut V {
147 self.get_or_insert_with(key, || value)
148 }
149
150 pub fn get_const<'a>(&'a self, key: K, local_cache: &'a LocalCache<K, V, H>) -> Option<&'a V> {
156 self.cache.get_const(&key).map(|entry| {
157 match entry {
158 EntryState::Cached(v) => {
159 local_cache.remember_cached_access(key.clone());
161
162 v
163 }
164 EntryState::Evicted => {
165 local_cache.get_or_insert_with(key.clone(), || {
167 self.compressed.get(&key).unwrap().decompress()
168 })
169 }
170 }
171 })
172 }
173
174 pub fn get_copy_without_caching(&self, key: &K) -> Option<MaybeCompressed<V, Compressed<A>>>
179 where
180 V: Clone,
181 Compressed<A>: Clone,
182 {
183 self.cache.get_const(key).map(|entry| match entry {
184 EntryState::Cached(v) => MaybeCompressed::Decompressed(v.clone()),
185 EntryState::Evicted => {
186 MaybeCompressed::Compressed(self.compressed.get(key).unwrap().clone())
187 }
188 })
189 }
190
191 pub fn flush_local_cache(&mut self, local_cache: LocalCache<K, V, H>) {
195 let CompressibleMap {
196 cache, compressed, ..
197 } = self;
198 for (key, access) in local_cache.into_iter() {
199 match access {
200 LocalAccess::Cached => {
201 cache.get(&key);
204 }
205 LocalAccess::Missed(value) => {
206 cache.get_or_repopulate_with(key.clone(), || {
210 compressed.remove(&key);
211
212 value
213 });
214 }
215 }
216 }
217 }
218
219 pub fn drop(&mut self, key: &K) {
220 self.cache.remove(key);
221 self.compressed.remove(key);
222 }
223
224 pub fn remove(&mut self, key: &K) -> Option<MaybeCompressed<V, Compressed<A>>> {
226 self.cache.remove(key).map(|entry| match entry {
227 EntryState::Cached(v) => MaybeCompressed::Decompressed(v),
228 EntryState::Evicted => {
229 MaybeCompressed::Compressed(self.compressed.remove(key).unwrap())
230 }
231 })
232 }
233
234 pub fn clear(&mut self) {
235 self.cache.clear();
236 self.compressed.clear();
237 }
238
239 pub fn len(&self) -> usize {
240 self.len_cached() + self.len_compressed()
241 }
242
243 pub fn len_cached(&self) -> usize {
244 self.cache.len_cached()
245 }
246
247 pub fn len_compressed(&self) -> usize {
248 self.compressed.len()
249 }
250
251 pub fn is_empty(&self) -> bool {
252 self.len() == 0
253 }
254
255 pub fn keys<'a>(&'a self) -> impl Iterator<Item = &K>
256 where
257 Compressed<A>: 'a,
258 {
259 self.cache.keys()
260 }
261
262 pub fn iter<'a>(&'a self) -> impl Iterator<Item = (&K, MaybeCompressed<&V, &Compressed<A>>)>
265 where
266 Compressed<A>: 'a,
267 {
268 self.cache
269 .iter()
270 .map(|(k, v)| (k, MaybeCompressed::Decompressed(v)))
271 .chain(
272 self.compressed
273 .iter()
274 .map(|(k, v)| (k, MaybeCompressed::Compressed(v))),
275 )
276 }
277
278 pub fn into_iter(self) -> impl Iterator<Item = (K, MaybeCompressed<V, Compressed<A>>)> {
279 self.cache
280 .into_iter()
281 .map(|(k, v)| (k, MaybeCompressed::Decompressed(v)))
282 .chain(
283 self.compressed
284 .into_iter()
285 .map(|(k, v)| (k, MaybeCompressed::Compressed(v))),
286 )
287 }
288}
289
290pub enum MaybeCompressed<D, C> {
291 Decompressed(D),
292 Compressed(C),
293}
294
295impl<A: Compression> MaybeCompressed<A::Data, Compressed<A>> {
296 pub fn as_decompressed(self) -> A::Data {
297 match self {
298 MaybeCompressed::Compressed(c) => c.decompress(),
299 MaybeCompressed::Decompressed(d) => d,
300 }
301 }
302
303 pub fn unwrap_decompressed(self) -> A::Data {
304 match self {
305 MaybeCompressed::Compressed(_) => panic!("Must be decompressed"),
306 MaybeCompressed::Decompressed(d) => d,
307 }
308 }
309}
310
311#[cfg(test)]
319mod tests {
320 use super::*;
321
322 struct FakeFooCompression;
323
324 impl Compression for FakeFooCompression {
325 type Data = Foo;
326 type CompressedData = Foo;
327
328 fn compress(&self, data: &Self::Data) -> Compressed<Self> {
329 Compressed::new(Foo(data.0 + 1))
330 }
331
332 fn decompress(compressed: &Self::CompressedData) -> Self::Data {
333 Foo(compressed.0 + 1)
334 }
335 }
336
337 #[derive(Clone, Debug, Default, Eq, PartialEq)]
338 struct Foo(u32);
339
340 #[test]
341 fn get_after_compress() {
342 let mut map = CompressibleMap::<_, _, _>::new(FakeFooCompression);
343
344 map.insert(1, Foo(0));
345
346 map.compress_lru();
347
348 assert_eq!(map.len_cached(), 0);
349 assert_eq!(map.len_compressed(), 1);
350
351 assert_eq!(Some(&Foo(2)), map.get(1));
352
353 assert_eq!(map.len_cached(), 1);
354 assert_eq!(map.len_compressed(), 0);
355 }
356
357 #[test]
358 fn keys_iterator_has_both_cached_and_compressed() {
359 let mut map = CompressibleMap::<_, _, _>::new(FakeFooCompression);
360
361 map.insert(1, Foo(0));
362 map.insert(2, Foo(0));
363
364 map.compress_lru();
365
366 let mut keys: Vec<i32> = map.keys().cloned().collect();
367 keys.sort();
368 assert_eq!(keys, vec![1, 2]);
369 }
370
371 #[test]
372 fn flush_after_get_const_populates_cache() {
373 fn do_test_with_global_cache(map: &mut CompressibleMap<i32, Foo, FakeFooCompression>) {
375 map.insert(1, Foo(0));
376 map.insert(2, Foo(1));
377
378 map.compress_lru();
380 map.compress_lru();
381
382 let local_cache = LocalCache::default();
383 let mut values = Vec::new();
384 values.push(map.get_const(1, &local_cache));
385 values.push(map.get_const(2, &local_cache));
386
387 assert_eq!(Some(&Foo(2)), values[0]);
392 assert_eq!(Some(&Foo(3)), values[1]);
393
394 assert_eq!(map.len_cached(), 0);
396 assert_eq!(map.len_compressed(), 2);
397
398 map.flush_local_cache(local_cache);
399
400 assert_eq!(map.len_cached(), 2);
401 assert_eq!(map.len_compressed(), 0);
402
403 assert_eq!(Some(&Foo(2)), map.get(1));
404 assert_eq!(Some(&Foo(3)), map.get(2));
405 }
406
407 let mut map = CompressibleMap::new(FakeFooCompression);
408 do_test_with_global_cache(&mut map);
409 }
410
411 #[test]
412 fn multithreaded_borrows() {
413 use crossbeam::thread;
414
415 let mut map = CompressibleMap::<_, _, _>::new(FakeFooCompression);
417 for i in 0..100 {
418 map.insert(i, Foo(i));
419 }
420
421 for _ in 0..50 {
423 map.compress_lru();
424 }
425
426 let local_cache = LocalCache::new();
428 let mut batch = Vec::new();
429 for i in 0..100 {
430 batch.push(map.get_const(i, &local_cache));
431 }
432
433 thread::scope(|s| {
434 for (i, value) in batch.into_iter().enumerate() {
435 s.spawn(move |_| {
436 if i < 50 {
437 assert_eq!(value, Some(&Foo((i + 2) as u32)))
439 } else {
440 assert_eq!(value, Some(&Foo(i as u32)))
442 }
443 });
444 }
445 })
446 .unwrap();
447
448 map.flush_local_cache(local_cache);
449
450 assert_eq!(map.len_cached(), 100);
451 }
452
453 #[test]
454 fn multithreaded_decompression() {
455 use crossbeam::{channel, thread};
456
457 let mut map = CompressibleMap::<_, _, _>::new(FakeFooCompression);
459 for i in 0..100 {
460 map.insert(i, Foo(i));
461 }
462
463 for _ in 0..50 {
465 map.compress_lru();
466 }
467
468 let map_ref = ↦
470 let (tx, rx) = channel::unbounded();
471 {
472 let mut txs = Vec::new();
473 for _ in 0..99 {
474 txs.push(tx.clone());
475 }
476 txs.push(tx);
477 let txs_ref = &txs;
478
479 thread::scope(|s| {
480 for i in 0..100 {
481 s.spawn(move |_| {
482 let local_cache = LocalCache::new();
483 if i < 50 {
484 assert_eq!(
486 map_ref.get_const(i, &local_cache),
487 Some(&Foo((i + 2) as u32))
488 )
489 } else {
490 assert_eq!(map_ref.get_const(i, &local_cache), Some(&Foo(i as u32)))
492 }
493
494 txs_ref[i as usize].send(local_cache).unwrap();
495 });
496 }
497 })
498 .unwrap();
499 }
500
501 loop {
502 match rx.recv() {
503 Ok(cache) => map.flush_local_cache(cache),
504 Err(_) => {
505 break;
506 }
507 }
508 }
509
510 assert_eq!(map.len_cached(), 100);
511 }
512}