1use std::sync::Arc;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::Error;
10
11#[async_trait]
12pub trait CacheDriver: Send + Sync {
13 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error>;
14 async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<(), Error>;
15 async fn forget(&self, key: &str) -> Result<(), Error>;
16 async fn flush(&self) -> Result<(), Error>;
17}
18
19#[derive(Clone)]
20pub struct CacheStore {
21 driver: Arc<dyn CacheDriver>,
22}
23
24impl CacheStore {
25 pub fn new(driver: Arc<dyn CacheDriver>) -> Self {
26 Self { driver }
27 }
28
29 pub fn null() -> Self {
30 Self {
31 driver: Arc::new(NullDriver),
32 }
33 }
34
35 pub fn moka(capacity: u64) -> Self {
36 Self {
37 driver: Arc::new(MokaDriver::new(capacity)),
38 }
39 }
40
41 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
42 match self.driver.get(key).await? {
43 Some(bytes) => Ok(Some(serde_json::from_slice(&bytes)?)),
44 None => Ok(None),
45 }
46 }
47
48 pub async fn put<T: Serialize>(
49 &self,
50 key: &str,
51 value: &T,
52 ttl: Option<Duration>,
53 ) -> Result<(), Error> {
54 let bytes = serde_json::to_vec(value)?;
55 self.driver.put(key, bytes, ttl).await
56 }
57
58 pub async fn forget(&self, key: &str) -> Result<(), Error> {
59 self.driver.forget(key).await
60 }
61
62 pub async fn flush(&self) -> Result<(), Error> {
63 self.driver.flush().await
64 }
65
66 pub async fn remember<T, F, Fut>(
68 &self,
69 key: &str,
70 ttl: Duration,
71 loader: F,
72 ) -> Result<T, Error>
73 where
74 T: Serialize + DeserializeOwned + Send + Sync,
75 F: FnOnce() -> Fut + Send,
76 Fut: std::future::Future<Output = Result<T, Error>> + Send,
77 {
78 if let Some(hit) = self.get::<T>(key).await? {
79 return Ok(hit);
80 }
81 let value = loader().await?;
82 self.put(key, &value, Some(ttl)).await?;
83 Ok(value)
84 }
85}
86
87struct NullDriver;
88
89#[async_trait]
90impl CacheDriver for NullDriver {
91 async fn get(&self, _key: &str) -> Result<Option<Vec<u8>>, Error> {
92 Ok(None)
93 }
94 async fn put(&self, _: &str, _: Vec<u8>, _: Option<Duration>) -> Result<(), Error> {
95 Ok(())
96 }
97 async fn forget(&self, _: &str) -> Result<(), Error> {
98 Ok(())
99 }
100 async fn flush(&self) -> Result<(), Error> {
101 Ok(())
102 }
103}
104
105pub struct MokaDriver {
106 inner: moka::future::Cache<String, Vec<u8>>,
107}
108
109impl MokaDriver {
110 pub fn new(capacity: u64) -> Self {
111 Self {
112 inner: moka::future::Cache::builder()
113 .max_capacity(capacity)
114 .build(),
115 }
116 }
117}
118
119#[async_trait]
120impl CacheDriver for MokaDriver {
121 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
122 Ok(self.inner.get(key).await)
123 }
124
125 async fn put(&self, key: &str, value: Vec<u8>, _ttl: Option<Duration>) -> Result<(), Error> {
126 self.inner.insert(key.to_string(), value).await;
127 Ok(())
128 }
129
130 async fn forget(&self, key: &str) -> Result<(), Error> {
131 self.inner.invalidate(key).await;
132 Ok(())
133 }
134
135 async fn flush(&self) -> Result<(), Error> {
136 self.inner.invalidate_all();
137 Ok(())
138 }
139}
140
141pub struct RedisDriver {
142 pool: redis::aio::ConnectionManager,
143}
144
145impl RedisDriver {
146 pub async fn connect(url: &str) -> Result<Self, Error> {
147 let client = redis::Client::open(url).map_err(|e| Error::Cache(e.to_string()))?;
148 let pool = redis::aio::ConnectionManager::new(client)
149 .await
150 .map_err(|e| Error::Cache(e.to_string()))?;
151 Ok(Self { pool })
152 }
153}
154
155#[async_trait]
156impl CacheDriver for RedisDriver {
157 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
158 use redis::AsyncCommands;
159 let mut conn = self.pool.clone();
160 let val: Option<Vec<u8>> = conn
161 .get(key)
162 .await
163 .map_err(|e| Error::Cache(e.to_string()))?;
164 Ok(val)
165 }
166
167 async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<(), Error> {
168 use redis::AsyncCommands;
169 let mut conn = self.pool.clone();
170 if let Some(ttl) = ttl {
171 let _: () = conn
172 .set_ex(key, value, ttl.as_secs())
173 .await
174 .map_err(|e| Error::Cache(e.to_string()))?;
175 } else {
176 let _: () = conn
177 .set(key, value)
178 .await
179 .map_err(|e| Error::Cache(e.to_string()))?;
180 }
181 Ok(())
182 }
183
184 async fn forget(&self, key: &str) -> Result<(), Error> {
185 use redis::AsyncCommands;
186 let mut conn = self.pool.clone();
187 let _: () = conn
188 .del(key)
189 .await
190 .map_err(|e| Error::Cache(e.to_string()))?;
191 Ok(())
192 }
193
194 async fn flush(&self) -> Result<(), Error> {
195 let mut conn = self.pool.clone();
196 let _: () = redis::cmd("FLUSHDB")
197 .query_async(&mut conn)
198 .await
199 .map_err(|e| Error::Cache(e.to_string()))?;
200 Ok(())
201 }
202}