futures_cache/
lib.rs

1//! [<img alt="github" src="https://img.shields.io/badge/github-udoprog/futures--cache-8da0cb?style=for-the-badge&logo=github" height="20">](https://github.com/udoprog/futures-cache)
2//! [<img alt="crates.io" src="https://img.shields.io/crates/v/futures-cache.svg?style=for-the-badge&color=fc8d62&logo=rust" height="20">](https://crates.io/crates/futures-cache)
3//! [<img alt="docs.rs" src="https://img.shields.io/badge/docs.rs-futures--cache-66c2a5?style=for-the-badge&logoColor=white&logo=" height="20">](https://docs.rs/futures-cache)
4//!
5//! Futures-aware cache abstraction.
6//!
7//! Provides a cache for asynchronous operations that persist data on the
8//! filesystem using [sled]. The async cache works by accepting a future, but
9//! will cancel the accepted future in case the answer is already in the cache.
10//!
11//! It requires unique cache keys that are [serde] serializable. To distinguish
12//! across different sub-components of the cache, they can be namespaces using
13//! [Cache::namespaced].
14//!
15//! [sled]: https://github.com/spacejam/sled
16//!
17//! <br>
18//!
19//! ## State
20//!
21//! The state of the library is:
22//! * API is limited to only `wrap`, which includes a timeout ([#1]).
23//! * Requests are currently racing in the `wrap` method, so multiple unecessary
24//!   requests might occur when they should //! instead be queueing up ([#2]).
25//! * Entries only expire when the library is loaded ([#3]).
26//! * Only storage backend is sled ([#4]).
27//!
28//! [#1]: https://github.com/udoprog/futures-cache/issues/1
29//! [#2]: https://github.com/udoprog/futures-cache/issues/2
30//! [#3]: https://github.com/udoprog/futures-cache/issues/3
31//! [#4]: https://github.com/udoprog/futures-cache/issues/4
32//!
33//! <br>
34//!
35//! ## Usage
36//!
37//! This library requires the user to add the following dependencies to use:
38//!
39//! ```toml
40//! futures-cache = "0.10.3"
41//! serde = {version = "1.0", features = ["derive"]}
42//! ```
43//!
44//! <br>
45//!
46//! ## Examples
47//!
48//! Simple example showcasing fetching information on a github repository.
49//!
50//! > This is also available as an example you can run with:
51//! > ```sh
52//! > cargo run --example github -- --user udoprog --repo futures-cache
53//! > ```
54//!
55//! ```rust,no_run
56//! use futures_cache::{Cache, Duration};
57//! use serde::Serialize;
58//!
59//! type Error = Box<dyn std::error::Error>;
60//!
61//! #[derive(Debug, Serialize)]
62//! enum GithubKey<'a> {
63//!     Repo { user: &'a str, repo: &'a str },
64//! }
65//!
66//! async fn github_repo(user: &str, repo: &str) -> Result<String, Error> {
67//!     use reqwest::header;
68//!     use reqwest::{Client, Url};
69//!
70//!     let client = Client::new();
71//!
72//!     let url = Url::parse(&format!("https://api.github.com/repos/{}/{}", user, repo))?;
73//!
74//!     let req = client
75//!         .get(url)
76//!         .header(header::USER_AGENT, "Reqwest/0.10")
77//!         .build()?;
78//!
79//!     let body = client.execute(req).await?.text().await?;
80//!     Ok(body)
81//! }
82//!
83//! #[tokio::main]
84//! async fn main() -> Result<(), Error> {
85//!     let db = sled::open("cache")?;
86//!     let cache = Cache::load(db.open_tree("cache")?)?;
87//!
88//!     let user = "udoprog";
89//!     let repo = "futures-cache";
90//!
91//!     let text = cache
92//!         .wrap(
93//!             GithubKey::Repo {
94//!                 user: user,
95//!                 repo: repo,
96//!             },
97//!             Duration::seconds(60),
98//!             github_repo(user, repo),
99//!         )
100//!         .await?;
101//!
102//!     println!("{}", text);
103//!     Ok(())
104//! }
105//! ```
106//!
107//! [serde]: https://docs.rs/serde
108//! [Cache::namespaced]: https://docs.rs/futures-cache/0/futures_cache/struct.Cache.html#method.namespaced
109
110#![deny(missing_docs)]
111
112use chrono::{DateTime, Utc};
113use crossbeam::queue::SegQueue;
114use futures_channel::oneshot;
115use hashbrown::HashMap;
116use hex::ToHex as _;
117use parking_lot::RwLock;
118use serde::{Deserialize, Serialize};
119use serde_cbor as cbor;
120use serde_hashkey as hashkey;
121use serde_json as json;
122use std::error;
123use std::fmt;
124use std::future::Future;
125use std::sync::atomic::{AtomicUsize, Ordering};
126use std::sync::Arc;
127
128pub use chrono::Duration;
129pub use sled;
130
131/// Error type for the cache.
132#[derive(Debug)]
133pub enum Error {
134    /// An underlying CBOR error.
135    Cbor(cbor::error::Error),
136    /// An underlying HashKey error.
137    HashKey(hashkey::Error),
138    /// An underlying JSON error.
139    Json(json::error::Error),
140    /// An underlying Sled error.
141    Sled(sled::Error),
142    /// The underlying future failed (with an unspecified error).
143    Failed,
144}
145
146impl fmt::Display for Error {
147    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
148        match self {
149            Error::Cbor(e) => write!(fmt, "CBOR error: {}", e),
150            Error::HashKey(e) => write!(fmt, "HashKey error: {}", e),
151            Error::Json(e) => write!(fmt, "JSON error: {}", e),
152            Error::Sled(e) => write!(fmt, "Database error: {}", e),
153            Error::Failed => write!(fmt, "Operation failed"),
154        }
155    }
156}
157
158impl error::Error for Error {
159    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
160        match self {
161            Error::Cbor(e) => Some(e),
162            Error::HashKey(e) => Some(e),
163            Error::Json(e) => Some(e),
164            Error::Sled(e) => Some(e),
165            _ => None,
166        }
167    }
168}
169
170impl From<json::error::Error> for Error {
171    fn from(error: json::error::Error) -> Self {
172        Error::Json(error)
173    }
174}
175
176impl From<cbor::error::Error> for Error {
177    fn from(error: cbor::error::Error) -> Self {
178        Error::Cbor(error)
179    }
180}
181
182impl From<hashkey::Error> for Error {
183    fn from(error: hashkey::Error) -> Self {
184        Error::HashKey(error)
185    }
186}
187
188impl From<sled::Error> for Error {
189    fn from(error: sled::Error) -> Self {
190        Error::Sled(error)
191    }
192}
193
194/// Represents the state of an entry.
195pub enum State<T> {
196    /// Entry is fresh and can be used.
197    Fresh(StoredEntry<T>),
198    /// Entry exists, but is expired.
199    /// Cache is referenced so that the value can be removed if needed.
200    Expired(StoredEntry<T>),
201    /// No entry.
202    Missing,
203}
204
205impl<T> State<T> {
206    /// Get as an option, regardless if it's expired or not.
207    pub fn get(self) -> Option<T> {
208        match self {
209            State::Fresh(e) | State::Expired(e) => Some(e.value),
210            State::Missing => None,
211        }
212    }
213}
214
215/// Entry which have had its type erased into a JSON representation for convenience.
216///
217/// This is necessary in case you want to list all the entries in the database unless you want to deal with raw bytes.
218#[derive(Debug, Serialize, Deserialize)]
219pub struct JsonEntry {
220    /// The key of the entry.
221    pub key: serde_json::Value,
222    /// The stored entry.
223    #[serde(flatten)]
224    pub stored: StoredEntry<serde_json::Value>,
225}
226
227/// A complete stored entry with a type.
228#[derive(Debug, Serialize, Deserialize)]
229pub struct StoredEntry<T> {
230    expires_at: DateTime<Utc>,
231    value: T,
232}
233
234/// A reference to a complete stored entry with a type.
235///
236/// This is used for serialization to avoid taking ownership of the value to serialize.
237#[derive(Debug, Serialize)]
238pub struct StoredEntryRef<'a, T> {
239    expires_at: DateTime<Utc>,
240    value: &'a T,
241}
242
243impl<T> StoredEntry<T> {
244    /// Test if entry is expired.
245    fn is_expired(&self, now: DateTime<Utc>) -> bool {
246        self.expires_at < now
247    }
248}
249
250/// Used to only deserialize part of the stored entry.
251#[derive(Debug, Serialize, Deserialize)]
252struct PartialStoredEntry {
253    expires_at: DateTime<Utc>,
254}
255
256impl PartialStoredEntry {
257    /// Test if entry is expired.
258    fn is_expired(&self, now: DateTime<Utc>) -> bool {
259        self.expires_at < now
260    }
261
262    /// Convert into a stored entry.
263    fn into_stored_entry(self) -> StoredEntry<()> {
264        StoredEntry {
265            expires_at: self.expires_at,
266            value: (),
267        }
268    }
269}
270
271#[derive(Default)]
272struct Waker {
273    /// Number of things waiting for a response.
274    pending: AtomicUsize,
275    /// Channels to use for notifying dependents.
276    channels: SegQueue<oneshot::Sender<bool>>,
277}
278
279impl Waker {
280    /// Spin on performing cleanup, receiving channels to notify until we are in a stable state
281    /// where everything has been reset.
282    fn cleanup(&self, error: bool) {
283        let mut previous = self.pending.load(Ordering::Acquire);
284
285        loop {
286            while previous > 1 {
287                let mut received = 0usize;
288
289                while let Some(waker) = self.channels.pop() {
290                    received += 1;
291                    let _ = waker.send(error);
292                }
293
294                // Subtract the number of notifications sent here. Setting this inside the wrap
295                // function would deadlock on singlethreaded executors since they can't make
296                // progress at the same time as this procedure.
297                previous = self.pending.fetch_sub(received, Ordering::AcqRel);
298            }
299
300            previous =
301                match self
302                    .pending
303                    .compare_exchange(1, 0, Ordering::AcqRel, Ordering::Acquire)
304                {
305                    Ok(n) => n,
306                    Err(n) => n,
307                };
308
309            if previous == 1 {
310                break;
311            }
312        }
313    }
314}
315
316struct Inner {
317    /// The serialized namespace this cache belongs to.
318    ns: Option<hashkey::Key>,
319    /// Underlying storage.
320    db: sled::Tree,
321    /// Things to wake up.
322    /// TODO: clean up wakers that have been idle for a long time in future cleanup loop.
323    wakers: RwLock<HashMap<Vec<u8>, Arc<Waker>>>,
324}
325
326/// Primary cache abstraction.
327///
328/// Can be cheaply cloned and namespaced.
329#[derive(Clone)]
330pub struct Cache {
331    inner: Arc<Inner>,
332}
333
334impl Cache {
335    /// Load the cache from the database.
336    pub fn load(db: sled::Tree) -> Result<Cache, Error> {
337        let cache = Cache {
338            inner: Arc::new(Inner {
339                ns: None,
340                db,
341                wakers: Default::default(),
342            }),
343        };
344        cache.cleanup()?;
345        Ok(cache)
346    }
347
348    /// Delete the given key from the specified namespace.
349    pub fn delete_with_ns<N, K>(&self, ns: Option<&N>, key: &K) -> Result<(), Error>
350    where
351        N: Serialize,
352        K: Serialize,
353    {
354        let ns = match ns {
355            Some(ns) => Some(hashkey::to_key(ns)?.normalize()),
356            None => None,
357        };
358
359        let key = self.key_with_ns(ns.as_ref(), key)?;
360        self.inner.db.remove(key)?;
361        Ok(())
362    }
363
364    /// List all cache entries as JSON.
365    pub fn list_json(&self) -> Result<Vec<JsonEntry>, Error> {
366        let mut out = Vec::new();
367
368        for result in self.inner.db.range::<&[u8], _>(..) {
369            let (key, value) = result?;
370
371            let key: json::Value = match cbor::from_slice(&key) {
372                Ok(key) => key,
373                // key is malformed.
374                Err(_) => continue,
375            };
376
377            let stored = match cbor::from_slice(&value) {
378                Ok(storage) => storage,
379                // something weird stored in there.
380                Err(_) => continue,
381            };
382
383            out.push(JsonEntry { key, stored });
384        }
385
386        Ok(out)
387    }
388
389    /// Clean up stale entries.
390    ///
391    /// This could be called periodically if you want to reclaim space.
392    fn cleanup(&self) -> Result<(), Error> {
393        let now = Utc::now();
394
395        for result in self.inner.db.range::<&[u8], _>(..) {
396            let (key, value) = result?;
397
398            let entry: PartialStoredEntry = match cbor::from_slice(&value) {
399                Ok(entry) => entry,
400                Err(e) => {
401                    if log::log_enabled!(log::Level::Trace) {
402                        log::warn!(
403                            "{}: failed to load: {}: {}",
404                            KeyFormat(&key),
405                            e,
406                            KeyFormat(&value)
407                        );
408                    } else {
409                        log::warn!("{}: failed to load: {}", KeyFormat(&key), e);
410                    }
411
412                    // delete key since it's invalid.
413                    self.inner.db.remove(key)?;
414                    continue;
415                }
416            };
417
418            if entry.is_expired(now) {
419                self.inner.db.remove(key)?;
420            }
421        }
422
423        Ok(())
424    }
425
426    /// Create a namespaced cache.
427    ///
428    /// The namespace must be unique to avoid conflicts.
429    ///
430    /// Each call to this functions will return its own queue for resolving futures.
431    pub fn namespaced<N>(&self, ns: &N) -> Result<Self, Error>
432    where
433        N: Serialize,
434    {
435        Ok(Self {
436            inner: Arc::new(Inner {
437                ns: Some(hashkey::to_key(ns)?.normalize()),
438                db: self.inner.db.clone(),
439                wakers: Default::default(),
440            }),
441        })
442    }
443
444    /// Insert a value into the cache.
445    pub fn insert<K, T>(&self, key: K, age: Duration, value: &T) -> Result<(), Error>
446    where
447        K: Serialize,
448        T: Serialize,
449    {
450        let key = self.key(&key)?;
451        self.inner_insert(&key, age, value)
452    }
453
454    /// Insert a value into the cache.
455    #[inline(always)]
456    fn inner_insert<T>(&self, key: &Vec<u8>, age: Duration, value: &T) -> Result<(), Error>
457    where
458        T: Serialize,
459    {
460        let expires_at = Utc::now() + age;
461
462        let value = match cbor::to_vec(&StoredEntryRef { expires_at, value }) {
463            Ok(value) => value,
464            Err(e) => {
465                log::trace!("store:{} *errored*", KeyFormat(key));
466                return Err(e.into());
467            }
468        };
469
470        log::trace!("store:{}", KeyFormat(key));
471        self.inner.db.insert(key, value)?;
472        Ok(())
473    }
474
475    /// Test an entry from the cache.
476    pub fn test<K>(&self, key: K) -> Result<State<()>, Error>
477    where
478        K: Serialize,
479    {
480        let key = self.key(&key)?;
481        self.inner_test(&key)
482    }
483
484    /// Load an entry from the cache.
485    #[inline(always)]
486    fn inner_test(&self, key: &[u8]) -> Result<State<()>, Error> {
487        let value = match self.inner.db.get(key)? {
488            Some(value) => value,
489            None => {
490                log::trace!("test:{} -> null (missing)", KeyFormat(key));
491                return Ok(State::Missing);
492            }
493        };
494
495        let stored: PartialStoredEntry = match cbor::from_slice(&value) {
496            Ok(value) => value,
497            Err(e) => {
498                if log::log_enabled!(log::Level::Trace) {
499                    log::warn!(
500                        "{}: failed to deserialize: {}: {}",
501                        KeyFormat(key),
502                        e,
503                        KeyFormat(&value)
504                    );
505                } else {
506                    log::warn!("{}: failed to deserialize: {}", KeyFormat(key), e);
507                }
508
509                log::trace!("test:{} -> null (deserialize error)", KeyFormat(key));
510                return Ok(State::Missing);
511            }
512        };
513
514        if stored.is_expired(Utc::now()) {
515            log::trace!("test:{} -> null (expired)", KeyFormat(key));
516            return Ok(State::Expired(stored.into_stored_entry()));
517        }
518
519        log::trace!("test:{} -> *value*", KeyFormat(key));
520        Ok(State::Fresh(stored.into_stored_entry()))
521    }
522
523    /// Load an entry from the cache.
524    pub fn get<K, T>(&self, key: K) -> Result<State<T>, Error>
525    where
526        K: Serialize,
527        T: serde::de::DeserializeOwned,
528    {
529        let key = self.key(&key)?;
530        self.inner_get(&key)
531    }
532
533    /// Load an entry from the cache.
534    #[inline(always)]
535    fn inner_get<T>(&self, key: &[u8]) -> Result<State<T>, Error>
536    where
537        T: serde::de::DeserializeOwned,
538    {
539        let value = match self.inner.db.get(key)? {
540            Some(value) => value,
541            None => {
542                log::trace!("load:{} -> null (missing)", KeyFormat(key));
543                return Ok(State::Missing);
544            }
545        };
546
547        let stored: StoredEntry<T> = match cbor::from_slice(&value) {
548            Ok(value) => value,
549            Err(e) => {
550                if log::log_enabled!(log::Level::Trace) {
551                    log::warn!(
552                        "{}: failed to deserialize: {}: {}",
553                        KeyFormat(key),
554                        e,
555                        KeyFormat(&value)
556                    );
557                } else {
558                    log::warn!("{}: failed to deserialize: {}", KeyFormat(key), e);
559                }
560
561                log::trace!("load:{} -> null (deserialize error)", KeyFormat(key));
562                return Ok(State::Missing);
563            }
564        };
565
566        if stored.is_expired(Utc::now()) {
567            log::trace!("load:{} -> null (expired)", KeyFormat(key));
568            return Ok(State::Expired(stored));
569        }
570
571        log::trace!("load:{} -> *value*", KeyFormat(key));
572        Ok(State::Fresh(stored))
573    }
574
575    /// Get the waker associated with the given key.
576    fn waker(&self, key: &[u8]) -> Arc<Waker> {
577        let wakers = self.inner.wakers.read();
578
579        match wakers.get(key) {
580            Some(waker) => return waker.clone(),
581            None => drop(wakers),
582        }
583
584        self.inner
585            .wakers
586            .write()
587            .entry(key.to_vec())
588            .or_default()
589            .clone()
590    }
591
592    /// Wrap the result of the given future to load and store from cache.
593    pub async fn wrap<K, F, T, E>(&self, key: K, age: Duration, future: F) -> Result<T, E>
594    where
595        K: Serialize,
596        F: Future<Output = Result<T, E>>,
597        T: Serialize + serde::de::DeserializeOwned,
598        E: From<Error>,
599    {
600        let key = self.key(&key)?;
601
602        loop {
603            // There a slight race here. The answer might _just_ have been provided when we perform
604            // this check.
605            //
606            // If that happens, worst case we will end up re-computing the answer again.
607            if let State::Fresh(e) = self.inner_get(&key)? {
608                return Ok(e.value);
609            }
610
611            let waker = self.waker(&key);
612
613            // only pending == 0 will be driving the future for a response.
614            if waker.pending.fetch_add(1, Ordering::AcqRel) > 0 {
615                let (tx, rx) = oneshot::channel();
616                waker.channels.push(tx);
617
618                let result = rx.await;
619
620                // Ignore if sender is cancelled, just loop again.
621                match result {
622                    Ok(true) => return Err(E::from(Error::Failed)),
623                    Err(oneshot::Canceled) | Ok(false) => continue,
624                }
625            }
626
627            // Check key again in case we got really unlucky and had two call do an interleaving
628            // pass for the previous check:
629            //
630            // T1 just went passed the first inner_get test above.
631            // T2 just finished the Waker::cleanup procedure and reduces pending to 0.
632            // T1 notices that it is the first pending thread (pending == 0) and ends up here.
633            if let State::Fresh(e) = self.inner_get(&key)? {
634                waker.cleanup(false);
635                return Ok(e.value);
636            }
637
638            // Guard in case it is cancelled.
639            let result = Guard::new(|| waker.cleanup(false)).wrap(future).await;
640
641            // Compute the answer by polling the underlying future and store it in the cache,
642            // then acquire the wakers lock and dispatch to all pending futures.
643            match result {
644                Ok(output) => {
645                    self.inner_insert(&key, age, &output)?;
646                    waker.cleanup(false);
647                    return Ok(output);
648                }
649                Err(e) => {
650                    waker.cleanup(true);
651                    return Err(e);
652                }
653            }
654        }
655
656        /// Create a stack guard that will run unless it is forgotten.
657        struct Guard<F>
658        where
659            F: FnMut(),
660        {
661            f: F,
662        }
663
664        impl<F> Guard<F>
665        where
666            F: FnMut(),
667        {
668            /// Construct a new finalizer.
669            pub fn new(f: F) -> Self {
670                Self { f }
671            }
672
673            /// Wrap the given future with this cancellation guard.
674            pub async fn wrap<O>(self, future: O) -> O::Output
675            where
676                O: Future,
677            {
678                let result = future.await;
679                std::mem::forget(self);
680                result
681            }
682        }
683
684        impl<F> Drop for Guard<F>
685        where
686            F: FnMut(),
687        {
688            fn drop(&mut self) {
689                (self.f)();
690            }
691        }
692    }
693
694    /// Helper to serialize the key with the default namespace.
695    fn key<T>(&self, key: &T) -> Result<Vec<u8>, Error>
696    where
697        T: Serialize,
698    {
699        self.key_with_ns(self.inner.ns.as_ref(), key)
700    }
701
702    /// Helper to serialize the key with a specific namespace.
703    fn key_with_ns<T>(&self, ns: Option<&hashkey::Key>, key: &T) -> Result<Vec<u8>, Error>
704    where
705        T: Serialize,
706    {
707        let key = hashkey::to_key(key)?.normalize();
708        let key = Key(ns, key);
709        return Ok(cbor::to_vec(&key)?);
710
711        #[derive(Serialize)]
712        struct Key<'a>(Option<&'a hashkey::Key>, hashkey::Key);
713    }
714}
715
716/// Helper formatter to convert cbor bytes to JSON or hex.
717struct KeyFormat<'a>(&'a [u8]);
718
719impl fmt::Display for KeyFormat<'_> {
720    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
721        let value = match cbor::from_slice::<cbor::Value>(self.0) {
722            Ok(value) => value,
723            Err(_) => return self.0.encode_hex::<String>().fmt(fmt),
724        };
725
726        let value = match json::to_string(&value) {
727            Ok(value) => value,
728            Err(_) => return self.0.encode_hex::<String>().fmt(fmt),
729        };
730
731        value.fmt(fmt)
732    }
733}
734
735#[cfg(test)]
736mod tests {
737    use super::{Cache, Duration, Error};
738    use std::{error, fs, sync::Arc, thread};
739    use tempdir::TempDir;
740
741    fn db(name: &str) -> Result<sled::Tree, Box<dyn error::Error>> {
742        let path = TempDir::new(name)?;
743        let path = path.path();
744
745        if !path.is_dir() {
746            fs::create_dir_all(path)?;
747        }
748
749        let db = sled::open(path)?;
750        Ok(db.open_tree("test")?)
751    }
752
753    #[test]
754    fn test_cached() -> Result<(), Box<dyn error::Error>> {
755        use std::sync::atomic::{AtomicUsize, Ordering};
756
757        let db = db("test_cached")?;
758        let cache = Cache::load(db)?;
759
760        let count = Arc::new(AtomicUsize::default());
761
762        let c = count.clone();
763
764        let op1 = cache.wrap("a", Duration::hours(12), async move {
765            let _ = c.fetch_add(1, Ordering::SeqCst);
766            Ok::<_, Error>(String::from("foo"))
767        });
768
769        let c = count.clone();
770
771        let op2 = cache.wrap("a", Duration::hours(12), async move {
772            let _ = c.fetch_add(1, Ordering::SeqCst);
773            Ok::<_, Error>(String::from("foo"))
774        });
775
776        ::futures::executor::block_on(async move {
777            let (a, b) = ::futures::future::join(op1, op2).await;
778            assert_eq!("foo", a.expect("ok result"));
779            assert_eq!("foo", b.expect("ok result"));
780            assert_eq!(1, count.load(Ordering::SeqCst));
781        });
782
783        Ok(())
784    }
785
786    #[test]
787    fn test_contended() -> Result<(), Box<dyn error::Error>> {
788        use crossbeam::queue::SegQueue;
789        use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
790
791        const THREAD_COUNT: usize = 1_000;
792
793        let db = db("test_contended")?;
794        let cache = Cache::load(db)?;
795
796        let started = Arc::new(AtomicBool::new(false));
797        let count = Arc::new(AtomicUsize::default());
798        let results = Arc::new(SegQueue::new());
799        let mut threads = Vec::with_capacity(THREAD_COUNT);
800
801        for _ in 0..THREAD_COUNT {
802            let started = started.clone();
803            let cache = cache.clone();
804            let results = results.clone();
805            let count = count.clone();
806
807            let t = thread::spawn(move || {
808                let op = cache.wrap("a", Duration::hours(12), async move {
809                    let _ = count.fetch_add(1, Ordering::SeqCst);
810                    Ok::<_, Error>(String::from("foo"))
811                });
812
813                while !started.load(Ordering::Acquire) {}
814
815                ::futures::executor::block_on(async move {
816                    results.push(op.await);
817                });
818            });
819
820            threads.push(t);
821        }
822
823        started.store(true, Ordering::Release);
824
825        for t in threads {
826            t.join().expect("thread to join");
827        }
828
829        assert_eq!(1, count.load(Ordering::SeqCst));
830        Ok(())
831    }
832
833    #[test]
834    fn test_guards() -> Result<(), Box<dyn error::Error>> {
835        use self::futures::PollOnce;
836        use ::futures::channel::oneshot;
837        use std::sync::atomic::Ordering;
838
839        let db = db("test_guards")?;
840        let cache = Cache::load(db)?;
841
842        ::futures::executor::block_on(async move {
843            let (op1_tx, op1_rx) = oneshot::channel::<()>();
844
845            let op1 = cache.wrap("a", Duration::hours(12), async move {
846                let _ = op1_rx.await;
847                Ok::<_, Error>(String::from("foo"))
848            });
849
850            pin_utils::pin_mut!(op1);
851
852            let (op2_tx, op2_rx) = oneshot::channel::<()>();
853
854            let op2 = cache.wrap("a", Duration::hours(12), async move {
855                let _ = op2_rx.await;
856                Ok::<_, Error>(String::from("foo"))
857            });
858
859            pin_utils::pin_mut!(op2);
860
861            assert!(PollOnce::new(&mut op1).await.is_none());
862
863            let k = cache.key(&"a")?;
864            let waker = cache.inner.wakers.read().get(&k).cloned();
865            assert!(waker.is_some());
866            let waker = waker.expect("waker to be registered");
867
868            assert_eq!(1, waker.pending.load(Ordering::SeqCst));
869            assert!(PollOnce::new(&mut op2).await.is_none());
870            assert_eq!(2, waker.pending.load(Ordering::SeqCst));
871
872            op1_tx.send(()).expect("send to op1");
873            op2_tx.send(()).expect("send to op2");
874
875            assert!(PollOnce::new(&mut op1).await.is_some());
876            assert_eq!(0, waker.pending.load(Ordering::SeqCst));
877            assert!(PollOnce::new(&mut op2).await.is_some());
878
879            Ok(())
880        })
881    }
882
883    mod futures {
884        use std::{
885            future::Future,
886            pin::Pin,
887            task::{Context, Poll},
888        };
889
890        pub struct PollOnce<F> {
891            future: F,
892        }
893
894        impl<F> PollOnce<F> {
895            /// Wrap a new future to be polled once.
896            pub fn new(future: F) -> Self {
897                Self { future }
898            }
899        }
900
901        impl<F> PollOnce<F> {
902            pin_utils::unsafe_pinned!(future: F);
903        }
904
905        impl<F> Future for PollOnce<F>
906        where
907            F: Future,
908        {
909            type Output = Option<F::Output>;
910
911            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
912                match self.future().poll(cx) {
913                    Poll::Ready(output) => Poll::Ready(Some(output)),
914                    Poll::Pending => Poll::Ready(None),
915                }
916            }
917        }
918    }
919}