1use crate::core::MtopError;
2use async_trait::async_trait;
3use std::collections::HashMap;
4use std::fmt;
5use std::hash::Hash;
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicU64, Ordering};
8use tokio::sync::Mutex;
9
10#[async_trait]
15pub trait ClientFactory<K, V> {
16 async fn make(&self, key: &K) -> Result<V, MtopError>;
18}
19
20#[derive(Debug, Clone)]
21pub(crate) struct ClientPoolConfig {
22 pub name: String,
23 pub max_idle: u64,
24}
25
26pub(crate) struct ClientPool<K, V>
27where
28 K: Eq + Hash + Clone + fmt::Debug + fmt::Display,
29{
30 clients: Mutex<HashMap<K, Vec<PooledClient<K, V>>>>,
31 config: ClientPoolConfig,
32 factory: Box<dyn ClientFactory<K, V> + Send + Sync>,
33 ids: AtomicU64,
34}
35
36impl<K, V> ClientPool<K, V>
37where
38 K: Eq + Hash + Clone + fmt::Debug + fmt::Display,
39{
40 pub(crate) fn new<F>(config: ClientPoolConfig, factory: F) -> Self
41 where
42 F: ClientFactory<K, V> + Send + Sync + 'static,
43 {
44 Self {
45 clients: Mutex::new(HashMap::new()),
46 factory: Box::new(factory),
47 ids: AtomicU64::new(0),
48 config,
49 }
50 }
51
52 pub(crate) async fn get(&self, key: &K) -> Result<PooledClient<K, V>, MtopError> {
53 let client = {
57 let mut clients = self.clients.lock().await;
58 clients.get_mut(key).and_then(|v| v.pop())
59 };
60
61 match client {
62 Some(c) => {
63 tracing::trace!(message = "using existing client", pool = self.config.name, server = %key, id = c.id);
64 Ok(c)
65 }
66 None => {
67 tracing::trace!(message = "creating new client", pool = self.config.name, server = %key);
68 let inner = self.factory.make(key).await?;
69 Ok(PooledClient {
70 id: self.ids.fetch_add(1, Ordering::Relaxed),
71 key: key.clone(),
72 inner,
73 })
74 }
75 }
76 }
77
78 pub(crate) async fn put(&self, client: PooledClient<K, V>) {
79 let mut clients = self.clients.lock().await;
80 let entries = clients.entry(client.key.clone()).or_default();
81 if (entries.len() as u64) < self.config.max_idle {
82 entries.push(client);
83 }
84 }
85}
86
87impl<K, V> fmt::Debug for ClientPool<K, V>
88where
89 K: Eq + Hash + Clone + fmt::Debug + fmt::Display,
90{
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 f.debug_struct("ClientPool")
93 .field("config", &self.config)
94 .field("clients", &"...")
95 .field("factory", &"...")
96 .finish()
97 }
98}
99
100#[derive(Debug)]
103pub struct PooledClient<K, V> {
104 id: u64,
105 key: K,
106 inner: V,
107}
108
109impl<K, V> Deref for PooledClient<K, V> {
110 type Target = V;
111
112 fn deref(&self) -> &Self::Target {
113 &self.inner
114 }
115}
116
117impl<K, V> DerefMut for PooledClient<K, V> {
118 fn deref_mut(&mut self) -> &mut Self::Target {
119 &mut self.inner
120 }
121}
122
123#[cfg(test)]
124mod test {
125 use super::{ClientFactory, ClientPool, ClientPoolConfig};
126 use crate::core::MtopError;
127 use async_trait::async_trait;
128 use std::sync::Arc;
129 use std::sync::atomic::{AtomicU64, Ordering};
130
131 struct CountingClient {
132 dropped: Arc<AtomicU64>,
133 }
134
135 impl Drop for CountingClient {
136 fn drop(&mut self) {
137 self.dropped.fetch_add(1, Ordering::Release);
138 }
139 }
140
141 struct CountingClientFactory {
142 created: Arc<AtomicU64>,
143 dropped: Arc<AtomicU64>,
144 }
145
146 #[async_trait]
147 impl ClientFactory<String, CountingClient> for CountingClientFactory {
148 async fn make(&self, _key: &String) -> Result<CountingClient, MtopError> {
149 self.created.fetch_add(1, Ordering::Release);
150
151 Ok(CountingClient {
152 dropped: self.dropped.clone(),
153 })
154 }
155 }
156
157 fn new_pool(created: Arc<AtomicU64>, dropped: Arc<AtomicU64>) -> ClientPool<String, CountingClient> {
158 let factory = CountingClientFactory {
159 created: created.clone(),
160 dropped: dropped.clone(),
161 };
162 let pool_cfg = ClientPoolConfig {
163 name: "test".to_owned(),
164 max_idle: 1,
165 };
166
167 ClientPool::new(pool_cfg, factory)
168 }
169
170 #[tokio::test]
171 async fn test_client_pool_get_empty_pool() {
172 let created = Arc::new(AtomicU64::new(0));
173 let dropped = Arc::new(AtomicU64::new(0));
174 let pool = new_pool(created.clone(), dropped.clone());
175
176 let _client = pool.get(&"whatever".to_owned()).await.unwrap();
177
178 assert_eq!(1, created.load(Ordering::Acquire));
179 assert_eq!(0, dropped.load(Ordering::Acquire));
180 }
181
182 #[tokio::test]
183 async fn test_client_pool_get_existing_client() {
184 let created = Arc::new(AtomicU64::new(0));
185 let dropped = Arc::new(AtomicU64::new(0));
186 let pool = new_pool(created.clone(), dropped.clone());
187
188 let client1 = pool.get(&"whatever".to_owned()).await.unwrap();
189 pool.put(client1).await;
190 let _client2 = pool.get(&"whatever".to_owned()).await.unwrap();
191
192 assert_eq!(1, created.load(Ordering::Acquire));
193 assert_eq!(0, dropped.load(Ordering::Acquire));
194 }
195
196 #[tokio::test]
197 async fn test_client_pool_put_at_max_idle() {
198 let created = Arc::new(AtomicU64::new(0));
199 let dropped = Arc::new(AtomicU64::new(0));
200 let pool = new_pool(created.clone(), dropped.clone());
201
202 let client1 = pool.get(&"whatever".to_owned()).await.unwrap();
203 let client2 = pool.get(&"whatever".to_owned()).await.unwrap();
204 pool.put(client1).await;
205 pool.put(client2).await;
206
207 assert_eq!(2, created.load(Ordering::Acquire));
208 assert_eq!(1, dropped.load(Ordering::Acquire));
209 }
210
211 #[tokio::test]
212 async fn test_client_pool_put_zero_max_idle() {
213 let created = Arc::new(AtomicU64::new(0));
214 let dropped = Arc::new(AtomicU64::new(0));
215 let mut pool = new_pool(created.clone(), dropped.clone());
216 pool.config.max_idle = 0;
217
218 let client = pool.get(&"whatever".to_owned()).await.unwrap();
219 pool.put(client).await;
220
221 assert_eq!(1, created.load(Ordering::Acquire));
222 assert_eq!(1, dropped.load(Ordering::Acquire));
223 }
224}