1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use core::hash::Hash;
use dashmap::{DashMap, DashSet};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::marker::{Send, Sync};

#[derive(Clone, Debug)]
pub struct DashmapCache {
    inner: DashMap<Vec<u8>, Vec<u8>>,
    tags: DashMap<String, DashSet<Vec<u8>>>,
}

#[derive(Debug)]
pub enum CacheError {
    Decode(rmp_serde::decode::Error),
    Encode(rmp_serde::encode::Error),
}

impl From<rmp_serde::decode::Error> for CacheError {
    fn from(value: rmp_serde::decode::Error) -> Self {
        Self::Decode(value)
    }
}

impl From<rmp_serde::encode::Error> for CacheError {
    fn from(value: rmp_serde::encode::Error) -> Self {
        Self::Encode(value)
    }
}

impl<'a> DashmapCache {
    pub fn new() -> Self {
        let inner = DashMap::new();
        Self {
            inner,
            tags: DashMap::new(),
        }
    }

    fn insert(&self, tags: &Vec<String>, key: Vec<u8>, val: Vec<u8>) -> Option<Vec<u8>> {
        for tag in tags {
            if !self.tags.contains_key(tag) {
                let dash = DashSet::new();
                dash.insert(key.clone());
                self.tags.insert(tag.to_owned(), dash);
            } else {
                self.tags.alter(tag, |_k, ex_tags| {
                    ex_tags.insert(key.clone());
                    ex_tags
                })
            }
        }
        self.inner.insert(key, val)
    }

    pub fn cached<F, A, V>(
        &self,
        invalidate_keys: &Vec<String>,
        closure: F,
        arg: A,
    ) -> Result<V, CacheError>
    where
        F: Fn(&A) -> V,
        A: Hash + Sync + Send + Eq + Serialize,
        V: Send + Sync + Clone + Serialize + for<'b> Deserialize<'b>,
    {
        let arg_bytes = rmp_serde::to_vec(&arg)?;

        match self.inner.get(&arg_bytes) {
            None => {
                let val = closure(&arg);
                let val_bytes = rmp_serde::to_vec(&val)?;
                self.insert(invalidate_keys, arg_bytes, val_bytes);
                Ok(val)
            }
            Some(val) => {
                let ret_val = rmp_serde::from_slice::<V>(&val)?;
                Ok(ret_val.to_owned())
            }
        }
    }

    fn invalidate_inner(&self, tag: &str) {
        let hashes = self.tags.get(tag);
        match hashes {
            Some(lst_hashes) => {
                for hsh in lst_hashes.clone() {
                    self.inner.remove(&hsh);
                }
            }
            None => {}
        }
    }

    pub fn invalidate(&self, tag: &str) {
        self.invalidate_inner(tag);
        self.tags.remove(tag);
    }
}