pravega_connection_pool/
connection_pool.rs

1//
2// Copyright (c) Dell Inc., or its subsidiaries. All Rights Reserved.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10
11use async_trait::async_trait;
12use dashmap::DashMap;
13use pravega_client_shared::PravegaNodeUri;
14use snafu::Snafu;
15use std::fmt;
16use std::ops::{Deref, DerefMut};
17use uuid::Uuid;
18
19#[derive(Debug, Snafu)]
20#[snafu(visibility = "pub(crate)")]
21pub enum ConnectionPoolError {
22    #[snafu(display("Could not establish connection to endpoint: {}", endpoint))]
23    EstablishConnection { endpoint: String, error_msg: String },
24
25    #[snafu(display("No available connection in the internal pool"))]
26    NoAvailableConnection {},
27}
28
29/// Manager is a trait for defining custom connections. User can implement their own
30/// type of connection and their own way of establishing the connection in this trait.
31/// ConnectionPool will accept an implementation of this trait and manage the customized
32/// connection using the method that user provides
33/// # Example
34///
35/// ```no_run
36/// use async_trait::async_trait;
37/// use pravega_connection_pool::connection_pool::{Manager, ConnectionPoolError, ConnectionPool};
38/// use pravega_client_shared::PravegaNodeUri;
39/// use tokio::runtime::Runtime;
40///
41/// struct FooConnection {}
42///
43/// struct FooManager {}
44///
45/// #[async_trait]
46/// impl Manager for FooManager {
47/// type Conn = FooConnection;
48///
49/// async fn establish_connection(&self, endpoint: PravegaNodeUri) -> Result<Self::Conn, ConnectionPoolError> {
50///         unimplemented!()
51///     }
52///
53/// fn is_valid(&self,conn: &Self::Conn) -> bool {
54///         unimplemented!()
55///     }
56///
57/// fn get_max_connections(&self) -> u32 {
58///         unimplemented!()
59///     }
60///
61///
62/// fn name(&self) -> String {
63///         unimplemented!()
64///     }
65/// }
66///
67/// let mut rt = Runtime::new().unwrap();
68/// let manager = FooManager{};
69/// let pool = ConnectionPool::new(manager);
70/// let endpoint = PravegaNodeUri::from("tcp://127.0.0.1:12345");
71/// let connection = rt.block_on(pool.get_connection(endpoint));
72/// ```
73#[async_trait]
74pub trait Manager {
75    /// The customized connection must implement Send and Sized marker trait
76    type Conn: Send + Sized;
77
78    /// Define how to establish the customized connection
79    async fn establish_connection(&self, endpoint: PravegaNodeUri)
80        -> Result<Self::Conn, ConnectionPoolError>;
81
82    /// Check whether this connection is still valid. This method will be used to filter out
83    /// invalid connections when putting connection back to the pool
84    fn is_valid(&self, conn: &Self::Conn) -> bool;
85
86    /// Get the maximum connections in the pool
87    fn get_max_connections(&self) -> u32;
88
89    fn name(&self) -> String;
90}
91
92/// ConnectionPool creates a pool of connections for reuse.
93/// It is thread safe.
94pub struct ConnectionPool<M>
95where
96    M: Manager,
97{
98    manager: M,
99
100    /// managed_pool holds a map that maps endpoint to the internal pool.
101    /// each endpoint has its own internal pool.
102    managed_pool: ManagedPool<M::Conn>,
103}
104
105impl<M> ConnectionPool<M>
106where
107    M: Manager,
108{
109    /// Create a new ConnectionPoolImpl instances by passing into a ClientConfig. It will create
110    /// a Runtime, a map and a ConnectionFactory.
111    pub fn new(manager: M) -> Self {
112        let managed_pool = ManagedPool::new(manager.get_max_connections());
113        ConnectionPool {
114            manager,
115            managed_pool,
116        }
117    }
118
119    /// get_connection takes an endpoint and returns a PooledConnection. The PooledConnection is a
120    /// wrapper that contains a Connection that can be used to send and read.
121    ///
122    /// This method is thread safe and can be called concurrently. It will return an error if it fails
123    /// to establish connection to the remote server.
124    pub async fn get_connection(
125        &self,
126        endpoint: PravegaNodeUri,
127    ) -> Result<PooledConnection<'_, M::Conn>, ConnectionPoolError> {
128        // use an infinite loop.
129        loop {
130            match self.managed_pool.get_connection(endpoint.clone()) {
131                Ok(internal_conn) => {
132                    let conn = internal_conn.conn;
133                    if self.manager.is_valid(&conn) {
134                        return Ok(PooledConnection {
135                            uuid: internal_conn.uuid,
136                            inner: Some(conn),
137                            endpoint,
138                            pool: &self.managed_pool,
139                        });
140                    }
141
142                    //if it is not valid, will be deleted automatically
143                }
144                Err(_e) => {
145                    let conn = self.manager.establish_connection(endpoint.clone()).await?;
146                    return Ok(PooledConnection {
147                        uuid: Uuid::new_v4(),
148                        inner: Some(conn),
149                        endpoint,
150                        pool: &self.managed_pool,
151                    });
152                }
153            }
154        }
155    }
156
157    /// Returns the pool length of a specific internal pool
158    pub fn pool_len(&self, endpoint: &PravegaNodeUri) -> usize {
159        self.managed_pool.pool_len(endpoint)
160    }
161}
162
163impl<M: Manager> fmt::Debug for ConnectionPool<M> {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        f.debug_struct("ConnectionPool")
166            .field("managed pool name", &self.manager.name())
167            .field("managed pool", &self.managed_pool)
168            .finish()
169    }
170}
171
172// ManagedPool maintains a map that maps endpoint to InternalPool.
173// The map is a concurrent map named Dashmap, which supports multi-threading with high performance.
174struct ManagedPool<T: Sized + Send> {
175    map: DashMap<PravegaNodeUri, InternalPool<T>>,
176    max_connections: u32,
177}
178
179impl<T: Sized + Send> ManagedPool<T> {
180    pub fn new(max_connections: u32) -> Self {
181        let map = DashMap::new();
182        ManagedPool { map, max_connections }
183    }
184
185    // add a connection to the internal pool
186    fn add_connection(&self, endpoint: PravegaNodeUri, connection: InternalConn<T>) {
187        let mut internal = self.map.entry(endpoint).or_insert_with(InternalPool::new);
188        if self.max_connections > internal.conns.len() as u32 {
189            internal.conns.push(connection);
190        }
191    }
192
193    // get a connection from the internal pool. If there is no available connections, returns an error
194    fn get_connection(&self, endpoint: PravegaNodeUri) -> Result<InternalConn<T>, ConnectionPoolError> {
195        let mut internal = self.map.entry(endpoint).or_insert_with(InternalPool::new);
196        if internal.conns.is_empty() {
197            Err(ConnectionPoolError::NoAvailableConnection {})
198        } else {
199            let conn = internal.conns.pop().expect("pop connection from vec");
200            Ok(conn)
201        }
202    }
203
204    // return the pool length of the internal pool
205    fn pool_len(&self, endpoint: &PravegaNodeUri) -> usize {
206        self.map.get(endpoint).map_or(0, |pool| pool.conns.len())
207    }
208}
209
210impl<T: Sized + Send> fmt::Debug for ManagedPool<T> {
211    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212        f.debug_struct("ManagedPool")
213            .field("internal map", &self.map)
214            .field("max connection", &self.max_connections)
215            .finish()
216    }
217}
218
219// An internal connection struct that stores the uuid of the connection
220struct InternalConn<T> {
221    uuid: Uuid,
222    conn: T,
223}
224
225// An InternalPool that maintains a vector that stores all the connections.
226struct InternalPool<T> {
227    conns: Vec<InternalConn<T>>,
228}
229
230impl<T: Send + Sized> InternalPool<T> {
231    fn new() -> Self {
232        InternalPool { conns: vec![] }
233    }
234}
235
236impl<T> fmt::Debug for InternalPool<T> {
237    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238        f.debug_struct("InternalPool")
239            .field("pool size", &self.conns.len())
240            .finish()
241    }
242}
243
244/// A smart pointer wrapping a Connection so that the inner Connection can return to the ConnectionPool once
245/// this pointer is dropped.
246pub struct PooledConnection<'a, T: Send + Sized> {
247    uuid: Uuid,
248    endpoint: PravegaNodeUri,
249    inner: Option<T>,
250    pool: &'a ManagedPool<T>,
251}
252
253impl<T: Send + Sized> fmt::Debug for PooledConnection<'_, T> {
254    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
255        fmt::Debug::fmt(&self.uuid, fmt)
256    }
257}
258
259impl<T: Send + Sized> Drop for PooledConnection<'_, T> {
260    fn drop(&mut self) {
261        let conn = self.inner.take().expect("get inner connection");
262        self.pool.add_connection(
263            self.endpoint.clone(),
264            InternalConn {
265                uuid: self.uuid,
266                conn,
267            },
268        )
269    }
270}
271
272impl<T: Send + Sized> Deref for PooledConnection<'_, T> {
273    type Target = T;
274
275    fn deref(&self) -> &T {
276        self.inner.as_ref().expect("borrow inner connection")
277    }
278}
279
280impl<T: Send + Sized> DerefMut for PooledConnection<'_, T> {
281    fn deref_mut(&mut self) -> &mut T {
282        self.inner.as_mut().expect("mutably borrow inner connection")
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use std::sync::Arc;
290    use tokio::time::Duration;
291
292    struct FooConnection {}
293
294    struct FooManager {
295        max_connections_in_pool: u32,
296    }
297
298    #[async_trait]
299    impl Manager for FooManager {
300        type Conn = FooConnection;
301
302        async fn establish_connection(
303            &self,
304            _endpoint: PravegaNodeUri,
305        ) -> Result<Self::Conn, ConnectionPoolError> {
306            Ok(FooConnection {})
307        }
308
309        fn is_valid(&self, _conn: &Self::Conn) -> bool {
310            true
311        }
312
313        fn get_max_connections(&self) -> u32 {
314            self.max_connections_in_pool
315        }
316
317        fn name(&self) -> String {
318            "foo".to_string()
319        }
320    }
321
322    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
323    async fn test_connection_pool_basic() {
324        let manager = FooManager {
325            max_connections_in_pool: 2,
326        };
327        let pool = ConnectionPool::new(manager);
328        let endpoint = PravegaNodeUri::from("127.0.0.1:1000".to_string());
329
330        assert_eq!(pool.pool_len(&endpoint), 0);
331        let connection = pool
332            .get_connection(endpoint.clone())
333            .await
334            .expect("get connection");
335        assert_eq!(pool.pool_len(&endpoint), 0);
336        drop(connection);
337        assert_eq!(pool.pool_len(&endpoint), 1);
338    }
339
340    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
341    async fn test_connection_pool_size() {
342        const MAX_CONNECTION: u32 = 2;
343        let manager = FooManager {
344            max_connections_in_pool: MAX_CONNECTION,
345        };
346        let pool = Arc::new(ConnectionPool::new(manager));
347        let endpoint = PravegaNodeUri::from("127.0.0.1:1234".to_string());
348
349        let mut handles = vec![];
350        for _ in 0..10 {
351            let cloned_pool = pool.clone();
352            let endpoint_clone = endpoint.clone();
353            let handle = tokio::spawn(async move {
354                let _connection = cloned_pool
355                    .get_connection(endpoint_clone)
356                    .await
357                    .expect("get connection");
358                tokio::time::sleep(Duration::from_millis(500)).await;
359            });
360            handles.push(handle);
361        }
362
363        while !handles.is_empty() {
364            let handle = handles.pop().expect("get handle");
365            handle.await.expect("handle should work");
366        }
367
368        assert_eq!(pool.pool_len(&endpoint) as u32, MAX_CONNECTION);
369    }
370}