1use crate::tls::{
10 TlsConnection, TokioTls, rustls, tls_connection_execution, tls_flow_from_stream_with_execution,
11};
12use datum::{Flow, StreamCompletion, StreamError, StreamResult};
13use std::future::Future;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::net::{TcpStream, ToSocketAddrs};
17use tokio::runtime::Handle;
18use tokio::time::{sleep, timeout};
19use tokio_rustls::TlsConnector;
20use tokio_rustls::rustls::pki_types::ServerName;
21
22const DEFAULT_CHUNK_SIZE: usize = 8192;
23const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
24const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
25const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(100);
26const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(5);
27const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
28
29#[derive(Debug, Clone, PartialEq)]
35pub struct RetryPolicy {
36 pub max_attempts: usize,
37 pub initial_backoff: Duration,
38 pub backoff_multiplier: f64,
39 pub max_backoff: Duration,
40}
41
42impl Default for RetryPolicy {
43 fn default() -> Self {
44 Self {
45 max_attempts: 1,
46 initial_backoff: DEFAULT_INITIAL_BACKOFF,
47 backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
48 max_backoff: DEFAULT_MAX_BACKOFF,
49 }
50 }
51}
52
53impl RetryPolicy {
54 #[must_use]
56 pub fn new() -> Self {
57 Self::default()
58 }
59
60 #[must_use]
61 pub fn max_attempts(mut self, max_attempts: usize) -> Self {
62 self.max_attempts = max_attempts.max(1);
63 self
64 }
65
66 #[must_use]
67 pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
68 self.initial_backoff = initial_backoff;
69 self
70 }
71
72 #[must_use]
73 pub fn backoff_multiplier(mut self, backoff_multiplier: f64) -> Self {
74 self.backoff_multiplier = sane_multiplier(backoff_multiplier);
75 self
76 }
77
78 #[must_use]
79 pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
80 self.max_backoff = max_backoff;
81 self
82 }
83
84 fn attempts(&self) -> usize {
85 self.max_attempts.max(1)
86 }
87
88 fn backoff_after_attempt(&self, attempt: usize) -> Duration {
89 if self.initial_backoff.is_zero() || self.max_backoff.is_zero() {
90 return Duration::ZERO;
91 }
92
93 let multiplier = sane_multiplier(self.backoff_multiplier);
94 let exponent = attempt.saturating_sub(1).min(32) as i32;
95 let delay_secs = self.initial_backoff.as_secs_f64() * multiplier.powi(exponent);
96 let capped_secs = delay_secs.min(self.max_backoff.as_secs_f64());
97 Duration::from_secs_f64(capped_secs)
98 }
99}
100
101#[derive(Debug, Clone, PartialEq)]
103pub struct ConnectionSettings {
104 pub connect_timeout: Option<Duration>,
105 pub handshake_timeout: Option<Duration>,
106 pub retry_policy: RetryPolicy,
107}
108
109impl Default for ConnectionSettings {
110 fn default() -> Self {
111 Self {
112 connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
113 handshake_timeout: Some(DEFAULT_HANDSHAKE_TIMEOUT),
114 retry_policy: RetryPolicy::default(),
115 }
116 }
117}
118
119impl ConnectionSettings {
120 #[must_use]
122 pub fn new() -> Self {
123 Self::default()
124 }
125
126 #[must_use]
127 pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
128 self.connect_timeout = Some(connect_timeout);
129 self
130 }
131
132 #[must_use]
133 pub fn without_connect_timeout(mut self) -> Self {
134 self.connect_timeout = None;
135 self
136 }
137
138 #[must_use]
139 pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
140 self.handshake_timeout = Some(handshake_timeout);
141 self
142 }
143
144 #[must_use]
145 pub fn without_handshake_timeout(mut self) -> Self {
146 self.handshake_timeout = None;
147 self
148 }
149
150 #[must_use]
151 pub fn retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
152 self.retry_policy = retry_policy;
153 self
154 }
155}
156
157pub struct Connection;
159
160impl Connection {
161 #[must_use]
163 pub fn tls_client<A>(
164 addr: A,
165 server_name: ServerName<'static>,
166 client_config: Arc<rustls::ClientConfig>,
167 settings: ConnectionSettings,
168 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
169 where
170 A: ToSocketAddrs + Clone + Send + Sync + 'static,
171 {
172 TokioTls::outgoing_connection_with_lifecycle(addr, server_name, client_config, settings)
173 }
174
175 #[must_use]
182 pub fn graceful_shutdown<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
183 where
184 Mat: Send + 'static,
185 {
186 flow.graceful_shutdown_on_upstream_finish()
187 }
188
189 #[must_use]
191 pub fn half_close<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
192 where
193 Mat: Send + 'static,
194 {
195 Self::graceful_shutdown(flow)
196 }
197}
198
199pub trait ConnectionLifecycleExt<Mat> {
201 #[must_use]
208 fn graceful_shutdown_on_upstream_finish(self) -> Self;
209
210 #[must_use]
212 fn half_close_on_upstream_finish(self) -> Self
213 where
214 Self: Sized,
215 {
216 self.graceful_shutdown_on_upstream_finish()
217 }
218}
219
220impl<Mat> ConnectionLifecycleExt<Mat> for Flow<Vec<u8>, Vec<u8>, Mat>
221where
222 Mat: Send + 'static,
223{
224 fn graceful_shutdown_on_upstream_finish(self) -> Self {
225 self
226 }
227}
228
229impl TokioTls {
230 #[must_use]
237 pub fn outgoing_connection_with_lifecycle<A>(
238 addr: A,
239 server_name: ServerName<'static>,
240 client_config: Arc<rustls::ClientConfig>,
241 settings: ConnectionSettings,
242 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
243 where
244 A: ToSocketAddrs + Clone + Send + Sync + 'static,
245 {
246 Self::outgoing_connection_with_lifecycle_and_chunk_size(
247 addr,
248 server_name,
249 client_config,
250 settings,
251 DEFAULT_CHUNK_SIZE,
252 )
253 }
254
255 #[must_use]
257 pub fn outgoing_connection_with_lifecycle_and_chunk_size<A>(
258 addr: A,
259 server_name: ServerName<'static>,
260 client_config: Arc<rustls::ClientConfig>,
261 settings: ConnectionSettings,
262 chunk_size: usize,
263 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
264 where
265 A: ToSocketAddrs + Clone + Send + Sync + 'static,
266 {
267 assert!(chunk_size > 0, "chunk size must be greater than zero");
268 Flow::future_flow(move || {
269 let addr = addr.clone();
270 let server_name = server_name.clone();
271 let client_config = Arc::clone(&client_config);
272 let settings = settings.clone();
273 async move {
274 let handle = Handle::current();
275 retry_tls_client_connect(
276 addr,
277 server_name,
278 client_config,
279 settings,
280 handle,
281 chunk_size,
282 )
283 .await
284 }
285 })
286 }
287}
288
289async fn retry_tls_client_connect<A>(
290 addr: A,
291 server_name: ServerName<'static>,
292 client_config: Arc<rustls::ClientConfig>,
293 settings: ConnectionSettings,
294 handle: Handle,
295 chunk_size: usize,
296) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
297where
298 A: ToSocketAddrs + Clone + Send + 'static,
299{
300 let attempts = settings.retry_policy.attempts();
301 for attempt in 1..=attempts {
302 match tls_client_connect_once(
303 addr.clone(),
304 server_name.clone(),
305 Arc::clone(&client_config),
306 &settings,
307 handle.clone(),
308 chunk_size,
309 )
310 .await
311 {
312 Ok(flow) => return Ok(flow),
313 Err(error) if attempt == attempts => {
314 return Err(final_retry_error(error, attempt));
315 }
316 Err(_) => {
317 let delay = settings.retry_policy.backoff_after_attempt(attempt);
318 if !delay.is_zero() {
319 sleep(delay).await;
320 }
321 }
322 }
323 }
324 Err(StreamError::Failed(
325 "connection retry policy had no attempts".into(),
326 ))
327}
328
329async fn tls_client_connect_once<A>(
330 addr: A,
331 server_name: ServerName<'static>,
332 client_config: Arc<rustls::ClientConfig>,
333 settings: &ConnectionSettings,
334 handle: Handle,
335 chunk_size: usize,
336) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
337where
338 A: ToSocketAddrs + Send + 'static,
339{
340 let execution = tls_connection_execution(handle);
341 let connect_timeout = settings.connect_timeout;
342 let handshake_timeout = settings.handshake_timeout;
343 let (tls, connection) = execution
344 .run(async move {
345 let tcp =
346 io_with_optional_timeout("TCP connect", connect_timeout, TcpStream::connect(addr))
347 .await?;
348 let connection = TlsConnection {
349 local_addr: tcp.local_addr().map_err(io_error)?,
350 remote_addr: tcp.peer_addr().map_err(io_error)?,
351 };
352 let tls = io_with_optional_timeout(
353 "TLS handshake",
354 handshake_timeout,
355 TlsConnector::from(client_config).connect(server_name, tcp),
356 )
357 .await?;
358 Ok((tls, connection))
359 })
360 .await?;
361 Ok(tls_flow_from_stream_with_execution(
362 tls, connection, execution, chunk_size,
363 ))
364}
365
366async fn io_with_optional_timeout<T, Fut>(
367 operation: &'static str,
368 limit: Option<Duration>,
369 future: Fut,
370) -> StreamResult<T>
371where
372 Fut: Future<Output = std::io::Result<T>>,
373{
374 match limit {
375 Some(duration) => match timeout(duration, future).await {
376 Ok(Ok(value)) => Ok(value),
377 Ok(Err(error)) => Err(io_error(error)),
378 Err(_) => Err(StreamError::Failed(format!(
379 "{operation} timed out after {duration:?}"
380 ))),
381 },
382 None => future.await.map_err(io_error),
383 }
384}
385
386fn final_retry_error(error: StreamError, attempts: usize) -> StreamError {
387 if attempts <= 1 {
388 error
389 } else {
390 StreamError::Failed(format!(
391 "connection establishment failed after {attempts} attempts: {error}"
392 ))
393 }
394}
395
396fn io_error(error: std::io::Error) -> StreamError {
397 StreamError::Failed(error.to_string())
398}
399
400fn sane_multiplier(multiplier: f64) -> f64 {
401 if multiplier.is_finite() && multiplier >= 1.0 {
402 multiplier
403 } else {
404 1.0
405 }
406}