1use std::time::Duration;
12
13use futures_util::{SinkExt, StreamExt};
14use pin_project::pin_project;
15use tokio::{
16 io::{AsyncRead, AsyncWrite},
17 net::TcpStream,
18};
19use tokio_util::codec::{Decoder, Framed};
20
21use crate::{
22 error,
23 resp::{self, RespCodec},
24};
25
26#[pin_project(project = RespConnectionInnerProj)]
27pub enum RespConnectionInner {
28 #[cfg(feature = "with-rustls")]
29 Tls {
30 #[pin]
31 stream: tokio_rustls::client::TlsStream<TcpStream>,
32 },
33 #[cfg(feature = "with-native-tls")]
34 Tls {
35 #[pin]
36 stream: tokio_native_tls::TlsStream<TcpStream>,
37 },
38 Plain {
39 #[pin]
40 stream: TcpStream,
41 },
42}
43
44impl AsyncWrite for RespConnectionInner {
45 fn poll_write(
46 self: std::pin::Pin<&mut Self>,
47 cx: &mut std::task::Context<'_>,
48 buf: &[u8],
49 ) -> std::task::Poll<Result<usize, std::io::Error>> {
50 let this = self.project();
51 match this {
52 #[cfg(feature = "tls")]
53 RespConnectionInnerProj::Tls { stream } => stream.poll_write(cx, buf),
54 RespConnectionInnerProj::Plain { stream } => stream.poll_write(cx, buf),
55 }
56 }
57
58 fn poll_flush(
59 self: std::pin::Pin<&mut Self>,
60 cx: &mut std::task::Context<'_>,
61 ) -> std::task::Poll<Result<(), std::io::Error>> {
62 let this = self.project();
63 match this {
64 #[cfg(feature = "tls")]
65 RespConnectionInnerProj::Tls { stream } => stream.poll_flush(cx),
66 RespConnectionInnerProj::Plain { stream } => stream.poll_flush(cx),
67 }
68 }
69
70 fn poll_shutdown(
71 self: std::pin::Pin<&mut Self>,
72 cx: &mut std::task::Context<'_>,
73 ) -> std::task::Poll<Result<(), std::io::Error>> {
74 let this = self.project();
75 match this {
76 #[cfg(feature = "tls")]
77 RespConnectionInnerProj::Tls { stream } => stream.poll_shutdown(cx),
78 RespConnectionInnerProj::Plain { stream } => stream.poll_shutdown(cx),
79 }
80 }
81}
82
83impl AsyncRead for RespConnectionInner {
84 fn poll_read(
85 self: std::pin::Pin<&mut Self>,
86 cx: &mut std::task::Context<'_>,
87 buf: &mut tokio::io::ReadBuf<'_>,
88 ) -> std::task::Poll<std::io::Result<()>> {
89 let this = self.project();
90 match this {
91 #[cfg(feature = "tls")]
92 RespConnectionInnerProj::Tls { stream } => stream.poll_read(cx, buf),
93 RespConnectionInnerProj::Plain { stream } => stream.poll_read(cx, buf),
94 }
95 }
96}
97
98pub type RespConnection = Framed<RespConnectionInner, RespCodec>;
99
100pub async fn connect(
116 host: &str,
117 port: u16,
118 socket_keepalive: Option<Duration>,
119 socket_timeout: Option<Duration>,
120) -> Result<RespConnection, error::Error> {
121 let tcp_stream = TcpStream::connect((host, port)).await?;
122 apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
123 Ok(RespCodec.framed(RespConnectionInner::Plain { stream: tcp_stream }))
124}
125
126#[cfg(feature = "with-rustls")]
127pub async fn connect_tls(
128 host: &str,
129 port: u16,
130 socket_keepalive: Option<Duration>,
131 socket_timeout: Option<Duration>,
132) -> Result<RespConnection, error::Error> {
133 use std::sync::Arc;
134 use tokio_rustls::{
135 rustls::{ClientConfig, RootCertStore},
136 TlsConnector,
137 };
138
139 let mut root_store = RootCertStore::empty();
140 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
141 let config = ClientConfig::builder()
142 .with_root_certificates(root_store)
143 .with_no_client_auth();
144 let connector = TlsConnector::from(Arc::new(config));
145 let addr =
146 tokio::net::lookup_host((host, port))
147 .await?
148 .next()
149 .ok_or(error::Error::Connection(
150 error::ConnectionReason::ConnectionFailed,
151 ))?;
152 let tcp_stream = TcpStream::connect(addr).await?;
153 apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
154
155 let stream = connector
156 .connect(
157 String::from(host)
158 .try_into()
159 .map_err(|_err| error::Error::InvalidDnsName)?,
160 tcp_stream,
161 )
162 .await?;
163 Ok(RespCodec.framed(RespConnectionInner::Tls { stream }))
164}
165
166#[cfg(feature = "with-native-tls")]
167pub async fn connect_tls(
168 host: &str,
169 port: u16,
170 socket_keepalive: Option<Duration>,
171 socket_timeout: Option<Duration>,
172) -> Result<RespConnection, error::Error> {
173 let cx = native_tls::TlsConnector::builder().build()?;
174 let cx = tokio_native_tls::TlsConnector::from(cx);
175
176 let addr =
177 tokio::net::lookup_host((host, port))
178 .await?
179 .next()
180 .ok_or(error::Error::Connection(
181 error::ConnectionReason::ConnectionFailed,
182 ))?;
183 let tcp_stream = TcpStream::connect(addr).await?;
184 apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
185 let stream = cx.connect(host, tcp_stream).await?;
186
187 Ok(RespCodec.framed(RespConnectionInner::Tls { stream }))
188}
189
190pub async fn connect_with_auth(
191 host: &str,
192 port: u16,
193 username: Option<&str>,
194 password: Option<&str>,
195 #[allow(unused_variables)] tls: bool,
196 socket_keepalive: Option<Duration>,
197 socket_timeout: Option<Duration>,
198) -> Result<RespConnection, error::Error> {
199 #[cfg(feature = "tls")]
200 let mut connection = if tls {
201 connect_tls(host, port, socket_keepalive, socket_timeout).await?
202 } else {
203 connect(host, port, socket_keepalive, socket_timeout).await?
204 };
205 #[cfg(not(feature = "tls"))]
206 let mut connection = connect(host, port, socket_keepalive, socket_timeout).await?;
207
208 if let Some(password) = password {
209 let mut auth = resp_array!["AUTH"];
210
211 if let Some(username) = username {
212 auth.push(username);
213 }
214
215 auth.push(password);
216
217 connection.send(auth).await?;
218 match connection.next().await {
219 Some(Ok(value)) => match resp::FromResp::from_resp(value) {
220 Ok(()) => (),
221 Err(e) => return Err(e),
222 },
223 Some(Err(e)) => return Err(e),
224 None => {
225 return Err(error::internal(
226 "Connection closed before authentication complete",
227 ))
228 }
229 }
230 }
231
232 Ok(connection)
233}
234
235fn apply_keepalive_and_timeouts(
237 stream: &TcpStream,
238 socket_keepalive: Option<Duration>,
239 socket_timeout: Option<Duration>,
240) -> Result<(), error::Error> {
241 let sock_ref = socket2::SockRef::from(stream);
242
243 if let Some(interval) = socket_keepalive {
244 let keep_alive = socket2::TcpKeepalive::new()
245 .with_time(interval)
246 .with_interval(interval);
247 #[cfg(any(
249 target_os = "android",
250 target_os = "dragonfly",
251 target_os = "freebsd",
252 target_os = "fuchsia",
253 target_os = "illumos",
254 target_os = "ios",
255 target_os = "linux",
256 target_os = "macos",
257 target_os = "netbsd",
258 target_os = "tvos",
259 target_os = "watchos",
260 ))]
261 let keep_alive = keep_alive.with_retries(1);
262 sock_ref.set_tcp_keepalive(&keep_alive)?;
263 }
264
265 if let Some(timeout) = socket_timeout {
266 sock_ref.set_read_timeout(Some(timeout))?;
267 sock_ref.set_write_timeout(Some(timeout))?;
268 }
269
270 Ok(())
271}
272
273#[cfg(test)]
274mod test {
275 use futures_util::{
276 sink::SinkExt,
277 stream::{self, StreamExt},
278 };
279
280 use crate::resp;
281
282 #[tokio::test]
283 async fn can_connect() {
284 let mut connection = super::connect("127.0.0.1", 6379, None, None)
285 .await
286 .expect("Cannot connect");
287 connection
288 .send(resp_array!["PING", "TEST"])
289 .await
290 .expect("Cannot send PING");
291 let values: Vec<_> = connection
292 .take(1)
293 .map(|r| r.expect("Unexpected invalid data"))
294 .collect()
295 .await;
296
297 assert_eq!(values.len(), 1);
298 assert_eq!(values[0], "TEST".into());
299 }
300
301 #[tokio::test]
302 async fn complex_test() {
303 let mut connection = super::connect("127.0.0.1", 6379, None, None)
304 .await
305 .expect("Cannot connect");
306 let mut ops = Vec::new();
307 ops.push(resp_array!["FLUSH"]);
308 ops.extend((0..1000).map(|i| resp_array!["SADD", "test_set", format!("VALUE: {}", i)]));
309 ops.push(resp_array!["SMEMBERS", "test_set"]);
310 let mut ops_stream = stream::iter(ops).map(Ok);
311 connection
312 .send_all(&mut ops_stream)
313 .await
314 .expect("Cannot send");
315 let values: Vec<_> = connection
316 .skip(1001)
317 .take(1)
318 .map(|r| r.expect("Unexpected invalid data"))
319 .collect()
320 .await;
321
322 assert_eq!(values.len(), 1);
323 let values = match &values[0] {
324 resp::RespValue::Array(ref values) => values.clone(),
325 _ => panic!("Not an array"),
326 };
327 assert_eq!(values.len(), 1000);
328 }
329}