Skip to main content

datum_net/
connection.rs

1//! Connection lifecycle utilities for `datum-net` transports.
2//!
3//! The lifecycle layer keeps connection establishment lazy: TCP connect, TLS
4//! handshake, timeout handling, and retry attempts start only when the returned
5//! Datum flow is materialized and pulled. Completing the upstream side of a
6//! connection byte flow gracefully shuts down the write direction while leaving
7//! the read direction open for the peer's response.
8
9use crate::tls::{TlsConnection, TokioTls, rustls, tls_flow_from_stream};
10use datum::{Flow, StreamCompletion, StreamError, StreamResult};
11use std::future::Future;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::net::{TcpStream, ToSocketAddrs};
15use tokio::runtime::Handle;
16use tokio::time::{sleep, timeout};
17use tokio_rustls::TlsConnector;
18use tokio_rustls::rustls::pki_types::ServerName;
19
20const DEFAULT_CHUNK_SIZE: usize = 8192;
21const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
22const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
23const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(100);
24const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(5);
25const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
26
27/// Retry settings for connection establishment.
28///
29/// `max_attempts` counts the initial attempt. Values below one are treated as
30/// one attempt so direct field construction cannot create a zero-attempt
31/// connection.
32#[derive(Debug, Clone, PartialEq)]
33pub struct RetryPolicy {
34    pub max_attempts: usize,
35    pub initial_backoff: Duration,
36    pub backoff_multiplier: f64,
37    pub max_backoff: Duration,
38}
39
40impl Default for RetryPolicy {
41    fn default() -> Self {
42        Self {
43            max_attempts: 1,
44            initial_backoff: DEFAULT_INITIAL_BACKOFF,
45            backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
46            max_backoff: DEFAULT_MAX_BACKOFF,
47        }
48    }
49}
50
51impl RetryPolicy {
52    /// Creates a retry policy with one attempt and exponential backoff defaults.
53    #[must_use]
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    #[must_use]
59    pub fn max_attempts(mut self, max_attempts: usize) -> Self {
60        self.max_attempts = max_attempts.max(1);
61        self
62    }
63
64    #[must_use]
65    pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
66        self.initial_backoff = initial_backoff;
67        self
68    }
69
70    #[must_use]
71    pub fn backoff_multiplier(mut self, backoff_multiplier: f64) -> Self {
72        self.backoff_multiplier = sane_multiplier(backoff_multiplier);
73        self
74    }
75
76    #[must_use]
77    pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
78        self.max_backoff = max_backoff;
79        self
80    }
81
82    fn attempts(&self) -> usize {
83        self.max_attempts.max(1)
84    }
85
86    fn backoff_after_attempt(&self, attempt: usize) -> Duration {
87        if self.initial_backoff.is_zero() || self.max_backoff.is_zero() {
88            return Duration::ZERO;
89        }
90
91        let multiplier = sane_multiplier(self.backoff_multiplier);
92        let exponent = attempt.saturating_sub(1).min(32) as i32;
93        let delay_secs = self.initial_backoff.as_secs_f64() * multiplier.powi(exponent);
94        let capped_secs = delay_secs.min(self.max_backoff.as_secs_f64());
95        Duration::from_secs_f64(capped_secs)
96    }
97}
98
99/// Connection establishment settings shared by lifecycle-aware transports.
100#[derive(Debug, Clone, PartialEq)]
101pub struct ConnectionSettings {
102    pub connect_timeout: Option<Duration>,
103    pub handshake_timeout: Option<Duration>,
104    pub retry_policy: RetryPolicy,
105}
106
107impl Default for ConnectionSettings {
108    fn default() -> Self {
109        Self {
110            connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
111            handshake_timeout: Some(DEFAULT_HANDSHAKE_TIMEOUT),
112            retry_policy: RetryPolicy::default(),
113        }
114    }
115}
116
117impl ConnectionSettings {
118    /// Creates lifecycle settings with bounded connect and TLS handshake time.
119    #[must_use]
120    pub fn new() -> Self {
121        Self::default()
122    }
123
124    #[must_use]
125    pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
126        self.connect_timeout = Some(connect_timeout);
127        self
128    }
129
130    #[must_use]
131    pub fn without_connect_timeout(mut self) -> Self {
132        self.connect_timeout = None;
133        self
134    }
135
136    #[must_use]
137    pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
138        self.handshake_timeout = Some(handshake_timeout);
139        self
140    }
141
142    #[must_use]
143    pub fn without_handshake_timeout(mut self) -> Self {
144        self.handshake_timeout = None;
145        self
146    }
147
148    #[must_use]
149    pub fn retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
150        self.retry_policy = retry_policy;
151        self
152    }
153}
154
155/// Namespace for transport-agnostic lifecycle constructors.
156pub struct Connection;
157
158impl Connection {
159    /// Opens a lifecycle-aware TLS client connection with the default chunk size.
160    #[must_use]
161    pub fn tls_client<A>(
162        addr: A,
163        server_name: ServerName<'static>,
164        client_config: Arc<rustls::ClientConfig>,
165        settings: ConnectionSettings,
166    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
167    where
168        A: ToSocketAddrs + Clone + Send + Sync + 'static,
169    {
170        TokioTls::outgoing_connection_with_lifecycle(addr, server_name, client_config, settings)
171    }
172
173    /// Marks a connection flow as using graceful half-close on upstream finish.
174    ///
175    /// Datum TCP/TLS connection flows already map upstream completion to
176    /// `AsyncWriteExt::shutdown()` and keep the read side alive. This helper is
177    /// an explicit API affordance for that behavior when a call site wants to
178    /// state the lifecycle intent.
179    #[must_use]
180    pub fn graceful_shutdown<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
181    where
182        Mat: Send + 'static,
183    {
184        flow.graceful_shutdown_on_upstream_finish()
185    }
186
187    /// Alias for [`Connection::graceful_shutdown`].
188    #[must_use]
189    pub fn half_close<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
190    where
191        Mat: Send + 'static,
192    {
193        Self::graceful_shutdown(flow)
194    }
195}
196
197/// Extension methods for connection byte flows.
198pub trait ConnectionLifecycleExt<Mat> {
199    /// Makes the half-close behavior explicit at the call site.
200    ///
201    /// Completing the upstream side of Datum TCP/TLS connection flows shuts
202    /// down the write direction and keeps the read direction alive. The method
203    /// returns the original flow because the transport sink already performs
204    /// the shutdown.
205    #[must_use]
206    fn graceful_shutdown_on_upstream_finish(self) -> Self;
207
208    /// Alias for [`ConnectionLifecycleExt::graceful_shutdown_on_upstream_finish`].
209    #[must_use]
210    fn half_close_on_upstream_finish(self) -> Self
211    where
212        Self: Sized,
213    {
214        self.graceful_shutdown_on_upstream_finish()
215    }
216}
217
218impl<Mat> ConnectionLifecycleExt<Mat> for Flow<Vec<u8>, Vec<u8>, Mat>
219where
220    Mat: Send + 'static,
221{
222    fn graceful_shutdown_on_upstream_finish(self) -> Self {
223        self
224    }
225}
226
227impl TokioTls {
228    /// Opens a lifecycle-aware TLS client connection using the default 8 KiB chunk size.
229    ///
230    /// TCP connect and TLS handshake are bounded by [`ConnectionSettings`] and
231    /// retried according to its [`RetryPolicy`]. A timeout or final retry
232    /// failure surfaces as a [`StreamError`] through the materialized
233    /// [`StreamCompletion`] and through the stream.
234    #[must_use]
235    pub fn outgoing_connection_with_lifecycle<A>(
236        addr: A,
237        server_name: ServerName<'static>,
238        client_config: Arc<rustls::ClientConfig>,
239        settings: ConnectionSettings,
240    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
241    where
242        A: ToSocketAddrs + Clone + Send + Sync + 'static,
243    {
244        Self::outgoing_connection_with_lifecycle_and_chunk_size(
245            addr,
246            server_name,
247            client_config,
248            settings,
249            DEFAULT_CHUNK_SIZE,
250        )
251    }
252
253    /// Opens a lifecycle-aware TLS client connection with an explicit chunk size.
254    #[must_use]
255    pub fn outgoing_connection_with_lifecycle_and_chunk_size<A>(
256        addr: A,
257        server_name: ServerName<'static>,
258        client_config: Arc<rustls::ClientConfig>,
259        settings: ConnectionSettings,
260        chunk_size: usize,
261    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
262    where
263        A: ToSocketAddrs + Clone + Send + Sync + 'static,
264    {
265        assert!(chunk_size > 0, "chunk size must be greater than zero");
266        Flow::future_flow(move || {
267            let addr = addr.clone();
268            let server_name = server_name.clone();
269            let client_config = Arc::clone(&client_config);
270            let settings = settings.clone();
271            async move {
272                let handle = Handle::current();
273                retry_tls_client_connect(
274                    addr,
275                    server_name,
276                    client_config,
277                    settings,
278                    handle,
279                    chunk_size,
280                )
281                .await
282            }
283        })
284    }
285}
286
287async fn retry_tls_client_connect<A>(
288    addr: A,
289    server_name: ServerName<'static>,
290    client_config: Arc<rustls::ClientConfig>,
291    settings: ConnectionSettings,
292    handle: Handle,
293    chunk_size: usize,
294) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
295where
296    A: ToSocketAddrs + Clone + Send + 'static,
297{
298    let attempts = settings.retry_policy.attempts();
299    for attempt in 1..=attempts {
300        match tls_client_connect_once(
301            addr.clone(),
302            server_name.clone(),
303            Arc::clone(&client_config),
304            &settings,
305            handle.clone(),
306            chunk_size,
307        )
308        .await
309        {
310            Ok(flow) => return Ok(flow),
311            Err(error) if attempt == attempts => {
312                return Err(final_retry_error(error, attempt));
313            }
314            Err(_) => {
315                let delay = settings.retry_policy.backoff_after_attempt(attempt);
316                if !delay.is_zero() {
317                    sleep(delay).await;
318                }
319            }
320        }
321    }
322    Err(StreamError::Failed(
323        "connection retry policy had no attempts".into(),
324    ))
325}
326
327async fn tls_client_connect_once<A>(
328    addr: A,
329    server_name: ServerName<'static>,
330    client_config: Arc<rustls::ClientConfig>,
331    settings: &ConnectionSettings,
332    handle: Handle,
333    chunk_size: usize,
334) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
335where
336    A: ToSocketAddrs + Send + 'static,
337{
338    let tcp = io_with_optional_timeout(
339        "TCP connect",
340        settings.connect_timeout,
341        TcpStream::connect(addr),
342    )
343    .await?;
344    let connection = TlsConnection {
345        local_addr: tcp.local_addr().map_err(io_error)?,
346        remote_addr: tcp.peer_addr().map_err(io_error)?,
347    };
348    let tls = io_with_optional_timeout(
349        "TLS handshake",
350        settings.handshake_timeout,
351        TlsConnector::from(client_config).connect(server_name, tcp),
352    )
353    .await?;
354    Ok(tls_flow_from_stream(tls, connection, handle, chunk_size))
355}
356
357async fn io_with_optional_timeout<T, Fut>(
358    operation: &'static str,
359    limit: Option<Duration>,
360    future: Fut,
361) -> StreamResult<T>
362where
363    Fut: Future<Output = std::io::Result<T>>,
364{
365    match limit {
366        Some(duration) => match timeout(duration, future).await {
367            Ok(Ok(value)) => Ok(value),
368            Ok(Err(error)) => Err(io_error(error)),
369            Err(_) => Err(StreamError::Failed(format!(
370                "{operation} timed out after {duration:?}"
371            ))),
372        },
373        None => future.await.map_err(io_error),
374    }
375}
376
377fn final_retry_error(error: StreamError, attempts: usize) -> StreamError {
378    if attempts <= 1 {
379        error
380    } else {
381        StreamError::Failed(format!(
382            "connection establishment failed after {attempts} attempts: {error}"
383        ))
384    }
385}
386
387fn io_error(error: std::io::Error) -> StreamError {
388    StreamError::Failed(error.to_string())
389}
390
391fn sane_multiplier(multiplier: f64) -> f64 {
392    if multiplier.is_finite() && multiplier >= 1.0 {
393        multiplier
394    } else {
395        1.0
396    }
397}