Skip to main content

ember_client/
subscriber.rs

1//! Pub/sub subscriber mode.
2//!
3//! When a connection issues [`Client::subscribe`] or [`Client::psubscribe`] it
4//! enters pub/sub mode: the server will push [`Message`] frames whenever a
5//! matching publish event occurs. Normal request-response commands cannot be
6//! issued on the same connection while it is in sub mode.
7//!
8//! Use a separate [`Client`] for regular commands while subscribed.
9
10use bytes::Bytes;
11use ember_protocol::types::Frame;
12
13use crate::connection::{Client, ClientError};
14
15/// A message pushed by the server to a subscribed connection.
16#[derive(Debug, Clone)]
17pub struct Message {
18    /// The channel the message was published to.
19    pub channel: Bytes,
20    /// The message payload.
21    pub data: Bytes,
22    /// Set only for pattern-matched messages (`PSUBSCRIBE`). Contains the
23    /// pattern that matched the channel.
24    pub pattern: Option<Bytes>,
25}
26
27/// A connection locked into pub/sub mode.
28///
29/// Obtained by calling [`Client::subscribe`] or [`Client::psubscribe`].
30/// The underlying transport is consumed — create a separate [`Client`]
31/// for regular commands while this subscriber is active.
32///
33/// # Example
34///
35/// ```no_run
36/// use ember_client::Client;
37///
38/// #[tokio::main]
39/// async fn main() -> Result<(), ember_client::ClientError> {
40///     let mut publisher = Client::connect("127.0.0.1", 6379).await?;
41///     let subscriber_conn = Client::connect("127.0.0.1", 6379).await?;
42///
43///     let mut sub = subscriber_conn.subscribe(&["news"]).await?;
44///
45///     publisher.publish("news", "breaking: hello world").await?;
46///
47///     let msg = sub.recv().await?;
48///     println!("got: {:?}", msg.data);
49///     Ok(())
50/// }
51/// ```
52pub struct Subscriber {
53    inner: Client,
54}
55
56impl Subscriber {
57    pub(crate) fn new(inner: Client) -> Self {
58        Self { inner }
59    }
60
61    /// Blocks until the next message arrives on any subscribed channel.
62    ///
63    /// Subscription confirmation frames (`subscribe`/`psubscribe`) are
64    /// skipped silently — only actual message frames are returned.
65    pub async fn recv(&mut self) -> Result<Message, ClientError> {
66        loop {
67            let frame = self.inner.read_response().await?;
68            if let Some(msg) = try_parse_message(frame)? {
69                return Ok(msg);
70            }
71            // confirmation frame (subscribe/unsubscribe/psubscribe/punsubscribe)
72            // — loop back and wait for the next frame
73        }
74    }
75
76    /// Subscribes to additional channels without leaving sub mode.
77    pub async fn subscribe(&mut self, channels: &[&str]) -> Result<(), ClientError> {
78        let mut parts = Vec::with_capacity(1 + channels.len());
79        parts.push(Frame::Bulk(Bytes::from_static(b"SUBSCRIBE")));
80        for ch in channels {
81            parts.push(Frame::Bulk(Bytes::copy_from_slice(ch.as_bytes())));
82        }
83        self.inner.write_frame(Frame::Array(parts)).await?;
84        // drain the confirmation frames (one per channel)
85        for _ in 0..channels.len() {
86            self.inner.read_response().await?;
87        }
88        Ok(())
89    }
90
91    /// Unsubscribes from the given channels.
92    ///
93    /// When all subscriptions have been removed the inner [`Client`] is
94    /// returned so the connection can be reused for regular commands.
95    pub async fn unsubscribe(mut self, channels: &[&str]) -> Result<Client, ClientError> {
96        let mut parts = Vec::with_capacity(1 + channels.len());
97        parts.push(Frame::Bulk(Bytes::from_static(b"UNSUBSCRIBE")));
98        for ch in channels {
99            parts.push(Frame::Bulk(Bytes::copy_from_slice(ch.as_bytes())));
100        }
101        self.inner.write_frame(Frame::Array(parts)).await?;
102        // drain the unsubscribe confirmation frames
103        for _ in 0..channels.len() {
104            self.inner.read_response().await?;
105        }
106        Ok(self.inner)
107    }
108}
109
110/// Tries to parse a push frame into a [`Message`].
111///
112/// Returns `Ok(Some(_))` for `message` and `pmessage` frames.
113/// Returns `Ok(None)` for subscription management frames.
114/// Returns `Err` for unexpected or malformed frames.
115fn try_parse_message(frame: Frame) -> Result<Option<Message>, ClientError> {
116    let elems = match frame {
117        Frame::Array(e) => e,
118        Frame::Error(e) => return Err(ClientError::Server(e)),
119        other => {
120            return Err(ClientError::Protocol(format!(
121                "expected array push frame, got {other:?}"
122            )))
123        }
124    };
125
126    if elems.len() < 3 {
127        return Err(ClientError::Protocol(format!(
128            "push frame too short: {} elements",
129            elems.len()
130        )));
131    }
132
133    let kind = match &elems[0] {
134        Frame::Bulk(b) => b.clone(),
135        Frame::Simple(s) => Bytes::copy_from_slice(s.as_bytes()),
136        other => {
137            return Err(ClientError::Protocol(format!(
138                "expected bulk/simple frame kind, got {other:?}"
139            )))
140        }
141    };
142
143    match kind.as_ref() {
144        b"message" => {
145            if elems.len() < 3 {
146                return Err(ClientError::Protocol(
147                    "message frame has fewer than 3 elements".into(),
148                ));
149            }
150            let channel = bulk_bytes(elems[1].clone())?;
151            let data = bulk_bytes(elems[2].clone())?;
152            Ok(Some(Message {
153                channel,
154                data,
155                pattern: None,
156            }))
157        }
158        b"pmessage" => {
159            if elems.len() < 4 {
160                return Err(ClientError::Protocol(
161                    "pmessage frame has fewer than 4 elements".into(),
162                ));
163            }
164            let pattern = bulk_bytes(elems[1].clone())?;
165            let channel = bulk_bytes(elems[2].clone())?;
166            let data = bulk_bytes(elems[3].clone())?;
167            Ok(Some(Message {
168                channel,
169                data,
170                pattern: Some(pattern),
171            }))
172        }
173        // subscribe / unsubscribe / psubscribe / punsubscribe confirmations
174        _ => Ok(None),
175    }
176}
177
178fn bulk_bytes(frame: Frame) -> Result<Bytes, ClientError> {
179    match frame {
180        Frame::Bulk(b) => Ok(b),
181        other => Err(ClientError::Protocol(format!(
182            "expected bulk in push frame, got {other:?}"
183        ))),
184    }
185}