causal_hub/types/
cache.rs1use std::sync::{Arc, RwLock};
2
3use dry::macro_for;
4
5use crate::{
6 estimators::{CIMEstimator, CPDEstimator},
7 models::{CIM, CPD, Labelled},
8 types::{Labels, Map, Set},
9};
10
11#[derive(Clone, Debug)]
13pub struct Cache<'a, C, K, V> {
14 call: &'a C,
15 cache: Arc<RwLock<Map<K, V>>>,
16}
17
18impl<'a, E, P> Cache<'a, E, (Vec<usize>, Vec<usize>), P>
19where
20 P: Clone,
21{
22 #[inline]
33 pub fn new(call: &'a E) -> Self {
34 let cache = Arc::new(RwLock::new(Map::default()));
36
37 Self { call, cache }
38 }
39}
40
41impl<C, K, V> Labelled for Cache<'_, C, K, V>
42where
43 C: Labelled,
44{
45 #[inline]
46 fn labels(&self) -> &Labels {
47 self.call.labels()
48 }
49}
50
51macro_for!($type in [CPD, CIM] {
52 paste::paste! {
53
54 impl<E, P> [<$type Estimator>]<P> for Cache<'_, E, (Vec<usize>, Vec<usize>), P>
55 where
56 E: [<$type Estimator>]<P>,
57 P: $type + Clone,
58 P::Statistics: Clone,
59 {
60 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> P {
61 let key: (Vec<_>, Vec<_>) = (
63 x.into_iter().cloned().collect(),
64 z.into_iter().cloned().collect(),
65 );
66 if let Some(value) = self.cache.read().unwrap().get(&key) {
68 return value.clone();
70 }
71 let value = self.call.fit(x, z);
73 self.cache.write().unwrap().insert(key, value.clone());
75 value
77 }
78 }
79
80 }
81});