1use std::collections::LinkedList;
35use std::fmt;
36use std::io;
37use std::marker::PhantomData;
38use std::ops::{Add, Deref, DerefMut};
39use std::sync::{Arc, Mutex, MutexGuard};
40use std::time::{Duration, Instant};
41
42use async_trait::async_trait;
43use tokio::time::{delay_for, timeout};
44
45#[cfg(test)]
46mod test;
47
48#[async_trait]
50pub trait ManageConnection: Send + Sync + 'static {
51 type Connection: Send + 'static;
53
54 async fn connect(&self) -> io::Result<Self::Connection>;
56
57 async fn check(&self, conn: &mut Self::Connection) -> io::Result<()>;
62}
63
64fn other(msg: &str) -> io::Error {
65 io::Error::new(io::ErrorKind::Other, msg)
66}
67
68pub struct Builder<M>
70where
71 M: ManageConnection,
72{
73 pub max_lifetime: Option<Duration>,
74 pub idle_timeout: Option<Duration>,
75 pub connection_timeout: Option<Duration>,
76 pub max_size: u32,
77 pub check_interval: Option<Duration>,
78 _pd: PhantomData<M>,
79}
80
81impl<M> fmt::Debug for Builder<M>
82where
83 M: ManageConnection,
84{
85 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
86 fmt.debug_struct("Builder")
87 .field("max_size", &self.max_size)
88 .field("max_lifetime", &self.max_lifetime)
89 .field("idle_timeout", &self.idle_timeout)
90 .field("connection_timeout", &self.connection_timeout)
91 .finish()
92 }
93}
94
95impl<M> Default for Builder<M>
96where
97 M: ManageConnection,
98{
99 fn default() -> Self {
100 Builder {
101 max_lifetime: Some(Duration::from_secs(60 * 30)),
102 idle_timeout: Some(Duration::from_secs(3 * 60)),
103 connection_timeout: Some(Duration::from_secs(3)),
104 check_interval: Some(Duration::from_secs(3)),
105 max_size: 0,
106 _pd: PhantomData,
107 }
108 }
109}
110
111impl<M> Builder<M>
112where
113 M: ManageConnection,
114{
115 pub fn new() -> Self {
119 Builder::default()
120 }
121
122 pub fn max_lifetime(mut self, max_lifetime: Option<Duration>) -> Self {
131 if max_lifetime == Some(Duration::from_secs(0)) {
132 self
133 } else {
134 self.max_lifetime = max_lifetime;
135 self
136 }
137 }
138
139 pub fn idle_timeout(mut self, idle_timeout: Option<Duration>) -> Self {
147 if idle_timeout == Some(Duration::from_secs(0)) {
148 self
149 } else {
150 self.idle_timeout = idle_timeout;
151 self
152 }
153 }
154
155 pub fn connection_timeout(mut self, connection_timeout: Option<Duration>) -> Self {
163 if connection_timeout == Some(Duration::from_secs(0)) {
164 self
165 } else {
166 self.connection_timeout = connection_timeout;
167 self
168 }
169 }
170
171 pub fn max_size(mut self, max_size: u32) -> Self {
177 self.max_size = max_size;
178 self
179 }
180
181 pub fn check_interval(mut self, interval: Option<Duration>) -> Self {
185 self.check_interval = interval;
186 self
187 }
188
189 pub fn build(&self, manager: M) -> Pool<M>
191 where
192 M: ManageConnection,
193 {
194 let intervals = PoolInternals {
195 conns: LinkedList::new(),
196 active: 0,
197 };
198
199 let shared = SharedPool {
200 intervals: Mutex::new(intervals),
201 max_lifetime: self.max_lifetime,
202 idle_timeout: self.idle_timeout,
203 connection_timeout: self.connection_timeout,
204 max_size: self.max_size,
205 check_interval: self.check_interval,
206 manager,
207 };
208
209 let pool = Pool(Arc::new(shared));
210 tokio::spawn(pool.clone().check());
211 pool
212 }
213}
214
215pub struct Connection<M>
217where
218 M: ManageConnection,
219{
220 conn: Option<IdleConn<M::Connection>>,
221 pool: Pool<M>,
222}
223
224impl<M> Drop for Connection<M>
225where
226 M: ManageConnection,
227{
228 fn drop(&mut self) {
229 if self.conn.is_some() {
230 self.pool.put(self.conn.take().unwrap());
231 }
232 }
233}
234
235impl<M> Deref for Connection<M>
236where
237 M: ManageConnection,
238{
239 type Target = M::Connection;
240
241 fn deref(&self) -> &M::Connection {
242 &self.conn.as_ref().unwrap().conn
243 }
244}
245
246impl<M> DerefMut for Connection<M>
247where
248 M: ManageConnection,
249{
250 fn deref_mut(&mut self) -> &mut M::Connection {
251 &mut self.conn.as_mut().unwrap().conn
252 }
253}
254
255pub struct Pool<M>(Arc<SharedPool<M>>)
257where
258 M: ManageConnection;
259
260impl<M> Clone for Pool<M>
261where
262 M: ManageConnection,
263{
264 fn clone(&self) -> Pool<M> {
265 Pool(self.0.clone())
266 }
267}
268
269impl<M> Pool<M>
270where
271 M: ManageConnection,
272{
273 pub fn new(manager: M) -> Pool<M> {
275 Pool::builder().build(manager)
276 }
277
278 pub fn builder() -> Builder<M> {
280 Builder::new()
281 }
282
283 pub(crate) fn interval<'a>(&'a self) -> MutexGuard<'a, PoolInternals<M::Connection>> {
284 self.0.intervals.lock().unwrap()
285 }
286
287 fn idle_count(&self) -> usize {
288 self.interval().conns.len()
289 }
290
291 fn incr_active(&self) {
292 self.interval().active += 1;
293 }
294
295 fn decr_active(&self) {
296 self.interval().active -= 1;
297 }
298
299 fn pop_front(&self) -> Option<IdleConn<M::Connection>> {
300 self.interval().conns.pop_front()
301 }
302
303 fn push_back(&mut self, conn: IdleConn<M::Connection>) {
304 self.interval().conns.push_back(conn);
305 }
306
307 fn exceed_idle_timeout(&self, conn: &IdleConn<M::Connection>) -> bool {
308 if let Some(idle_timeout) = self.0.idle_timeout {
309 if idle_timeout.as_micros() > 0 && conn.last_visited.add(idle_timeout) < Instant::now()
310 {
311 return true;
312 }
313 }
314
315 false
316 }
317
318 fn exceed_max_lifetime(&self, conn: &IdleConn<M::Connection>) -> bool {
319 if let Some(max_lifetime) = self.0.max_lifetime {
320 if max_lifetime.as_micros() > 0 && conn.created.add(max_lifetime) < Instant::now() {
321 return true;
322 }
323 }
324
325 false
326 }
327
328 async fn check(mut self) {
329 if let Some(interval) = self.0.check_interval {
330 loop {
331 delay_for(interval).await;
332 let n = self.idle_count();
333 for _ in 0..n {
334 if let Some(mut conn) = self.pop_front() {
335 if self.exceed_idle_timeout(&conn) || self.exceed_max_lifetime(&conn) {
336 self.decr_active();
337 continue;
338 }
339 match self.0.manager.check(&mut conn.conn).await {
340 Ok(_) => {
341 self.push_back(conn);
342 continue;
343 }
344 Err(_) => {
345 self.decr_active();
346 }
347 }
348 continue;
349 }
350 break;
351 }
352 }
353 }
354 }
355
356 fn exceed_limit(&self) -> bool {
357 let max_size = self.0.max_size;
358 if max_size > 0 && self.interval().active > max_size {
359 true
360 } else {
361 false
362 }
363 }
364
365 pub async fn get_timeout(
369 &self,
370 connection_timeout: Option<Duration>,
371 ) -> io::Result<M::Connection> {
372 if let Some(connection_timeout) = connection_timeout {
373 let conn = match timeout(connection_timeout, self.0.manager.connect()).await {
374 Ok(s) => match s {
375 Ok(s) => s,
376 Err(e) => {
377 return Err(other(&e.to_string()));
378 }
379 },
380 Err(e) => {
381 return Err(other(&e.to_string()));
382 }
383 };
384
385 Ok(conn)
386 } else {
387 let conn = self.0.manager.connect().await?;
388 Ok(conn)
389 }
390 }
391
392 pub async fn get(&self) -> io::Result<Connection<M>> {
397 if let Some(conn) = self.pop_front() {
398 return Ok(Connection {
399 conn: Some(conn),
400 pool: self.clone(),
401 });
402 }
403
404 self.incr_active();
405 if self.exceed_limit() {
406 self.decr_active();
407 return Err(other("exceed limit"));
408 }
409
410 let conn = self
411 .get_timeout(self.0.connection_timeout)
412 .await
413 .map_err(|e| {
414 self.decr_active();
415 e
416 })?;
417
418 return Ok(Connection {
419 conn: Some(IdleConn {
420 conn,
421 last_visited: Instant::now(),
422 created: Instant::now(),
423 }),
424 pool: self.clone(),
425 });
426 }
427
428 fn put(&mut self, mut conn: IdleConn<M::Connection>) {
429 conn.last_visited = Instant::now();
430 self.push_back(conn);
431 }
432}
433
434struct SharedPool<M>
435where
436 M: ManageConnection,
437{
438 intervals: Mutex<PoolInternals<M::Connection>>,
439 max_lifetime: Option<Duration>,
440 idle_timeout: Option<Duration>,
441 connection_timeout: Option<Duration>,
442 max_size: u32,
443 check_interval: Option<Duration>,
444 manager: M,
445}
446
447struct IdleConn<C> {
448 conn: C,
449 last_visited: Instant,
450 created: Instant,
451}
452
453struct PoolInternals<C> {
454 conns: LinkedList<IdleConn<C>>,
455 active: u32,
456}