datum-net 0.6.0

Network sources and sinks for Datum streams, built on datum-core
Documentation
//! Connection lifecycle utilities for `datum-net` transports.
//!
//! The lifecycle layer keeps connection establishment lazy: TCP connect, TLS
//! handshake, timeout handling, and retry attempts start only when the returned
//! Datum flow is materialized and pulled. Completing the upstream side of a
//! connection byte flow gracefully shuts down the write direction while leaving
//! the read direction open for the peer's response.

use crate::tls::{TlsConnection, TokioTls, rustls, tls_flow_from_stream};
use datum::{Flow, StreamCompletion, StreamError, StreamResult};
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::runtime::Handle;
use tokio::time::{sleep, timeout};
use tokio_rustls::TlsConnector;
use tokio_rustls::rustls::pki_types::ServerName;

const DEFAULT_CHUNK_SIZE: usize = 8192;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(100);
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(5);
const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;

/// Retry settings for connection establishment.
///
/// `max_attempts` counts the initial attempt. Values below one are treated as
/// one attempt so direct field construction cannot create a zero-attempt
/// connection.
#[derive(Debug, Clone, PartialEq)]
pub struct RetryPolicy {
    pub max_attempts: usize,
    pub initial_backoff: Duration,
    pub backoff_multiplier: f64,
    pub max_backoff: Duration,
}

impl Default for RetryPolicy {
    fn default() -> Self {
        Self {
            max_attempts: 1,
            initial_backoff: DEFAULT_INITIAL_BACKOFF,
            backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
            max_backoff: DEFAULT_MAX_BACKOFF,
        }
    }
}

impl RetryPolicy {
    /// Creates a retry policy with one attempt and exponential backoff defaults.
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }

    #[must_use]
    pub fn max_attempts(mut self, max_attempts: usize) -> Self {
        self.max_attempts = max_attempts.max(1);
        self
    }

    #[must_use]
    pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
        self.initial_backoff = initial_backoff;
        self
    }

    #[must_use]
    pub fn backoff_multiplier(mut self, backoff_multiplier: f64) -> Self {
        self.backoff_multiplier = sane_multiplier(backoff_multiplier);
        self
    }

    #[must_use]
    pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
        self.max_backoff = max_backoff;
        self
    }

    fn attempts(&self) -> usize {
        self.max_attempts.max(1)
    }

    fn backoff_after_attempt(&self, attempt: usize) -> Duration {
        if self.initial_backoff.is_zero() || self.max_backoff.is_zero() {
            return Duration::ZERO;
        }

        let multiplier = sane_multiplier(self.backoff_multiplier);
        let exponent = attempt.saturating_sub(1).min(32) as i32;
        let delay_secs = self.initial_backoff.as_secs_f64() * multiplier.powi(exponent);
        let capped_secs = delay_secs.min(self.max_backoff.as_secs_f64());
        Duration::from_secs_f64(capped_secs)
    }
}

/// Connection establishment settings shared by lifecycle-aware transports.
#[derive(Debug, Clone, PartialEq)]
pub struct ConnectionSettings {
    pub connect_timeout: Option<Duration>,
    pub handshake_timeout: Option<Duration>,
    pub retry_policy: RetryPolicy,
}

impl Default for ConnectionSettings {
    fn default() -> Self {
        Self {
            connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
            handshake_timeout: Some(DEFAULT_HANDSHAKE_TIMEOUT),
            retry_policy: RetryPolicy::default(),
        }
    }
}

impl ConnectionSettings {
    /// Creates lifecycle settings with bounded connect and TLS handshake time.
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }

    #[must_use]
    pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
        self.connect_timeout = Some(connect_timeout);
        self
    }

    #[must_use]
    pub fn without_connect_timeout(mut self) -> Self {
        self.connect_timeout = None;
        self
    }

    #[must_use]
    pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
        self.handshake_timeout = Some(handshake_timeout);
        self
    }

    #[must_use]
    pub fn without_handshake_timeout(mut self) -> Self {
        self.handshake_timeout = None;
        self
    }

    #[must_use]
    pub fn retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
        self.retry_policy = retry_policy;
        self
    }
}

/// Namespace for transport-agnostic lifecycle constructors.
pub struct Connection;

impl Connection {
    /// Opens a lifecycle-aware TLS client connection with the default chunk size.
    #[must_use]
    pub fn tls_client<A>(
        addr: A,
        server_name: ServerName<'static>,
        client_config: Arc<rustls::ClientConfig>,
        settings: ConnectionSettings,
    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
    where
        A: ToSocketAddrs + Clone + Send + Sync + 'static,
    {
        TokioTls::outgoing_connection_with_lifecycle(addr, server_name, client_config, settings)
    }

    /// Marks a connection flow as using graceful half-close on upstream finish.
    ///
    /// Datum TCP/TLS connection flows already map upstream completion to
    /// `AsyncWriteExt::shutdown()` and keep the read side alive. This helper is
    /// an explicit API affordance for that behavior when a call site wants to
    /// state the lifecycle intent.
    #[must_use]
    pub fn graceful_shutdown<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
    where
        Mat: Send + 'static,
    {
        flow.graceful_shutdown_on_upstream_finish()
    }

    /// Alias for [`Connection::graceful_shutdown`].
    #[must_use]
    pub fn half_close<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
    where
        Mat: Send + 'static,
    {
        Self::graceful_shutdown(flow)
    }
}

/// Extension methods for connection byte flows.
pub trait ConnectionLifecycleExt<Mat> {
    /// Makes the half-close behavior explicit at the call site.
    ///
    /// Completing the upstream side of Datum TCP/TLS connection flows shuts
    /// down the write direction and keeps the read direction alive. The method
    /// returns the original flow because the transport sink already performs
    /// the shutdown.
    #[must_use]
    fn graceful_shutdown_on_upstream_finish(self) -> Self;

