clickhouse_pool/
lib.rs

1use clickhouse::Client;
2use std::fmt::{Debug, Formatter, Result as FmtResult};
3use std::sync::{Arc, Mutex};
4use tokio::sync::{OwnedSemaphorePermit, Semaphore};
5mod error;
6use crate::error::Error;
7
8pub struct ConnectionPool {
9    clients: Arc<Mutex<Vec<Client>>>,
10    semaphore: Arc<Semaphore>,
11}
12
13impl ConnectionPool {
14    /// Spawns a new connection pool, given the address to the ClickHouse server,
15    /// and the maximum number of connections that the pool can spawn.
16    ///
17    /// # Errors
18    ///
19    /// Returns an error if any of the connections failed to instantiate.
20    pub async fn spawn(params: impl Into<String>, count: usize) -> Result<Self, Error> {
21        let params = params.into();
22        let mut clients = Vec::with_capacity(count);
23
24        for _ in 0..count {
25            let client = connect(params.clone()).await?;
26            clients.push(client);
27        }
28
29        Ok(ConnectionPool {
30            clients: Arc::new(Mutex::new(clients)),
31            semaphore: Arc::new(Semaphore::new(count)),
32        })
33    }
34
35    /// Acquires a `Client` from the pool.
36    ///
37    /// Returns a `ClientWrapper` which will automatically return the client
38    /// to the pool when dropped.
39    ///
40    /// # Errors
41    ///
42    /// Returns an error if the semaphore is closed or no clients are available.
43    pub async fn acquire(&self) -> Result<ClientWrapper, Error> {
44        let permit = self.semaphore.clone().acquire_owned().await?;
45
46        let client = {
47            let mut clients = self.clients.lock().unwrap();
48            clients.pop()
49        };
50
51        if let Some(client) = client {
52            Ok(ClientWrapper {
53                client: Some(client),
54                pool: self.clone(),
55                _permit: permit,
56            })
57        } else {
58            // This should not happen because the semaphore ensures that clients are available
59            drop(permit);
60            Err(Error::Unknown)
61        }
62    }
63}
64
65impl Clone for ConnectionPool {
66    fn clone(&self) -> Self {
67        ConnectionPool {
68            clients: Arc::clone(&self.clients),
69            semaphore: Arc::clone(&self.semaphore),
70        }
71    }
72}
73
74/// A wrapper around `Client` that returns it to the pool when dropped.
75pub struct ClientWrapper {
76    client: Option<Client>,
77    pool: ConnectionPool,
78    _permit: OwnedSemaphorePermit,
79}
80
81impl ClientWrapper {
82    /// Accesses the `Client`.
83    pub fn client(&self) -> &Client {
84        self.client.as_ref().unwrap()
85    }
86
87    /// Mutably accesses the `Client`.
88    pub fn client_mut(&mut self) -> &mut Client {
89        self.client.as_mut().unwrap()
90    }
91}
92
93impl Drop for ClientWrapper {
94    fn drop(&mut self) {
95        if let Some(client) = self.client.take() {
96            let mut clients = self.pool.clients.lock().unwrap();
97            clients.push(client);
98        }
99    }
100}
101
102impl Debug for ConnectionPool {
103    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
104        f.write_str("ConnectionPool { ... }")
105    }
106}
107
108async fn connect(params: impl Into<String>) -> Result<Client, Error> {
109    let client = Client::default().with_url(params);
110
111    Ok(client)
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use clickhouse::test;
118    use futures::future::join_all;
119    use tokio;
120
121    use once_cell::sync::Lazy;
122
123    static MOCK: Lazy<test::Mock> = Lazy::new(|| test::Mock::new());
124
125    #[tokio::test]
126    async fn test_pool_limits() {
127        let pool_size = 2;
128
129        let pool = ConnectionPool::spawn(MOCK.url(), pool_size)
130            .await
131            .expect("Failed to spawn pool");
132
133        let client1 = pool.acquire().await.expect("Failed to acquire client 1");
134        let client2 = pool.acquire().await.expect("Failed to acquire client 2");
135
136        let pool_clone = pool.clone();
137        let acquire_future = tokio::spawn(async move {
138            pool_clone
139                .acquire()
140                .await
141                .expect("Failed to acquire client 3")
142        });
143
144        drop(client1);
145
146        let client3 = acquire_future.await.expect("Failed to await client 3");
147
148        drop(client2);
149        drop(client3);
150    }
151    
152    #[tokio::test]
153    async fn test_concurrent_acquisitions() {
154        let pool_size = 5;
155        let task_count = 10;
156
157        let pool = ConnectionPool::spawn(MOCK.url(), pool_size)
158            .await
159            .expect("Failed to spawn pool");
160
161        let mut tasks = Vec::new();
162
163        for i in 0..task_count {
164            let pool = pool.clone();
165            tasks.push(tokio::spawn(async move {
166                let client_wrapper = pool.acquire().await.expect("Failed to acquire client");
167                let client = client_wrapper.client();
168
169                let result: u64 = client
170                    .query("SELECT number FROM system.numbers LIMIT 1 OFFSET ?")
171                    .bind(i)
172                    .fetch_one()
173                    .await
174                    .expect("Failed to fetch number");
175
176                assert_eq!(result, i as u64);
177            }));
178        }
179
180        join_all(tasks).await;
181    }
182}