Skip to main content

rok_cache/
cache.rs

1use std::{fmt, future::Future, sync::Arc, time::Duration};
2
3use serde::{de::DeserializeOwned, Serialize};
4
5use crate::{driver::CacheHandle, CacheError};
6
7tokio::task_local! {
8    pub(crate) static CURRENT_CACHE: Arc<CacheHandle>;
9}
10
11pub fn scope_cache<F: Future>(handle: Arc<CacheHandle>, f: F) -> impl Future<Output = F::Output> {
12    CURRENT_CACHE.scope(handle, f)
13}
14
15pub struct Cache;
16
17impl Cache {
18    fn handle() -> Result<Arc<CacheHandle>, CacheError> {
19        CURRENT_CACHE
20            .try_with(|h| h.clone())
21            .map_err(|_| CacheError::NotConfigured)
22    }
23
24    /// Retrieve a cached value.  Returns `None` on miss or expiry.
25    pub async fn get<T: DeserializeOwned>(key: &str) -> Result<Option<T>, CacheError> {
26        let h = Self::handle()?;
27        let full_key = h.key(key);
28        match h.driver.get(&full_key).await? {
29            None => Ok(None),
30            Some(json) => serde_json::from_str(&json)
31                .map(Some)
32                .map_err(|e| CacheError::Deserialize(e.to_string())),
33        }
34    }
35
36    /// Store a value.  `ttl = None` means no expiry.
37    pub async fn set<T: Serialize>(
38        key: &str,
39        value: &T,
40        ttl: Option<Duration>,
41    ) -> Result<(), CacheError> {
42        let h = Self::handle()?;
43        let full_key = h.key(key);
44        let json =
45            serde_json::to_string(value).map_err(|e| CacheError::Serialize(e.to_string()))?;
46        h.driver.set(&full_key, json, ttl).await
47    }
48
49    /// Remove a key.
50    pub async fn forget(key: &str) -> Result<(), CacheError> {
51        let h = Self::handle()?;
52        h.driver.forget(&h.key(key)).await
53    }
54
55    /// Remove all keys (scoped to the current database / memory store).
56    pub async fn flush() -> Result<(), CacheError> {
57        Self::handle()?.driver.flush().await
58    }
59
60    /// Return the cached value if present; otherwise call `f`, store the
61    /// result for `ttl`, and return it.
62    ///
63    /// The fallback closure may return any error type that implements
64    /// `Display` — it is converted to `CacheError::Fetch`.
65    pub async fn remember<T, F, Fut, E>(key: &str, ttl: Duration, f: F) -> Result<T, CacheError>
66    where
67        T: Serialize + DeserializeOwned + Send,
68        F: FnOnce() -> Fut,
69        Fut: Future<Output = Result<T, E>>,
70        E: fmt::Display,
71    {
72        if let Some(cached) = Self::get::<T>(key).await? {
73            return Ok(cached);
74        }
75        let value = f().await.map_err(|e| CacheError::Fetch(e.to_string()))?;
76        Self::set(key, &value, Some(ttl)).await?;
77        Ok(value)
78    }
79}