redis_cacher/
stream.rs

1use displaydoc::Display;
2use redis::{FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, Value};
3use std::{
4    cmp::Ord,
5    fmt,
6    hash::Hash,
7    num::ParseIntError,
8    str::FromStr,
9    time::{Duration, SystemTime},
10};
11use strum::IntoStaticStr;
12use thiserror::Error;
13
14#[derive(Debug, Display, Error)]
15pub enum Error {
16    /// Invalid message id: timestamp={0} sequence={1} id={2}
17    InvalidMessageId(ParseIntError, ParseIntError, String),
18    /// Invalid message id timestamp: timestamp={0} id={1}
19    InvalidMessageIdTimestamp(ParseIntError, String),
20    /// Invalid message id sequence number: sequence={0} id={1}
21    InvalidMessageIdSequence(ParseIntError, String),
22    /// Malformed message id: id={0}
23    MalformedMessageId(String),
24}
25
26#[derive(Debug, Clone)]
27pub struct ReadStream<'s, S>
28where
29    &'s S: ToRedisArgs,
30{
31    pub id: &'s S,
32    pub offset: MessageId,
33}
34
35/// Options for XADD command
36#[derive(Debug, Copy, Clone, Default)]
37pub struct WriteStreamOptions {
38    // Do not create new stream if nonexistent
39    pub disable_create: bool,
40    // Args for capping the stream length during a write
41    pub capacity: Option<(Trim, TrimPrecision)>,
42}
43
44/// Options for XREAD command
45#[derive(Debug, Copy, Clone, Default)]
46pub struct ReadStreamOptions {
47    // Block until response available or timeout
48    pub block: Option<Duration>,
49    // Maximum items returned from read
50    pub count: Option<u64>,
51}
52
53#[derive(Debug, Copy, Clone)]
54pub enum Trim {
55    Length(u64),
56    Id(MessageId),
57}
58
59#[derive(Debug, Copy, Clone, IntoStaticStr)]
60pub enum TrimPrecision {
61    #[strum(to_string = "=")]
62    Exact,
63    #[strum(to_string = "~")]
64    Approximate,
65}
66
67impl ToRedisArgs for &WriteStreamOptions {
68    fn write_redis_args<W>(&self, out: &mut W)
69    where
70        W: ?Sized + RedisWrite,
71    {
72        if self.disable_create {
73            out.write_arg(b"NOMKSTREAM")
74        }
75        match self.capacity {
76            Some((Trim::Length(len), precision)) => {
77                let precision: &'static str = precision.into();
78                out.write_arg(b"MAXLEN");
79                out.write_arg(precision.as_bytes());
80                out.write_arg(len.to_string().as_bytes());
81            }
82            Some((Trim::Id(id), precision)) => {
83                let precision: &'static str = precision.into();
84                out.write_arg(b"MINID");
85                out.write_arg(precision.as_bytes());
86                out.write_arg(id.to_string().as_bytes());
87            }
88            None => {}
89        }
90    }
91}
92
93impl ToRedisArgs for &ReadStreamOptions {
94    fn write_redis_args<W>(&self, out: &mut W)
95    where
96        W: ?Sized + RedisWrite,
97    {
98        if let Some(count) = self.count {
99            out.write_arg(b"COUNT");
100            out.write_arg(count.to_string().as_bytes());
101        }
102        if let Some(block) = self.block {
103            out.write_arg(b"BLOCK");
104            out.write_arg(block.as_millis().to_string().as_bytes());
105        }
106    }
107}
108
109#[derive(Debug)]
110pub struct StreamReadReply<S: FromRedisValue, T: FromRedisValue>(pub Vec<StreamItems<S, T>>);
111
112#[derive(Debug)]
113pub struct StreamItems<S: FromRedisValue, T: FromRedisValue> {
114    pub id: S,
115    pub items: Vec<StreamItem<T>>,
116}
117
118#[derive(Debug)]
119pub struct StreamItem<T: FromRedisValue> {
120    pub offset: MessageId,
121    pub payload: Result<T, RedisError>,
122}
123
124#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
125pub struct MessageId {
126    pub timestamp_ms: u128,
127    pub sequence: u64,
128}
129
130impl MessageId {
131    pub fn now() -> Self {
132        let now = SystemTime::now()
133            .duration_since(std::time::UNIX_EPOCH)
134            .expect("System time before Unix epoch");
135        Self {
136            timestamp_ms: now.as_millis(),
137            sequence: 0,
138        }
139    }
140}
141
142impl FromStr for MessageId {
143    type Err = Error;
144
145    fn from_str(s: &str) -> Result<Self, Self::Err> {
146        let parts = s.split('-').collect::<Vec<&str>>();
147        if parts.len() != 2 {
148            return Err(Error::MalformedMessageId(s.to_string()));
149        }
150
151        match (parts[0].parse::<u128>(), parts[1].parse::<u64>()) {
152            (Ok(timestamp_ms), Ok(sequence)) => Ok(MessageId {
153                timestamp_ms,
154                sequence,
155            }),
156            (Err(ts_err), Ok(_)) => Err(Error::InvalidMessageIdTimestamp(ts_err, s.to_string())),
157            (Ok(_), Err(sequence_err)) => {
158                Err(Error::InvalidMessageIdSequence(sequence_err, s.to_string()))
159            }
160            (Err(ts_err), Err(sequence_err)) => {
161                Err(Error::InvalidMessageId(ts_err, sequence_err, s.to_string()))
162            }
163        }
164    }
165}
166
167impl fmt::Display for MessageId {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        write!(f, "{}-{}", self.timestamp_ms, self.sequence)
170    }
171}
172
173impl FromRedisValue for MessageId {
174    fn from_redis_value(v: &redis::Value) -> redis::RedisResult<Self> {
175        String::from_redis_value(v)?
176            .parse::<MessageId>()
177            .map_err(|e| {
178                redis::RedisError::from((
179                    redis::ErrorKind::TypeError,
180                    "invalid value for MessageId",
181                    e.to_string(),
182                ))
183            })
184    }
185}
186
187impl ToRedisArgs for MessageId {
188    fn write_redis_args<W>(&self, out: &mut W)
189    where
190        W: ?Sized + RedisWrite,
191    {
192        out.write_arg(self.to_string().as_bytes())
193    }
194}
195
196impl<S, T> FromRedisValue for StreamReadReply<S, T>
197where
198    S: FromRedisValue + Eq + Hash,
199    T: FromRedisValue,
200{
201    fn from_redis_value(v: &Value) -> RedisResult<Self> {
202        match v {
203            Value::Bulk(bulk) => {
204                let stream_reply = bulk
205                    .iter()
206                    .map(|streams| match streams {
207                        redis::Value::Bulk(bulk) => {
208                            let mut iter = bulk.iter();
209                            let id = get_next_value::<_, S>(
210                                &mut iter,
211                                "missing key `id` in `StreamItems` response",
212                            )?;
213                            let items = get_next_value::<_, Vec<StreamItem<T>>>(
214                                &mut iter,
215                                "missing key `items` in `StreamItems` response",
216                            )?;
217                            Ok(StreamItems { id, items })
218                        }
219                        _ => Err(redis::RedisError::from((
220                            redis::ErrorKind::TypeError,
221                            "expecting Value::Bulk for `StreamItems` response",
222                        ))),
223                    })
224                    .collect::<Result<Vec<_>, _>>()?;
225
226                Ok(Self(stream_reply))
227            }
228            Value::Nil => Ok(Self(vec![])),
229            _ => Err(redis::RedisError::from((
230                redis::ErrorKind::TypeError,
231                "expecting Value::Bulk for `StreamReadReply` response",
232            ))),
233        }
234    }
235}
236
237impl<T> FromRedisValue for StreamItem<T>
238where
239    T: FromRedisValue,
240{
241    fn from_redis_value(v: &Value) -> RedisResult<Self> {
242        match v {
243            redis::Value::Bulk(bulk) => {
244                let mut iter = bulk.iter();
245                let offset = get_next_value::<_, MessageId>(
246                    &mut iter,
247                    "missing key `offset` in `StreamItem` response",
248                )?;
249                let payload = get_next_value::<_, T>(
250                    &mut iter,
251                    "missing key `payload` in `StreamItem` response",
252                );
253
254                Ok(StreamItem { offset, payload })
255            }
256            _ => Err(redis::RedisError::from((
257                redis::ErrorKind::TypeError,
258                "expecting Value::Bulk for `StreamItem` response",
259            ))),
260        }
261    }
262}
263
264fn get_next_value<'v, I, T>(iter: &mut I, err_msg: &'static str) -> Result<T, redis::RedisError>
265where
266    I: Iterator<Item = &'v Value>,
267    T: FromRedisValue,
268{
269    iter.next()
270        .ok_or_else(|| redis::RedisError::from((redis::ErrorKind::TypeError, err_msg)))
271        .and_then(T::from_redis_value)
272}
273
274#[cfg(test)]
275mod tests {
276    use super::{
277        MessageId, ReadStreamOptions, StreamReadReply, Trim, TrimPrecision, WriteStreamOptions,
278    };
279    use redis::{FromRedisValue, ToRedisArgs, Value};
280    use std::{collections::HashMap, str, time::Duration};
281
282    #[test]
283    fn stream_write_message_id_to_args() {
284        let bytes = MessageId {
285            timestamp_ms: 111,
286            sequence: 22,
287        }
288        .to_redis_args();
289        assert_eq!(bytes.len(), 1);
290        assert_eq!(str::from_utf8(bytes[0].as_slice()).unwrap(), "111-22");
291    }
292
293    #[test]
294    fn stream_write_options() {
295        let bytes = (&WriteStreamOptions::default()).to_redis_args();
296        assert_eq!(bytes.len(), 0);
297        let bytes = (&WriteStreamOptions {
298            disable_create: true,
299            capacity: Some((Trim::Length(10), TrimPrecision::Approximate)),
300        })
301            .to_redis_args();
302        assert_eq!(bytes.len(), 4);
303        assert_eq!(str::from_utf8(bytes[0].as_slice()).unwrap(), "NOMKSTREAM");
304        assert_eq!(str::from_utf8(bytes[1].as_slice()).unwrap(), "MAXLEN");
305        assert_eq!(str::from_utf8(bytes[2].as_slice()).unwrap(), "~");
306        assert_eq!(str::from_utf8(bytes[3].as_slice()).unwrap(), "10");
307    }
308
309    #[test]
310    fn stream_read_options() {
311        let bytes = (&ReadStreamOptions::default()).to_redis_args();
312        assert_eq!(bytes.len(), 0);
313        let bytes = (&ReadStreamOptions {
314            block: Some(Duration::from_secs(1)),
315            count: Some(50),
316        })
317            .to_redis_args();
318        assert_eq!(bytes.len(), 4);
319        assert_eq!(str::from_utf8(bytes[0].as_slice()).unwrap(), "COUNT");
320        assert_eq!(str::from_utf8(bytes[1].as_slice()).unwrap(), "50");
321        assert_eq!(str::from_utf8(bytes[2].as_slice()).unwrap(), "BLOCK");
322        assert_eq!(str::from_utf8(bytes[3].as_slice()).unwrap(), "1000");
323    }
324
325    #[test]
326    fn stream_message_id_ordering() {
327        let msg = MessageId {
328            timestamp_ms: 5,
329            sequence: 0,
330        };
331
332        assert!(
333            msg > MessageId {
334                timestamp_ms: 4,
335                sequence: 0
336            }
337        );
338        assert!(
339            msg == MessageId {
340                timestamp_ms: 5,
341                sequence: 0
342            }
343        );
344        assert!(
345            msg < MessageId {
346                timestamp_ms: 5,
347                sequence: 1
348            }
349        );
350        assert!(
351            msg < MessageId {
352                timestamp_ms: 6,
353                sequence: 0
354            }
355        );
356    }
357    #[test]
358    fn stream_message_id_display() {
359        assert_eq!(
360            MessageId {
361                timestamp_ms: 1000,
362                sequence: 3
363            }
364            .to_string(),
365            "1000-3".to_string()
366        )
367    }
368
369    #[test]
370    fn stream_message_id_deser() {
371        let raw = Value::Data(b"1-0".to_vec());
372        let deserialized = MessageId::from_redis_value(&raw);
373        assert_eq!(
374            deserialized.unwrap(),
375            MessageId {
376                timestamp_ms: 1,
377                sequence: 0
378            }
379        );
380        let raw = Value::Data(b"1636634305271-1".to_vec());
381        let deserialized = MessageId::from_redis_value(&raw);
382        assert_eq!(
383            deserialized.unwrap(),
384            MessageId {
385                timestamp_ms: 1636634305271,
386                sequence: 1
387            }
388        );
389
390        let raw = Value::Data(b"18446744073709551615-99".to_vec());
391        let deserialized = MessageId::from_redis_value(&raw);
392        assert_eq!(
393            deserialized.unwrap(),
394            MessageId {
395                timestamp_ms: 18446744073709551615,
396                sequence: 99
397            }
398        );
399
400        // invalid
401        let raw = Value::Data(b"123".to_vec());
402        let deserialized = MessageId::from_redis_value(&raw);
403        assert!(deserialized.is_err());
404        let raw = Value::Data(b"123-a".to_vec());
405        let deserialized = MessageId::from_redis_value(&raw);
406        assert!(deserialized.is_err());
407        let raw = Value::Data(b"a23-0".to_vec());
408        let deserialized = MessageId::from_redis_value(&raw);
409        assert!(deserialized.is_err());
410    }
411
412    #[test]
413    fn stream_read_reply_deser_nil() {
414        let nil = Value::Nil;
415        let deserialized =
416            StreamReadReply::<String, Vec<(String, String)>>::from_redis_value(&nil).unwrap();
417        assert!(deserialized.0.is_empty());
418    }
419
420    #[test]
421    fn stream_read_reply_deser_one() {
422        let raw = Value::Bulk(vec![Value::Bulk(vec![
423            Value::Data(b"stream-id".to_vec()),
424            Value::Bulk(vec![Value::Bulk(vec![
425                Value::Data(b"1000-2".to_vec()),
426                Value::Bulk(vec![
427                    Value::Data(b"key".to_vec()),
428                    Value::Data(b"value".to_vec()),
429                    Value::Data(b"key2".to_vec()),
430                    Value::Data(b"value".to_vec()),
431                ]),
432            ])]),
433        ])]);
434        let deserialized =
435            StreamReadReply::<String, HashMap<String, String>>::from_redis_value(&raw).unwrap();
436        assert_eq!(deserialized.0.len(), 1);
437        assert_eq!(deserialized.0[0].id, "stream-id".to_string());
438        assert_eq!(deserialized.0[0].items.len(), 1);
439        assert_eq!(
440            deserialized.0[0].items[0].offset,
441            MessageId {
442                timestamp_ms: 1000,
443                sequence: 2
444            }
445        );
446        assert_eq!(
447            deserialized.0[0].items[0].payload,
448            Ok(HashMap::from([
449                ("key".to_string(), "value".to_string()),
450                ("key2".to_string(), "value".to_string())
451            ]))
452        );
453    }
454
455    #[test]
456    fn stream_read_reply_deser_many() {
457        let raw = Value::Bulk(vec![
458            Value::Bulk(vec![
459                Value::Data(b"stream-0".to_vec()),
460                Value::Bulk(vec![
461                    Value::Bulk(vec![
462                        Value::Data(b"0-0".to_vec()),
463                        Value::Bulk(vec![
464                            Value::Data(b"key0-0".to_vec()),
465                            Value::Data(b"value0-0".to_vec()),
466                            Value::Data(b"key0-1".to_vec()),
467                            Value::Data(b"value0-1".to_vec()),
468                        ]),
469                    ]),
470                    Value::Bulk(vec![
471                        Value::Data(b"0-1".to_vec()),
472                        Value::Bulk(vec![
473                            Value::Data(b"key0-0".to_vec()),
474                            Value::Data(b"value0-0".to_vec()),
475                        ]),
476                    ]),
477                ]),
478            ]),
479            Value::Bulk(vec![
480                Value::Data(b"stream-1".to_vec()),
481                Value::Bulk(vec![Value::Bulk(vec![
482                    Value::Data(b"1-0".to_vec()),
483                    Value::Bulk(vec![]),
484                ])]),
485            ]),
486            Value::Bulk(vec![
487                Value::Data(b"stream-2".to_vec()),
488                Value::Bulk(vec![Value::Bulk(vec![
489                    Value::Data(b"2-0".to_vec()),
490                    Value::Bulk(vec![
491                        Value::Data(b"key2-0".to_vec()),
492                        Value::Data(b"value2-0".to_vec()),
493                        Value::Data(b"key2-1".to_vec()),
494                        Value::Data(b"value2-1".to_vec()),
495                        Value::Data(b"key2-2".to_vec()),
496                        Value::Data(b"value2-2".to_vec()),
497                    ]),
498                ])]),
499            ]),
500        ]);
501
502        let deserialized =
503            StreamReadReply::<String, HashMap<String, String>>::from_redis_value(&raw).unwrap();
504        assert_eq!(deserialized.0.len(), 3);
505        for (stream_idx, stream_reply) in deserialized.0.into_iter().enumerate() {
506            assert_eq!(stream_reply.id, format!("stream-{}", stream_idx));
507            for (item_idx, item) in stream_reply.items.into_iter().enumerate() {
508                assert_eq!(
509                    item.offset,
510                    MessageId {
511                        timestamp_ms: stream_idx as u128,
512                        sequence: item_idx as u64
513                    }
514                );
515                let payload = item.payload.unwrap();
516                assert_eq!(
517                    payload,
518                    std::iter::repeat(0)
519                        .take(payload.len())
520                        .enumerate()
521                        .map(|(idx, _)| {
522                            (
523                                format!("key{}-{}", stream_idx, idx),
524                                format!("value{}-{}", stream_idx, idx),
525                            )
526                        })
527                        .collect()
528                );
529            }
530        }
531    }
532}