redis_async/client/
connect.rs

1/*
2 * Copyright 2017-2024 Ben Ashford
3 *
4 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5 * http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 * <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7 * option. This file may not be copied, modified, or distributed
8 * except according to those terms.
9 */
10
11use std::time::Duration;
12
13use futures_util::{SinkExt, StreamExt};
14use pin_project::pin_project;
15use tokio::{
16    io::{AsyncRead, AsyncWrite},
17    net::TcpStream,
18};
19use tokio_util::codec::{Decoder, Framed};
20
21use crate::{
22    error,
23    resp::{self, RespCodec},
24};
25
26#[pin_project(project = RespConnectionInnerProj)]
27#[cfg_attr(
28    any(feature = "with-rustls", feature = "with-native-tls"),
29    expect(
30        clippy::large_enum_variant,
31        reason = "will only be enabled if selected, and isn't moved once built"
32    )
33)]
34pub enum RespConnectionInner {
35    #[cfg(feature = "with-rustls")]
36    Tls {
37        #[pin]
38        stream: tokio_rustls::client::TlsStream<TcpStream>,
39    },
40    #[cfg(feature = "with-native-tls")]
41    Tls {
42        #[pin]
43        stream: tokio_native_tls::TlsStream<TcpStream>,
44    },
45    Plain {
46        #[pin]
47        stream: TcpStream,
48    },
49}
50
51impl AsyncWrite for RespConnectionInner {
52    fn poll_write(
53        self: std::pin::Pin<&mut Self>,
54        cx: &mut std::task::Context<'_>,
55        buf: &[u8],
56    ) -> std::task::Poll<Result<usize, std::io::Error>> {
57        let this = self.project();
58        match this {
59            #[cfg(feature = "tls")]
60            RespConnectionInnerProj::Tls { stream } => stream.poll_write(cx, buf),
61            RespConnectionInnerProj::Plain { stream } => stream.poll_write(cx, buf),
62        }
63    }
64
65    fn poll_flush(
66        self: std::pin::Pin<&mut Self>,
67        cx: &mut std::task::Context<'_>,
68    ) -> std::task::Poll<Result<(), std::io::Error>> {
69        let this = self.project();
70        match this {
71            #[cfg(feature = "tls")]
72            RespConnectionInnerProj::Tls { stream } => stream.poll_flush(cx),
73            RespConnectionInnerProj::Plain { stream } => stream.poll_flush(cx),
74        }
75    }
76
77    fn poll_shutdown(
78        self: std::pin::Pin<&mut Self>,
79        cx: &mut std::task::Context<'_>,
80    ) -> std::task::Poll<Result<(), std::io::Error>> {
81        let this = self.project();
82        match this {
83            #[cfg(feature = "tls")]
84            RespConnectionInnerProj::Tls { stream } => stream.poll_shutdown(cx),
85            RespConnectionInnerProj::Plain { stream } => stream.poll_shutdown(cx),
86        }
87    }
88}
89
90impl AsyncRead for RespConnectionInner {
91    fn poll_read(
92        self: std::pin::Pin<&mut Self>,
93        cx: &mut std::task::Context<'_>,
94        buf: &mut tokio::io::ReadBuf<'_>,
95    ) -> std::task::Poll<std::io::Result<()>> {
96        let this = self.project();
97        match this {
98            #[cfg(feature = "tls")]
99            RespConnectionInnerProj::Tls { stream } => stream.poll_read(cx, buf),
100            RespConnectionInnerProj::Plain { stream } => stream.poll_read(cx, buf),
101        }
102    }
103}
104
105pub type RespConnection = Framed<RespConnectionInner, RespCodec>;
106
107/// Connect to a Redis server and return a Future that resolves to a
108/// `RespConnection` for reading and writing asynchronously.
109///
110/// Each `RespConnection` implements both `Sink` and `Stream` and read and
111/// writes `RESP` objects.
112///
113/// This is a low-level interface to enable the creation of higher-level
114/// functionality.
115///
116/// The sink and stream sides behave independently of each other, it is the
117/// responsibility of the calling application to determine what results are
118/// paired to a particular command.
119///
120/// But since most Redis usages involve issue commands that result in one
121/// single result, this library also implements `paired_connect`.
122pub async fn connect(
123    host: &str,
124    port: u16,
125    socket_keepalive: Option<Duration>,
126    socket_timeout: Option<Duration>,
127) -> Result<RespConnection, error::Error> {
128    let tcp_stream = TcpStream::connect((host, port)).await?;
129    apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
130    Ok(RespCodec.framed(RespConnectionInner::Plain { stream: tcp_stream }))
131}
132
133#[cfg(feature = "with-rustls")]
134pub async fn connect_tls(
135    host: &str,
136    port: u16,
137    socket_keepalive: Option<Duration>,
138    socket_timeout: Option<Duration>,
139) -> Result<RespConnection, error::Error> {
140    use std::sync::Arc;
141    use tokio_rustls::{
142        rustls::{ClientConfig, RootCertStore},
143        TlsConnector,
144    };
145
146    let mut root_store = RootCertStore::empty();
147    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
148    let config = ClientConfig::builder()
149        .with_root_certificates(root_store)
150        .with_no_client_auth();
151    let connector = TlsConnector::from(Arc::new(config));
152    let addr =
153        tokio::net::lookup_host((host, port))
154            .await?
155            .next()
156            .ok_or(error::Error::Connection(
157                error::ConnectionReason::ConnectionFailed,
158            ))?;
159    let tcp_stream = TcpStream::connect(addr).await?;
160    apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
161
162    let stream = connector
163        .connect(
164            String::from(host)
165                .try_into()
166                .map_err(|_err| error::Error::InvalidDnsName)?,
167            tcp_stream,
168        )
169        .await?;
170    Ok(RespCodec.framed(RespConnectionInner::Tls { stream }))
171}
172
173#[cfg(feature = "with-native-tls")]
174pub async fn connect_tls(
175    host: &str,
176    port: u16,
177    socket_keepalive: Option<Duration>,
178    socket_timeout: Option<Duration>,
179) -> Result<RespConnection, error::Error> {
180    let cx = native_tls::TlsConnector::builder().build()?;
181    let cx = tokio_native_tls::TlsConnector::from(cx);
182
183    let addr =
184        tokio::net::lookup_host((host, port))
185            .await?
186            .next()
187            .ok_or(error::Error::Connection(
188                error::ConnectionReason::ConnectionFailed,
189            ))?;
190    let tcp_stream = TcpStream::connect(addr).await?;
191    apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
192    let stream = cx.connect(host, tcp_stream).await?;
193
194    Ok(RespCodec.framed(RespConnectionInner::Tls { stream }))
195}
196
197pub async fn connect_with_auth(
198    host: &str,
199    port: u16,
200    username: Option<&str>,
201    password: Option<&str>,
202    #[allow(unused_variables)] tls: bool,
203    socket_keepalive: Option<Duration>,
204    socket_timeout: Option<Duration>,
205) -> Result<RespConnection, error::Error> {
206    #[cfg(feature = "tls")]
207    let mut connection = if tls {
208        connect_tls(host, port, socket_keepalive, socket_timeout).await?
209    } else {
210        connect(host, port, socket_keepalive, socket_timeout).await?
211    };
212    #[cfg(not(feature = "tls"))]
213    let mut connection = connect(host, port, socket_keepalive, socket_timeout).await?;
214
215    if let Some(password) = password {
216        let mut auth = resp_array!["AUTH"];
217
218        if let Some(username) = username {
219            auth.push(username);
220        }
221
222        auth.push(password);
223
224        connection.send(auth).await?;
225        match connection.next().await {
226            Some(Ok(value)) => match resp::FromResp::from_resp(value) {
227                Ok(()) => (),
228                Err(e) => return Err(e),
229            },
230            Some(Err(e)) => return Err(e),
231            None => {
232                return Err(error::internal(
233                    "Connection closed before authentication complete",
234                ))
235            }
236        }
237    }
238
239    Ok(connection)
240}
241
242/// Apply a custom keep-alive value to the connection
243fn apply_keepalive_and_timeouts(
244    stream: &TcpStream,
245    socket_keepalive: Option<Duration>,
246    socket_timeout: Option<Duration>,
247) -> Result<(), error::Error> {
248    let sock_ref = socket2::SockRef::from(stream);
249
250    if let Some(interval) = socket_keepalive {
251        let keep_alive = socket2::TcpKeepalive::new()
252            .with_time(interval)
253            .with_interval(interval);
254        // Not windows
255        #[cfg(any(
256            target_os = "android",
257            target_os = "dragonfly",
258            target_os = "freebsd",
259            target_os = "fuchsia",
260            target_os = "illumos",
261            target_os = "ios",
262            target_os = "linux",
263            target_os = "macos",
264            target_os = "netbsd",
265            target_os = "tvos",
266            target_os = "watchos",
267        ))]
268        let keep_alive = keep_alive.with_retries(1);
269        sock_ref.set_tcp_keepalive(&keep_alive)?;
270    }
271
272    if let Some(timeout) = socket_timeout {
273        sock_ref.set_read_timeout(Some(timeout))?;
274        sock_ref.set_write_timeout(Some(timeout))?;
275    }
276
277    Ok(())
278}
279
280#[cfg(test)]
281mod test {
282    use futures_util::{
283        sink::SinkExt,
284        stream::{self, StreamExt},
285    };
286
287    use crate::resp;
288
289    #[tokio::test]
290    async fn can_connect() {
291        let mut connection = super::connect("127.0.0.1", 6379, None, None)
292            .await
293            .expect("Cannot connect");
294        connection
295            .send(resp_array!["PING", "TEST"])
296            .await
297            .expect("Cannot send PING");
298        let values: Vec<_> = connection
299            .take(1)
300            .map(|r| r.expect("Unexpected invalid data"))
301            .collect()
302            .await;
303
304        assert_eq!(values.len(), 1);
305        assert_eq!(values[0], "TEST".into());
306    }
307
308    #[tokio::test]
309    async fn complex_test() {
310        let mut connection = super::connect("127.0.0.1", 6379, None, None)
311            .await
312            .expect("Cannot connect");
313        let mut ops = Vec::new();
314        ops.push(resp_array!["FLUSH"]);
315        ops.extend((0..1000).map(|i| resp_array!["SADD", "test_set", format!("VALUE: {}", i)]));
316        ops.push(resp_array!["SMEMBERS", "test_set"]);
317        let mut ops_stream = stream::iter(ops).map(Ok);
318        connection
319            .send_all(&mut ops_stream)
320            .await
321            .expect("Cannot send");
322        let values: Vec<_> = connection
323            .skip(1001)
324            .take(1)
325            .map(|r| r.expect("Unexpected invalid data"))
326            .collect()
327            .await;
328
329        assert_eq!(values.len(), 1);
330        let values = match &values[0] {
331            resp::RespValue::Array(ref values) => values.clone(),
332            _ => panic!("Not an array"),
333        };
334        assert_eq!(values.len(), 1000);
335    }
336}