use std::future::Future;
use std::net::SocketAddr;
use std::time::Duration;
use thiserror::Error;
use tokio::task::JoinHandle;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddressFamily {
IPv6Preferred,
IPv4Preferred,
}
#[derive(Debug, Clone)]
pub struct HappyEyeballsConfig {
pub connection_attempt_delay: Duration,
pub first_address_family_count: usize,
pub preferred_family: AddressFamily,
}
impl Default for HappyEyeballsConfig {
fn default() -> Self {
Self {
connection_attempt_delay: Duration::from_millis(250),
first_address_family_count: 1,
preferred_family: AddressFamily::IPv6Preferred,
}
}
}
#[derive(Debug, Error)]
pub enum HappyEyeballsError {
#[error("no addresses provided for connection attempts")]
NoAddresses,
#[error(
"all {count} connection attempts failed: {summary}",
count = errors.len(),
summary = format_error_summary(errors)
)]
AllAttemptsFailed {
errors: Vec<(SocketAddr, String)>,
},
#[error("connection racing timed out")]
Timeout,
}
fn format_error_summary(errors: &[(SocketAddr, String)]) -> String {
errors
.iter()
.map(|(addr, err)| format!("{addr}: {err}"))
.collect::<Vec<_>>()
.join("; ")
}
pub fn sort_addresses(addresses: &[SocketAddr], config: &HappyEyeballsConfig) -> Vec<SocketAddr> {
if addresses.is_empty() {
return Vec::new();
}
let is_preferred = |addr: &SocketAddr| -> bool {
match config.preferred_family {
AddressFamily::IPv6Preferred => addr.is_ipv6(),
AddressFamily::IPv4Preferred => addr.is_ipv4(),
}
};
let preferred: Vec<SocketAddr> = addresses.iter().copied().filter(is_preferred).collect();
let non_preferred: Vec<SocketAddr> = addresses
.iter()
.copied()
.filter(|a| !is_preferred(a))
.collect();
let mut result = Vec::with_capacity(addresses.len());
let first_count = config.first_address_family_count.min(preferred.len());
result.extend_from_slice(&preferred[..first_count]);
let mut pref_iter = preferred[first_count..].iter();
let mut non_pref_iter = non_preferred.iter();
loop {
let non_pref_next = non_pref_iter.next();
let pref_next = pref_iter.next();
match (non_pref_next, pref_next) {
(Some(np), Some(p)) => {
result.push(*np);
result.push(*p);
}
(Some(np), None) => {
result.push(*np);
}
(None, Some(p)) => {
result.push(*p);
}
(None, None) => break,
}
}
result
}
enum AttemptResult<C> {
Success(C, SocketAddr),
Failure(SocketAddr, String),
}
fn spawn_attempt<F, Fut, C, E>(
addr: SocketAddr,
attempt_num: usize,
connect_fn: &F,
tx: &tokio::sync::mpsc::UnboundedSender<AttemptResult<C>>,
) -> JoinHandle<()>
where
F: Fn(SocketAddr) -> Fut,
Fut: Future<Output = Result<C, E>> + Send + 'static,
C: Send + 'static,
E: std::fmt::Display + Send + 'static,
{
debug!(addr = %addr, attempt = attempt_num, "Starting connection attempt");
let fut = connect_fn(addr);
let tx_clone = tx.clone();
tokio::spawn(async move {
match fut.await {
Ok(conn) => {
let _ = tx_clone.send(AttemptResult::Success(conn, addr));
}
Err(e) => {
let _ = tx_clone.send(AttemptResult::Failure(addr, e.to_string()));
}
}
})
}
pub async fn race_connect<F, Fut, C, E>(
addresses: &[SocketAddr],
config: &HappyEyeballsConfig,
connect_fn: F,
) -> Result<(C, SocketAddr), HappyEyeballsError>
where
F: Fn(SocketAddr) -> Fut,
Fut: Future<Output = Result<C, E>> + Send + 'static,
C: Send + 'static,
E: std::fmt::Display + Send + 'static,
{
if addresses.is_empty() {
return Err(HappyEyeballsError::NoAddresses);
}
let sorted = sort_addresses(addresses, config);
debug!(
addresses = ?sorted,
delay_ms = config.connection_attempt_delay.as_millis(),
"Starting Happy Eyeballs connection racing"
);
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<AttemptResult<C>>();
let mut handles: Vec<JoinHandle<()>> = Vec::with_capacity(sorted.len());
let mut errors: Vec<(SocketAddr, String)> = Vec::new();
let mut next_index: usize = 0;
let total = sorted.len();
let mut in_flight: usize = 0;
handles.push(spawn_attempt(
sorted[next_index],
next_index + 1,
&connect_fn,
&tx,
));
next_index += 1;
in_flight += 1;
loop {
if next_index < total {
tokio::select! {
biased;
result = rx.recv() => {
match result {
Some(AttemptResult::Success(conn, addr)) => {
info!(addr = %addr, "Happy Eyeballs: connection succeeded");
abort_all(&handles);
return Ok((conn, addr));
}
Some(AttemptResult::Failure(addr, err)) => {
warn!(addr = %addr, error = %err, "Connection attempt failed");
errors.push((addr, err));
in_flight -= 1;
if next_index < total {
handles.push(spawn_attempt(
sorted[next_index],
next_index + 1,
&connect_fn,
&tx,
));
next_index += 1;
in_flight += 1;
}
}
None => {
break;
}
}
}
_ = tokio::time::sleep(config.connection_attempt_delay) => {
if next_index < total {
debug!(
addr = %sorted[next_index],
attempt = next_index + 1,
"Starting parallel attempt after delay"
);
handles.push(spawn_attempt(
sorted[next_index],
next_index + 1,
&connect_fn,
&tx,
));
next_index += 1;
in_flight += 1;
}
}
}
} else {
if in_flight == 0 {
break;
}
match rx.recv().await {
Some(AttemptResult::Success(conn, addr)) => {
info!(addr = %addr, "Happy Eyeballs: connection succeeded");
abort_all(&handles);
return Ok((conn, addr));
}
Some(AttemptResult::Failure(addr, err)) => {
warn!(addr = %addr, error = %err, "Connection attempt failed");
errors.push((addr, err));
in_flight -= 1;
}
None => {
break;
}
}
}
}
Err(HappyEyeballsError::AllAttemptsFailed { errors })
}
fn abort_all(handles: &[JoinHandle<()>]) {
for handle in handles {
handle.abort();
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
fn v4(s: &str) -> SocketAddr {
s.parse().unwrap()
}
fn v6(s: &str) -> SocketAddr {
s.parse().unwrap()
}
#[test]
fn test_sort_ipv6_preferred() {
let addrs = vec![
v4("192.168.1.1:80"), v6("[::1]:80"), v4("192.168.1.2:80"), v6("[::2]:80"), v4("192.168.1.3:80"), ];
let config = HappyEyeballsConfig {
preferred_family: AddressFamily::IPv6Preferred,
first_address_family_count: 1,
..Default::default()
};
let sorted = sort_addresses(&addrs, &config);
assert_eq!(sorted.len(), 5);
assert_eq!(sorted[0], v6("[::1]:80")); assert_eq!(sorted[1], v4("192.168.1.1:80")); assert_eq!(sorted[2], v6("[::2]:80")); assert_eq!(sorted[3], v4("192.168.1.2:80")); assert_eq!(sorted[4], v4("192.168.1.3:80")); }
#[test]
fn test_sort_ipv4_preferred() {
let addrs = vec![
v6("[::1]:80"),
v4("10.0.0.1:80"),
v6("[::2]:80"),
v4("10.0.0.2:80"),
];
let config = HappyEyeballsConfig {
preferred_family: AddressFamily::IPv4Preferred,
first_address_family_count: 1,
..Default::default()
};
let sorted = sort_addresses(&addrs, &config);
assert_eq!(sorted.len(), 4);
assert_eq!(sorted[0], v4("10.0.0.1:80")); assert_eq!(sorted[1], v6("[::1]:80")); assert_eq!(sorted[2], v4("10.0.0.2:80")); assert_eq!(sorted[3], v6("[::2]:80")); }
#[test]
fn test_sort_single_family() {
let addrs = vec![v4("10.0.0.1:80"), v4("10.0.0.2:80"), v4("10.0.0.3:80")];
let config = HappyEyeballsConfig::default();
let sorted = sort_addresses(&addrs, &config);
assert_eq!(sorted.len(), 3);
assert_eq!(sorted[0], v4("10.0.0.1:80"));
assert_eq!(sorted[1], v4("10.0.0.2:80"));
assert_eq!(sorted[2], v4("10.0.0.3:80"));
}
#[test]
fn test_sort_empty() {
let addrs: Vec<SocketAddr> = vec![];
let config = HappyEyeballsConfig::default();
let sorted = sort_addresses(&addrs, &config);
assert!(sorted.is_empty());
}
#[test]
fn test_sort_first_count_two() {
let addrs = vec![
v4("10.0.0.1:80"),
v6("[::1]:80"),
v4("10.0.0.2:80"),
v6("[::2]:80"),
v6("[::3]:80"),
];
let config = HappyEyeballsConfig {
preferred_family: AddressFamily::IPv6Preferred,
first_address_family_count: 2,
..Default::default()
};
let sorted = sort_addresses(&addrs, &config);
assert_eq!(sorted.len(), 5);
assert_eq!(sorted[0], v6("[::1]:80"));
assert_eq!(sorted[1], v6("[::2]:80"));
assert_eq!(sorted[2], v4("10.0.0.1:80"));
assert_eq!(sorted[3], v6("[::3]:80"));
assert_eq!(sorted[4], v4("10.0.0.2:80"));
}
#[tokio::test]
async fn test_race_single_address_success() {
let addrs = vec![v4("10.0.0.1:80")];
let config = HappyEyeballsConfig::default();
let result = race_connect(&addrs, &config, |addr| async move {
Ok::<_, String>(format!("connected to {addr}"))
})
.await;
let (conn, addr) = result.unwrap();
assert_eq!(conn, "connected to 10.0.0.1:80");
assert_eq!(addr, v4("10.0.0.1:80"));
}
#[tokio::test]
async fn test_race_first_succeeds_fast() {
let attempt_count = Arc::new(AtomicUsize::new(0));
let attempt_count_clone = Arc::clone(&attempt_count);
let addrs = vec![v6("[::1]:80"), v4("10.0.0.1:80")];
let config = HappyEyeballsConfig {
connection_attempt_delay: Duration::from_millis(500),
..Default::default()
};
let result = race_connect(&addrs, &config, move |addr| {
let count = Arc::clone(&attempt_count_clone);
async move {
count.fetch_add(1, Ordering::SeqCst);
Ok::<_, String>(format!("connected to {addr}"))
}
})
.await;
let (conn, addr) = result.unwrap();
assert_eq!(conn, "connected to [::1]:80");
assert_eq!(addr, v6("[::1]:80"));
assert_eq!(attempt_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_race_first_fails_second_succeeds() {
let addrs = vec![v6("[::1]:80"), v4("10.0.0.1:80")];
let config = HappyEyeballsConfig {
connection_attempt_delay: Duration::from_secs(10), ..Default::default()
};
let result = race_connect(&addrs, &config, |addr| async move {
if addr == v6("[::1]:80") {
Err("connection refused".to_string())
} else {
Ok(format!("connected to {addr}"))
}
})
.await;
let (conn, addr) = result.unwrap();
assert_eq!(conn, "connected to 10.0.0.1:80");
assert_eq!(addr, v4("10.0.0.1:80"));
}
#[tokio::test]
async fn test_race_slow_first_fast_second() {
let addrs = vec![v6("[::1]:80"), v4("10.0.0.1:80")];
let config = HappyEyeballsConfig {
connection_attempt_delay: Duration::from_millis(50),
..Default::default()
};
let result = race_connect(&addrs, &config, |addr| async move {
if addr == v6("[::1]:80") {
tokio::time::sleep(Duration::from_secs(2)).await;
Ok::<_, String>(format!("connected to {addr}"))
} else {
tokio::time::sleep(Duration::from_millis(10)).await;
Ok(format!("connected to {addr}"))
}
})
.await;
let (conn, addr) = result.unwrap();
assert_eq!(conn, "connected to 10.0.0.1:80");
assert_eq!(addr, v4("10.0.0.1:80"));
}
#[tokio::test]
async fn test_race_all_fail() {
let addrs = vec![v6("[::1]:80"), v4("10.0.0.1:80"), v4("10.0.0.2:80")];
let config = HappyEyeballsConfig {
connection_attempt_delay: Duration::from_millis(10),
..Default::default()
};
let result = race_connect(&addrs, &config, |addr| async move {
Err::<String, _>(format!("failed to connect to {addr}"))
})
.await;
match result {
Err(HappyEyeballsError::AllAttemptsFailed { errors }) => {
assert_eq!(errors.len(), 3, "Expected 3 errors, got {}", errors.len());
let addrs_in_errors: Vec<SocketAddr> =
errors.iter().map(|(addr, _)| *addr).collect();
assert!(addrs_in_errors.contains(&v6("[::1]:80")));
assert!(addrs_in_errors.contains(&v4("10.0.0.1:80")));
assert!(addrs_in_errors.contains(&v4("10.0.0.2:80")));
}
other => panic!("Expected AllAttemptsFailed, got: {other:?}"),
}
}
#[tokio::test]
async fn test_race_empty_addresses() {
let addrs: Vec<SocketAddr> = vec![];
let config = HappyEyeballsConfig::default();
let result = race_connect(&addrs, &config, |addr| async move {
Ok::<_, String>(format!("connected to {addr}"))
})
.await;
match result {
Err(HappyEyeballsError::NoAddresses) => {} other => panic!("Expected NoAddresses, got: {other:?}"),
}
}
#[test]
fn test_default_config() {
let config = HappyEyeballsConfig::default();
assert_eq!(config.connection_attempt_delay, Duration::from_millis(250));
assert_eq!(config.preferred_family, AddressFamily::IPv6Preferred);
assert_eq!(config.first_address_family_count, 1);
}
#[tokio::test]
async fn test_race_immediate_failure_triggers_next() {
let attempt_times = Arc::new(tokio::sync::Mutex::new(Vec::new()));
let attempt_times_clone = Arc::clone(&attempt_times);
let addrs = vec![v6("[::1]:80"), v4("10.0.0.1:80")];
let config = HappyEyeballsConfig {
connection_attempt_delay: Duration::from_secs(60),
..Default::default()
};
let start = tokio::time::Instant::now();
let result = race_connect(&addrs, &config, move |addr| {
let times = Arc::clone(&attempt_times_clone);
let start_time = start;
async move {
{
let mut t = times.lock().await;
t.push((addr, start_time.elapsed()));
}
if addr == v6("[::1]:80") {
Err("connection refused".to_string())
} else {
Ok(format!("connected to {addr}"))
}
}
})
.await;
let (conn, _addr) = result.unwrap();
assert_eq!(conn, "connected to 10.0.0.1:80");
let times = attempt_times.lock().await;
assert_eq!(times.len(), 2);
let second_start = times[1].1;
assert!(
second_start < Duration::from_millis(500),
"Second attempt took too long to start: {second_start:?} (expected < 500ms, \
indicating failure-triggered immediate start)"
);
}
#[tokio::test]
async fn test_race_cancels_remaining_on_success() {
let completed = Arc::new(AtomicUsize::new(0));
let completed_clone = Arc::clone(&completed);
let addrs = vec![v6("[::1]:80"), v4("10.0.0.1:80"), v4("10.0.0.2:80")];
let config = HappyEyeballsConfig {
connection_attempt_delay: Duration::from_millis(10),
..Default::default()
};
let result = race_connect(&addrs, &config, move |addr| {
let done = Arc::clone(&completed_clone);
async move {
if addr == v4("10.0.0.1:80") {
tokio::time::sleep(Duration::from_millis(50)).await;
done.fetch_add(1, Ordering::SeqCst);
Ok::<_, String>(format!("connected to {addr}"))
} else {
tokio::time::sleep(Duration::from_secs(10)).await;
done.fetch_add(1, Ordering::SeqCst);
Ok(format!("connected to {addr}"))
}
}
})
.await;
let (_conn, addr) = result.unwrap();
assert_eq!(addr, v4("10.0.0.1:80"));
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(completed.load(Ordering::SeqCst), 1);
}
#[test]
fn test_error_display() {
let err = HappyEyeballsError::NoAddresses;
assert_eq!(
err.to_string(),
"no addresses provided for connection attempts"
);
let err = HappyEyeballsError::Timeout;
assert_eq!(err.to_string(), "connection racing timed out");
let err = HappyEyeballsError::AllAttemptsFailed {
errors: vec![
(v4("10.0.0.1:80"), "refused".to_string()),
(v6("[::1]:80"), "timeout".to_string()),
],
};
let display = err.to_string();
assert!(display.contains("10.0.0.1:80: refused"));
assert!(display.contains("[::1]:80: timeout"));
}
}