1use std::collections::VecDeque;
12use std::future::Future;
13use std::marker::PhantomData;
14use std::mem;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use std::time::Duration;
19
20use futures_channel::{mpsc, oneshot};
21use futures_sink::Sink;
22use futures_util::{future::TryFutureExt, stream::StreamExt};
23
24use super::{
25 connect::{connect_with_auth, RespConnection},
26 ConnectionBuilder,
27};
28
29use crate::{
30 error,
31 reconnect::{reconnect, Reconnect},
32 resp,
33};
34
35enum SendStatus {
37 Ok,
39 End,
41 Full(resp::RespValue),
44}
45
46#[derive(Debug)]
48enum ReceiveStatus {
49 ReadyFinished,
51 ReadyMore,
53 NotReady,
55}
56
57type CommandResult = Result<resp::RespValue, error::Error>;
58type Responder = oneshot::Sender<CommandResult>;
59type SendPayload = (resp::RespValue, Responder);
60
61struct PairedConnectionInner {
64 connection: RespConnection,
66 out_rx: mpsc::UnboundedReceiver<SendPayload>,
68 waiting: VecDeque<Responder>,
70
71 send_status: SendStatus,
73}
74
75impl PairedConnectionInner {
76 fn new(
77 con: RespConnection,
78 out_rx: mpsc::UnboundedReceiver<(resp::RespValue, Responder)>,
79 ) -> Self {
80 PairedConnectionInner {
81 connection: con,
82 out_rx,
83 waiting: VecDeque::new(),
84 send_status: SendStatus::Ok,
85 }
86 }
87
88 fn impl_start_send(
89 &mut self,
90 cx: &mut Context,
91 msg: resp::RespValue,
92 ) -> Result<bool, error::Error> {
93 match Pin::new(&mut self.connection).poll_ready(cx) {
94 Poll::Ready(Ok(())) => (),
95 Poll::Ready(Err(e)) => return Err(e.into()),
96 Poll::Pending => {
97 self.send_status = SendStatus::Full(msg);
98 return Ok(false);
99 }
100 }
101
102 self.send_status = SendStatus::Ok;
103 Pin::new(&mut self.connection).start_send(msg)?;
104 Ok(true)
105 }
106
107 fn poll_start_send(&mut self, cx: &mut Context) -> Result<bool, error::Error> {
108 let mut status = SendStatus::Ok;
109 mem::swap(&mut status, &mut self.send_status);
110
111 let message = match status {
112 SendStatus::End => {
113 self.send_status = SendStatus::End;
114 return Ok(false);
115 }
116 SendStatus::Full(msg) => msg,
117 SendStatus::Ok => match self.out_rx.poll_next_unpin(cx) {
118 Poll::Ready(Some((msg, tx))) => {
119 self.waiting.push_back(tx);
120 msg
121 }
122 Poll::Ready(None) => {
123 self.send_status = SendStatus::End;
124 return Ok(false);
125 }
126 Poll::Pending => return Ok(false),
127 },
128 };
129
130 self.impl_start_send(cx, message)
131 }
132
133 fn poll_complete(&mut self, cx: &mut Context) -> Result<(), error::Error> {
134 let _ = Pin::new(&mut self.connection).poll_flush(cx)?;
135 Ok(())
136 }
137
138 fn receive(&mut self, cx: &mut Context) -> Result<ReceiveStatus, error::Error> {
139 if let SendStatus::End = self.send_status {
140 if self.waiting.is_empty() {
141 return Ok(ReceiveStatus::ReadyFinished);
142 }
143 }
144 match self.connection.poll_next_unpin(cx) {
145 Poll::Ready(None) => Err(error::unexpected("Connection to Redis closed unexpectedly")),
146 Poll::Ready(Some(Ok(msg))) => {
147 let tx = match self.waiting.pop_front() {
148 Some(tx) => tx,
149 None => panic!("Received unexpected message: {:?}", msg),
150 };
151 let _ = tx.send(Ok(msg));
152 Ok(ReceiveStatus::ReadyMore)
153 }
154 Poll::Ready(Some(Err(e))) => Err(e),
155 Poll::Pending => Ok(ReceiveStatus::NotReady),
156 }
157 }
158
159 fn handle_error(&mut self, e: &error::Error) {
160 for tx in self.waiting.drain(..) {
161 let _ = tx.send(Err(error::internal(format!(
162 "Failed due to underlying failure: {}",
163 e
164 ))));
165 }
166
167 log::error!("Internal error in PairedConnectionInner: {}", e);
168 }
169}
170
171impl Future for PairedConnectionInner {
172 type Output = ();
173
174 #[allow(clippy::unit_arg)]
175 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
176 let mut_self = self.get_mut();
177 let mut sending = true;
179 while sending {
180 sending = match mut_self.poll_start_send(cx) {
181 Ok(sending) => sending,
182 Err(ref e) => return Poll::Ready(mut_self.handle_error(e)),
183 };
184 }
185
186 if let Err(ref e) = mut_self.poll_complete(cx) {
187 return Poll::Ready(mut_self.handle_error(e));
188 };
189
190 loop {
192 match mut_self.receive(cx) {
193 Ok(ReceiveStatus::NotReady) => return Poll::Pending,
194 Ok(ReceiveStatus::ReadyMore) => (),
195 Ok(ReceiveStatus::ReadyFinished) => return Poll::Ready(()),
196 Err(ref e) => return Poll::Ready(mut_self.handle_error(e)),
197 }
198 }
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct PairedConnection {
205 out_tx_c: Arc<Reconnect<SendPayload, mpsc::UnboundedSender<SendPayload>>>,
206}
207
208async fn inner_conn_fn(
209 host: String,
210 port: u16,
211 username: Option<Arc<str>>,
212 password: Option<Arc<str>>,
213 tls: bool,
214 socket_keepalive: Option<Duration>,
215 socket_timeout: Option<Duration>,
216) -> Result<mpsc::UnboundedSender<SendPayload>, error::Error> {
217 let username = username.as_ref().map(|u| u.as_ref());
218 let password = password.as_ref().map(|p| p.as_ref());
219 let connection = connect_with_auth(
220 &host,
221 port,
222 username,
223 password,
224 tls,
225 socket_keepalive,
226 socket_timeout,
227 )
228 .await?;
229 let (out_tx, out_rx) = mpsc::unbounded();
230 let paired_connection_inner = PairedConnectionInner::new(connection, out_rx);
231 tokio::spawn(paired_connection_inner);
232 Ok(out_tx)
233}
234
235impl ConnectionBuilder {
236 pub fn paired_connect(&self) -> impl Future<Output = Result<PairedConnection, error::Error>> {
237 let host = self.host.clone();
238 let port = self.port;
239 let username = self.username.clone();
240 let password = self.password.clone();
241
242 let work_fn = |con: &mpsc::UnboundedSender<SendPayload>, act| {
243 con.unbounded_send(act).map_err(|e| e.into())
244 };
245
246 #[cfg(feature = "tls")]
247 let tls = self.tls;
248 #[cfg(not(feature = "tls"))]
249 let tls = false;
250
251 let socket_keepalive = self.socket_keepalive;
252 let socket_timeout = self.socket_timeout;
253
254 let conn_fn = move || {
255 let con_f = inner_conn_fn(
256 host.clone(),
257 port,
258 username.clone(),
259 password.clone(),
260 tls,
261 socket_keepalive,
262 socket_timeout,
263 );
264 Box::pin(con_f) as Pin<Box<dyn Future<Output = Result<_, error::Error>> + Send + Sync>>
265 };
266
267 let reconnecting_con = reconnect(work_fn, conn_fn);
268 reconnecting_con.map_ok(|con| PairedConnection {
269 out_tx_c: Arc::new(con),
270 })
271 }
272}
273
274pub async fn paired_connect(
286 host: impl Into<String>,
287 port: u16,
288) -> Result<PairedConnection, error::Error> {
289 ConnectionBuilder::new(host, port)?.paired_connect().await
290}
291
292impl PairedConnection {
293 pub fn send<T>(&self, msg: resp::RespValue) -> SendFuture<T>
306 where
307 T: resp::FromResp + Unpin,
308 {
309 match &msg {
310 resp::RespValue::Array(_) => (),
311 _ => {
312 return SendFuture::new(error::internal("Command must be a RespValue::Array"));
313 }
314 }
315
316 let (tx, rx) = oneshot::channel();
317 match self.out_tx_c.do_work((msg, tx)) {
318 Ok(()) => SendFuture::new(rx),
319 Err(e) => SendFuture::new(e),
320 }
321 }
322
323 #[inline]
324 pub fn send_and_forget(&self, msg: resp::RespValue) {
325 let send_f = self.send::<resp::RespValue>(msg);
326 let forget_f = async {
327 if let Err(e) = send_f.await {
328 log::error!("Error in send_and_forget: {}", e);
329 }
330 };
331 tokio::spawn(forget_f);
332 }
333}
334
335#[derive(Debug)]
336enum SendFutureType {
337 Wait(oneshot::Receiver<Result<resp::RespValue, error::Error>>),
338 Error(Option<error::Error>),
339}
340
341impl From<oneshot::Receiver<Result<resp::RespValue, error::Error>>> for SendFutureType {
342 fn from(from: oneshot::Receiver<Result<resp::RespValue, error::Error>>) -> Self {
343 Self::Wait(from)
344 }
345}
346
347impl From<error::Error> for SendFutureType {
348 fn from(e: error::Error) -> Self {
349 Self::Error(Some(e))
350 }
351}
352
353#[derive(Debug)]
354pub struct SendFuture<T> {
355 send_type: SendFutureType,
356 _phantom: PhantomData<T>,
357}
358
359impl<T> SendFuture<T> {
360 #[inline]
361 fn new(send_type: impl Into<SendFutureType>) -> Self {
362 Self {
363 send_type: send_type.into(),
364 _phantom: Default::default(),
365 }
366 }
367}
368
369impl<T> Future for SendFuture<T>
370where
371 T: resp::FromResp + Unpin,
372{
373 type Output = Result<T, error::Error>;
374
375 #[inline]
376 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
377 match self.get_mut().send_type {
378 SendFutureType::Error(ref mut e) => match e.take() {
379 Some(e) => Poll::Ready(Err(e)),
380 None => panic!("Future polled several times after completion"),
381 },
382 SendFutureType::Wait(ref mut rx) => match Pin::new(rx).poll(cx) {
383 Poll::Ready(Ok(Ok(v))) => Poll::Ready(T::from_resp(v)),
384 Poll::Ready(Ok(Err(e))) => Poll::Ready(Err(e)),
385 Poll::Ready(Err(_)) => Poll::Ready(Err(error::internal(
386 "Connection closed before response received",
387 ))),
388 Poll::Pending => Poll::Pending,
389 },
390 }
391 }
392}
393
394#[cfg(test)]
395mod test {
396 use super::ConnectionBuilder;
397
398 #[tokio::test]
399 async fn can_paired_connect() {
400 let connection = super::paired_connect("127.0.0.1", 6379)
401 .await
402 .expect("Cannot establish connection");
403
404 let res_f = connection.send(resp_array!["PING", "TEST"]);
405 connection.send_and_forget(resp_array!["SET", "X", "123"]);
406 let wait_f = connection.send(resp_array!["GET", "X"]);
407
408 let result_1: String = res_f.await.expect("Cannot read result of first thing");
409 let result_2: String = wait_f.await.expect("Cannot read result of second thing");
410
411 assert_eq!(result_1, "TEST");
412 assert_eq!(result_2, "123");
413 }
414
415 #[tokio::test]
416 async fn complex_paired_connect() {
417 let connection = super::paired_connect("127.0.0.1", 6379)
418 .await
419 .expect("Cannot establish connection");
420
421 let value: String = connection
422 .send(resp_array!["INCR", "CTR"])
423 .await
424 .expect("Cannot increment counter");
425 let result: String = connection
426 .send(resp_array!["SET", "LASTCTR", value])
427 .await
428 .expect("Cannot set value");
429
430 assert_eq!(result, "OK");
431 }
432
433 #[tokio::test]
434 async fn sending_a_lot_of_data_test() {
435 let connection = super::paired_connect("127.0.0.1", 6379)
436 .await
437 .expect("Cannot connect to Redis");
438 let mut futures = Vec::with_capacity(1000);
439 for i in 0..1000 {
440 let key = format!("X_{}", i);
441 connection.send_and_forget(resp_array!["SET", &key, i.to_string()]);
442 futures.push(connection.send(resp_array!["GET", key]));
443 }
444 let last_future = futures.remove(999);
445 let result: String = last_future.await.expect("Cannot wait for result");
446 assert_eq!(result, "999");
447 }
448
449 #[tokio::test]
450 async fn test_builder() {
451 let mut builder =
452 ConnectionBuilder::new("127.0.0.1", 6379).expect("Cannot construct builder...");
453 builder.password("password");
454 builder.username(String::from("username"));
455 let connection_result = builder.paired_connect().await;
456 assert!(connection_result.is_err());
458 }
459}