1use async_trait::async_trait;
5use futures::{
6 channel::oneshot,
7 future::{self, BoxFuture},
8};
9use redis::aio::{ConnectionLike, MultiplexedConnection};
10use redis::{Client, Cmd, IntoConnectionInfo, Pipeline, RedisError, RedisFuture, Value};
11use tracing::{debug, debug_span, warn, Instrument};
12
13use std::{
14 convert::{AsMut, AsRef},
15 ops::{Deref, DerefMut},
16};
17
18type Result<T> = std::result::Result<T, RedisError>;
19
20#[derive(Debug)]
22pub struct RedisConnectionManager {
23 client: redis::Client,
24}
25
26impl RedisConnectionManager {
27 pub fn new(params: impl IntoConnectionInfo) -> Result<RedisConnectionManager> {
29 Ok(RedisConnectionManager {
30 client: Client::open(params)?,
31 })
32 }
33}
34
35pub struct AsyncConnection {
36 pub conn: MultiplexedConnection,
37 done_rx: oneshot::Receiver<()>,
38 drop_tx: Option<oneshot::Sender<()>>,
39 broken: bool,
40}
41
42impl Drop for AsyncConnection {
46 fn drop(&mut self) {
47 if let Some(drop_tx) = self.drop_tx.take() {
48 let _ = drop_tx.send(());
49 }
50 }
51}
52
53impl Deref for AsyncConnection {
54 type Target = MultiplexedConnection;
55
56 fn deref(&self) -> &Self::Target {
57 &self.conn
58 }
59}
60
61impl DerefMut for AsyncConnection {
62 fn deref_mut(&mut self) -> &mut Self::Target {
63 &mut self.conn
64 }
65}
66
67impl AsMut<MultiplexedConnection> for AsyncConnection {
68 fn as_mut(&mut self) -> &mut MultiplexedConnection {
69 &mut self.conn
70 }
71}
72
73impl AsRef<MultiplexedConnection> for AsyncConnection {
74 fn as_ref(&self) -> &MultiplexedConnection {
75 &self.conn
76 }
77}
78
79impl ConnectionLike for AsyncConnection {
80 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
81 self.conn.req_packed_command(cmd)
82 }
83
84 fn req_packed_commands<'a>(
85 &'a mut self,
86 cmd: &'a Pipeline,
87 offset: usize,
88 count: usize,
89 ) -> RedisFuture<'a, Vec<Value>> {
90 self.conn.req_packed_commands(cmd, offset, count)
91 }
92
93 fn get_db(&self) -> i64 {
94 self.conn.get_db()
95 }
96}
97
98pub async fn async_transaction<C, K, T, F, Args>(
125 con: &mut C,
126 keys: &[K],
127 args: &mut Args,
128 func: F,
129) -> redis::RedisResult<T>
130where
131 C: ConnectionLike,
132 K: redis::ToRedisArgs,
133 F: for<'a> FnMut(
134 &'a mut C,
135 &'a mut Pipeline,
136 &'a mut Args,
137 ) -> BoxFuture<'a, redis::RedisResult<Option<T>>>,
138{
139 let mut func = func;
140 loop {
141 redis::cmd("WATCH")
142 .arg(keys)
143 .query_async::<_, ()>(&mut *con)
144 .await?;
145
146 let mut p = redis::pipe();
147 let response: Option<T> = func(con, p.atomic(), args).await?;
148 match response {
149 None => {
150 continue;
151 }
152 Some(response) => {
153 redis::cmd("UNWATCH")
156 .query_async::<_, ()>(&mut *con)
157 .await?;
158 return Ok(response);
159 }
160 }
161 }
162}
163
164#[async_trait]
165impl l337::ManageConnection for RedisConnectionManager {
166 type Connection = AsyncConnection;
167 type Error = RedisError;
168
169 async fn connect(&self) -> std::result::Result<Self::Connection, l337::Error<Self::Error>> {
170 let (connection, future) = self
171 .client
172 .create_multiplexed_tokio_connection()
173 .instrument(debug_span!("connect: open new redis connection"))
174 .await
175 .map_err(l337::Error::External)?;
176
177 let (done_tx, done_rx) = oneshot::channel();
178 let (drop_tx, drop_rx) = oneshot::channel();
179
180 tokio::spawn(async move {
181 debug!("connect: spawn future backing redis connection");
182 futures::pin_mut!(future, drop_rx);
183
184 future::select(future, drop_rx).await;
185 debug!("Future backing redis connection ended, future calls to this redis connection will fail");
186
187 let _ = done_tx.send(());
191 });
192
193 debug!("connect: redis connection established");
194 Ok(AsyncConnection {
195 conn: connection,
196 broken: false,
197 done_rx,
198 drop_tx: Some(drop_tx),
199 })
200 }
201
202 async fn is_valid(
203 &self,
204 conn: &mut Self::Connection,
205 ) -> std::result::Result<(), l337::Error<Self::Error>> {
206 let result = redis::cmd("PING")
207 .query_async::<_, ()>(conn)
208 .await
209 .map_err(l337::Error::External);
210
211 if result.is_err() {
212 conn.broken = true;
213 }
214
215 result
216 }
217
218 fn has_broken(&self, conn: &mut Self::Connection) -> bool {
219 if conn.broken {
220 return true;
221 }
222
223 match conn.done_rx.try_recv() {
227 Ok(Some(())) => {
230 conn.broken = true;
231 true
232 }
233 Ok(None) => false,
236 Err(error) => {
239 warn!(%error, "cannot receive from connection future");
240 conn.broken = true;
241 true
242 }
243 }
244 }
245
246 fn timed_out(&self) -> l337::Error<Self::Error> {
247 unimplemented!()
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use l337::{Config, Pool};
255
256 #[tokio::test]
257 async fn it_works() {
258 let mngr = RedisConnectionManager::new("redis://redis:6379/0").unwrap();
259
260 let config: Config = Default::default();
261
262 let pool = Pool::new(mngr, config).await.unwrap();
263 let mut conn = pool.connection().await.unwrap();
264 redis::cmd("PING")
265 .query_async::<_, ()>(&mut *conn)
266 .await
267 .unwrap();
268
269 println!("done ping")
270 }
271}