1use clickhouse::Client;
2use std::fmt::{Debug, Formatter, Result as FmtResult};
3use std::sync::{Arc, Mutex};
4use tokio::sync::{OwnedSemaphorePermit, Semaphore};
5mod error;
6use crate::error::Error;
7
8pub struct ConnectionPool {
9 clients: Arc<Mutex<Vec<Client>>>,
10 semaphore: Arc<Semaphore>,
11}
12
13impl ConnectionPool {
14 pub async fn spawn(params: impl Into<String>, count: usize) -> Result<Self, Error> {
21 let params = params.into();
22 let mut clients = Vec::with_capacity(count);
23
24 for _ in 0..count {
25 let client = connect(params.clone()).await?;
26 clients.push(client);
27 }
28
29 Ok(ConnectionPool {
30 clients: Arc::new(Mutex::new(clients)),
31 semaphore: Arc::new(Semaphore::new(count)),
32 })
33 }
34
35 pub async fn acquire(&self) -> Result<ClientWrapper, Error> {
44 let permit = self.semaphore.clone().acquire_owned().await?;
45
46 let client = {
47 let mut clients = self.clients.lock().unwrap();
48 clients.pop()
49 };
50
51 if let Some(client) = client {
52 Ok(ClientWrapper {
53 client: Some(client),
54 pool: self.clone(),
55 _permit: permit,
56 })
57 } else {
58 drop(permit);
60 Err(Error::Unknown)
61 }
62 }
63}
64
65impl Clone for ConnectionPool {
66 fn clone(&self) -> Self {
67 ConnectionPool {
68 clients: Arc::clone(&self.clients),
69 semaphore: Arc::clone(&self.semaphore),
70 }
71 }
72}
73
74pub struct ClientWrapper {
76 client: Option<Client>,
77 pool: ConnectionPool,
78 _permit: OwnedSemaphorePermit,
79}
80
81impl ClientWrapper {
82 pub fn client(&self) -> &Client {
84 self.client.as_ref().unwrap()
85 }
86
87 pub fn client_mut(&mut self) -> &mut Client {
89 self.client.as_mut().unwrap()
90 }
91}
92
93impl Drop for ClientWrapper {
94 fn drop(&mut self) {
95 if let Some(client) = self.client.take() {
96 let mut clients = self.pool.clients.lock().unwrap();
97 clients.push(client);
98 }
99 }
100}
101
102impl Debug for ConnectionPool {
103 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
104 f.write_str("ConnectionPool { ... }")
105 }
106}
107
108async fn connect(params: impl Into<String>) -> Result<Client, Error> {
109 let client = Client::default().with_url(params);
110
111 Ok(client)
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use clickhouse::test;
118 use futures::future::join_all;
119 use tokio;
120
121 use once_cell::sync::Lazy;
122
123 static MOCK: Lazy<test::Mock> = Lazy::new(|| test::Mock::new());
124
125 #[tokio::test]
126 async fn test_pool_limits() {
127 let pool_size = 2;
128
129 let pool = ConnectionPool::spawn(MOCK.url(), pool_size)
130 .await
131 .expect("Failed to spawn pool");
132
133 let client1 = pool.acquire().await.expect("Failed to acquire client 1");
134 let client2 = pool.acquire().await.expect("Failed to acquire client 2");
135
136 let pool_clone = pool.clone();
137 let acquire_future = tokio::spawn(async move {
138 pool_clone
139 .acquire()
140 .await
141 .expect("Failed to acquire client 3")
142 });
143
144 drop(client1);
145
146 let client3 = acquire_future.await.expect("Failed to await client 3");
147
148 drop(client2);
149 drop(client3);
150 }
151
152 #[tokio::test]
153 async fn test_concurrent_acquisitions() {
154 let pool_size = 5;
155 let task_count = 10;
156
157 let pool = ConnectionPool::spawn(MOCK.url(), pool_size)
158 .await
159 .expect("Failed to spawn pool");
160
161 let mut tasks = Vec::new();
162
163 for i in 0..task_count {
164 let pool = pool.clone();
165 tasks.push(tokio::spawn(async move {
166 let client_wrapper = pool.acquire().await.expect("Failed to acquire client");
167 let client = client_wrapper.client();
168
169 let result: u64 = client
170 .query("SELECT number FROM system.numbers LIMIT 1 OFFSET ?")
171 .bind(i)
172 .fetch_one()
173 .await
174 .expect("Failed to fetch number");
175
176 assert_eq!(result, i as u64);
177 }));
178 }
179
180 join_all(tasks).await;
181 }
182}