ugi 0.2.1

Runtime-agnostic Rust request client with HTTP/1.1, HTTP/2, HTTP/3, H2C, WebSocket, SSE, and gRPC support
Documentation
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;

use async_net::TcpStream;
use futures_lite::future;
use socket2::{Domain, Protocol, Socket, Type};

use crate::error::{Error, ErrorKind, Result};

/// Connect to the first successful address, trying all candidates sequentially.
/// TCP_NODELAY is set on every connection to disable Nagle's algorithm.
pub(crate) async fn connect_tcp_candidates(
    addrs: impl IntoIterator<Item = SocketAddr>,
    local_addr: Option<SocketAddr>,
    timeout: Option<Duration>,
) -> Result<TcpStream> {
    let mut last_error = None;
    for addr in addrs {
        if let Some(local_addr) = local_addr {
            if local_addr.is_ipv4() != addr.is_ipv4() {
                continue;
            }
        }

        match connect_tcp_addr(addr, local_addr, timeout).await {
            Ok(stream) => return Ok(stream),
            Err(err) => last_error = Some(err),
        }
    }

    Err(last_error.unwrap_or_else(|| Error::new(ErrorKind::Transport, "connect failed")))
}

/// Connect to two candidate address lists concurrently (Happy Eyeballs, RFC 6555).
///
/// Races the primary list (IPv6) against the fallback list (IPv4) with a 250 ms
/// head-start for the primary.  The first *successful* connection wins.  Only
/// if both lists are exhausted does this return an error.
/// Falls back to sequential if either list is empty.
pub(crate) async fn connect_happy_eyeballs(
    primary: Vec<SocketAddr>,
    fallback: Vec<SocketAddr>,
    local_addr: Option<SocketAddr>,
    timeout: Option<Duration>,
) -> Result<TcpStream> {
    if primary.is_empty() {
        return connect_tcp_candidates(fallback, local_addr, timeout).await;
    }
    if fallback.is_empty() {
        return connect_tcp_candidates(primary, local_addr, timeout).await;
    }

    // Give IPv6 a 250 ms head-start (RFC 6555 ยง5.4).
    // Race both: return whichever *succeeds* first.
    // If primary fails quickly (e.g. ECONNREFUSED on ::1), the fallback
    // continues and its result (success or failure) is returned.
    let primary_fut = connect_tcp_candidates(primary, local_addr, timeout);
    let fallback_fut = async move {
        async_io::Timer::after(Duration::from_millis(250)).await;
        connect_tcp_candidates(fallback, local_addr, timeout).await
    };

    race_first_ok(primary_fut, fallback_fut).await
}

/// Race two futures that each return `Result<T, E>`.
/// Returns the first `Ok` result.  If one fails, waits for the other.
/// If both fail, returns the error from the second (fallback) future.
async fn race_first_ok<T, E, F1, F2>(f1: F1, f2: F2) -> std::result::Result<T, E>
where
    F1: future::Future<Output = std::result::Result<T, E>>,
    F2: future::Future<Output = std::result::Result<T, E>>,
{
    RaceOk {
        f1: Some(f1),
        f2: Some(f2),
        f1_err: None,
        f2_err: None,
        _t: std::marker::PhantomData,
    }
    .await
}

struct RaceOk<T, E, F1, F2> {
    f1: Option<F1>,
    f2: Option<F2>,
    f1_err: Option<E>,
    f2_err: Option<E>,
    _t: std::marker::PhantomData<T>,
}

// Manual pin_project for the two option fields.
impl<T, E, F1, F2> std::future::Future for RaceOk<T, E, F1, F2>
where
    F1: future::Future<Output = std::result::Result<T, E>>,
    F2: future::Future<Output = std::result::Result<T, E>>,
{
    type Output = std::result::Result<T, E>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // SAFETY: we never move the inner futures out of the pin.
        let this = unsafe { self.get_unchecked_mut() };

        if let Some(f1) = this.f1.as_mut() {
            if let Poll::Ready(r) = unsafe { Pin::new_unchecked(f1) }.poll(cx) {
                this.f1 = None;
                match r {
                    Ok(v) => return Poll::Ready(Ok(v)),
                    Err(e) => this.f1_err = Some(e),
                }
            }
        }
        if let Some(f2) = this.f2.as_mut() {
            if let Poll::Ready(r) = unsafe { Pin::new_unchecked(f2) }.poll(cx) {
                this.f2 = None;
                match r {
                    Ok(v) => return Poll::Ready(Ok(v)),
                    Err(e) => this.f2_err = Some(e),
                }
            }
        }
        if this.f1.is_none() && this.f2.is_none() {
            // Both failed; prefer fallback (f2) error, else primary (f1)
            let err = this
                .f2_err
                .take()
                .or_else(|| this.f1_err.take())
                .expect("RaceOk completed with no stored error");
            return Poll::Ready(Err(err));
        }
        Poll::Pending
    }
}

