1use std::time::Duration;
12
13use futures_util::{SinkExt, StreamExt};
14use pin_project::pin_project;
15use tokio::{
16 io::{AsyncRead, AsyncWrite},
17 net::TcpStream,
18};
19use tokio_util::codec::{Decoder, Framed};
20
21use crate::{
22 error,
23 resp::{self, RespCodec},
24};
25
26#[pin_project(project = RespConnectionInnerProj)]
27#[cfg_attr(
28 any(feature = "with-rustls", feature = "with-native-tls"),
29 expect(
30 clippy::large_enum_variant,
31 reason = "will only be enabled if selected, and isn't moved once built"
32 )
33)]
34pub enum RespConnectionInner {
35 #[cfg(feature = "with-rustls")]
36 Tls {
37 #[pin]
38 stream: tokio_rustls::client::TlsStream<TcpStream>,
39 },
40 #[cfg(feature = "with-native-tls")]
41 Tls {
42 #[pin]
43 stream: tokio_native_tls::TlsStream<TcpStream>,
44 },
45 Plain {
46 #[pin]
47 stream: TcpStream,
48 },
49}
50
51impl AsyncWrite for RespConnectionInner {
52 fn poll_write(
53 self: std::pin::Pin<&mut Self>,
54 cx: &mut std::task::Context<'_>,
55 buf: &[u8],
56 ) -> std::task::Poll<Result<usize, std::io::Error>> {
57 let this = self.project();
58 match this {
59 #[cfg(feature = "tls")]
60 RespConnectionInnerProj::Tls { stream } => stream.poll_write(cx, buf),
61 RespConnectionInnerProj::Plain { stream } => stream.poll_write(cx, buf),
62 }
63 }
64
65 fn poll_flush(
66 self: std::pin::Pin<&mut Self>,
67 cx: &mut std::task::Context<'_>,
68 ) -> std::task::Poll<Result<(), std::io::Error>> {
69 let this = self.project();
70 match this {
71 #[cfg(feature = "tls")]
72 RespConnectionInnerProj::Tls { stream } => stream.poll_flush(cx),
73 RespConnectionInnerProj::Plain { stream } => stream.poll_flush(cx),
74 }
75 }
76
77 fn poll_shutdown(
78 self: std::pin::Pin<&mut Self>,
79 cx: &mut std::task::Context<'_>,
80 ) -> std::task::Poll<Result<(), std::io::Error>> {
81 let this = self.project();
82 match this {
83 #[cfg(feature = "tls")]
84 RespConnectionInnerProj::Tls { stream } => stream.poll_shutdown(cx),
85 RespConnectionInnerProj::Plain { stream } => stream.poll_shutdown(cx),
86 }
87 }
88}
89
90impl AsyncRead for RespConnectionInner {
91 fn poll_read(
92 self: std::pin::Pin<&mut Self>,
93 cx: &mut std::task::Context<'_>,
94 buf: &mut tokio::io::ReadBuf<'_>,
95 ) -> std::task::Poll<std::io::Result<()>> {
96 let this = self.project();
97 match this {
98 #[cfg(feature = "tls")]
99 RespConnectionInnerProj::Tls { stream } => stream.poll_read(cx, buf),
100 RespConnectionInnerProj::Plain { stream } => stream.poll_read(cx, buf),
101 }
102 }
103}
104
105pub type RespConnection = Framed<RespConnectionInner, RespCodec>;
106
107pub async fn connect(
123 host: &str,
124 port: u16,
125 socket_keepalive: Option<Duration>,
126 socket_timeout: Option<Duration>,
127) -> Result<RespConnection, error::Error> {
128 let tcp_stream = TcpStream::connect((host, port)).await?;
129 apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
130 Ok(RespCodec.framed(RespConnectionInner::Plain { stream: tcp_stream }))
131}
132
133#[cfg(feature = "with-rustls")]
134pub async fn connect_tls(
135 host: &str,
136 port: u16,
137 socket_keepalive: Option<Duration>,
138 socket_timeout: Option<Duration>,
139) -> Result<RespConnection, error::Error> {
140 use std::sync::Arc;
141 use tokio_rustls::{
142 rustls::{ClientConfig, RootCertStore},
143 TlsConnector,
144 };
145
146 let mut root_store = RootCertStore::empty();
147 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
148 let config = ClientConfig::builder()
149 .with_root_certificates(root_store)
150 .with_no_client_auth();
151 let connector = TlsConnector::from(Arc::new(config));
152 let addr =
153 tokio::net::lookup_host((host, port))
154 .await?
155 .next()
156 .ok_or(error::Error::Connection(
157 error::ConnectionReason::ConnectionFailed,
158 ))?;
159 let tcp_stream = TcpStream::connect(addr).await?;
160 apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
161
162 let stream = connector
163 .connect(
164 String::from(host)
165 .try_into()
166 .map_err(|_err| error::Error::InvalidDnsName)?,
167 tcp_stream,
168 )
169 .await?;
170 Ok(RespCodec.framed(RespConnectionInner::Tls { stream }))
171}
172
173#[cfg(feature = "with-native-tls")]
174pub async fn connect_tls(
175 host: &str,
176 port: u16,
177 socket_keepalive: Option<Duration>,
178 socket_timeout: Option<Duration>,
179) -> Result<RespConnection, error::Error> {
180 let cx = native_tls::TlsConnector::builder().build()?;
181 let cx = tokio_native_tls::TlsConnector::from(cx);
182
183 let addr =
184 tokio::net::lookup_host((host, port))
185 .await?
186 .next()
187 .ok_or(error::Error::Connection(
188 error::ConnectionReason::ConnectionFailed,
189 ))?;
190 let tcp_stream = TcpStream::connect(addr).await?;
191 apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
192 let stream = cx.connect(host, tcp_stream).await?;
193
194 Ok(RespCodec.framed(RespConnectionInner::Tls { stream }))
195}
196
197pub async fn connect_with_auth(
198 host: &str,
199 port: u16,
200 username: Option<&str>,
201 password: Option<&str>,
202 #[allow(unused_variables)] tls: bool,
203 socket_keepalive: Option<Duration>,
204 socket_timeout: Option<Duration>,
205) -> Result<RespConnection, error::Error> {
206 #[cfg(feature = "tls")]
207 let mut connection = if tls {
208 connect_tls(host, port, socket_keepalive, socket_timeout).await?
209 } else {
210 connect(host, port, socket_keepalive, socket_timeout).await?
211 };
212 #[cfg(not(feature = "tls"))]
213 let mut connection = connect(host, port, socket_keepalive, socket_timeout).await?;
214
215 if let Some(password) = password {
216 let mut auth = resp_array!["AUTH"];
217
218 if let Some(username) = username {
219 auth.push(username);
220 }
221
222 auth.push(password);
223
224 connection.send(auth).await?;
225 match connection.next().await {
226 Some(Ok(value)) => match resp::FromResp::from_resp(value) {
227 Ok(()) => (),
228 Err(e) => return Err(e),
229 },
230 Some(Err(e)) => return Err(e),
231 None => {
232 return Err(error::internal(
233 "Connection closed before authentication complete",
234 ))
235 }
236 }
237 }
238
239 Ok(connection)
240}
241
242fn apply_keepalive_and_timeouts(
244 stream: &TcpStream,
245 socket_keepalive: Option<Duration>,
246 socket_timeout: Option<Duration>,
247) -> Result<(), error::Error> {
248 let sock_ref = socket2::SockRef::from(stream);
249
250 if let Some(interval) = socket_keepalive {
251 let keep_alive = socket2::TcpKeepalive::new()
252 .with_time(interval)
253 .with_interval(interval);
254 #[cfg(any(
256 target_os = "android",
257 target_os = "dragonfly",
258 target_os = "freebsd",
259 target_os = "fuchsia",
260 target_os = "illumos",
261 target_os = "ios",
262 target_os = "linux",
263 target_os = "macos",
264 target_os = "netbsd",
265 target_os = "tvos",
266 target_os = "watchos",
267 ))]
268 let keep_alive = keep_alive.with_retries(1);
269 sock_ref.set_tcp_keepalive(&keep_alive)?;
270 }
271
272 if let Some(timeout) = socket_timeout {
273 sock_ref.set_read_timeout(Some(timeout))?;
274 sock_ref.set_write_timeout(Some(timeout))?;
275 }
276
277 Ok(())
278}
279
280#[cfg(test)]
281mod test {
282 use futures_util::{
283 sink::SinkExt,
284 stream::{self, StreamExt},
285 };
286
287 use crate::resp;
288
289 #[tokio::test]
290 async fn can_connect() {
291 let mut connection = super::connect("127.0.0.1", 6379, None, None)
292 .await
293 .expect("Cannot connect");
294 connection
295 .send(resp_array!["PING", "TEST"])
296 .await
297 .expect("Cannot send PING");
298 let values: Vec<_> = connection
299 .take(1)
300 .map(|r| r.expect("Unexpected invalid data"))
301 .collect()
302 .await;
303
304 assert_eq!(values.len(), 1);
305 assert_eq!(values[0], "TEST".into());
306 }
307
308 #[tokio::test]
309 async fn complex_test() {
310 let mut connection = super::connect("127.0.0.1", 6379, None, None)
311 .await
312 .expect("Cannot connect");
313 let mut ops = Vec::new();
314 ops.push(resp_array!["FLUSH"]);
315 ops.extend((0..1000).map(|i| resp_array!["SADD", "test_set", format!("VALUE: {}", i)]));
316 ops.push(resp_array!["SMEMBERS", "test_set"]);
317 let mut ops_stream = stream::iter(ops).map(Ok);
318 connection
319 .send_all(&mut ops_stream)
320 .await
321 .expect("Cannot send");
322 let values: Vec<_> = connection
323 .skip(1001)
324 .take(1)
325 .map(|r| r.expect("Unexpected invalid data"))
326 .collect()
327 .await;
328
329 assert_eq!(values.len(), 1);
330 let values = match &values[0] {
331 resp::RespValue::Array(ref values) => values.clone(),
332 _ => panic!("Not an array"),
333 };
334 assert_eq!(values.len(), 1000);
335 }
336}