use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::TcpSocket;
use tokio::task::JoinSet;
use tokio::time::Instant;
use tracing::debug;
use super::connection::ConnectionConfig;
use crate::error::{KrafkaError, Result};
type AttemptResult =
std::result::Result<(SocketAddr, tokio::net::TcpStream), (SocketAddr, KrafkaError)>;
const MIN_CONNECTION_ATTEMPT_DELAY: Duration = Duration::from_millis(100);
const MAX_CONNECTION_ATTEMPT_DELAY: Duration = Duration::from_secs(2);
fn clamp_delay(delay: Duration) -> Duration {
delay.clamp(MIN_CONNECTION_ATTEMPT_DELAY, MAX_CONNECTION_ATTEMPT_DELAY)
}
pub(crate) async fn connect_happy_eyeballs(
address: &str,
config: &ConnectionConfig,
) -> Result<tokio::net::TcpStream> {
let addrs = resolve(address, config.connect_timeout).await?;
if addrs.is_empty() {
return Err(KrafkaError::invalid_state(format!(
"no addresses resolved for '{address}'"
)));
}
let sorted = interleave_address_families(&addrs);
debug!(
address = %address,
candidates = sorted.len(),
first_family = family_label(sorted[0]),
"Happy Eyeballs: starting connection race"
);
if sorted.len() == 1 {
return connect_one(sorted[0], config).await;
}
staggered_connect(&sorted, config).await
}
async fn resolve(address: &str, timeout_dur: Duration) -> Result<Vec<SocketAddr>> {
let addrs: Vec<SocketAddr> =
tokio::time::timeout(timeout_dur, tokio::net::lookup_host(address))
.await
.map_err(|_| KrafkaError::timeout("DNS resolution"))?
.map_err(KrafkaError::network)?
.collect();
Ok(addrs)
}
fn interleave_address_families(addrs: &[SocketAddr]) -> Vec<SocketAddr> {
let mut v6: Vec<SocketAddr> = addrs.iter().filter(|a| a.is_ipv6()).copied().collect();
let mut v4: Vec<SocketAddr> = addrs.iter().filter(|a| a.is_ipv4()).copied().collect();
if v6.is_empty() {
return v4;
}
if v4.is_empty() {
return v6;
}
let mut result = Vec::with_capacity(v6.len() + v4.len());
let mut i6 = v6.drain(..);
let mut i4 = v4.drain(..);
loop {
match (i6.next(), i4.next()) {
(Some(a6), Some(a4)) => {
result.push(a6);
result.push(a4);
}
(Some(a6), None) => {
result.push(a6);
result.extend(i6);
break;
}
(None, Some(a4)) => {
result.push(a4);
result.extend(i4);
break;
}
(None, None) => break,
}
}
result
}
async fn connect_one(addr: SocketAddr, config: &ConnectionConfig) -> Result<tokio::net::TcpStream> {
let socket = create_socket(addr, config)?;
tokio::time::timeout(config.connect_timeout, socket.connect(addr))
.await
.map_err(|_| KrafkaError::timeout("connection"))?
.map_err(KrafkaError::network)
}
async fn staggered_connect(
addrs: &[SocketAddr],
config: &ConnectionConfig,
) -> Result<tokio::net::TcpStream> {
let deadline = Instant::now() + config.connect_timeout;
let delay = clamp_delay(config.connection_attempt_delay);
let mut tasks = JoinSet::new();
let mut next_idx = 0;
let total = addrs.len();
let mut errors: Vec<(SocketAddr, KrafkaError)> = Vec::new();
spawn_attempt(&mut tasks, addrs[next_idx], config, deadline);
next_idx += 1;
let stagger = tokio::time::sleep(delay);
tokio::pin!(stagger);
loop {
let all_launched = next_idx >= total;
tokio::select! {
biased;
Some(join_result) = tasks.join_next() => {
match join_result {
Ok(Ok((addr, stream))) => {
debug!(
addr = %addr,
"Happy Eyeballs: connected successfully"
);
return Ok(stream);
}
Ok(Err((addr, e))) => {
debug!(
addr = %addr,
error = %e,
"Happy Eyeballs: attempt failed"
);
errors.push((addr, e));
if next_idx < total {
spawn_attempt(&mut tasks, addrs[next_idx], config, deadline);
next_idx += 1;
stagger.as_mut().reset(Instant::now() + delay);
}
}
Err(join_err) => {
debug!(error = %join_err, "Happy Eyeballs: task join error");
}
}
if tasks.is_empty() && next_idx >= total {
break;
}
}
_ = &mut stagger, if !all_launched => {
debug!(
addr = %addrs[next_idx],
attempt = next_idx,
"Happy Eyeballs: stagger delay elapsed, launching next attempt"
);
spawn_attempt(&mut tasks, addrs[next_idx], config, deadline);
next_idx += 1;
if next_idx < total {
stagger.as_mut().reset(Instant::now() + delay);
}
}
_ = tokio::time::sleep_until(deadline) => {
debug!("Happy Eyeballs: overall connect timeout reached");
return Err(KrafkaError::timeout("connection (Happy Eyeballs)"));
}
}
}
Err(build_combined_error(&errors))
}
fn spawn_attempt(
tasks: &mut JoinSet<AttemptResult>,
addr: SocketAddr,
config: &ConnectionConfig,
deadline: Instant,
) {
let send_buf = config.send_buffer_size;
let recv_buf = config.recv_buffer_size;
let tcp_keepalive = config.tcp_keepalive;
tasks.spawn(async move {
let result = async {
let socket = create_socket_raw(addr, send_buf, recv_buf, tcp_keepalive)?;
tokio::time::timeout_at(deadline, socket.connect(addr))
.await
.map_err(|_| KrafkaError::timeout("connection attempt"))?
.map_err(KrafkaError::network)
}
.await;
match result {
Ok(stream) => Ok((addr, stream)),
Err(e) => Err((addr, e)),
}
});
}
fn build_combined_error(errors: &[(SocketAddr, KrafkaError)]) -> KrafkaError {
if errors.is_empty() {
return KrafkaError::invalid_state("all connection attempts failed");
}
if errors.len() == 1 {
let (addr, e) = &errors[0];
return KrafkaError::invalid_state(format!("connection to {addr} failed: {e}"));
}
let details: Vec<String> = errors
.iter()
.map(|(addr, e)| format!(" {addr}: {e}"))
.collect();
KrafkaError::invalid_state(format!(
"all {} connection attempts failed:\n{}",
errors.len(),
details.join("\n")
))
}
fn create_socket_raw(
addr: SocketAddr,
send_buffer_size: Option<usize>,
recv_buffer_size: Option<usize>,
tcp_keepalive: Option<Duration>,
) -> Result<TcpSocket> {
let socket = if addr.is_ipv6() {
TcpSocket::new_v6()
} else {
TcpSocket::new_v4()
}
.map_err(KrafkaError::network)?;
if let Some(size) = send_buffer_size {
socket
.set_send_buffer_size(size as u32)
.map_err(KrafkaError::network)?;
}
if let Some(size) = recv_buffer_size {
socket
.set_recv_buffer_size(size as u32)
.map_err(KrafkaError::network)?;
}
if let Some(interval) = tcp_keepalive {
let sock_ref = socket2::SockRef::from(&socket);
let keepalive = socket2::TcpKeepalive::new().with_time(interval);
sock_ref
.set_tcp_keepalive(&keepalive)
.map_err(KrafkaError::network)?;
}
Ok(socket)
}
pub(super) fn create_socket(addr: SocketAddr, config: &ConnectionConfig) -> Result<TcpSocket> {
create_socket_raw(
addr,
config.send_buffer_size,
config.recv_buffer_size,
config.tcp_keepalive,
)
}
fn family_label(addr: SocketAddr) -> &'static str {
if addr.is_ipv6() { "IPv6" } else { "IPv4" }
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
fn v4(port: u16) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(192, 168, 1, port as u8),
port,
))
}
fn v6(port: u16) -> SocketAddr {
SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, port),
port,
0,
0,
))
}
#[test]
fn clamp_delay_within_bounds() {
assert_eq!(
clamp_delay(Duration::from_millis(250)),
Duration::from_millis(250)
);
}
#[test]
fn clamp_delay_below_minimum() {
assert_eq!(
clamp_delay(Duration::from_millis(5)),
MIN_CONNECTION_ATTEMPT_DELAY
);
}
#[test]
fn clamp_delay_above_maximum() {
assert_eq!(
clamp_delay(Duration::from_secs(10)),
MAX_CONNECTION_ATTEMPT_DELAY
);
}
#[test]
fn interleave_ipv6_preferred_first() {
let addrs = vec![v4(1), v4(2), v6(3), v6(4)];
let sorted = interleave_address_families(&addrs);
assert!(sorted[0].is_ipv6(), "first address should be IPv6");
assert!(sorted[1].is_ipv4(), "second address should be IPv4");
assert_eq!(sorted.len(), 4);
}
#[test]
fn interleave_all_ipv4() {
let addrs = vec![v4(1), v4(2), v4(3)];
let sorted = interleave_address_families(&addrs);
assert_eq!(sorted.len(), 3);
assert!(sorted.iter().all(|a| a.is_ipv4()));
}
#[test]
fn interleave_all_ipv6() {
let addrs = vec![v6(1), v6(2)];
let sorted = interleave_address_families(&addrs);
assert_eq!(sorted.len(), 2);
assert!(sorted.iter().all(|a| a.is_ipv6()));
}
#[test]
fn interleave_single_address() {
let addrs = vec![v4(1)];
let sorted = interleave_address_families(&addrs);
assert_eq!(sorted.len(), 1);
}
#[test]
fn interleave_uneven_families() {
let addrs = vec![v6(1), v6(2), v6(3), v6(4), v6(5), v4(10), v4(11)];
let sorted = interleave_address_families(&addrs);
assert_eq!(sorted.len(), 7);
assert!(sorted[0].is_ipv6());
assert!(sorted[1].is_ipv4());
assert!(sorted[2].is_ipv6());
assert!(sorted[3].is_ipv4());
assert!(sorted[4..].iter().all(|a| a.is_ipv6()));
}
#[test]
fn interleave_preserves_order_within_family() {
let addrs = vec![v6(1), v6(2), v6(3), v4(10), v4(11)];
let sorted = interleave_address_families(&addrs);
let v6_sorted: Vec<_> = sorted.iter().filter(|a| a.is_ipv6()).collect();
assert_eq!(v6_sorted[0].port(), 1);
assert_eq!(v6_sorted[1].port(), 2);
assert_eq!(v6_sorted[2].port(), 3);
let v4_sorted: Vec<_> = sorted.iter().filter(|a| a.is_ipv4()).collect();
assert_eq!(v4_sorted[0].port(), 10);
assert_eq!(v4_sorted[1].port(), 11);
}
#[test]
fn interleave_empty() {
let sorted = interleave_address_families(&[]);
assert!(sorted.is_empty());
}
#[test]
fn combined_error_empty() {
let err = build_combined_error(&[]);
assert!(err.to_string().contains("all connection attempts failed"));
}
#[test]
fn combined_error_single() {
let err = build_combined_error(&[(v4(1), KrafkaError::timeout("test"))]);
assert!(err.to_string().contains("192.168.1.1:1"));
}
#[test]
fn combined_error_multiple() {
let err = build_combined_error(&[
(v4(1), KrafkaError::timeout("t1")),
(v6(2), KrafkaError::timeout("t2")),
]);
let msg = err.to_string();
assert!(msg.contains("2 connection attempts failed"));
assert!(msg.contains("192.168.1.1:1"));
}
#[tokio::test]
async fn connect_to_localhost() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let config = ConnectionConfig::default();
let result = connect_happy_eyeballs(&addr.to_string(), &config).await;
assert!(
result.is_ok(),
"should connect to localhost: {:?}",
result.unwrap_err()
);
}
#[tokio::test]
async fn connect_to_unreachable_fails() {
let config = ConnectionConfig::builder()
.connect_timeout(Duration::from_millis(200))
.build();
let result = connect_happy_eyeballs("198.51.100.1:9092", &config).await;
assert!(result.is_err());
}
#[tokio::test]
async fn connect_single_address_fast_path() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let stream = connect_one(addr, &ConnectionConfig::default()).await;
assert!(stream.is_ok());
}
#[tokio::test]
async fn staggered_first_address_fast() {
let l1 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let l2 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addrs = vec![l1.local_addr().unwrap(), l2.local_addr().unwrap()];
let config = ConnectionConfig::default();
let stream = staggered_connect(&addrs, &config).await;
assert!(stream.is_ok());
}
}