librashader_cache/
cache.rs

1use crate::cacheable::Cacheable;
2use crate::key::CacheKey;
3use std::panic::{catch_unwind, AssertUnwindSafe};
4
5pub(crate) mod internal {
6    #[derive(Debug, Error)]
7    enum CatchPanicError {
8        #[error("a panic ocurred when loading the database")]
9        Panic(Box<dyn Any + Send + 'static>),
10    }
11
12    use platform_dirs::AppDirs;
13    use std::any::Any;
14    use std::error::Error;
15    use std::panic::catch_unwind;
16    use std::path::PathBuf;
17
18    use persy::{ByteVec, Config, Persy, ValueMode};
19    use thiserror::Error;
20
21    pub(crate) fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
22        let cache_dir = if let Some(cache_dir) =
23            AppDirs::new(Some("librashader"), false).map(|a| a.cache_dir)
24        {
25            cache_dir
26        } else {
27            let mut current_dir = std::env::current_dir()?;
28            current_dir.push("librashader");
29            current_dir
30        };
31
32        std::fs::create_dir_all(&cache_dir)?;
33
34        Ok(cache_dir)
35    }
36
37    // pub(crate) fn get_cache() -> Result<Connection, Box<dyn Error>> {
38    //     let cache_dir = get_cache_dir()?;
39    //     let mut conn = Connection::open(&cache_dir.join("librashader.db"))?;
40    //
41    //     let tx = conn.transaction()?;
42    //     tx.pragma_update(Some(DatabaseName::Main), "journal_mode", "wal2")?;
43    //     tx.execute(
44    //         r#"create table if not exists cache (
45    //     type text not null,
46    //     id blob not null,
47    //     value blob not null unique,
48    //     primary key (id, type)
49    // )"#,
50    //         [],
51    //     )?;
52    //     tx.commit()?;
53    //     Ok(conn)
54    // }
55
56    pub(crate) fn remove_cache() {
57        let Ok(cache_dir) = get_cache_dir() else {
58            return;
59        };
60        let path = &cache_dir.join("librashader.db.1");
61        let _ = std::fs::remove_file(path).ok();
62    }
63
64    pub(crate) fn get_cache() -> Result<Persy, Box<dyn Error>> {
65        let cache_dir = get_cache_dir()?;
66        match catch_unwind(|| {
67            Persy::open_or_create_with(
68                &cache_dir.join("librashader.db.1"),
69                Config::new(),
70                |persy| {
71                    let tx = persy.begin()?;
72                    tx.commit()?;
73                    Ok(())
74                },
75            )
76        }) {
77            Ok(Ok(conn)) => Ok(conn),
78            Ok(Err(e)) => {
79                remove_cache();
80                Err(e)?
81            }
82            Err(e) => {
83                remove_cache();
84                Err(CatchPanicError::Panic(e))?
85            }
86        }
87    }
88
89    pub(crate) fn get_blob(
90        conn: &Persy,
91        index: &str,
92        key: &[u8],
93    ) -> Result<Option<Vec<u8>>, Box<dyn Error>> {
94        if !conn.exists_index(index)? {
95            return Ok(None);
96        }
97
98        let value = conn.get::<_, ByteVec>(index, &ByteVec::from(key))?.next();
99        Ok(value.map(|v| v.to_vec()))
100    }
101
102    pub(crate) fn set_blob(
103        conn: &Persy,
104        index: &str,
105        key: &[u8],
106        value: &[u8],
107    ) -> Result<(), Box<dyn Error>> {
108        let mut tx = conn.begin()?;
109        if !tx.exists_index(index)? {
110            tx.create_index::<ByteVec, ByteVec>(index, ValueMode::Replace)?;
111        }
112
113        tx.put(index, ByteVec::from(key), ByteVec::from(value))?;
114        tx.commit()?;
115
116        Ok(())
117    }
118}
119
120/// Cache a shader object (usually bytecode) created by the keyed objects.
121///
122/// - `factory` is the function that compiles the values passed as keys to a shader object.
123/// - `load` tries to load a compiled shader object to a driver-specialized result.
124pub fn cache_shader_object<E, T, R, H, const KEY_SIZE: usize>(
125    index: &str,
126    keys: &[H; KEY_SIZE],
127    factory: impl Fn(&[H; KEY_SIZE]) -> Result<T, E> + std::panic::RefUnwindSafe,
128    load: impl Fn(T) -> Result<R, E> + std::panic::RefUnwindSafe,
129    bypass_cache: bool,
130) -> Result<R, E>
131where
132    H: CacheKey + std::panic::RefUnwindSafe,
133    T: Cacheable,
134{
135    if bypass_cache {
136        return Ok(load(factory(keys)?)?);
137    }
138
139    catch_unwind(|| {
140        let cache = internal::get_cache();
141
142        let Ok(cache) = cache else {
143            return Ok(load(factory(keys)?)?);
144        };
145
146        let hashkey = {
147            let mut hasher = blake3::Hasher::new();
148            for subkeys in keys {
149                hasher.update(subkeys.hash_bytes());
150            }
151            let hash = hasher.finalize();
152            hash
153        };
154
155        'attempt: {
156            if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
157                let cached = T::from_bytes(&blob).map(&load);
158
159                match cached {
160                    None => break 'attempt,
161                    Some(Err(_)) => break 'attempt,
162                    Some(Ok(res)) => return Ok(res),
163                }
164            }
165        };
166
167        let blob = factory(keys)?;
168
169        if let Some(slice) = T::to_bytes(&blob) {
170            let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
171        }
172        Ok(load(blob)?)
173    })
174    .unwrap_or_else(|_| {
175        internal::remove_cache();
176        Ok(load(factory(keys)?)?)
177    })
178}
179
180/// Cache a pipeline state object.
181///
182/// Keys are not used to create the object and are only used to uniquely identify the pipeline state.
183///
184/// - `restore_pipeline` tries to restore the pipeline with either a cached binary pipeline state
185///    cache, or create a new pipeline if no cached value is available.
186/// - `fetch_pipeline_state` fetches the new pipeline state cache after the pipeline was created.
187pub fn cache_pipeline<E, T, R, const KEY_SIZE: usize>(
188    index: &str,
189    keys: &[&dyn CacheKey; KEY_SIZE],
190    restore_pipeline: impl Fn(Option<Vec<u8>>) -> Result<R, E>,
191    fetch_pipeline_state: impl Fn(&R) -> Result<T, E>,
192    bypass_cache: bool,
193) -> Result<R, E>
194where
195    T: Cacheable,
196{
197    if bypass_cache {
198        return Ok(restore_pipeline(None)?);
199    }
200
201    catch_unwind(AssertUnwindSafe(|| {
202        let cache = internal::get_cache();
203
204        let Ok(cache) = cache else {
205            return Ok(restore_pipeline(None)?);
206        };
207
208        let hashkey = {
209            let mut hasher = blake3::Hasher::new();
210            for subkeys in keys {
211                hasher.update(subkeys.hash_bytes());
212            }
213            let hash = hasher.finalize();
214            hash
215        };
216
217        let pipeline = 'attempt: {
218            if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
219                let cached = restore_pipeline(Some(blob));
220                match cached {
221                    Ok(res) => {
222                        break 'attempt res;
223                    }
224                    _ => (),
225                }
226            }
227
228            restore_pipeline(None)?
229        };
230
231        // update the pso every time just in case.
232        if let Ok(state) = fetch_pipeline_state(&pipeline) {
233            if let Some(slice) = T::to_bytes(&state) {
234                // We don't really care if the transaction fails, just try again next time.
235                let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
236            }
237        }
238
239        Ok(pipeline)
240    }))
241    .unwrap_or_else(|_| {
242        internal::remove_cache();
243        Ok(restore_pipeline(None)?)
244    })
245}