1use async_trait::async_trait;
2use bb8_redis::bb8::RunError;
3use redis_rs::{AsyncCommands, ErrorKind, FromRedisValue, RedisError, ToRedisArgs};
4use serde::{de::DeserializeOwned, Serialize};
5use serde_json::json;
6use std::{
7 future::Future,
8 str::from_utf8,
9 time::{SystemTime, UNIX_EPOCH},
10};
11use thiserror::Error;
12
13use tracing::error;
15
16pub use connection::Connection;
17pub use connection::Pool;
18pub use connection::RedisConfig;
19
20#[async_trait]
23pub trait GetOrFetchExt: AsyncCommands {
24 async fn get_or_fetch<K, V, F, Fut>(
25 &mut self,
26 key: K,
27 data_loader: F,
28 expire_seconds: usize,
29 ) -> Result<V>
30 where
31 K: ToRedisArgs + Send + Sync,
32 V: FromRedisValue + ToRedisArgs + Send + Sync,
33 F: FnOnce() -> Fut + Send,
34 Fut: Future<Output = anyhow::Result<V>> + Send;
35}
36
37#[async_trait]
38impl GetOrFetchExt for redis_cluster_async::Connection {
39 async fn get_or_fetch<K, V, F, Fut>(
40 &mut self,
41 key: K,
42 data_loader: F,
43 expire_seconds: usize,
44 ) -> Result<V>
45 where
46 K: ToRedisArgs + Send + Sync,
47 V: FromRedisValue + ToRedisArgs + Send + Sync,
48 F: FnOnce() -> Fut + Send,
49 Fut: Future<Output = anyhow::Result<V>> + Send,
50 {
51 if cfg!(test) {
52 return Ok(data_loader().await?);
53 }
54
55 match self.get(&key).await {
56 Ok(Some(bytes)) => Ok(bytes),
57 Ok(None) => {
58 let result = data_loader().await?;
59 self.set_ex(&key, &result, expire_seconds).await?;
60 Ok(result)
61 }
62 Err(err) => {
63 error!("redis error: {:?}", err);
64 Ok(data_loader().await?)
65 }
66 }
67 }
68}
69
70#[async_trait]
71pub trait GetOrRefreshExt {
72 async fn get_or_refresh<'a, V, F, Fut>(
73 mut self,
74 key: &str, data_loader: F,
76 expire_seconds: usize,
77 ) -> Result<V>
78 where
79 V: FromRedisValue + ToRedisArgs + Send + Sync + 'static,
80 F: FnOnce() -> Fut + Send + 'static,
81 Fut: Future<Output = anyhow::Result<V>> + Send;
82}
83
84#[async_trait]
85impl GetOrRefreshExt for connection::Connection {
86 async fn get_or_refresh<'a, V, F, Fut>(
87 mut self,
88 key: &str,
89 data_loader: F,
90 expire_seconds: usize,
91 ) -> Result<V>
92 where
93 V: FromRedisValue + ToRedisArgs + Send + Sync + 'static,
94 F: FnOnce() -> Fut + Send + 'static,
95 Fut: Future<Output = anyhow::Result<V>> + Send,
96 {
97 if cfg!(test) {
98 return Ok(data_loader().await?);
99 }
100
101 let now = SystemTime::now()
102 .duration_since(UNIX_EPOCH)
103 .expect("Time went backwards")
104 .as_secs();
105 let is_expired = |expired_when: u64| now > expired_when;
106
107 let owned_key = key.to_owned();
108 macro_rules! awaiting_get_and_set {
109 () => {{
110 let new_expired_when = now + expire_seconds as u64;
111
112 let new_value = data_loader().await?;
113
114 let _: () = self
115 .hset(&owned_key, "expired_when", new_expired_when)
116 .await?;
117 let _: () = self.hset(&owned_key, "value", &new_value).await?;
118
119 let result: Result<V> = Ok(new_value);
120
121 result
122 }};
123 }
124
125 let expired_when: Result<Option<u64>> = Ok(self.hget(key, "expired_when").await?);
126 let value: Result<Option<V>> = Ok(self.hget(key, "value").await?);
127
128 match (expired_when, value) {
129 (Ok(Some(expired_when)), Ok(Some(value))) if !is_expired(expired_when) => Ok(value),
130 (Ok(Some(_)), Ok(Some(value))) => {
131 tokio::spawn(async move {
132 if let Err(e) = async { awaiting_get_and_set!() }.await {
133 error!("Failed to load and set in background: {}", e);
134 }
135 });
136
137 Ok(value)
138 }
139 (Ok(None), _) | (_, Ok(None)) => {
140 awaiting_get_and_set!()
141 }
142 (Err(err), _) | (_, Err(err)) => {
143 error!("redis error: {:?}", err);
144
145 awaiting_get_and_set!()
146 }
147 }
148 }
149}
150
151#[derive(Error, Debug)]
152pub enum Error {
153 #[error("data error")]
154 Data(#[from] anyhow::Error),
155 #[error("redis error")]
156 Redis(#[from] RedisError),
157 #[error("cluster connection error")]
158 Cluster(#[from] RunError<RedisError>),
159}
160
161pub type Result<T> = std::result::Result<T, Error>;
162
163macro_rules! invalid_type_error {
164 ($v:expr, $det:expr) => {
165 RedisError::from((
166 ErrorKind::TypeError,
167 "Response was of incompatible type",
168 format!("{:?} (response was {:?})", $det, $v),
169 ))
170 };
171}
172
173pub struct VecRedisValue<T: Serialize + DeserializeOwned>(pub Vec<T>);
174
175impl<T: Serialize + DeserializeOwned> std::ops::Deref for VecRedisValue<T> {
176 type Target = Vec<T>;
177
178 fn deref(&self) -> &Self::Target {
179 &self.0
180 }
181}
182
183impl<T: Serialize + DeserializeOwned> From<Vec<T>> for VecRedisValue<T> {
184 fn from(value: Vec<T>) -> Self {
185 Self(value)
186 }
187}
188
189impl<T: Serialize + DeserializeOwned> ToRedisArgs for VecRedisValue<T> {
190 fn write_redis_args<W>(&self, out: &mut W)
191 where
192 W: ?Sized + redis_rs::RedisWrite,
193 {
194 let string = json!(self.0).to_string();
195 out.write_arg(string.as_bytes())
196 }
197}
198
199impl<T: Serialize + DeserializeOwned> FromRedisValue for VecRedisValue<T> {
200 fn from_redis_value(v: &redis_rs::Value) -> redis_rs::RedisResult<Self> {
201 match *v {
202 redis_rs::Value::Data(ref bytes) => {
203 let json = from_utf8(bytes)?.to_string();
204 let result = serde_json::from_str::<Vec<T>>(&json).map_err(|err| {
205 invalid_type_error!(
206 v,
207 format!(
208 "Could not deserialize into {} struct with err {}.",
209 stringify!($t),
210 err
211 )
212 )
213 })?;
214 Ok(VecRedisValue(result))
215 }
216 _ => Err(invalid_type_error!(
217 v,
218 format!("Could not deserialize into {} struct.", stringify!($t))
219 )),
220 }
221 }
222}
223
224mod connection {
225 use async_trait::async_trait;
226 use bb8_redis::bb8;
227 use redis_rs::aio::ConnectionLike;
228 use redis_rs::IntoConnectionInfo;
229 use redis_rs::RedisError;
230 use redis_rs::RedisResult;
231 use serde::Deserialize;
232
233 use super::Result;
234
235 pub type Pool = bb8::Pool<RedisClusterConnectionManager>;
236 pub type Connection = bb8::PooledConnection<'static, RedisClusterConnectionManager>;
237
238 #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
239 pub struct RedisConfig {
240 pub hosts_csv: String,
241 pub expire_seconds: usize,
242 pub max_connections: u32,
243 }
244
245 impl RedisConfig {
246 fn hosts(&self) -> Vec<&str> {
247 self.hosts_csv.split(',').collect()
248 }
249
250 pub async fn init_pool(&self) -> Result<Pool> {
251 Ok(bb8::Pool::builder()
252 .max_size(self.max_connections)
253 .build(RedisClusterConnectionManager::new(self.hosts())?)
254 .await?)
255 }
256 }
257
258 pub struct RedisClusterConnectionManager {
259 client: redis_cluster_async::Client,
260 }
261
262 impl RedisClusterConnectionManager {
263 pub fn new<T: IntoConnectionInfo>(info: Vec<T>) -> Result<Self> {
264 Ok(RedisClusterConnectionManager {
265 client: redis_cluster_async::Client::open(info)?,
266 })
267 }
268 }
269
270 #[async_trait]
271 impl bb8::ManageConnection for RedisClusterConnectionManager {
272 type Connection = redis_cluster_async::Connection;
273 type Error = RedisError;
274
275 async fn connect(&self) -> RedisResult<Self::Connection> {
276 self.client.get_connection().await
277 }
278
279 async fn is_valid(&self, connection: &mut Self::Connection) -> RedisResult<()> {
280 connection
281 .req_packed_command(&redis_rs::cmd("PING"))
282 .await
283 .and_then(check_is_pong)
284 }
285
286 fn has_broken(&self, _: &mut Self::Connection) -> bool {
287 false
288 }
289 }
290
291 fn check_is_pong(value: redis_rs::Value) -> RedisResult<()> {
292 match value {
293 redis_rs::Value::Status(string) if &string == "PONG" => RedisResult::Ok(()),
294 _ => RedisResult::Err(RedisError::from((
295 redis_rs::ErrorKind::ResponseError,
296 "ping request",
297 ))),
298 }
299 }
300}