/// Connect to a single address and set TCP_NODELAY to disable Nagle's algorithm.
pub(crate) async fn connect_tcp_addr(
    addr: SocketAddr,
    local_addr: Option<SocketAddr>,
    timeout: Option<Duration>,
) -> Result<TcpStream> {
    if let Some(local) = local_addr {
        // Need to bind a local address: use socket2 to create + configure,
        // then perform an async connect via async-io.
        let domain = if addr.is_ipv4() {
            Domain::IPV4
        } else {
            Domain::IPV6
        };
        let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).map_err(|err| {
            Error::with_source(ErrorKind::Transport, "failed to create tcp socket", err)
        })?;
        socket.set_nodelay(true).map_err(|err| {
            Error::with_source(ErrorKind::Transport, "failed to set TCP_NODELAY", err)
        })?;
        socket.bind(&local.into()).map_err(|err| {
            Error::with_source(ErrorKind::Transport, "failed to bind local tcp socket", err)
        })?;
        socket.set_nonblocking(true).map_err(|err| {
            Error::with_source(ErrorKind::Transport, "failed to set non-blocking", err)
        })?;

        // Initiate non-blocking connect; EINPROGRESS is expected and handled
        // by the async-io wrapper when we poll for writability.
        let _ = socket.connect(&addr.into()); // EINPROGRESS is OK

        let std_stream = std::net::TcpStream::from(socket);
        let stream = async_io::Async::new(std_stream).map_err(|err| {
            Error::with_source(ErrorKind::Transport, "failed to adopt tcp socket", err)
        })?;

        // Wait for the socket to become writable (connect completes).
        let connect_fut = async {
            stream.writable().await.map_err(|err| {
                Error::with_source(ErrorKind::Transport, "tcp connect failed", err)
            })?;
            // Check for deferred connect error via SO_ERROR.
            let so_error = stream.get_ref().take_error().map_err(|err| {
                Error::with_source(ErrorKind::Transport, "tcp connect failed", err)
            })?;
            if let Some(err) = so_error {
                return Err(Error::with_source(
                    ErrorKind::Transport,
                    "tcp connect failed",
                    err,
                ));
            }
            // Convert to async_net::TcpStream.
            let inner = stream.into_inner().map_err(|err| {
                Error::with_source(ErrorKind::Transport, "failed to finalize tcp stream", err)
            })?;
            async_net::TcpStream::try_from(inner).map_err(|err| {
                Error::with_source(ErrorKind::Transport, "failed to adopt tcp stream", err)
            })
        };

        return apply_timeout(connect_fut, timeout).await;
    }

    // No local bind: use async_net directly (simplest async connect path).
    let connect_fut = async move {
        async_net::TcpStream::connect(addr)
            .await
            .map_err(|err| Error::with_source(ErrorKind::Transport, "tcp connect failed", err))
    };
    let stream = apply_timeout(connect_fut, timeout).await?;
    stream.set_nodelay(true).map_err(|err| {
        Error::with_source(ErrorKind::Transport, "failed to set TCP_NODELAY", err)
    })?;
    Ok(stream)
}

async fn apply_timeout<T>(
    fut: impl future::Future<Output = Result<T>>,
    timeout: Option<Duration>,
) -> Result<T> {
    match timeout {
        None => fut.await,
        Some(dur) => {
            futures_lite::pin!(fut);
            future::or(fut, async {
                async_io::Timer::after(dur).await;
                Err(Error::new(ErrorKind::Transport, "tcp connect timed out"))
            })
            .await
        }
    }
}