librashader_cache/
cache.rs1use crate::cacheable::Cacheable;
2use crate::key::CacheKey;
3use std::panic::{catch_unwind, AssertUnwindSafe};
4
5pub(crate) mod internal {
6 #[derive(Debug, Error)]
7 enum CatchPanicError {
8 #[error("a panic ocurred when loading the database")]
9 Panic(Box<dyn Any + Send + 'static>),
10 }
11
12 use platform_dirs::AppDirs;
13 use std::any::Any;
14 use std::error::Error;
15 use std::panic::catch_unwind;
16 use std::path::PathBuf;
17
18 use persy::{ByteVec, Config, Persy, ValueMode};
19 use thiserror::Error;
20
21 pub(crate) fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
22 let cache_dir = if let Some(cache_dir) =
23 AppDirs::new(Some("librashader"), false).map(|a| a.cache_dir)
24 {
25 cache_dir
26 } else {
27 let mut current_dir = std::env::current_dir()?;
28 current_dir.push("librashader");
29 current_dir
30 };
31
32 std::fs::create_dir_all(&cache_dir)?;
33
34 Ok(cache_dir)
35 }
36
37 pub(crate) fn remove_cache() {
57 let Ok(cache_dir) = get_cache_dir() else {
58 return;
59 };
60 let path = &cache_dir.join("librashader.db.1");
61 let _ = std::fs::remove_file(path).ok();
62 }
63
64 pub(crate) fn get_cache() -> Result<Persy, Box<dyn Error>> {
65 let cache_dir = get_cache_dir()?;
66 match catch_unwind(|| {
67 Persy::open_or_create_with(
68 &cache_dir.join("librashader.db.1"),
69 Config::new(),
70 |persy| {
71 let tx = persy.begin()?;
72 tx.commit()?;
73 Ok(())
74 },
75 )
76 }) {
77 Ok(Ok(conn)) => Ok(conn),
78 Ok(Err(e)) => {
79 remove_cache();
80 Err(e)?
81 }
82 Err(e) => {
83 remove_cache();
84 Err(CatchPanicError::Panic(e))?
85 }
86 }
87 }
88
89 pub(crate) fn get_blob(
90 conn: &Persy,
91 index: &str,
92 key: &[u8],
93 ) -> Result<Option<Vec<u8>>, Box<dyn Error>> {
94 if !conn.exists_index(index)? {
95 return Ok(None);
96 }
97
98 let value = conn.get::<_, ByteVec>(index, &ByteVec::from(key))?.next();
99 Ok(value.map(|v| v.to_vec()))
100 }
101
102 pub(crate) fn set_blob(
103 conn: &Persy,
104 index: &str,
105 key: &[u8],
106 value: &[u8],
107 ) -> Result<(), Box<dyn Error>> {
108 let mut tx = conn.begin()?;
109 if !tx.exists_index(index)? {
110 tx.create_index::<ByteVec, ByteVec>(index, ValueMode::Replace)?;
111 }
112
113 tx.put(index, ByteVec::from(key), ByteVec::from(value))?;
114 tx.commit()?;
115
116 Ok(())
117 }
118}
119
120pub fn cache_shader_object<E, T, R, H, const KEY_SIZE: usize>(
125 index: &str,
126 keys: &[H; KEY_SIZE],
127 factory: impl Fn(&[H; KEY_SIZE]) -> Result<T, E> + std::panic::RefUnwindSafe,
128 load: impl Fn(T) -> Result<R, E> + std::panic::RefUnwindSafe,
129 bypass_cache: bool,
130) -> Result<R, E>
131where
132 H: CacheKey + std::panic::RefUnwindSafe,
133 T: Cacheable,
134{
135 if bypass_cache {
136 return Ok(load(factory(keys)?)?);
137 }
138
139 catch_unwind(|| {
140 let cache = internal::get_cache();
141
142 let Ok(cache) = cache else {
143 return Ok(load(factory(keys)?)?);
144 };
145
146 let hashkey = {
147 let mut hasher = blake3::Hasher::new();
148 for subkeys in keys {
149 hasher.update(subkeys.hash_bytes());
150 }
151 let hash = hasher.finalize();
152 hash
153 };
154
155 'attempt: {
156 if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
157 let cached = T::from_bytes(&blob).map(&load);
158
159 match cached {
160 None => break 'attempt,
161 Some(Err(_)) => break 'attempt,
162 Some(Ok(res)) => return Ok(res),
163 }
164 }
165 };
166
167 let blob = factory(keys)?;
168
169 if let Some(slice) = T::to_bytes(&blob) {
170 let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
171 }
172 Ok(load(blob)?)
173 })
174 .unwrap_or_else(|_| {
175 internal::remove_cache();
176 Ok(load(factory(keys)?)?)
177 })
178}
179
180pub fn cache_pipeline<E, T, R, const KEY_SIZE: usize>(
188 index: &str,
189 keys: &[&dyn CacheKey; KEY_SIZE],
190 restore_pipeline: impl Fn(Option<Vec<u8>>) -> Result<R, E>,
191 fetch_pipeline_state: impl Fn(&R) -> Result<T, E>,
192 bypass_cache: bool,
193) -> Result<R, E>
194where
195 T: Cacheable,
196{
197 if bypass_cache {
198 return Ok(restore_pipeline(None)?);
199 }
200
201 catch_unwind(AssertUnwindSafe(|| {
202 let cache = internal::get_cache();
203
204 let Ok(cache) = cache else {
205 return Ok(restore_pipeline(None)?);
206 };
207
208 let hashkey = {
209 let mut hasher = blake3::Hasher::new();
210 for subkeys in keys {
211 hasher.update(subkeys.hash_bytes());
212 }
213 let hash = hasher.finalize();
214 hash
215 };
216
217 let pipeline = 'attempt: {
218 if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
219 let cached = restore_pipeline(Some(blob));
220 match cached {
221 Ok(res) => {
222 break 'attempt res;
223 }
224 _ => (),
225 }
226 }
227
228 restore_pipeline(None)?
229 };
230
231 if let Ok(state) = fetch_pipeline_state(&pipeline) {
233 if let Some(slice) = T::to_bytes(&state) {
234 let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
236 }
237 }
238
239 Ok(pipeline)
240 }))
241 .unwrap_or_else(|_| {
242 internal::remove_cache();
243 Ok(restore_pipeline(None)?)
244 })
245}