1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
//! The [AssetCache] drives a [Vfs], producing an output type from an input type with 2 levels of caching.
//!
//! Typically, the [Vfs] is a file-backed implementation, and the output type is something like a decoded image or audio
//! buffer.
//!
//! The cache caches at two points:
//!
//! - First, at the level of reading bytes into an in-memory buffer.
//! - Second, the actual decoded objects themselves.
//!
//! Any asset which is so critical that it must never be unloaded may be pinned with [AssetCache::cache_always], at
//! which point it may only be removed with [AssetCache::remove_key].
use std::io::{Error as IoError, Read};
use std::sync::{Arc, Mutex, RwLock};

use crate::*;

type CacheHashMap<V> = std::collections::HashMap<String, V, ahash::RandomState>;

/// Configuration for a [AssetCache].
///
/// This type doesn't implement `Default`: applications should carefully consider their memory requirements and decide
/// on appropriate values.
#[derive(Debug, derive_builder::Builder)]
pub struct AssetCacheConfig {
    /// Maximum cost of the bytes cache in bytes.
    pub max_bytes_cost: u64,
    /// Maximum cost of the decoded cache in bytes.
    pub max_decoded_cost: u64,
    /// Maximum size of a single vec of bytes before we won't cache it.
    ///
    /// Use this to avoid caching huge objects.
    pub max_single_object_bytes_cost: u64,
    /// Point at which we will avoid caching individual decoded objects.
    ///
    /// For example maybe your audio file is 50mb when decoded, and you'd like to not keep it around.
    ///
    /// Note that even when we choose not to cache such objects, we still keep them around via weak references, so it's
    /// not always the case that the cache will refuse to give it back to you without decoding a second time.
    pub max_single_object_decoded_cost: u64,
}

/// The Asset cache itself.  See crate level documentation for details.
pub struct AssetCache<VfsImpl: Vfs, DecoderImpl: Decoder> {
    config: AssetCacheConfig,
    pinned_entries: RwLock<CacheHashMap<Arc<DecoderImpl::Output>>>,
    bytes_cache: Mutex<CostBasedLru<str, Vec<u8>>>,
    decoded_cache: Mutex<CostBasedLru<str, DecoderImpl::Output>>,
    /// Mutexes that stop multiple threads trying to decode the same content.
    decoding_guards: Mutex<CacheHashMap<Arc<Mutex<()>>>>,
    /// After eviction, we can still give the item back if something external kept it around; do so unless the user explicitly deleted it.
    weak_refs: RwLock<CacheHashMap<std::sync::Weak<DecoderImpl::Output>>>,
    vfs: VfsImpl,
    decoder: DecoderImpl,
}

