mtop_client/
pool.rs

1use crate::core::MtopError;
2use async_trait::async_trait;
3use std::collections::HashMap;
4use std::fmt;
5use std::hash::Hash;
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicU64, Ordering};
8use tokio::sync::Mutex;
9
10/// Trait used by a client pool for creating new client instances when needed.
11///
12/// Implementations are expected to retain any required configuration for client
13/// instances beyond the identifier for an instance (usually a server address).
14#[async_trait]
15pub trait ClientFactory<K, V> {
16    /// Create a new client instance based on its ID.
17    async fn make(&self, key: &K) -> Result<V, MtopError>;
18}
19
20#[derive(Debug, Clone)]
21pub(crate) struct ClientPoolConfig {
22    pub name: String,
23    pub max_idle: u64,
24}
25
26pub(crate) struct ClientPool<K, V>
27where
28    K: Eq + Hash + Clone + fmt::Debug + fmt::Display,
29{
30    clients: Mutex<HashMap<K, Vec<PooledClient<K, V>>>>,
31    config: ClientPoolConfig,
32    factory: Box<dyn ClientFactory<K, V> + Send + Sync>,
33    ids: AtomicU64,
34}
35
36impl<K, V> ClientPool<K, V>
37where
38    K: Eq + Hash + Clone + fmt::Debug + fmt::Display,
39{
40    pub(crate) fn new<F>(config: ClientPoolConfig, factory: F) -> Self
41    where
42        F: ClientFactory<K, V> + Send + Sync + 'static,
43    {
44        Self {
45            clients: Mutex::new(HashMap::new()),
46            factory: Box::new(factory),
47            ids: AtomicU64::new(0),
48            config,
49        }
50    }
51
52    pub(crate) async fn get(&self, key: &K) -> Result<PooledClient<K, V>, MtopError> {
53        // Lock the clients HashMap and try to get an existing client in a limited scope
54        // so that we don't hold the lock while trying to connect if there are no exising
55        // clients.
56        let client = {
57            let mut clients = self.clients.lock().await;
58            clients.get_mut(key).and_then(|v| v.pop())
59        };
60
61        match client {
62            Some(c) => {
63                tracing::trace!(message = "using existing client", pool = self.config.name, server = %key, id = c.id);
64                Ok(c)
65            }
66            None => {
67                tracing::trace!(message = "creating new client", pool = self.config.name, server = %key);
68                let inner = self.factory.make(key).await?;
69                Ok(PooledClient {
70                    id: self.ids.fetch_add(1, Ordering::Relaxed),
71                    key: key.clone(),
72                    inner,
73                })
74            }
75        }
76    }
77
78    pub(crate) async fn put(&self, client: PooledClient<K, V>) {
79        let mut clients = self.clients.lock().await;
80        let entries = clients.entry(client.key.clone()).or_default();
81        if (entries.len() as u64) < self.config.max_idle {
82            entries.push(client);
83        }
84    }
85}
86
87impl<K, V> fmt::Debug for ClientPool<K, V>
88where
89    K: Eq + Hash + Clone + fmt::Debug + fmt::Display,
90{
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        f.debug_struct("ClientPool")
93            .field("config", &self.config)
94            .field("clients", &"...")
95            .field("factory", &"...")
96            .finish()
97    }
98}
99
100/// Wrapper for a client that belongs to a pool and must be returned
101/// to the pool when complete.
102#[derive(Debug)]
103pub struct PooledClient<K, V> {
104    id: u64,
105    key: K,
106    inner: V,
107}
108
109impl<K, V> Deref for PooledClient<K, V> {
110    type Target = V;
111
112    fn deref(&self) -> &Self::Target {
113        &self.inner
114    }
115}
116
117impl<K, V> DerefMut for PooledClient<K, V> {
118    fn deref_mut(&mut self) -> &mut Self::Target {
119        &mut self.inner
120    }
121}
122
123#[cfg(test)]
124mod test {
125    use super::{ClientFactory, ClientPool, ClientPoolConfig};
126    use crate::core::MtopError;
127    use async_trait::async_trait;
128    use std::sync::Arc;
129    use std::sync::atomic::{AtomicU64, Ordering};
130
131    struct CountingClient {
132        dropped: Arc<AtomicU64>,
133    }
134
135    impl Drop for CountingClient {
136        fn drop(&mut self) {
137            self.dropped.fetch_add(1, Ordering::Release);
138        }
139    }
140
141    struct CountingClientFactory {
142        created: Arc<AtomicU64>,
143        dropped: Arc<AtomicU64>,
144    }
145
146    #[async_trait]
147    impl ClientFactory<String, CountingClient> for CountingClientFactory {
148        async fn make(&self, _key: &String) -> Result<CountingClient, MtopError> {
149            self.created.fetch_add(1, Ordering::Release);
150
151            Ok(CountingClient {
152                dropped: self.dropped.clone(),
153            })
154        }
155    }
156
157    fn new_pool(created: Arc<AtomicU64>, dropped: Arc<AtomicU64>) -> ClientPool<String, CountingClient> {
158        let factory = CountingClientFactory {
159            created: created.clone(),
160            dropped: dropped.clone(),
161        };
162        let pool_cfg = ClientPoolConfig {
163            name: "test".to_owned(),
164            max_idle: 1,
165        };
166
167        ClientPool::new(pool_cfg, factory)
168    }
169
170    #[tokio::test]
171    async fn test_client_pool_get_empty_pool() {
172        let created = Arc::new(AtomicU64::new(0));
173        let dropped = Arc::new(AtomicU64::new(0));
174        let pool = new_pool(created.clone(), dropped.clone());
175
176        let _client = pool.get(&"whatever".to_owned()).await.unwrap();
177
178        assert_eq!(1, created.load(Ordering::Acquire));
179        assert_eq!(0, dropped.load(Ordering::Acquire));
180    }
181
182    #[tokio::test]
183    async fn test_client_pool_get_existing_client() {
184        let created = Arc::new(AtomicU64::new(0));
185        let dropped = Arc::new(AtomicU64::new(0));
186        let pool = new_pool(created.clone(), dropped.clone());
187
188        let client1 = pool.get(&"whatever".to_owned()).await.unwrap();
189        pool.put(client1).await;
190        let _client2 = pool.get(&"whatever".to_owned()).await.unwrap();
191
192        assert_eq!(1, created.load(Ordering::Acquire));
193        assert_eq!(0, dropped.load(Ordering::Acquire));
194    }
195
196    #[tokio::test]
197    async fn test_client_pool_put_at_max_idle() {
198        let created = Arc::new(AtomicU64::new(0));
199        let dropped = Arc::new(AtomicU64::new(0));
200        let pool = new_pool(created.clone(), dropped.clone());
201
202        let client1 = pool.get(&"whatever".to_owned()).await.unwrap();
203        let client2 = pool.get(&"whatever".to_owned()).await.unwrap();
204        pool.put(client1).await;
205        pool.put(client2).await;
206
207        assert_eq!(2, created.load(Ordering::Acquire));
208        assert_eq!(1, dropped.load(Ordering::Acquire));
209    }
210
211    #[tokio::test]
212    async fn test_client_pool_put_zero_max_idle() {
213        let created = Arc::new(AtomicU64::new(0));
214        let dropped = Arc::new(AtomicU64::new(0));
215        let mut pool = new_pool(created.clone(), dropped.clone());
216        pool.config.max_idle = 0;
217
218        let client = pool.get(&"whatever".to_owned()).await.unwrap();
219        pool.put(client).await;
220
221        assert_eq!(1, created.load(Ordering::Acquire));
222        assert_eq!(1, dropped.load(Ordering::Acquire));
223    }
224}