redis_oxide/pubsub.rs
1//! Pub/Sub support for Redis
2//!
3//! This module provides functionality for Redis publish/subscribe messaging.
4//! Redis Pub/Sub allows you to send messages between different parts of your
5//! application or between different applications.
6//!
7//! # Examples
8//!
9//! ## Publisher
10//!
11//! ```no_run
12//! use redis_oxide::{Client, ConnectionConfig};
13//!
14//! # #[tokio::main]
15//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
16//! let config = ConnectionConfig::new("redis://localhost:6379");
17//! let client = Client::connect(config).await?;
18//!
19//! // Publish a message to a channel
20//! let subscribers = client.publish("news", "Breaking news!").await?;
21//! println!("Message sent to {} subscribers", subscribers);
22//! # Ok(())
23//! # }
24//! ```
25//!
26//! ## Subscriber
27//!
28//! ```no_run
29//! use redis_oxide::{Client, ConnectionConfig};
30//! use futures::StreamExt;
31//!
32//! # #[tokio::main]
33//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
34//! let config = ConnectionConfig::new("redis://localhost:6379");
35//! let client = Client::connect(config).await?;
36//!
37//! // Subscribe to channels
38//! let mut subscriber = client.subscriber().await?;
39//! subscriber.subscribe(vec!["news".to_string(), "updates".to_string()]).await?;
40//!
41//! // Listen for messages
42//! while let Some(message) = subscriber.next_message().await? {
43//! println!("Received: {} on channel {}", message.payload, message.channel);
44//! }
45//! # Ok(())
46//! # }
47//! ```
48
49use crate::core::{
50 error::{RedisError, RedisResult},
51 value::RespValue,
52};
53use futures_util::Stream;
54use std::collections::HashMap;
55use std::pin::Pin;
56use std::sync::Arc;
57use std::task::{Context, Poll};
58use tokio::sync::{mpsc, Mutex};
59use tokio::time::{timeout, Duration};
60
61/// A message received from a Redis channel
62#[derive(Debug, Clone)]
63pub struct PubSubMessage {
64 /// The channel the message was received on
65 pub channel: String,
66 /// The message payload
67 pub payload: String,
68 /// The pattern that matched (for pattern subscriptions)
69 pub pattern: Option<String>,
70}
71
72/// Redis Pub/Sub subscriber
73pub struct Subscriber {
74 connection: Arc<Mutex<dyn PubSubConnection + Send + Sync>>,
75 message_rx: mpsc::UnboundedReceiver<PubSubMessage>,
76 subscribed_channels: HashMap<String, bool>,
77 subscribed_patterns: HashMap<String, bool>,
78}
79
80/// Trait for Pub/Sub connections
81#[async_trait::async_trait]
82pub trait PubSubConnection {
83 /// Subscribe to channels
84 async fn subscribe(&mut self, channels: Vec<String>) -> RedisResult<()>;
85
86 /// Unsubscribe from channels
87 async fn unsubscribe(&mut self, channels: Vec<String>) -> RedisResult<()>;
88
89 /// Subscribe to patterns
90 async fn psubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()>;
91
92 /// Unsubscribe from patterns
93 async fn punsubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()>;
94
95 /// Start listening for messages
96 async fn listen(&mut self, message_tx: mpsc::UnboundedSender<PubSubMessage>)
97 -> RedisResult<()>;
98
99 /// Publish a message to a channel
100 async fn publish(&mut self, channel: String, message: String) -> RedisResult<i64>;
101}
102
103impl Subscriber {
104 /// Create a new subscriber
105 pub fn new(connection: Arc<Mutex<dyn PubSubConnection + Send + Sync>>) -> Self {
106 let (message_tx, message_rx) = mpsc::unbounded_channel();
107
108 // Start listening for messages in the background
109 let conn_clone = connection.clone();
110 tokio::spawn(async move {
111 let mut conn = conn_clone.lock().await;
112 if let Err(e) = conn.listen(message_tx).await {
113 eprintln!("Pub/Sub listener error: {}", e);
114 }
115 });
116
117 Self {
118 connection,
119 message_rx,
120 subscribed_channels: HashMap::new(),
121 subscribed_patterns: HashMap::new(),
122 }
123 }
124
125 /// Subscribe to one or more channels
126 ///
127 /// # Examples
128 ///
129 /// ```no_run
130 /// # use redis_oxide::{Client, ConnectionConfig};
131 /// # #[tokio::main]
132 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
133 /// # let config = ConnectionConfig::new("redis://localhost:6379");
134 /// # let client = Client::connect(config).await?;
135 /// let mut subscriber = client.subscriber().await?;
136 ///
137 /// // Subscribe to multiple channels
138 /// subscriber.subscribe(vec![
139 /// "news".to_string(),
140 /// "updates".to_string(),
141 /// "alerts".to_string()
142 /// ]).await?;
143 /// # Ok(())
144 /// # }
145 /// ```
146 pub async fn subscribe(&mut self, channels: Vec<String>) -> RedisResult<()> {
147 let mut connection = self.connection.lock().await;
148 connection.subscribe(channels.clone()).await?;
149
150 for channel in channels {
151 self.subscribed_channels.insert(channel, true);
152 }
153
154 Ok(())
155 }
156
157 /// Unsubscribe from one or more channels
158 ///
159 /// # Examples
160 ///
161 /// ```no_run
162 /// # use redis_oxide::{Client, ConnectionConfig};
163 /// # #[tokio::main]
164 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
165 /// # let config = ConnectionConfig::new("redis://localhost:6379");
166 /// # let client = Client::connect(config).await?;
167 /// let mut subscriber = client.subscriber().await?;
168 /// subscriber.subscribe(vec!["news".to_string()]).await?;
169 ///
170 /// // Later, unsubscribe
171 /// subscriber.unsubscribe(vec!["news".to_string()]).await?;
172 /// # Ok(())
173 /// # }
174 /// ```
175 pub async fn unsubscribe(&mut self, channels: Vec<String>) -> RedisResult<()> {
176 let mut connection = self.connection.lock().await;
177 connection.unsubscribe(channels.clone()).await?;
178
179 for channel in channels {
180 self.subscribed_channels.remove(&channel);
181 }
182
183 Ok(())
184 }
185
186 /// Subscribe to one or more patterns
187 ///
188 /// Patterns support glob-style matching:
189 /// - `*` matches any sequence of characters
190 /// - `?` matches any single character
191 /// - `[abc]` matches any character in the set
192 ///
193 /// # Examples
194 ///
195 /// ```no_run
196 /// # use redis_oxide::{Client, ConnectionConfig};
197 /// # #[tokio::main]
198 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
199 /// # let config = ConnectionConfig::new("redis://localhost:6379");
200 /// # let client = Client::connect(config).await?;
201 /// let mut subscriber = client.subscriber().await?;
202 ///
203 /// // Subscribe to all channels starting with "news"
204 /// subscriber.psubscribe(vec!["news*".to_string()]).await?;
205 ///
206 /// // Subscribe to all channels ending with "log"
207 /// subscriber.psubscribe(vec!["*log".to_string()]).await?;
208 /// # Ok(())
209 /// # }
210 /// ```
211 pub async fn psubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()> {
212 let mut connection = self.connection.lock().await;
213 connection.psubscribe(patterns.clone()).await?;
214
215 for pattern in patterns {
216 self.subscribed_patterns.insert(pattern, true);
217 }
218
219 Ok(())
220 }
221
222 /// Unsubscribe from one or more patterns
223 pub async fn punsubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()> {
224 let mut connection = self.connection.lock().await;
225 connection.punsubscribe(patterns.clone()).await?;
226
227 for pattern in patterns {
228 self.subscribed_patterns.remove(&pattern);
229 }
230
231 Ok(())
232 }
233
234 /// Get the next message from subscribed channels
235 ///
236 /// This method will block until a message is received or an error occurs.
237 ///
238 /// # Examples
239 ///
240 /// ```no_run
241 /// # use redis_oxide::{Client, ConnectionConfig};
242 /// # #[tokio::main]
243 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
244 /// # let config = ConnectionConfig::new("redis://localhost:6379");
245 /// # let client = Client::connect(config).await?;
246 /// let mut subscriber = client.subscriber().await?;
247 /// subscriber.subscribe(vec!["news".to_string()]).await?;
248 ///
249 /// // Wait for the next message
250 /// if let Some(message) = subscriber.next_message().await? {
251 /// println!("Received: {} on {}", message.payload, message.channel);
252 /// }
253 /// # Ok(())
254 /// # }
255 /// ```
256 pub async fn next_message(&mut self) -> RedisResult<Option<PubSubMessage>> {
257 match self.message_rx.recv().await {
258 Some(message) => Ok(Some(message)),
259 None => Ok(None), // Channel closed
260 }
261 }
262
263 /// Get the next message with a timeout
264 ///
265 /// # Examples
266 ///
267 /// ```no_run
268 /// # use redis_oxide::{Client, ConnectionConfig};
269 /// # use std::time::Duration;
270 /// # #[tokio::main]
271 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
272 /// # let config = ConnectionConfig::new("redis://localhost:6379");
273 /// # let client = Client::connect(config).await?;
274 /// let mut subscriber = client.subscriber().await?;
275 /// subscriber.subscribe(vec!["news".to_string()]).await?;
276 ///
277 /// // Wait for a message with 5 second timeout
278 /// match subscriber.next_message_timeout(Duration::from_secs(5)).await? {
279 /// Some(message) => println!("Received: {}", message.payload),
280 /// None => println!("No message received within timeout"),
281 /// }
282 /// # Ok(())
283 /// # }
284 /// ```
285 pub async fn next_message_timeout(
286 &mut self,
287 duration: Duration,
288 ) -> RedisResult<Option<PubSubMessage>> {
289 match timeout(duration, self.message_rx.recv()).await {
290 Ok(Some(message)) => Ok(Some(message)),
291 Ok(None) => Ok(None), // Channel closed
292 Err(_) => Ok(None), // Timeout
293 }
294 }
295
296 /// Get a list of currently subscribed channels
297 #[must_use]
298 pub fn subscribed_channels(&self) -> Vec<String> {
299 self.subscribed_channels.keys().cloned().collect()
300 }
301
302 /// Get a list of currently subscribed patterns
303 #[must_use]
304 pub fn subscribed_patterns(&self) -> Vec<String> {
305 self.subscribed_patterns.keys().cloned().collect()
306 }
307
308 /// Check if subscribed to a specific channel
309 #[must_use]
310 pub fn is_subscribed_to_channel(&self, channel: &str) -> bool {
311 self.subscribed_channels.contains_key(channel)
312 }
313
314 /// Check if subscribed to a specific pattern
315 #[must_use]
316 pub fn is_subscribed_to_pattern(&self, pattern: &str) -> bool {
317 self.subscribed_patterns.contains_key(pattern)
318 }
319}
320
321/// Stream implementation for Subscriber
322impl Stream for Subscriber {
323 type Item = RedisResult<PubSubMessage>;
324
325 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
326 match self.message_rx.poll_recv(cx) {
327 Poll::Ready(Some(message)) => Poll::Ready(Some(Ok(message))),
328 Poll::Ready(None) => Poll::Ready(None), // Channel closed
329 Poll::Pending => Poll::Pending,
330 }
331 }
332}
333
334/// Publisher for sending messages to Redis channels
335pub struct Publisher {
336 connection: Arc<Mutex<dyn PubSubConnection + Send + Sync>>,
337}
338
339impl Publisher {
340 /// Create a new publisher
341 pub fn new(connection: Arc<Mutex<dyn PubSubConnection + Send + Sync>>) -> Self {
342 Self { connection }
343 }
344
345 /// Publish a message to a channel
346 ///
347 /// Returns the number of subscribers that received the message.
348 ///
349 /// # Examples
350 ///
351 /// ```no_run
352 /// # use redis_oxide::{Client, ConnectionConfig};
353 /// # #[tokio::main]
354 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
355 /// # let config = ConnectionConfig::new("redis://localhost:6379");
356 /// # let client = Client::connect(config).await?;
357 /// let publisher = client.publisher().await?;
358 ///
359 /// let subscribers = publisher.publish("news", "Breaking news!").await?;
360 /// println!("Message delivered to {} subscribers", subscribers);
361 /// # Ok(())
362 /// # }
363 /// ```
364 pub async fn publish(
365 &self,
366 channel: impl Into<String>,
367 message: impl Into<String>,
368 ) -> RedisResult<i64> {
369 let mut connection = self.connection.lock().await;
370 connection.publish(channel.into(), message.into()).await
371 }
372
373 /// Publish multiple messages to different channels
374 ///
375 /// # Examples
376 ///
377 /// ```no_run
378 /// # use redis_oxide::{Client, ConnectionConfig};
379 /// # use std::collections::HashMap;
380 /// # #[tokio::main]
381 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
382 /// # let config = ConnectionConfig::new("redis://localhost:6379");
383 /// # let client = Client::connect(config).await?;
384 /// let publisher = client.publisher().await?;
385 ///
386 /// let mut messages = HashMap::new();
387 /// messages.insert("news".to_string(), "Breaking news!".to_string());
388 /// messages.insert("updates".to_string(), "System update available".to_string());
389 ///
390 /// let results = publisher.publish_multiple(messages).await?;
391 /// for (channel, count) in results {
392 /// println!("Channel {}: {} subscribers", channel, count);
393 /// }
394 /// # Ok(())
395 /// # }
396 /// ```
397 pub async fn publish_multiple(
398 &self,
399 messages: HashMap<String, String>,
400 ) -> RedisResult<HashMap<String, i64>> {
401 let mut results = HashMap::new();
402
403 for (channel, message) in messages {
404 let count = self.publish(&channel, message).await?;
405 results.insert(channel, count);
406 }
407
408 Ok(results)
409 }
410}
411
412/// Pub/Sub message types for internal parsing
413#[derive(Debug)]
414enum PubSubMessageType {
415 Subscribe,
416 Unsubscribe,
417 Message,
418 PSubscribe,
419 PUnsubscribe,
420 PMessage,
421}
422
423impl PubSubMessageType {
424 fn from_str(s: &str) -> Option<Self> {
425 match s {
426 "subscribe" => Some(Self::Subscribe),
427 "unsubscribe" => Some(Self::Unsubscribe),
428 "message" => Some(Self::Message),
429 "psubscribe" => Some(Self::PSubscribe),
430 "punsubscribe" => Some(Self::PUnsubscribe),
431 "pmessage" => Some(Self::PMessage),
432 _ => None,
433 }
434 }
435}
436
437/// Parse a Pub/Sub message from Redis response
438pub fn parse_pubsub_message(response: RespValue) -> RedisResult<Option<PubSubMessage>> {
439 match response {
440 RespValue::Array(items) if items.len() >= 3 => {
441 let message_type = items[0].as_string()?;
442 let msg_type = PubSubMessageType::from_str(&message_type);
443
444 match msg_type {
445 Some(PubSubMessageType::Message) => {
446 let channel = items[1].as_string()?;
447 let payload = items[2].as_string()?;
448
449 Ok(Some(PubSubMessage {
450 channel,
451 payload,
452 pattern: None,
453 }))
454 }
455 Some(PubSubMessageType::PMessage) if items.len() >= 4 => {
456 let pattern = items[1].as_string()?;
457 let channel = items[2].as_string()?;
458 let payload = items[3].as_string()?;
459
460 Ok(Some(PubSubMessage {
461 channel,
462 payload,
463 pattern: Some(pattern),
464 }))
465 }
466 Some(
467 PubSubMessageType::Subscribe
468 | PubSubMessageType::Unsubscribe
469 | PubSubMessageType::PSubscribe
470 | PubSubMessageType::PUnsubscribe,
471 ) => {
472 // These are subscription confirmations, not actual messages
473 Ok(None)
474 }
475 _ => Err(RedisError::Protocol(format!(
476 "Unknown pub/sub message type: {}",
477 message_type
478 ))),
479 }
480 }
481 _ => Err(RedisError::Protocol(format!(
482 "Invalid pub/sub message format: {:?}",
483 response
484 ))),
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491 use std::sync::Arc;
492 use tokio::sync::Mutex;
493
494 struct MockPubSubConnection {
495 published_messages: Vec<(String, String)>,
496 subscribed_channels: Vec<String>,
497 subscribed_patterns: Vec<String>,
498 }
499
500 impl MockPubSubConnection {
501 fn new() -> Self {
502 Self {
503 published_messages: Vec::new(),
504 subscribed_channels: Vec::new(),
505 subscribed_patterns: Vec::new(),
506 }
507 }
508 }
509
510 #[async_trait::async_trait]
511 impl PubSubConnection for MockPubSubConnection {
512 async fn subscribe(&mut self, channels: Vec<String>) -> RedisResult<()> {
513 self.subscribed_channels.extend(channels);
514 Ok(())
515 }
516
517 async fn unsubscribe(&mut self, channels: Vec<String>) -> RedisResult<()> {
518 for channel in channels {
519 self.subscribed_channels.retain(|c| c != &channel);
520 }
521 Ok(())
522 }
523
524 async fn psubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()> {
525 self.subscribed_patterns.extend(patterns);
526 Ok(())
527 }
528
529 async fn punsubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()> {
530 for pattern in patterns {
531 self.subscribed_patterns.retain(|p| p != &pattern);
532 }
533 Ok(())
534 }
535
536 async fn listen(
537 &mut self,
538 _message_tx: mpsc::UnboundedSender<PubSubMessage>,
539 ) -> RedisResult<()> {
540 // Mock implementation - would normally listen for messages
541 Ok(())
542 }
543
544 async fn publish(&mut self, channel: String, message: String) -> RedisResult<i64> {
545 self.published_messages.push((channel, message));
546 Ok(1) // Mock: 1 subscriber
547 }
548 }
549
550 #[tokio::test]
551 async fn test_subscriber_creation() {
552 let connection = MockPubSubConnection::new();
553 let subscriber = Subscriber::new(Arc::new(Mutex::new(connection)));
554
555 assert!(subscriber.subscribed_channels().is_empty());
556 assert!(subscriber.subscribed_patterns().is_empty());
557 }
558
559 #[tokio::test]
560 async fn test_subscriber_subscribe() {
561 let connection = MockPubSubConnection::new();
562 let mut subscriber = Subscriber::new(Arc::new(Mutex::new(connection)));
563
564 subscriber
565 .subscribe(vec!["news".to_string(), "updates".to_string()])
566 .await
567 .unwrap();
568
569 assert_eq!(subscriber.subscribed_channels().len(), 2);
570 assert!(subscriber.is_subscribed_to_channel("news"));
571 assert!(subscriber.is_subscribed_to_channel("updates"));
572 }
573
574 #[tokio::test]
575 async fn test_subscriber_unsubscribe() {
576 let connection = MockPubSubConnection::new();
577 let mut subscriber = Subscriber::new(Arc::new(Mutex::new(connection)));
578
579 subscriber
580 .subscribe(vec!["news".to_string(), "updates".to_string()])
581 .await
582 .unwrap();
583 subscriber
584 .unsubscribe(vec!["news".to_string()])
585 .await
586 .unwrap();
587
588 assert_eq!(subscriber.subscribed_channels().len(), 1);
589 assert!(!subscriber.is_subscribed_to_channel("news"));
590 assert!(subscriber.is_subscribed_to_channel("updates"));
591 }
592
593 #[tokio::test]
594 async fn test_publisher_publish() {
595 let connection = MockPubSubConnection::new();
596 let publisher = Publisher::new(Arc::new(Mutex::new(connection)));
597
598 let count = publisher.publish("news", "Breaking news!").await.unwrap();
599 assert_eq!(count, 1);
600 }
601
602 #[test]
603 fn test_parse_pubsub_message() {
604 // Test regular message
605 let response = RespValue::Array(vec![
606 RespValue::from("message"),
607 RespValue::from("news"),
608 RespValue::from("Breaking news!"),
609 ]);
610
611 let message = parse_pubsub_message(response).unwrap().unwrap();
612 assert_eq!(message.channel, "news");
613 assert_eq!(message.payload, "Breaking news!");
614 assert!(message.pattern.is_none());
615 }
616
617 #[test]
618 fn test_parse_pubsub_pattern_message() {
619 // Test pattern message
620 let response = RespValue::Array(vec![
621 RespValue::from("pmessage"),
622 RespValue::from("news*"),
623 RespValue::from("news-tech"),
624 RespValue::from("Tech news!"),
625 ]);
626
627 let message = parse_pubsub_message(response).unwrap().unwrap();
628 assert_eq!(message.channel, "news-tech");
629 assert_eq!(message.payload, "Tech news!");
630 assert_eq!(message.pattern, Some("news*".to_string()));
631 }
632
633 #[test]
634 fn test_parse_pubsub_subscribe_confirmation() {
635 // Test subscription confirmation (should return None)
636 let response = RespValue::Array(vec![
637 RespValue::from("subscribe"),
638 RespValue::from("news"),
639 RespValue::Integer(1),
640 ]);
641
642 let message = parse_pubsub_message(response).unwrap();
643 assert!(message.is_none());
644 }
645}