1use crate::cacheable::Cacheable;
2use crate::key::CacheKey;
3use std::panic::{catch_unwind, AssertUnwindSafe};
4
5pub(crate) mod internal {
6 #[derive(Debug, Error)]
7 enum CatchPanicError {
8 #[error("a panic ocurred when loading the database")]
9 Panic(Box<dyn Any + Send + 'static>),
10 }
11
12 use parking_lot::Mutex;
13 use persy::{ByteVec, Config, Persy, ValueMode};
14 use platform_dirs::AppDirs;
15 use std::any::Any;
16 use std::error::Error;
17 use std::panic::catch_unwind;
18 use std::path::PathBuf;
19 use std::sync::OnceLock;
20 use thiserror::Error;
21
22 pub(crate) fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
23 let cache_dir = if let Some(cache_dir) =
24 AppDirs::new(Some("librashader"), false).map(|a| a.cache_dir)
25 {
26 cache_dir
27 } else {
28 let mut current_dir = std::env::current_dir()?;
29 current_dir.push("librashader");
30 current_dir
31 };
32
33 std::fs::create_dir_all(&cache_dir)?;
34
35 Ok(cache_dir)
36 }
37
38 pub(crate) fn remove_cache() {
58 let Ok(cache_dir) = get_cache_dir() else {
59 return;
60 };
61 let path = &cache_dir.join("librashader.db.1");
62 let _ = std::fs::remove_file(path).ok();
63 }
64
65 pub(crate) fn get_cache() -> Result<Persy, Box<dyn Error>> {
66 let cache_dir = get_cache_dir()?;
67 static CACHE: OnceLock<Persy> = OnceLock::new();
68
69 if let Some(persy) = CACHE.get() {
70 return Ok(persy.clone());
71 }
72
73 let persy = match catch_unwind(|| {
74 Persy::open_or_create_with(
75 &cache_dir.join("librashader.db.1"),
76 Config::new(),
77 |persy| {
78 let tx = persy.begin()?;
79 tx.commit()?;
80 Ok(())
81 },
82 )
83 }) {
84 Ok(Ok(conn)) => Ok::<_, Box<dyn Error>>(conn),
85 Ok(Err(e)) => {
86 remove_cache();
87 Err(e)?
88 }
89 Err(e) => {
90 remove_cache();
91 Err(CatchPanicError::Panic(e))?
92 }
93 }?;
94
95 Ok(CACHE.get_or_init(move || persy).clone())
96 }
97
98 pub(crate) fn get_blob(
99 conn: &Persy,
100 index: &str,
101 key: &[u8],
102 ) -> Result<Option<Vec<u8>>, Box<dyn Error>> {
103 if !conn.exists_index(index)? {
104 return Ok(None);
105 }
106
107 let value = conn.get::<_, ByteVec>(index, &ByteVec::from(key))?.next();
108 Ok(value.map(|v| v.to_vec()))
109 }
110
111 pub(crate) fn set_blob(
112 conn: &Persy,
113 index: &str,
114 key: &[u8],
115 value: &[u8],
116 ) -> Result<(), Box<dyn Error>> {
117 static WRITE_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
118 let write_lock = WRITE_LOCK.get_or_init(|| Mutex::new(()));
119 let _guard = write_lock.lock();
120
121 let mut tx = conn.begin()?;
122 if !tx.exists_index(index)? {
123 tx.create_index::<ByteVec, ByteVec>(index, ValueMode::Replace)?;
124 }
125
126 tx.put(index, ByteVec::from(key), ByteVec::from(value))?;
127 tx.commit()?;
128
129 Ok(())
130 }
131}
132
133pub fn cache_shader_object<E, T, R, H, const KEY_SIZE: usize>(
138 index: &str,
139 keys: &[H; KEY_SIZE],
140 factory: impl Fn(&[H; KEY_SIZE]) -> Result<T, E> + std::panic::RefUnwindSafe,
141 load: impl Fn(T) -> Result<R, E> + std::panic::RefUnwindSafe,
142 bypass_cache: bool,
143) -> Result<R, E>
144where
145 H: CacheKey + std::panic::RefUnwindSafe,
146 T: Cacheable,
147{
148 if bypass_cache {
149 return Ok(load(factory(keys)?)?);
150 }
151
152 catch_unwind(|| {
153 let cache = internal::get_cache();
154
155 let Ok(cache) = cache else {
156 return Ok(load(factory(keys)?)?);
157 };
158
159 let hashkey = {
160 let mut hasher = blake3::Hasher::new();
161 for subkeys in keys {
162 hasher.update(subkeys.hash_bytes());
163 }
164 let hash = hasher.finalize();
165 hash
166 };
167
168 'attempt: {
169 if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
170 let cached = T::from_bytes(&blob).map(&load);
171
172 match cached {
173 None => break 'attempt,
174 Some(Err(_)) => break 'attempt,
175 Some(Ok(res)) => return Ok(res),
176 }
177 }
178 };
179
180 let blob = factory(keys)?;
181
182 if let Some(slice) = T::to_bytes(&blob) {
183 let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
184 }
185 Ok(load(blob)?)
186 })
187 .unwrap_or_else(|_| {
188 internal::remove_cache();
189 Ok(load(factory(keys)?)?)
190 })
191}
192
193pub fn cache_shader_object_deferred<'a, E, T, R, H, const KEY_SIZE: usize>(
205 index: &str,
206 keys: &[H; KEY_SIZE],
207 factory: impl Fn(&[H; KEY_SIZE]) -> Result<T, E> + std::panic::RefUnwindSafe,
208 load: impl FnOnce(T) -> Result<R, E> + Send + 'a,
209 bypass_cache: bool,
210) -> Result<Box<dyn FnOnce() -> Result<R, E> + Send + 'a>, E>
211where
212 H: CacheKey + std::panic::RefUnwindSafe,
213 T: Cacheable + Send + 'a,
214{
215 let object = if bypass_cache {
216 factory(keys)
217 } else {
218 catch_unwind(|| {
219 let Ok(cache) = internal::get_cache() else {
220 return factory(keys);
221 };
222
223 let hashkey = {
224 let mut hasher = blake3::Hasher::new();
225 for subkeys in keys {
226 hasher.update(subkeys.hash_bytes());
227 }
228 hasher.finalize()
229 };
230
231 if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
232 if let Some(cached) = T::from_bytes(&blob) {
233 return Ok(cached);
234 }
235 }
236
237 let blob = factory(keys)?;
238
239 if let Some(slice) = T::to_bytes(&blob) {
240 let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
241 }
242 Ok(blob)
243 })
244 .unwrap_or_else(|_| {
245 internal::remove_cache();
246 factory(keys)
247 })
248 }?;
249
250 Ok(Box::new(move || load(object)))
251}
252
253pub fn cache_pipeline<E, T, R, const KEY_SIZE: usize>(
261 index: &str,
262 keys: &[&dyn CacheKey; KEY_SIZE],
263 restore_pipeline: impl Fn(Option<Vec<u8>>) -> Result<R, E>,
264 fetch_pipeline_state: impl Fn(&R) -> Result<T, E>,
265 bypass_cache: bool,
266) -> Result<R, E>
267where
268 T: Cacheable,
269{
270 if bypass_cache {
271 return Ok(restore_pipeline(None)?);
272 }
273
274 catch_unwind(AssertUnwindSafe(|| {
275 let cache = internal::get_cache();
276
277 let Ok(cache) = cache else {
278 return Ok(restore_pipeline(None)?);
279 };
280
281 let hashkey = {
282 let mut hasher = blake3::Hasher::new();
283 for subkeys in keys {
284 hasher.update(subkeys.hash_bytes());
285 }
286 let hash = hasher.finalize();
287 hash
288 };
289
290 let pipeline = 'attempt: {
291 if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
292 let cached = restore_pipeline(Some(blob));
293 match cached {
294 Ok(res) => {
295 break 'attempt res;
296 }
297 _ => (),
298 }
299 }
300
301 restore_pipeline(None)?
302 };
303
304 if let Ok(state) = fetch_pipeline_state(&pipeline) {
306 if let Some(slice) = T::to_bytes(&state) {
307 let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
309 }
310 }
311
312 Ok(pipeline)
313 }))
314 .unwrap_or_else(|_| {
315 internal::remove_cache();
316 Ok(restore_pipeline(None)?)
317 })
318}