mini_redis/cmd/subscribe.rs
1use crate::cmd::{Parse, ParseError, Unknown};
2use crate::{Command, Connection, Db, Frame, Shutdown};
3
4use bytes::Bytes;
5use std::pin::Pin;
6use tokio::select;
7use tokio::sync::broadcast;
8use tokio_stream::{Stream, StreamExt, StreamMap};
9
10/// Subscribes the client to one or more channels.
11///
12/// Once the client enters the subscribed state, it is not supposed to issue any
13/// other commands, except for additional SUBSCRIBE, PSUBSCRIBE, UNSUBSCRIBE,
14/// PUNSUBSCRIBE, PING and QUIT commands.
15#[derive(Debug)]
16pub struct Subscribe {
17 channels: Vec<String>,
18}
19
20/// Unsubscribes the client from one or more channels.
21///
22/// When no channels are specified, the client is unsubscribed from all the
23/// previously subscribed channels.
24#[derive(Clone, Debug)]
25pub struct Unsubscribe {
26 channels: Vec<String>,
27}
28
29/// Stream of messages. The stream receives messages from the
30/// `broadcast::Receiver`. We use `stream!` to create a `Stream` that consumes
31/// messages. Because `stream!` values cannot be named, we box the stream using
32/// a trait object.
33type Messages = Pin<Box<dyn Stream<Item = Bytes> + Send>>;
34
35impl Subscribe {
36 /// Creates a new `Subscribe` command to listen on the specified channels.
37 pub(crate) fn new(channels: &[String]) -> Subscribe {
38 Subscribe {
39 channels: channels.to_vec(),
40 }
41 }
42
43 /// Parse a `Subscribe` instance from a received frame.
44 ///
45 /// The `Parse` argument provides a cursor-like API to read fields from the
46 /// `Frame`. At this point, the entire frame has already been received from
47 /// the socket.
48 ///
49 /// The `SUBSCRIBE` string has already been consumed.
50 ///
51 /// # Returns
52 ///
53 /// On success, the `Subscribe` value is returned. If the frame is
54 /// malformed, `Err` is returned.
55 ///
56 /// # Format
57 ///
58 /// Expects an array frame containing two or more entries.
59 ///
60 /// ```text
61 /// SUBSCRIBE channel [channel ...]
62 /// ```
63 pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result<Subscribe> {
64 use ParseError::EndOfStream;
65
66 // The `SUBSCRIBE` string has already been consumed. At this point,
67 // there is one or more strings remaining in `parse`. These represent
68 // the channels to subscribe to.
69 //
70 // Extract the first string. If there is none, the the frame is
71 // malformed and the error is bubbled up.
72 let mut channels = vec![parse.next_string()?];
73
74 // Now, the remainder of the frame is consumed. Each value must be a
75 // string or the frame is malformed. Once all values in the frame have
76 // been consumed, the command is fully parsed.
77 loop {
78 match parse.next_string() {
79 // A string has been consumed from the `parse`, push it into the
80 // list of channels to subscribe to.
81 Ok(s) => channels.push(s),
82 // The `EndOfStream` error indicates there is no further data to
83 // parse.
84 Err(EndOfStream) => break,
85 // All other errors are bubbled up, resulting in the connection
86 // being terminated.
87 Err(err) => return Err(err.into()),
88 }
89 }
90
91 Ok(Subscribe { channels })
92 }
93
94 /// Apply the `Subscribe` command to the specified `Db` instance.
95 ///
96 /// This function is the entry point and includes the initial list of
97 /// channels to subscribe to. Additional `subscribe` and `unsubscribe`
98 /// commands may be received from the client and the list of subscriptions
99 /// are updated accordingly.
100 ///
101 /// [here]: https://redis.io/topics/pubsub
102 pub(crate) async fn apply(
103 mut self,
104 db: &Db,
105 dst: &mut Connection,
106 shutdown: &mut Shutdown,
107 ) -> crate::Result<()> {
108 // Each individual channel subscription is handled using a
109 // `sync::broadcast` channel. Messages are then fanned out to all
110 // clients currently subscribed to the channels.
111 //
112 // An individual client may subscribe to multiple channels and may
113 // dynamically add and remove channels from its subscription set. To
114 // handle this, a `StreamMap` is used to track active subscriptions. The
115 // `StreamMap` merges messages from individual broadcast channels as
116 // they are received.
117 let mut subscriptions = StreamMap::new();
118
119 loop {
120 // `self.channels` is used to track additional channels to subscribe
121 // to. When new `SUBSCRIBE` commands are received during the
122 // execution of `apply`, the new channels are pushed onto this vec.
123 for channel_name in self.channels.drain(..) {
124 subscribe_to_channel(channel_name, &mut subscriptions, db, dst).await?;
125 }
126
127 // Wait for one of the following to happen:
128 //
129 // - Receive a message from one of the subscribed channels.
130 // - Receive a subscribe or unsubscribe command from the client.
131 // - A server shutdown signal.
132 select! {
133 // Receive messages from subscribed channels
134 Some((channel_name, msg)) = subscriptions.next() => {
135 dst.write_frame(&make_message_frame(channel_name, msg)).await?;
136 }
137 res = dst.read_frame() => {
138 let frame = match res? {
139 Some(frame) => frame,
140 // This happens if the remote client has disconnected.
141 None => return Ok(())
142 };
143
144 handle_command(
145 frame,
146 &mut self.channels,
147 &mut subscriptions,
148 dst,
149 ).await?;
150 }
151 _ = shutdown.recv() => {
152 return Ok(());
153 }
154 };
155 }
156 }
157
158 /// Converts the command into an equivalent `Frame`.
159 ///
160 /// This is called by the client when encoding a `Subscribe` command to send
161 /// to the server.
162 pub(crate) fn into_frame(self) -> Frame {
163 let mut frame = Frame::array();
164 frame.push_bulk(Bytes::from("subscribe".as_bytes()));
165 for channel in self.channels {
166 frame.push_bulk(Bytes::from(channel.into_bytes()));
167 }
168 frame
169 }
170}
171
172async fn subscribe_to_channel(
173 channel_name: String,
174 subscriptions: &mut StreamMap<String, Messages>,
175 db: &Db,
176 dst: &mut Connection,
177) -> crate::Result<()> {
178 let mut rx = db.subscribe(channel_name.clone());
179
180 // Subscribe to the channel.
181 let rx = Box::pin(async_stream::stream! {
182 loop {
183 match rx.recv().await {
184 Ok(msg) => yield msg,
185 // If we lagged in consuming messages, just resume.
186 Err(broadcast::error::RecvError::Lagged(_)) => {}
187 Err(_) => break,
188 }
189 }
190 });
191
192 // Track subscription in this client's subscription set.
193 subscriptions.insert(channel_name.clone(), rx);
194
195 // Respond with the successful subscription
196 let response = make_subscribe_frame(channel_name, subscriptions.len());
197 dst.write_frame(&response).await?;
198
199 Ok(())
200}
201
202/// Handle a command received while inside `Subscribe::apply`. Only subscribe
203/// and unsubscribe commands are permitted in this context.
204///
205/// Any new subscriptions are appended to `subscribe_to` instead of modifying
206/// `subscriptions`.
207async fn handle_command(
208 frame: Frame,
209 subscribe_to: &mut Vec<String>,
210 subscriptions: &mut StreamMap<String, Messages>,
211 dst: &mut Connection,
212) -> crate::Result<()> {
213 // A command has been received from the client.
214 //
215 // Only `SUBSCRIBE` and `UNSUBSCRIBE` commands are permitted
216 // in this context.
217 match Command::from_frame(frame)? {
218 Command::Subscribe(subscribe) => {
219 // The `apply` method will subscribe to the channels we add to this
220 // vector.
221 subscribe_to.extend(subscribe.channels.into_iter());
222 }
223 Command::Unsubscribe(mut unsubscribe) => {
224 // If no channels are specified, this requests unsubscribing from
225 // **all** channels. To implement this, the `unsubscribe.channels`
226 // vec is populated with the list of channels currently subscribed
227 // to.
228 if unsubscribe.channels.is_empty() {
229 unsubscribe.channels = subscriptions
230 .keys()
231 .map(|channel_name| channel_name.to_string())
232 .collect();
233 }
234
235 for channel_name in unsubscribe.channels {
236 subscriptions.remove(&channel_name);
237
238 let response = make_unsubscribe_frame(channel_name, subscriptions.len());
239 dst.write_frame(&response).await?;
240 }
241 }
242 command => {
243 let cmd = Unknown::new(command.get_name());
244 cmd.apply(dst).await?;
245 }
246 }
247 Ok(())
248}
249
250/// Creates the response to a subcribe request.
251///
252/// All of these functions take the `channel_name` as a `String` instead of
253/// a `&str` since `Bytes::from` can reuse the allocation in the `String`, and
254/// taking a `&str` would require copying the data. This allows the caller to
255/// decide whether to clone the channel name or not.
256fn make_subscribe_frame(channel_name: String, num_subs: usize) -> Frame {
257 let mut response = Frame::array();
258 response.push_bulk(Bytes::from_static(b"subscribe"));
259 response.push_bulk(Bytes::from(channel_name));
260 response.push_int(num_subs as u64);
261 response
262}
263
264/// Creates the response to an unsubcribe request.
265fn make_unsubscribe_frame(channel_name: String, num_subs: usize) -> Frame {
266 let mut response = Frame::array();
267 response.push_bulk(Bytes::from_static(b"unsubscribe"));
268 response.push_bulk(Bytes::from(channel_name));
269 response.push_int(num_subs as u64);
270 response
271}
272
273/// Creates a message informing the client about a new message on a channel that
274/// the client subscribes to.
275fn make_message_frame(channel_name: String, msg: Bytes) -> Frame {
276 let mut response = Frame::array();
277 response.push_bulk(Bytes::from_static(b"message"));
278 response.push_bulk(Bytes::from(channel_name));
279 response.push_bulk(msg);
280 response
281}
282
283impl Unsubscribe {
284 /// Create a new `Unsubscribe` command with the given `channels`.
285 pub(crate) fn new(channels: &[String]) -> Unsubscribe {
286 Unsubscribe {
287 channels: channels.to_vec(),
288 }
289 }
290
291 /// Parse a `Unsubscribe` instance from a received frame.
292 ///
293 /// The `Parse` argument provides a cursor-like API to read fields from the
294 /// `Frame`. At this point, the entire frame has already been received from
295 /// the socket.
296 ///
297 /// The `UNSUBSCRIBE` string has already been consumed.
298 ///
299 /// # Returns
300 ///
301 /// On success, the `Unsubscribe` value is returned. If the frame is
302 /// malformed, `Err` is returned.
303 ///
304 /// # Format
305 ///
306 /// Expects an array frame containing at least one entry.
307 ///
308 /// ```text
309 /// UNSUBSCRIBE [channel [channel ...]]
310 /// ```
311 pub(crate) fn parse_frames(parse: &mut Parse) -> Result<Unsubscribe, ParseError> {
312 use ParseError::EndOfStream;
313
314 // There may be no channels listed, so start with an empty vec.
315 let mut channels = vec![];
316
317 // Each entry in the frame must be a string or the frame is malformed.
318 // Once all values in the frame have been consumed, the command is fully
319 // parsed.
320 loop {
321 match parse.next_string() {
322 // A string has been consumed from the `parse`, push it into the
323 // list of channels to unsubscribe from.
324 Ok(s) => channels.push(s),
325 // The `EndOfStream` error indicates there is no further data to
326 // parse.
327 Err(EndOfStream) => break,
328 // All other errors are bubbled up, resulting in the connection
329 // being terminated.
330 Err(err) => return Err(err),
331 }
332 }
333
334 Ok(Unsubscribe { channels })
335 }
336
337 /// Converts the command into an equivalent `Frame`.
338 ///
339 /// This is called by the client when encoding an `Unsubscribe` command to
340 /// send to the server.
341 pub(crate) fn into_frame(self) -> Frame {
342 let mut frame = Frame::array();
343 frame.push_bulk(Bytes::from("unsubscribe".as_bytes()));
344
345 for channel in self.channels {
346 frame.push_bulk(Bytes::from(channel.into_bytes()));
347 }
348
349 frame
350 }
351}