1use async_trait::async_trait;
2use crossbeam::queue::ArrayQueue;
3use futures::future::BoxFuture;
4use std::{
5 fmt::Debug,
6 io,
7 ops::{Deref, DerefMut},
8 sync::{
9 atomic::{self, AtomicUsize},
10 Arc, Weak,
11 },
12 time::Duration,
13};
14use tokio::{
15 sync::{OwnedSemaphorePermit, Semaphore},
16 time::sleep,
17};
18
19#[async_trait]
20pub trait ConnectionManager {
21 type Address: Clone + Send + Sync;
24
25 type Connection: Sized + Send + Sync;
27
28 type Error: From<io::Error> + Send;
30
31 async fn connect(address: &Self::Address) -> Result<Self::Connection, Self::Error>;
33
34 fn check_alive(connection: &Self::Connection) -> Option<bool>;
37
38 async fn ping(connection: &mut Self::Connection) -> Result<(), Self::Error>;
40
41 fn reset_connection(
44 _connection: &mut Self::Connection,
45 ) -> Option<BoxFuture<'_, Result<(), Self::Error>>> {
46 None
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct ConfigBuilder<C: ConnectionManager> {
52 pub address: Option<C::Address>,
53 pub min_size: Option<usize>,
54 pub max_size: Option<usize>,
55}
56
57impl<C: ConnectionManager> Default for ConfigBuilder<C> {
59 fn default() -> Self {
60 ConfigBuilder {
61 address: None,
62 min_size: None,
63 max_size: None,
64 }
65 }
66}
67
68impl<C: ConnectionManager> ConfigBuilder<C> {
69 pub fn new() -> ConfigBuilder<C> {
70 Self::default()
71 }
72
73 pub fn address(&mut self, val: C::Address) -> &mut Self {
74 self.address = Some(val);
75 self
76 }
77
78 pub fn min_size(&mut self, val: Option<usize>) -> &mut Self {
79 self.min_size = val;
80 self
81 }
82
83 pub fn max_size(&mut self, val: Option<usize>) -> &mut Self {
84 self.max_size = val;
85 self
86 }
87
88 pub fn build(&mut self) -> Config<C> {
89 Config {
90 address: self
91 .address
92 .take()
93 .expect("ConfigBuilder address not specified"),
94 min_size: self.min_size.take().unwrap_or(0),
95 max_size: self.max_size.take().unwrap_or(100),
96 }
97 }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
101pub struct Config<C: ConnectionManager> {
102 pub address: C::Address,
103 pub min_size: usize,
104 pub max_size: usize,
105}
106
107struct PoolShared<C: ConnectionManager> {
108 config: Config<C>,
109 idle_queue: ArrayQueue<C::Connection>,
110 idle_queue_len: AtomicUsize,
112 permits: Arc<Semaphore>,
113}
114
115pub struct Pool<C: ConnectionManager>(Arc<PoolShared<C>>);
116
117impl<C: ConnectionManager> Clone for Pool<C> {
118 fn clone(&self) -> Pool<C> {
119 Pool(self.0.clone())
120 }
121}
122
123pub struct PoolConnection<C: ConnectionManager + 'static> {
124 connection: Option<(C::Connection, OwnedSemaphorePermit)>,
125 pool: Pool<C>,
126}
127
128impl<C: ConnectionManager> Deref for PoolConnection<C> {
129 type Target = C::Connection;
130
131 fn deref(&self) -> &C::Connection {
132 &self
133 .connection
134 .as_ref()
135 .expect("PoolConnection doesn't have an underlying connection")
136 .0
137 }
138}
139
140impl<C: ConnectionManager> DerefMut for PoolConnection<C> {
141 fn deref_mut(&mut self) -> &mut C::Connection {
142 &mut self
143 .connection
144 .as_mut()
145 .expect("PoolConnection doesn't have an underlying connection")
146 .0
147 }
148}
149
150impl<C: ConnectionManager> Drop for PoolConnection<C> {
151 fn drop(&mut self) {
152 let connection = match self.connection.take() {
153 Some(c) => c,
154 None => return,
155 };
156 let pool = self.pool.clone();
157 tokio::spawn(async move {
158 let (mut conn, _permit) = connection;
159 let is_alive = match C::reset_connection(&mut conn) {
160 Some(fut) => Some(fut.await.is_ok()),
161 None => None,
162 };
163 let is_alive = match is_alive {
164 Some(x) => x,
165 None => C::check_alive(&conn).unwrap_or(true),
166 };
167 if is_alive {
168 if pool.0.idle_queue.push(conn).is_ok() {
171 pool.0
172 .idle_queue_len
173 .fetch_add(1, atomic::Ordering::Relaxed);
174 }
175 }
176 });
177 }
178}
179
180impl<C: ConnectionManager> Pool<C>
181where
182 C: 'static,
183 C::Address: Debug,
184 C::Error: Debug,
185{
186 pub async fn new(config: Config<C>) -> Self {
188 let idle_queue = ArrayQueue::new(config.max_size);
189 let mut init_len = 0;
190 assert!(config.max_size >= config.min_size);
191 let mut some_failed = false;
192 for _ in 0..config.min_size {
193 match C::connect(&config.address).await {
194 Ok(conn) => {
195 idle_queue
196 .push(conn)
197 .ok()
198 .expect("Pool queue must have the capacity to allocate idle connections");
199 init_len += 1;
200 }
201 Err(err) => {
202 if !some_failed {
203 some_failed = true;
204 tracing::warn!(
205 "During pool initial connections to {:?} {:?}",
206 config.address,
207 err,
208 );
209 }
210 }
211 }
212 }
213 let permits = Arc::new(Semaphore::new(config.max_size));
214 let this = Pool(Arc::new(PoolShared {
215 idle_queue,
216 idle_queue_len: AtomicUsize::new(init_len),
217 config,
218 permits,
219 }));
220 tokio::spawn(Self::keepalive(Arc::downgrade(&this.0)));
221 this
222 }
223
224 async fn keepalive(weak: Weak<PoolShared<C>>) {
226 loop {
227 let mut idle_count;
228 {
229 let this = match weak.upgrade() {
230 Some(arc) => Pool(arc),
231 None => return,
232 };
233
234 if let Some(mut conn) = this.try_get_idle_connection().await {
235 if let Err(err) = C::ping(&mut conn).await {
236 tracing::warn!("Failed to ping DB connection: {:?}", err);
237 }
238 }
239 idle_count = this.0.idle_queue_len.load(atomic::Ordering::Relaxed);
240 }
241 if idle_count == 0 {
242 idle_count = 1;
243 }
244 let delay = Duration::from_secs(60) / (idle_count as u32);
245 sleep(delay).await;
246 }
247 }
248
249 fn idle_connection(&self) -> Option<C::Connection> {
251 let connection = self.0.idle_queue.pop()?;
252 self.0
253 .idle_queue_len
254 .fetch_sub(1, atomic::Ordering::Relaxed);
255 Some(connection)
256 }
257
258 async fn try_get_idle_connection(&self) -> Option<PoolConnection<C>> {
261 let permit = self.0.permits.clone().try_acquire_owned().ok()?;
262 Some(PoolConnection {
263 connection: Some((self.idle_connection()?, permit)),
264 pool: (*self).clone(),
265 })
266 }
267
268 pub async fn get_connection(&self) -> Result<PoolConnection<C>, C::Error> {
271 let permit = self
272 .0
273 .permits
274 .clone()
275 .acquire_owned()
276 .await
277 .map_err(|_| io::Error::new(io::ErrorKind::Other, "Connection pool closed"))?;
278 self.get_connection_internal(permit).await
279 }
280
281 pub async fn try_get_connection(&self) -> Result<PoolConnection<C>, C::Error> {
284 let permit = self.0.permits.clone().try_acquire_owned().map_err(|_| {
285 io::Error::new(io::ErrorKind::Other, "Connection pool size reached maximum")
286 })?;
287 self.get_connection_internal(permit).await
288 }
289
290 pub async fn get_connection_timeout(
293 &self,
294 timeout: Duration,
295 ) -> Result<PoolConnection<C>, C::Error> {
296 let permit = tokio::time::timeout(timeout, self.0.permits.clone().acquire_owned())
297 .await
298 .map_err(|_| {
299 io::Error::new(io::ErrorKind::Other, "Connection pool size reached maximum")
300 })?
301 .map_err(|_| io::Error::new(io::ErrorKind::Other, "Connection pool closed"))?;
302 self.get_connection_internal(permit).await
303 }
304
305 fn create_connection(
306 &self,
307 permit: OwnedSemaphorePermit,
308 conn: C::Connection,
309 ) -> PoolConnection<C> {
310 PoolConnection {
311 connection: Some((conn, permit)),
312 pool: (*self).clone(),
313 }
314 }
315
316 async fn get_connection_internal(
318 &self,
319 mut permit: OwnedSemaphorePermit,
320 ) -> Result<PoolConnection<C>, C::Error> {
321 loop {
322 match self.idle_connection() {
323 Some(c) => {
324 let mut conn = self.create_connection(permit, c);
327 let alive = match C::check_alive(&conn) {
328 Some(alive) => alive,
329 None => C::ping(&mut conn).await.is_ok(),
330 };
331 if alive {
332 break Ok(conn);
333 } else {
334 let c = conn
337 .connection
338 .take()
339 .expect("PoolConnection doesn't have an underlying connection");
340 permit = c.1;
341 }
342 }
343 None => {
344 let conn = C::connect(&self.0.config.address).await?;
345 break Ok(self.create_connection(permit, conn));
346 }
347 }
348 }
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::error::Error;
356 use std::time::Duration;
357 use tokio::time::timeout;
358
359 struct TestConnection;
360
361 #[async_trait]
362 impl ConnectionManager for TestConnection {
363 type Address = ();
364 type Connection = Self;
365 type Error = Error;
366 async fn connect(_address: &Self::Address) -> Result<Self::Connection, Self::Error> {
367 Ok(TestConnection)
368 }
369
370 fn check_alive(_connection: &Self::Connection) -> Option<bool> {
371 Some(true)
372 }
373
374 async fn ping(_connection: &mut Self::Connection) -> Result<(), Self::Error> {
375 Ok(())
376 }
377 }
378
379 #[tokio::test]
380 async fn test_connection_pool() {
381 let config = ConfigBuilder::<TestConnection>::new()
382 .address(())
383 .max_size(Some(3))
384 .build();
385
386 let pool = Pool::<TestConnection>::new(config).await;
387
388 let mut connections = Vec::with_capacity(3);
389 for _ in 0..3 {
391 connections.push(
392 pool.try_get_connection()
393 .await
394 .expect("Unable to get connection"),
395 );
396 }
397
398 assert!(pool.try_get_connection().await.is_err());
400
401 connections.pop();
404
405 timeout(Duration::from_millis(1), pool.get_connection())
407 .await
408 .expect("get_connection timed out")
409 .expect("Unable to get connection");
410 }
411}