Skip to main content

librashader_cache/
cache.rs

1use crate::cacheable::Cacheable;
2use crate::key::CacheKey;
3use std::panic::{catch_unwind, AssertUnwindSafe};
4
5pub(crate) mod internal {
6    #[derive(Debug, Error)]
7    enum CatchPanicError {
8        #[error("a panic ocurred when loading the database")]
9        Panic(Box<dyn Any + Send + 'static>),
10    }
11
12    use parking_lot::Mutex;
13    use persy::{ByteVec, Config, Persy, ValueMode};
14    use platform_dirs::AppDirs;
15    use std::any::Any;
16    use std::error::Error;
17    use std::panic::catch_unwind;
18    use std::path::PathBuf;
19    use std::sync::OnceLock;
20    use thiserror::Error;
21
22    pub(crate) fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
23        let cache_dir = if let Some(cache_dir) =
24            AppDirs::new(Some("librashader"), false).map(|a| a.cache_dir)
25        {
26            cache_dir
27        } else {
28            let mut current_dir = std::env::current_dir()?;
29            current_dir.push("librashader");
30            current_dir
31        };
32
33        std::fs::create_dir_all(&cache_dir)?;
34
35        Ok(cache_dir)
36    }
37
38    // pub(crate) fn get_cache() -> Result<Connection, Box<dyn Error>> {
39    //     let cache_dir = get_cache_dir()?;
40    //     let mut conn = Connection::open(&cache_dir.join("librashader.db"))?;
41    //
42    //     let tx = conn.transaction()?;
43    //     tx.pragma_update(Some(DatabaseName::Main), "journal_mode", "wal2")?;
44    //     tx.execute(
45    //         r#"create table if not exists cache (
46    //     type text not null,
47    //     id blob not null,
48    //     value blob not null unique,
49    //     primary key (id, type)
50    // )"#,
51    //         [],
52    //     )?;
53    //     tx.commit()?;
54    //     Ok(conn)
55    // }
56
57    pub(crate) fn remove_cache() {
58        let Ok(cache_dir) = get_cache_dir() else {
59            return;
60        };
61        let path = &cache_dir.join("librashader.db.1");
62        let _ = std::fs::remove_file(path).ok();
63    }
64
65    pub(crate) fn get_cache() -> Result<Persy, Box<dyn Error>> {
66        let cache_dir = get_cache_dir()?;
67        static CACHE: OnceLock<Persy> = OnceLock::new();
68
69        if let Some(persy) = CACHE.get() {
70            return Ok(persy.clone());
71        }
72
73        let persy = match catch_unwind(|| {
74            Persy::open_or_create_with(
75                &cache_dir.join("librashader.db.1"),
76                Config::new(),
77                |persy| {
78                    let tx = persy.begin()?;
79                    tx.commit()?;
80                    Ok(())
81                },
82            )
83        }) {
84            Ok(Ok(conn)) => Ok::<_, Box<dyn Error>>(conn),
85            Ok(Err(e)) => {
86                remove_cache();
87                Err(e)?
88            }
89            Err(e) => {
90                remove_cache();
91                Err(CatchPanicError::Panic(e))?
92            }
93        }?;
94
95        Ok(CACHE.get_or_init(move || persy).clone())
96    }
97
98    pub(crate) fn get_blob(
99        conn: &Persy,
100        index: &str,
101        key: &[u8],
102    ) -> Result<Option<Vec<u8>>, Box<dyn Error>> {
103        if !conn.exists_index(index)? {
104            return Ok(None);
105        }
106
107        let value = conn.get::<_, ByteVec>(index, &ByteVec::from(key))?.next();
108        Ok(value.map(|v| v.to_vec()))
109    }
110
111    pub(crate) fn set_blob(
112        conn: &Persy,
113        index: &str,
114        key: &[u8],
115        value: &[u8],
116    ) -> Result<(), Box<dyn Error>> {
117        static WRITE_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
118        let write_lock = WRITE_LOCK.get_or_init(|| Mutex::new(()));
119        let _guard = write_lock.lock();
120
121        let mut tx = conn.begin()?;
122        if !tx.exists_index(index)? {
123            tx.create_index::<ByteVec, ByteVec>(index, ValueMode::Replace)?;
124        }
125
126        tx.put(index, ByteVec::from(key), ByteVec::from(value))?;
127        tx.commit()?;
128
129        Ok(())
130    }
131}
132
133/// Cache a shader object (usually bytecode) created by the keyed objects.
134///
135/// - `factory` is the function that compiles the values passed as keys to a shader object.
136/// - `load` tries to load a compiled shader object to a driver-specialized result.
137pub fn cache_shader_object<E, T, R, H, const KEY_SIZE: usize>(
138    index: &str,
139    keys: &[H; KEY_SIZE],
140    factory: impl Fn(&[H; KEY_SIZE]) -> Result<T, E> + std::panic::RefUnwindSafe,
141    load: impl Fn(T) -> Result<R, E> + std::panic::RefUnwindSafe,
142    bypass_cache: bool,
143) -> Result<R, E>
144where
145    H: CacheKey + std::panic::RefUnwindSafe,
146    T: Cacheable,
147{
148    if bypass_cache {
149        return Ok(load(factory(keys)?)?);
150    }
151
152    catch_unwind(|| {
153        let cache = internal::get_cache();
154
155        let Ok(cache) = cache else {
156            return Ok(load(factory(keys)?)?);
157        };
158
159        let hashkey = {
160            let mut hasher = blake3::Hasher::new();
161            for subkeys in keys {
162                hasher.update(subkeys.hash_bytes());
163            }
164            let hash = hasher.finalize();
165            hash
166        };
167
168        'attempt: {
169            if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
170                let cached = T::from_bytes(&blob).map(&load);
171
172                match cached {
173                    None => break 'attempt,
174                    Some(Err(_)) => break 'attempt,
175                    Some(Ok(res)) => return Ok(res),
176                }
177            }
178        };
179
180        let blob = factory(keys)?;
181
182        if let Some(slice) = T::to_bytes(&blob) {
183            let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
184        }
185        Ok(load(blob)?)
186    })
187    .unwrap_or_else(|_| {
188        internal::remove_cache();
189        Ok(load(factory(keys)?)?)
190    })
191}
192
193/// Cache a shader object (usually bytecode) created by the keyed objects, deferring the load step.
194///
195/// This behaves like [`cache_shader_object`], except that `load` is not executed immediately.
196/// Instead, the compiled (or cached) shader object is curried into the returned closure, which can
197/// be invoked later to produce the driver-specialized result. This allows shader objects to be
198/// compiled in parallel via `factory`, then have their driver resources created sequentially by the
199/// returned closures, for drivers whose object creation is not safe to call concurrently.
200///
201/// Because `load` is deferred, a cached shader object can not be validated by `load` at fetch time:
202/// unlike [`cache_shader_object`], a cached object that fails to load will not be transparently
203/// recompiled.
204pub fn cache_shader_object_deferred<'a, E, T, R, H, const KEY_SIZE: usize>(
205    index: &str,
206    keys: &[H; KEY_SIZE],
207    factory: impl Fn(&[H; KEY_SIZE]) -> Result<T, E> + std::panic::RefUnwindSafe,
208    load: impl FnOnce(T) -> Result<R, E> + Send + 'a,
209    bypass_cache: bool,
210) -> Result<Box<dyn FnOnce() -> Result<R, E> + Send + 'a>, E>
211where
212    H: CacheKey + std::panic::RefUnwindSafe,
213    T: Cacheable + Send + 'a,
214{
215    let object = if bypass_cache {
216        factory(keys)
217    } else {
218        catch_unwind(|| {
219            let Ok(cache) = internal::get_cache() else {
220                return factory(keys);
221            };
222
223            let hashkey = {
224                let mut hasher = blake3::Hasher::new();
225                for subkeys in keys {
226                    hasher.update(subkeys.hash_bytes());
227                }
228                hasher.finalize()
229            };
230
231            if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
232                if let Some(cached) = T::from_bytes(&blob) {
233                    return Ok(cached);
234                }
235            }
236
237            let blob = factory(keys)?;
238
239            if let Some(slice) = T::to_bytes(&blob) {
240                let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
241            }
242            Ok(blob)
243        })
244        .unwrap_or_else(|_| {
245            internal::remove_cache();
246            factory(keys)
247        })
248    }?;
249
250    Ok(Box::new(move || load(object)))
251}
252
253/// Cache a pipeline state object.
254///
255/// Keys are not used to create the object and are only used to uniquely identify the pipeline state.
256///
257/// - `restore_pipeline` tries to restore the pipeline with either a cached binary pipeline state
258///    cache, or create a new pipeline if no cached value is available.
259/// - `fetch_pipeline_state` fetches the new pipeline state cache after the pipeline was created.
260pub fn cache_pipeline<E, T, R, const KEY_SIZE: usize>(
261    index: &str,
262    keys: &[&dyn CacheKey; KEY_SIZE],
263    restore_pipeline: impl Fn(Option<Vec<u8>>) -> Result<R, E>,
264    fetch_pipeline_state: impl Fn(&R) -> Result<T, E>,
265    bypass_cache: bool,
266) -> Result<R, E>
267where
268    T: Cacheable,
269{
270    if bypass_cache {
271        return Ok(restore_pipeline(None)?);
272    }
273
274    catch_unwind(AssertUnwindSafe(|| {
275        let cache = internal::get_cache();
276
277        let Ok(cache) = cache else {
278            return Ok(restore_pipeline(None)?);
279        };
280
281        let hashkey = {
282            let mut hasher = blake3::Hasher::new();
283            for subkeys in keys {
284                hasher.update(subkeys.hash_bytes());
285            }
286            let hash = hasher.finalize();
287            hash
288        };
289
290        let pipeline = 'attempt: {
291            if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
292                let cached = restore_pipeline(Some(blob));
293                match cached {
294                    Ok(res) => {
295                        break 'attempt res;
296                    }
297                    _ => (),
298                }
299            }
300
301            restore_pipeline(None)?
302        };
303
304        // update the pso every time just in case.
305        if let Ok(state) = fetch_pipeline_state(&pipeline) {
306            if let Some(slice) = T::to_bytes(&state) {
307                // We don't really care if the transaction fails, just try again next time.
308                let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
309            }
310        }
311
312        Ok(pipeline)
313    }))
314    .unwrap_or_else(|_| {
315        internal::remove_cache();
316        Ok(restore_pipeline(None)?)
317    })
318}