use std::cell::RefCell;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicI32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use crate::error::{KrafkaError, Result};
thread_local! {
static JITTER_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::from_os_rng());
}
#[derive(Debug, Clone)]
pub struct BackoffPolicy {
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
pub jitter_factor: f64,
}
impl Default for BackoffPolicy {
fn default() -> Self {
Self {
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(10),
backoff_multiplier: 2.0,
jitter_factor: 0.1,
}
}
}
impl BackoffPolicy {
#[inline]
pub fn calculate_backoff(&self, attempt: u32) -> Duration {
if attempt == 0 {
return Duration::ZERO;
}
let effective_max = self.max_backoff.max(self.initial_backoff);
let initial_secs = self.initial_backoff.as_secs_f64();
let effective_max_secs = effective_max.as_secs_f64();
let exponent = attempt.saturating_sub(1).min(i32::MAX as u32) as i32;
let base_backoff = if initial_secs >= effective_max_secs || exponent >= 1024 {
effective_max_secs
} else {
(initial_secs * self.backoff_multiplier.powi(exponent)).min(effective_max_secs)
};
let jitter_range = base_backoff * self.jitter_factor;
let jitter = if self.jitter_factor > 0.0 {
JITTER_RNG.with(|rng| rng.borrow_mut().random_range(-jitter_range..=jitter_range))
} else {
0.0
};
let final_backoff = (base_backoff + jitter).max(self.initial_backoff.as_secs_f64());
if !final_backoff.is_finite() {
tracing::warn!(
"BackoffPolicy::calculate_backoff produced non-finite value ({final_backoff}); capping at max_backoff"
);
return effective_max;
}
Duration::from_secs_f64(final_backoff)
}
#[inline]
pub fn initial_backoff(&self) -> Duration {
self.initial_backoff
}
#[inline]
pub fn max_backoff(&self) -> Duration {
self.max_backoff
}
#[inline]
pub fn backoff_multiplier(&self) -> f64 {
self.backoff_multiplier
}
#[inline]
pub fn jitter_factor(&self) -> f64 {
self.jitter_factor
}
}
pub(crate) const NO_RESPONSE_CORRELATION_ID: i32 = i32::MIN;
#[inline]
pub fn duration_to_millis_i32(d: Duration) -> i32 {
const WARN_INTERVAL_NANOS: u64 = 3600 * 1_000_000_000;
static BASELINE: OnceLock<Instant> = OnceLock::new();
static NEXT_WARN_NANOS: AtomicU64 = AtomicU64::new(0);
let ms = d.as_millis();
if ms > i32::MAX as u128 {
let now_nanos = BASELINE
.get_or_init(Instant::now)
.elapsed()
.as_nanos()
.min(u64::MAX as u128) as u64;
let next = NEXT_WARN_NANOS.load(Ordering::Relaxed);
if now_nanos >= next
&& NEXT_WARN_NANOS
.compare_exchange(
next,
now_nanos + WARN_INTERVAL_NANOS,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
tracing::warn!(
duration_ms = %ms,
capped_at = i32::MAX,
"duration exceeds i32::MAX (~24.8 days); clamping to i32::MAX. \
Check timeout/deadline configuration. (repeats at most once per hour)"
);
}
}
ms.min(i32::MAX as u128) as i32
}
#[inline]
pub fn duration_to_millis_i64(d: Duration) -> i64 {
d.as_millis().min(i64::MAX as u128) as i64
}
pub fn random_uuid_v4() -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut bytes: [u8; 16] = rand::random();
bytes[6] = (bytes[6] & 0x0F) | 0x40;
bytes[8] = (bytes[8] & 0x3F) | 0x80;
let mut s = String::with_capacity(36);
for (i, &b) in bytes.iter().enumerate() {
if i == 4 || i == 6 || i == 8 || i == 10 {
s.push('-');
}
s.push(HEX[(b >> 4) as usize] as char);
s.push(HEX[(b & 0xF) as usize] as char);
}
debug_assert_eq!(s.len(), 36, "UUID must be exactly 36 chars");
s
}
pub struct CorrelationIdGenerator {
counter: AtomicI32,
}
impl CorrelationIdGenerator {
pub const fn new() -> Self {
Self {
counter: AtomicI32::new(1),
}
}
#[inline]
pub fn next(&self) -> i32 {
loop {
let correlation_id = self.counter.fetch_add(1, Ordering::Relaxed);
if correlation_id != NO_RESPONSE_CORRELATION_ID {
return correlation_id;
}
}
}
}
impl Default for CorrelationIdGenerator {
fn default() -> Self {
Self::new()
}
}
#[inline]
pub fn crc32c(data: &[u8]) -> u32 {
crc32c::crc32c(data)
}
pub mod varint {
use bytes::{Buf, BufMut};
use crate::error::{KrafkaError, ProtocolErrorKind, Result};
#[inline]
pub const fn unsigned_varint_size(mut value: u32) -> usize {
let mut len = 1usize;
while value >= 0x80 {
value >>= 7;
len += 1;
}
len
}
#[inline]
pub const fn signed_varint_size(value: i32) -> usize {
let unsigned = ((value << 1) ^ (value >> 31)) as u32;
unsigned_varint_size(unsigned)
}
#[inline]
pub const fn unsigned_varlong_size(mut value: u64) -> usize {
let mut len = 1usize;
while value >= 0x80 {
value >>= 7;
len += 1;
}
len
}
#[inline]
pub const fn signed_varlong_size(value: i64) -> usize {
let unsigned = ((value << 1) ^ (value >> 63)) as u64;
unsigned_varlong_size(unsigned)
}
#[inline]
pub fn encode_signed_varint(value: i32, buf: &mut impl BufMut) {
let unsigned = ((value << 1) ^ (value >> 31)) as u32;
encode_unsigned_varint(unsigned, buf);
}
#[inline]
pub fn encode_unsigned_varint(mut value: u32, buf: &mut impl BufMut) {
while value >= 0x80 {
buf.put_u8((value as u8) | 0x80);
value >>= 7;
}
buf.put_u8(value as u8);
}
#[inline]
pub fn encode_signed_varlong(value: i64, buf: &mut impl BufMut) {
let unsigned = ((value << 1) ^ (value >> 63)) as u64;
encode_unsigned_varlong(unsigned, buf);
}
#[inline]
pub fn encode_unsigned_varlong(mut value: u64, buf: &mut impl BufMut) {
while value >= 0x80 {
buf.put_u8((value as u8) | 0x80);
value >>= 7;
}
buf.put_u8(value as u8);
}
#[inline]
pub fn decode_signed_varint(buf: &mut impl Buf) -> Result<i32> {
let unsigned = decode_unsigned_varint(buf)?;
Ok(((unsigned >> 1) as i32) ^ -((unsigned & 1) as i32))
}
#[inline]
pub fn decode_unsigned_varint(buf: &mut impl Buf) -> Result<u32> {
let mut result: u32 = 0;
let mut shift = 0;
loop {
if !buf.has_remaining() {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::TruncatedFrame,
"unexpected end of varint",
));
}
let byte = buf.get_u8();
result |= ((byte & 0x7F) as u32) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
if shift >= 35 {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidLength,
"varint too long",
));
}
}
Ok(result)
}
#[inline]
pub fn decode_signed_varlong(buf: &mut impl Buf) -> Result<i64> {
let unsigned = decode_unsigned_varlong(buf)?;
Ok(((unsigned >> 1) as i64) ^ -((unsigned & 1) as i64))
}
#[inline]
pub fn decode_unsigned_varlong(buf: &mut impl Buf) -> Result<u64> {
let mut result: u64 = 0;
let mut shift = 0;
loop {
if !buf.has_remaining() {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::TruncatedFrame,
"unexpected end of varlong",
));
}
let byte = buf.get_u8();
result |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
if shift >= 70 {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidLength,
"varlong too long",
));
}
}
Ok(result)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use bytes::BytesMut;
use super::*;
#[test]
fn test_correlation_id_generator() {
let generator = CorrelationIdGenerator::new();
assert_eq!(generator.next(), 1);
assert_eq!(generator.next(), 2);
assert_eq!(generator.next(), 3);
}
#[test]
fn test_correlation_id_generator_skips_reserved_no_response_id() {
let generator = CorrelationIdGenerator {
counter: AtomicI32::new(NO_RESPONSE_CORRELATION_ID),
};
assert_eq!(generator.next(), NO_RESPONSE_CORRELATION_ID + 1);
assert_eq!(generator.next(), NO_RESPONSE_CORRELATION_ID + 2);
}
#[test]
fn test_varint_encode_decode() {
let test_values = [0, 1, 127, 128, 255, 300, 16383, 16384, i32::MAX, i32::MIN];
for value in test_values {
let mut buf = BytesMut::new();
varint::encode_signed_varint(value, &mut buf);
let decoded = varint::decode_signed_varint(&mut buf.freeze()).unwrap();
assert_eq!(decoded, value, "Failed for value {value}");
}
}
#[test]
fn test_varlong_encode_decode() {
let test_values = [
0i64,
1,
127,
128,
255,
300,
16383,
16384,
i64::MAX,
i64::MIN,
];
for value in test_values {
let mut buf = BytesMut::new();
varint::encode_signed_varlong(value, &mut buf);
let decoded = varint::decode_signed_varlong(&mut buf.freeze()).unwrap();
assert_eq!(decoded, value, "Failed for value {value}");
}
}
#[test]
fn test_crc32c() {
let data = b"hello world";
let crc = crc32c(data);
assert_eq!(crc, 0xc99465aa);
}
#[test]
fn test_duration_to_millis_i32_normal() {
assert_eq!(duration_to_millis_i32(Duration::from_millis(100)), 100);
assert_eq!(duration_to_millis_i32(Duration::from_secs(30)), 30_000);
assert_eq!(duration_to_millis_i32(Duration::ZERO), 0);
}
#[test]
fn test_duration_to_millis_i32_caps_at_max() {
let huge = Duration::from_secs(25 * 24 * 3600);
assert_eq!(duration_to_millis_i32(huge), i32::MAX);
}
#[test]
fn test_duration_to_millis_i32_exact_max() {
let exact = Duration::from_millis(i32::MAX as u64);
assert_eq!(duration_to_millis_i32(exact), i32::MAX);
}
#[test]
fn test_duration_to_millis_i64_normal() {
assert_eq!(duration_to_millis_i64(Duration::from_millis(100)), 100);
assert_eq!(duration_to_millis_i64(Duration::from_secs(30)), 30_000);
assert_eq!(duration_to_millis_i64(Duration::ZERO), 0);
}
#[test]
fn test_duration_to_millis_i64_caps_at_max() {
let huge = Duration::from_secs(u64::MAX);
assert_eq!(duration_to_millis_i64(huge), i64::MAX);
}
#[test]
fn test_duration_to_millis_i64_exact_max() {
let exact = Duration::from_millis(i64::MAX as u64);
assert_eq!(duration_to_millis_i64(exact), i64::MAX);
}
#[test]
fn test_random_uuid_v4_format() {
let uuid = random_uuid_v4();
assert_eq!(uuid.len(), 36);
let parts: Vec<&str> = uuid.split('-').collect();
assert_eq!(parts.len(), 5);
assert_eq!(parts[0].len(), 8);
assert_eq!(parts[1].len(), 4);
assert_eq!(parts[2].len(), 4);
assert_eq!(parts[3].len(), 4);
assert_eq!(parts[4].len(), 12);
assert!(uuid.chars().all(|c| c.is_ascii_hexdigit() || c == '-'));
}
#[test]
fn test_random_uuid_v4_version_and_variant() {
let uuid = random_uuid_v4();
let parts: Vec<&str> = uuid.split('-').collect();
assert_eq!(
parts[2].chars().next().unwrap(),
'4',
"UUID version nibble must be 4"
);
let variant = parts[3].chars().next().unwrap();
assert!(
matches!(variant, '8' | '9' | 'a' | 'b'),
"UUID variant nibble must be 8/9/a/b, got '{variant}'"
);
}
#[test]
fn test_random_uuid_v4_uniqueness() {
let a = random_uuid_v4();
let b = random_uuid_v4();
assert_ne!(a, b, "Two UUIDs should not be identical");
}
}
pub fn parse_bootstrap_servers(servers: &str) -> Result<Vec<String>> {
let addrs: Vec<String> = servers
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(String::from)
.collect();
if addrs.is_empty() {
return Err(KrafkaError::config("no bootstrap servers specified"));
}
Ok(addrs)
}
pub fn extract_sni_hostname(address: &str) -> Result<&str> {
if address.is_empty() {
return Err(KrafkaError::config("empty address"));
}
let has_open = address.contains('[');
let close_pos = address.find(']');
match (has_open, close_pos) {
(true, Some(end)) => {
if !address.starts_with('[') {
return Err(KrafkaError::config(format!(
"malformed address ('[' not at start): {address}"
)));
}
let hostname = &address[1..end];
if hostname.is_empty() {
return Err(KrafkaError::config(format!(
"empty hostname in brackets: {address}"
)));
}
let after = &address[end + 1..];
if after.contains('[') || after.contains(']') {
return Err(KrafkaError::config(format!(
"unexpected bracket characters after closing ']': {address}"
)));
}
if !after.is_empty() {
if !after.starts_with(':') {
return Err(KrafkaError::config(format!(
"unexpected characters after closing ']': {address}"
)));
}
let port_str = &after[1..];
if port_str.is_empty() || !port_str.chars().all(|c| c.is_ascii_digit()) {
return Err(KrafkaError::config(format!(
"invalid port after closing ']': {address}"
)));
}
}
Ok(hostname)
}
(true, None) => Err(KrafkaError::config(format!(
"malformed address (missing closing ']'): {address}"
))),
(false, Some(_)) => Err(KrafkaError::config(format!(
"malformed address (unexpected ']' without '['): {address}"
))),
(false, None) => {
if address.parse::<std::net::Ipv6Addr>().is_ok() {
Ok(address)
} else {
Ok(address.rsplit_once(':').map_or(address, |(host, _)| host))
}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod bootstrap_tests {
use super::*;
#[test]
fn test_parse_bootstrap_servers_basic() {
let result = parse_bootstrap_servers("localhost:9092,broker:9093").unwrap();
assert_eq!(result, vec!["localhost:9092", "broker:9093"]);
}
#[test]
fn test_parse_bootstrap_servers_trims_whitespace() {
let result = parse_bootstrap_servers(" localhost:9092 , broker:9093 ").unwrap();
assert_eq!(result, vec!["localhost:9092", "broker:9093"]);
}
#[test]
fn test_parse_bootstrap_servers_filters_empty() {
let result = parse_bootstrap_servers(" , ,localhost:9092, , broker:9093, ").unwrap();
assert_eq!(result, vec!["localhost:9092", "broker:9093"]);
}
#[test]
fn test_parse_bootstrap_servers_empty_string() {
assert!(parse_bootstrap_servers("").is_err());
}
#[test]
fn test_parse_bootstrap_servers_only_whitespace() {
assert!(parse_bootstrap_servers(" , , ").is_err());
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod sni_tests {
use super::*;
#[test]
fn test_extract_sni_bracketed_ipv6_with_port() {
assert_eq!(extract_sni_hostname("[::1]:9092").unwrap(), "::1");
}
#[test]
fn test_extract_sni_bracketed_ipv6_no_port() {
assert_eq!(extract_sni_hostname("[::1]").unwrap(), "::1");
}
#[test]
fn test_extract_sni_bare_ipv6() {
assert_eq!(extract_sni_hostname("2001:db8::1").unwrap(), "2001:db8::1");
}
#[test]
fn test_extract_sni_bare_ipv6_loopback() {
assert_eq!(extract_sni_hostname("::1").unwrap(), "::1");
}
#[test]
fn test_extract_sni_ipv4_with_port() {
assert_eq!(
extract_sni_hostname("192.168.1.1:9092").unwrap(),
"192.168.1.1"
);
}
#[test]
fn test_extract_sni_hostname_with_port() {
assert_eq!(
extract_sni_hostname("broker.example.com:9092").unwrap(),
"broker.example.com"
);
}
#[test]
fn test_extract_sni_hostname_no_port() {
assert_eq!(
extract_sni_hostname("broker.example.com").unwrap(),
"broker.example.com"
);
}
#[test]
fn test_extract_sni_bracketed_ipv6_full() {
assert_eq!(
extract_sni_hostname("[2001:db8::1]:9092").unwrap(),
"2001:db8::1"
);
}
#[test]
fn test_extract_sni_ipv6_ambiguous_port() {
assert_eq!(
extract_sni_hostname("2001:db8::1:9092").unwrap(),
"2001:db8::1:9092"
);
assert_eq!(
extract_sni_hostname("2001:db8::zz:9092").unwrap(),
"2001:db8::zz"
);
}
#[test]
fn test_extract_sni_malformed_bracket_returns_error() {
assert!(extract_sni_hostname("[::1").is_err());
assert!(extract_sni_hostname("[host").is_err());
assert!(extract_sni_hostname("[host:9092").is_err());
assert!(extract_sni_hostname("::1]:9092").is_err());
assert!(extract_sni_hostname("host]").is_err());
assert!(extract_sni_hostname("host]:9092").is_err());
assert!(extract_sni_hostname("foo[::1]:9092").is_err());
assert!(extract_sni_hostname("[::1]extra").is_err());
assert!(extract_sni_hostname("[::1]:9092]").is_err());
assert!(extract_sni_hostname("[::1]:abc").is_err());
assert!(extract_sni_hostname("[::1]:").is_err());
}
#[test]
fn test_extract_sni_empty_input_returns_error() {
assert!(extract_sni_hostname("").is_err());
assert!(extract_sni_hostname("[]").is_err());
assert!(extract_sni_hostname("[]:9092").is_err());
}
}