redis_async/client/
paired.rs

1/*
2 * Copyright 2017-2024 Ben Ashford
3 *
4 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5 * http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 * <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7 * option. This file may not be copied, modified, or distributed
8 * except according to those terms.
9 */
10
11use std::collections::VecDeque;
12use std::future::Future;
13use std::marker::PhantomData;
14use std::mem;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use std::time::Duration;
19
20use futures_channel::{mpsc, oneshot};
21use futures_sink::Sink;
22use futures_util::{future::TryFutureExt, stream::StreamExt};
23
24use super::{
25    connect::{connect_with_auth, RespConnection},
26    ConnectionBuilder,
27};
28
29use crate::{
30    error,
31    reconnect::{reconnect, Reconnect},
32    resp,
33};
34
35/// The state of sending messages to a Redis server
36enum SendStatus {
37    /// The connection is clear, more messages can be sent
38    Ok,
39    /// The connection has closed, nothing more should be sent
40    End,
41    /// The connection reported itself as full, it should be flushed before attempting to send the
42    /// pending message again
43    Full(resp::RespValue),
44}
45
46/// The state of receiving messages from a Redis server
47#[derive(Debug)]
48enum ReceiveStatus {
49    /// Everything has been read, and the connection is closed, don't attempt to read any more
50    ReadyFinished,
51    /// Everything has been read, but the connection is open for future messages.
52    ReadyMore,
53    /// The connection is not ready
54    NotReady,
55}
56
57type CommandResult = Result<resp::RespValue, error::Error>;
58type Responder = oneshot::Sender<CommandResult>;
59type SendPayload = (resp::RespValue, Responder);
60
61// /// The PairedConnectionInner is a spawned future that is responsible for pairing commands and
62// /// results onto a `RespConnection` that is otherwise unpaired
63struct PairedConnectionInner {
64    /// The underlying connection that talks the RESP protocol
65    connection: RespConnection,
66    /// The channel upon which commands are received
67    out_rx: mpsc::UnboundedReceiver<SendPayload>,
68    /// The queue of waiting oneshot's for commands sent but results not yet received
69    waiting: VecDeque<Responder>,
70
71    /// The status of the underlying connection
72    send_status: SendStatus,
73}
74
75impl PairedConnectionInner {
76    fn new(
77        con: RespConnection,
78        out_rx: mpsc::UnboundedReceiver<(resp::RespValue, Responder)>,
79    ) -> Self {
80        PairedConnectionInner {
81            connection: con,
82            out_rx,
83            waiting: VecDeque::new(),
84            send_status: SendStatus::Ok,
85        }
86    }
87
88    fn impl_start_send(
89        &mut self,
90        cx: &mut Context,
91        msg: resp::RespValue,
92    ) -> Result<bool, error::Error> {
93        match Pin::new(&mut self.connection).poll_ready(cx) {
94            Poll::Ready(Ok(())) => (),
95            Poll::Ready(Err(e)) => return Err(e.into()),
96            Poll::Pending => {
97                self.send_status = SendStatus::Full(msg);
98                return Ok(false);
99            }
100        }
101
102        self.send_status = SendStatus::Ok;
103        Pin::new(&mut self.connection).start_send(msg)?;
104        Ok(true)
105    }
106
107    fn poll_start_send(&mut self, cx: &mut Context) -> Result<bool, error::Error> {
108        let mut status = SendStatus::Ok;
109        mem::swap(&mut status, &mut self.send_status);
110
111        let message = match status {
112            SendStatus::End => {
113                self.send_status = SendStatus::End;
114                return Ok(false);
115            }
116            SendStatus::Full(msg) => msg,
117            SendStatus::Ok => match self.out_rx.poll_next_unpin(cx) {
118                Poll::Ready(Some((msg, tx))) => {
119                    self.waiting.push_back(tx);
120                    msg
121                }
122                Poll::Ready(None) => {
123                    self.send_status = SendStatus::End;
124                    return Ok(false);
125                }
126                Poll::Pending => return Ok(false),
127            },
128        };
129
130        self.impl_start_send(cx, message)
131    }
132
133    fn poll_complete(&mut self, cx: &mut Context) -> Result<(), error::Error> {
134        let _ = Pin::new(&mut self.connection).poll_flush(cx)?;
135        Ok(())
136    }
137
138    fn receive(&mut self, cx: &mut Context) -> Result<ReceiveStatus, error::Error> {
139        if let SendStatus::End = self.send_status {
140            if self.waiting.is_empty() {
141                return Ok(ReceiveStatus::ReadyFinished);
142            }
143        }
144        match self.connection.poll_next_unpin(cx) {
145            Poll::Ready(None) => Err(error::unexpected("Connection to Redis closed unexpectedly")),
146            Poll::Ready(Some(Ok(msg))) => {
147                let tx = match self.waiting.pop_front() {
148                    Some(tx) => tx,
149                    None => panic!("Received unexpected message: {:?}", msg),
150                };
151                let _ = tx.send(Ok(msg));
152                Ok(ReceiveStatus::ReadyMore)
153            }
154            Poll::Ready(Some(Err(e))) => Err(e),
155            Poll::Pending => Ok(ReceiveStatus::NotReady),
156        }
157    }
158
159    fn handle_error(&mut self, e: &error::Error) {
160        for tx in self.waiting.drain(..) {
161            let _ = tx.send(Err(error::internal(format!(
162                "Failed due to underlying failure: {}",
163                e
164            ))));
165        }
166
167        log::error!("Internal error in PairedConnectionInner: {}", e);
168    }
169}
170
171impl Future for PairedConnectionInner {
172    type Output = ();
173
174    #[allow(clippy::unit_arg)]
175    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
176        let mut_self = self.get_mut();
177        // If there's something to send, send it...
178        let mut sending = true;
179        while sending {
180            sending = match mut_self.poll_start_send(cx) {
181                Ok(sending) => sending,
182                Err(ref e) => return Poll::Ready(mut_self.handle_error(e)),
183            };
184        }
185
186        if let Err(ref e) = mut_self.poll_complete(cx) {
187            return Poll::Ready(mut_self.handle_error(e));
188        };
189
190        // If there's something to receive, receive it...
191        loop {
192            match mut_self.receive(cx) {
193                Ok(ReceiveStatus::NotReady) => return Poll::Pending,
194                Ok(ReceiveStatus::ReadyMore) => (),
195                Ok(ReceiveStatus::ReadyFinished) => return Poll::Ready(()),
196                Err(ref e) => return Poll::Ready(mut_self.handle_error(e)),
197            }
198        }
199    }
200}
201
202/// A shareable and cheaply cloneable connection to which Redis commands can be sent
203#[derive(Debug, Clone)]
204pub struct PairedConnection {
205    out_tx_c: Arc<Reconnect<SendPayload, mpsc::UnboundedSender<SendPayload>>>,
206}
207
208async fn inner_conn_fn(
209    host: String,
210    port: u16,
211    username: Option<Arc<str>>,
212    password: Option<Arc<str>>,
213    tls: bool,
214    socket_keepalive: Option<Duration>,
215    socket_timeout: Option<Duration>,
216) -> Result<mpsc::UnboundedSender<SendPayload>, error::Error> {
217    let username = username.as_ref().map(|u| u.as_ref());
218    let password = password.as_ref().map(|p| p.as_ref());
219    let connection = connect_with_auth(
220        &host,
221        port,
222        username,
223        password,
224        tls,
225        socket_keepalive,
226        socket_timeout,
227    )
228    .await?;
229    let (out_tx, out_rx) = mpsc::unbounded();
230    let paired_connection_inner = PairedConnectionInner::new(connection, out_rx);
231    tokio::spawn(paired_connection_inner);
232    Ok(out_tx)
233}
234
235impl ConnectionBuilder {
236    pub fn paired_connect(&self) -> impl Future<Output = Result<PairedConnection, error::Error>> {
237        let host = self.host.clone();
238        let port = self.port;
239        let username = self.username.clone();
240        let password = self.password.clone();
241
242        let work_fn = |con: &mpsc::UnboundedSender<SendPayload>, act| {
243            con.unbounded_send(act).map_err(|e| e.into())
244        };
245
246        #[cfg(feature = "tls")]
247        let tls = self.tls;
248        #[cfg(not(feature = "tls"))]
249        let tls = false;
250
251        let socket_keepalive = self.socket_keepalive;
252        let socket_timeout = self.socket_timeout;
253
254        let conn_fn = move || {
255            let con_f = inner_conn_fn(
256                host.clone(),
257                port,
258                username.clone(),
259                password.clone(),
260                tls,
261                socket_keepalive,
262                socket_timeout,
263            );
264            Box::pin(con_f) as Pin<Box<dyn Future<Output = Result<_, error::Error>> + Send + Sync>>
265        };
266
267        let reconnecting_con = reconnect(work_fn, conn_fn);
268        reconnecting_con.map_ok(|con| PairedConnection {
269            out_tx_c: Arc::new(con),
270        })
271    }
272}
273
274/// The default starting point to use most default Redis functionality.
275///
276/// Returns a future that resolves to a `PairedConnection`. The future will complete when the
277/// initial connection is established.
278///
279/// Once the initial connection is established, the connection will attempt to reconnect should
280/// the connection be broken (e.g. the Redis server being restarted), but reconnections occur
281/// asynchronously, so all commands issued while the connection is unavailable will error, it is
282/// the client's responsibility to retry commands as applicable. Also, at least one command needs
283/// to be tried against the connection to trigger the re-connection attempt; this means at least
284/// one command will definitely fail in a disconnect/reconnect scenario.
285pub async fn paired_connect(
286    host: impl Into<String>,
287    port: u16,
288) -> Result<PairedConnection, error::Error> {
289    ConnectionBuilder::new(host, port)?.paired_connect().await
290}
291
292impl PairedConnection {
293    /// Sends a command to Redis.
294    ///
295    /// The message must be in the format of a single RESP message, this can be constructed
296    /// manually or with the `resp_array!` macro.  Returned is a future that resolves to the value
297    /// returned from Redis.  The type must be one for which the `resp::FromResp` trait is defined.
298    ///
299    /// The future will fail for numerous reasons, including but not limited to: IO issues, conversion
300    /// problems, and server-side errors being returned by Redis.
301    ///
302    /// Behind the scenes the message is queued up and sent to Redis asynchronously before the
303    /// future is realised.  As such, it is guaranteed that messages are sent in the same order
304    /// that `send` is called.
305    pub fn send<T>(&self, msg: resp::RespValue) -> SendFuture<T>
306    where
307        T: resp::FromResp + Unpin,
308    {
309        match &msg {
310            resp::RespValue::Array(_) => (),
311            _ => {
312                return SendFuture::new(error::internal("Command must be a RespValue::Array"));
313            }
314        }
315
316        let (tx, rx) = oneshot::channel();
317        match self.out_tx_c.do_work((msg, tx)) {
318            Ok(()) => SendFuture::new(rx),
319            Err(e) => SendFuture::new(e),
320        }
321    }
322
323    #[inline]
324    pub fn send_and_forget(&self, msg: resp::RespValue) {
325        let send_f = self.send::<resp::RespValue>(msg);
326        let forget_f = async {
327            if let Err(e) = send_f.await {
328                log::error!("Error in send_and_forget: {}", e);
329            }
330        };
331        tokio::spawn(forget_f);
332    }
333}
334
335#[derive(Debug)]
336enum SendFutureType {
337    Wait(oneshot::Receiver<Result<resp::RespValue, error::Error>>),
338    Error(Option<error::Error>),
339}
340
341impl From<oneshot::Receiver<Result<resp::RespValue, error::Error>>> for SendFutureType {
342    fn from(from: oneshot::Receiver<Result<resp::RespValue, error::Error>>) -> Self {
343        Self::Wait(from)
344    }
345}
346
347impl From<error::Error> for SendFutureType {
348    fn from(e: error::Error) -> Self {
349        Self::Error(Some(e))
350    }
351}
352
353#[derive(Debug)]
354pub struct SendFuture<T> {
355    send_type: SendFutureType,
356    _phantom: PhantomData<T>,
357}
358
359impl<T> SendFuture<T> {
360    #[inline]
361    fn new(send_type: impl Into<SendFutureType>) -> Self {
362        Self {
363            send_type: send_type.into(),
364            _phantom: Default::default(),
365        }
366    }
367}
368
369impl<T> Future for SendFuture<T>
370where
371    T: resp::FromResp + Unpin,
372{
373    type Output = Result<T, error::Error>;
374
375    #[inline]
376    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
377        match self.get_mut().send_type {
378            SendFutureType::Error(ref mut e) => match e.take() {
379                Some(e) => Poll::Ready(Err(e)),
380                None => panic!("Future polled several times after completion"),
381            },
382            SendFutureType::Wait(ref mut rx) => match Pin::new(rx).poll(cx) {
383                Poll::Ready(Ok(Ok(v))) => Poll::Ready(T::from_resp(v)),
384                Poll::Ready(Ok(Err(e))) => Poll::Ready(Err(e)),
385                Poll::Ready(Err(_)) => Poll::Ready(Err(error::internal(
386                    "Connection closed before response received",
387                ))),
388                Poll::Pending => Poll::Pending,
389            },
390        }
391    }
392}
393
394#[cfg(test)]
395mod test {
396    use super::ConnectionBuilder;
397
398    #[tokio::test]
399    async fn can_paired_connect() {
400        let connection = super::paired_connect("127.0.0.1", 6379)
401            .await
402            .expect("Cannot establish connection");
403
404        let res_f = connection.send(resp_array!["PING", "TEST"]);
405        connection.send_and_forget(resp_array!["SET", "X", "123"]);
406        let wait_f = connection.send(resp_array!["GET", "X"]);
407
408        let result_1: String = res_f.await.expect("Cannot read result of first thing");
409        let result_2: String = wait_f.await.expect("Cannot read result of second thing");
410
411        assert_eq!(result_1, "TEST");
412        assert_eq!(result_2, "123");
413    }
414
415    #[tokio::test]
416    async fn complex_paired_connect() {
417        let connection = super::paired_connect("127.0.0.1", 6379)
418            .await
419            .expect("Cannot establish connection");
420
421        let value: String = connection
422            .send(resp_array!["INCR", "CTR"])
423            .await
424            .expect("Cannot increment counter");
425        let result: String = connection
426            .send(resp_array!["SET", "LASTCTR", value])
427            .await
428            .expect("Cannot set value");
429
430        assert_eq!(result, "OK");
431    }
432
433    #[tokio::test]
434    async fn sending_a_lot_of_data_test() {
435        let connection = super::paired_connect("127.0.0.1", 6379)
436            .await
437            .expect("Cannot connect to Redis");
438        let mut futures = Vec::with_capacity(1000);
439        for i in 0..1000 {
440            let key = format!("X_{}", i);
441            connection.send_and_forget(resp_array!["SET", &key, i.to_string()]);
442            futures.push(connection.send(resp_array!["GET", key]));
443        }
444        let last_future = futures.remove(999);
445        let result: String = last_future.await.expect("Cannot wait for result");
446        assert_eq!(result, "999");
447    }
448
449    #[tokio::test]
450    async fn test_builder() {
451        let mut builder =
452            ConnectionBuilder::new("127.0.0.1", 6379).expect("Cannot construct builder...");
453        builder.password("password");
454        builder.username(String::from("username"));
455        let connection_result = builder.paired_connect().await;
456        // Expecting an error as these aren't the correct username/password
457        assert!(connection_result.is_err());
458    }
459}