1mod inner;
12
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::Duration;
18
19use futures_channel::{mpsc, oneshot};
20use futures_util::{
21 future::TryFutureExt,
22 stream::{Stream, StreamExt},
23};
24
25use super::{connect::connect_with_auth, ConnectionBuilder};
26
27use crate::{
28 error,
29 reconnect::{reconnect, Reconnect},
30 resp,
31};
32
33use self::inner::PubsubConnectionInner;
34
35#[derive(Debug)]
36pub(crate) enum PubsubEvent {
37 Subscribe(String, PubsubSink, oneshot::Sender<()>),
40 Psubscribe(String, PubsubSink, oneshot::Sender<()>),
41 Unsubscribe(String),
44 Punsubscribe(String),
45}
46
47type PubsubStreamInner = mpsc::UnboundedReceiver<Result<resp::RespValue, error::Error>>;
48type PubsubSink = mpsc::UnboundedSender<Result<resp::RespValue, error::Error>>;
49
50#[derive(Debug, Clone)]
52pub struct PubsubConnection {
53 out_tx_c: Arc<Reconnect<PubsubEvent, mpsc::UnboundedSender<PubsubEvent>>>,
54}
55
56async fn inner_conn_fn(
57 host: String,
59 port: u16,
60 username: Option<Arc<str>>,
61 password: Option<Arc<str>>,
62 tls: bool,
63 socket_keepalive: Option<Duration>,
64 socket_timeout: Option<Duration>,
65) -> Result<mpsc::UnboundedSender<PubsubEvent>, error::Error> {
66 let username = username.as_deref();
67 let password = password.as_deref();
68
69 let connection = connect_with_auth(
70 &host,
71 port,
72 username,
73 password,
74 tls,
75 socket_keepalive,
76 socket_timeout,
77 )
78 .await?;
79 let (out_tx, out_rx) = mpsc::unbounded();
80 tokio::spawn(async {
81 match PubsubConnectionInner::new(connection, out_rx).await {
82 Ok(_) => (),
83 Err(e) => log::error!("Pub/Sub error: {:?}", e),
84 }
85 });
86 Ok(out_tx)
87}
88
89impl ConnectionBuilder {
90 pub fn pubsub_connect(&self) -> impl Future<Output = Result<PubsubConnection, error::Error>> {
91 let username = self.username.clone();
92 let password = self.password.clone();
93
94 #[cfg(feature = "tls")]
95 let tls = self.tls;
96 #[cfg(not(feature = "tls"))]
97 let tls = false;
98
99 let host = self.host.clone();
100 let port = self.port;
101
102 let socket_keepalive = self.socket_keepalive;
103 let socket_timeout = self.socket_timeout;
104
105 let reconnecting_f = reconnect(
106 |con: &mpsc::UnboundedSender<PubsubEvent>, act| {
107 con.unbounded_send(act).map_err(|e| e.into())
108 },
109 move || {
110 let con_f = inner_conn_fn(
111 host.clone(),
112 port,
113 username.clone(),
114 password.clone(),
115 tls,
116 socket_keepalive,
117 socket_timeout,
118 );
119 Box::pin(con_f)
120 },
121 );
122 reconnecting_f.map_ok(|con| PubsubConnection {
123 out_tx_c: Arc::new(con),
124 })
125 }
126}
127
128pub async fn pubsub_connect(
135 host: impl Into<String>,
136 port: u16,
137) -> Result<PubsubConnection, error::Error> {
138 ConnectionBuilder::new(host, port)?.pubsub_connect().await
139}
140
141impl PubsubConnection {
142 pub async fn subscribe(&self, topic: &str) -> Result<PubsubStream, error::Error> {
153 let (tx, rx) = mpsc::unbounded();
154 let (signal_t, signal_r) = oneshot::channel();
155 self.out_tx_c
156 .do_work(PubsubEvent::Subscribe(topic.to_owned(), tx, signal_t))?;
157
158 match signal_r.await {
159 Ok(_) => Ok(PubsubStream {
160 topic: topic.to_owned(),
161 underlying: rx,
162 con: self.clone(),
163 is_pattern: false,
164 }),
165 Err(_) => Err(error::internal("Subscription failed, try again later...")),
166 }
167 }
168
169 pub async fn psubscribe(&self, topic: &str) -> Result<PubsubStream, error::Error> {
170 let (tx, rx) = mpsc::unbounded();
171 let (signal_t, signal_r) = oneshot::channel();
172 self.out_tx_c
173 .do_work(PubsubEvent::Psubscribe(topic.to_owned(), tx, signal_t))?;
174
175 match signal_r.await {
176 Ok(_) => Ok(PubsubStream {
177 topic: topic.to_owned(),
178 underlying: rx,
179 con: self.clone(),
180 is_pattern: true,
181 }),
182 Err(_) => Err(error::internal("Subscription failed, try again later...")),
183 }
184 }
185
186 pub fn unsubscribe<T: Into<String>>(&self, topic: T) {
189 let _ = self
192 .out_tx_c
193 .do_work(PubsubEvent::Unsubscribe(topic.into()));
194 }
195
196 pub fn punsubscribe<T: Into<String>>(&self, topic: T) {
197 let _ = self
200 .out_tx_c
201 .do_work(PubsubEvent::Punsubscribe(topic.into()));
202 }
203}
204
205#[derive(Debug)]
206pub struct PubsubStream {
207 topic: String,
208 underlying: PubsubStreamInner,
209 con: PubsubConnection,
210 is_pattern: bool,
211}
212
213impl Stream for PubsubStream {
214 type Item = Result<resp::RespValue, error::Error>;
215
216 #[inline]
217 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
218 self.get_mut().underlying.poll_next_unpin(cx)
219 }
220}
221
222impl Drop for PubsubStream {
223 fn drop(&mut self) {
224 let topic: &str = self.topic.as_ref();
225 if self.is_pattern {
226 self.con.punsubscribe(topic);
227 } else {
228 self.con.unsubscribe(topic);
229 }
230 }
231}
232
233#[cfg(test)]
234mod test {
235 use std::mem;
236
237 use futures::{try_join, StreamExt, TryStreamExt};
238
239 use crate::{client, resp};
240
241 static SUBSCRIBE_TEST_TOPIC: &str = "test-topic";
243 static SUBSCRIBE_TEST_NON_TOPIC: &str = "test-not-topic";
244
245 static UNSUBSCRIBE_TOPIC_1: &str = "test-topic-1";
246 static UNSUBSCRIBE_TOPIC_2: &str = "test-topic-2";
247 static UNSUBSCRIBE_TOPIC_3: &str = "test-topic-3";
248
249 static RESUBSCRIBE_TOPIC: &str = "test-topic-resubscribe";
250
251 static DROP_CONNECTION_TOPIC: &str = "test-topic-drop-connection";
252
253 static PSUBSCRIBE_PATTERN: &str = "ptest.*";
254 static PSUBSCRIBE_TOPIC_1: &str = "ptest.1";
255 static PSUBSCRIBE_TOPIC_2: &str = "ptest.2";
256 static PSUBSCRIBE_TOPIC_3: &str = "ptest.3";
257
258 static UNSUBSCRIBE_TWICE_TOPIC_1: &str = "test-topic-1-twice";
259 static UNSUBSCRIBE_TWICE_TOPIC_2: &str = "test-topic-2-twice";
260
261 #[tokio::test]
262 async fn subscribe_test() {
263 let paired_c = client::paired_connect("127.0.0.1", 6379);
264 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
265 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
266
267 let topic_messages = pubsub
268 .subscribe(SUBSCRIBE_TEST_TOPIC)
269 .await
270 .expect("Cannot subscribe to topic");
271
272 paired.send_and_forget(resp_array!["PUBLISH", SUBSCRIBE_TEST_TOPIC, "test-message"]);
273 paired.send_and_forget(resp_array![
274 "PUBLISH",
275 SUBSCRIBE_TEST_NON_TOPIC,
276 "test-message-1.5"
277 ]);
278 let _: resp::RespValue = paired
279 .send(resp_array![
280 "PUBLISH",
281 SUBSCRIBE_TEST_TOPIC,
282 "test-message2"
283 ])
284 .await
285 .expect("Cannot send to topic");
286
287 let result: Vec<_> = topic_messages
288 .take(2)
289 .try_collect()
290 .await
291 .expect("Cannot collect two values");
292
293 assert_eq!(result.len(), 2);
294 assert_eq!(result[0], "test-message".into());
295 assert_eq!(result[1], "test-message2".into());
296 }
297
298 #[tokio::test]
301 async fn unsubscribe_test() {
302 let paired_c = client::paired_connect("127.0.0.1", 6379);
303 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
304 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
305
306 let mut topic_1 = pubsub
307 .subscribe(UNSUBSCRIBE_TOPIC_1)
308 .await
309 .expect("Cannot subscribe to topic");
310 let mut topic_2 = pubsub
311 .subscribe(UNSUBSCRIBE_TOPIC_2)
312 .await
313 .expect("Cannot subscribe to topic");
314 let mut topic_3 = pubsub
315 .subscribe(UNSUBSCRIBE_TOPIC_3)
316 .await
317 .expect("Cannot subscribe to topic");
318
319 paired.send_and_forget(resp_array![
320 "PUBLISH",
321 UNSUBSCRIBE_TOPIC_1,
322 "test-message-1"
323 ]);
324 paired.send_and_forget(resp_array![
325 "PUBLISH",
326 UNSUBSCRIBE_TOPIC_2,
327 "test-message-2"
328 ]);
329 paired.send_and_forget(resp_array![
330 "PUBLISH",
331 UNSUBSCRIBE_TOPIC_3,
332 "test-message-3"
333 ]);
334
335 let result1 = topic_1
336 .next()
337 .await
338 .expect("Cannot get next value")
339 .expect("Cannot get next value");
340 assert_eq!(result1, "test-message-1".into());
341
342 let result2 = topic_2
343 .next()
344 .await
345 .expect("Cannot get next value")
346 .expect("Cannot get next value");
347 assert_eq!(result2, "test-message-2".into());
348
349 let result3 = topic_3
350 .next()
351 .await
352 .expect("Cannot get next value")
353 .expect("Cannot get next value");
354 assert_eq!(result3, "test-message-3".into());
355
356 pubsub.unsubscribe(UNSUBSCRIBE_TOPIC_2);
358
359 mem::drop(topic_3);
361
362 paired.send_and_forget(resp_array![
364 "PUBLISH",
365 UNSUBSCRIBE_TOPIC_1,
366 "test-message-1.5"
367 ]);
368 paired.send_and_forget(resp_array![
369 "PUBLISH",
370 UNSUBSCRIBE_TOPIC_2,
371 "test-message-2.5"
372 ]);
373 paired.send_and_forget(resp_array![
374 "PUBLISH",
375 UNSUBSCRIBE_TOPIC_3,
376 "test-message-3.5"
377 ]);
378
379 let result1 = topic_1
381 .next()
382 .await
383 .expect("Cannot get next value")
384 .expect("Cannot get next value");
385 assert_eq!(result1, "test-message-1.5".into());
386
387 let result2 = topic_2.next().await;
389 assert!(result2.is_none());
390 }
391
392 #[tokio::test]
394 async fn resubscribe_test() {
395 let paired_c = client::paired_connect("127.0.0.1", 6379);
396 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
397 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
398
399 let mut topic_1 = pubsub
400 .subscribe(RESUBSCRIBE_TOPIC)
401 .await
402 .expect("Cannot subscribe to topic");
403
404 paired.send_and_forget(resp_array!["PUBLISH", RESUBSCRIBE_TOPIC, "test-message-1"]);
405
406 let result1 = topic_1
407 .next()
408 .await
409 .expect("Cannot get next value")
410 .expect("Cannot get next value");
411 assert_eq!(result1, "test-message-1".into());
412
413 pubsub.unsubscribe(RESUBSCRIBE_TOPIC);
415
416 paired.send_and_forget(resp_array![
418 "PUBLISH",
419 RESUBSCRIBE_TOPIC,
420 "test-message-1.5"
421 ]);
422
423 let result1 = topic_1.next().await;
425 assert!(result1.is_none());
426
427 let mut topic_1 = pubsub
429 .subscribe(RESUBSCRIBE_TOPIC)
430 .await
431 .expect("Cannot subscribe to topic");
432
433 paired.send_and_forget(resp_array![
435 "PUBLISH",
436 RESUBSCRIBE_TOPIC,
437 "test-message-1.75"
438 ]);
439
440 let result1 = topic_1
442 .next()
443 .await
444 .expect("Cannot get next value")
445 .expect("Cannot get next value");
446 assert_eq!(result1, "test-message-1.75".into());
447 }
448
449 #[tokio::test]
451 async fn drop_connection_test() {
452 let paired_c = client::paired_connect("127.0.0.1", 6379);
453 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
454 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
455
456 let mut topic_1 = pubsub
457 .subscribe(DROP_CONNECTION_TOPIC)
458 .await
459 .expect("Cannot subscribe to topic");
460
461 mem::drop(pubsub);
462
463 paired.send_and_forget(resp_array![
464 "PUBLISH",
465 DROP_CONNECTION_TOPIC,
466 "test-message-1"
467 ]);
468
469 let result1 = topic_1
470 .next()
471 .await
472 .expect("Cannot get next value")
473 .expect("Cannot get next value");
474 assert_eq!(result1, "test-message-1".into());
475
476 mem::drop(topic_1);
477 }
478
479 #[tokio::test]
480 async fn psubscribe_test() {
481 let paired_c = client::paired_connect("127.0.0.1", 6379);
482 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
483 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
484
485 let topic_messages = pubsub
486 .psubscribe(PSUBSCRIBE_PATTERN)
487 .await
488 .expect("Cannot subscribe to topic");
489
490 paired.send_and_forget(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_1, "test-message-1"]);
491 paired.send_and_forget(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_2, "test-message-2"]);
492 let _: resp::RespValue = paired
493 .send(resp_array!["PUBLISH", PSUBSCRIBE_TOPIC_3, "test-message-3"])
494 .await
495 .expect("Cannot send to topic");
496
497 let result: Vec<_> = topic_messages
498 .take(3)
499 .try_collect()
500 .await
501 .expect("Cannot collect two values");
502
503 assert_eq!(result.len(), 3);
504 assert_eq!(result[0], "test-message-1".into());
505 assert_eq!(result[1], "test-message-2".into());
506 assert_eq!(result[2], "test-message-3".into());
507 }
508
509 #[tokio::test]
511 async fn unsubscribe_twice_test() {
512 let paired_c = client::paired_connect("127.0.0.1", 6379);
513 let pubsub_c = super::pubsub_connect("127.0.0.1", 6379);
514 let (paired, pubsub) = try_join!(paired_c, pubsub_c).expect("Cannot connect to Redis");
515
516 let mut topic_1 = pubsub
517 .subscribe(UNSUBSCRIBE_TWICE_TOPIC_1)
518 .await
519 .expect("Cannot subscribe to topic");
520 let mut topic_2 = pubsub
521 .subscribe(UNSUBSCRIBE_TWICE_TOPIC_2)
522 .await
523 .expect("Cannot subscribe to topic");
524
525 paired.send_and_forget(resp_array![
526 "PUBLISH",
527 UNSUBSCRIBE_TWICE_TOPIC_1,
528 "test-message-1"
529 ]);
530 paired.send_and_forget(resp_array![
531 "PUBLISH",
532 UNSUBSCRIBE_TWICE_TOPIC_2,
533 "test-message-2"
534 ]);
535
536 pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_2);
537 pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_2);
538
539 paired.send_and_forget(resp_array![
540 "PUBLISH",
541 UNSUBSCRIBE_TWICE_TOPIC_1,
542 "test-message-1.5"
543 ]);
544
545 pubsub.unsubscribe(UNSUBSCRIBE_TWICE_TOPIC_1);
546
547 let result1 = topic_1
548 .next()
549 .await
550 .expect("Cannot get next value")
551 .expect("Cannot get next value");
552 assert_eq!(result1, "test-message-1".into());
553
554 let result1 = topic_1
555 .next()
556 .await
557 .expect("Cannot get next value")
558 .expect("Cannot get next value");
559 assert_eq!(result1, "test-message-1.5".into());
560
561 let result2 = topic_2
562 .next()
563 .await
564 .expect("Cannot get next value")
565 .expect("Cannot get next value");
566 assert_eq!(result2, "test-message-2".into());
567
568 let result1 = topic_1.next().await;
569 assert!(result1.is_none());
570
571 let result2 = topic_2.next().await;
572 assert!(result2.is_none());
573 }
574}