use std::num::NonZeroU64;
use std::time::Duration;
use bb8::ManageConnection;
use tokio::time::timeout;
use crate::prelude::*;
use crate::settings::Settings;
use crate::{Client, ClientBuilder, ClientOptions, ConnectionStatus, Destination, Error, Result};
pub type NativeConnectionPoolBuilder = ConnectionPoolBuilder<NativeFormat>;
pub type ArrowConnectionPoolBuilder = ConnectionPoolBuilder<ArrowFormat>;
pub type NativeConnectionManager = ConnectionManager<NativeFormat>;
pub type ArrowConnectionManager = ConnectionManager<ArrowFormat>;
pub type PoolBuilder<T> = bb8::Builder<ConnectionManager<T>>;
pub type NativePoolBuilder = bb8::Builder<ConnectionManager<NativeFormat>>;
pub type ArrowPoolBuilder = bb8::Builder<ConnectionManager<ArrowFormat>>;
pub type ConnectionPool<T> = bb8::Pool<ConnectionManager<T>>;
pub struct ConnectionPoolBuilder<T: ClientFormat> {
client_builder: ClientBuilder,
pool: PoolBuilder<T>,
check_health: bool,
}
impl<T: ClientFormat> ConnectionPoolBuilder<T> {
pub fn new<A: Into<Destination>>(destination: A) -> Self {
let client_builder = ClientBuilder::new().with_destination(destination);
Self { pool: bb8::Builder::new(), client_builder, check_health: false }
}
pub fn with_client_builder(client_builder: ClientBuilder) -> Self {
Self { pool: bb8::Builder::new(), client_builder, check_health: false }
}
pub fn connection_identifier(&self) -> String { self.client_builder.connection_identifier() }
pub fn client_options(&self) -> &ClientOptions { self.client_builder.options() }
pub fn client_settings(&self) -> Option<&Settings> { self.client_builder.settings() }
#[must_use]
pub fn with_check(mut self) -> Self {
self.check_health = true;
self
}
#[must_use]
pub fn configure_client<F>(mut self, f: F) -> Self
where
F: FnOnce(ClientBuilder) -> ClientBuilder,
{
self.client_builder = f(self.client_builder);
self
}
#[must_use]
pub fn configure_pool<F>(mut self, f: F) -> Self
where
F: FnOnce(PoolBuilder<T>) -> PoolBuilder<T>,
{
self.pool = f(self.pool);
self
}
pub async fn build_manager(&self) -> Result<ConnectionManager<T>> {
Ok(ConnectionManager::try_new_with_builder(self.client_builder.clone())
.await?
.with_check(self.check_health))
}
pub async fn build(self) -> Result<ConnectionPool<T>> {
let manager = ConnectionManager::try_new_with_builder(self.client_builder)
.await?
.with_check(self.check_health);
self.pool.build(manager).await
}
}
#[derive(Clone)]
pub struct ConnectionManager<T: ClientFormat> {
builder: ClientBuilder,
check_health: bool,
_phantom: std::marker::PhantomData<Client<T>>,
}
impl<T: ClientFormat> ConnectionManager<T> {
#[instrument(
level = "trace",
name = "clickhouse.pool.try_new",
fields(db.system = "clickhouse"),
skip_all
)]
pub async fn try_new<A: Into<Destination>, S: Into<Settings>>(
destination: A,
options: ClientOptions,
settings: Option<S>,
span: Option<NonZeroU64>,
) -> Result<Self> {
let builder = ClientBuilder::new()
.with_options(options)
.with_destination(destination)
.with_trace_context(TraceContext::from(span))
.with_settings(settings.map(Into::into).unwrap_or_default());
Self::try_new_with_builder(builder).await
}
#[instrument(
level = "trace",
name = "clickhouse.pool.try_new_with_builder",
fields(db.system = "clickhouse"),
skip_all
)]
pub async fn try_new_with_builder(builder: ClientBuilder) -> Result<Self> {
let builder = builder.verify().await?;
Ok(Self { builder, check_health: false, _phantom: std::marker::PhantomData })
}
#[must_use]
pub fn with_check(mut self, check: bool) -> Self {
self.check_health = check;
self
}
#[cfg(feature = "cloud")]
#[must_use]
pub fn with_cloud_track(
mut self,
track: std::sync::Arc<std::sync::atomic::AtomicBool>,
) -> Self {
self.builder = self.builder.with_cloud_track(track);
self
}
pub fn connection_identifier(&self) -> String { self.builder.connection_identifier() }
async fn connect(&self) -> Result<Client<T>> { self.builder.clone().build().await }
}
impl<T: ClientFormat> ManageConnection for ConnectionManager<T> {
type Connection = Client<T>;
type Error = Error;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
debug!("Connecting to ClickHouse...");
self.connect()
.await
.inspect(|c| trace!({ { ATT_CID } = c.client_id }, "Connection established"))
.inspect_err(|error| error!(?error, "Connection failed"))
}
async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
match conn.status() {
ConnectionStatus::Error => {
error!("Connection validation failed: Error");
Err(Error::ConnectionGone("Connection in error state"))
}
ConnectionStatus::Closed => {
warn!("Connection validation failed: Closed");
Err(Error::ConnectionGone("Connection in closed state"))
}
ConnectionStatus::Open => {
let id = conn.client_id;
let timeout_duration = Duration::from_secs(2);
return match timeout(timeout_duration, conn.health_check(self.check_health)).await {
Ok(Ok(())) => Ok(()),
Ok(Err(error)) => {
warn!(?error, { ATT_CID } = id, "Health check failed");
Err(error)
}
Err(_) => Err(Error::ConnectionTimeout("Health check timed out".into())),
};
}
}
}
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
matches!(conn.status(), ConnectionStatus::Error | ConnectionStatus::Closed)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ExponentialBackoff {
current_interval: Duration,
factor: f64,
max_interval: Duration,
max_elapsed_time: Option<Duration>,
attempts: u32,
}
impl ExponentialBackoff {
pub fn new() -> Self {
ExponentialBackoff {
current_interval: Duration::from_millis(10), factor: 2.0,
max_interval: Duration::from_secs(60),
max_elapsed_time: Some(Duration::from_secs(900)), attempts: 0,
}
}
pub fn next_backoff(&mut self) -> Option<Duration> {
self.attempts += 1;
if let Some(max_time) = self.max_elapsed_time
&& self.current_interval * self.attempts > max_time
{
return None;
}
#[expect(clippy::cast_possible_wrap)]
let next_interval =
self.current_interval.mul_f64(self.factor.powi(self.attempts as i32 - 1));
Some(next_interval.min(self.max_interval))
}
}
impl Default for ExponentialBackoff {
fn default() -> Self { Self::new() }
}