mini_redis/cmd/
subscribe.rs

1use crate::cmd::{Parse, ParseError, Unknown};
2use crate::{Command, Connection, Db, Frame, Shutdown};
3
4use bytes::Bytes;
5use std::pin::Pin;
6use tokio::select;
7use tokio::sync::broadcast;
8use tokio_stream::{Stream, StreamExt, StreamMap};
9
10/// Subscribes the client to one or more channels.
11///
12/// Once the client enters the subscribed state, it is not supposed to issue any
13/// other commands, except for additional SUBSCRIBE, PSUBSCRIBE, UNSUBSCRIBE,
14/// PUNSUBSCRIBE, PING and QUIT commands.
15#[derive(Debug)]
16pub struct Subscribe {
17    channels: Vec<String>,
18}
19
20/// Unsubscribes the client from one or more channels.
21///
22/// When no channels are specified, the client is unsubscribed from all the
23/// previously subscribed channels.
24#[derive(Clone, Debug)]
25pub struct Unsubscribe {
26    channels: Vec<String>,
27}
28
29/// Stream of messages. The stream receives messages from the
30/// `broadcast::Receiver`. We use `stream!` to create a `Stream` that consumes
31/// messages. Because `stream!` values cannot be named, we box the stream using
32/// a trait object.
33type Messages = Pin<Box<dyn Stream<Item = Bytes> + Send>>;
34
35impl Subscribe {
36    /// Creates a new `Subscribe` command to listen on the specified channels.
37    pub(crate) fn new(channels: &[String]) -> Subscribe {
38        Subscribe {
39            channels: channels.to_vec(),
40        }
41    }
42
43    /// Parse a `Subscribe` instance from a received frame.
44    ///
45    /// The `Parse` argument provides a cursor-like API to read fields from the
46    /// `Frame`. At this point, the entire frame has already been received from
47    /// the socket.
48    ///
49    /// The `SUBSCRIBE` string has already been consumed.
50    ///
51    /// # Returns
52    ///
53    /// On success, the `Subscribe` value is returned. If the frame is
54    /// malformed, `Err` is returned.
55    ///
56    /// # Format
57    ///
58    /// Expects an array frame containing two or more entries.
59    ///
60    /// ```text
61    /// SUBSCRIBE channel [channel ...]
62    /// ```
63    pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result<Subscribe> {
64        use ParseError::EndOfStream;
65
66        // The `SUBSCRIBE` string has already been consumed. At this point,
67        // there is one or more strings remaining in `parse`. These represent
68        // the channels to subscribe to.
69        //
70        // Extract the first string. If there is none, the the frame is
71        // malformed and the error is bubbled up.
72        let mut channels = vec![parse.next_string()?];
73
74        // Now, the remainder of the frame is consumed. Each value must be a
75        // string or the frame is malformed. Once all values in the frame have
76        // been consumed, the command is fully parsed.
77        loop {
78            match parse.next_string() {
79                // A string has been consumed from the `parse`, push it into the
80                // list of channels to subscribe to.
81                Ok(s) => channels.push(s),
82                // The `EndOfStream` error indicates there is no further data to
83                // parse.
84                Err(EndOfStream) => break,
85                // All other errors are bubbled up, resulting in the connection
86                // being terminated.
87                Err(err) => return Err(err.into()),
88            }
89        }
90
91        Ok(Subscribe { channels })
92    }
93
94    /// Apply the `Subscribe` command to the specified `Db` instance.
95    ///
96    /// This function is the entry point and includes the initial list of
97    /// channels to subscribe to. Additional `subscribe` and `unsubscribe`
98    /// commands may be received from the client and the list of subscriptions
99    /// are updated accordingly.
100    ///
101    /// [here]: https://redis.io/topics/pubsub
102    pub(crate) async fn apply(
103        mut self,
104        db: &Db,
105        dst: &mut Connection,
106        shutdown: &mut Shutdown,
107    ) -> crate::Result<()> {
108        // Each individual channel subscription is handled using a
109        // `sync::broadcast` channel. Messages are then fanned out to all
110        // clients currently subscribed to the channels.
111        //
112        // An individual client may subscribe to multiple channels and may
113        // dynamically add and remove channels from its subscription set. To
114        // handle this, a `StreamMap` is used to track active subscriptions. The
115        // `StreamMap` merges messages from individual broadcast channels as
116        // they are received.
117        let mut subscriptions = StreamMap::new();
118
119        loop {
120            // `self.channels` is used to track additional channels to subscribe
121            // to. When new `SUBSCRIBE` commands are received during the
122            // execution of `apply`, the new channels are pushed onto this vec.
123            for channel_name in self.channels.drain(..) {
124                subscribe_to_channel(channel_name, &mut subscriptions, db, dst).await?;
125            }
126
127            // Wait for one of the following to happen:
128            //
129            // - Receive a message from one of the subscribed channels.
130            // - Receive a subscribe or unsubscribe command from the client.
131            // - A server shutdown signal.
132            select! {
133                // Receive messages from subscribed channels
134                Some((channel_name, msg)) = subscriptions.next() => {
135                    dst.write_frame(&make_message_frame(channel_name, msg)).await?;
136                }
137                res = dst.read_frame() => {
138                    let frame = match res? {
139                        Some(frame) => frame,
140                        // This happens if the remote client has disconnected.
141                        None => return Ok(())
142                    };
143
144                    handle_command(
145                        frame,
146                        &mut self.channels,
147                        &mut subscriptions,
148                        dst,
149                    ).await?;
150                }
151                _ = shutdown.recv() => {
152                    return Ok(());
153                }
154            };
155        }
156    }
157
158    /// Converts the command into an equivalent `Frame`.
159    ///
160    /// This is called by the client when encoding a `Subscribe` command to send
161    /// to the server.
162    pub(crate) fn into_frame(self) -> Frame {
163        let mut frame = Frame::array();
164        frame.push_bulk(Bytes::from("subscribe".as_bytes()));
165        for channel in self.channels {
166            frame.push_bulk(Bytes::from(channel.into_bytes()));
167        }
168        frame
169    }
170}
171
172async fn subscribe_to_channel(
173    channel_name: String,
174    subscriptions: &mut StreamMap<String, Messages>,
175    db: &Db,
176    dst: &mut Connection,
177) -> crate::Result<()> {
178    let mut rx = db.subscribe(channel_name.clone());
179
180    // Subscribe to the channel.
181    let rx = Box::pin(async_stream::stream! {
182        loop {
183            match rx.recv().await {
184                Ok(msg) => yield msg,
185                // If we lagged in consuming messages, just resume.
186                Err(broadcast::error::RecvError::Lagged(_)) => {}
187                Err(_) => break,
188            }
189        }
190    });
191
192    // Track subscription in this client's subscription set.
193    subscriptions.insert(channel_name.clone(), rx);
194
195    // Respond with the successful subscription
196    let response = make_subscribe_frame(channel_name, subscriptions.len());
197    dst.write_frame(&response).await?;
198
199    Ok(())
200}
201
202/// Handle a command received while inside `Subscribe::apply`. Only subscribe
203/// and unsubscribe commands are permitted in this context.
204///
205/// Any new subscriptions are appended to `subscribe_to` instead of modifying
206/// `subscriptions`.
207async fn handle_command(
208    frame: Frame,
209    subscribe_to: &mut Vec<String>,
210    subscriptions: &mut StreamMap<String, Messages>,
211    dst: &mut Connection,
212) -> crate::Result<()> {
213    // A command has been received from the client.
214    //
215    // Only `SUBSCRIBE` and `UNSUBSCRIBE` commands are permitted
216    // in this context.
217    match Command::from_frame(frame)? {
218        Command::Subscribe(subscribe) => {
219            // The `apply` method will subscribe to the channels we add to this
220            // vector.
221            subscribe_to.extend(subscribe.channels.into_iter());
222        }
223        Command::Unsubscribe(mut unsubscribe) => {
224            // If no channels are specified, this requests unsubscribing from
225            // **all** channels. To implement this, the `unsubscribe.channels`
226            // vec is populated with the list of channels currently subscribed
227            // to.
228            if unsubscribe.channels.is_empty() {
229                unsubscribe.channels = subscriptions
230                    .keys()
231                    .map(|channel_name| channel_name.to_string())
232                    .collect();
233            }
234
235            for channel_name in unsubscribe.channels {
236                subscriptions.remove(&channel_name);
237
238                let response = make_unsubscribe_frame(channel_name, subscriptions.len());
239                dst.write_frame(&response).await?;
240            }
241        }
242        command => {
243            let cmd = Unknown::new(command.get_name());
244            cmd.apply(dst).await?;
245        }
246    }
247    Ok(())
248}
249
250/// Creates the response to a subcribe request.
251///
252/// All of these functions take the `channel_name` as a `String` instead of
253/// a `&str` since `Bytes::from` can reuse the allocation in the `String`, and
254/// taking a `&str` would require copying the data. This allows the caller to
255/// decide whether to clone the channel name or not.
256fn make_subscribe_frame(channel_name: String, num_subs: usize) -> Frame {
257    let mut response = Frame::array();
258    response.push_bulk(Bytes::from_static(b"subscribe"));
259    response.push_bulk(Bytes::from(channel_name));
260    response.push_int(num_subs as u64);
261    response
262}
263
264/// Creates the response to an unsubcribe request.
265fn make_unsubscribe_frame(channel_name: String, num_subs: usize) -> Frame {
266    let mut response = Frame::array();
267    response.push_bulk(Bytes::from_static(b"unsubscribe"));
268    response.push_bulk(Bytes::from(channel_name));
269    response.push_int(num_subs as u64);
270    response
271}
272
273/// Creates a message informing the client about a new message on a channel that
274/// the client subscribes to.
275fn make_message_frame(channel_name: String, msg: Bytes) -> Frame {
276    let mut response = Frame::array();
277    response.push_bulk(Bytes::from_static(b"message"));
278    response.push_bulk(Bytes::from(channel_name));
279    response.push_bulk(msg);
280    response
281}
282
283impl Unsubscribe {
284    /// Create a new `Unsubscribe` command with the given `channels`.
285    pub(crate) fn new(channels: &[String]) -> Unsubscribe {
286        Unsubscribe {
287            channels: channels.to_vec(),
288        }
289    }
290
291    /// Parse a `Unsubscribe` instance from a received frame.
292    ///
293    /// The `Parse` argument provides a cursor-like API to read fields from the
294    /// `Frame`. At this point, the entire frame has already been received from
295    /// the socket.
296    ///
297    /// The `UNSUBSCRIBE` string has already been consumed.
298    ///
299    /// # Returns
300    ///
301    /// On success, the `Unsubscribe` value is returned. If the frame is
302    /// malformed, `Err` is returned.
303    ///
304    /// # Format
305    ///
306    /// Expects an array frame containing at least one entry.
307    ///
308    /// ```text
309    /// UNSUBSCRIBE [channel [channel ...]]
310    /// ```
311    pub(crate) fn parse_frames(parse: &mut Parse) -> Result<Unsubscribe, ParseError> {
312        use ParseError::EndOfStream;
313
314        // There may be no channels listed, so start with an empty vec.
315        let mut channels = vec![];
316
317        // Each entry in the frame must be a string or the frame is malformed.
318        // Once all values in the frame have been consumed, the command is fully
319        // parsed.
320        loop {
321            match parse.next_string() {
322                // A string has been consumed from the `parse`, push it into the
323                // list of channels to unsubscribe from.
324                Ok(s) => channels.push(s),
325                // The `EndOfStream` error indicates there is no further data to
326                // parse.
327                Err(EndOfStream) => break,
328                // All other errors are bubbled up, resulting in the connection
329                // being terminated.
330                Err(err) => return Err(err),
331            }
332        }
333
334        Ok(Unsubscribe { channels })
335    }
336
337    /// Converts the command into an equivalent `Frame`.
338    ///
339    /// This is called by the client when encoding an `Unsubscribe` command to
340    /// send to the server.
341    pub(crate) fn into_frame(self) -> Frame {
342        let mut frame = Frame::array();
343        frame.push_bulk(Bytes::from("unsubscribe".as_bytes()));
344
345        for channel in self.channels {
346            frame.push_bulk(Bytes::from(channel.into_bytes()));
347        }
348
349        frame
350    }
351}