1mod inner;
12
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::Duration;
18
19use futures_channel::{mpsc, oneshot};
20use futures_util::{
21 future::TryFutureExt,
22 stream::{Stream, StreamExt},
23};
24
25use super::{connect::connect_with_auth, ConnectionBuilder};
26
27use crate::{
28 error,
29 reconnect::{reconnect, Reconnect},
30 resp,
31};
32
33use self::inner::PubsubConnectionInner;
34
35#[derive(Debug)]
36pub(crate) enum PubsubEvent {
37 Subscribe(String, PubsubSink, oneshot::Sender<()>),
40 Psubscribe(String, PubsubSink, oneshot::Sender<()>),
41 Unsubscribe(String),
44 Punsubscribe(String),
45}
46
47type PubsubStreamInner = mpsc::UnboundedReceiver<Result<resp::RespValue, error::Error>>;
48type PubsubSink = mpsc::UnboundedSender<Result<resp::RespValue, error::Error>>;
49
50#[derive(Debug, Clone)]
52pub struct PubsubConnection {
53 out_tx_c: Arc<Reconnect<PubsubEvent, mpsc::UnboundedSender<PubsubEvent>>>,
54}
55
56async fn inner_conn_fn(
57 host: String,
59 port: u16,
60 username: Option<Arc<str>>,
61 password: Option<Arc<str>>,
62 tls: bool,
63 socket_keepalive: Option<Duration>,
64 socket_timeout: Option<Duration>,
65) -> Result<mpsc::UnboundedSender<PubsubEvent>, error::Error> {
66 let username = username.as_deref();
67 let password = password.as_deref();
68
69 let connection = connect_with_auth(
70 &host,
71 port,
72 username,
73 password,
74 tls,
75 socket_keepalive,
76 socket_timeout,
77 )
78 .await?;
79 let (out_tx, out_rx) = mpsc::unbounded();
80 tokio::spawn(async {
81 match PubsubConnectionInner::new(connection, out_rx).await {
82 Ok(_) => (),
83 Err(e) => log::error!("Pub/Sub error: {:?}", e),
84 }
85 });
86 Ok(out_tx)
87}
88
89impl ConnectionBuilder {
90 pub fn pubsub_connect(&self) -> impl Future<Output = Result<PubsubConnection, error::Error>> {
91 let username = self.username.clone();
92 let password = self.password.clone();
93
94 #[cfg(feature = "tls")]
95 let tls = self.tls;
96 #[cfg(not(feature = "tls"))]
97 let tls = false;
98
99 let host = self.host.clone();
100 let port = self.port;
101
102 let socket_keepalive = self.socket_keepalive;
103 let socket_timeout = self.socket_timeout;
104
105 let reconnecting_f = reconnect(
106 |con: &mpsc::UnboundedSender<PubsubEvent>, act| {
107 con.unbounded_send(act).map_err(|e| e.into())
108 },
109 move || {
110 let con_f = inner_conn_fn(
111 host.clone(),
112 port,
113 username.clone(),
114 password.clone(),
115 tls,
116 socket_keepalive,
117 socket_timeout,
118 );
119 Box::pin(con_f)
120 },
121 self.reconnect_options,
122 );
123 reconnecting_f.map_ok(|con| PubsubConnection {
124 out_tx_c: Arc::new(con),
125 })
126 }
127}
128
129pub async fn pubsub_connect(
136 host: impl Into<String>,
137 port: u16,
138) -> Result<PubsubConnection, error::Error> {
139 ConnectionBuilder::new(host, port)?.pubsub_connect().await
140}
141
142impl PubsubConnection {
143 pub async fn subscribe(&self, topic: &str) -> Result<PubsubStream, error::Error> {
154 let (tx, rx) = mpsc::unbounded();
155 let (signal_t, signal_r) = oneshot::channel();
156 self.out_tx_c
157 .do_work(PubsubEvent::Subscribe(topic.to_owned(), tx, signal_t))?;
158
159 match signal_r.await {
160 Ok(_) => Ok(PubsubStream {
161 topic: topic.to_owned(),
162 underlying: rx,
163 con: self.clone(),
164 is_pattern: false,
165 }),
166 Err(_) => Err(error::internal("Subscription failed, try again later...")),
167 }
168 }
169
170 pub async fn psubscribe(&self, topic: &str) -> Result<PubsubStream, error::Error> {
171 let (tx, rx) = mpsc::unbounded();
172 let (signal_t, signal_r) = oneshot::channel();
173 self.out_tx_c
174 .do_work(PubsubEvent::Psubscribe(topic.to_owned(), tx, signal_t))?;
175
176 match signal_r.await {
177 Ok(_) => Ok(PubsubStream {
178 topic: topic.to_owned(),
179 underlying: rx,
180 con: self.clone(),
181 is_pattern: true,
182 }),
183 Err(_) => Err(error::internal("Subscription failed, try again later...")),
184 }
185 }
186
187 pub fn unsubscribe<T: Into<String>>(&self, topic: T) {
190 let _ = self
193 .out_tx_c
194 .do_work(PubsubEvent::Unsubscribe(topic.into()));
195 }
196
197 pub fn punsubscribe<T: Into<String>>(&self, topic: T) {
198 let _ = self
201 .out_tx_c
202 .do_work(PubsubEvent::Punsubscribe(topic.into()));
203 }
204}
205
206#[derive(Debug)]
207pub struct PubsubStream {
208 topic: String,
209 underlying: PubsubStreamInner,
210 con: PubsubConnection,
211 is_pattern: bool,
212}
213
214impl Stream for PubsubStream {
215 type Item = Result<resp::RespValue, error::Error>;
216
217 #[inline]
218 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
219 self.get_mut().underlying.poll_next_unpin(cx)
220 }
221}
222
223impl Drop for PubsubStream {
224 fn drop(&mut self) {
225 let topic: &str = self.topic.as_ref();
226 if self.is_pattern {
227 self.con.punsubscribe(topic);
228 } else {
229 self.con.unsubscribe(topic);
230 }
231 }
232}
233
234#[cfg(test)]
235mod test {
236 use std::mem;
237 use std::time::Duration;
238
239 use futures::{try_join, StreamExt, TryStreamExt};
240 use tokio::time::sleep;
241
242 use crate::{client, resp};
243
244 static SUBSCRIBE_TEST_TOPIC: &str = "test-topic";
246 static SUBSCRIBE_TEST_NON_TOPIC: &str = "test-not-topic";
247
248 static UNSUBSCRIBE_TOPIC_1: &str = "test-topic-1";
249 static UNSUBSCRIBE_TOPIC_2: &str = "test-topic-2";
250 static UNSUBSCRIBE_TOPIC_3: &str = "test-topic-3";
251
252 static RESUBSCRIBE_TOPIC: &str = "test-topic-resubscribe";
253
254 static DROP_CONNECTION_TOPIC: &str = "test-topic-drop-connection";
255
256 static PSUBSCRIBE_PATTERN: &str = "ptest.*";
257 static PSUBSCRIBE_TOPIC_1: &str = "ptest.1";
258 static PSUBSCRIBE_TOPIC_2: &str = "ptest.2";
259 static PSUBSCRIBE_TOPIC_3: &str = "ptest.3";
260
261 static UNSUBSCRIBE_TWICE_TOPIC_1: &str = "test-topic-1-twice";
262 static UNSUBSCRIBE_TWICE_TOPIC_2: &str = "test-topic-2-twice";
263
264 #[tokio::test]
265 async fn subscribe_test() {
266 let paired_c = client::paired_connect("127.0.0.1", 6379);
267 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
268 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
269
270 let topic_messages = pubsub
271 .subscribe(SUBSCRIBE_TEST_TOPIC)
272 .await
273 .expect("Cannot subscribe to topic");
274
275 paired.send_and_forget(resp_array!["PUBLISH", SUBSCRIBE_TEST_TOPIC, "test-message"]);
276 paired.send_and_forget(resp_array![
277 "PUBLISH",
278 SUBSCRIBE_TEST_NON_TOPIC,
279 "test-message-1.5"
280 ]);
281 let _: resp::RespValue = paired
282 .send(resp_array![
283 "PUBLISH",
284 SUBSCRIBE_TEST_TOPIC,
285 "test-message2"
286 ])
287 .await
288 .expect("Cannot send to topic");
289
290 let result: Vec<_> = topic_messages
291 .take(2)
292 .try_collect()
293 .await
294 .expect("Cannot collect two values");
295
296 assert_eq!(result.len(), 2);
297 assert_eq!(result[0], "test-message".into());
298 assert_eq!(result[1], "test-message2".into());
299 }
300
301 #[tokio::test]
304 async fn unsubscribe_test() {
305 let paired_c = client::paired_connect("127.0.0.1", 6379);
306 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
307 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
308
309 let mut topic_1 = pubsub
310 .subscribe(UNSUBSCRIBE_TOPIC_1)
311 .await
312 .expect("Cannot subscribe to topic");
313 let mut topic_2 = pubsub
314 .subscribe(UNSUBSCRIBE_TOPIC_2)
315 .await
316 .expect("Cannot subscribe to topic");
317 let mut topic_3 = pubsub
318 .subscribe(UNSUBSCRIBE_TOPIC_3)
319 .await
320 .expect("Cannot subscribe to topic");
321
322 paired.send_and_forget(resp_array![
323 "PUBLISH",
324 UNSUBSCRIBE_TOPIC_1,
325 "test-message-1"
326 ]);
327 paired.send_and_forget(resp_array![
328 "PUBLISH",
329 UNSUBSCRIBE_TOPIC_2,
330 "test-message-2"
331 ]);
332 paired.send_and_forget(resp_array![
333 "PUBLISH",
334 UNSUBSCRIBE_TOPIC_3,
335 "test-message-3"
336 ]);
337
338 let result1 = topic_1
339 .next()
340 .await
341 .expect("Cannot get next value")
342 .expect("Cannot get next value");
343 assert_eq!(result1, "test-message-1".into());
344
345 let result2 = topic_2
346 .next()
347 .await
348 .expect("Cannot get next value")
349 .expect("Cannot get next value");
350 assert_eq!(result2, "test-message-2".into());
351
352 let result3 = topic_3
353 .next()
354 .await
355 .expect("Cannot get next value")
356 .expect("Cannot get next value");
357 assert_eq!(result3, "test-message-3".into());
358
359 pubsub.unsubscribe(UNSUBSCRIBE_TOPIC_2);
361
362 sleep(Duration::from_millis(1000)).await;
364
365 mem::drop(topic_3);
367
368 paired.send_and_forget(resp_array![
370 "PUBLISH",
371 UNSUBSCRIBE_TOPIC_1,
372 "test-message-1.5"
373 ]);
374 paired.send_and_forget(resp_array![
375 "PUBLISH",
376 UNSUBSCRIBE_TOPIC_2,
377 "test-message-2.5"
378 ]);
379 paired.send_and_forget(resp_array![
380 "PUBLISH",
381 UNSUBSCRIBE_TOPIC_3,
382 "test-message-3.5"
383 ]);
384
385 let result1 = topic_1
387 .next()
388 .await
389 .expect("Cannot get next value")
390 .expect("Cannot get next value");
391 assert_eq!(result1, "test-message-1.5".into());
392
393 let result2 = topic_2.next().await;
395 assert!(result2.is_none());
396 }
397
398 #[tokio::test]
400 async fn resubscribe_test() {
401 let paired_c = client::paired_connect("127.0.0.1", 6379);
402 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
403 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
404
405 let mut topic_1 = pubsub
406 .subscribe(RESUBSCRIBE_TOPIC)
407 .await
408 .expect("Cannot subscribe to topic");
409
410 paired.send_and_forget(resp_array!["PUBLISH", RESUBSCRIBE_TOPIC, "test-message-1"]);
411
412 let result1 = topic_1
413 .next()
414 .await
415 .expect("Cannot get next value")
416 .expect("Cannot get next value");
417 assert_eq!(result1, "test-message-1".into());
418
419 pubsub.unsubscribe(RESUBSCRIBE_TOPIC);
421
422 sleep(Duration::from_millis(1000)).await;
424
425 paired.send_and_forget(resp_array![
427 "PUBLISH",
428 RESUBSCRIBE_TOPIC,
429 "test-message-1.5"
430 ]);
431
432 let result1 = topic_1.next().await;
434 assert!(result1.is_none());
435
436 let mut topic_1 = pubsub
438 .subscribe(RESUBSCRIBE_TOPIC)
439 .await
440 .expect("Cannot subscribe to topic");
441
442 paired.send_and_forget(resp_array![
444 "PUBLISH",
445 RESUBSCRIBE_TOPIC,
446 "test-message-1.75"
447 ]);
448
449 let result1 = topic_1
451 .next()
452 .await
453 .expect("Cannot get next value")
454 .expect("Cannot get next value");
455 assert_eq!(result1, "test-message-1.75".into());
456 }
457
458 #[tokio::test]
460 async fn drop_connection_test() {
461 let paired_c = client::paired_connect("127.0.0.1", 6379);
462 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
463 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
464
465 let mut topic_1 = pubsub
466 .subscribe(DROP_CONNECTION_TOPIC)
467 .await
468 .expect("Cannot subscribe to topic");
469
470 mem::drop(pubsub);
471
472 paired.send_and_forget(resp_array![
473 "PUBLISH",
474 DROP_CONNECTION_TOPIC,
475 "test-message-1"
476 ]);
477
478 let result1 = topic_1
479 .next()
480 .await
481 .expect("Cannot get next value")
482 .expect("Cannot get next value");
483 assert_eq!(result1, "test-message-1".into());
484
485 mem::drop(topic_1);
486 }
487
488 #[tokio::test]
489 async fn psubscribe_test() {
490 let paired_c = client::paired_connect("127.0.0.1", 6379);
491 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
492 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
493
494 let topic_messages = pubsub
495 .psubscribe(PSUBSCRIBE_PATTERN)
496 .await
497 .expect("Cannot subscribe to topic");
498
499 paired.send_and_forget(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_1, "test-message-1"]);
500 paired.send_and_forget(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_2, "test-message-2"]);
501 let _: resp::RespValue = paired
502 .send(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_3, "test-message-3"])
503 .await
504 .expect("Cannot send to topic");
505
506 let result: Vec<_> = topic_messages
507 .take(3)
508 .try_collect()
509 .await
510 .expect("Cannot collect two values");
511
512 assert_eq!(result.len(), 3);
513 assert_eq!(result[0], "test-message-1".into());
514 assert_eq!(result[1], "test-message-2".into());
515 assert_eq!(result[2], "test-message-3".into());
516 }
517
518 #[tokio::test]
520 async fn unsubscribe_twice_test() {
521 let paired_c = client::paired_connect("127.0.0.1", 6379);
522 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
523 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
524
525 let mut topic_1 = pubsub
526 .subscribe(UNSUBSCRIBE_TWICE_TOPIC_1)
527 .await
528 .expect("Cannot subscribe to topic");
529 let mut topic_2 = pubsub
530 .subscribe(UNSUBSCRIBE_TWICE_TOPIC_2)
531 .await
532 .expect("Cannot subscribe to topic");
533
534 paired.send_and_forget(resp_array![
535 "PUBLISH",
536 UNSUBSCRIBE_TWICE_TOPIC_1,
537 "test-message-1"
538 ]);
539 paired.send_and_forget(resp_array![
540 "PUBLISH",
541 UNSUBSCRIBE_TWICE_TOPIC_2,
542 "test-message-2"
543 ]);
544
545 pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_2);
546 pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_2);
547
548 paired.send_and_forget(resp_array![
549 "PUBLISH",
550 UNSUBSCRIBE_TWICE_TOPIC_1,
551 "test-message-1.5"
552 ]);
553
554 sleep(Duration::from_millis(1000)).await;
556
557 pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_1);
558
559 let result1 = topic_1
560 .next()
561 .await
562 .expect("Cannot get next value")
563 .expect("Cannot get next value");
564 assert_eq!(result1, "test-message-1".into());
565
566 let result1 = topic_1
567 .next()
568 .await
569 .expect("Cannot get next value")
570 .expect("Cannot get next value");
571 assert_eq!(result1, "test-message-1.5".into());
572
573 let result2 = topic_2
574 .next()
575 .await
576 .expect("Cannot get next value")
577 .expect("Cannot get next value");
578 assert_eq!(result2, "test-message-2".into());
579
580 let result1 = topic_1.next().await;
581 assert!(result1.is_none());
582
583 let result2 = topic_2.next().await;
584 assert!(result2.is_none());
585 }
586}