redis_async/client/pubsub/
mod.rs

1/*
2 * Copyright 2017-2024 Ben Ashford
3 *
4 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5 * http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 * <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7 * option. This file may not be copied, modified, or distributed
8 * except according to those terms.
9 */
10
11mod inner;
12
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::Duration;
18
19use futures_channel::{mpsc, oneshot};
20use futures_util::{
21    future::TryFutureExt,
22    stream::{Stream, StreamExt},
23};
24
25use super::{connect::connect_with_auth, ConnectionBuilder};
26
27use crate::{
28    error,
29    reconnect::{reconnect, Reconnect},
30    resp,
31};
32
33use self::inner::PubsubConnectionInner;
34
35#[derive(Debug)]
36pub(crate) enum PubsubEvent {
37    /// The: topic, sink to send messages through, and a oneshot to signal subscription has
38    /// occurred.
39    Subscribe(String, PubsubSink, oneshot::Sender<()>),
40    Psubscribe(String, PubsubSink, oneshot::Sender<()>),
41    /// The name of the topic to unsubscribe from. Unsubscription will be signaled by the stream
42    /// closing without error.
43    Unsubscribe(String),
44    Punsubscribe(String),
45}
46
47type PubsubStreamInner = mpsc::UnboundedReceiver<Result<resp::RespValue, error::Error>>;
48type PubsubSink = mpsc::UnboundedSender<Result<resp::RespValue, error::Error>>;
49
50/// A shareable reference to subscribe to PUBSUB topics
51#[derive(Debug, Clone)]
52pub struct PubsubConnection {
53    out_tx_c: Arc<Reconnect<PubsubEvent, mpsc::UnboundedSender<PubsubEvent>>>,
54}
55
56async fn inner_conn_fn(
57    // Needs to be a String for lifetime reasons
58    host: String,
59    port: u16,
60    username: Option<Arc<str>>,
61    password: Option<Arc<str>>,
62    tls: bool,
63    socket_keepalive: Option<Duration>,
64    socket_timeout: Option<Duration>,
65) -> Result<mpsc::UnboundedSender<PubsubEvent>, error::Error> {
66    let username = username.as_deref();
67    let password = password.as_deref();
68
69    let connection = connect_with_auth(
70        &host,
71        port,
72        username,
73        password,
74        tls,
75        socket_keepalive,
76        socket_timeout,
77    )
78    .await?;
79    let (out_tx, out_rx) = mpsc::unbounded();
80    tokio::spawn(async {
81        match PubsubConnectionInner::new(connection, out_rx).await {
82            Ok(_) => (),
83            Err(e) => log::error!("Pub/Sub error: {:?}", e),
84        }
85    });
86    Ok(out_tx)
87}
88
89impl ConnectionBuilder {
90    pub fn pubsub_connect(&self) -> impl Future<Output = Result<PubsubConnection, error::Error>> {
91        let username = self.username.clone();
92        let password = self.password.clone();
93
94        #[cfg(feature = "tls")]
95        let tls = self.tls;
96        #[cfg(not(feature = "tls"))]
97        let tls = false;
98
99        let host = self.host.clone();
100        let port = self.port;
101
102        let socket_keepalive = self.socket_keepalive;
103        let socket_timeout = self.socket_timeout;
104
105        let reconnecting_f = reconnect(
106            |con: &mpsc::UnboundedSender<PubsubEvent>, act| {
107                con.unbounded_send(act).map_err(|e| e.into())
108            },
109            move || {
110                let con_f = inner_conn_fn(
111                    host.clone(),
112                    port,
113                    username.clone(),
114                    password.clone(),
115                    tls,
116                    socket_keepalive,
117                    socket_timeout,
118                );
119                Box::pin(con_f)
120            },
121        );
122        reconnecting_f.map_ok(|con| PubsubConnection {
123            out_tx_c: Arc::new(con),
124        })
125    }
126}
127
128/// Used for Redis's PUBSUB functionality.
129///
130/// Returns a future that resolves to a `PubsubConnection`. The future will only resolve once the
131/// connection is established; after the intial establishment, if the connection drops for any
132/// reason (e.g. Redis server being restarted), the connection will attempt re-connect, however
133/// any subscriptions will need to be re-subscribed.
134pub async fn pubsub_connect(
135    host: impl Into<String>,
136    port: u16,
137) -> Result<PubsubConnection, error::Error> {
138    ConnectionBuilder::new(host, port)?.pubsub_connect().await
139}
140
141impl PubsubConnection {
142    /// Subscribes to a particular PUBSUB topic.
143    ///
144    /// Returns a future that resolves to a `Stream` that contains all the messages published on
145    /// that particular topic.
146    ///
147    /// The resolved stream will end with `redis_async::error::Error::EndOfStream` if the
148    /// underlying connection is lost for unexpected reasons. In this situation, clients should
149    /// `subscribe` to re-subscribe; the underlying connect will automatically reconnect. However,
150    /// clients should be aware that resubscriptions will only succeed if the underlying connection
151    /// has re-established, so multiple calls to `subscribe` may be required.
152    pub async fn subscribe(&self, topic: &str) -> Result<PubsubStream, error::Error> {
153        let (tx, rx) = mpsc::unbounded();
154        let (signal_t, signal_r) = oneshot::channel();
155        self.out_tx_c
156            .do_work(PubsubEvent::Subscribe(topic.to_owned(), tx, signal_t))?;
157
158        match signal_r.await {
159            Ok(_) => Ok(PubsubStream {
160                topic: topic.to_owned(),
161                underlying: rx,
162                con: self.clone(),
163                is_pattern: false,
164            }),
165            Err(_) => Err(error::internal("Subscription failed, try again later...")),
166        }
167    }
168
169    pub async fn psubscribe(&self, topic: &str) -> Result<PubsubStream, error::Error> {
170        let (tx, rx) = mpsc::unbounded();
171        let (signal_t, signal_r) = oneshot::channel();
172        self.out_tx_c
173            .do_work(PubsubEvent::Psubscribe(topic.to_owned(), tx, signal_t))?;
174
175        match signal_r.await {
176            Ok(_) => Ok(PubsubStream {
177                topic: topic.to_owned(),
178                underlying: rx,
179                con: self.clone(),
180                is_pattern: true,
181            }),
182            Err(_) => Err(error::internal("Subscription failed, try again later...")),
183        }
184    }
185
186    /// Tells the client to unsubscribe from a particular topic. This will return immediately, the
187    /// actual unsubscription will be confirmed when the stream returned from `subscribe` ends.
188    pub fn unsubscribe<T: Into<String>>(&self, topic: T) {
189        // Ignoring any results, as any errors communicating with Redis would de-facto unsubscribe
190        // anyway, and would be reported/logged elsewhere
191        let _ = self
192            .out_tx_c
193            .do_work(PubsubEvent::Unsubscribe(topic.into()));
194    }
195
196    pub fn punsubscribe<T: Into<String>>(&self, topic: T) {
197        // Ignoring any results, as any errors communicating with Redis would de-facto unsubscribe
198        // anyway, and would be reported/logged elsewhere
199        let _ = self
200            .out_tx_c
201            .do_work(PubsubEvent::Punsubscribe(topic.into()));
202    }
203}
204
205#[derive(Debug)]
206pub struct PubsubStream {
207    topic: String,
208    underlying: PubsubStreamInner,
209    con: PubsubConnection,
210    is_pattern: bool,
211}
212
213impl Stream for PubsubStream {
214    type Item = Result<resp::RespValue, error::Error>;
215
216    #[inline]
217    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
218        self.get_mut().underlying.poll_next_unpin(cx)
219    }
220}
221
222impl Drop for PubsubStream {
223    fn drop(&mut self) {
224        let topic: &str = self.topic.as_ref();
225        if self.is_pattern {
226            self.con.punsubscribe(topic);
227        } else {
228            self.con.unsubscribe(topic);
229        }
230    }
231}
232
233#[cfg(test)]
234mod test {
235    use std::mem;
236
237    use futures::{try_join, StreamExt, TryStreamExt};
238
239    use crate::{client, resp};
240
241    /* IMPORTANT: The tests run in parallel, so the topic names used must be exclusive to each test */
242    static SUBSCRIBE_TEST_TOPIC: &str = "test-topic";
243    static SUBSCRIBE_TEST_NON_TOPIC: &str = "test-not-topic";
244
245    static UNSUBSCRIBE_TOPIC_1: &str = "test-topic-1";
246    static UNSUBSCRIBE_TOPIC_2: &str = "test-topic-2";
247    static UNSUBSCRIBE_TOPIC_3: &str = "test-topic-3";
248
249    static RESUBSCRIBE_TOPIC: &str = "test-topic-resubscribe";
250
251    static DROP_CONNECTION_TOPIC: &str = "test-topic-drop-connection";
252
253    static PSUBSCRIBE_PATTERN: &str = "ptest.*";
254    static PSUBSCRIBE_TOPIC_1: &str = "ptest.1";
255    static PSUBSCRIBE_TOPIC_2: &str = "ptest.2";
256    static PSUBSCRIBE_TOPIC_3: &str = "ptest.3";
257
258    static UNSUBSCRIBE_TWICE_TOPIC_1: &str = "test-topic-1-twice";
259    static UNSUBSCRIBE_TWICE_TOPIC_2: &str = "test-topic-2-twice";
260
261    #[tokio::test]
262    async fn subscribe_test() {
263        let paired_c = client::paired_connect("127.0.0.1", 6379);
264        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
265        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
266
267        let topic_messages = pubsub
268            .subscribe(SUBSCRIBE_TEST_TOPIC)
269            .await
270            .expect("Cannot subscribe to topic");
271
272        paired.send_and_forget(resp_array!["PUBLISH", SUBSCRIBE_TEST_TOPIC, "test-message"]);
273        paired.send_and_forget(resp_array![
274            "PUBLISH",
275            SUBSCRIBE_TEST_NON_TOPIC,
276            "test-message-1.5"
277        ]);
278        let _: resp::RespValue = paired
279            .send(resp_array![
280                "PUBLISH",
281                SUBSCRIBE_TEST_TOPIC,
282                "test-message2"
283            ])
284            .await
285            .expect("Cannot send to topic");
286
287        let result: Vec<_> = topic_messages
288            .take(2)
289            .try_collect()
290            .await
291            .expect("Cannot collect two values");
292
293        assert_eq!(result.len(), 2);
294        assert_eq!(result[0], "test-message".into());
295        assert_eq!(result[1], "test-message2".into());
296    }
297
298    /// A test to examine the edge-case where a client subscribes to a topic, then the subscription is specifically unsubscribed,
299    /// vs. where the subscription is automatically unsubscribed.
300    #[tokio::test]
301    async fn unsubscribe_test() {
302        let paired_c = client::paired_connect("127.0.0.1", 6379);
303        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
304        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
305
306        let mut topic_1 = pubsub
307            .subscribe(UNSUBSCRIBE_TOPIC_1)
308            .await
309            .expect("Cannot subscribe to topic");
310        let mut topic_2 = pubsub
311            .subscribe(UNSUBSCRIBE_TOPIC_2)
312            .await
313            .expect("Cannot subscribe to topic");
314        let mut topic_3 = pubsub
315            .subscribe(UNSUBSCRIBE_TOPIC_3)
316            .await
317            .expect("Cannot subscribe to topic");
318
319        paired.send_and_forget(resp_array![
320            "PUBLISH",
321            UNSUBSCRIBE_TOPIC_1,
322            "test-message-1"
323        ]);
324        paired.send_and_forget(resp_array![
325            "PUBLISH",
326            UNSUBSCRIBE_TOPIC_2,
327            "test-message-2"
328        ]);
329        paired.send_and_forget(resp_array![
330            "PUBLISH",
331            UNSUBSCRIBE_TOPIC_3,
332            "test-message-3"
333        ]);
334
335        let result1 = topic_1
336            .next()
337            .await
338            .expect("Cannot get next value")
339            .expect("Cannot get next value");
340        assert_eq!(result1, "test-message-1".into());
341
342        let result2 = topic_2
343            .next()
344            .await
345            .expect("Cannot get next value")
346            .expect("Cannot get next value");
347        assert_eq!(result2, "test-message-2".into());
348
349        let result3 = topic_3
350            .next()
351            .await
352            .expect("Cannot get next value")
353            .expect("Cannot get next value");
354        assert_eq!(result3, "test-message-3".into());
355
356        // Unsubscribe from topic 2
357        pubsub.unsubscribe(UNSUBSCRIBE_TOPIC_2);
358
359        // Drop the subscription for topic 3
360        mem::drop(topic_3);
361
362        // Send some more messages
363        paired.send_and_forget(resp_array![
364            "PUBLISH",
365            UNSUBSCRIBE_TOPIC_1,
366            "test-message-1.5"
367        ]);
368        paired.send_and_forget(resp_array![
369            "PUBLISH",
370            UNSUBSCRIBE_TOPIC_2,
371            "test-message-2.5"
372        ]);
373        paired.send_and_forget(resp_array![
374            "PUBLISH",
375            UNSUBSCRIBE_TOPIC_3,
376            "test-message-3.5"
377        ]);
378
379        // Get the next message for topic 1
380        let result1 = topic_1
381            .next()
382            .await
383            .expect("Cannot get next value")
384            .expect("Cannot get next value");
385        assert_eq!(result1, "test-message-1.5".into());
386
387        // Get the next message for topic 2
388        let result2 = topic_2.next().await;
389        assert!(result2.is_none());
390    }
391
392    /// Test that we can subscribe, unsubscribe, and resubscribe to a topic.
393    #[tokio::test]
394    async fn resubscribe_test() {
395        let paired_c = client::paired_connect("127.0.0.1", 6379);
396        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
397        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
398
399        let mut topic_1 = pubsub
400            .subscribe(RESUBSCRIBE_TOPIC)
401            .await
402            .expect("Cannot subscribe to topic");
403
404        paired.send_and_forget(resp_array!["PUBLISH", RESUBSCRIBE_TOPIC, "test-message-1"]);
405
406        let result1 = topic_1
407            .next()
408            .await
409            .expect("Cannot get next value")
410            .expect("Cannot get next value");
411        assert_eq!(result1, "test-message-1".into());
412
413        // Unsubscribe from topic 1
414        pubsub.unsubscribe(RESUBSCRIBE_TOPIC);
415
416        // Send some more messages
417        paired.send_and_forget(resp_array![
418            "PUBLISH",
419            RESUBSCRIBE_TOPIC,
420            "test-message-1.5"
421        ]);
422
423        // Get the next message for topic 1
424        let result1 = topic_1.next().await;
425        assert!(result1.is_none());
426
427        // Resubscribe to topic 1
428        let mut topic_1 = pubsub
429            .subscribe(RESUBSCRIBE_TOPIC)
430            .await
431            .expect("Cannot subscribe to topic");
432
433        // Send some more messages
434        paired.send_and_forget(resp_array![
435            "PUBLISH",
436            RESUBSCRIBE_TOPIC,
437            "test-message-1.75"
438        ]);
439
440        // Get the next message for topic 1
441        let result1 = topic_1
442            .next()
443            .await
444            .expect("Cannot get next value")
445            .expect("Cannot get next value");
446        assert_eq!(result1, "test-message-1.75".into());
447    }
448
449    /// Test that dropping the connection doesn't stop the subscriptions. Not initially anyway.
450    #[tokio::test]
451    async fn drop_connection_test() {
452        let paired_c = client::paired_connect("127.0.0.1", 6379);
453        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
454        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
455
456        let mut topic_1 = pubsub
457            .subscribe(DROP_CONNECTION_TOPIC)
458            .await
459            .expect("Cannot subscribe to topic");
460
461        mem::drop(pubsub);
462
463        paired.send_and_forget(resp_array![
464            "PUBLISH",
465            DROP_CONNECTION_TOPIC,
466            "test-message-1"
467        ]);
468
469        let result1 = topic_1
470            .next()
471            .await
472            .expect("Cannot get next value")
473            .expect("Cannot get next value");
474        assert_eq!(result1, "test-message-1".into());
475
476        mem::drop(topic_1);
477    }
478
479    #[tokio::test]
480    async fn psubscribe_test() {
481        let paired_c = client::paired_connect("127.0.0.1", 6379);
482        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
483        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
484
485        let topic_messages = pubsub
486            .psubscribe(PSUBSCRIBE_PATTERN)
487            .await
488            .expect("Cannot subscribe to topic");
489
490        paired.send_and_forget(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_1, "test-message-1"]);
491        paired.send_and_forget(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_2, "test-message-2"]);
492        let _: resp::RespValue = paired
493            .send(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_3, "test-message-3"])
494            .await
495            .expect("Cannot send to topic");
496
497        let result: Vec<_> = topic_messages
498            .take(3)
499            .try_collect()
500            .await
501            .expect("Cannot collect two values");
502
503        assert_eq!(result.len(), 3);
504        assert_eq!(result[0], "test-message-1".into());
505        assert_eq!(result[1], "test-message-2".into());
506        assert_eq!(result[2], "test-message-3".into());
507    }
508
509    /// Allow unsubscribe to be called twice
510    #[tokio::test]
511    async fn unsubscribe_twice_test() {
512        let paired_c = client::paired_connect("127.0.0.1", 6379);
513        let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
514        let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
515
516        let mut topic_1 = pubsub
517            .subscribe(UNSUBSCRIBE_TWICE_TOPIC_1)
518            .await
519            .expect("Cannot subscribe to topic");
520        let mut topic_2 = pubsub
521            .subscribe(UNSUBSCRIBE_TWICE_TOPIC_2)
522            .await
523            .expect("Cannot subscribe to topic");
524
525        paired.send_and_forget(resp_array![
526            "PUBLISH",
527            UNSUBSCRIBE_TWICE_TOPIC_1,
528            "test-message-1"
529        ]);
530        paired.send_and_forget(resp_array![
531            "PUBLISH",
532            UNSUBSCRIBE_TWICE_TOPIC_2,
533            "test-message-2"
534        ]);
535
536        pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_2);
537        pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_2);
538
539        paired.send_and_forget(resp_array![
540            "PUBLISH",
541            UNSUBSCRIBE_TWICE_TOPIC_1,
542            "test-message-1.5"
543        ]);
544
545        pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_1);
546
547        let result1 = topic_1
548            .next()
549            .await
550            .expect("Cannot get next value")
551            .expect("Cannot get next value");
552        assert_eq!(result1, "test-message-1".into());
553
554        let result1 = topic_1
555            .next()
556            .await
557            .expect("Cannot get next value")
558            .expect("Cannot get next value");
559        assert_eq!(result1, "test-message-1.5".into());
560
561        let result2 = topic_2
562            .next()
563            .await
564            .expect("Cannot get next value")
565            .expect("Cannot get next value");
566        assert_eq!(result2, "test-message-2".into());
567
568        let result1 = topic_1.next().await;
569        assert!(result1.is_none());
570
571        let result2 = topic_2.next().await;
572        assert!(result2.is_none());
573    }
574}