redis_subscribe/
redis_sub.rs

1use std::cmp;
2use std::collections::HashSet;
3use std::time::Duration;
4
5use async_stream::stream;
6use rand::{thread_rng, Rng};
7use tokio::{
8    io::{AsyncReadExt, AsyncWriteExt},
9    net::{
10        tcp::{OwnedReadHalf, OwnedWriteHalf},
11        TcpStream,
12    },
13    sync::Mutex,
14    time::sleep,
15};
16use tokio_stream::Stream;
17
18use crate::{parser, Command, Message};
19
20/// Redis subscription object.
21/// This connects to the Redis server.
22#[derive(Debug)]
23pub struct RedisSub {
24    /// Address of the redis server.
25    addr: String,
26    /// Set of channels currently subscribed to.
27    channels: Mutex<HashSet<String>>,
28    /// Set of channels currently subscribed to by pattern.
29    pattern_channels: Mutex<HashSet<String>>,
30    /// TCP socket writer to write commands to.
31    writer: Mutex<Option<OwnedWriteHalf>>,
32}
33
34impl RedisSub {
35    /// Create the new Redis client.
36    /// This does not connect to the server, use `.listen()` for that.
37    #[must_use]
38    pub fn new(addr: &str) -> Self {
39        Self {
40            addr: addr.to_string(),
41            channels: Mutex::new(HashSet::new()),
42            pattern_channels: Mutex::new(HashSet::new()),
43            writer: Mutex::new(None),
44        }
45    }
46
47    /// Subscribe to a channel.
48    ///
49    /// # Errors
50    /// Returns an error if an error happens on the underlying TCP stream.
51    pub async fn subscribe(&self, channel: String) -> crate::Result<()> {
52        self.channels.lock().await.insert(channel.clone());
53
54        self.send_cmd(Command::Subscribe(channel)).await
55    }
56
57    /// Unsubscribe from a channel.
58    ///
59    /// # Errors
60    /// Returns an error if an error happens on the underlying TCP stream.
61    pub async fn unsubscribe(&self, channel: String) -> crate::Result<()> {
62        if !self.channels.lock().await.remove(&channel) {
63            return Err(crate::Error::NotSubscribed);
64        }
65
66        self.send_cmd(Command::Unsubscribe(channel)).await
67    }
68
69    /// Subscribe to a pattern of channels.
70    ///
71    /// # Errors
72    /// Returns an error if an error happens on the underlying TCP stream.
73    pub async fn psubscribe(&self, channel: String) -> crate::Result<()> {
74        self.pattern_channels.lock().await.insert(channel.clone());
75
76        self.send_cmd(Command::PatternSubscribe(channel)).await
77    }
78
79    /// Unsubscribe from a pattern of channels.
80    ///
81    /// # Errors
82    /// Returns an error if an error happens on the underlying TCP stream.
83    pub async fn punsubscribe(&self, channel: String) -> crate::Result<()> {
84        if !self.pattern_channels.lock().await.remove(&channel) {
85            return Err(crate::Error::NotSubscribed);
86        }
87
88        self.send_cmd(Command::PatternUnsubscribe(channel)).await
89    }
90
91    /// Connect to the Redis server specified by `self.addr`.
92    ///
93    /// Handles exponential backoff.
94    ///
95    /// Returns a split TCP stream.
96    ///
97    /// # Errors
98    /// Returns an error if attempting connection failed eight times.
99    pub(crate) async fn connect(
100        &self,
101        fail_fast: bool,
102    ) -> crate::Result<(OwnedReadHalf, OwnedWriteHalf)> {
103        let mut retry_count = 0;
104
105        loop {
106            // Generate jitter for the backoff function.
107            let jitter = thread_rng().gen_range(0..1000);
108            // Connect to the Redis server.
109            match TcpStream::connect(self.addr.as_str()).await {
110                Ok(stream) => return Ok(stream.into_split()),
111                Err(e) if fail_fast => return Err(crate::Error::IoError(e)),
112                Err(e) if retry_count <= 7 => {
113                    // Backoff and reconnect.
114                    warn!(
115                        "failed to connect to redis (attempt {}/8) {:?}",
116                        retry_count, e
117                    );
118                    retry_count += 1;
119                    let timeout = cmp::min(retry_count ^ 2, 64) * 1000 + jitter;
120                    sleep(Duration::from_millis(timeout)).await;
121                    continue;
122                }
123                Err(e) => {
124                    // Retry count has passed 7.
125                    // Assume connection failed and return.
126                    return Err(crate::Error::IoError(e));
127                }
128            };
129        }
130    }
131
132    async fn subscribe_stored(&self) -> crate::Result<()> {
133        for channel in self.channels.lock().await.iter() {
134            self.send_cmd(Command::Subscribe(channel.to_string()))
135                .await?;
136        }
137
138        for channel in self.pattern_channels.lock().await.iter() {
139            self.send_cmd(Command::PatternSubscribe(channel.to_string()))
140                .await?;
141        }
142
143        Ok(())
144    }
145
146    /// Listen for incoming messages.
147    /// Only here the server connects to the Redis server.
148    /// It handles reconnection and backoff for you.
149    ///
150    /// # Errors
151    /// Returns an error if the first connection attempt fails
152    pub async fn listen(&self) -> crate::Result<impl Stream<Item = Message> + '_> {
153        self.connect(true).await?;
154
155        Ok(Box::pin(stream! {
156            loop {
157                let (mut read, write) = match self.connect(false).await {
158                    Ok(t) => t,
159                    Err(e) => {
160                        warn!("failed to connect to server: {:?}", e);
161                        continue;
162                    }
163                };
164
165                // Update the stored writer.
166                {
167                    debug!("updating stored Redis TCP writer");
168                    let mut stored_writer = self.writer.lock().await;
169                    *stored_writer = Some(write);
170                }
171
172                // Subscribe to all stored channels
173                debug!("subscribing to stored channels after connect");
174                if let Err(e) = self.subscribe_stored().await {
175                    warn!("failed to subscribe to stored channels on connection, trying connection again... (err {:?})", e);
176                    continue;
177                }
178
179                // Yield a connect message to the library consumer.
180                yield Message::Connected;
181
182                // Create the read buffers.
183                let mut buf = [0; 64 * 1024];
184                let mut unread_buf = String::new();
185
186                'inner: loop {
187                    debug!("reading incoming data");
188                    // Read incoming data to the buffer.
189                    let res = match read.read(&mut buf).await {
190                        Ok(0) => Err(crate::Error::ZeroBytesRead),
191                        Ok(n) => Ok(n),
192                        Err(e) => Err(crate::Error::from(e)),
193                    };
194
195                    // Disconnect and reconnect if a write error occurred.
196                    let n = match res {
197                        Ok(n) => n,
198                        Err(e) => {
199                            *self.writer.lock().await = None;
200                            yield Message::Disconnected(e);
201                            break 'inner;
202                        }
203                    };
204
205                    let buf_data = match std::str::from_utf8(&buf[..n]) {
206                        Ok(d) => d,
207                        Err(e) => {
208                            yield Message::Error(e.into());
209                            continue;
210                        }
211                    };
212
213                    // Add the new data to the unread buffer.
214                    unread_buf.push_str(buf_data);
215                    // Parse the unread data.
216                    let parsed = parser::parse(&mut unread_buf);
217
218                    // Loop through the parsed commands.
219                    for res in parsed {
220                        debug!("new message");
221                        // Create a message from the parsed command and yield it.
222                        match Message::from_response(res) {
223                            Ok(msg) => yield msg,
224                            Err(e) => {
225                                warn!("failed to parse message: {:?}", e);
226                                continue;
227                            },
228                        };
229                    }
230                }
231            }
232        }))
233    }
234
235    /// Send a command to the server.
236    async fn send_cmd(&self, command: Command) -> crate::Result<()> {
237        if let Some(writer) = &mut *self.writer.lock().await {
238            writer.writable().await?;
239
240            debug!("sending command {:?} to redis", &command);
241            writer.write_all(command.to_string().as_bytes()).await?;
242        }
243
244        Ok(())
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use redis::AsyncCommands;
252    use tokio_stream::StreamExt;
253
254    async fn get_redis_connections() -> (redis::Client, redis::aio::Connection, RedisSub) {
255        let client =
256            redis::Client::open("redis://127.0.0.1/").expect("failed to create Redis client");
257        let connection = client
258            .get_tokio_connection()
259            .await
260            .expect("failed to open Redis connection");
261        let redis_sub = RedisSub::new("127.0.0.1:6379");
262        (client, connection, redis_sub)
263    }
264
265    #[tokio::test]
266    async fn test_redis_sub() {
267        let (_client, mut connection, redis_sub) = get_redis_connections().await;
268
269        redis_sub
270            .subscribe("1234".to_string())
271            .await
272            .expect("failed to subscribe to new Redis channel");
273        let f = tokio::spawn(async move {
274            {
275                let mut stream = redis_sub
276                    .listen()
277                    .await
278                    .expect("failed to connect to redis");
279
280                let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
281                    .await
282                    .expect("timeout duration of 500 milliseconds was exceeded")
283                    .expect("expected a Message");
284                assert!(
285                    msg.is_connected(),
286                    "message after opening stream was not `Connected`: {:?}",
287                    msg
288                );
289
290                let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
291                    .await
292                    .expect("timeout duration of 500 milliseconds was exceeded")
293                    .expect("expected a Message");
294                assert!(
295                    msg.is_subscription(),
296                    "message after connection was not `Subscription`: {:?}",
297                    msg
298                );
299
300                let msg = tokio::time::timeout(Duration::from_secs(2), stream.next())
301                    .await
302                    .expect("timeout duration of 2 seconds was exceeded")
303                    .expect("expected a Message");
304                assert!(
305                    msg.is_message(),
306                    "message after subscription was not `Message`: {:?}",
307                    msg
308                );
309                match msg {
310                    Message::Message { channel, message } => {
311                        assert_eq!(channel, "1234".to_string());
312                        assert_eq!(message, "1234".to_string());
313                    }
314                    _ => unreachable!("already checked this is message"),
315                }
316            }
317
318            redis_sub
319        });
320
321        // 100 milliseconds longer than the maximum timeout for Redis connection
322        tokio::time::sleep(Duration::from_millis(1100)).await;
323        connection
324            .publish::<&str, &str, u32>("1234", "1234")
325            .await
326            .expect("failed to send publish command to Redis");
327        let redis_sub = f.await.expect("background future failed");
328
329        let mut stream = redis_sub
330            .listen()
331            .await
332            .expect("failed to connect to redis");
333        let _ = stream.next().await;
334        let _ = stream.next().await;
335        redis_sub
336            .unsubscribe("1234".to_string())
337            .await
338            .expect("failed to unsubscribe from Redis channel");
339        let msg = stream.next().await.expect("expected a Message");
340        assert!(
341            msg.is_unsubscription(),
342            "message after unsubscription was not `Unsubscription`: {:?}",
343            msg
344        )
345    }
346
347    #[tokio::test]
348    pub async fn test_redis_pattern_sub() {
349        let (_client, mut connection, redis_sub) = get_redis_connections().await;
350
351        redis_sub
352            .psubscribe("*420*".to_string())
353            .await
354            .expect("failed to subscribe to new Redis channel");
355        let f = tokio::spawn(async move {
356            {
357                let mut stream = redis_sub
358                    .listen()
359                    .await
360                    .expect("failed to connect to redis");
361
362                let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
363                    .await
364                    .expect("timeout duration of 500 milliseconds was exceeded")
365                    .expect("expected a Message");
366                assert!(
367                    msg.is_connected(),
368                    "message after opening stream was not `Connected`: {:?}",
369                    msg
370                );
371
372                let msg = tokio::time::timeout(Duration::from_millis(500), stream.next())
373                    .await
374                    .expect("timeout duration of 500 milliseconds was exceeded")
375                    .expect("expected a Message");
376                assert!(
377                    msg.is_pattern_subscription(),
378                    "message after connection was not `PatternSubscription`: {:?}",
379                    msg
380                );
381
382                let msg = tokio::time::timeout(Duration::from_secs(2), stream.next())
383                    .await
384                    .expect("timeout duration of 2 seconds was exceeded")
385                    .expect("expected a Message");
386                assert!(
387                    msg.is_pattern_message(),
388                    "message after subscription was not `PatternMessage`: {:?}",
389                    msg
390                );
391                match msg {
392                    Message::PatternMessage {
393                        pattern,
394                        channel,
395                        message,
396                    } => {
397                        assert_eq!(pattern, "*420*".to_string());
398                        assert_eq!(channel, "64209".to_string());
399                        assert_eq!(message, "123456".to_string());
400                    }
401                    _ => unreachable!("already checked this is message"),
402                }
403            }
404
405            redis_sub
406        });
407
408        // 100 milliseconds longer than the maximum timeout for connection failure
409        tokio::time::sleep(Duration::from_millis(1100)).await;
410        connection
411            .publish::<&str, &str, u32>("64209", "123456")
412            .await
413            .expect("failed to send publish command to Redis");
414        let redis_sub = f.await.expect("background future failed");
415
416        let mut stream = redis_sub
417            .listen()
418            .await
419            .expect("failed to connect to redis");
420        let _ = stream.next().await;
421        let _ = stream.next().await;
422        redis_sub
423            .punsubscribe("*420*".to_string())
424            .await
425            .expect("failed to unsubscribe from Redis channel");
426        let msg = stream.next().await.expect("expected a Message");
427        assert!(
428            msg.is_pattern_unsubscription(),
429            "message after unsubscription was not `Unsubscription`: {:?}",
430            msg
431        )
432    }
433}