redis_async/client/
connect.rs

1/*
2 * Copyright 2017-2024 Ben Ashford
3 *
4 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5 * http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 * <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7 * option. This file may not be copied, modified, or distributed
8 * except according to those terms.
9 */
10
11use std::time::Duration;
12
13use futures_util::{SinkExt, StreamExt};
14use pin_project::pin_project;
15use tokio::{
16    io::{AsyncRead, AsyncWrite},
17    net::TcpStream,
18};
19use tokio_util::codec::{Decoder, Framed};
20
21use crate::{
22    error,
23    resp::{self, RespCodec},
24};
25
26#[pin_project(project = RespConnectionInnerProj)]
27pub enum RespConnectionInner {
28    #[cfg(feature = "with-rustls")]
29    Tls {
30        #[pin]
31        stream: tokio_rustls::client::TlsStream<TcpStream>,
32    },
33    #[cfg(feature = "with-native-tls")]
34    Tls {
35        #[pin]
36        stream: tokio_native_tls::TlsStream<TcpStream>,
37    },
38    Plain {
39        #[pin]
40        stream: TcpStream,
41    },
42}
43
44impl AsyncWrite for RespConnectionInner {
45    fn poll_write(
46        self: std::pin::Pin<&mut Self>,
47        cx: &mut std::task::Context<'_>,
48        buf: &[u8],
49    ) -> std::task::Poll<Result<usize, std::io::Error>> {
50        let this = self.project();
51        match this {
52            #[cfg(feature = "tls")]
53            RespConnectionInnerProj::Tls { stream } => stream.poll_write(cx, buf),
54            RespConnectionInnerProj::Plain { stream } => stream.poll_write(cx, buf),
55        }
56    }
57
58    fn poll_flush(
59        self: std::pin::Pin<&mut Self>,
60        cx: &mut std::task::Context<'_>,
61    ) -> std::task::Poll<Result<(), std::io::Error>> {
62        let this = self.project();
63        match this {
64            #[cfg(feature = "tls")]
65            RespConnectionInnerProj::Tls { stream } => stream.poll_flush(cx),
66            RespConnectionInnerProj::Plain { stream } => stream.poll_flush(cx),
67        }
68    }
69
70    fn poll_shutdown(
71        self: std::pin::Pin<&mut Self>,
72        cx: &mut std::task::Context<'_>,
73    ) -> std::task::Poll<Result<(), std::io::Error>> {
74        let this = self.project();
75        match this {
76            #[cfg(feature = "tls")]
77            RespConnectionInnerProj::Tls { stream } => stream.poll_shutdown(cx),
78            RespConnectionInnerProj::Plain { stream } => stream.poll_shutdown(cx),
79        }
80    }
81}
82
83impl AsyncRead for RespConnectionInner {
84    fn poll_read(
85        self: std::pin::Pin<&mut Self>,
86        cx: &mut std::task::Context<'_>,
87        buf: &mut tokio::io::ReadBuf<'_>,
88    ) -> std::task::Poll<std::io::Result<()>> {
89        let this = self.project();
90        match this {
91            #[cfg(feature = "tls")]
92            RespConnectionInnerProj::Tls { stream } => stream.poll_read(cx, buf),
93            RespConnectionInnerProj::Plain { stream } => stream.poll_read(cx, buf),
94        }
95    }
96}
97
98pub type RespConnection = Framed<RespConnectionInner, RespCodec>;
99
100/// Connect to a Redis server and return a Future that resolves to a
101/// `RespConnection` for reading and writing asynchronously.
102///
103/// Each `RespConnection` implements both `Sink` and `Stream` and read and
104/// writes `RESP` objects.
105///
106/// This is a low-level interface to enable the creation of higher-level
107/// functionality.
108///
109/// The sink and stream sides behave independently of each other, it is the
110/// responsibility of the calling application to determine what results are
111/// paired to a particular command.
112///
113/// But since most Redis usages involve issue commands that result in one
114/// single result, this library also implements `paired_connect`.
115pub async fn connect(
116    host: &str,
117    port: u16,
118    socket_keepalive: Option<Duration>,
119    socket_timeout: Option<Duration>,
120) -> Result<RespConnection, error::Error> {
121    let tcp_stream = TcpStream::connect((host, port)).await?;
122    apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
123    Ok(RespCodec.framed(RespConnectionInner::Plain { stream: tcp_stream }))
124}
125
126#[cfg(feature = "with-rustls")]
127pub async fn connect_tls(
128    host: &str,
129    port: u16,
130    socket_keepalive: Option<Duration>,
131    socket_timeout: Option<Duration>,
132) -> Result<RespConnection, error::Error> {
133    use std::sync::Arc;
134    use tokio_rustls::{
135        rustls::{ClientConfig, RootCertStore},
136        TlsConnector,
137    };
138
139    let mut root_store = RootCertStore::empty();
140    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
141    let config = ClientConfig::builder()
142        .with_root_certificates(root_store)
143        .with_no_client_auth();
144    let connector = TlsConnector::from(Arc::new(config));
145    let addr =
146        tokio::net::lookup_host((host, port))
147            .await?
148            .next()
149            .ok_or(error::Error::Connection(
150                error::ConnectionReason::ConnectionFailed,
151            ))?;
152    let tcp_stream = TcpStream::connect(addr).await?;
153    apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
154
155    let stream = connector
156        .connect(
157            String::from(host)
158                .try_into()
159                .map_err(|_err| error::Error::InvalidDnsName)?,
160            tcp_stream,
161        )
162        .await?;
163    Ok(RespCodec.framed(RespConnectionInner::Tls { stream }))
164}
165
166#[cfg(feature = "with-native-tls")]
167pub async fn connect_tls(
168    host: &str,
169    port: u16,
170    socket_keepalive: Option<Duration>,
171    socket_timeout: Option<Duration>,
172) -> Result<RespConnection, error::Error> {
173    let cx = native_tls::TlsConnector::builder().build()?;
174    let cx = tokio_native_tls::TlsConnector::from(cx);
175
176    let addr =
177        tokio::net::lookup_host((host, port))
178            .await?
179            .next()
180            .ok_or(error::Error::Connection(
181                error::ConnectionReason::ConnectionFailed,
182            ))?;
183    let tcp_stream = TcpStream::connect(addr).await?;
184    apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
185    let stream = cx.connect(host, tcp_stream).await?;
186
187    Ok(RespCodec.framed(RespConnectionInner::Tls { stream }))
188}
189
190pub async fn connect_with_auth(
191    host: &str,
192    port: u16,
193    username: Option<&str>,
194    password: Option<&str>,
195    #[allow(unused_variables)] tls: bool,
196    socket_keepalive: Option<Duration>,
197    socket_timeout: Option<Duration>,
198) -> Result<RespConnection, error::Error> {
199    #[cfg(feature = "tls")]
200    let mut connection = if tls {
201        connect_tls(host, port, socket_keepalive, socket_timeout).await?
202    } else {
203        connect(host, port, socket_keepalive, socket_timeout).await?
204    };
205    #[cfg(not(feature = "tls"))]
206    let mut connection = connect(host, port, socket_keepalive, socket_timeout).await?;
207
208    if let Some(password) = password {
209        let mut auth = resp_array!["AUTH"];
210
211        if let Some(username) = username {
212            auth.push(username);
213        }
214
215        auth.push(password);
216
217        connection.send(auth).await?;
218        match connection.next().await {
219            Some(Ok(value)) => match resp::FromResp::from_resp(value) {
220                Ok(()) => (),
221                Err(e) => return Err(e),
222            },
223            Some(Err(e)) => return Err(e),
224            None => {
225                return Err(error::internal(
226                    "Connection closed before authentication complete",
227                ))
228            }
229        }
230    }
231
232    Ok(connection)
233}
234
235/// Apply a custom keep-alive value to the connection
236fn apply_keepalive_and_timeouts(
237    stream: &TcpStream,
238    socket_keepalive: Option<Duration>,
239    socket_timeout: Option<Duration>,
240) -> Result<(), error::Error> {
241    let sock_ref = socket2::SockRef::from(stream);
242
243    if let Some(interval) = socket_keepalive {
244        let keep_alive = socket2::TcpKeepalive::new()
245            .with_time(interval)
246            .with_interval(interval);
247        // Not windows
248        #[cfg(any(
249            target_os = "android",
250            target_os = "dragonfly",
251            target_os = "freebsd",
252            target_os = "fuchsia",
253            target_os = "illumos",
254            target_os = "ios",
255            target_os = "linux",
256            target_os = "macos",
257            target_os = "netbsd",
258            target_os = "tvos",
259            target_os = "watchos",
260        ))]
261        let keep_alive = keep_alive.with_retries(1);
262        sock_ref.set_tcp_keepalive(&keep_alive)?;
263    }
264
265    if let Some(timeout) = socket_timeout {
266        sock_ref.set_read_timeout(Some(timeout))?;
267        sock_ref.set_write_timeout(Some(timeout))?;
268    }
269
270    Ok(())
271}
272
273#[cfg(test)]
274mod test {
275    use futures_util::{
276        sink::SinkExt,
277        stream::{self, StreamExt},
278    };
279
280    use crate::resp;
281
282    #[tokio::test]
283    async fn can_connect() {
284        let mut connection = super::connect("127.0.0.1", 6379, None, None)
285            .await
286            .expect("Cannot connect");
287        connection
288            .send(resp_array!["PING", "TEST"])
289            .await
290            .expect("Cannot send PING");
291        let values: Vec<_> = connection
292            .take(1)
293            .map(|r| r.expect("Unexpected invalid data"))
294            .collect()
295            .await;
296
297        assert_eq!(values.len(), 1);
298        assert_eq!(values[0], "TEST".into());
299    }
300
301    #[tokio::test]
302    async fn complex_test() {
303        let mut connection = super::connect("127.0.0.1", 6379, None, None)
304            .await
305            .expect("Cannot connect");
306        let mut ops = Vec::new();
307        ops.push(resp_array!["FLUSH"]);
308        ops.extend((0..1000).map(|i| resp_array!["SADD", "test_set", format!("VALUE: {}", i)]));
309        ops.push(resp_array!["SMEMBERS", "test_set"]);
310        let mut ops_stream = stream::iter(ops).map(Ok);
311        connection
312            .send_all(&mut ops_stream)
313            .await
314            .expect("Cannot send");
315        let values: Vec<_> = connection
316            .skip(1001)
317            .take(1)
318            .map(|r| r.expect("Unexpected invalid data"))
319            .collect()
320            .await;
321
322        assert_eq!(values.len(), 1);
323        let values = match &values[0] {
324            resp::RespValue::Array(ref values) => values.clone(),
325            _ => panic!("Not an array"),
326        };
327        assert_eq!(values.len(), 1000);
328    }
329}