basteh_redis/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::time::Duration;
4
5use basteh::{
6    dev::{Action, Mutation, OwnedValue, Provider, Value},
7    BastehError, Result,
8};
9use bytes::BytesMut;
10use redis::{aio::ConnectionManager, AsyncCommands, FromRedisValue, RedisResult, ToRedisArgs};
11
12pub use redis::{ConnectionAddr, ConnectionInfo, ErrorKind, RedisConnectionInfo, RedisError};
13use utils::run_mutations;
14
15mod utils;
16
17#[inline]
18fn get_full_key(scope: impl AsRef<[u8]>, key: impl AsRef<[u8]>) -> Vec<u8> {
19    [scope.as_ref(), b":", key.as_ref()].concat()
20}
21
22/// An implementation of [`ExpiryStore`](basteh::dev::ExpiryStore) based on redis
23/// using redis-rs async runtime
24///
25/// ## Example
26/// ```no_run
27/// use basteh::Basteh;
28/// use basteh_redis::{RedisBackend, ConnectionInfo, RedisConnectionInfo, ConnectionAddr};
29///
30/// # async fn your_main() {
31/// let provider = RedisBackend::connect_default();
32/// // OR
33/// let connection_info = ConnectionInfo {
34///     addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 1234).into(),
35///     redis: RedisConnectionInfo{
36///         db: 0,
37///         username: Some("god".to_string()),
38///         password: Some("bless".to_string()),
39///     }
40/// };
41/// let provider = RedisBackend::connect(connection_info).await.expect("Redis connection failed");
42/// let basteh = Basteh::build().provider(provider).finish();
43/// # }
44/// ```
45///
46#[derive(Clone)]
47pub struct RedisBackend {
48    con: ConnectionManager,
49}
50
51impl RedisBackend {
52    /// Connect using the provided connection info
53    pub async fn connect(connection_info: ConnectionInfo) -> RedisResult<Self> {
54        let client = redis::Client::open(connection_info)?;
55        let con = client.get_tokio_connection_manager().await?;
56        Ok(Self { con })
57    }
58
59    /// Connect using the default redis port on local machine
60    pub async fn connect_default() -> RedisResult<Self> {
61        Self::connect("redis://127.0.0.1/".parse()?).await
62    }
63}
64
65#[async_trait::async_trait]
66impl Provider for RedisBackend {
67    async fn keys(&self, scope: &str) -> Result<Box<dyn Iterator<Item = Vec<u8>>>> {
68        let keys = self
69            .con
70            .clone()
71            .keys::<_, Vec<Vec<u8>>>([scope, ":*"].concat())
72            .await
73            .map_err(BastehError::custom)?
74            .into_iter()
75            .map(move |k| {
76                let ignored = scope.len() + 1;
77                k[ignored..].to_vec()
78            })
79            .collect::<Vec<_>>();
80        Ok(Box::new(keys.into_iter()))
81    }
82
83    async fn set(&self, scope: &str, key: &[u8], value: Value<'_>) -> Result<()> {
84        let full_key = get_full_key(scope, key);
85        match value {
86            Value::List(l) => {
87                redis::pipe()
88                    .del(&full_key)
89                    .rpush(
90                        full_key,
91                        l.into_iter().map(ValueWrapper).collect::<Vec<_>>(),
92                    )
93                    .query_async(&mut self.con.clone())
94                    .await
95                    .map_err(BastehError::custom)?;
96            }
97            _ => {
98                self.con
99                    .clone()
100                    .set(full_key, ValueWrapper(value))
101                    .await
102                    .map_err(BastehError::custom)?;
103            }
104        }
105        Ok(())
106    }
107
108    async fn get(&self, scope: &str, key: &[u8]) -> Result<Option<OwnedValue>> {
109        let full_key = get_full_key(scope, key);
110        self.con
111            .clone()
112            .get::<_, OwnedValueWrapper>(full_key)
113            .await
114            .map(|v| v.0)
115            .map_err(BastehError::custom)
116    }
117
118    async fn get_range(
119        &self,
120        scope: &str,
121        key: &[u8],
122        start: i64,
123        end: i64,
124    ) -> Result<Vec<OwnedValue>> {
125        let full_key = get_full_key(scope, key);
126        self.con
127            .clone()
128            .lrange::<_, OwnedValueWrapper>(full_key, start as isize, end as isize)
129            .await
130            .map(|v| v.0)
131            .map_err(BastehError::custom)
132            .and_then(|v| match v {
133                Some(OwnedValue::List(l)) => Ok(l),
134                Some(OwnedValue::Bytes(b)) => Ok(b
135                    .into_iter()
136                    .map(Into::<Value>::into)
137                    .map(|v| v.into_owned())
138                    .collect::<Vec<_>>()),
139                _ => Err(BastehError::TypeConversion),
140            })
141    }
142
143    async fn push(&self, scope: &str, key: &[u8], value: Value<'_>) -> Result<()> {
144        let full_key = get_full_key(scope, key);
145        self.con
146            .clone()
147            .rpush(full_key, ValueWrapper(value))
148            .await
149            .map_err(BastehError::custom)?;
150        Ok(())
151    }
152
153    async fn push_multiple(&self, scope: &str, key: &[u8], value: Vec<Value<'_>>) -> Result<()> {
154        let full_key = get_full_key(scope, key);
155        self.con
156            .clone()
157            .rpush(
158                full_key,
159                value.into_iter().map(ValueWrapper).collect::<Vec<_>>(),
160            )
161            .await
162            .map_err(BastehError::custom)?;
163        Ok(())
164    }
165
166    async fn pop(&self, scope: &str, key: &[u8]) -> Result<Option<OwnedValue>> {
167        let full_key = get_full_key(scope, key);
168        self.con
169            .clone()
170            .rpop::<_, OwnedValueWrapper>(full_key, None)
171            .await
172            .map(|v| v.0)
173            .map_err(BastehError::custom)
174    }
175
176    async fn mutate(&self, scope: &str, key: &[u8], mutations: Mutation) -> Result<i64> {
177        let full_key = get_full_key(scope, key);
178
179        if mutations.len() == 0 {
180            let mut con = self.con.clone();
181
182            // Get the value or set to 0 and return
183            let res = con
184                .get::<_, Option<i64>>(&full_key)
185                .await
186                .map_err(BastehError::custom)?;
187
188            if let Some(res) = res {
189                Ok(res)
190            } else {
191                con.set(full_key, 0__i64)
192                    .await
193                    .map_err(BastehError::custom)?;
194                Ok(0)
195            }
196        } else if mutations.len() == 1 {
197            match mutations.into_iter().next().unwrap() {
198                Action::Incr(delta) => self
199                    .con
200                    .clone()
201                    .incr(full_key, delta)
202                    .await
203                    .map_err(BastehError::custom),
204                Action::Decr(delta) => self
205                    .con
206                    .clone()
207                    .decr(full_key, delta)
208                    .await
209                    .map_err(BastehError::custom),
210                Action::Set(value) => {
211                    self.con
212                        .clone()
213                        .set(full_key, value)
214                        .await
215                        .map_err(BastehError::custom)?;
216                    return Ok(value);
217                }
218                action => run_mutations(self.con.clone(), full_key, [action])
219                    .await
220                    .map_err(|e| BastehError::Custom(Box::new(e))),
221            }
222        } else {
223            run_mutations(self.con.clone(), full_key, mutations.into_iter())
224                .await
225                .map_err(|e| BastehError::Custom(Box::new(e)))
226        }
227    }
228
229    async fn remove(&self, scope: &str, key: &[u8]) -> Result<Option<OwnedValue>> {
230        let full_key = get_full_key(scope, key);
231        Ok(redis::pipe()
232            .get(&full_key)
233            .del(full_key)
234            .ignore()
235            .query_async::<_, Vec<OwnedValueWrapper>>(&mut self.con.clone())
236            .await
237            .map_err(BastehError::custom)?
238            .into_iter()
239            .next()
240            .and_then(|v| v.0))
241    }
242
243    async fn contains_key(&self, scope: &str, key: &[u8]) -> Result<bool> {
244        let full_key = get_full_key(scope, key);
245        let res: u8 = self
246            .con
247            .clone()
248            .exists(full_key)
249            .await
250            .map_err(BastehError::custom)?;
251        Ok(res > 0)
252    }
253
254    async fn persist(&self, scope: &str, key: &[u8]) -> Result<()> {
255        let full_key = get_full_key(scope, key);
256        self.con
257            .clone()
258            .persist(full_key)
259            .await
260            .map_err(BastehError::custom)?;
261        Ok(())
262    }
263
264    async fn expiry(&self, scope: &str, key: &[u8]) -> Result<Option<Duration>> {
265        let full_key = get_full_key(scope, key);
266        let res: i32 = self
267            .con
268            .clone()
269            .ttl(full_key)
270            .await
271            .map_err(BastehError::custom)?;
272        Ok(if res >= 0 {
273            Some(Duration::from_secs(res as u64))
274        } else {
275            None
276        })
277    }
278
279    async fn expire(&self, scope: &str, key: &[u8], expire_in: Duration) -> Result<()> {
280        let full_key = get_full_key(scope, key);
281        self.con
282            .clone()
283            .expire(full_key, expire_in.as_secs() as usize)
284            .await
285            .map_err(BastehError::custom)?;
286        Ok(())
287    }
288
289    async fn set_expiring(
290        &self,
291        scope: &str,
292        key: &[u8],
293        value: Value<'_>,
294        expire_in: Duration,
295    ) -> Result<()> {
296        let full_key = get_full_key(scope, key);
297        self.con
298            .clone()
299            .set_ex(full_key, ValueWrapper(value), expire_in.as_secs() as usize)
300            .await
301            .map_err(BastehError::custom)?;
302        Ok(())
303    }
304}
305
306struct ValueWrapper<'a>(Value<'a>);
307
308impl<'a> ToRedisArgs for ValueWrapper<'a> {
309    fn write_redis_args<W>(&self, out: &mut W)
310    where
311        W: ?Sized + redis::RedisWrite,
312    {
313        match &self.0 {
314            Value::Number(n) => <i64 as ToRedisArgs>::write_redis_args(&n, out),
315            Value::Bytes(b) => <&[u8] as ToRedisArgs>::write_redis_args(&b.as_ref(), out),
316            Value::String(s) => <&str as ToRedisArgs>::write_redis_args(&s.as_ref(), out),
317            Value::List(l) => {
318                for item in l {
319                    ValueWrapper(item.clone()).write_redis_args(out);
320                }
321            }
322        }
323    }
324}
325struct OwnedValueWrapper(Option<OwnedValue>);
326
327impl<'a> FromRedisValue for OwnedValueWrapper {
328    fn from_redis_value(v: &redis::Value) -> RedisResult<OwnedValueWrapper> {
329        Ok(OwnedValueWrapper(match v {
330            // If it's Nil then return None
331            redis::Value::Nil => None,
332            // Otherwise try to decode as Number, String or Bytes in order
333            _ => Some(
334                <i64 as FromRedisValue>::from_redis_value(v)
335                    .map(OwnedValue::Number)
336                    .or_else(|_| {
337                        <String as FromRedisValue>::from_redis_value(v).map(OwnedValue::String)
338                    })
339                    .or_else(|_| match v {
340                        redis::Value::Data(bytes_vec) => {
341                            Ok(OwnedValue::Bytes(BytesMut::from(bytes_vec.as_slice())))
342                        }
343                        _ => Err(RedisError::from((
344                            redis::ErrorKind::TypeError,
345                            "Response was of incompatible type",
346                        ))),
347                    })
348                    .or_else(|_| {
349                        <Vec<OwnedValueWrapper> as FromRedisValue>::from_redis_value(v)
350                            .map(|v| v.into_iter().filter_map(|v| v.0).collect())
351                            .map(OwnedValue::List)
352                    })?,
353            ),
354        }))
355    }
356}
357
358#[cfg(test)]
359mod test {
360    use super::*;
361    use basteh::test_utils::*;
362    use std::sync::Once;
363
364    static INIT: Once = Once::new();
365
366    async fn get_connection() -> RedisBackend {
367        let con = RedisBackend::connect_default().await;
368        match con {
369            Ok(con) => {
370                INIT.call_once(|| {
371                    let mut client = redis::Client::open("redis://localhost").unwrap();
372                    let _: () = redis::cmd("FLUSHDB").query(&mut client).unwrap();
373                });
374                con
375            }
376            Err(err) => panic!("{:?}", err),
377        }
378    }
379
380    #[tokio::test]
381    async fn test_redis_store() {
382        test_store(get_connection().await).await;
383    }
384
385    #[tokio::test]
386    async fn test_redis_mutations() {
387        test_mutations(get_connection().await).await;
388    }
389
390    #[tokio::test]
391    async fn test_redis_expiry() {
392        test_expiry(get_connection().await, 5).await;
393    }
394
395    #[tokio::test]
396    async fn test_redis_expiry_store() {
397        test_expiry_store(get_connection().await, 5).await;
398    }
399}