redis_oxide/
pubsub.rs

1//! Pub/Sub support for Redis
2//!
3//! This module provides functionality for Redis publish/subscribe messaging.
4//! Redis Pub/Sub allows you to send messages between different parts of your
5//! application or between different applications.
6//!
7//! # Examples
8//!
9//! ## Publisher
10//!
11//! ```no_run
12//! use redis_oxide::{Client, ConnectionConfig};
13//!
14//! # #[tokio::main]
15//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
16//! let config = ConnectionConfig::new("redis://localhost:6379");
17//! let client = Client::connect(config).await?;
18//!
19//! // Publish a message to a channel
20//! let subscribers = client.publish("news", "Breaking news!").await?;
21//! println!("Message sent to {} subscribers", subscribers);
22//! # Ok(())
23//! # }
24//! ```
25//!
26//! ## Subscriber
27//!
28//! ```no_run
29//! use redis_oxide::{Client, ConnectionConfig};
30//! use futures::StreamExt;
31//!
32//! # #[tokio::main]
33//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
34//! let config = ConnectionConfig::new("redis://localhost:6379");
35//! let client = Client::connect(config).await?;
36//!
37//! // Subscribe to channels
38//! let mut subscriber = client.subscriber().await?;
39//! subscriber.subscribe(vec!["news".to_string(), "updates".to_string()]).await?;
40//!
41//! // Listen for messages
42//! while let Some(message) = subscriber.next_message().await? {
43//!     println!("Received: {} on channel {}", message.payload, message.channel);
44//! }
45//! # Ok(())
46//! # }
47//! ```
48
49use crate::core::{
50    error::{RedisError, RedisResult},
51    value::RespValue,
52};
53use futures_util::Stream;
54use std::collections::HashMap;
55use std::pin::Pin;
56use std::sync::Arc;
57use std::task::{Context, Poll};
58use tokio::sync::{mpsc, Mutex};
59use tokio::time::{timeout, Duration};
60
61/// A message received from a Redis channel
62#[derive(Debug, Clone)]
63pub struct PubSubMessage {
64    /// The channel the message was received on
65    pub channel: String,
66    /// The message payload
67    pub payload: String,
68    /// The pattern that matched (for pattern subscriptions)
69    pub pattern: Option<String>,
70}
71
72/// Redis Pub/Sub subscriber
73pub struct Subscriber {
74    connection: Arc<Mutex<dyn PubSubConnection + Send + Sync>>,
75    message_rx: mpsc::UnboundedReceiver<PubSubMessage>,
76    subscribed_channels: HashMap<String, bool>,
77    subscribed_patterns: HashMap<String, bool>,
78}
79
80/// Trait for Pub/Sub connections
81#[async_trait::async_trait]
82pub trait PubSubConnection {
83    /// Subscribe to channels
84    async fn subscribe(&mut self, channels: Vec<String>) -> RedisResult<()>;
85
86    /// Unsubscribe from channels
87    async fn unsubscribe(&mut self, channels: Vec<String>) -> RedisResult<()>;
88
89    /// Subscribe to patterns
90    async fn psubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()>;
91
92    /// Unsubscribe from patterns
93    async fn punsubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()>;
94
95    /// Start listening for messages
96    async fn listen(&mut self, message_tx: mpsc::UnboundedSender<PubSubMessage>)
97        -> RedisResult<()>;
98
99    /// Publish a message to a channel
100    async fn publish(&mut self, channel: String, message: String) -> RedisResult<i64>;
101}
102
103impl Subscriber {
104    /// Create a new subscriber
105    pub fn new(connection: Arc<Mutex<dyn PubSubConnection + Send + Sync>>) -> Self {
106        let (message_tx, message_rx) = mpsc::unbounded_channel();
107
108        // Start listening for messages in the background
109        let conn_clone = connection.clone();
110        tokio::spawn(async move {
111            let mut conn = conn_clone.lock().await;
112            if let Err(e) = conn.listen(message_tx).await {
113                eprintln!("Pub/Sub listener error: {}", e);
114            }
115        });
116
117        Self {
118            connection,
119            message_rx,
120            subscribed_channels: HashMap::new(),
121            subscribed_patterns: HashMap::new(),
122        }
123    }
124
125    /// Subscribe to one or more channels
126    ///
127    /// # Examples
128    ///
129    /// ```no_run
130    /// # use redis_oxide::{Client, ConnectionConfig};
131    /// # #[tokio::main]
132    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
133    /// # let config = ConnectionConfig::new("redis://localhost:6379");
134    /// # let client = Client::connect(config).await?;
135    /// let mut subscriber = client.subscriber().await?;
136    ///
137    /// // Subscribe to multiple channels
138    /// subscriber.subscribe(vec![
139    ///     "news".to_string(),
140    ///     "updates".to_string(),
141    ///     "alerts".to_string()
142    /// ]).await?;
143    /// # Ok(())
144    /// # }
145    /// ```
146    pub async fn subscribe(&mut self, channels: Vec<String>) -> RedisResult<()> {
147        let mut connection = self.connection.lock().await;
148        connection.subscribe(channels.clone()).await?;
149
150        for channel in channels {
151            self.subscribed_channels.insert(channel, true);
152        }
153
154        Ok(())
155    }
156
157    /// Unsubscribe from one or more channels
158    ///
159    /// # Examples
160    ///
161    /// ```no_run
162    /// # use redis_oxide::{Client, ConnectionConfig};
163    /// # #[tokio::main]
164    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
165    /// # let config = ConnectionConfig::new("redis://localhost:6379");
166    /// # let client = Client::connect(config).await?;
167    /// let mut subscriber = client.subscriber().await?;
168    /// subscriber.subscribe(vec!["news".to_string()]).await?;
169    ///
170    /// // Later, unsubscribe
171    /// subscriber.unsubscribe(vec!["news".to_string()]).await?;
172    /// # Ok(())
173    /// # }
174    /// ```
175    pub async fn unsubscribe(&mut self, channels: Vec<String>) -> RedisResult<()> {
176        let mut connection = self.connection.lock().await;
177        connection.unsubscribe(channels.clone()).await?;
178
179        for channel in channels {
180            self.subscribed_channels.remove(&channel);
181        }
182
183        Ok(())
184    }
185
186    /// Subscribe to one or more patterns
187    ///
188    /// Patterns support glob-style matching:
189    /// - `*` matches any sequence of characters
190    /// - `?` matches any single character
191    /// - `[abc]` matches any character in the set
192    ///
193    /// # Examples
194    ///
195    /// ```no_run
196    /// # use redis_oxide::{Client, ConnectionConfig};
197    /// # #[tokio::main]
198    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
199    /// # let config = ConnectionConfig::new("redis://localhost:6379");
200    /// # let client = Client::connect(config).await?;
201    /// let mut subscriber = client.subscriber().await?;
202    ///
203    /// // Subscribe to all channels starting with "news"
204    /// subscriber.psubscribe(vec!["news*".to_string()]).await?;
205    ///
206    /// // Subscribe to all channels ending with "log"
207    /// subscriber.psubscribe(vec!["*log".to_string()]).await?;
208    /// # Ok(())
209    /// # }
210    /// ```
211    pub async fn psubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()> {
212        let mut connection = self.connection.lock().await;
213        connection.psubscribe(patterns.clone()).await?;
214
215        for pattern in patterns {
216            self.subscribed_patterns.insert(pattern, true);
217        }
218
219        Ok(())
220    }
221
222    /// Unsubscribe from one or more patterns
223    pub async fn punsubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()> {
224        let mut connection = self.connection.lock().await;
225        connection.punsubscribe(patterns.clone()).await?;
226
227        for pattern in patterns {
228            self.subscribed_patterns.remove(&pattern);
229        }
230
231        Ok(())
232    }
233
234    /// Get the next message from subscribed channels
235    ///
236    /// This method will block until a message is received or an error occurs.
237    ///
238    /// # Examples
239    ///
240    /// ```no_run
241    /// # use redis_oxide::{Client, ConnectionConfig};
242    /// # #[tokio::main]
243    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
244    /// # let config = ConnectionConfig::new("redis://localhost:6379");
245    /// # let client = Client::connect(config).await?;
246    /// let mut subscriber = client.subscriber().await?;
247    /// subscriber.subscribe(vec!["news".to_string()]).await?;
248    ///
249    /// // Wait for the next message
250    /// if let Some(message) = subscriber.next_message().await? {
251    ///     println!("Received: {} on {}", message.payload, message.channel);
252    /// }
253    /// # Ok(())
254    /// # }
255    /// ```
256    pub async fn next_message(&mut self) -> RedisResult<Option<PubSubMessage>> {
257        match self.message_rx.recv().await {
258            Some(message) => Ok(Some(message)),
259            None => Ok(None), // Channel closed
260        }
261    }
262
263    /// Get the next message with a timeout
264    ///
265    /// # Examples
266    ///
267    /// ```no_run
268    /// # use redis_oxide::{Client, ConnectionConfig};
269    /// # use std::time::Duration;
270    /// # #[tokio::main]
271    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
272    /// # let config = ConnectionConfig::new("redis://localhost:6379");
273    /// # let client = Client::connect(config).await?;
274    /// let mut subscriber = client.subscriber().await?;
275    /// subscriber.subscribe(vec!["news".to_string()]).await?;
276    ///
277    /// // Wait for a message with 5 second timeout
278    /// match subscriber.next_message_timeout(Duration::from_secs(5)).await? {
279    ///     Some(message) => println!("Received: {}", message.payload),
280    ///     None => println!("No message received within timeout"),
281    /// }
282    /// # Ok(())
283    /// # }
284    /// ```
285    pub async fn next_message_timeout(
286        &mut self,
287        duration: Duration,
288    ) -> RedisResult<Option<PubSubMessage>> {
289        match timeout(duration, self.message_rx.recv()).await {
290            Ok(Some(message)) => Ok(Some(message)),
291            Ok(None) => Ok(None), // Channel closed
292            Err(_) => Ok(None),   // Timeout
293        }
294    }
295
296    /// Get a list of currently subscribed channels
297    #[must_use]
298    pub fn subscribed_channels(&self) -> Vec<String> {
299        self.subscribed_channels.keys().cloned().collect()
300    }
301
302    /// Get a list of currently subscribed patterns
303    #[must_use]
304    pub fn subscribed_patterns(&self) -> Vec<String> {
305        self.subscribed_patterns.keys().cloned().collect()
306    }
307
308    /// Check if subscribed to a specific channel
309    #[must_use]
310    pub fn is_subscribed_to_channel(&self, channel: &str) -> bool {
311        self.subscribed_channels.contains_key(channel)
312    }
313
314    /// Check if subscribed to a specific pattern
315    #[must_use]
316    pub fn is_subscribed_to_pattern(&self, pattern: &str) -> bool {
317        self.subscribed_patterns.contains_key(pattern)
318    }
319}
320
321/// Stream implementation for Subscriber
322impl Stream for Subscriber {
323    type Item = RedisResult<PubSubMessage>;
324
325    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
326        match self.message_rx.poll_recv(cx) {
327            Poll::Ready(Some(message)) => Poll::Ready(Some(Ok(message))),
328            Poll::Ready(None) => Poll::Ready(None), // Channel closed
329            Poll::Pending => Poll::Pending,
330        }
331    }
332}
333
334/// Publisher for sending messages to Redis channels
335pub struct Publisher {
336    connection: Arc<Mutex<dyn PubSubConnection + Send + Sync>>,
337}
338
339impl Publisher {
340    /// Create a new publisher
341    pub fn new(connection: Arc<Mutex<dyn PubSubConnection + Send + Sync>>) -> Self {
342        Self { connection }
343    }
344
345    /// Publish a message to a channel
346    ///
347    /// Returns the number of subscribers that received the message.
348    ///
349    /// # Examples
350    ///
351    /// ```no_run
352    /// # use redis_oxide::{Client, ConnectionConfig};
353    /// # #[tokio::main]
354    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
355    /// # let config = ConnectionConfig::new("redis://localhost:6379");
356    /// # let client = Client::connect(config).await?;
357    /// let publisher = client.publisher().await?;
358    ///
359    /// let subscribers = publisher.publish("news", "Breaking news!").await?;
360    /// println!("Message delivered to {} subscribers", subscribers);
361    /// # Ok(())
362    /// # }
363    /// ```
364    pub async fn publish(
365        &self,
366        channel: impl Into<String>,
367        message: impl Into<String>,
368    ) -> RedisResult<i64> {
369        let mut connection = self.connection.lock().await;
370        connection.publish(channel.into(), message.into()).await
371    }
372
373    /// Publish multiple messages to different channels
374    ///
375    /// # Examples
376    ///
377    /// ```no_run
378    /// # use redis_oxide::{Client, ConnectionConfig};
379    /// # use std::collections::HashMap;
380    /// # #[tokio::main]
381    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
382    /// # let config = ConnectionConfig::new("redis://localhost:6379");
383    /// # let client = Client::connect(config).await?;
384    /// let publisher = client.publisher().await?;
385    ///
386    /// let mut messages = HashMap::new();
387    /// messages.insert("news".to_string(), "Breaking news!".to_string());
388    /// messages.insert("updates".to_string(), "System update available".to_string());
389    ///
390    /// let results = publisher.publish_multiple(messages).await?;
391    /// for (channel, count) in results {
392    ///     println!("Channel {}: {} subscribers", channel, count);
393    /// }
394    /// # Ok(())
395    /// # }
396    /// ```
397    pub async fn publish_multiple(
398        &self,
399        messages: HashMap<String, String>,
400    ) -> RedisResult<HashMap<String, i64>> {
401        let mut results = HashMap::new();
402
403        for (channel, message) in messages {
404            let count = self.publish(&channel, message).await?;
405            results.insert(channel, count);
406        }
407
408        Ok(results)
409    }
410}
411
412/// Pub/Sub message types for internal parsing
413#[derive(Debug)]
414enum PubSubMessageType {
415    Subscribe,
416    Unsubscribe,
417    Message,
418    PSubscribe,
419    PUnsubscribe,
420    PMessage,
421}
422
423impl PubSubMessageType {
424    fn from_str(s: &str) -> Option<Self> {
425        match s {
426            "subscribe" => Some(Self::Subscribe),
427            "unsubscribe" => Some(Self::Unsubscribe),
428            "message" => Some(Self::Message),
429            "psubscribe" => Some(Self::PSubscribe),
430            "punsubscribe" => Some(Self::PUnsubscribe),
431            "pmessage" => Some(Self::PMessage),
432            _ => None,
433        }
434    }
435}
436
437/// Parse a Pub/Sub message from Redis response
438pub fn parse_pubsub_message(response: RespValue) -> RedisResult<Option<PubSubMessage>> {
439    match response {
440        RespValue::Array(items) if items.len() >= 3 => {
441            let message_type = items[0].as_string()?;
442            let msg_type = PubSubMessageType::from_str(&message_type);
443
444            match msg_type {
445                Some(PubSubMessageType::Message) => {
446                    let channel = items[1].as_string()?;
447                    let payload = items[2].as_string()?;
448
449                    Ok(Some(PubSubMessage {
450                        channel,
451                        payload,
452                        pattern: None,
453                    }))
454                }
455                Some(PubSubMessageType::PMessage) if items.len() >= 4 => {
456                    let pattern = items[1].as_string()?;
457                    let channel = items[2].as_string()?;
458                    let payload = items[3].as_string()?;
459
460                    Ok(Some(PubSubMessage {
461                        channel,
462                        payload,
463                        pattern: Some(pattern),
464                    }))
465                }
466                Some(
467                    PubSubMessageType::Subscribe
468                    | PubSubMessageType::Unsubscribe
469                    | PubSubMessageType::PSubscribe
470                    | PubSubMessageType::PUnsubscribe,
471                ) => {
472                    // These are subscription confirmations, not actual messages
473                    Ok(None)
474                }
475                _ => Err(RedisError::Protocol(format!(
476                    "Unknown pub/sub message type: {}",
477                    message_type
478                ))),
479            }
480        }
481        _ => Err(RedisError::Protocol(format!(
482            "Invalid pub/sub message format: {:?}",
483            response
484        ))),
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491    use std::sync::Arc;
492    use tokio::sync::Mutex;
493
494    struct MockPubSubConnection {
495        published_messages: Vec<(String, String)>,
496        subscribed_channels: Vec<String>,
497        subscribed_patterns: Vec<String>,
498    }
499
500    impl MockPubSubConnection {
501        fn new() -> Self {
502            Self {
503                published_messages: Vec::new(),
504                subscribed_channels: Vec::new(),
505                subscribed_patterns: Vec::new(),
506            }
507        }
508    }
509
510    #[async_trait::async_trait]
511    impl PubSubConnection for MockPubSubConnection {
512        async fn subscribe(&mut self, channels: Vec<String>) -> RedisResult<()> {
513            self.subscribed_channels.extend(channels);
514            Ok(())
515        }
516
517        async fn unsubscribe(&mut self, channels: Vec<String>) -> RedisResult<()> {
518            for channel in channels {
519                self.subscribed_channels.retain(|c| c != &channel);
520            }
521            Ok(())
522        }
523
524        async fn psubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()> {
525            self.subscribed_patterns.extend(patterns);
526            Ok(())
527        }
528
529        async fn punsubscribe(&mut self, patterns: Vec<String>) -> RedisResult<()> {
530            for pattern in patterns {
531                self.subscribed_patterns.retain(|p| p != &pattern);
532            }
533            Ok(())
534        }
535
536        async fn listen(
537            &mut self,
538            _message_tx: mpsc::UnboundedSender<PubSubMessage>,
539        ) -> RedisResult<()> {
540            // Mock implementation - would normally listen for messages
541            Ok(())
542        }
543
544        async fn publish(&mut self, channel: String, message: String) -> RedisResult<i64> {
545            self.published_messages.push((channel, message));
546            Ok(1) // Mock: 1 subscriber
547        }
548    }
549
550    #[tokio::test]
551    async fn test_subscriber_creation() {
552        let connection = MockPubSubConnection::new();
553        let subscriber = Subscriber::new(Arc::new(Mutex::new(connection)));
554
555        assert!(subscriber.subscribed_channels().is_empty());
556        assert!(subscriber.subscribed_patterns().is_empty());
557    }
558
559    #[tokio::test]
560    async fn test_subscriber_subscribe() {
561        let connection = MockPubSubConnection::new();
562        let mut subscriber = Subscriber::new(Arc::new(Mutex::new(connection)));
563
564        subscriber
565            .subscribe(vec!["news".to_string(), "updates".to_string()])
566            .await
567            .unwrap();
568
569        assert_eq!(subscriber.subscribed_channels().len(), 2);
570        assert!(subscriber.is_subscribed_to_channel("news"));
571        assert!(subscriber.is_subscribed_to_channel("updates"));
572    }
573
574    #[tokio::test]
575    async fn test_subscriber_unsubscribe() {
576        let connection = MockPubSubConnection::new();
577        let mut subscriber = Subscriber::new(Arc::new(Mutex::new(connection)));
578
579        subscriber
580            .subscribe(vec!["news".to_string(), "updates".to_string()])
581            .await
582            .unwrap();
583        subscriber
584            .unsubscribe(vec!["news".to_string()])
585            .await
586            .unwrap();
587
588        assert_eq!(subscriber.subscribed_channels().len(), 1);
589        assert!(!subscriber.is_subscribed_to_channel("news"));
590        assert!(subscriber.is_subscribed_to_channel("updates"));
591    }
592
593    #[tokio::test]
594    async fn test_publisher_publish() {
595        let connection = MockPubSubConnection::new();
596        let publisher = Publisher::new(Arc::new(Mutex::new(connection)));
597
598        let count = publisher.publish("news", "Breaking news!").await.unwrap();
599        assert_eq!(count, 1);
600    }
601
602    #[test]
603    fn test_parse_pubsub_message() {
604        // Test regular message
605        let response = RespValue::Array(vec![
606            RespValue::from("message"),
607            RespValue::from("news"),
608            RespValue::from("Breaking news!"),
609        ]);
610
611        let message = parse_pubsub_message(response).unwrap().unwrap();
612        assert_eq!(message.channel, "news");
613        assert_eq!(message.payload, "Breaking news!");
614        assert!(message.pattern.is_none());
615    }
616
617    #[test]
618    fn test_parse_pubsub_pattern_message() {
619        // Test pattern message
620        let response = RespValue::Array(vec![
621            RespValue::from("pmessage"),
622            RespValue::from("news*"),
623            RespValue::from("news-tech"),
624            RespValue::from("Tech news!"),
625        ]);
626
627        let message = parse_pubsub_message(response).unwrap().unwrap();
628        assert_eq!(message.channel, "news-tech");
629        assert_eq!(message.payload, "Tech news!");
630        assert_eq!(message.pattern, Some("news*".to_string()));
631    }
632
633    #[test]
634    fn test_parse_pubsub_subscribe_confirmation() {
635        // Test subscription confirmation (should return None)
636        let response = RespValue::Array(vec![
637            RespValue::from("subscribe"),
638            RespValue::from("news"),
639            RespValue::Integer(1),
640        ]);
641
642        let message = parse_pubsub_message(response).unwrap();
643        assert!(message.is_none());
644    }
645}