pravega_connection_pool/
connection_pool.rs1use async_trait::async_trait;
12use dashmap::DashMap;
13use pravega_client_shared::PravegaNodeUri;
14use snafu::Snafu;
15use std::fmt;
16use std::ops::{Deref, DerefMut};
17use uuid::Uuid;
18
19#[derive(Debug, Snafu)]
20#[snafu(visibility = "pub(crate)")]
21pub enum ConnectionPoolError {
22 #[snafu(display("Could not establish connection to endpoint: {}", endpoint))]
23 EstablishConnection { endpoint: String, error_msg: String },
24
25 #[snafu(display("No available connection in the internal pool"))]
26 NoAvailableConnection {},
27}
28
29#[async_trait]
74pub trait Manager {
75 type Conn: Send + Sized;
77
78 async fn establish_connection(&self, endpoint: PravegaNodeUri)
80 -> Result<Self::Conn, ConnectionPoolError>;
81
82 fn is_valid(&self, conn: &Self::Conn) -> bool;
85
86 fn get_max_connections(&self) -> u32;
88
89 fn name(&self) -> String;
90}
91
92pub struct ConnectionPool<M>
95where
96 M: Manager,
97{
98 manager: M,
99
100 managed_pool: ManagedPool<M::Conn>,
103}
104
105impl<M> ConnectionPool<M>
106where
107 M: Manager,
108{
109 pub fn new(manager: M) -> Self {
112 let managed_pool = ManagedPool::new(manager.get_max_connections());
113 ConnectionPool {
114 manager,
115 managed_pool,
116 }
117 }
118
119 pub async fn get_connection(
125 &self,
126 endpoint: PravegaNodeUri,
127 ) -> Result<PooledConnection<'_, M::Conn>, ConnectionPoolError> {
128 loop {
130 match self.managed_pool.get_connection(endpoint.clone()) {
131 Ok(internal_conn) => {
132 let conn = internal_conn.conn;
133 if self.manager.is_valid(&conn) {
134 return Ok(PooledConnection {
135 uuid: internal_conn.uuid,
136 inner: Some(conn),
137 endpoint,
138 pool: &self.managed_pool,
139 });
140 }
141
142 }
144 Err(_e) => {
145 let conn = self.manager.establish_connection(endpoint.clone()).await?;
146 return Ok(PooledConnection {
147 uuid: Uuid::new_v4(),
148 inner: Some(conn),
149 endpoint,
150 pool: &self.managed_pool,
151 });
152 }
153 }
154 }
155 }
156
157 pub fn pool_len(&self, endpoint: &PravegaNodeUri) -> usize {
159 self.managed_pool.pool_len(endpoint)
160 }
161}
162
163impl<M: Manager> fmt::Debug for ConnectionPool<M> {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 f.debug_struct("ConnectionPool")
166 .field("managed pool name", &self.manager.name())
167 .field("managed pool", &self.managed_pool)
168 .finish()
169 }
170}
171
172struct ManagedPool<T: Sized + Send> {
175 map: DashMap<PravegaNodeUri, InternalPool<T>>,
176 max_connections: u32,
177}
178
179impl<T: Sized + Send> ManagedPool<T> {
180 pub fn new(max_connections: u32) -> Self {
181 let map = DashMap::new();
182 ManagedPool { map, max_connections }
183 }
184
185 fn add_connection(&self, endpoint: PravegaNodeUri, connection: InternalConn<T>) {
187 let mut internal = self.map.entry(endpoint).or_insert_with(InternalPool::new);
188 if self.max_connections > internal.conns.len() as u32 {
189 internal.conns.push(connection);
190 }
191 }
192
193 fn get_connection(&self, endpoint: PravegaNodeUri) -> Result<InternalConn<T>, ConnectionPoolError> {
195 let mut internal = self.map.entry(endpoint).or_insert_with(InternalPool::new);
196 if internal.conns.is_empty() {
197 Err(ConnectionPoolError::NoAvailableConnection {})
198 } else {
199 let conn = internal.conns.pop().expect("pop connection from vec");
200 Ok(conn)
201 }
202 }
203
204 fn pool_len(&self, endpoint: &PravegaNodeUri) -> usize {
206 self.map.get(endpoint).map_or(0, |pool| pool.conns.len())
207 }
208}
209
210impl<T: Sized + Send> fmt::Debug for ManagedPool<T> {
211 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212 f.debug_struct("ManagedPool")
213 .field("internal map", &self.map)
214 .field("max connection", &self.max_connections)
215 .finish()
216 }
217}
218
219struct InternalConn<T> {
221 uuid: Uuid,
222 conn: T,
223}
224
225struct InternalPool<T> {
227 conns: Vec<InternalConn<T>>,
228}
229
230impl<T: Send + Sized> InternalPool<T> {
231 fn new() -> Self {
232 InternalPool { conns: vec![] }
233 }
234}
235
236impl<T> fmt::Debug for InternalPool<T> {
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 f.debug_struct("InternalPool")
239 .field("pool size", &self.conns.len())
240 .finish()
241 }
242}
243
244pub struct PooledConnection<'a, T: Send + Sized> {
247 uuid: Uuid,
248 endpoint: PravegaNodeUri,
249 inner: Option<T>,
250 pool: &'a ManagedPool<T>,
251}
252
253impl<T: Send + Sized> fmt::Debug for PooledConnection<'_, T> {
254 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
255 fmt::Debug::fmt(&self.uuid, fmt)
256 }
257}
258
259impl<T: Send + Sized> Drop for PooledConnection<'_, T> {
260 fn drop(&mut self) {
261 let conn = self.inner.take().expect("get inner connection");
262 self.pool.add_connection(
263 self.endpoint.clone(),
264 InternalConn {
265 uuid: self.uuid,
266 conn,
267 },
268 )
269 }
270}
271
272impl<T: Send + Sized> Deref for PooledConnection<'_, T> {
273 type Target = T;
274
275 fn deref(&self) -> &T {
276 self.inner.as_ref().expect("borrow inner connection")
277 }
278}
279
280impl<T: Send + Sized> DerefMut for PooledConnection<'_, T> {
281 fn deref_mut(&mut self) -> &mut T {
282 self.inner.as_mut().expect("mutably borrow inner connection")
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use std::sync::Arc;
290 use tokio::time::Duration;
291
292 struct FooConnection {}
293
294 struct FooManager {
295 max_connections_in_pool: u32,
296 }
297
298 #[async_trait]
299 impl Manager for FooManager {
300 type Conn = FooConnection;
301
302 async fn establish_connection(
303 &self,
304 _endpoint: PravegaNodeUri,
305 ) -> Result<Self::Conn, ConnectionPoolError> {
306 Ok(FooConnection {})
307 }
308
309 fn is_valid(&self, _conn: &Self::Conn) -> bool {
310 true
311 }
312
313 fn get_max_connections(&self) -> u32 {
314 self.max_connections_in_pool
315 }
316
317 fn name(&self) -> String {
318 "foo".to_string()
319 }
320 }
321
322 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
323 async fn test_connection_pool_basic() {
324 let manager = FooManager {
325 max_connections_in_pool: 2,
326 };
327 let pool = ConnectionPool::new(manager);
328 let endpoint = PravegaNodeUri::from("127.0.0.1:1000".to_string());
329
330 assert_eq!(pool.pool_len(&endpoint), 0);
331 let connection = pool
332 .get_connection(endpoint.clone())
333 .await
334 .expect("get connection");
335 assert_eq!(pool.pool_len(&endpoint), 0);
336 drop(connection);
337 assert_eq!(pool.pool_len(&endpoint), 1);
338 }
339
340 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
341 async fn test_connection_pool_size() {
342 const MAX_CONNECTION: u32 = 2;
343 let manager = FooManager {
344 max_connections_in_pool: MAX_CONNECTION,
345 };
346 let pool = Arc::new(ConnectionPool::new(manager));
347 let endpoint = PravegaNodeUri::from("127.0.0.1:1234".to_string());
348
349 let mut handles = vec![];
350 for _ in 0..10 {
351 let cloned_pool = pool.clone();
352 let endpoint_clone = endpoint.clone();
353 let handle = tokio::spawn(async move {
354 let _connection = cloned_pool
355 .get_connection(endpoint_clone)
356 .await
357 .expect("get connection");
358 tokio::time::sleep(Duration::from_millis(500)).await;
359 });
360 handles.push(handle);
361 }
362
363 while !handles.is_empty() {
364 let handle = handles.pop().expect("get handle");
365 handle.await.expect("handle should work");
366 }
367
368 assert_eq!(pool.pool_len(&endpoint) as u32, MAX_CONNECTION);
369 }
370}