use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tracing::{debug, info, warn};
use super::connection::{BrokerConnection, ConnectionConfig};
use crate::BrokerId;
use crate::error::{KrafkaError, Result};
use crate::metrics::ConnectionMetrics;
#[derive(Debug, Clone)]
pub struct ConnectionRetryConfig {
pub(crate) max_retries: u32,
pub(crate) initial_backoff: Duration,
pub(crate) max_backoff: Duration,
pub(crate) backoff_multiplier: f64,
pub(crate) jitter_factor: f64,
}
impl Default for ConnectionRetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(10),
backoff_multiplier: 2.0,
jitter_factor: 0.2,
}
}
}
impl ConnectionRetryConfig {
pub fn builder() -> ConnectionRetryConfigBuilder {
ConnectionRetryConfigBuilder::default()
}
#[inline]
pub fn max_retries(&self) -> u32 {
self.max_retries
}
#[inline]
pub fn initial_backoff(&self) -> Duration {
self.initial_backoff
}
#[inline]
pub fn max_backoff(&self) -> Duration {
self.max_backoff
}
#[inline]
pub fn backoff_multiplier(&self) -> f64 {
self.backoff_multiplier
}
#[inline]
pub fn jitter_factor(&self) -> f64 {
self.jitter_factor
}
#[inline]
fn calculate_backoff(&self, attempt: u32) -> Duration {
if attempt == 0 {
return Duration::ZERO;
}
let exponent = attempt.saturating_sub(1).min(i32::MAX as u32) as i32;
let base_backoff =
self.initial_backoff.as_secs_f64() * self.backoff_multiplier.powi(exponent);
let capped_backoff = base_backoff.min(self.max_backoff.as_secs_f64());
let jitter_range = capped_backoff * self.jitter_factor;
let jitter = if self.jitter_factor > 0.0 {
use rand::Rng;
let mut rng = rand::rng();
rng.random_range(-jitter_range..=jitter_range)
} else {
0.0
};
let final_backoff = (capped_backoff + jitter).max(0.0);
if !final_backoff.is_finite() {
warn!(
attempt,
final_backoff,
"Backoff calculation produced non-finite value, falling back to max_backoff"
);
return self.max_backoff;
}
Duration::from_secs_f64(final_backoff)
}
}
#[must_use = "builders do nothing until .build() is called"]
#[derive(Debug, Default)]
pub struct ConnectionRetryConfigBuilder {
config: ConnectionRetryConfig,
}
impl ConnectionRetryConfigBuilder {
pub fn max_retries(mut self, retries: u32) -> Self {
self.config.max_retries = retries;
self
}
pub fn initial_backoff(mut self, duration: Duration) -> Self {
self.config.initial_backoff = duration;
self
}
pub fn max_backoff(mut self, duration: Duration) -> Self {
self.config.max_backoff = duration;
self
}
pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
self.config.backoff_multiplier = if multiplier.is_finite() && multiplier > 0.0 {
multiplier
} else {
1.0
};
self
}
pub fn jitter_factor(mut self, factor: f64) -> Self {
self.config.jitter_factor = if factor.is_finite() {
factor.clamp(0.0, 1.0)
} else {
0.0
};
self
}
pub fn build(self) -> ConnectionRetryConfig {
self.config
}
}
pub struct BrokerConnectionBundle {
address: String,
connections: Vec<Arc<BrokerConnection>>,
counter: AtomicUsize,
}
impl BrokerConnectionBundle {
pub async fn connect(address: &str, config: ConnectionConfig) -> Result<Self> {
let num_connections = config.connections_per_broker.max(1);
if num_connections == 1 {
let conn = BrokerConnection::connect(address, config).await?;
return Ok(Self {
address: address.to_string(),
connections: vec![Arc::new(conn)],
counter: AtomicUsize::new(0),
});
}
let addr_owned = address.to_string();
let mut handles = Vec::with_capacity(num_connections);
for _ in 0..num_connections {
let addr = addr_owned.clone();
let cfg = config.clone();
handles.push(tokio::spawn(async move {
BrokerConnection::connect(&addr, cfg).await
}));
}
let mut connections = Vec::with_capacity(num_connections);
for handle in handles {
let conn = handle.await.map_err(|e| {
KrafkaError::invalid_state(format!("Connection task failed: {e}"))
})??;
connections.push(Arc::new(conn));
}
debug!(
"Created connection bundle with {} connections to {}",
connections.len(),
addr_owned
);
Ok(Self {
address: addr_owned,
connections,
counter: AtomicUsize::new(0),
})
}
#[inline]
pub fn select(&self) -> Arc<BrokerConnection> {
if self.connections.len() == 1 {
return self.connections[0].clone();
}
let index = self.counter.fetch_add(1, Ordering::Relaxed) % self.connections.len();
self.connections[index].clone()
}
#[inline]
pub fn get(&self, index: usize) -> Option<Arc<BrokerConnection>> {
self.connections
.get(index % self.connections.len())
.cloned()
}
#[inline]
pub fn first(&self) -> Arc<BrokerConnection> {
self.connections[0].clone()
}
#[inline]
pub fn address(&self) -> &str {
&self.address
}
#[inline]
pub fn len(&self) -> usize {
self.connections.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.connections.is_empty()
}
#[inline]
pub fn all_usable(&self) -> bool {
self.connections.iter().all(|c| c.is_usable())
}
#[inline]
pub fn any_usable(&self) -> bool {
self.connections.iter().any(|c| c.is_usable())
}
#[inline]
pub fn usable_count(&self) -> usize {
self.connections.iter().filter(|c| c.is_usable()).count()
}
pub fn select_usable(&self) -> Option<Arc<BrokerConnection>> {
let len = self.connections.len();
if len == 0 {
return None;
}
let start = self.counter.fetch_add(1, Ordering::Relaxed) % len;
for i in 0..len {
let index = (start + i) % len;
if self.connections[index].is_usable() {
return Some(self.connections[index].clone());
}
}
None
}
pub async fn close_all(&self) {
for conn in &self.connections {
conn.close().await;
}
}
}
pub const DEFAULT_MAX_IDLE: Duration = Duration::from_secs(9 * 60);
type ConnectingWaiters = HashMap<String, Vec<oneshot::Sender<Result<Arc<BrokerConnection>>>>>;
struct ReconnectGuard {
connecting: Arc<Mutex<ConnectingWaiters>>,
address: Option<String>,
}
impl ReconnectGuard {
fn new(connecting: &Arc<Mutex<ConnectingWaiters>>, address: String) -> Self {
Self {
connecting: Arc::clone(connecting),
address: Some(address),
}
}
fn defuse(&mut self) {
self.address = None;
}
}
impl Drop for ReconnectGuard {
fn drop(&mut self) {
let Some(address) = self.address.take() else {
return;
};
let mut guard = self.connecting.lock();
let waiters = guard.remove(&address).unwrap_or_default();
let err = KrafkaError::network(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
format!("reconnection to {address} was cancelled"),
));
for waiter in waiters {
let _ = waiter.send(Err(err.clone()));
}
}
}
pub struct ConnectionPool {
connections: RwLock<HashMap<BrokerId, Arc<BrokerConnection>>>,
connections_by_addr: RwLock<HashMap<String, Arc<BrokerConnection>>>,
connecting: Arc<Mutex<ConnectingWaiters>>,
config: ConnectionConfig,
retry_config: ConnectionRetryConfig,
max_idle: Option<Duration>,
evictor_handle: Mutex<Option<JoinHandle<()>>>,
}
impl ConnectionPool {
pub fn new(config: ConnectionConfig) -> Self {
Self {
connections: RwLock::new(HashMap::new()),
connections_by_addr: RwLock::new(HashMap::new()),
connecting: Arc::new(Mutex::new(HashMap::new())),
config,
retry_config: ConnectionRetryConfig::default(),
max_idle: Some(DEFAULT_MAX_IDLE),
evictor_handle: Mutex::new(None),
}
}
pub fn with_retry_config(
config: ConnectionConfig,
retry_config: ConnectionRetryConfig,
) -> Self {
Self {
connections: RwLock::new(HashMap::new()),
connections_by_addr: RwLock::new(HashMap::new()),
connecting: Arc::new(Mutex::new(HashMap::new())),
config,
retry_config,
max_idle: Some(DEFAULT_MAX_IDLE),
evictor_handle: Mutex::new(None),
}
}
pub fn start(config: ConnectionConfig) -> Arc<Self> {
let pool = Arc::new(Self::new(config));
pool.start_idle_evictor();
pool
}
#[inline]
pub fn metrics(&self) -> Arc<ConnectionMetrics> {
self.config.connection_metrics()
}
#[must_use]
pub fn with_max_idle(mut self, max_idle: Option<Duration>) -> Self {
self.max_idle = max_idle;
self
}
#[inline]
pub fn max_idle(&self) -> Option<Duration> {
self.max_idle
}
pub async fn refresh_tls(&self) -> crate::error::Result<()> {
self.config.refresh_tls().await
}
async fn reconnect_with_backoff(&self, address: &str) -> Result<Arc<BrokerConnection>> {
let mut last_error: Option<KrafkaError> = None;
for attempt in 0..=self.retry_config.max_retries {
if attempt > 0 {
let backoff = self.retry_config.calculate_backoff(attempt);
debug!(
address = %address,
attempt = attempt,
max_retries = self.retry_config.max_retries,
backoff_ms = backoff.as_millis(),
"Retrying connection after backoff"
);
tokio::time::sleep(backoff).await;
}
match BrokerConnection::connect(address, self.config.clone()).await {
Ok(conn) => {
if attempt > 0 {
info!(
address = %address,
attempt = attempt,
"Successfully reconnected after retries"
);
}
return Ok(Arc::new(conn));
}
Err(e) => {
if !e.is_retriable() {
warn!(
address = %address,
error = %e,
"Non-retriable connection error, not retrying"
);
return Err(e);
}
warn!(
address = %address,
attempt = attempt,
max_retries = self.retry_config.max_retries,
error = %e,
"Connection attempt failed"
);
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| {
KrafkaError::network(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
format!(
"Failed to connect to {} after {} retries",
address, self.retry_config.max_retries
),
))
}))
}
async fn get_or_reconnect(&self, address: &str) -> Result<Arc<BrokerConnection>> {
{
let conns = self.connections_by_addr.read();
if conns
.get(address)
.is_some_and(|c| c.is_alive() && c.needs_reauthentication())
{
info!(
address = %address,
"Replacing connection due to SASL session expiry (KIP-368)"
);
}
}
enum CoalesceAction {
AlreadyConnected(Arc<BrokerConnection>),
WaitForPeer(oneshot::Receiver<Result<Arc<BrokerConnection>>>),
Reconnect(String),
}
let existing = {
let conns = self.connections_by_addr.read();
conns.get(address).filter(|c| c.is_usable()).cloned()
};
let action = {
let mut connecting = self.connecting.lock();
if let Some(conn) = existing {
CoalesceAction::AlreadyConnected(conn)
} else if let Some(waiters) = connecting.get_mut(address) {
let (tx, rx) = oneshot::channel();
waiters.push(tx);
CoalesceAction::WaitForPeer(rx)
} else {
let addr_owned = address.to_string();
connecting.insert(addr_owned.clone(), Vec::new());
CoalesceAction::Reconnect(addr_owned)
}
};
let addr_owned = match action {
CoalesceAction::AlreadyConnected(conn) => return Ok(conn),
CoalesceAction::WaitForPeer(rx) => {
return rx.await.map_err(|_| {
KrafkaError::network(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
format!("reconnection to {address} was cancelled"),
))
})?;
}
CoalesceAction::Reconnect(addr_owned) => addr_owned,
};
let mut guard = ReconnectGuard::new(&self.connecting, addr_owned.clone());
let result = self.reconnect_with_backoff(address).await;
let waiters = self.connecting.lock().remove(address).unwrap_or_default();
if let Ok(conn) = &result {
self.connections_by_addr
.write()
.insert(addr_owned, conn.clone());
}
for waiter in waiters {
let _ = waiter.send(result.clone());
}
guard.defuse();
result
}
pub async fn get_connection(&self, address: &str) -> Result<Arc<BrokerConnection>> {
{
let conns = self.connections_by_addr.read();
if let Some(conn) = conns.get(address)
&& conn.is_usable()
{
return Ok(conn.clone());
}
}
self.get_or_reconnect(address).await
}
pub async fn get_connection_by_id(
&self,
broker_id: BrokerId,
address: &str,
) -> Result<Arc<BrokerConnection>> {
{
let conns = self.connections.read();
if let Some(conn) = conns.get(&broker_id)
&& conn.is_usable()
{
return Ok(conn.clone());
}
}
let conn = self.get_or_reconnect(address).await?;
{
let mut by_id = self.connections.write();
if !by_id.get(&broker_id).is_some_and(|c| c.is_usable()) {
by_id.insert(broker_id, conn.clone());
}
}
Ok(conn)
}
pub fn evict_idle(&self) -> usize {
let Some(max_idle) = self.max_idle else {
return 0;
};
let stale_ids: Vec<BrokerId> = {
let conns = self.connections.read();
conns
.iter()
.filter(|(_, c)| c.idle_duration() >= max_idle)
.map(|(id, _)| *id)
.collect()
};
let stale_addrs: Vec<String> = {
let conns = self.connections_by_addr.read();
conns
.iter()
.filter(|(_, c)| c.idle_duration() >= max_idle)
.map(|(addr, _)| addr.clone())
.collect()
};
if stale_ids.is_empty() && stale_addrs.is_empty() {
return 0;
}
let mut removed: Vec<Arc<BrokerConnection>> =
Vec::with_capacity(stale_ids.len() + stale_addrs.len());
if !stale_ids.is_empty() {
let mut conns = self.connections.write();
for id in &stale_ids {
if let Some(c) = conns.remove(id) {
if c.idle_duration() >= max_idle {
removed.push(c);
} else {
conns.insert(*id, c);
}
}
}
}
if !stale_addrs.is_empty() {
let mut conns = self.connections_by_addr.write();
for addr in &stale_addrs {
if let Some(c) = conns.remove(addr) {
if c.idle_duration() >= max_idle {
removed.push(c);
} else {
conns.insert(addr.clone(), c);
}
}
}
}
removed.sort_by_key(|c| Arc::as_ptr(c) as usize);
removed.dedup_by(|a, b| Arc::ptr_eq(a, b));
let count = removed.len();
if count > 0 {
debug!(
evicted = count,
max_idle_ms = max_idle.as_millis(),
"Evicted idle connections"
);
}
if tokio::runtime::Handle::try_current().is_ok() {
for conn in removed {
tokio::spawn(async move { conn.close().await });
}
}
count
}
pub fn start_idle_evictor(self: &Arc<Self>) {
let Some(max_idle) = self.max_idle else {
return;
};
if tokio::runtime::Handle::try_current().is_err() {
warn!("start_idle_evictor called outside a Tokio runtime; idle eviction disabled");
return;
}
let interval = (max_idle / 9)
.max(Duration::from_secs(1))
.min(Duration::from_secs(60));
let weak = Arc::downgrade(self);
let handle = tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
ticker.tick().await;
loop {
ticker.tick().await;
let Some(pool) = weak.upgrade() else {
break;
};
pool.evict_idle();
}
});
if let Some(prev) = self.evictor_handle.lock().replace(handle) {
prev.abort();
}
}
pub async fn close_all(&self) {
if let Some(handle) = self.evictor_handle.lock().take() {
handle.abort();
}
let by_id: Vec<Arc<BrokerConnection>> =
self.connections.write().drain().map(|(_, c)| c).collect();
let by_addr: Vec<Arc<BrokerConnection>> = self
.connections_by_addr
.write()
.drain()
.map(|(_, c)| c)
.collect();
let mut seen = HashMap::with_capacity(by_id.len() + by_addr.len());
for conn in by_id.into_iter().chain(by_addr) {
seen.entry(Arc::as_ptr(&conn) as usize).or_insert(conn);
}
{
let mut connecting = self.connecting.lock();
for (addr, waiters) in connecting.drain() {
let err = KrafkaError::network(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
format!("pool closed while reconnecting to {addr}"),
));
for waiter in waiters {
let _ = waiter.send(Err(err.clone()));
}
}
}
for conn in seen.into_values() {
conn.close().await;
}
}
pub fn len(&self) -> usize {
let connections = self.connections.read();
connections.values().filter(|c| c.is_usable()).count()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_connection_pool_new() {
let pool = ConnectionPool::new(ConnectionConfig::default());
let _ = pool;
}
#[test]
fn test_connection_retry_config_default() {
let config = ConnectionRetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.initial_backoff, Duration::from_millis(100));
assert_eq!(config.max_backoff, Duration::from_secs(10));
assert_eq!(config.backoff_multiplier, 2.0);
}
#[test]
fn test_calculate_backoff() {
let config = ConnectionRetryConfig {
jitter_factor: 0.0, ..ConnectionRetryConfig::default()
};
assert_eq!(config.calculate_backoff(0), Duration::ZERO);
assert_eq!(config.calculate_backoff(1), Duration::from_millis(100));
assert_eq!(config.calculate_backoff(2), Duration::from_millis(200));
assert_eq!(config.calculate_backoff(3), Duration::from_millis(400));
}
#[test]
fn test_calculate_backoff_capped() {
let config = ConnectionRetryConfig {
max_retries: 10,
initial_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(5),
backoff_multiplier: 10.0,
jitter_factor: 0.0, };
assert_eq!(config.calculate_backoff(2), Duration::from_secs(5));
}
#[test]
fn test_calculate_backoff_handles_max_attempt() {
let config = ConnectionRetryConfig {
max_retries: u32::MAX,
jitter_factor: 0.0,
..ConnectionRetryConfig::default()
};
assert_eq!(config.calculate_backoff(u32::MAX), config.max_backoff);
}
#[test]
fn test_connection_pool_with_retry_config() {
let retry_config = ConnectionRetryConfig {
max_retries: 5,
initial_backoff: Duration::from_millis(50),
max_backoff: Duration::from_secs(5),
backoff_multiplier: 3.0,
jitter_factor: 0.2,
};
let pool = ConnectionPool::with_retry_config(ConnectionConfig::default(), retry_config);
assert_eq!(pool.retry_config.max_retries, 5);
assert_eq!(pool.retry_config.initial_backoff, Duration::from_millis(50));
}
#[test]
fn test_connections_per_broker_config() {
let config = ConnectionConfig::default();
assert_eq!(config.connections_per_broker, 1);
let config = ConnectionConfig::builder()
.connections_per_broker(4)
.build();
assert_eq!(config.connections_per_broker, 4);
let config = ConnectionConfig::builder()
.connections_per_broker(0)
.build();
assert_eq!(config.connections_per_broker, 1);
}
#[tokio::test]
async fn test_pool_close_all_clears_both_maps() {
let pool = ConnectionPool::new(ConnectionConfig::default());
assert!(pool.connections.read().is_empty());
assert!(pool.connections_by_addr.read().is_empty());
pool.close_all().await;
}
#[test]
fn test_max_idle_default_matches_java_client() {
let pool = ConnectionPool::new(ConnectionConfig::default());
assert_eq!(pool.max_idle(), Some(Duration::from_millis(9 * 60 * 1000)));
assert_eq!(DEFAULT_MAX_IDLE, Duration::from_secs(540));
}
#[test]
fn test_with_max_idle_none_disables_eviction() {
let pool = ConnectionPool::new(ConnectionConfig::default()).with_max_idle(None);
assert_eq!(pool.max_idle(), None);
assert_eq!(pool.evict_idle(), 0);
}
#[test]
fn test_evict_idle_on_empty_pool_is_noop() {
let pool = ConnectionPool::new(ConnectionConfig::default());
assert_eq!(pool.evict_idle(), 0);
}
#[tokio::test]
async fn test_start_idle_evictor_installs_and_aborts_task() {
let pool = Arc::new(ConnectionPool::new(ConnectionConfig::default()));
assert!(pool.evictor_handle.lock().is_none());
pool.start_idle_evictor();
assert!(pool.evictor_handle.lock().is_some());
pool.start_idle_evictor();
assert!(pool.evictor_handle.lock().is_some());
pool.close_all().await;
assert!(pool.evictor_handle.lock().is_none());
}
#[tokio::test]
async fn test_start_idle_evictor_noop_when_max_idle_disabled() {
let pool = Arc::new(ConnectionPool::new(ConnectionConfig::default()).with_max_idle(None));
pool.start_idle_evictor();
assert!(pool.evictor_handle.lock().is_none());
}
#[test]
fn test_start_idle_evictor_noop_outside_tokio_runtime() {
let pool = Arc::new(ConnectionPool::new(ConnectionConfig::default()));
pool.start_idle_evictor();
assert!(
pool.evictor_handle.lock().is_none(),
"evictor must not be installed without a Tokio runtime"
);
}
#[test]
fn test_evict_idle_removes_stale_from_both_maps() {
let pool = ConnectionPool::new(ConnectionConfig::default())
.with_max_idle(Some(Duration::from_millis(100)));
let stale = Arc::new(BrokerConnection::test_stub_idle_for(
"b1:9092",
Duration::from_secs(10),
));
pool.connections.write().insert(1, stale.clone());
pool.connections_by_addr
.write()
.insert("b1:9092".to_string(), stale);
assert_eq!(pool.evict_idle(), 1);
assert!(pool.connections.read().is_empty());
assert!(pool.connections_by_addr.read().is_empty());
}
#[test]
fn test_evict_idle_retains_fresh_and_evicts_stale() {
let pool = ConnectionPool::new(ConnectionConfig::default())
.with_max_idle(Some(Duration::from_millis(100)));
let stale = Arc::new(BrokerConnection::test_stub_idle_for(
"b1:9092",
Duration::from_secs(10),
));
let fresh = Arc::new(BrokerConnection::test_stub_idle_for(
"b2:9092",
Duration::from_millis(10),
));
{
let mut w = pool.connections.write();
w.insert(1, stale);
w.insert(2, fresh);
}
assert_eq!(pool.evict_idle(), 1);
let kept = pool.connections.read();
assert!(!kept.contains_key(&1));
assert!(kept.contains_key(&2));
}
#[test]
fn test_evict_idle_rescued_after_refresh() {
let pool = ConnectionPool::new(ConnectionConfig::default())
.with_max_idle(Some(Duration::from_millis(100)));
let conn = Arc::new(BrokerConnection::test_stub_idle_for(
"b1:9092",
Duration::from_secs(10),
));
conn.test_mark_fresh();
pool.connections.write().insert(1, conn);
assert_eq!(pool.evict_idle(), 0);
assert!(pool.connections.read().contains_key(&1));
}
}