redis_cacher/
pool.rs

1use async_trait::async_trait;
2use crossbeam::queue::ArrayQueue;
3use futures::future::BoxFuture;
4use std::{
5    fmt::Debug,
6    io,
7    ops::{Deref, DerefMut},
8    sync::{
9        atomic::{self, AtomicUsize},
10        Arc, Weak,
11    },
12    time::Duration,
13};
14use tokio::{
15    sync::{OwnedSemaphorePermit, Semaphore},
16    time::sleep,
17};
18
19#[async_trait]
20pub trait ConnectionManager {
21    /// Any information needed to connect to a database
22    /// e.g. IP, port, username, password
23    type Address: Clone + Send + Sync;
24
25    /// A connection to the database
26    type Connection: Sized + Send + Sync;
27
28    /// All operations may return this error type
29    type Error: From<io::Error> + Send;
30
31    /// Connect to a given address
32    async fn connect(address: &Self::Address) -> Result<Self::Connection, Self::Error>;
33
34    /// Check if the connection is in a good state.
35    /// If None, the status is unknown and the database should be pinged.
36    fn check_alive(connection: &Self::Connection) -> Option<bool>;
37
38    /// Ping the database
39    async fn ping(connection: &mut Self::Connection) -> Result<(), Self::Error>;
40
41    /// Reset the connection to a fresh state for reuse
42    /// If this doesn't perform a network operation, returns None
43    fn reset_connection(
44        _connection: &mut Self::Connection,
45    ) -> Option<BoxFuture<'_, Result<(), Self::Error>>> {
46        None
47    }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct ConfigBuilder<C: ConnectionManager> {
52    pub address: Option<C::Address>,
53    pub min_size: Option<usize>,
54    pub max_size: Option<usize>,
55}
56
57// #[derive(Default)] doesn't work, I think b/c of the type parameter
58impl<C: ConnectionManager> Default for ConfigBuilder<C> {
59    fn default() -> Self {
60        ConfigBuilder {
61            address: None,
62            min_size: None,
63            max_size: None,
64        }
65    }
66}
67
68impl<C: ConnectionManager> ConfigBuilder<C> {
69    pub fn new() -> ConfigBuilder<C> {
70        Self::default()
71    }
72
73    pub fn address(&mut self, val: C::Address) -> &mut Self {
74        self.address = Some(val);
75        self
76    }
77
78    pub fn min_size(&mut self, val: Option<usize>) -> &mut Self {
79        self.min_size = val;
80        self
81    }
82
83    pub fn max_size(&mut self, val: Option<usize>) -> &mut Self {
84        self.max_size = val;
85        self
86    }
87
88    pub fn build(&mut self) -> Config<C> {
89        Config {
90            address: self
91                .address
92                .take()
93                .expect("ConfigBuilder address not specified"),
94            min_size: self.min_size.take().unwrap_or(0),
95            max_size: self.max_size.take().unwrap_or(100),
96        }
97    }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
101pub struct Config<C: ConnectionManager> {
102    pub address: C::Address,
103    pub min_size: usize,
104    pub max_size: usize,
105}
106
107struct PoolShared<C: ConnectionManager> {
108    config: Config<C>,
109    idle_queue: ArrayQueue<C::Connection>,
110    /// Approximate at any given time because it's concurrent
111    idle_queue_len: AtomicUsize,
112    permits: Arc<Semaphore>,
113}
114
115pub struct Pool<C: ConnectionManager>(Arc<PoolShared<C>>);
116
117impl<C: ConnectionManager> Clone for Pool<C> {
118    fn clone(&self) -> Pool<C> {
119        Pool(self.0.clone())
120    }
121}
122
123pub struct PoolConnection<C: ConnectionManager + 'static> {
124    connection: Option<(C::Connection, OwnedSemaphorePermit)>,
125    pool: Pool<C>,
126}
127
128impl<C: ConnectionManager> Deref for PoolConnection<C> {
129    type Target = C::Connection;
130
131    fn deref(&self) -> &C::Connection {
132        &self
133            .connection
134            .as_ref()
135            .expect("PoolConnection doesn't have an underlying connection")
136            .0
137    }
138}
139
140impl<C: ConnectionManager> DerefMut for PoolConnection<C> {
141    fn deref_mut(&mut self) -> &mut C::Connection {
142        &mut self
143            .connection
144            .as_mut()
145            .expect("PoolConnection doesn't have an underlying connection")
146            .0
147    }
148}
149
150impl<C: ConnectionManager> Drop for PoolConnection<C> {
151    fn drop(&mut self) {
152        let connection = match self.connection.take() {
153            Some(c) => c,
154            None => return,
155        };
156        let pool = self.pool.clone();
157        tokio::spawn(async move {
158            let (mut conn, _permit) = connection;
159            let is_alive = match C::reset_connection(&mut conn) {
160                Some(fut) => Some(fut.await.is_ok()),
161                None => None,
162            };
163            let is_alive = match is_alive {
164                Some(x) => x,
165                None => C::check_alive(&conn).unwrap_or(true),
166            };
167            if is_alive {
168                // Ignore if we can't recycle because the queue is full?
169                // TODO: should log
170                if pool.0.idle_queue.push(conn).is_ok() {
171                    pool.0
172                        .idle_queue_len
173                        .fetch_add(1, atomic::Ordering::Relaxed);
174                }
175            }
176        });
177    }
178}
179
180impl<C: ConnectionManager> Pool<C>
181where
182    C: 'static,
183    C::Address: Debug,
184    C::Error: Debug,
185{
186    /// Creates a new pool and fills it to the minimum idle size.
187    pub async fn new(config: Config<C>) -> Self {
188        let idle_queue = ArrayQueue::new(config.max_size);
189        let mut init_len = 0;
190        assert!(config.max_size >= config.min_size);
191        let mut some_failed = false;
192        for _ in 0..config.min_size {
193            match C::connect(&config.address).await {
194                Ok(conn) => {
195                    idle_queue
196                        .push(conn)
197                        .ok()
198                        .expect("Pool queue must have the capacity to allocate idle connections");
199                    init_len += 1;
200                }
201                Err(err) => {
202                    if !some_failed {
203                        some_failed = true;
204                        tracing::warn!(
205                            "During pool initial connections to {:?} {:?}",
206                            config.address,
207                            err,
208                        );
209                    }
210                }
211            }
212        }
213        let permits = Arc::new(Semaphore::new(config.max_size));
214        let this = Pool(Arc::new(PoolShared {
215            idle_queue,
216            idle_queue_len: AtomicUsize::new(init_len),
217            config,
218            permits,
219        }));
220        tokio::spawn(Self::keepalive(Arc::downgrade(&this.0)));
221        this
222    }
223
224    /// Keeps idle pool connections alive by pinging them regularly.
225    async fn keepalive(weak: Weak<PoolShared<C>>) {
226        loop {
227            let mut idle_count;
228            {
229                let this = match weak.upgrade() {
230                    Some(arc) => Pool(arc),
231                    None => return,
232                };
233
234                if let Some(mut conn) = this.try_get_idle_connection().await {
235                    if let Err(err) = C::ping(&mut conn).await {
236                        tracing::warn!("Failed to ping DB connection: {:?}", err);
237                    }
238                }
239                idle_count = this.0.idle_queue_len.load(atomic::Ordering::Relaxed);
240            }
241            if idle_count == 0 {
242                idle_count = 1;
243            }
244            let delay = Duration::from_secs(60) / (idle_count as u32);
245            sleep(delay).await;
246        }
247    }
248
249    /// Get an idle connection from the queue
250    fn idle_connection(&self) -> Option<C::Connection> {
251        let connection = self.0.idle_queue.pop()?;
252        self.0
253            .idle_queue_len
254            .fetch_sub(1, atomic::Ordering::Relaxed);
255        Some(connection)
256    }
257
258    /// Attempt to get an idle connection from the pool, along with it's permit.
259    /// This method will not replenish the number of idle connections.
260    async fn try_get_idle_connection(&self) -> Option<PoolConnection<C>> {
261        let permit = self.0.permits.clone().try_acquire_owned().ok()?;
262        Some(PoolConnection {
263            connection: Some((self.idle_connection()?, permit)),
264            pool: (*self).clone(),
265        })
266    }
267
268    /// Get a connection from the pool or create a new one. If the pool is at capacity
269    /// this will wait until a connection is available.
270    pub async fn get_connection(&self) -> Result<PoolConnection<C>, C::Error> {
271        let permit = self
272            .0
273            .permits
274            .clone()
275            .acquire_owned()
276            .await
277            .map_err(|_| io::Error::new(io::ErrorKind::Other, "Connection pool closed"))?;
278        self.get_connection_internal(permit).await
279    }
280
281    /// Attempt to get a connection from the pool or create a new one. If the pool is at
282    /// capacity this will return an error without waiting.
283    pub async fn try_get_connection(&self) -> Result<PoolConnection<C>, C::Error> {
284        let permit = self.0.permits.clone().try_acquire_owned().map_err(|_| {
285            io::Error::new(io::ErrorKind::Other, "Connection pool size reached maximum")
286        })?;
287        self.get_connection_internal(permit).await
288    }
289
290    /// Attempt to get a connection from the pool or create a new one, waiting up to `timeout`
291    /// before returning an error if the pool is at capacity.
292    pub async fn get_connection_timeout(
293        &self,
294        timeout: Duration,
295    ) -> Result<PoolConnection<C>, C::Error> {
296        let permit = tokio::time::timeout(timeout, self.0.permits.clone().acquire_owned())
297            .await
298            .map_err(|_| {
299                io::Error::new(io::ErrorKind::Other, "Connection pool size reached maximum")
300            })?
301            .map_err(|_| io::Error::new(io::ErrorKind::Other, "Connection pool closed"))?;
302        self.get_connection_internal(permit).await
303    }
304
305    fn create_connection(
306        &self,
307        permit: OwnedSemaphorePermit,
308        conn: C::Connection,
309    ) -> PoolConnection<C> {
310        PoolConnection {
311            connection: Some((conn, permit)),
312            pool: (*self).clone(),
313        }
314    }
315
316    /// Returns a connection from the pool. If there's no idle connection, creates a new one.
317    async fn get_connection_internal(
318        &self,
319        mut permit: OwnedSemaphorePermit,
320    ) -> Result<PoolConnection<C>, C::Error> {
321        loop {
322            match self.idle_connection() {
323                Some(c) => {
324                    // Wrap the connection in a PoolConnection to ensure it's returned to the queue
325                    // if this future is canceled
326                    let mut conn = self.create_connection(permit, c);
327                    let alive = match C::check_alive(&conn) {
328                        Some(alive) => alive,
329                        None => C::ping(&mut conn).await.is_ok(),
330                    };
331                    if alive {
332                        break Ok(conn);
333                    } else {
334                        // Extract the bad connection from the PoolConnection so that it's not
335                        // returned to the queue
336                        let c = conn
337                            .connection
338                            .take()
339                            .expect("PoolConnection doesn't have an underlying connection");
340                        permit = c.1;
341                    }
342                }
343                None => {
344                    let conn = C::connect(&self.0.config.address).await?;
345                    break Ok(self.create_connection(permit, conn));
346                }
347            }
348        }
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::error::Error;
356    use std::time::Duration;
357    use tokio::time::timeout;
358
359    struct TestConnection;
360
361    #[async_trait]
362    impl ConnectionManager for TestConnection {
363        type Address = ();
364        type Connection = Self;
365        type Error = Error;
366        async fn connect(_address: &Self::Address) -> Result<Self::Connection, Self::Error> {
367            Ok(TestConnection)
368        }
369
370        fn check_alive(_connection: &Self::Connection) -> Option<bool> {
371            Some(true)
372        }
373
374        async fn ping(_connection: &mut Self::Connection) -> Result<(), Self::Error> {
375            Ok(())
376        }
377    }
378
379    #[tokio::test]
380    async fn test_connection_pool() {
381        let config = ConfigBuilder::<TestConnection>::new()
382            .address(())
383            .max_size(Some(3))
384            .build();
385
386        let pool = Pool::<TestConnection>::new(config).await;
387
388        let mut connections = Vec::with_capacity(3);
389        // Get all 3 connections we can have
390        for _ in 0..3 {
391            connections.push(
392                pool.try_get_connection()
393                    .await
394                    .expect("Unable to get connection"),
395            );
396        }
397
398        // Check if getting a new connection fails
399        assert!(pool.try_get_connection().await.is_err());
400
401        // Pop a connection, this will drop the pool connection
402        // and eventually free the connection slot
403        connections.pop();
404
405        // Check if we can get the connection, get_connection will wait until one is available
406        timeout(Duration::from_millis(1), pool.get_connection())
407            .await
408            .expect("get_connection timed out")
409            .expect("Unable to get connection");
410    }
411}