    /// Alias for [`ConnectionLifecycleExt::graceful_shutdown_on_upstream_finish`].
    #[must_use]
    fn half_close_on_upstream_finish(self) -> Self
    where
        Self: Sized,
    {
        self.graceful_shutdown_on_upstream_finish()
    }
}

impl<Mat> ConnectionLifecycleExt<Mat> for Flow<Vec<u8>, Vec<u8>, Mat>
where
    Mat: Send + 'static,
{
    fn graceful_shutdown_on_upstream_finish(self) -> Self {
        self
    }
}

impl TokioTls {
    /// Opens a lifecycle-aware TLS client connection using the default 8 KiB chunk size.
    ///
    /// TCP connect and TLS handshake are bounded by [`ConnectionSettings`] and
    /// retried according to its [`RetryPolicy`]. A timeout or final retry
    /// failure surfaces as a [`StreamError`] through the materialized
    /// [`StreamCompletion`] and through the stream.
    #[must_use]
    pub fn outgoing_connection_with_lifecycle<A>(
        addr: A,
        server_name: ServerName<'static>,
        client_config: Arc<rustls::ClientConfig>,
        settings: ConnectionSettings,
    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
    where
        A: ToSocketAddrs + Clone + Send + Sync + 'static,
    {
        Self::outgoing_connection_with_lifecycle_and_chunk_size(
            addr,
            server_name,
            client_config,
            settings,
            DEFAULT_CHUNK_SIZE,
        )
    }

    /// Opens a lifecycle-aware TLS client connection with an explicit chunk size.
    #[must_use]
    pub fn outgoing_connection_with_lifecycle_and_chunk_size<A>(
        addr: A,
        server_name: ServerName<'static>,
        client_config: Arc<rustls::ClientConfig>,
        settings: ConnectionSettings,
        chunk_size: usize,
    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
    where
        A: ToSocketAddrs + Clone + Send + Sync + 'static,
    {
        assert!(chunk_size > 0, "chunk size must be greater than zero");
        Flow::future_flow(move || {
            let addr = addr.clone();
            let server_name = server_name.clone();
            let client_config = Arc::clone(&client_config);
            let settings = settings.clone();
            async move {
                let handle = Handle::current();
                retry_tls_client_connect(
                    addr,
                    server_name,
                    client_config,
                    settings,
                    handle,
                    chunk_size,
                )
                .await
            }
        })
    }
}

async fn retry_tls_client_connect<A>(
    addr: A,
    server_name: ServerName<'static>,
    client_config: Arc<rustls::ClientConfig>,
    settings: ConnectionSettings,
    handle: Handle,
    chunk_size: usize,
) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
where
    A: ToSocketAddrs + Clone + Send + 'static,
{
    let attempts = settings.retry_policy.attempts();
    for attempt in 1..=attempts {
        match tls_client_connect_once(
            addr.clone(),
            server_name.clone(),
            Arc::clone(&client_config),
            &settings,
            handle.clone(),
            chunk_size,
        )
        .await
        {
            Ok(flow) => return Ok(flow),
            Err(error) if attempt == attempts => {
                return Err(final_retry_error(error, attempt));
            }
            Err(_) => {
                let delay = settings.retry_policy.backoff_after_attempt(attempt);
                if !delay.is_zero() {
                    sleep(delay).await;
                }
            }
        }
    }
    Err(StreamError::Failed(
        "connection retry policy had no attempts".into(),
    ))
}

async fn tls_client_connect_once<A>(
    addr: A,
    server_name: ServerName<'static>,
    client_config: Arc<rustls::ClientConfig>,
    settings: &ConnectionSettings,
    handle: Handle,
    chunk_size: usize,
) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
where
    A: ToSocketAddrs + Send + 'static,
{
    let tcp = io_with_optional_timeout(
        "TCP connect",
        settings.connect_timeout,
        TcpStream::connect(addr),
    )
    .await?;
    let connection = TlsConnection {
        local_addr: tcp.local_addr().map_err(io_error)?,
        remote_addr: tcp.peer_addr().map_err(io_error)?,
    };
    let tls = io_with_optional_timeout(
        "TLS handshake",
        settings.handshake_timeout,
        TlsConnector::from(client_config).connect(server_name, tcp),
    )
    .await?;
    Ok(tls_flow_from_stream(tls, connection, handle, chunk_size))
}

async fn io_with_optional_timeout<T, Fut>(
    operation: &'static str,
    limit: Option<Duration>,
    future: Fut,
) -> StreamResult<T>
where
    Fut: Future<Output = std::io::Result<T>>,
{
    match limit {
        Some(duration) => match timeout(duration, future).await {
            Ok(Ok(value)) => Ok(value),
            Ok(Err(error)) => Err(io_error(error)),
            Err(_) => Err(StreamError::Failed(format!(
                "{operation} timed out after {duration:?}"
            ))),
        },
        None => future.await.map_err(io_error),
    }
}

fn final_retry_error(error: StreamError, attempts: usize) -> StreamError {
    if attempts <= 1 {
        error
    } else {
        StreamError::Failed(format!(
            "connection establishment failed after {attempts} attempts: {error}"
        ))
    }
}

fn io_error(error: std::io::Error) -> StreamError {
    StreamError::Failed(error.to_string())
}

fn sane_multiplier(multiplier: f64) -> f64 {
    if multiplier.is_finite() && multiplier >= 1.0 {
        multiplier
    } else {
        1.0
    }
}