redis_pubsub/
redis_sub.rs

1use std::cmp;
2use std::collections::HashSet;
3use std::time::Duration;
4
5use async_stream::stream;
6use rand::{thread_rng, Rng};
7use tokio::{
8    io::{AsyncReadExt, AsyncWriteExt},
9    net::{
10        tcp::{OwnedReadHalf, OwnedWriteHalf},
11        TcpStream,
12    },
13    sync::Mutex,
14    time::sleep,
15};
16use tokio_stream::Stream;
17
18use crate::{parser, Command, Message};
19
20/// Redis subscription object.
21/// This connects to the Redis server.
22#[derive(Debug)]
23pub struct RedisSub {
24    /// Address of the redis server.
25    addr: String,
26    /// Set of channels currently subscribed to.
27    channels: Mutex<HashSet<String>>,
28    /// Set of channels currently subscribed to by pattern.
29    pattern_channels: Mutex<HashSet<String>>,
30    /// TCP socket writer to write commands to.
31    writer: Mutex<Option<OwnedWriteHalf>>,
32}
33
34impl RedisSub {
35    /// Create the new Redis client.
36    /// This does not connect to the server, use `.listen()` for that.
37    #[must_use]
38    pub fn new(addr: &str) -> Self {
39        Self {
40            addr: addr.to_string(),
41            channels: Mutex::new(HashSet::new()),
42            pattern_channels: Mutex::new(HashSet::new()),
43            writer: Mutex::new(None),
44        }
45    }
46
47    /// Publish a message to a channel.
48    ///
49    /// # Errors
50    /// Returns an error if an error happens on the underlying TCP stream.
51    pub async fn publish(&self, channel: String, message: String) -> crate::Result<()> {
52        self.send_cmd(Command::Publish(channel, message)).await
53    }
54
55    /// Subscribe to a channel.
56    ///
57    /// # Errors
58    /// Returns an error if an error happens on the underlying TCP stream.
59    pub async fn subscribe(&self, channel: String) -> crate::Result<()> {
60        self.channels.lock().await.insert(channel.clone());
61
62        self.send_cmd(Command::Subscribe(channel)).await
63    }
64
65    /// Unsubscribe from a channel.
66    ///
67    /// # Errors
68    /// Returns an error if an error happens on the underlying TCP stream.
69    pub async fn unsubscribe(&self, channel: String) -> crate::Result<()> {
70        if !self.channels.lock().await.remove(&channel) {
71            return Err(crate::Error::NotSubscribed);
72        }
73
74        self.send_cmd(Command::Unsubscribe(channel)).await
75    }
76
77    /// Subscribe to a pattern of channels.
78    ///
79    /// # Errors
80    /// Returns an error if an error happens on the underlying TCP stream.
81    pub async fn psubscribe(&self, channel: String) -> crate::Result<()> {
82        self.pattern_channels.lock().await.insert(channel.clone());
83
84        self.send_cmd(Command::PatternSubscribe(channel)).await
85    }
86
87    /// Unsubscribe from a pattern of channels.
88    ///
89    /// # Errors
90    /// Returns an error if an error happens on the underlying TCP stream.
91    pub async fn punsubscribe(&self, channel: String) -> crate::Result<()> {
92        if !self.pattern_channels.lock().await.remove(&channel) {
93            return Err(crate::Error::NotSubscribed);
94        }
95
96        self.send_cmd(Command::PatternUnsubscribe(channel)).await
97    }
98
99    /// Connect to the Redis server specified by `self.addr`.
100    ///
101    /// Handles exponential backoff.
102    ///
103    /// Returns a split TCP stream.
104    ///
105    /// # Errors
106    /// Returns an error if attempting connection failed eight times.
107    pub(crate) async fn connect(
108        &self,
109        fail_fast: bool,
110    ) -> crate::Result<(OwnedReadHalf, OwnedWriteHalf)> {
111        let mut retry_count = 0;
112
113        loop {
114            // Generate jitter for the backoff function.
115            let jitter = thread_rng().gen_range(0..1000);
116            // Connect to the Redis server.
117            match TcpStream::connect(self.addr.as_str()).await {
118                Ok(stream) => return Ok(stream.into_split()),
119                Err(e) if fail_fast => return Err(crate::Error::IoError(e)),
120                Err(e) if retry_count <= 7 => {
121                    // Backoff and reconnect.
122                    warn!(
123                        "failed to connect to redis (attempt {}/8) {:?}",
124                        retry_count, e
125                    );
126                    retry_count += 1;
127                    let timeout = cmp::min(retry_count ^ 2, 64) * 1000 + jitter;
128                    sleep(Duration::from_millis(timeout)).await;
129                    continue;
130                }
131                Err(e) => {
132                    // Retry count has passed 7.
133                    // Assume connection failed and return.
134                    return Err(crate::Error::IoError(e));
135                }
136            };
137        }
138    }
139
140    async fn subscribe_stored(&self) -> crate::Result<()> {
141        for channel in self.channels.lock().await.iter() {
142            self.send_cmd(Command::Subscribe(channel.to_string()))
143                .await?;
144        }
145
146        for channel in self.pattern_channels.lock().await.iter() {
147            self.send_cmd(Command::PatternSubscribe(channel.to_string()))
148                .await?;
149        }
150
151        Ok(())
152    }
153
154    /// Listen for incoming messages.
155    /// Only here the server connects to the Redis server.
156    /// It handles reconnection and backoff for you.
157    ///
158    /// # Errors
159    /// Returns an error if the first connection attempt fails
160    pub async fn listen(&self) -> crate::Result<impl Stream<Item = Message> + '_> {
161        self.connect(true).await?;
162
163        Ok(Box::pin(stream! {
164            loop {
165                let (mut read, write) = match self.connect(false).await {
166                    Ok(t) => t,
167                    Err(e) => {
168                        warn!("failed to connect to server: {:?}", e);
169                        continue;
170                    }
171                };
172
173                // Update the stored writer.
174                {
175                    debug!("updating stored Redis TCP writer");
176                    let mut stored_writer = self.writer.lock().await;
177                    *stored_writer = Some(write);
178                }
179
180                // Subscribe to all stored channels
181                debug!("subscribing to stored channels after connect");
182                if let Err(e) = self.subscribe_stored().await {
183                    warn!("failed to subscribe to stored channels on connection, trying connection again... (err {:?})", e);
184                    continue;
185                }
186
187                // Yield a connect message to the library consumer.
188                yield Message::Connected;
189
190                // Create the read buffers.
191                let mut buf = [0; 64 * 1024];
192                let mut unread_buf = String::new();
193
194                'inner: loop {
195                    debug!("reading incoming data");
196                    // Read incoming data to the buffer.
197                    let res = match read.read(&mut buf).await {
198                        Ok(0) => Err(crate::Error::ZeroBytesRead),
199                        Ok(n) => Ok(n),
200                        Err(e) => Err(crate::Error::from(e)),
201                    };
202
203                    // Disconnect and reconnect if a write error occurred.
204                    let n = match res {
205                        Ok(n) => n,
206                        Err(e) => {
207                            *self.writer.lock().await = None;
208                            yield Message::Disconnected(e);
209                            break 'inner;
210                        }
211                    };
212
213                    let buf_data = match std::str::from_utf8(&buf[..n]) {
214                        Ok(d) => d,
215                        Err(e) => {
216                            yield Message::Error(e.into());
217                            continue;
218                        }
219                    };
220
221                    // Add the new data to the unread buffer.
222                    unread_buf.push_str(buf_data);
223                    // Parse the unread data.
224                    let parsed = parser::parse(&mut unread_buf);
225
226                    // Loop through the parsed commands.
227                    for res in parsed {
228                        debug!("new message");
229                        // Create a message from the parsed command and yield it.
230                        match Message::from_response(res) {
231                            Ok(msg) => yield msg,
232                            Err(e) => {
233                                warn!("failed to parse message: {:?}", e);
234                                continue;
235                            },
236                        };
237                    }
238                }
239            }
240        }))
241    }
242
243    /// Send a command to the server.
244    async fn send_cmd(&self, command: Command) -> crate::Result<()> {
245        if let Some(writer) = &mut *self.writer.lock().await {
246            writer.writable().await?;
247
248            debug!("sending command {:?} to redis", &command);
249            writer.write_all(command.to_string().as_bytes()).await?;
250        }
251
252        Ok(())
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use redis::AsyncCommands;
260    use tokio_stream::StreamExt;
261
262    async fn get_redis_connections() -> (redis::Client, redis::aio::Connection, RedisSub) {
263        let client =
264            redis::Client::open("redis://127.0.0.1/").expect("failed to create Redis client");
265        let connection = client
266            .get_tokio_connection()
267            .await
268            .expect("failed to open Redis connection");
269        let redis_sub = RedisSub::new("127.0.0.1:6379");
270        (client, connection, redis_sub)
271    }
272
273    #[tokio::test]
274    async fn test_redis_sub() {
275        let (_client, mut connection, redis_sub) = get_redis_connections().await;
276
277        redis_sub
278            .subscribe("1234".to_string())
279            .await
280            .expect("failed to subscribe to new Redis channel");
281        let f = tokio::spawn(async move {
282            {
283                let mut stream = redis_sub
284                    .listen()
285                    .await
286                    .expect("failed to connect to redis");
287
288                let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
289                    .await
290                    .expect("timeout duration of 500 milliseconds was exceeded")
291                    .expect("expected a Message");
292                assert!(
293                    msg.is_connected(),
294                    "message after opening stream was not `Connected`: {:?}",
295                    msg
296                );
297
298                let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
299                    .await
300                    .expect("timeout duration of 500 milliseconds was exceeded")
301                    .expect("expected a Message");
302                assert!(
303                    msg.is_subscription(),
304                    "message after connection was not `Subscription`: {:?}",
305                    msg
306                );
307
308                let msg = tokio::time::timeout(Duration::from_secs(2), stream.next())
309                    .await
310                    .expect("timeout duration of 2 seconds was exceeded")
311                    .expect("expected a Message");
312                assert!(
313                    msg.is_message(),
314                    "message after subscription was not `Message`: {:?}",
315                    msg
316                );
317                match msg {
318                    Message::Message { channel, message } => {
319                        assert_eq!(channel, "1234".to_string());
320                        assert_eq!(message, "1234".to_string());
321                    }
322                    _ => unreachable!("already checked this is message"),
323                }
324            }
325
326            redis_sub
327        });
328
329        // 100 milliseconds longer than the maximum timeout for Redis connection
330        tokio::time::sleep(Duration::from_millis(1100)).await;
331        connection
332            .publish::<&str, &str, u32>("1234", "1234")
333            .await
334            .expect("failed to send publish command to Redis");
335        let redis_sub = f.await.expect("background future failed");
336
337        let mut stream = redis_sub
338            .listen()
339            .await
340            .expect("failed to connect to redis");
341        let _ = stream.next().await;
342        let _ = stream.next().await;
343        redis_sub
344            .unsubscribe("1234".to_string())
345            .await
346            .expect("failed to unsubscribe from Redis channel");
347        let msg = stream.next().await.expect("expected a Message");
348        assert!(
349            msg.is_unsubscription(),
350            "message after unsubscription was not `Unsubscription`: {:?}",
351            msg
352        )
353    }
354
355    #[tokio::test]
356    pub async fn test_redis_pattern_sub() {
357        let (_client, mut connection, redis_sub) = get_redis_connections().await;
358
359        redis_sub
360            .psubscribe("*420*".to_string())
361            .await
362            .expect("failed to subscribe to new Redis channel");
363        let f = tokio::spawn(async move {
364            {
365                let mut stream = redis_sub
366                    .listen()
367                    .await
368                    .expect("failed to connect to redis");
369
370                let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
371                    .await
372                    .expect("timeout duration of 500 milliseconds was exceeded")
373                    .expect("expected a Message");
374                assert!(
375                    msg.is_connected(),
376                    "message after opening stream was not `Connected`: {:?}",
377                    msg
378                );
379
380                let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
381                    .await
382                    .expect("timeout duration of 500 milliseconds was exceeded")
383                    .expect("expected a Message");
384                assert!(
385                    msg.is_pattern_subscription(),
386                    "message after connection was not `PatternSubscription`: {:?}",
387                    msg
388                );
389
390                let msg = tokio::time::timeout(Duration::from_secs(2), stream.next())
391                    .await
392                    .expect("timeout duration of 2 seconds was exceeded")
393                    .expect("expected a Message");
394                assert!(
395                    msg.is_pattern_message(),
396                    "message after subscription was not `PatternMessage`: {:?}",
397                    msg
398                );
399                match msg {
400                    Message::PatternMessage {
401                        pattern,
402                        channel,
403                        message,
404                    } => {
405                        assert_eq!(pattern, "*420*".to_string());
406                        assert_eq!(channel, "64209".to_string());
407                        assert_eq!(message, "123456".to_string());
408                    }
409                    _ => unreachable!("already checked this is message"),
410                }
411            }
412
413            redis_sub
414        });
415
416        // 100 milliseconds longer than the maximum timeout for connection failure
417        tokio::time::sleep(Duration::from_millis(1100)).await;
418        connection
419            .publish::<&str, &str, u32>("64209", "123456")
420            .await
421            .expect("failed to send publish command to Redis");
422        let redis_sub = f.await.expect("background future failed");
423
424        let mut stream = redis_sub
425            .listen()
426            .await
427            .expect("failed to connect to redis");
428        let _ = stream.next().await;
429        let _ = stream.next().await;
430        redis_sub
431            .punsubscribe("*420*".to_string())
432            .await
433            .expect("failed to unsubscribe from Redis channel");
434        let msg = stream.next().await.expect("expected a Message");
435        assert!(
436            msg.is_pattern_unsubscription(),
437            "message after unsubscription was not `Unsubscription`: {:?}",
438            msg
439        )
440    }
441}