1use std::sync::Arc;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use serde::{de::DeserializeOwned, Serialize};
8
9use crate::Error;
10
11#[async_trait]
12pub trait CacheDriver: Send + Sync {
13 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error>;
14 async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<(), Error>;
15 async fn forget(&self, key: &str) -> Result<(), Error>;
16 async fn flush(&self) -> Result<(), Error>;
17}
18
19#[derive(Clone)]
20pub struct CacheStore {
21 driver: Arc<dyn CacheDriver>,
22}
23
24impl CacheStore {
25 pub fn new(driver: Arc<dyn CacheDriver>) -> Self {
26 Self { driver }
27 }
28
29 pub fn null() -> Self {
30 Self {
31 driver: Arc::new(NullDriver),
32 }
33 }
34
35 pub fn moka(capacity: u64) -> Self {
36 Self {
37 driver: Arc::new(MokaDriver::new(capacity)),
38 }
39 }
40
41 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
42 match self.driver.get(key).await? {
43 Some(bytes) => Ok(Some(serde_json::from_slice(&bytes)?)),
44 None => Ok(None),
45 }
46 }
47
48 pub async fn put<T: Serialize>(
49 &self,
50 key: &str,
51 value: &T,
52 ttl: Option<Duration>,
53 ) -> Result<(), Error> {
54 let bytes = serde_json::to_vec(value)?;
55 self.driver.put(key, bytes, ttl).await
56 }
57
58 pub async fn forget(&self, key: &str) -> Result<(), Error> {
59 self.driver.forget(key).await
60 }
61
62 pub async fn flush(&self) -> Result<(), Error> {
63 self.driver.flush().await
64 }
65
66 pub async fn remember<T, F, Fut>(&self, key: &str, ttl: Duration, loader: F) -> Result<T, Error>
68 where
69 T: Serialize + DeserializeOwned + Send + Sync,
70 F: FnOnce() -> Fut + Send,
71 Fut: std::future::Future<Output = Result<T, Error>> + Send,
72 {
73 if let Some(hit) = self.get::<T>(key).await? {
74 return Ok(hit);
75 }
76 let value = loader().await?;
77 self.put(key, &value, Some(ttl)).await?;
78 Ok(value)
79 }
80}
81
82struct NullDriver;
83
84#[async_trait]
85impl CacheDriver for NullDriver {
86 async fn get(&self, _key: &str) -> Result<Option<Vec<u8>>, Error> {
87 Ok(None)
88 }
89 async fn put(&self, _: &str, _: Vec<u8>, _: Option<Duration>) -> Result<(), Error> {
90 Ok(())
91 }
92 async fn forget(&self, _: &str) -> Result<(), Error> {
93 Ok(())
94 }
95 async fn flush(&self) -> Result<(), Error> {
96 Ok(())
97 }
98}
99
100pub struct MokaDriver {
101 inner: moka::future::Cache<String, Vec<u8>>,
102}
103
104impl MokaDriver {
105 pub fn new(capacity: u64) -> Self {
106 Self {
107 inner: moka::future::Cache::builder()
108 .max_capacity(capacity)
109 .build(),
110 }
111 }
112}
113
114#[async_trait]
115impl CacheDriver for MokaDriver {
116 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
117 Ok(self.inner.get(key).await)
118 }
119
120 async fn put(&self, key: &str, value: Vec<u8>, _ttl: Option<Duration>) -> Result<(), Error> {
121 self.inner.insert(key.to_string(), value).await;
122 Ok(())
123 }
124
125 async fn forget(&self, key: &str) -> Result<(), Error> {
126 self.inner.invalidate(key).await;
127 Ok(())
128 }
129
130 async fn flush(&self) -> Result<(), Error> {
131 self.inner.invalidate_all();
132 Ok(())
133 }
134}
135
136pub struct RedisDriver {
137 pool: redis::aio::ConnectionManager,
138}
139
140impl RedisDriver {
141 pub async fn connect(url: &str) -> Result<Self, Error> {
142 let client = redis::Client::open(url).map_err(|e| Error::Cache(e.to_string()))?;
143 let pool = redis::aio::ConnectionManager::new(client)
144 .await
145 .map_err(|e| Error::Cache(e.to_string()))?;
146 Ok(Self { pool })
147 }
148}
149
150#[async_trait]
151impl CacheDriver for RedisDriver {
152 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
153 use redis::AsyncCommands;
154 let mut conn = self.pool.clone();
155 let val: Option<Vec<u8>> = conn
156 .get(key)
157 .await
158 .map_err(|e| Error::Cache(e.to_string()))?;
159 Ok(val)
160 }
161
162 async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<(), Error> {
163 use redis::AsyncCommands;
164 let mut conn = self.pool.clone();
165 if let Some(ttl) = ttl {
166 let _: () = conn
167 .set_ex(key, value, ttl.as_secs())
168 .await
169 .map_err(|e| Error::Cache(e.to_string()))?;
170 } else {
171 let _: () = conn
172 .set(key, value)
173 .await
174 .map_err(|e| Error::Cache(e.to_string()))?;
175 }
176 Ok(())
177 }
178
179 async fn forget(&self, key: &str) -> Result<(), Error> {
180 use redis::AsyncCommands;
181 let mut conn = self.pool.clone();
182 let _: () = conn
183 .del(key)
184 .await
185 .map_err(|e| Error::Cache(e.to_string()))?;
186 Ok(())
187 }
188
189 async fn flush(&self) -> Result<(), Error> {
190 let mut conn = self.pool.clone();
191 let _: () = redis::cmd("FLUSHDB")
192 .query_async(&mut conn)
193 .await
194 .map_err(|e| Error::Cache(e.to_string()))?;
195 Ok(())
196 }
197}