avantis_utils/
redis.rs

1use async_trait::async_trait;
2use bb8_redis::bb8::RunError;
3use redis_rs::{AsyncCommands, ErrorKind, FromRedisValue, RedisError, ToRedisArgs};
4use serde::{de::DeserializeOwned, Serialize};
5use serde_json::json;
6use std::{
7    future::Future,
8    str::from_utf8,
9    time::{SystemTime, UNIX_EPOCH},
10};
11use thiserror::Error;
12
13// TODO: tracing error wont works. find a new way to communicate to user that it works or not
14use tracing::error;
15
16pub use connection::Connection;
17pub use connection::Pool;
18pub use connection::RedisConfig;
19
20// TODO: add tests for VecRedisValue
21
22#[async_trait]
23pub trait GetOrFetchExt: AsyncCommands {
24    async fn get_or_fetch<K, V, F, Fut>(
25        &mut self,
26        key: K,
27        data_loader: F,
28        expire_seconds: usize,
29    ) -> Result<V>
30    where
31        K: ToRedisArgs + Send + Sync,
32        V: FromRedisValue + ToRedisArgs + Send + Sync,
33        F: FnOnce() -> Fut + Send,
34        Fut: Future<Output = anyhow::Result<V>> + Send;
35}
36
37#[async_trait]
38impl GetOrFetchExt for redis_cluster_async::Connection {
39    async fn get_or_fetch<K, V, F, Fut>(
40        &mut self,
41        key: K,
42        data_loader: F,
43        expire_seconds: usize,
44    ) -> Result<V>
45    where
46        K: ToRedisArgs + Send + Sync,
47        V: FromRedisValue + ToRedisArgs + Send + Sync,
48        F: FnOnce() -> Fut + Send,
49        Fut: Future<Output = anyhow::Result<V>> + Send,
50    {
51        if cfg!(test) {
52            return Ok(data_loader().await?);
53        }
54
55        match self.get(&key).await {
56            Ok(Some(bytes)) => Ok(bytes),
57            Ok(None) => {
58                let result = data_loader().await?;
59                self.set_ex(&key, &result, expire_seconds).await?;
60                Ok(result)
61            }
62            Err(err) => {
63                error!("redis error: {:?}", err);
64                Ok(data_loader().await?)
65            }
66        }
67    }
68}
69
70#[async_trait]
71pub trait GetOrRefreshExt {
72    async fn get_or_refresh<'a, V, F, Fut>(
73        mut self,
74        key: &str, // Would be nice if key is K: ToRedisArgs + Send + Sync instead.
75        data_loader: F,
76        expire_seconds: usize,
77    ) -> Result<V>
78    where
79        V: FromRedisValue + ToRedisArgs + Send + Sync + 'static,
80        F: FnOnce() -> Fut + Send + 'static,
81        Fut: Future<Output = anyhow::Result<V>> + Send;
82}
83
84#[async_trait]
85impl GetOrRefreshExt for connection::Connection {
86    async fn get_or_refresh<'a, V, F, Fut>(
87        mut self,
88        key: &str,
89        data_loader: F,
90        expire_seconds: usize,
91    ) -> Result<V>
92    where
93        V: FromRedisValue + ToRedisArgs + Send + Sync + 'static,
94        F: FnOnce() -> Fut + Send + 'static,
95        Fut: Future<Output = anyhow::Result<V>> + Send,
96    {
97        if cfg!(test) {
98            return Ok(data_loader().await?);
99        }
100
101        let now = SystemTime::now()
102            .duration_since(UNIX_EPOCH)
103            .expect("Time went backwards")
104            .as_secs();
105        let is_expired = |expired_when: u64| now > expired_when;
106
107        let owned_key = key.to_owned();
108        macro_rules! awaiting_get_and_set {
109            () => {{
110                let new_expired_when = now + expire_seconds as u64;
111
112                let new_value = data_loader().await?;
113
114                let _: () = self
115                    .hset(&owned_key, "expired_when", new_expired_when)
116                    .await?;
117                let _: () = self.hset(&owned_key, "value", &new_value).await?;
118
119                let result: Result<V> = Ok(new_value);
120
121                result
122            }};
123        }
124
125        let expired_when: Result<Option<u64>> = Ok(self.hget(key, "expired_when").await?);
126        let value: Result<Option<V>> = Ok(self.hget(key, "value").await?);
127
128        match (expired_when, value) {
129            (Ok(Some(expired_when)), Ok(Some(value))) if !is_expired(expired_when) => Ok(value),
130            (Ok(Some(_)), Ok(Some(value))) => {
131                tokio::spawn(async move {
132                    if let Err(e) = async { awaiting_get_and_set!() }.await {
133                        error!("Failed to load and set in background: {}", e);
134                    }
135                });
136
137                Ok(value)
138            }
139            (Ok(None), _) | (_, Ok(None)) => {
140                awaiting_get_and_set!()
141            }
142            (Err(err), _) | (_, Err(err)) => {
143                error!("redis error: {:?}", err);
144
145                awaiting_get_and_set!()
146            }
147        }
148    }
149}
150
151#[derive(Error, Debug)]
152pub enum Error {
153    #[error("data error")]
154    Data(#[from] anyhow::Error),
155    #[error("redis error")]
156    Redis(#[from] RedisError),
157    #[error("cluster connection error")]
158    Cluster(#[from] RunError<RedisError>),
159}
160
161pub type Result<T> = std::result::Result<T, Error>;
162
163macro_rules! invalid_type_error {
164    ($v:expr, $det:expr) => {
165        RedisError::from((
166            ErrorKind::TypeError,
167            "Response was of incompatible type",
168            format!("{:?} (response was {:?})", $det, $v),
169        ))
170    };
171}
172
173pub struct VecRedisValue<T: Serialize + DeserializeOwned>(pub Vec<T>);
174
175impl<T: Serialize + DeserializeOwned> std::ops::Deref for VecRedisValue<T> {
176    type Target = Vec<T>;
177
178    fn deref(&self) -> &Self::Target {
179        &self.0
180    }
181}
182
183impl<T: Serialize + DeserializeOwned> From<Vec<T>> for VecRedisValue<T> {
184    fn from(value: Vec<T>) -> Self {
185        Self(value)
186    }
187}
188
189impl<T: Serialize + DeserializeOwned> ToRedisArgs for VecRedisValue<T> {
190    fn write_redis_args<W>(&self, out: &mut W)
191    where
192        W: ?Sized + redis_rs::RedisWrite,
193    {
194        let string = json!(self.0).to_string();
195        out.write_arg(string.as_bytes())
196    }
197}
198
199impl<T: Serialize + DeserializeOwned> FromRedisValue for VecRedisValue<T> {
200    fn from_redis_value(v: &redis_rs::Value) -> redis_rs::RedisResult<Self> {
201        match *v {
202            redis_rs::Value::Data(ref bytes) => {
203                let json = from_utf8(bytes)?.to_string();
204                let result = serde_json::from_str::<Vec<T>>(&json).map_err(|err| {
205                    invalid_type_error!(
206                        v,
207                        format!(
208                            "Could not deserialize into {} struct with err {}.",
209                            stringify!($t),
210                            err
211                        )
212                    )
213                })?;
214                Ok(VecRedisValue(result))
215            }
216            _ => Err(invalid_type_error!(
217                v,
218                format!("Could not deserialize into {} struct.", stringify!($t))
219            )),
220        }
221    }
222}
223
224mod connection {
225    use async_trait::async_trait;
226    use bb8_redis::bb8;
227    use redis_rs::aio::ConnectionLike;
228    use redis_rs::IntoConnectionInfo;
229    use redis_rs::RedisError;
230    use redis_rs::RedisResult;
231    use serde::Deserialize;
232
233    use super::Result;
234
235    pub type Pool = bb8::Pool<RedisClusterConnectionManager>;
236    pub type Connection = bb8::PooledConnection<'static, RedisClusterConnectionManager>;
237
238    #[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
239    pub struct RedisConfig {
240        pub hosts_csv: String,
241        pub expire_seconds: usize,
242        pub max_connections: u32,
243    }
244
245    impl RedisConfig {
246        fn hosts(&self) -> Vec<&str> {
247            self.hosts_csv.split(',').collect()
248        }
249
250        pub async fn init_pool(&self) -> Result<Pool> {
251            Ok(bb8::Pool::builder()
252                .max_size(self.max_connections)
253                .build(RedisClusterConnectionManager::new(self.hosts())?)
254                .await?)
255        }
256    }
257
258    pub struct RedisClusterConnectionManager {
259        client: redis_cluster_async::Client,
260    }
261
262    impl RedisClusterConnectionManager {
263        pub fn new<T: IntoConnectionInfo>(info: Vec<T>) -> Result<Self> {
264            Ok(RedisClusterConnectionManager {
265                client: redis_cluster_async::Client::open(info)?,
266            })
267        }
268    }
269
270    #[async_trait]
271    impl bb8::ManageConnection for RedisClusterConnectionManager {
272        type Connection = redis_cluster_async::Connection;
273        type Error = RedisError;
274
275        async fn connect(&self) -> RedisResult<Self::Connection> {
276            self.client.get_connection().await
277        }
278
279        async fn is_valid(&self, connection: &mut Self::Connection) -> RedisResult<()> {
280            connection
281                .req_packed_command(&redis_rs::cmd("PING"))
282                .await
283                .and_then(check_is_pong)
284        }
285
286        fn has_broken(&self, _: &mut Self::Connection) -> bool {
287            false
288        }
289    }
290
291    fn check_is_pong(value: redis_rs::Value) -> RedisResult<()> {
292        match value {
293            redis_rs::Value::Status(string) if &string == "PONG" => RedisResult::Ok(()),
294            _ => RedisResult::Err(RedisError::from((
295                redis_rs::ErrorKind::ResponseError,
296                "ping request",
297            ))),
298        }
299    }
300}