1#![doc = include_str!("../README.md")]
2
3use std::time::Duration;
4
5use basteh::{
6 dev::{Action, Mutation, OwnedValue, Provider, Value},
7 BastehError, Result,
8};
9use bytes::BytesMut;
10use redis::{aio::ConnectionManager, AsyncCommands, FromRedisValue, RedisResult, ToRedisArgs};
11
12pub use redis::{ConnectionAddr, ConnectionInfo, ErrorKind, RedisConnectionInfo, RedisError};
13use utils::run_mutations;
14
15mod utils;
16
17#[inline]
18fn get_full_key(scope: impl AsRef<[u8]>, key: impl AsRef<[u8]>) -> Vec<u8> {
19 [scope.as_ref(), b":", key.as_ref()].concat()
20}
21
22#[derive(Clone)]
47pub struct RedisBackend {
48 con: ConnectionManager,
49}
50
51impl RedisBackend {
52 pub async fn connect(connection_info: ConnectionInfo) -> RedisResult<Self> {
54 let client = redis::Client::open(connection_info)?;
55 let con = client.get_tokio_connection_manager().await?;
56 Ok(Self { con })
57 }
58
59 pub async fn connect_default() -> RedisResult<Self> {
61 Self::connect("redis://127.0.0.1/".parse()?).await
62 }
63}
64
65#[async_trait::async_trait]
66impl Provider for RedisBackend {
67 async fn keys(&self, scope: &str) -> Result<Box<dyn Iterator<Item = Vec<u8>>>> {
68 let keys = self
69 .con
70 .clone()
71 .keys::<_, Vec<Vec<u8>>>([scope, ":*"].concat())
72 .await
73 .map_err(BastehError::custom)?
74 .into_iter()
75 .map(move |k| {
76 let ignored = scope.len() + 1;
77 k[ignored..].to_vec()
78 })
79 .collect::<Vec<_>>();
80 Ok(Box::new(keys.into_iter()))
81 }
82
83 async fn set(&self, scope: &str, key: &[u8], value: Value<'_>) -> Result<()> {
84 let full_key = get_full_key(scope, key);
85 match value {
86 Value::List(l) => {
87 redis::pipe()
88 .del(&full_key)
89 .rpush(
90 full_key,
91 l.into_iter().map(ValueWrapper).collect::<Vec<_>>(),
92 )
93 .query_async(&mut self.con.clone())
94 .await
95 .map_err(BastehError::custom)?;
96 }
97 _ => {
98 self.con
99 .clone()
100 .set(full_key, ValueWrapper(value))
101 .await
102 .map_err(BastehError::custom)?;
103 }
104 }
105 Ok(())
106 }
107
108 async fn get(&self, scope: &str, key: &[u8]) -> Result<Option<OwnedValue>> {
109 let full_key = get_full_key(scope, key);
110 self.con
111 .clone()
112 .get::<_, OwnedValueWrapper>(full_key)
113 .await
114 .map(|v| v.0)
115 .map_err(BastehError::custom)
116 }
117
118 async fn get_range(
119 &self,
120 scope: &str,
121 key: &[u8],
122 start: i64,
123 end: i64,
124 ) -> Result<Vec<OwnedValue>> {
125 let full_key = get_full_key(scope, key);
126 self.con
127 .clone()
128 .lrange::<_, OwnedValueWrapper>(full_key, start as isize, end as isize)
129 .await
130 .map(|v| v.0)
131 .map_err(BastehError::custom)
132 .and_then(|v| match v {
133 Some(OwnedValue::List(l)) => Ok(l),
134 Some(OwnedValue::Bytes(b)) => Ok(b
135 .into_iter()
136 .map(Into::<Value>::into)
137 .map(|v| v.into_owned())
138 .collect::<Vec<_>>()),
139 _ => Err(BastehError::TypeConversion),
140 })
141 }
142
143 async fn push(&self, scope: &str, key: &[u8], value: Value<'_>) -> Result<()> {
144 let full_key = get_full_key(scope, key);
145 self.con
146 .clone()
147 .rpush(full_key, ValueWrapper(value))
148 .await
149 .map_err(BastehError::custom)?;
150 Ok(())
151 }
152
153 async fn push_multiple(&self, scope: &str, key: &[u8], value: Vec<Value<'_>>) -> Result<()> {
154 let full_key = get_full_key(scope, key);
155 self.con
156 .clone()
157 .rpush(
158 full_key,
159 value.into_iter().map(ValueWrapper).collect::<Vec<_>>(),
160 )
161 .await
162 .map_err(BastehError::custom)?;
163 Ok(())
164 }
165
166 async fn pop(&self, scope: &str, key: &[u8]) -> Result<Option<OwnedValue>> {
167 let full_key = get_full_key(scope, key);
168 self.con
169 .clone()
170 .rpop::<_, OwnedValueWrapper>(full_key, None)
171 .await
172 .map(|v| v.0)
173 .map_err(BastehError::custom)
174 }
175
176 async fn mutate(&self, scope: &str, key: &[u8], mutations: Mutation) -> Result<i64> {
177 let full_key = get_full_key(scope, key);
178
179 if mutations.len() == 0 {
180 let mut con = self.con.clone();
181
182 let res = con
184 .get::<_, Option<i64>>(&full_key)
185 .await
186 .map_err(BastehError::custom)?;
187
188 if let Some(res) = res {
189 Ok(res)
190 } else {
191 con.set(full_key, 0__i64)
192 .await
193 .map_err(BastehError::custom)?;
194 Ok(0)
195 }
196 } else if mutations.len() == 1 {
197 match mutations.into_iter().next().unwrap() {
198 Action::Incr(delta) => self
199 .con
200 .clone()
201 .incr(full_key, delta)
202 .await
203 .map_err(BastehError::custom),
204 Action::Decr(delta) => self
205 .con
206 .clone()
207 .decr(full_key, delta)
208 .await
209 .map_err(BastehError::custom),
210 Action::Set(value) => {
211 self.con
212 .clone()
213 .set(full_key, value)
214 .await
215 .map_err(BastehError::custom)?;
216 return Ok(value);
217 }
218 action => run_mutations(self.con.clone(), full_key, [action])
219 .await
220 .map_err(|e| BastehError::Custom(Box::new(e))),
221 }
222 } else {
223 run_mutations(self.con.clone(), full_key, mutations.into_iter())
224 .await
225 .map_err(|e| BastehError::Custom(Box::new(e)))
226 }
227 }
228
229 async fn remove(&self, scope: &str, key: &[u8]) -> Result<Option<OwnedValue>> {
230 let full_key = get_full_key(scope, key);
231 Ok(redis::pipe()
232 .get(&full_key)
233 .del(full_key)
234 .ignore()
235 .query_async::<_, Vec<OwnedValueWrapper>>(&mut self.con.clone())
236 .await
237 .map_err(BastehError::custom)?
238 .into_iter()
239 .next()
240 .and_then(|v| v.0))
241 }
242
243 async fn contains_key(&self, scope: &str, key: &[u8]) -> Result<bool> {
244 let full_key = get_full_key(scope, key);
245 let res: u8 = self
246 .con
247 .clone()
248 .exists(full_key)
249 .await
250 .map_err(BastehError::custom)?;
251 Ok(res > 0)
252 }
253
254 async fn persist(&self, scope: &str, key: &[u8]) -> Result<()> {
255 let full_key = get_full_key(scope, key);
256 self.con
257 .clone()
258 .persist(full_key)
259 .await
260 .map_err(BastehError::custom)?;
261 Ok(())
262 }
263
264 async fn expiry(&self, scope: &str, key: &[u8]) -> Result<Option<Duration>> {
265 let full_key = get_full_key(scope, key);
266 let res: i32 = self
267 .con
268 .clone()
269 .ttl(full_key)
270 .await
271 .map_err(BastehError::custom)?;
272 Ok(if res >= 0 {
273 Some(Duration::from_secs(res as u64))
274 } else {
275 None
276 })
277 }
278
279 async fn expire(&self, scope: &str, key: &[u8], expire_in: Duration) -> Result<()> {
280 let full_key = get_full_key(scope, key);
281 self.con
282 .clone()
283 .expire(full_key, expire_in.as_secs() as usize)
284 .await
285 .map_err(BastehError::custom)?;
286 Ok(())
287 }
288
289 async fn set_expiring(
290 &self,
291 scope: &str,
292 key: &[u8],
293 value: Value<'_>,
294 expire_in: Duration,
295 ) -> Result<()> {
296 let full_key = get_full_key(scope, key);
297 self.con
298 .clone()
299 .set_ex(full_key, ValueWrapper(value), expire_in.as_secs() as usize)
300 .await
301 .map_err(BastehError::custom)?;
302 Ok(())
303 }
304}
305
306struct ValueWrapper<'a>(Value<'a>);
307
308impl<'a> ToRedisArgs for ValueWrapper<'a> {
309 fn write_redis_args<W>(&self, out: &mut W)
310 where
311 W: ?Sized + redis::RedisWrite,
312 {
313 match &self.0 {
314 Value::Number(n) => <i64 as ToRedisArgs>::write_redis_args(&n, out),
315 Value::Bytes(b) => <&[u8] as ToRedisArgs>::write_redis_args(&b.as_ref(), out),
316 Value::String(s) => <&str as ToRedisArgs>::write_redis_args(&s.as_ref(), out),
317 Value::List(l) => {
318 for item in l {
319 ValueWrapper(item.clone()).write_redis_args(out);
320 }
321 }
322 }
323 }
324}
325struct OwnedValueWrapper(Option<OwnedValue>);
326
327impl<'a> FromRedisValue for OwnedValueWrapper {
328 fn from_redis_value(v: &redis::Value) -> RedisResult<OwnedValueWrapper> {
329 Ok(OwnedValueWrapper(match v {
330 redis::Value::Nil => None,
332 _ => Some(
334 <i64 as FromRedisValue>::from_redis_value(v)
335 .map(OwnedValue::Number)
336 .or_else(|_| {
337 <String as FromRedisValue>::from_redis_value(v).map(OwnedValue::String)
338 })
339 .or_else(|_| match v {
340 redis::Value::Data(bytes_vec) => {
341 Ok(OwnedValue::Bytes(BytesMut::from(bytes_vec.as_slice())))
342 }
343 _ => Err(RedisError::from((
344 redis::ErrorKind::TypeError,
345 "Response was of incompatible type",
346 ))),
347 })
348 .or_else(|_| {
349 <Vec<OwnedValueWrapper> as FromRedisValue>::from_redis_value(v)
350 .map(|v| v.into_iter().filter_map(|v| v.0).collect())
351 .map(OwnedValue::List)
352 })?,
353 ),
354 }))
355 }
356}
357
358#[cfg(test)]
359mod test {
360 use super::*;
361 use basteh::test_utils::*;
362 use std::sync::Once;
363
364 static INIT: Once = Once::new();
365
366 async fn get_connection() -> RedisBackend {
367 let con = RedisBackend::connect_default().await;
368 match con {
369 Ok(con) => {
370 INIT.call_once(|| {
371 let mut client = redis::Client::open("redis://localhost").unwrap();
372 let _: () = redis::cmd("FLUSHDB").query(&mut client).unwrap();
373 });
374 con
375 }
376 Err(err) => panic!("{:?}", err),
377 }
378 }
379
380 #[tokio::test]
381 async fn test_redis_store() {
382 test_store(get_connection().await).await;
383 }
384
385 #[tokio::test]
386 async fn test_redis_mutations() {
387 test_mutations(get_connection().await).await;
388 }
389
390 #[tokio::test]
391 async fn test_redis_expiry() {
392 test_expiry(get_connection().await, 5).await;
393 }
394
395 #[tokio::test]
396 async fn test_redis_expiry_store() {
397 test_expiry_store(get_connection().await, 5).await;
398 }
399}