1use displaydoc::Display;
2use redis::{FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, Value};
3use std::{
4 cmp::Ord,
5 fmt,
6 hash::Hash,
7 num::ParseIntError,
8 str::FromStr,
9 time::{Duration, SystemTime},
10};
11use strum::IntoStaticStr;
12use thiserror::Error;
13
14#[derive(Debug, Display, Error)]
15pub enum Error {
16 InvalidMessageId(ParseIntError, ParseIntError, String),
18 InvalidMessageIdTimestamp(ParseIntError, String),
20 InvalidMessageIdSequence(ParseIntError, String),
22 MalformedMessageId(String),
24}
25
26#[derive(Debug, Clone)]
27pub struct ReadStream<'s, S>
28where
29 &'s S: ToRedisArgs,
30{
31 pub id: &'s S,
32 pub offset: MessageId,
33}
34
35#[derive(Debug, Copy, Clone, Default)]
37pub struct WriteStreamOptions {
38 pub disable_create: bool,
40 pub capacity: Option<(Trim, TrimPrecision)>,
42}
43
44#[derive(Debug, Copy, Clone, Default)]
46pub struct ReadStreamOptions {
47 pub block: Option<Duration>,
49 pub count: Option<u64>,
51}
52
53#[derive(Debug, Copy, Clone)]
54pub enum Trim {
55 Length(u64),
56 Id(MessageId),
57}
58
59#[derive(Debug, Copy, Clone, IntoStaticStr)]
60pub enum TrimPrecision {
61 #[strum(to_string = "=")]
62 Exact,
63 #[strum(to_string = "~")]
64 Approximate,
65}
66
67impl ToRedisArgs for &WriteStreamOptions {
68 fn write_redis_args<W>(&self, out: &mut W)
69 where
70 W: ?Sized + RedisWrite,
71 {
72 if self.disable_create {
73 out.write_arg(b"NOMKSTREAM")
74 }
75 match self.capacity {
76 Some((Trim::Length(len), precision)) => {
77 let precision: &'static str = precision.into();
78 out.write_arg(b"MAXLEN");
79 out.write_arg(precision.as_bytes());
80 out.write_arg(len.to_string().as_bytes());
81 }
82 Some((Trim::Id(id), precision)) => {
83 let precision: &'static str = precision.into();
84 out.write_arg(b"MINID");
85 out.write_arg(precision.as_bytes());
86 out.write_arg(id.to_string().as_bytes());
87 }
88 None => {}
89 }
90 }
91}
92
93impl ToRedisArgs for &ReadStreamOptions {
94 fn write_redis_args<W>(&self, out: &mut W)
95 where
96 W: ?Sized + RedisWrite,
97 {
98 if let Some(count) = self.count {
99 out.write_arg(b"COUNT");
100 out.write_arg(count.to_string().as_bytes());
101 }
102 if let Some(block) = self.block {
103 out.write_arg(b"BLOCK");
104 out.write_arg(block.as_millis().to_string().as_bytes());
105 }
106 }
107}
108
109#[derive(Debug)]
110pub struct StreamReadReply<S: FromRedisValue, T: FromRedisValue>(pub Vec<StreamItems<S, T>>);
111
112#[derive(Debug)]
113pub struct StreamItems<S: FromRedisValue, T: FromRedisValue> {
114 pub id: S,
115 pub items: Vec<StreamItem<T>>,
116}
117
118#[derive(Debug)]
119pub struct StreamItem<T: FromRedisValue> {
120 pub offset: MessageId,
121 pub payload: Result<T, RedisError>,
122}
123
124#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
125pub struct MessageId {
126 pub timestamp_ms: u128,
127 pub sequence: u64,
128}
129
130impl MessageId {
131 pub fn now() -> Self {
132 let now = SystemTime::now()
133 .duration_since(std::time::UNIX_EPOCH)
134 .expect("System time before Unix epoch");
135 Self {
136 timestamp_ms: now.as_millis(),
137 sequence: 0,
138 }
139 }
140}
141
142impl FromStr for MessageId {
143 type Err = Error;
144
145 fn from_str(s: &str) -> Result<Self, Self::Err> {
146 let parts = s.split('-').collect::<Vec<&str>>();
147 if parts.len() != 2 {
148 return Err(Error::MalformedMessageId(s.to_string()));
149 }
150
151 match (parts[0].parse::<u128>(), parts[1].parse::<u64>()) {
152 (Ok(timestamp_ms), Ok(sequence)) => Ok(MessageId {
153 timestamp_ms,
154 sequence,
155 }),
156 (Err(ts_err), Ok(_)) => Err(Error::InvalidMessageIdTimestamp(ts_err, s.to_string())),
157 (Ok(_), Err(sequence_err)) => {
158 Err(Error::InvalidMessageIdSequence(sequence_err, s.to_string()))
159 }
160 (Err(ts_err), Err(sequence_err)) => {
161 Err(Error::InvalidMessageId(ts_err, sequence_err, s.to_string()))
162 }
163 }
164 }
165}
166
167impl fmt::Display for MessageId {
168 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169 write!(f, "{}-{}", self.timestamp_ms, self.sequence)
170 }
171}
172
173impl FromRedisValue for MessageId {
174 fn from_redis_value(v: &redis::Value) -> redis::RedisResult<Self> {
175 String::from_redis_value(v)?
176 .parse::<MessageId>()
177 .map_err(|e| {
178 redis::RedisError::from((
179 redis::ErrorKind::TypeError,
180 "invalid value for MessageId",
181 e.to_string(),
182 ))
183 })
184 }
185}
186
187impl ToRedisArgs for MessageId {
188 fn write_redis_args<W>(&self, out: &mut W)
189 where
190 W: ?Sized + RedisWrite,
191 {
192 out.write_arg(self.to_string().as_bytes())
193 }
194}
195
196impl<S, T> FromRedisValue for StreamReadReply<S, T>
197where
198 S: FromRedisValue + Eq + Hash,
199 T: FromRedisValue,
200{
201 fn from_redis_value(v: &Value) -> RedisResult<Self> {
202 match v {
203 Value::Bulk(bulk) => {
204 let stream_reply = bulk
205 .iter()
206 .map(|streams| match streams {
207 redis::Value::Bulk(bulk) => {
208 let mut iter = bulk.iter();
209 let id = get_next_value::<_, S>(
210 &mut iter,
211 "missing key `id` in `StreamItems` response",
212 )?;
213 let items = get_next_value::<_, Vec<StreamItem<T>>>(
214 &mut iter,
215 "missing key `items` in `StreamItems` response",
216 )?;
217 Ok(StreamItems { id, items })
218 }
219 _ => Err(redis::RedisError::from((
220 redis::ErrorKind::TypeError,
221 "expecting Value::Bulk for `StreamItems` response",
222 ))),
223 })
224 .collect::<Result<Vec<_>, _>>()?;
225
226 Ok(Self(stream_reply))
227 }
228 Value::Nil => Ok(Self(vec![])),
229 _ => Err(redis::RedisError::from((
230 redis::ErrorKind::TypeError,
231 "expecting Value::Bulk for `StreamReadReply` response",
232 ))),
233 }
234 }
235}
236
237impl<T> FromRedisValue for StreamItem<T>
238where
239 T: FromRedisValue,
240{
241 fn from_redis_value(v: &Value) -> RedisResult<Self> {
242 match v {
243 redis::Value::Bulk(bulk) => {
244 let mut iter = bulk.iter();
245 let offset = get_next_value::<_, MessageId>(
246 &mut iter,
247 "missing key `offset` in `StreamItem` response",
248 )?;
249 let payload = get_next_value::<_, T>(
250 &mut iter,
251 "missing key `payload` in `StreamItem` response",
252 );
253
254 Ok(StreamItem { offset, payload })
255 }
256 _ => Err(redis::RedisError::from((
257 redis::ErrorKind::TypeError,
258 "expecting Value::Bulk for `StreamItem` response",
259 ))),
260 }
261 }
262}
263
264fn get_next_value<'v, I, T>(iter: &mut I, err_msg: &'static str) -> Result<T, redis::RedisError>
265where
266 I: Iterator<Item = &'v Value>,
267 T: FromRedisValue,
268{
269 iter.next()
270 .ok_or_else(|| redis::RedisError::from((redis::ErrorKind::TypeError, err_msg)))
271 .and_then(T::from_redis_value)
272}
273
274#[cfg(test)]
275mod tests {
276 use super::{
277 MessageId, ReadStreamOptions, StreamReadReply, Trim, TrimPrecision, WriteStreamOptions,
278 };
279 use redis::{FromRedisValue, ToRedisArgs, Value};
280 use std::{collections::HashMap, str, time::Duration};
281
282 #[test]
283 fn stream_write_message_id_to_args() {
284 let bytes = MessageId {
285 timestamp_ms: 111,
286 sequence: 22,
287 }
288 .to_redis_args();
289 assert_eq!(bytes.len(), 1);
290 assert_eq!(str::from_utf8(bytes[0].as_slice()).unwrap(), "111-22");
291 }
292
293 #[test]
294 fn stream_write_options() {
295 let bytes = (&WriteStreamOptions::default()).to_redis_args();
296 assert_eq!(bytes.len(), 0);
297 let bytes = (&WriteStreamOptions {
298 disable_create: true,
299 capacity: Some((Trim::Length(10), TrimPrecision::Approximate)),
300 })
301 .to_redis_args();
302 assert_eq!(bytes.len(), 4);
303 assert_eq!(str::from_utf8(bytes[0].as_slice()).unwrap(), "NOMKSTREAM");
304 assert_eq!(str::from_utf8(bytes[1].as_slice()).unwrap(), "MAXLEN");
305 assert_eq!(str::from_utf8(bytes[2].as_slice()).unwrap(), "~");
306 assert_eq!(str::from_utf8(bytes[3].as_slice()).unwrap(), "10");
307 }
308
309 #[test]
310 fn stream_read_options() {
311 let bytes = (&ReadStreamOptions::default()).to_redis_args();
312 assert_eq!(bytes.len(), 0);
313 let bytes = (&ReadStreamOptions {
314 block: Some(Duration::from_secs(1)),
315 count: Some(50),
316 })
317 .to_redis_args();
318 assert_eq!(bytes.len(), 4);
319 assert_eq!(str::from_utf8(bytes[0].as_slice()).unwrap(), "COUNT");
320 assert_eq!(str::from_utf8(bytes[1].as_slice()).unwrap(), "50");
321 assert_eq!(str::from_utf8(bytes[2].as_slice()).unwrap(), "BLOCK");
322 assert_eq!(str::from_utf8(bytes[3].as_slice()).unwrap(), "1000");
323 }
324
325 #[test]
326 fn stream_message_id_ordering() {
327 let msg = MessageId {
328 timestamp_ms: 5,
329 sequence: 0,
330 };
331
332 assert!(
333 msg > MessageId {
334 timestamp_ms: 4,
335 sequence: 0
336 }
337 );
338 assert!(
339 msg == MessageId {
340 timestamp_ms: 5,
341 sequence: 0
342 }
343 );
344 assert!(
345 msg < MessageId {
346 timestamp_ms: 5,
347 sequence: 1
348 }
349 );
350 assert!(
351 msg < MessageId {
352 timestamp_ms: 6,
353 sequence: 0
354 }
355 );
356 }
357 #[test]
358 fn stream_message_id_display() {
359 assert_eq!(
360 MessageId {
361 timestamp_ms: 1000,
362 sequence: 3
363 }
364 .to_string(),
365 "1000-3".to_string()
366 )
367 }
368
369 #[test]
370 fn stream_message_id_deser() {
371 let raw = Value::Data(b"1-0".to_vec());
372 let deserialized = MessageId::from_redis_value(&raw);
373 assert_eq!(
374 deserialized.unwrap(),
375 MessageId {
376 timestamp_ms: 1,
377 sequence: 0
378 }
379 );
380 let raw = Value::Data(b"1636634305271-1".to_vec());
381 let deserialized = MessageId::from_redis_value(&raw);
382 assert_eq!(
383 deserialized.unwrap(),
384 MessageId {
385 timestamp_ms: 1636634305271,
386 sequence: 1
387 }
388 );
389
390 let raw = Value::Data(b"18446744073709551615-99".to_vec());
391 let deserialized = MessageId::from_redis_value(&raw);
392 assert_eq!(
393 deserialized.unwrap(),
394 MessageId {
395 timestamp_ms: 18446744073709551615,
396 sequence: 99
397 }
398 );
399
400 let raw = Value::Data(b"123".to_vec());
402 let deserialized = MessageId::from_redis_value(&raw);
403 assert!(deserialized.is_err());
404 let raw = Value::Data(b"123-a".to_vec());
405 let deserialized = MessageId::from_redis_value(&raw);
406 assert!(deserialized.is_err());
407 let raw = Value::Data(b"a23-0".to_vec());
408 let deserialized = MessageId::from_redis_value(&raw);
409 assert!(deserialized.is_err());
410 }
411
412 #[test]
413 fn stream_read_reply_deser_nil() {
414 let nil = Value::Nil;
415 let deserialized =
416 StreamReadReply::<String, Vec<(String, String)>>::from_redis_value(&nil).unwrap();
417 assert!(deserialized.0.is_empty());
418 }
419
420 #[test]
421 fn stream_read_reply_deser_one() {
422 let raw = Value::Bulk(vec![Value::Bulk(vec![
423 Value::Data(b"stream-id".to_vec()),
424 Value::Bulk(vec![Value::Bulk(vec![
425 Value::Data(b"1000-2".to_vec()),
426 Value::Bulk(vec![
427 Value::Data(b"key".to_vec()),
428 Value::Data(b"value".to_vec()),
429 Value::Data(b"key2".to_vec()),
430 Value::Data(b"value".to_vec()),
431 ]),
432 ])]),
433 ])]);
434 let deserialized =
435 StreamReadReply::<String, HashMap<String, String>>::from_redis_value(&raw).unwrap();
436 assert_eq!(deserialized.0.len(), 1);
437 assert_eq!(deserialized.0[0].id, "stream-id".to_string());
438 assert_eq!(deserialized.0[0].items.len(), 1);
439 assert_eq!(
440 deserialized.0[0].items[0].offset,
441 MessageId {
442 timestamp_ms: 1000,
443 sequence: 2
444 }
445 );
446 assert_eq!(
447 deserialized.0[0].items[0].payload,
448 Ok(HashMap::from([
449 ("key".to_string(), "value".to_string()),
450 ("key2".to_string(), "value".to_string())
451 ]))
452 );
453 }
454
455 #[test]
456 fn stream_read_reply_deser_many() {
457 let raw = Value::Bulk(vec![
458 Value::Bulk(vec![
459 Value::Data(b"stream-0".to_vec()),
460 Value::Bulk(vec![
461 Value::Bulk(vec![
462 Value::Data(b"0-0".to_vec()),
463 Value::Bulk(vec![
464 Value::Data(b"key0-0".to_vec()),
465 Value::Data(b"value0-0".to_vec()),
466 Value::Data(b"key0-1".to_vec()),
467 Value::Data(b"value0-1".to_vec()),
468 ]),
469 ]),
470 Value::Bulk(vec![
471 Value::Data(b"0-1".to_vec()),
472 Value::Bulk(vec![
473 Value::Data(b"key0-0".to_vec()),
474 Value::Data(b"value0-0".to_vec()),
475 ]),
476 ]),
477 ]),
478 ]),
479 Value::Bulk(vec![
480 Value::Data(b"stream-1".to_vec()),
481 Value::Bulk(vec![Value::Bulk(vec![
482 Value::Data(b"1-0".to_vec()),
483 Value::Bulk(vec![]),
484 ])]),
485 ]),
486 Value::Bulk(vec![
487 Value::Data(b"stream-2".to_vec()),
488 Value::Bulk(vec![Value::Bulk(vec![
489 Value::Data(b"2-0".to_vec()),
490 Value::Bulk(vec![
491 Value::Data(b"key2-0".to_vec()),
492 Value::Data(b"value2-0".to_vec()),
493 Value::Data(b"key2-1".to_vec()),
494 Value::Data(b"value2-1".to_vec()),
495 Value::Data(b"key2-2".to_vec()),
496 Value::Data(b"value2-2".to_vec()),
497 ]),
498 ])]),
499 ]),
500 ]);
501
502 let deserialized =
503 StreamReadReply::<String, HashMap<String, String>>::from_redis_value(&raw).unwrap();
504 assert_eq!(deserialized.0.len(), 3);
505 for (stream_idx, stream_reply) in deserialized.0.into_iter().enumerate() {
506 assert_eq!(stream_reply.id, format!("stream-{}", stream_idx));
507 for (item_idx, item) in stream_reply.items.into_iter().enumerate() {
508 assert_eq!(
509 item.offset,
510 MessageId {
511 timestamp_ms: stream_idx as u128,
512 sequence: item_idx as u64
513 }
514 );
515 let payload = item.payload.unwrap();
516 assert_eq!(
517 payload,
518 std::iter::repeat(0)
519 .take(payload.len())
520 .enumerate()
521 .map(|(idx, _)| {
522 (
523 format!("key{}-{}", stream_idx, idx),
524 format!("value{}-{}", stream_idx, idx),
525 )
526 })
527 .collect()
528 );
529 }
530 }
531 }
532}