Skip to main content

anvil_core/
cache.rs

1//! Cache subsystem. Trait-object based, with Moka (in-memory) and Redis drivers.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::Error;
10
11#[async_trait]
12pub trait CacheDriver: Send + Sync {
13    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error>;
14    async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<(), Error>;
15    async fn forget(&self, key: &str) -> Result<(), Error>;
16    async fn flush(&self) -> Result<(), Error>;
17}
18
19#[derive(Clone)]
20pub struct CacheStore {
21    driver: Arc<dyn CacheDriver>,
22}
23
24impl CacheStore {
25    pub fn new(driver: Arc<dyn CacheDriver>) -> Self {
26        Self { driver }
27    }
28
29    pub fn null() -> Self {
30        Self {
31            driver: Arc::new(NullDriver),
32        }
33    }
34
35    pub fn moka(capacity: u64) -> Self {
36        Self {
37            driver: Arc::new(MokaDriver::new(capacity)),
38        }
39    }
40
41    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
42        match self.driver.get(key).await? {
43            Some(bytes) => Ok(Some(serde_json::from_slice(&bytes)?)),
44            None => Ok(None),
45        }
46    }
47
48    pub async fn put<T: Serialize>(
49        &self,
50        key: &str,
51        value: &T,
52        ttl: Option<Duration>,
53    ) -> Result<(), Error> {
54        let bytes = serde_json::to_vec(value)?;
55        self.driver.put(key, bytes, ttl).await
56    }
57
58    pub async fn forget(&self, key: &str) -> Result<(), Error> {
59        self.driver.forget(key).await
60    }
61
62    pub async fn flush(&self) -> Result<(), Error> {
63        self.driver.flush().await
64    }
65
66    /// `remember` — get from cache, or compute, store, and return.
67    pub async fn remember<T, F, Fut>(&self, key: &str, ttl: Duration, loader: F) -> Result<T, Error>
68    where
69        T: Serialize + DeserializeOwned + Send + Sync,
70        F: FnOnce() -> Fut + Send,
71        Fut: std::future::Future<Output = Result<T, Error>> + Send,
72    {
73        if let Some(hit) = self.get::<T>(key).await? {
74            return Ok(hit);
75        }
76        let value = loader().await?;
77        self.put(key, &value, Some(ttl)).await?;
78        Ok(value)
79    }
80}
81
82struct NullDriver;
83
84#[async_trait]
85impl CacheDriver for NullDriver {
86    async fn get(&self, _key: &str) -> Result<Option<Vec<u8>>, Error> {
87        Ok(None)
88    }
89    async fn put(&self, _: &str, _: Vec<u8>, _: Option<Duration>) -> Result<(), Error> {
90        Ok(())
91    }
92    async fn forget(&self, _: &str) -> Result<(), Error> {
93        Ok(())
94    }
95    async fn flush(&self) -> Result<(), Error> {
96        Ok(())
97    }
98}
99
100pub struct MokaDriver {
101    inner: moka::future::Cache<String, Vec<u8>>,
102}
103
104impl MokaDriver {
105    pub fn new(capacity: u64) -> Self {
106        Self {
107            inner: moka::future::Cache::builder()
108                .max_capacity(capacity)
109                .build(),
110        }
111    }
112}
113
114#[async_trait]
115impl CacheDriver for MokaDriver {
116    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
117        Ok(self.inner.get(key).await)
118    }
119
120    async fn put(&self, key: &str, value: Vec<u8>, _ttl: Option<Duration>) -> Result<(), Error> {
121        self.inner.insert(key.to_string(), value).await;
122        Ok(())
123    }
124
125    async fn forget(&self, key: &str) -> Result<(), Error> {
126        self.inner.invalidate(key).await;
127        Ok(())
128    }
129
130    async fn flush(&self) -> Result<(), Error> {
131        self.inner.invalidate_all();
132        Ok(())
133    }
134}
135
136pub struct RedisDriver {
137    pool: redis::aio::ConnectionManager,
138}
139
140impl RedisDriver {
141    pub async fn connect(url: &str) -> Result<Self, Error> {
142        let client = redis::Client::open(url).map_err(|e| Error::Cache(e.to_string()))?;
143        let pool = redis::aio::ConnectionManager::new(client)
144            .await
145            .map_err(|e| Error::Cache(e.to_string()))?;
146        Ok(Self { pool })
147    }
148}
149
150#[async_trait]
151impl CacheDriver for RedisDriver {
152    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
153        use redis::AsyncCommands;
154        let mut conn = self.pool.clone();
155        let val: Option<Vec<u8>> = conn
156            .get(key)
157            .await
158            .map_err(|e| Error::Cache(e.to_string()))?;
159        Ok(val)
160    }
161
162    async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<(), Error> {
163        use redis::AsyncCommands;
164        let mut conn = self.pool.clone();
165        if let Some(ttl) = ttl {
166            let _: () = conn
167                .set_ex(key, value, ttl.as_secs())
168                .await
169                .map_err(|e| Error::Cache(e.to_string()))?;
170        } else {
171            let _: () = conn
172                .set(key, value)
173                .await
174                .map_err(|e| Error::Cache(e.to_string()))?;
175        }
176        Ok(())
177    }
178
179    async fn forget(&self, key: &str) -> Result<(), Error> {
180        use redis::AsyncCommands;
181        let mut conn = self.pool.clone();
182        let _: () = conn
183            .del(key)
184            .await
185            .map_err(|e| Error::Cache(e.to_string()))?;
186        Ok(())
187    }
188
189    async fn flush(&self) -> Result<(), Error> {
190        let mut conn = self.pool.clone();
191        let _: () = redis::cmd("FLUSHDB")
192            .query_async(&mut conn)
193            .await
194            .map_err(|e| Error::Cache(e.to_string()))?;
195        Ok(())
196    }
197}