concrete_utils/
keycache.rs

1use once_cell::sync::OnceCell;
2use serde::de::DeserializeOwned;
3use serde::Serialize;
4use std::fs::File;
5use std::io::{BufReader, BufWriter};
6use std::ops::Deref;
7use std::path::PathBuf;
8use std::sync::{Arc, RwLock};
9
10pub trait PersistentStorage<P, K> {
11    fn load(&self, param: P) -> Option<K>;
12    fn store(&self, param: P, key: &K);
13}
14
15pub trait NamedParam {
16    fn name(&self) -> String;
17}
18
19#[macro_export]
20macro_rules! named_params_impl(
21  ( $thing:ident == ( $($const_param:ident),* $(,)? )) => {
22      named_params_impl!({ *$thing } == ( $($const_param),* ))
23  };
24
25  ( { $thing:expr } == ( $($const_param:ident),* $(,)? )) => {
26      $(
27        if $thing == $const_param {
28            return stringify!($const_param).to_string();
29        }
30      )*
31
32      panic!("Unnamed parameters");
33  }
34);
35
36pub struct FileStorage {
37    prefix: String,
38}
39
40impl FileStorage {
41    pub fn new(prefix: String) -> Self {
42        Self { prefix }
43    }
44}
45
46impl<P, K> PersistentStorage<P, K> for FileStorage
47where
48    P: NamedParam + DeserializeOwned + Serialize + PartialEq,
49    K: DeserializeOwned + Serialize,
50{
51    fn load(&self, param: P) -> Option<K> {
52        let mut path_buf = PathBuf::with_capacity(256);
53        path_buf.push(&self.prefix);
54        path_buf.push(param.name());
55        path_buf.set_extension("bin");
56
57        if path_buf.exists() {
58            let file = BufReader::new(File::open(&path_buf).unwrap());
59            bincode::deserialize_from::<_, (P, K)>(file)
60                .ok()
61                .and_then(|(p, k)| if p == param { Some(k) } else { None })
62        } else {
63            None
64        }
65    }
66
67    fn store(&self, param: P, key: &K) {
68        let mut path_buf = PathBuf::with_capacity(256);
69        path_buf.push(&self.prefix);
70        std::fs::create_dir_all(&path_buf).unwrap();
71        path_buf.push(param.name());
72        path_buf.set_extension("bin");
73
74        let file = BufWriter::new(File::create(&path_buf).unwrap());
75        bincode::serialize_into(file, &(param, key)).unwrap();
76    }
77}
78
79pub struct SharedKey<K> {
80    inner: Arc<OnceCell<K>>,
81}
82
83impl<K> Clone for SharedKey<K> {
84    fn clone(&self) -> Self {
85        Self {
86            inner: self.inner.clone(),
87        }
88    }
89}
90
91impl<K> Deref for SharedKey<K> {
92    type Target = K;
93
94    fn deref(&self) -> &Self::Target {
95        self.inner.get().unwrap()
96    }
97}
98
99pub struct KeyCache<P, K, S> {
100    // Where the keys will be stored persistently
101    // So they are not generated between each run
102    persistent_storage: S,
103    // Temporary memory storage to avoid querying the persistent storage each time
104    // the outer Arc makes it so that we don't clone the OnceCell contents when initializing it
105    memory_storage: RwLock<Vec<(P, SharedKey<K>)>>,
106}
107
108impl<P, K, S> KeyCache<P, K, S> {
109    pub fn new(storage: S) -> Self {
110        Self {
111            persistent_storage: storage,
112            memory_storage: RwLock::new(vec![]),
113        }
114    }
115}
116
117impl<P, K, S> KeyCache<P, K, S>
118where
119    P: Copy + PartialEq + NamedParam,
120    S: PersistentStorage<P, K>,
121    K: From<P> + Clone,
122{
123    pub fn get(&self, param: P) -> SharedKey<K> {
124        self.with_key(param, |k| k.clone())
125    }
126
127    pub fn with_key<F, R>(&self, param: P, f: F) -> R
128    where
129        F: FnOnce(&SharedKey<K>) -> R,
130    {
131        let load_from_persistent_storage = || {
132            // we check if we can load the key from persistent storage
133            let persistent_storage = &self.persistent_storage;
134            let maybe_key = persistent_storage.load(param);
135            match maybe_key {
136                Some(key) => key,
137                None => {
138                    let key = K::from(param);
139                    persistent_storage.store(param, &key);
140                    key
141                }
142            }
143        };
144
145        let try_load_from_memory_and_init = || {
146            // we only hold a read lock for a short duration to find the key
147            let memory_storage = self.memory_storage.read().unwrap();
148            let maybe_shared_cell = memory_storage
149                .iter()
150                .find(|(p, _)| *p == param)
151                .map(|param_key| param_key.1.clone());
152            drop(memory_storage);
153
154            if let Some(shared_cell) = maybe_shared_cell {
155                shared_cell.inner.get_or_init(load_from_persistent_storage);
156                Ok(shared_cell)
157            } else {
158                Err(())
159            }
160        };
161
162        match try_load_from_memory_and_init() {
163            Ok(result) => f(&result),
164            Err(()) => {
165                {
166                    // we only hold a write lock for a short duration to push the lazily evaluated
167                    // key without actually evaluating the key
168                    let mut memory_storage = self.memory_storage.write().unwrap();
169                    if !memory_storage.iter().any(|(p, _)| *p == param) {
170                        memory_storage.push((
171                            param,
172                            SharedKey {
173                                inner: Arc::new(OnceCell::new()),
174                            },
175                        ));
176                    }
177                }
178                f(&try_load_from_memory_and_init().ok().unwrap())
179            }
180        }
181    }
182}