causal_hub/types/
cache.rs

1use std::sync::{Arc, RwLock};
2
3use dry::macro_for;
4
5use crate::{
6    estimators::{CIMEstimator, CPDEstimator},
7    models::{CIM, CPD, Labelled},
8    types::{Labels, Map, Set},
9};
10
11/// A cache for calling a function with a key and value.
12#[derive(Clone, Debug)]
13pub struct Cache<'a, C, K, V> {
14    call: &'a C,
15    cache: Arc<RwLock<Map<K, V>>>,
16}
17
18impl<'a, E, P> Cache<'a, E, (Vec<usize>, Vec<usize>), P>
19where
20    P: Clone,
21{
22    /// Create a new cache.
23    ///
24    /// # Arguments
25    ///
26    /// * `call` - The function to call.
27    ///
28    /// # Returns
29    ///
30    /// A new cache.
31    ///
32    #[inline]
33    pub fn new(call: &'a E) -> Self {
34        // Create a new cache.
35        let cache = Arc::new(RwLock::new(Map::default()));
36
37        Self { call, cache }
38    }
39}
40
41impl<C, K, V> Labelled for Cache<'_, C, K, V>
42where
43    C: Labelled,
44{
45    #[inline]
46    fn labels(&self) -> &Labels {
47        self.call.labels()
48    }
49}
50
51macro_for!($type in [CPD, CIM] {
52    paste::paste! {
53
54        impl<E, P> [<$type Estimator>]<P> for Cache<'_, E, (Vec<usize>, Vec<usize>), P>
55        where
56            E: [<$type Estimator>]<P>,
57            P: $type + Clone,
58            P::Statistics: Clone,
59        {
60            fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> P {
61                // Get the key.
62                let key: (Vec<_>, Vec<_>) = (
63                    x.into_iter().cloned().collect(),
64                    z.into_iter().cloned().collect(),
65                );
66                // Check if the key is in the cache.
67                if let Some(value) = self.cache.read().unwrap().get(&key) {
68                    // If it is, return the value.
69                    return value.clone();
70                }
71                // If it is not, call the function.
72                let value = self.call.fit(x, z);
73                // Insert the value into the cache.
74                self.cache.write().unwrap().insert(key, value.clone());
75                // Return the value.
76                value
77            }
78        }
79
80    }
81});