redis_async/client/pubsub/
mod.rs

1/*
2 * Copyright 2017-2025 Ben Ashford
3 *
4 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5 * http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 * <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7 * option. This file may not be copied, modified, or distributed
8 * except according to those terms.
9 */
10
11mod inner;
12
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::Duration;
18
19use futures_channel::{mpsc, oneshot};
20use futures_util::{
21    future::TryFutureExt,
22    stream::{Stream, StreamExt},
23};
24
25use super::{connect::connect_with_auth, ConnectionBuilder};
26
27use crate::{
28    error,
29    reconnect::{reconnect, Reconnect},
30    resp,
31};
32
33use self::inner::PubsubConnectionInner;
34
35#[derive(Debug)]
36pub(crate) enum PubsubEvent {
37    /// The: topic, sink to send messages through, and a oneshot to signal subscription has
38    /// occurred.
39    Subscribe(String, PubsubSink, oneshot::Sender<()>),
40    Psubscribe(String, PubsubSink, oneshot::Sender<()>),
41    /// The name of the topic to unsubscribe from. Unsubscription will be signaled by the stream
42    /// closing without error.
43    Unsubscribe(String),
44    Punsubscribe(String),
45}
46
47type PubsubStreamInner = mpsc::UnboundedReceiver<Result<resp::RespValue, error::Error>>;
48type PubsubSink = mpsc::UnboundedSender<Result<resp::RespValue, error::Error>>;
49
50/// A shareable reference to subscribe to PUBSUB topics
51#[derive(Debug, Clone)]
52pub struct PubsubConnection {
53    out_tx_c: Arc<Reconnect<PubsubEvent, mpsc::UnboundedSender<PubsubEvent>>>,
54}
55
56async fn inner_conn_fn(
57    // Needs to be a String for lifetime reasons
58    host: String,
59    port: u16,
60    username: Option<Arc<str>>,
61    password: Option<Arc<str>>,
62    tls: bool,
63    socket_keepalive: Option<Duration>,
64    socket_timeout: Option<Duration>,
65) -> Result<mpsc::UnboundedSender<PubsubEvent>, error::Error> {
66    let username = username.as_deref();
67    let password = password.as_deref();
68
69    let connection = connect_with_auth(
70        &host,
71        port,
72        username,
73        password,
74        tls,
75        socket_keepalive,
76        socket_timeout,
77    )
78    .await?;
79    let (out_tx, out_rx) = mpsc::unbounded();
80    tokio::spawn(async {
81        match PubsubConnectionInner::new(connection, out_rx).await {
82            Ok(_) => (),
83            Err(e) => log::error!("Pub/Sub error: {:?}", e),
84        }
85    });
86    Ok(out_tx)
87}
88
89impl ConnectionBuilder {
90    pub fn pubsub_connect(&self) -> impl Future<Output = Result<PubsubConnection, error::Error>> {
91        let username = self.username.clone();
92        let password = self.password.clone();
93
94        #[cfg(feature = "tls")]
95        let tls = self.tls;
96        #[cfg(not(feature = "tls"))]
97        let tls = false;
98
99        let host = self.host.clone();
100        let port = self.port;
101
102        let socket_keepalive = self.socket_keepalive;
103        let socket_timeout = self.socket_timeout;
104
105        let reconnecting_f = reconnect(
106            |con: &mpsc::UnboundedSender<PubsubEvent>, act| {
107                con.unbounded_send(act).map_err(|e| e.into())
108            },
109            move || {
110                let con_f = inner_conn_fn(
111                    host.clone(),
112                    port,
113                    username.clone(),
114                    password.clone(),
115                    tls,
116                    socket_keepalive,
117                    socket_timeout,
118                );
119                Box::pin(con_f)
120            },
121            self.reconnect_options,
122        );
123        reconnecting_f.map_ok(|con| PubsubConnection {
124            out_tx_c: Arc::new(con),
125        })
126    }
127}
128
129/// Used for Redis's PUBSUB functionality.
130///
131/// Returns a future that resolves to a `PubsubConnection`. The future will only resolve once the
132/// connection is established; after the intial establishment, if the connection drops for any
133/// reason (e.g. Redis server being restarted), the connection will attempt re-connect, however
134/// any subscriptions will need to be re-subscribed.
135pub async fn pubsub_connect(
136    host: impl Into<String>,
137    port: u16,
138) -> Result<PubsubConnection, error::Error> {
139    ConnectionBuilder::new(host, port)?.pubsub_connect().await
140}
141
142impl PubsubConnection {
143    /// Subscribes to a particular PUBSUB topic.
144    ///
145    /// Returns a future that resolves to a `Stream` that contains all the messages published on
146    /// that particular topic.
147    ///
148    /// The resolved stream will end with `redis_async::error::Error::EndOfStream` if the
149    /// underlying connection is lost for unexpected reasons. In this situation, clients should
150    /// `subscribe` to re-subscribe; the underlying connect will automatically reconnect. However,
151    /// clients should be aware that resubscriptions will only succeed if the underlying connection
152    /// has re-established, so multiple calls to `subscribe` may be required.
153    pub async fn subscribe(&self, topic: &str) -> Result<PubsubStream, error::Error> {
154        let (tx, rx) = mpsc::unbounded();
155        let (signal_t, signal_r) = oneshot::channel();
156        self.out_tx_c
157            .do_work(PubsubEvent::Subscribe(topic.to_owned(), tx, signal_t))?;
158
159        match signal_r.await {
160            Ok(_) => Ok(PubsubStream {
161                topic: topic.to_owned(),
162                underlying: rx,
163                con: self.clone(),
164                is_pattern: false,
165            }),
166            Err(_) => Err(error::internal("Subscription failed, try again later...")),
167        }
168    }
169
170    pub async fn psubscribe(&self, topic: &str) -> Result<PubsubStream, error::Error> {
171        let (tx, rx) = mpsc::unbounded();
172        let (signal_t, signal_r) = oneshot::channel();
173        self.out_tx_c
174            .do_work(PubsubEvent::Psubscribe(topic.to_owned(), tx, signal_t))?;
175
176        match signal_r.await {
177            Ok(_) => Ok(PubsubStream {
178                topic: topic.to_owned(),
179                underlying: rx,
180                con: self.clone(),
181                is_pattern: true,
182            }),
183            Err(_) => Err(error::internal("Subscription failed, try again later...")),
184        }
185    }
186
187    /// Tells the client to unsubscribe from a particular topic. This will return immediately, the
188    /// actual unsubscription will be confirmed when the stream returned from `subscribe` ends.
189    pub fn unsubscribe<T: Into<String>>(&self, topic: T) {
190        // Ignoring any results, as any errors communicating with Redis would de-facto unsubscribe
191        // anyway, and would be reported/logged elsewhere
192        let _ = self
193            .out_tx_c
194            .do_work(PubsubEvent::Unsubscribe(topic.into()));
195    }
196
197    pub fn punsubscribe<T: Into<String>>(&self, topic: T) {
198        // Ignoring any results, as any errors communicating with Redis would de-facto unsubscribe
199        // anyway, and would be reported/logged elsewhere
200        let _ = self
201            .out_tx_c
202            .do_work(PubsubEvent::Punsubscribe(topic.into()));
203    }
204}
205
206#[derive(Debug)]
207pub struct PubsubStream {
208    topic: String,
209    underlying: PubsubStreamInner,
210    con: PubsubConnection,
211    is_pattern: bool,
212}
213
214impl Stream for PubsubStream {
215    type Item = Result<resp::RespValue, error::Error>;
216
217    #[inline]
218    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
219        self.get_mut().underlying.poll_next_unpin(cx)
220    }
221}
222
223impl Drop for PubsubStream {
224    fn drop(&mut self) {
225        let topic: &str = self.topic.as_ref();
226        if self.is_pattern {
227            self.con.punsubscribe(topic);
228        } else {
229            self.con.unsubscribe(topic);
230        }
231    }
232}
233
234#[cfg(test)]
235mod test {
236    use std::mem;
237    use std::time::Duration;
238
239    use futures::{try_join, StreamExt, TryStreamExt};
240    use tokio::time::sleep;
241
242    use crate::{client, resp};
243
244    /* IMPORTANT: The tests run in parallel, so the topic names used must be exclusive to each test */
245    static SUBSCRIBE_TEST_TOPIC: &str = "test-topic";
246    static SUBSCRIBE_TEST_NON_TOPIC: &str = "test-not-topic";
247
248    static UNSUBSCRIBE_TOPIC_1: &str = "test-topic-1";
249    static UNSUBSCRIBE_TOPIC_2: &str = "test-topic-2";
250    static UNSUBSCRIBE_TOPIC_3: &str = "test-topic-3";
251
252    static RESUBSCRIBE_TOPIC: &str = "test-topic-resubscribe";
253
254    static DROP_CONNECTION_TOPIC: &str = "test-topic-drop-connection";
255
256    static PSUBSCRIBE_PATTERN: &str = "ptest.*";
257    static PSUBSCRIBE_TOPIC_1: &str = "ptest.1";
258    static PSUBSCRIBE_TOPIC_2: &str = "ptest.2";
259    static PSUBSCRIBE_TOPIC_3: &str = "ptest.3";
260
261    static UNSUBSCRIBE_TWICE_TOPIC_1: &str = "test-topic-1-twice";
262    static UNSUBSCRIBE_TWICE_TOPIC_2: &str = "test-topic-2-twice";
263
264    #[tokio::test]
265    async fn subscribe_test() {
266        let paired_c = client::paired_connect("127.0.0.1", 6379);
267        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
268        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
269
270        let topic_messages = pubsub
271            .subscribe(SUBSCRIBE_TEST_TOPIC)
272            .await
273            .expect("Cannot subscribe to topic");
274
275        paired.send_and_forget(resp_array!["PUBLISH", SUBSCRIBE_TEST_TOPIC, "test-message"]);
276        paired.send_and_forget(resp_array![
277            "PUBLISH",
278            SUBSCRIBE_TEST_NON_TOPIC,
279            "test-message-1.5"
280        ]);
281        let _: resp::RespValue = paired
282            .send(resp_array![
283                "PUBLISH",
284                SUBSCRIBE_TEST_TOPIC,
285                "test-message2"
286            ])
287            .await
288            .expect("Cannot send to topic");
289
290        let result: Vec<_> = topic_messages
291            .take(2)
292            .try_collect()
293            .await
294            .expect("Cannot collect two values");
295
296        assert_eq!(result.len(), 2);
297        assert_eq!(result[0], "test-message".into());
298        assert_eq!(result[1], "test-message2".into());
299    }
300
301    /// A test to examine the edge-case where a client subscribes to a topic, then the subscription is specifically unsubscribed,
302    /// vs. where the subscription is automatically unsubscribed.
303    #[tokio::test]
304    async fn unsubscribe_test() {
305        let paired_c = client::paired_connect("127.0.0.1", 6379);
306        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
307        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
308
309        let mut topic_1 = pubsub
310            .subscribe(UNSUBSCRIBE_TOPIC_1)
311            .await
312            .expect("Cannot subscribe to topic");
313        let mut topic_2 = pubsub
314            .subscribe(UNSUBSCRIBE_TOPIC_2)
315            .await
316            .expect("Cannot subscribe to topic");
317        let mut topic_3 = pubsub
318            .subscribe(UNSUBSCRIBE_TOPIC_3)
319            .await
320            .expect("Cannot subscribe to topic");
321
322        paired.send_and_forget(resp_array![
323            "PUBLISH",
324            UNSUBSCRIBE_TOPIC_1,
325            "test-message-1"
326        ]);
327        paired.send_and_forget(resp_array![
328            "PUBLISH",
329            UNSUBSCRIBE_TOPIC_2,
330            "test-message-2"
331        ]);
332        paired.send_and_forget(resp_array![
333            "PUBLISH",
334            UNSUBSCRIBE_TOPIC_3,
335            "test-message-3"
336        ]);
337
338        let result1 = topic_1
339            .next()
340            .await
341            .expect("Cannot get next value")
342            .expect("Cannot get next value");
343        assert_eq!(result1, "test-message-1".into());
344
345        let result2 = topic_2
346            .next()
347            .await
348            .expect("Cannot get next value")
349            .expect("Cannot get next value");
350        assert_eq!(result2, "test-message-2".into());
351
352        let result3 = topic_3
353            .next()
354            .await
355            .expect("Cannot get next value")
356            .expect("Cannot get next value");
357        assert_eq!(result3, "test-message-3".into());
358
359        // Unsubscribe from topic 2
360        pubsub.unsubscribe(UNSUBSCRIBE_TOPIC_2);
361
362        // Ensure unsubscription is processed
363        sleep(Duration::from_millis(1000)).await;
364
365        // Drop the subscription for topic 3
366        mem::drop(topic_3);
367
368        // Send some more messages
369        paired.send_and_forget(resp_array![
370            "PUBLISH",
371            UNSUBSCRIBE_TOPIC_1,
372            "test-message-1.5"
373        ]);
374        paired.send_and_forget(resp_array![
375            "PUBLISH",
376            UNSUBSCRIBE_TOPIC_2,
377            "test-message-2.5"
378        ]);
379        paired.send_and_forget(resp_array![
380            "PUBLISH",
381            UNSUBSCRIBE_TOPIC_3,
382            "test-message-3.5"
383        ]);
384
385        // Get the next message for topic 1
386        let result1 = topic_1
387            .next()
388            .await
389            .expect("Cannot get next value")
390            .expect("Cannot get next value");
391        assert_eq!(result1, "test-message-1.5".into());
392
393        // Get the next message for topic 2
394        let result2 = topic_2.next().await;
395        assert!(result2.is_none());
396    }
397
398    /// Test that we can subscribe, unsubscribe, and resubscribe to a topic.
399    #[tokio::test]
400    async fn resubscribe_test() {
401        let paired_c = client::paired_connect("127.0.0.1", 6379);
402        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
403        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
404
405        let mut topic_1 = pubsub
406            .subscribe(RESUBSCRIBE_TOPIC)
407            .await
408            .expect("Cannot subscribe to topic");
409
410        paired.send_and_forget(resp_array!["PUBLISH", RESUBSCRIBE_TOPIC, "test-message-1"]);
411
412        let result1 = topic_1
413            .next()
414            .await
415            .expect("Cannot get next value")
416            .expect("Cannot get next value");
417        assert_eq!(result1, "test-message-1".into());
418
419        // Unsubscribe from topic 1
420        pubsub.unsubscribe(RESUBSCRIBE_TOPIC);
421
422        // Yes, I know, just testing...
423        sleep(Duration::from_millis(1000)).await;
424
425        // Send some more messages
426        paired.send_and_forget(resp_array![
427            "PUBLISH",
428            RESUBSCRIBE_TOPIC,
429            "test-message-1.5"
430        ]);
431
432        // Get the next message for topic 1
433        let result1 = topic_1.next().await;
434        assert!(result1.is_none());
435
436        // Resubscribe to topic 1
437        let mut topic_1 = pubsub
438            .subscribe(RESUBSCRIBE_TOPIC)
439            .await
440            .expect("Cannot subscribe to topic");
441
442        // Send some more messages
443        paired.send_and_forget(resp_array![
444            "PUBLISH",
445            RESUBSCRIBE_TOPIC,
446            "test-message-1.75"
447        ]);
448
449        // Get the next message for topic 1
450        let result1 = topic_1
451            .next()
452            .await
453            .expect("Cannot get next value")
454            .expect("Cannot get next value");
455        assert_eq!(result1, "test-message-1.75".into());
456    }
457
458    /// Test that dropping the connection doesn't stop the subscriptions. Not initially anyway.
459    #[tokio::test]
460    async fn drop_connection_test() {
461        let paired_c = client::paired_connect("127.0.0.1", 6379);
462        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
463        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
464
465        let mut topic_1 = pubsub
466            .subscribe(DROP_CONNECTION_TOPIC)
467            .await
468            .expect("Cannot subscribe to topic");
469
470        mem::drop(pubsub);
471
472        paired.send_and_forget(resp_array![
473            "PUBLISH",
474            DROP_CONNECTION_TOPIC,
475            "test-message-1"
476        ]);
477
478        let result1 = topic_1
479            .next()
480            .await
481            .expect("Cannot get next value")
482            .expect("Cannot get next value");
483        assert_eq!(result1, "test-message-1".into());
484
485        mem::drop(topic_1);
486    }
487
488    #[tokio::test]
489    async fn psubscribe_test() {
490        let paired_c = client::paired_connect("127.0.0.1", 6379);
491        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
492        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
493
494        let topic_messages = pubsub
495            .psubscribe(PSUBSCRIBE_PATTERN)
496            .await
497            .expect("Cannot subscribe to topic");
498
499        paired.send_and_forget(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_1, "test-message-1"]);
500        paired.send_and_forget(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_2, "test-message-2"]);
501        let _: resp::RespValue = paired
502            .send(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_3, "test-message-3"])
503            .await
504            .expect("Cannot send to topic");
505
506        let result: Vec<_> = topic_messages
507            .take(3)
508            .try_collect()
509            .await
510            .expect("Cannot collect two values");
511
512        assert_eq!(result.len(), 3);
513        assert_eq!(result[0], "test-message-1".into());
514        assert_eq!(result[1], "test-message-2".into());
515        assert_eq!(result[2], "test-message-3".into());
516    }
517
518    /// Allow unsubscribe to be called twice
519    #[tokio::test]
520    async fn unsubscribe_twice_test() {
521        let paired_c = client::paired_connect("127.0.0.1", 6379);
522        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
523        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
524
525        let mut topic_1 = pubsub
526            .subscribe(UNSUBSCRIBE_TWICE_TOPIC_1)
527            .await
528            .expect("Cannot subscribe to topic");
529        let mut topic_2 = pubsub
530            .subscribe(UNSUBSCRIBE_TWICE_TOPIC_2)
531            .await
532            .expect("Cannot subscribe to topic");
533
534        paired.send_and_forget(resp_array![
535            "PUBLISH",
536            UNSUBSCRIBE_TWICE_TOPIC_1,
537            "test-message-1"
538        ]);
539        paired.send_and_forget(resp_array![
540            "PUBLISH",
541            UNSUBSCRIBE_TWICE_TOPIC_2,
542            "test-message-2"
543        ]);
544
545        pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_2);
546        pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_2);
547
548        paired.send_and_forget(resp_array![
549            "PUBLISH",
550            UNSUBSCRIBE_TWICE_TOPIC_1,
551            "test-message-1.5"
552        ]);
553
554        // Allow for the messages to be sent
555        sleep(Duration::from_millis(1000)).await;
556
557        pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_1);
558
559        let result1 = topic_1
560            .next()
561            .await
562            .expect("Cannot get next value")
563            .expect("Cannot get next value");
564        assert_eq!(result1, "test-message-1".into());
565
566        let result1 = topic_1
567            .next()
568            .await
569            .expect("Cannot get next value")
570            .expect("Cannot get next value");
571        assert_eq!(result1, "test-message-1.5".into());
572
573        let result2 = topic_2
574            .next()
575            .await
576            .expect("Cannot get next value")
577            .expect("Cannot get next value");
578        assert_eq!(result2, "test-message-2".into());
579
580        let result1 = topic_1.next().await;
581        assert!(result1.is_none());
582
583        let result2 = topic_2.next().await;
584        assert!(result2.is_none());
585    }
586}