1use crate::tls::{TlsConnection, TokioTls, rustls, tls_flow_from_stream};
10use datum::{Flow, StreamCompletion, StreamError, StreamResult};
11use std::future::Future;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::net::{TcpStream, ToSocketAddrs};
15use tokio::runtime::Handle;
16use tokio::time::{sleep, timeout};
17use tokio_rustls::TlsConnector;
18use tokio_rustls::rustls::pki_types::ServerName;
19
20const DEFAULT_CHUNK_SIZE: usize = 8192;
21const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
22const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
23const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(100);
24const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(5);
25const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
26
27#[derive(Debug, Clone, PartialEq)]
33pub struct RetryPolicy {
34 pub max_attempts: usize,
35 pub initial_backoff: Duration,
36 pub backoff_multiplier: f64,
37 pub max_backoff: Duration,
38}
39
40impl Default for RetryPolicy {
41 fn default() -> Self {
42 Self {
43 max_attempts: 1,
44 initial_backoff: DEFAULT_INITIAL_BACKOFF,
45 backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
46 max_backoff: DEFAULT_MAX_BACKOFF,
47 }
48 }
49}
50
51impl RetryPolicy {
52 #[must_use]
54 pub fn new() -> Self {
55 Self::default()
56 }
57
58 #[must_use]
59 pub fn max_attempts(mut self, max_attempts: usize) -> Self {
60 self.max_attempts = max_attempts.max(1);
61 self
62 }
63
64 #[must_use]
65 pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
66 self.initial_backoff = initial_backoff;
67 self
68 }
69
70 #[must_use]
71 pub fn backoff_multiplier(mut self, backoff_multiplier: f64) -> Self {
72 self.backoff_multiplier = sane_multiplier(backoff_multiplier);
73 self
74 }
75
76 #[must_use]
77 pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
78 self.max_backoff = max_backoff;
79 self
80 }
81
82 fn attempts(&self) -> usize {
83 self.max_attempts.max(1)
84 }
85
86 fn backoff_after_attempt(&self, attempt: usize) -> Duration {
87 if self.initial_backoff.is_zero() || self.max_backoff.is_zero() {
88 return Duration::ZERO;
89 }
90
91 let multiplier = sane_multiplier(self.backoff_multiplier);
92 let exponent = attempt.saturating_sub(1).min(32) as i32;
93 let delay_secs = self.initial_backoff.as_secs_f64() * multiplier.powi(exponent);
94 let capped_secs = delay_secs.min(self.max_backoff.as_secs_f64());
95 Duration::from_secs_f64(capped_secs)
96 }
97}
98
99#[derive(Debug, Clone, PartialEq)]
101pub struct ConnectionSettings {
102 pub connect_timeout: Option<Duration>,
103 pub handshake_timeout: Option<Duration>,
104 pub retry_policy: RetryPolicy,
105}
106
107impl Default for ConnectionSettings {
108 fn default() -> Self {
109 Self {
110 connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
111 handshake_timeout: Some(DEFAULT_HANDSHAKE_TIMEOUT),
112 retry_policy: RetryPolicy::default(),
113 }
114 }
115}
116
117impl ConnectionSettings {
118 #[must_use]
120 pub fn new() -> Self {
121 Self::default()
122 }
123
124 #[must_use]
125 pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
126 self.connect_timeout = Some(connect_timeout);
127 self
128 }
129
130 #[must_use]
131 pub fn without_connect_timeout(mut self) -> Self {
132 self.connect_timeout = None;
133 self
134 }
135
136 #[must_use]
137 pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
138 self.handshake_timeout = Some(handshake_timeout);
139 self
140 }
141
142 #[must_use]
143 pub fn without_handshake_timeout(mut self) -> Self {
144 self.handshake_timeout = None;
145 self
146 }
147
148 #[must_use]
149 pub fn retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
150 self.retry_policy = retry_policy;
151 self
152 }
153}
154
155pub struct Connection;
157
158impl Connection {
159 #[must_use]
161 pub fn tls_client<A>(
162 addr: A,
163 server_name: ServerName<'static>,
164 client_config: Arc<rustls::ClientConfig>,
165 settings: ConnectionSettings,
166 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
167 where
168 A: ToSocketAddrs + Clone + Send + Sync + 'static,
169 {
170 TokioTls::outgoing_connection_with_lifecycle(addr, server_name, client_config, settings)
171 }
172
173 #[must_use]
180 pub fn graceful_shutdown<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
181 where
182 Mat: Send + 'static,
183 {
184 flow.graceful_shutdown_on_upstream_finish()
185 }
186
187 #[must_use]
189 pub fn half_close<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
190 where
191 Mat: Send + 'static,
192 {
193 Self::graceful_shutdown(flow)
194 }
195}
196
197pub trait ConnectionLifecycleExt<Mat> {
199 #[must_use]
206 fn graceful_shutdown_on_upstream_finish(self) -> Self;
207
208 #[must_use]
210 fn half_close_on_upstream_finish(self) -> Self
211 where
212 Self: Sized,
213 {
214 self.graceful_shutdown_on_upstream_finish()
215 }
216}
217
218impl<Mat> ConnectionLifecycleExt<Mat> for Flow<Vec<u8>, Vec<u8>, Mat>
219where
220 Mat: Send + 'static,
221{
222 fn graceful_shutdown_on_upstream_finish(self) -> Self {
223 self
224 }
225}
226
227impl TokioTls {
228 #[must_use]
235 pub fn outgoing_connection_with_lifecycle<A>(
236 addr: A,
237 server_name: ServerName<'static>,
238 client_config: Arc<rustls::ClientConfig>,
239 settings: ConnectionSettings,
240 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
241 where
242 A: ToSocketAddrs + Clone + Send + Sync + 'static,
243 {
244 Self::outgoing_connection_with_lifecycle_and_chunk_size(
245 addr,
246 server_name,
247 client_config,
248 settings,
249 DEFAULT_CHUNK_SIZE,
250 )
251 }
252
253 #[must_use]
255 pub fn outgoing_connection_with_lifecycle_and_chunk_size<A>(
256 addr: A,
257 server_name: ServerName<'static>,
258 client_config: Arc<rustls::ClientConfig>,
259 settings: ConnectionSettings,
260 chunk_size: usize,
261 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
262 where
263 A: ToSocketAddrs + Clone + Send + Sync + 'static,
264 {
265 assert!(chunk_size > 0, "chunk size must be greater than zero");
266 Flow::future_flow(move || {
267 let addr = addr.clone();
268 let server_name = server_name.clone();
269 let client_config = Arc::clone(&client_config);
270 let settings = settings.clone();
271 async move {
272 let handle = Handle::current();
273 retry_tls_client_connect(
274 addr,
275 server_name,
276 client_config,
277 settings,
278 handle,
279 chunk_size,
280 )
281 .await
282 }
283 })
284 }
285}
286
287async fn retry_tls_client_connect<A>(
288 addr: A,
289 server_name: ServerName<'static>,
290 client_config: Arc<rustls::ClientConfig>,
291 settings: ConnectionSettings,
292 handle: Handle,
293 chunk_size: usize,
294) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
295where
296 A: ToSocketAddrs + Clone + Send + 'static,
297{
298 let attempts = settings.retry_policy.attempts();
299 for attempt in 1..=attempts {
300 match tls_client_connect_once(
301 addr.clone(),
302 server_name.clone(),
303 Arc::clone(&client_config),
304 &settings,
305 handle.clone(),
306 chunk_size,
307 )
308 .await
309 {
310 Ok(flow) => return Ok(flow),
311 Err(error) if attempt == attempts => {
312 return Err(final_retry_error(error, attempt));
313 }
314 Err(_) => {
315 let delay = settings.retry_policy.backoff_after_attempt(attempt);
316 if !delay.is_zero() {
317 sleep(delay).await;
318 }
319 }
320 }
321 }
322 Err(StreamError::Failed(
323 "connection retry policy had no attempts".into(),
324 ))
325}
326
327async fn tls_client_connect_once<A>(
328 addr: A,
329 server_name: ServerName<'static>,
330 client_config: Arc<rustls::ClientConfig>,
331 settings: &ConnectionSettings,
332 handle: Handle,
333 chunk_size: usize,
334) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
335where
336 A: ToSocketAddrs + Send + 'static,
337{
338 let tcp = io_with_optional_timeout(
339 "TCP connect",
340 settings.connect_timeout,
341 TcpStream::connect(addr),
342 )
343 .await?;
344 let connection = TlsConnection {
345 local_addr: tcp.local_addr().map_err(io_error)?,
346 remote_addr: tcp.peer_addr().map_err(io_error)?,
347 };
348 let tls = io_with_optional_timeout(
349 "TLS handshake",
350 settings.handshake_timeout,
351 TlsConnector::from(client_config).connect(server_name, tcp),
352 )
353 .await?;
354 Ok(tls_flow_from_stream(tls, connection, handle, chunk_size))
355}
356
357async fn io_with_optional_timeout<T, Fut>(
358 operation: &'static str,
359 limit: Option<Duration>,
360 future: Fut,
361) -> StreamResult<T>
362where
363 Fut: Future<Output = std::io::Result<T>>,
364{
365 match limit {
366 Some(duration) => match timeout(duration, future).await {
367 Ok(Ok(value)) => Ok(value),
368 Ok(Err(error)) => Err(io_error(error)),
369 Err(_) => Err(StreamError::Failed(format!(
370 "{operation} timed out after {duration:?}"
371 ))),
372 },
373 None => future.await.map_err(io_error),
374 }
375}
376
377fn final_retry_error(error: StreamError, attempts: usize) -> StreamError {
378 if attempts <= 1 {
379 error
380 } else {
381 StreamError::Failed(format!(
382 "connection establishment failed after {attempts} attempts: {error}"
383 ))
384 }
385}
386
387fn io_error(error: std::io::Error) -> StreamError {
388 StreamError::Failed(error.to_string())
389}
390
391fn sane_multiplier(multiplier: f64) -> f64 {
392 if multiplier.is_finite() && multiplier >= 1.0 {
393 multiplier
394 } else {
395 1.0
396 }
397}