/// An error from attempting to decode via the asset cache.
#[derive(Debug, thiserror::Error)]
pub enum AssetCacheError<DecoderError> {
    /// The error comes from the [Vfs].
    #[error("VFS error reading from cache")]
    Vfs(#[source] IoError),
    /// The error comes from the [Decoder].
    #[error("Decoder error reading from cache")]
    Decoder(#[source] DecoderError),
}

impl<VfsImpl: Vfs, DecoderImpl: Decoder> AssetCache<VfsImpl, DecoderImpl> {
    pub fn new(
        vfs: VfsImpl,
        decoder: DecoderImpl,
        config: AssetCacheConfig,
    ) -> AssetCache<VfsImpl, DecoderImpl> {
        AssetCache {
            decoder,
            vfs,
            bytes_cache: Mutex::new(CostBasedLru::new(config.max_bytes_cost)),
            decoded_cache: Mutex::new(CostBasedLru::new(config.max_decoded_cost)),
            decoding_guards: Default::default(),
            pinned_entries: RwLock::new(Default::default()),
            weak_refs: RwLock::new(Default::default()),
            config,
        }
    }

    /// Find an item in the cache, returning `None` if it isn't currently cached.
    fn search_for_item(&self, key: &str) -> Option<Arc<DecoderImpl::Output>> {
        {
            let guard = self.pinned_entries.read().unwrap();
            if let Some(x) = guard.get(key) {
                return Some((*x).clone());
            }
        }

        {
            let mut guard = self.decoded_cache.lock().unwrap();
            if let Some(x) = guard.get(key) {
                return Some(x);
            }
        }

        // The unlikely pessimistic case is that this item is in the weak references; let's try to get it out.
        self.weak_refs
            .read()
            .unwrap()
            .get(key)
            .and_then(|x| x.upgrade())
    }

    /// Decode an item for the cache, assuming we definitely know it isn't present and are holding the guard necessary
    /// to stop other threads from attempting to do so in parallel.
    ///
    /// This is hard to break up into smaller functions, unfortunately.
    fn find_or_decode_postchecked(
        &self,
        key: &str,
    ) -> Result<Arc<DecoderImpl::Output>, AssetCacheError<DecoderImpl::Error>> {
        // First, if we can find the item, return it immediately.
        if let Some(x) = self.search_for_item(key) {
            return Ok(x);
        }

        // If we can get the size of the item, and it is less than the single object limit, we cache a vec of bytes.
        // Otherwise, we feed the reader into the decoder directly.

        let mut bytes_reader = self.vfs.open(key).map_err(AssetCacheError::Vfs)?;
        let size = bytes_reader.get_size().map_err(AssetCacheError::Vfs)?;
        let decoded = if size <= self.config.max_single_object_bytes_cost {
            let maybe_cached_bytes = self.bytes_cache.lock().unwrap().get(key);
            if let Some(x) = maybe_cached_bytes {
                self.decoder
                    .decode_bytes(&x[..])
                    .map_err(AssetCacheError::Decoder)?
            } else {
                // Read to a vec, insert that vec, then read from the vec.
                let mut dest = vec![];
                bytes_reader
                    .read_to_end(&mut dest)
                    .map_err(AssetCacheError::Vfs)?;
                let will_use = {
                    let mut guard = self.bytes_cache.lock().unwrap();
                    guard.insert(key.to_string().into(), dest, size);
                    guard.get(key).expect("We just inserted this")
                };
                self.decoder
                    .decode_bytes(&will_use[..])
                    .map_err(AssetCacheError::Decoder)?
            }
        } else {
            // The object was too big, or we couldn't get the size; in this case, we feed the vfs directly to the
            // decoder.
            self.decoder
                .decode(bytes_reader)
                .map_err(AssetCacheError::Decoder)?
        };

        let cost = self
            .decoder
            .estimate_cost(&decoded)
            .map_err(AssetCacheError::Decoder)?;
        let res = if cost <= self.config.max_single_object_decoded_cost {
            let mut guard = self.decoded_cache.lock().unwrap();
            guard.insert(key.to_string().into(), decoded, cost);
            guard.get(key).expect("Just inserted")
        } else {
            Arc::new(decoded)
        };

        let weak = Arc::downgrade(&res);
        self.weak_refs
            .write()
            .unwrap()
            .insert(key.to_string(), weak);
        Ok(res)
    }

    /// Find or decode an item from the cache.
    fn find_or_decode(
        &self,
        key: &str,
    ) -> Result<Arc<DecoderImpl::Output>, AssetCacheError<DecoderImpl::Error>> {
        if let Some(x) = self.search_for_item(key) {
            return Ok(x);
        }

        // Stop any other threads from trying to decode this item, and make them wait on this thread to finish.
        let mutex = {
            let mut guard_inner = self.decoding_guards.lock().unwrap();
            let tmp = guard_inner
                .entry(key.to_string())
                .or_insert_with(|| Arc::new(Mutex::new(())));
            (*tmp).clone()
        };
        // The type here is important: it makes sure that we actually lock the mutex, by making this variable definitely
        // be a guard.  Any mistakes in the above rather complicated chain to set this up will be caught at compile
        // time.
        let _guard: std::sync::MutexGuard<()> = mutex.lock().unwrap();

        self.find_or_decode_postchecked(key)
    }

    /// Get an item from the cache, decoding if the item isn't present.
    pub fn get(
        &self,
        key: &str,
    ) -> Result<Arc<DecoderImpl::Output>, AssetCacheError<DecoderImpl::Error>> {
        self.find_or_decode(key)
    }

    /// Pin an item, so that it is always present in the cache until explicitly removed.
    pub fn cache_always(&self, key: String, value: Arc<DecoderImpl::Output>) {
        let weak = Arc::downgrade(&value);
        self.pinned_entries
            .write()
            .unwrap()
            .insert(key.clone(), value);
        self.weak_refs.write().unwrap().insert(key, weak);
    }

    /// Remove an item from the cache.
    pub fn remove(&self, key: &str) {
        self.pinned_entries.write().unwrap().remove(key);
        self.bytes_cache.lock().unwrap().remove(key);
        self.decoding_guards.lock().unwrap().remove(key);
        self.decoded_cache.lock().unwrap().remove(key);
        self.weak_refs.write().unwrap().remove(key);
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;

    use super::*;

    /// A VFS wrapping a `HashMap` for testing.
    struct HashMapVfs(Mutex<HashMap<String, Vec<u8>>>);

    impl Vfs for HashMapVfs {
        type Reader = std::io::Cursor<Vec<u8>>;

        fn open(&self, key: &str) -> Result<Self::Reader, IoError> {
            let ret = self
                .0
                .lock()
                .unwrap()
                .get(key)
                .ok_or_else(|| {
                    IoError::new(std::io::ErrorKind::NotFound, "Entry not found".to_string())
                })?
                .clone();
            Ok(std::io::Cursor::new(ret))
        }
    }

    impl VfsReader for std::io::Cursor<Vec<u8>> {
        fn get_size(&self) -> Result<u64, IoError> {
            Ok(self.get_ref().len() as u64)
        }
    }

    // Add a helper to put things into the vfs.
    impl HashMapVfs {
        fn new() -> HashMapVfs {
            HashMapVfs(Mutex::new(Default::default()))
        }

        pub fn insert(&self, key: &str, value: Vec<u8>) -> Option<Vec<u8>> {
            self.0.lock().unwrap().insert(key.to_string(), value)
        }
    }

    struct HashMapDecoder;

    impl Decoder for HashMapDecoder {
        type Error = IoError;
        type Output = String;

        fn decode<R: Read>(&self, mut reader: R) -> Result<String, IoError> {
            let mut out = String::new();
            reader.read_to_string(&mut out)?;
            Ok(out)
        }

        fn estimate_cost(&self, item: &String) -> Result<u64, IoError> {
            Ok(item.len() as u64)
        }
    }

    fn build_cache() -> (Arc<HashMapVfs>, AssetCache<Arc<HashMapVfs>, HashMapDecoder>) {
        let cfg = AssetCacheConfigBuilder::default()
            .max_bytes_cost(50)
            .max_single_object_bytes_cost(10)
            .max_decoded_cost(60)
            .max_single_object_decoded_cost(12)
            .build()
            .expect("Should build");
        let vfs = Arc::new(HashMapVfs::new());
        (vfs.clone(), AssetCache::new(vfs, HashMapDecoder, cfg))
    }

    // Test some basic common cache operations.
    #[test]
    fn basic_ops() {
        let (vfs, cache) = build_cache();
        vfs.insert("a", "abc".into());
        vfs.insert("b", "def".into());

        assert_eq!(&*cache.get("a").unwrap(), "abc");
        assert_eq!(&*cache.get("b").unwrap(), "def");

        // We should find these keys.
        cache.search_for_item("a").expect("Should find the key");
        cache.search_for_item("b").expect("Should find the item");

        cache.remove("b");
        assert!(cache.search_for_item("b").is_none());
        cache.search_for_item("a").expect("Key should be found");
    }

    #[test]
    fn test_single_object_limits() {
        let (vfs, cache) = build_cache();

        const SMALL: &str = "small";
        const NO_CACHE_BYTES: &str = "no_cache_bytes";
        const MAX_BYTES: &str = "max_bytes";
        const MAX_DECODED: &str = "max_decoded";
        const NO_CACHE: &str = "no_cache";

        vfs.insert(SMALL, "abc".into());
        vfs.insert(MAX_BYTES, "abcdefghij".into());
        // Big enough that decoding it won't cache the bytes.
        vfs.insert(NO_CACHE_BYTES, "abcdefghijk".into());
        // Largest object we'll cache.
        vfs.insert(MAX_DECODED, "abcdefghijkl".into());
        // Big enough that we don't cache it.
        vfs.insert(NO_CACHE, "abcdefghijklm".into());

        // Load up the cache.
        for i in &[SMALL, MAX_BYTES, MAX_DECODED, NO_CACHE, NO_CACHE_BYTES] {
            cache.get(i).expect("Should decode fine");
        }

        // All but NO_CACHE should be findable.
        assert_eq!(&*cache.search_for_item(SMALL).unwrap(), "abc");
        assert_eq!(&*cache.search_for_item(MAX_BYTES).unwrap(), "abcdefghij");
        assert_eq!(
            &*cache.search_for_item(MAX_DECODED).unwrap(),
            "abcdefghijkl"
        );
        assert_eq!(
            &*cache.search_for_item(NO_CACHE_BYTES).unwrap(),
            "abcdefghijk"
        );
        assert!(cache.search_for_item(NO_CACHE).is_none());
    }

    /// If we cache objects which are otherwise too large for the cache, or if an object is purged, we can still get the
    /// objects via our internal cache of weak references.
    #[test]
    fn test_weak_recovery() {
        let (vfs, cache) = build_cache();

        // insert a bunch of keys, holding onto the arcs.
        let mut arcs = vec![];
        for i in 0..100 {
            let key = format!("{}", i);
            let val = format!("{}", i);
            vfs.insert(&key, val.into());
            arcs.push(cache.get(&key).unwrap());
        }

        // Let's verify that key "1" isn't in any of the places we expect it to be.
        assert!(cache.bytes_cache.lock().unwrap().get("1").is_none());
        assert!(cache.decoded_cache.lock().unwrap().get("1").is_none());
        // But it should be in the weak map.
        assert!(cache.weak_refs.read().unwrap().get("1").is_some());

        // And looking for it should find it.
        assert_eq!(&*cache.get("1").unwrap(), "1");

        // If we drop our arcs, we can't find it anymore.
        arcs.clear();
        assert!(cache.search_for_item("1").is_none());

        // If we put a really big item in, then it doesn't cache. But holding onto the arc will let us get it back
        // anyway.
        vfs.insert("big", "abcdefghijklmnopqrstuvwxyz".into());
        let sref = cache.get("big");
        assert!(cache.bytes_cache.lock().unwrap().get("big").is_none());
        assert!(cache.decoded_cache.lock().unwrap().get("big").is_none());
        assert_eq!(&*cache.get("big").unwrap(), "abcdefghijklmnopqrstuvwxyz");
        // But droping sref makes it go away.
        std::mem::drop(sref);
        assert!(cache.search_for_item("big").is_none());
    }
}