mini_redis/
client.rs

1//! Minimal Redis client implementation
2//!
3//! Provides an async connect and methods for issuing the supported commands.
4
5use crate::cmd::{Get, Publish, Set, Subscribe, Unsubscribe};
6use crate::{Connection, Frame};
7
8use async_stream::try_stream;
9use bytes::Bytes;
10use std::io::{Error, ErrorKind};
11use std::time::Duration;
12use tokio::net::{TcpStream, ToSocketAddrs};
13use tokio_stream::Stream;
14use tracing::{debug, instrument};
15
16/// Established connection with a Redis server.
17///
18/// Backed by a single `TcpStream`, `Client` provides basic network client
19/// functionality (no pooling, retrying, ...). Connections are established using
20/// the [`connect`](fn@connect) function.
21///
22/// Requests are issued using the various methods of `Client`.
23pub struct Client {
24    /// The TCP connection decorated with the redis protocol encoder / decoder
25    /// implemented using a buffered `TcpStream`.
26    ///
27    /// When `Listener` receives an inbound connection, the `TcpStream` is
28    /// passed to `Connection::new`, which initializes the associated buffers.
29    /// `Connection` allows the handler to operate at the "frame" level and keep
30    /// the byte level protocol parsing details encapsulated in `Connection`.
31    connection: Connection,
32}
33
34/// A client that has entered pub/sub mode.
35///
36/// Once clients subscribe to a channel, they may only perform pub/sub related
37/// commands. The `Client` type is transitioned to a `Subscriber` type in order
38/// to prevent non-pub/sub methods from being called.
39pub struct Subscriber {
40    /// The subscribed client.
41    client: Client,
42
43    /// The set of channels to which the `Subscriber` is currently subscribed.
44    subscribed_channels: Vec<String>,
45}
46
47/// A message received on a subscribed channel.
48#[derive(Debug, Clone)]
49pub struct Message {
50    pub channel: String,
51    pub content: Bytes,
52}
53
54/// Establish a connection with the Redis server located at `addr`.
55///
56/// `addr` may be any type that can be asynchronously converted to a
57/// `SocketAddr`. This includes `SocketAddr` and strings. The `ToSocketAddrs`
58/// trait is the Tokio version and not the `std` version.
59///
60/// # Examples
61///
62/// ```no_run
63/// use mini_redis::client;
64///
65/// #[tokio::main]
66/// async fn main() {
67///     let client = match client::connect("localhost:6379").await {
68///         Ok(client) => client,
69///         Err(_) => panic!("failed to establish connection"),
70///     };
71/// # drop(client);
72/// }
73/// ```
74///
75pub async fn connect<T: ToSocketAddrs>(addr: T) -> crate::Result<Client> {
76    // The `addr` argument is passed directly to `TcpStream::connect`. This
77    // performs any asynchronous DNS lookup and attempts to establish the TCP
78    // connection. An error at either step returns an error, which is then
79    // bubbled up to the caller of `mini_redis` connect.
80    let socket = TcpStream::connect(addr).await?;
81
82    // Initialize the connection state. This allocates read/write buffers to
83    // perform redis protocol frame parsing.
84    let connection = Connection::new(socket);
85
86    Ok(Client { connection })
87}
88
89impl Client {
90    /// Get the value of key.
91    ///
92    /// If the key does not exist the special value `None` is returned.
93    ///
94    /// # Examples
95    ///
96    /// Demonstrates basic usage.
97    ///
98    /// ```no_run
99    /// use mini_redis::client;
100    ///
101    /// #[tokio::main]
102    /// async fn main() {
103    ///     let mut client = client::connect("localhost:6379").await.unwrap();
104    ///
105    ///     let val = client.get("foo").await.unwrap();
106    ///     println!("Got = {:?}", val);
107    /// }
108    /// ```
109    #[instrument(skip(self))]
110    pub async fn get(&mut self, key: &str) -> crate::Result<Option<Bytes>> {
111        // Create a `Get` command for the `key` and convert it to a frame.
112        let frame = Get::new(key).into_frame();
113
114        debug!(request = ?frame);
115
116        // Write the frame to the socket. This writes the full frame to the
117        // socket, waiting if necessary.
118        self.connection.write_frame(&frame).await?;
119
120        // Wait for the response from the server
121        //
122        // Both `Simple` and `Bulk` frames are accepted. `Null` represents the
123        // key not being present and `None` is returned.
124        match self.read_response().await? {
125            Frame::Simple(value) => Ok(Some(value.into())),
126            Frame::Bulk(value) => Ok(Some(value)),
127            Frame::Null => Ok(None),
128            frame => Err(frame.to_error()),
129        }
130    }
131
132    /// Set `key` to hold the given `value`.
133    ///
134    /// The `value` is associated with `key` until it is overwritten by the next
135    /// call to `set` or it is removed.
136    ///
137    /// If key already holds a value, it is overwritten. Any previous time to
138    /// live associated with the key is discarded on successful SET operation.
139    ///
140    /// # Examples
141    ///
142    /// Demonstrates basic usage.
143    ///
144    /// ```no_run
145    /// use mini_redis::client;
146    ///
147    /// #[tokio::main]
148    /// async fn main() {
149    ///     let mut client = client::connect("localhost:6379").await.unwrap();
150    ///
151    ///     client.set("foo", "bar".into()).await.unwrap();
152    ///
153    ///     // Getting the value immediately works
154    ///     let val = client.get("foo").await.unwrap().unwrap();
155    ///     assert_eq!(val, "bar");
156    /// }
157    /// ```
158    #[instrument(skip(self))]
159    pub async fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> {
160        // Create a `Set` command and pass it to `set_cmd`. A separate method is
161        // used to set a value with an expiration. The common parts of both
162        // functions are implemented by `set_cmd`.
163        self.set_cmd(Set::new(key, value, None)).await
164    }
165
166    /// Set `key` to hold the given `value`. The value expires after `expiration`
167    ///
168    /// The `value` is associated with `key` until one of the following:
169    /// - it expires.
170    /// - it is overwritten by the next call to `set`.
171    /// - it is removed.
172    ///
173    /// If key already holds a value, it is overwritten. Any previous time to
174    /// live associated with the key is discarded on a successful SET operation.
175    ///
176    /// # Examples
177    ///
178    /// Demonstrates basic usage. This example is not **guaranteed** to always
179    /// work as it relies on time based logic and assumes the client and server
180    /// stay relatively synchronized in time. The real world tends to not be so
181    /// favorable.
182    ///
183    /// ```no_run
184    /// use mini_redis::client;
185    /// use tokio::time;
186    /// use std::time::Duration;
187    ///
188    /// #[tokio::main]
189    /// async fn main() {
190    ///     let ttl = Duration::from_millis(500);
191    ///     let mut client = client::connect("localhost:6379").await.unwrap();
192    ///
193    ///     client.set_expires("foo", "bar".into(), ttl).await.unwrap();
194    ///
195    ///     // Getting the value immediately works
196    ///     let val = client.get("foo").await.unwrap().unwrap();
197    ///     assert_eq!(val, "bar");
198    ///
199    ///     // Wait for the TTL to expire
200    ///     time::sleep(ttl).await;
201    ///
202    ///     let val = client.get("foo").await.unwrap();
203    ///     assert!(val.is_some());
204    /// }
205    /// ```
206    #[instrument(skip(self))]
207    pub async fn set_expires(
208        &mut self,
209        key: &str,
210        value: Bytes,
211        expiration: Duration,
212    ) -> crate::Result<()> {
213        // Create a `Set` command and pass it to `set_cmd`. A separate method is
214        // used to set a value with an expiration. The common parts of both
215        // functions are implemented by `set_cmd`.
216        self.set_cmd(Set::new(key, value, Some(expiration))).await
217    }
218
219    /// The core `SET` logic, used by both `set` and `set_expires.
220    async fn set_cmd(&mut self, cmd: Set) -> crate::Result<()> {
221        // Convert the `Set` command into a frame
222        let frame = cmd.into_frame();
223
224        debug!(request = ?frame);
225
226        // Write the frame to the socket. This writes the full frame to the
227        // socket, waiting if necessary.
228        self.connection.write_frame(&frame).await?;
229
230        // Wait for the response from the server. On success, the server
231        // responds simply with `OK`. Any other response indicates an error.
232        match self.read_response().await? {
233            Frame::Simple(response) if response == "OK" => Ok(()),
234            frame => Err(frame.to_error()),
235        }
236    }
237
238    /// Posts `message` to the given `channel`.
239    ///
240    /// Returns the number of subscribers currently listening on the channel.
241    /// There is no guarantee that these subscribers receive the message as they
242    /// may disconnect at any time.
243    ///
244    /// # Examples
245    ///
246    /// Demonstrates basic usage.
247    ///
248    /// ```no_run
249    /// use mini_redis::client;
250    ///
251    /// #[tokio::main]
252    /// async fn main() {
253    ///     let mut client = client::connect("localhost:6379").await.unwrap();
254    ///
255    ///     let val = client.publish("foo", "bar".into()).await.unwrap();
256    ///     println!("Got = {:?}", val);
257    /// }
258    /// ```
259    #[instrument(skip(self))]
260    pub async fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result<u64> {
261        // Convert the `Publish` command into a frame
262        let frame = Publish::new(channel, message).into_frame();
263
264        debug!(request = ?frame);
265
266        // Write the frame to the socket
267        self.connection.write_frame(&frame).await?;
268
269        // Read the response
270        match self.read_response().await? {
271            Frame::Integer(response) => Ok(response),
272            frame => Err(frame.to_error()),
273        }
274    }
275
276    /// Subscribes the client to the specified channels.
277    ///
278    /// Once a client issues a subscribe command, it may no longer issue any
279    /// non-pub/sub commands. The function consumes `self` and returns a `Subscriber`.
280    ///
281    /// The `Subscriber` value is used to receive messages as well as manage the
282    /// list of channels the client is subscribed to.
283    #[instrument(skip(self))]
284    pub async fn subscribe(mut self, channels: Vec<String>) -> crate::Result<Subscriber> {
285        // Issue the subscribe command to the server and wait for confirmation.
286        // The client will then have been transitioned into the "subscriber"
287        // state and may only issue pub/sub commands from that point on.
288        self.subscribe_cmd(&channels).await?;
289
290        // Return the `Subscriber` type
291        Ok(Subscriber {
292            client: self,
293            subscribed_channels: channels,
294        })
295    }
296
297    /// The core `SUBSCRIBE` logic, used by misc subscribe fns
298    async fn subscribe_cmd(&mut self, channels: &[String]) -> crate::Result<()> {
299        // Convert the `Subscribe` command into a frame
300        let frame = Subscribe::new(&channels).into_frame();
301
302        debug!(request = ?frame);
303
304        // Write the frame to the socket
305        self.connection.write_frame(&frame).await?;
306
307        // For each channel being subscribed to, the server responds with a
308        // message confirming subscription to that channel.
309        for channel in channels {
310            // Read the response
311            let response = self.read_response().await?;
312
313            // Verify it is confirmation of subscription.
314            match response {
315                Frame::Array(ref frame) => match frame.as_slice() {
316                    // The server responds with an array frame in the form of:
317                    //
318                    // ```
319                    // [ "subscribe", channel, num-subscribed ]
320                    // ```
321                    //
322                    // where channel is the name of the channel and
323                    // num-subscribed is the number of channels that the client
324                    // is currently subscribed to.
325                    [subscribe, schannel, ..]
326                        if *subscribe == "subscribe" && *schannel == channel => {}
327                    _ => return Err(response.to_error()),
328                },
329                frame => return Err(frame.to_error()),
330            };
331        }
332
333        Ok(())
334    }
335
336    /// Reads a response frame from the socket.
337    ///
338    /// If an `Error` frame is received, it is converted to `Err`.
339    async fn read_response(&mut self) -> crate::Result<Frame> {
340        let response = self.connection.read_frame().await?;
341
342        debug!(?response);
343
344        match response {
345            // Error frames are converted to `Err`
346            Some(Frame::Error(msg)) => Err(msg.into()),
347            Some(frame) => Ok(frame),
348            None => {
349                // Receiving `None` here indicates the server has closed the
350                // connection without sending a frame. This is unexpected and is
351                // represented as a "connection reset by peer" error.
352                let err = Error::new(ErrorKind::ConnectionReset, "connection reset by server");
353
354                Err(err.into())
355            }
356        }
357    }
358}
359
360impl Subscriber {
361    /// Returns the set of channels currently subscribed to.
362    pub fn get_subscribed(&self) -> &[String] {
363        &self.subscribed_channels
364    }
365
366    /// Receive the next message published on a subscribed channel, waiting if
367    /// necessary.
368    ///
369    /// `None` indicates the subscription has been terminated.
370    pub async fn next_message(&mut self) -> crate::Result<Option<Message>> {
371        match self.client.connection.read_frame().await? {
372            Some(mframe) => {
373                debug!(?mframe);
374
375                match mframe {
376                    Frame::Array(ref frame) => match frame.as_slice() {
377                        [message, channel, content] if *message == "message" => Ok(Some(Message {
378                            channel: channel.to_string(),
379                            content: Bytes::from(content.to_string()),
380                        })),
381                        _ => Err(mframe.to_error()),
382                    },
383                    frame => Err(frame.to_error()),
384                }
385            }
386            None => Ok(None),
387        }
388    }
389
390    /// Convert the subscriber into a `Stream` yielding new messages published
391    /// on subscribed channels.
392    ///
393    /// `Subscriber` does not implement stream itself as doing so with safe code
394    /// is non trivial. The usage of async/await would require a manual Stream
395    /// implementation to use `unsafe` code. Instead, a conversion function is
396    /// provided and the returned stream is implemented with the help of the
397    /// `async-stream` crate.
398    pub fn into_stream(mut self) -> impl Stream<Item = crate::Result<Message>> {
399        // Uses the `try_stream` macro from the `async-stream` crate. Generators
400        // are not stable in Rust. The crate uses a macro to simulate generators
401        // on top of async/await. There are limitations, so read the
402        // documentation there.
403        try_stream! {
404            while let Some(message) = self.next_message().await? {
405                yield message;
406            }
407        }
408    }
409
410    /// Subscribe to a list of new channels
411    #[instrument(skip(self))]
412    pub async fn subscribe(&mut self, channels: &[String]) -> crate::Result<()> {
413        // Issue the subscribe command
414        self.client.subscribe_cmd(channels).await?;
415
416        // Update the set of subscribed channels.
417        self.subscribed_channels
418            .extend(channels.iter().map(Clone::clone));
419
420        Ok(())
421    }
422
423    /// Unsubscribe to a list of new channels
424    #[instrument(skip(self))]
425    pub async fn unsubscribe(&mut self, channels: &[String]) -> crate::Result<()> {
426        let frame = Unsubscribe::new(&channels).into_frame();
427
428        debug!(request = ?frame);
429
430        // Write the frame to the socket
431        self.client.connection.write_frame(&frame).await?;
432
433        // if the input channel list is empty, server acknowledges as unsubscribing
434        // from all subscribed channels, so we assert that the unsubscribe list received
435        // matches the client subscribed one
436        let num = if channels.is_empty() {
437            self.subscribed_channels.len()
438        } else {
439            channels.len()
440        };
441
442        // Read the response
443        for _ in 0..num {
444            let response = self.client.read_response().await?;
445
446            match response {
447                Frame::Array(ref frame) => match frame.as_slice() {
448                    [unsubscribe, channel, ..] if *unsubscribe == "unsubscribe" => {
449                        let len = self.subscribed_channels.len();
450
451                        if len == 0 {
452                            // There must be at least one channel
453                            return Err(response.to_error());
454                        }
455
456                        // unsubscribed channel should exist in the subscribed list at this point
457                        self.subscribed_channels.retain(|c| *channel != &c[..]);
458
459                        // Only a single channel should be removed from the
460                        // list of subscribed channels.
461                        if self.subscribed_channels.len() != len - 1 {
462                            return Err(response.to_error());
463                        }
464                    }
465                    _ => return Err(response.to_error()),
466                },
467                frame => return Err(frame.to_error()),
468            };
469        }
470
471        Ok(())
472    }
473}