Skip to main content

datum_net/
connection.rs

1//! Connection lifecycle utilities for `datum-net` transports.
2//!
3//! The lifecycle layer keeps connection establishment lazy: TCP connect, TLS
4//! handshake, timeout handling, and retry attempts start only when the returned
5//! Datum flow is materialized and pulled. Completing the upstream side of a
6//! connection byte flow gracefully shuts down the write direction while leaving
7//! the read direction open for the peer's response.
8
9use crate::tls::{
10    TlsConnection, TokioTls, rustls, tls_connection_execution, tls_flow_from_stream_with_execution,
11};
12use datum::{Flow, StreamCompletion, StreamError, StreamResult};
13use std::future::Future;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::net::{TcpStream, ToSocketAddrs};
17use tokio::runtime::Handle;
18use tokio::time::{sleep, timeout};
19use tokio_rustls::TlsConnector;
20use tokio_rustls::rustls::pki_types::ServerName;
21
22const DEFAULT_CHUNK_SIZE: usize = 8192;
23const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
24const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
25const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(100);
26const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(5);
27const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
28
29/// Retry settings for connection establishment.
30///
31/// `max_attempts` counts the initial attempt. Values below one are treated as
32/// one attempt so direct field construction cannot create a zero-attempt
33/// connection.
34#[derive(Debug, Clone, PartialEq)]
35pub struct RetryPolicy {
36    pub max_attempts: usize,
37    pub initial_backoff: Duration,
38    pub backoff_multiplier: f64,
39    pub max_backoff: Duration,
40}
41
42impl Default for RetryPolicy {
43    fn default() -> Self {
44        Self {
45            max_attempts: 1,
46            initial_backoff: DEFAULT_INITIAL_BACKOFF,
47            backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
48            max_backoff: DEFAULT_MAX_BACKOFF,
49        }
50    }
51}
52
53impl RetryPolicy {
54    /// Creates a retry policy with one attempt and exponential backoff defaults.
55    #[must_use]
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    #[must_use]
61    pub fn max_attempts(mut self, max_attempts: usize) -> Self {
62        self.max_attempts = max_attempts.max(1);
63        self
64    }
65
66    #[must_use]
67    pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
68        self.initial_backoff = initial_backoff;
69        self
70    }
71
72    #[must_use]
73    pub fn backoff_multiplier(mut self, backoff_multiplier: f64) -> Self {
74        self.backoff_multiplier = sane_multiplier(backoff_multiplier);
75        self
76    }
77
78    #[must_use]
79    pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
80        self.max_backoff = max_backoff;
81        self
82    }
83
84    fn attempts(&self) -> usize {
85        self.max_attempts.max(1)
86    }
87
88    fn backoff_after_attempt(&self, attempt: usize) -> Duration {
89        if self.initial_backoff.is_zero() || self.max_backoff.is_zero() {
90            return Duration::ZERO;
91        }
92
93        let multiplier = sane_multiplier(self.backoff_multiplier);
94        let exponent = attempt.saturating_sub(1).min(32) as i32;
95        let delay_secs = self.initial_backoff.as_secs_f64() * multiplier.powi(exponent);
96        let capped_secs = delay_secs.min(self.max_backoff.as_secs_f64());
97        Duration::from_secs_f64(capped_secs)
98    }
99}
100
101/// Connection establishment settings shared by lifecycle-aware transports.
102#[derive(Debug, Clone, PartialEq)]
103pub struct ConnectionSettings {
104    pub connect_timeout: Option<Duration>,
105    pub handshake_timeout: Option<Duration>,
106    pub retry_policy: RetryPolicy,
107}
108
109impl Default for ConnectionSettings {
110    fn default() -> Self {
111        Self {
112            connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
113            handshake_timeout: Some(DEFAULT_HANDSHAKE_TIMEOUT),
114            retry_policy: RetryPolicy::default(),
115        }
116    }
117}
118
119impl ConnectionSettings {
120    /// Creates lifecycle settings with bounded connect and TLS handshake time.
121    #[must_use]
122    pub fn new() -> Self {
123        Self::default()
124    }
125
126    #[must_use]
127    pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
128        self.connect_timeout = Some(connect_timeout);
129        self
130    }
131
132    #[must_use]
133    pub fn without_connect_timeout(mut self) -> Self {
134        self.connect_timeout = None;
135        self
136    }
137
138    #[must_use]
139    pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
140        self.handshake_timeout = Some(handshake_timeout);
141        self
142    }
143
144    #[must_use]
145    pub fn without_handshake_timeout(mut self) -> Self {
146        self.handshake_timeout = None;
147        self
148    }
149
150    #[must_use]
151    pub fn retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
152        self.retry_policy = retry_policy;
153        self
154    }
155}
156
157/// Namespace for transport-agnostic lifecycle constructors.
158pub struct Connection;
159
160impl Connection {
161    /// Opens a lifecycle-aware TLS client connection with the default chunk size.
162    #[must_use]
163    pub fn tls_client<A>(
164        addr: A,
165        server_name: ServerName<'static>,
166        client_config: Arc<rustls::ClientConfig>,
167        settings: ConnectionSettings,
168    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
169    where
170        A: ToSocketAddrs + Clone + Send + Sync + 'static,
171    {
172        TokioTls::outgoing_connection_with_lifecycle(addr, server_name, client_config, settings)
173    }
174
175    /// Marks a connection flow as using graceful half-close on upstream finish.
176    ///
177    /// Datum TCP/TLS connection flows already map upstream completion to
178    /// `AsyncWriteExt::shutdown()` and keep the read side alive. This helper is
179    /// an explicit API affordance for that behavior when a call site wants to
180    /// state the lifecycle intent.
181    #[must_use]
182    pub fn graceful_shutdown<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
183    where
184        Mat: Send + 'static,
185    {
186        flow.graceful_shutdown_on_upstream_finish()
187    }
188
189    /// Alias for [`Connection::graceful_shutdown`].
190    #[must_use]
191    pub fn half_close<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
192    where
193        Mat: Send + 'static,
194    {
195        Self::graceful_shutdown(flow)
196    }
197}
198
199/// Extension methods for connection byte flows.
200pub trait ConnectionLifecycleExt<Mat> {
201    /// Makes the half-close behavior explicit at the call site.
202    ///
203    /// Completing the upstream side of Datum TCP/TLS connection flows shuts
204    /// down the write direction and keeps the read direction alive. The method
205    /// returns the original flow because the transport sink already performs
206    /// the shutdown.
207    #[must_use]
208    fn graceful_shutdown_on_upstream_finish(self) -> Self;
209
210    /// Alias for [`ConnectionLifecycleExt::graceful_shutdown_on_upstream_finish`].
211    #[must_use]
212    fn half_close_on_upstream_finish(self) -> Self
213    where
214        Self: Sized,
215    {
216        self.graceful_shutdown_on_upstream_finish()
217    }
218}
219
220impl<Mat> ConnectionLifecycleExt<Mat> for Flow<Vec<u8>, Vec<u8>, Mat>
221where
222    Mat: Send + 'static,
223{
224    fn graceful_shutdown_on_upstream_finish(self) -> Self {
225        self
226    }
227}
228
229impl TokioTls {
230    /// Opens a lifecycle-aware TLS client connection using the default 8 KiB chunk size.
231    ///
232    /// TCP connect and TLS handshake are bounded by [`ConnectionSettings`] and
233    /// retried according to its [`RetryPolicy`]. A timeout or final retry
234    /// failure surfaces as a [`StreamError`] through the materialized
235    /// [`StreamCompletion`] and through the stream.
236    #[must_use]
237    pub fn outgoing_connection_with_lifecycle<A>(
238        addr: A,
239        server_name: ServerName<'static>,
240        client_config: Arc<rustls::ClientConfig>,
241        settings: ConnectionSettings,
242    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
243    where
244        A: ToSocketAddrs + Clone + Send + Sync + 'static,
245    {
246        Self::outgoing_connection_with_lifecycle_and_chunk_size(
247            addr,
248            server_name,
249            client_config,
250            settings,
251            DEFAULT_CHUNK_SIZE,
252        )
253    }
254
255    /// Opens a lifecycle-aware TLS client connection with an explicit chunk size.
256    #[must_use]
257    pub fn outgoing_connection_with_lifecycle_and_chunk_size<A>(
258        addr: A,
259        server_name: ServerName<'static>,
260        client_config: Arc<rustls::ClientConfig>,
261        settings: ConnectionSettings,
262        chunk_size: usize,
263    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
264    where
265        A: ToSocketAddrs + Clone + Send + Sync + 'static,
266    {
267        assert!(chunk_size > 0, "chunk size must be greater than zero");
268        Flow::future_flow(move || {
269            let addr = addr.clone();
270            let server_name = server_name.clone();
271            let client_config = Arc::clone(&client_config);
272            let settings = settings.clone();
273            async move {
274                let handle = Handle::current();
275                retry_tls_client_connect(
276                    addr,
277                    server_name,
278                    client_config,
279                    settings,
280                    handle,
281                    chunk_size,
282                )
283                .await
284            }
285        })
286    }
287}
288
289async fn retry_tls_client_connect<A>(
290    addr: A,
291    server_name: ServerName<'static>,
292    client_config: Arc<rustls::ClientConfig>,
293    settings: ConnectionSettings,
294    handle: Handle,
295    chunk_size: usize,
296) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
297where
298    A: ToSocketAddrs + Clone + Send + 'static,
299{
300    let attempts = settings.retry_policy.attempts();
301    for attempt in 1..=attempts {
302        match tls_client_connect_once(
303            addr.clone(),
304            server_name.clone(),
305            Arc::clone(&client_config),
306            &settings,
307            handle.clone(),
308            chunk_size,
309        )
310        .await
311        {
312            Ok(flow) => return Ok(flow),
313            Err(error) if attempt == attempts => {
314                return Err(final_retry_error(error, attempt));
315            }
316            Err(_) => {
317                let delay = settings.retry_policy.backoff_after_attempt(attempt);
318                if !delay.is_zero() {
319                    sleep(delay).await;
320                }
321            }
322        }
323    }
324    Err(StreamError::Failed(
325        "connection retry policy had no attempts".into(),
326    ))
327}
328
329async fn tls_client_connect_once<A>(
330    addr: A,
331    server_name: ServerName<'static>,
332    client_config: Arc<rustls::ClientConfig>,
333    settings: &ConnectionSettings,
334    handle: Handle,
335    chunk_size: usize,
336) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
337where
338    A: ToSocketAddrs + Send + 'static,
339{
340    let execution = tls_connection_execution(handle);
341    let connect_timeout = settings.connect_timeout;
342    let handshake_timeout = settings.handshake_timeout;
343    let (tls, connection) = execution
344        .run(async move {
345            let tcp =
346                io_with_optional_timeout("TCP connect", connect_timeout, TcpStream::connect(addr))
347                    .await?;
348            let connection = TlsConnection {
349                local_addr: tcp.local_addr().map_err(io_error)?,
350                remote_addr: tcp.peer_addr().map_err(io_error)?,
351            };
352            let tls = io_with_optional_timeout(
353                "TLS handshake",
354                handshake_timeout,
355                TlsConnector::from(client_config).connect(server_name, tcp),
356            )
357            .await?;
358            Ok((tls, connection))
359        })
360        .await?;
361    Ok(tls_flow_from_stream_with_execution(
362        tls, connection, execution, chunk_size,
363    ))
364}
365
366async fn io_with_optional_timeout<T, Fut>(
367    operation: &'static str,
368    limit: Option<Duration>,
369    future: Fut,
370) -> StreamResult<T>
371where
372    Fut: Future<Output = std::io::Result<T>>,
373{
374    match limit {
375        Some(duration) => match timeout(duration, future).await {
376            Ok(Ok(value)) => Ok(value),
377            Ok(Err(error)) => Err(io_error(error)),
378            Err(_) => Err(StreamError::Failed(format!(
379                "{operation} timed out after {duration:?}"
380            ))),
381        },
382        None => future.await.map_err(io_error),
383    }
384}
385
386fn final_retry_error(error: StreamError, attempts: usize) -> StreamError {
387    if attempts <= 1 {
388        error
389    } else {
390        StreamError::Failed(format!(
391            "connection establishment failed after {attempts} attempts: {error}"
392        ))
393    }
394}
395
396fn io_error(error: std::io::Error) -> StreamError {
397    StreamError::Failed(error.to_string())
398}
399
400fn sane_multiplier(multiplier: f64) -> f64 {
401    if multiplier.is_finite() && multiplier >= 1.0 {
402        multiplier
403    } else {
404        1.0
405    }
406}