Skip to main content

anvil_core/
cache.rs

1//! Cache subsystem. Trait-object based, with Moka (in-memory) and Redis drivers.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::Error;
10
11#[async_trait]
12pub trait CacheDriver: Send + Sync {
13    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error>;
14    async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<(), Error>;
15    async fn forget(&self, key: &str) -> Result<(), Error>;
16    async fn flush(&self) -> Result<(), Error>;
17}
18
19#[derive(Clone)]
20pub struct CacheStore {
21    driver: Arc<dyn CacheDriver>,
22}
23
24impl CacheStore {
25    pub fn new(driver: Arc<dyn CacheDriver>) -> Self {
26        Self { driver }
27    }
28
29    pub fn null() -> Self {
30        Self {
31            driver: Arc::new(NullDriver),
32        }
33    }
34
35    pub fn moka(capacity: u64) -> Self {
36        Self {
37            driver: Arc::new(MokaDriver::new(capacity)),
38        }
39    }
40
41    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
42        match self.driver.get(key).await? {
43            Some(bytes) => Ok(Some(serde_json::from_slice(&bytes)?)),
44            None => Ok(None),
45        }
46    }
47
48    pub async fn put<T: Serialize>(
49        &self,
50        key: &str,
51        value: &T,
52        ttl: Option<Duration>,
53    ) -> Result<(), Error> {
54        let bytes = serde_json::to_vec(value)?;
55        self.driver.put(key, bytes, ttl).await
56    }
57
58    pub async fn forget(&self, key: &str) -> Result<(), Error> {
59        self.driver.forget(key).await
60    }
61
62    pub async fn flush(&self) -> Result<(), Error> {
63        self.driver.flush().await
64    }
65
66    /// `remember` — get from cache, or compute, store, and return.
67    pub async fn remember<T, F, Fut>(
68        &self,
69        key: &str,
70        ttl: Duration,
71        loader: F,
72    ) -> Result<T, Error>
73    where
74        T: Serialize + DeserializeOwned + Send + Sync,
75        F: FnOnce() -> Fut + Send,
76        Fut: std::future::Future<Output = Result<T, Error>> + Send,
77    {
78        if let Some(hit) = self.get::<T>(key).await? {
79            return Ok(hit);
80        }
81        let value = loader().await?;
82        self.put(key, &value, Some(ttl)).await?;
83        Ok(value)
84    }
85}
86
87struct NullDriver;
88
89#[async_trait]
90impl CacheDriver for NullDriver {
91    async fn get(&self, _key: &str) -> Result<Option<Vec<u8>>, Error> {
92        Ok(None)
93    }
94    async fn put(&self, _: &str, _: Vec<u8>, _: Option<Duration>) -> Result<(), Error> {
95        Ok(())
96    }
97    async fn forget(&self, _: &str) -> Result<(), Error> {
98        Ok(())
99    }
100    async fn flush(&self) -> Result<(), Error> {
101        Ok(())
102    }
103}
104
105pub struct MokaDriver {
106    inner: moka::future::Cache<String, Vec<u8>>,
107}
108
109impl MokaDriver {
110    pub fn new(capacity: u64) -> Self {
111        Self {
112            inner: moka::future::Cache::builder()
113                .max_capacity(capacity)
114                .build(),
115        }
116    }
117}
118
119#[async_trait]
120impl CacheDriver for MokaDriver {
121    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
122        Ok(self.inner.get(key).await)
123    }
124
125    async fn put(&self, key: &str, value: Vec<u8>, _ttl: Option<Duration>) -> Result<(), Error> {
126        self.inner.insert(key.to_string(), value).await;
127        Ok(())
128    }
129
130    async fn forget(&self, key: &str) -> Result<(), Error> {
131        self.inner.invalidate(key).await;
132        Ok(())
133    }
134
135    async fn flush(&self) -> Result<(), Error> {
136        self.inner.invalidate_all();
137        Ok(())
138    }
139}
140
141pub struct RedisDriver {
142    pool: redis::aio::ConnectionManager,
143}
144
145impl RedisDriver {
146    pub async fn connect(url: &str) -> Result<Self, Error> {
147        let client = redis::Client::open(url).map_err(|e| Error::Cache(e.to_string()))?;
148        let pool = redis::aio::ConnectionManager::new(client)
149            .await
150            .map_err(|e| Error::Cache(e.to_string()))?;
151        Ok(Self { pool })
152    }
153}
154
155#[async_trait]
156impl CacheDriver for RedisDriver {
157    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
158        use redis::AsyncCommands;
159        let mut conn = self.pool.clone();
160        let val: Option<Vec<u8>> = conn
161            .get(key)
162            .await
163            .map_err(|e| Error::Cache(e.to_string()))?;
164        Ok(val)
165    }
166
167    async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<(), Error> {
168        use redis::AsyncCommands;
169        let mut conn = self.pool.clone();
170        if let Some(ttl) = ttl {
171            let _: () = conn
172                .set_ex(key, value, ttl.as_secs())
173                .await
174                .map_err(|e| Error::Cache(e.to_string()))?;
175        } else {
176            let _: () = conn
177                .set(key, value)
178                .await
179                .map_err(|e| Error::Cache(e.to_string()))?;
180        }
181        Ok(())
182    }
183
184    async fn forget(&self, key: &str) -> Result<(), Error> {
185        use redis::AsyncCommands;
186        let mut conn = self.pool.clone();
187        let _: () = conn
188            .del(key)
189            .await
190            .map_err(|e| Error::Cache(e.to_string()))?;
191        Ok(())
192    }
193
194    async fn flush(&self) -> Result<(), Error> {
195        let mut conn = self.pool.clone();
196        let _: () = redis::cmd("FLUSHDB")
197            .query_async(&mut conn)
198            .await
199            .map_err(|e| Error::Cache(e.to_string()))?;
200        Ok(())
201    }
202}