1use std::num::NonZeroU64;
2use std::time::Duration;
3
4use bb8::ManageConnection;
5use tokio::time::timeout;
6
7use crate::prelude::*;
8use crate::settings::Settings;
9use crate::{Client, ClientBuilder, ClientOptions, ConnectionStatus, Destination, Error, Result};
10
11pub type NativeConnectionPoolBuilder = ConnectionPoolBuilder<NativeFormat>;
13pub type ArrowConnectionPoolBuilder = ConnectionPoolBuilder<ArrowFormat>;
15pub type NativeConnectionManager = ConnectionManager<NativeFormat>;
17pub type ArrowConnectionManager = ConnectionManager<ArrowFormat>;
19pub type PoolBuilder<T> = bb8::Builder<ConnectionManager<T>>;
21pub type NativePoolBuilder = bb8::Builder<ConnectionManager<NativeFormat>>;
23pub type ArrowPoolBuilder = bb8::Builder<ConnectionManager<ArrowFormat>>;
25pub type ConnectionPool<T> = bb8::Pool<ConnectionManager<T>>;
27
28pub struct ConnectionPoolBuilder<T: ClientFormat> {
30 client_builder: ClientBuilder,
31 pool: PoolBuilder<T>,
32 check_health: bool,
33}
34
35impl<T: ClientFormat> ConnectionPoolBuilder<T> {
36 pub fn new<A: Into<Destination>>(destination: A) -> Self {
39 let client_builder = ClientBuilder::new().with_destination(destination);
40 Self { pool: bb8::Builder::new(), client_builder, check_health: false }
41 }
42
43 pub fn with_client_builder(client_builder: ClientBuilder) -> Self {
45 Self { pool: bb8::Builder::new(), client_builder, check_health: false }
46 }
47
48 pub fn connection_identifier(&self) -> String { self.client_builder.connection_identifier() }
50
51 pub fn client_options(&self) -> &ClientOptions { self.client_builder.options() }
53
54 pub fn client_settings(&self) -> Option<&Settings> { self.client_builder.settings() }
56
57 #[must_use]
59 pub fn with_check(mut self) -> Self {
60 self.check_health = true;
61 self
62 }
63
64 #[must_use]
66 pub fn configure_client<F>(mut self, f: F) -> Self
67 where
68 F: FnOnce(ClientBuilder) -> ClientBuilder,
69 {
70 self.client_builder = f(self.client_builder);
71 self
72 }
73
74 #[must_use]
76 pub fn configure_pool<F>(mut self, f: F) -> Self
77 where
78 F: FnOnce(PoolBuilder<T>) -> PoolBuilder<T>,
79 {
80 self.pool = f(self.pool);
81 self
82 }
83
84 pub async fn build_manager(&self) -> Result<ConnectionManager<T>> {
89 Ok(ConnectionManager::try_new_with_builder(self.client_builder.clone())
90 .await?
91 .with_check(self.check_health))
92 }
93
94 pub async fn build(self) -> Result<ConnectionPool<T>> {
100 let manager = ConnectionManager::try_new_with_builder(self.client_builder)
101 .await?
102 .with_check(self.check_health);
103 self.pool.build(manager).await
104 }
105}
106
107#[derive(Clone)]
109pub struct ConnectionManager<T: ClientFormat> {
110 builder: ClientBuilder,
111 check_health: bool,
112 _phantom: std::marker::PhantomData<Client<T>>,
113}
114
115impl<T: ClientFormat> ConnectionManager<T> {
116 #[instrument(
131 level = "trace",
132 name = "clickhouse.pool.try_new",
133 fields(db.system = "clickhouse"),
134 skip_all
135 )]
136 pub async fn try_new<A: Into<Destination>, S: Into<Settings>>(
137 destination: A,
138 options: ClientOptions,
139 settings: Option<S>,
140 span: Option<NonZeroU64>,
141 ) -> Result<Self> {
142 let builder = ClientBuilder::new()
143 .with_options(options)
144 .with_destination(destination)
145 .with_trace_context(TraceContext::from(span))
146 .with_settings(settings.map(Into::into).unwrap_or_default());
147 Self::try_new_with_builder(builder).await
148 }
149
150 #[instrument(
162 level = "trace",
163 name = "clickhouse.pool.try_new_with_builder",
164 fields(db.system = "clickhouse"),
165 skip_all
166 )]
167 pub async fn try_new_with_builder(builder: ClientBuilder) -> Result<Self> {
168 let builder = builder.verify().await?;
170 Ok(Self { builder, check_health: false, _phantom: std::marker::PhantomData })
171 }
172
173 #[must_use]
175 pub fn with_check(mut self, check: bool) -> Self {
176 self.check_health = check;
177 self
178 }
179
180 #[cfg(feature = "cloud")]
183 #[must_use]
184 pub fn with_cloud_track(
185 mut self,
186 track: std::sync::Arc<std::sync::atomic::AtomicBool>,
187 ) -> Self {
188 self.builder = self.builder.with_cloud_track(track);
189 self
190 }
191
192 pub fn connection_identifier(&self) -> String { self.builder.connection_identifier() }
194
195 async fn connect(&self) -> Result<Client<T>> { self.builder.clone().build().await }
196}
197
198impl<T: ClientFormat> ManageConnection for ConnectionManager<T> {
199 type Connection = Client<T>;
200 type Error = Error;
201
202 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
203 debug!("Connecting to ClickHouse...");
204 self.connect()
205 .await
206 .inspect(|c| trace!({ { ATT_CID } = c.client_id }, "Connection established"))
207 .inspect_err(|error| error!(?error, "Connection failed"))
208 }
209
210 async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
211 match conn.status() {
212 ConnectionStatus::Error => {
213 error!("Connection validation failed: Error");
214 Err(Error::ConnectionGone("Connection in error state"))
215 }
216 ConnectionStatus::Closed => {
217 warn!("Connection validation failed: Closed");
218 Err(Error::ConnectionGone("Connection in closed state"))
219 }
220 ConnectionStatus::Open => {
221 let id = conn.client_id;
222 let timeout_duration = Duration::from_secs(2);
223 return match timeout(timeout_duration, conn.health_check(self.check_health)).await {
227 Ok(Ok(())) => Ok(()),
228 Ok(Err(error)) => {
229 warn!(?error, { ATT_CID } = id, "Health check failed");
230 Err(error)
231 }
232 Err(_) => Err(Error::ConnectionTimeout("Health check timed out".into())),
233 };
234 }
235 }
236 }
237
238 fn has_broken(&self, conn: &mut Self::Connection) -> bool {
239 matches!(conn.status(), ConnectionStatus::Error | ConnectionStatus::Closed)
240 }
241}
242
243#[derive(Debug, Clone, Copy)]
244pub struct ExponentialBackoff {
245 current_interval: Duration,
246 factor: f64,
247 max_interval: Duration,
248 max_elapsed_time: Option<Duration>,
249 attempts: u32,
250}
251
252impl ExponentialBackoff {
253 pub fn new() -> Self {
254 ExponentialBackoff {
255 current_interval: Duration::from_millis(10), factor: 2.0,
257 max_interval: Duration::from_secs(60),
258 max_elapsed_time: Some(Duration::from_secs(900)), attempts: 0,
260 }
261 }
262
263 pub fn next_backoff(&mut self) -> Option<Duration> {
264 self.attempts += 1;
265
266 if let Some(max_time) = self.max_elapsed_time
267 && self.current_interval * self.attempts > max_time
268 {
269 return None;
270 }
271
272 #[expect(clippy::cast_possible_wrap)]
273 let next_interval =
274 self.current_interval.mul_f64(self.factor.powi(self.attempts as i32 - 1));
275
276 Some(next_interval.min(self.max_interval))
277 }
278}
279
280impl Default for ExponentialBackoff {
281 fn default() -> Self { Self::new() }
282}