use crate::tls::{
TlsConnection, TokioTls, rustls, tls_connection_execution, tls_flow_from_stream_with_execution,
};
use datum::{Flow, StreamCompletion, StreamError, StreamResult};
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::runtime::Handle;
use tokio::time::{sleep, timeout};
use tokio_rustls::TlsConnector;
use tokio_rustls::rustls::pki_types::ServerName;
const DEFAULT_CHUNK_SIZE: usize = 8192;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(100);
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(5);
const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
#[derive(Debug, Clone, PartialEq)]
pub struct RetryPolicy {
pub max_attempts: usize,
pub initial_backoff: Duration,
pub backoff_multiplier: f64,
pub max_backoff: Duration,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 1,
initial_backoff: DEFAULT_INITIAL_BACKOFF,
backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
max_backoff: DEFAULT_MAX_BACKOFF,
}
}
}
impl RetryPolicy {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn max_attempts(mut self, max_attempts: usize) -> Self {
self.max_attempts = max_attempts.max(1);
self
}
#[must_use]
pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
self.initial_backoff = initial_backoff;
self
}
#[must_use]
pub fn backoff_multiplier(mut self, backoff_multiplier: f64) -> Self {
self.backoff_multiplier = sane_multiplier(backoff_multiplier);
self
}
#[must_use]
pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
self.max_backoff = max_backoff;
self
}
fn attempts(&self) -> usize {
self.max_attempts.max(1)
}
fn backoff_after_attempt(&self, attempt: usize) -> Duration {
if self.initial_backoff.is_zero() || self.max_backoff.is_zero() {
return Duration::ZERO;
}
let multiplier = sane_multiplier(self.backoff_multiplier);
let exponent = attempt.saturating_sub(1).min(32) as i32;
let delay_secs = self.initial_backoff.as_secs_f64() * multiplier.powi(exponent);
let capped_secs = delay_secs.min(self.max_backoff.as_secs_f64());
Duration::from_secs_f64(capped_secs)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ConnectionSettings {
pub connect_timeout: Option<Duration>,
pub handshake_timeout: Option<Duration>,
pub retry_policy: RetryPolicy,
}
impl Default for ConnectionSettings {
fn default() -> Self {
Self {
connect_timeout: Some(DEFAULT_CONNECT_TIMEOUT),
handshake_timeout: Some(DEFAULT_HANDSHAKE_TIMEOUT),
retry_policy: RetryPolicy::default(),
}
}
}
impl ConnectionSettings {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
self.connect_timeout = Some(connect_timeout);
self
}
#[must_use]
pub fn without_connect_timeout(mut self) -> Self {
self.connect_timeout = None;
self
}
#[must_use]
pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
self.handshake_timeout = Some(handshake_timeout);
self
}
#[must_use]
pub fn without_handshake_timeout(mut self) -> Self {
self.handshake_timeout = None;
self
}
#[must_use]
pub fn retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
self.retry_policy = retry_policy;
self
}
}
pub struct Connection;
impl Connection {
#[must_use]
pub fn tls_client<A>(
addr: A,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
settings: ConnectionSettings,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
TokioTls::outgoing_connection_with_lifecycle(addr, server_name, client_config, settings)
}
#[must_use]
pub fn graceful_shutdown<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
where
Mat: Send + 'static,
{
flow.graceful_shutdown_on_upstream_finish()
}
#[must_use]
pub fn half_close<Mat>(flow: Flow<Vec<u8>, Vec<u8>, Mat>) -> Flow<Vec<u8>, Vec<u8>, Mat>
where
Mat: Send + 'static,
{
Self::graceful_shutdown(flow)
}
}
pub trait ConnectionLifecycleExt<Mat> {
#[must_use]
fn graceful_shutdown_on_upstream_finish(self) -> Self;
#[must_use]
fn half_close_on_upstream_finish(self) -> Self
where
Self: Sized,
{
self.graceful_shutdown_on_upstream_finish()
}
}
impl<Mat> ConnectionLifecycleExt<Mat> for Flow<Vec<u8>, Vec<u8>, Mat>
where
Mat: Send + 'static,
{
fn graceful_shutdown_on_upstream_finish(self) -> Self {
self
}
}
impl TokioTls {
#[must_use]
pub fn outgoing_connection_with_lifecycle<A>(
addr: A,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
settings: ConnectionSettings,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::outgoing_connection_with_lifecycle_and_chunk_size(
addr,
server_name,
client_config,
settings,
DEFAULT_CHUNK_SIZE,
)
}
#[must_use]
pub fn outgoing_connection_with_lifecycle_and_chunk_size<A>(
addr: A,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
settings: ConnectionSettings,
chunk_size: usize,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(chunk_size > 0, "chunk size must be greater than zero");
Flow::future_flow(move || {
let addr = addr.clone();
let server_name = server_name.clone();
let client_config = Arc::clone(&client_config);
let settings = settings.clone();
async move {
let handle = Handle::current();
retry_tls_client_connect(
addr,
server_name,
client_config,
settings,
handle,
chunk_size,
)
.await
}
})
}
}
async fn retry_tls_client_connect<A>(
addr: A,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
settings: ConnectionSettings,
handle: Handle,
chunk_size: usize,
) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
where
A: ToSocketAddrs + Clone + Send + 'static,
{
let attempts = settings.retry_policy.attempts();
for attempt in 1..=attempts {
match tls_client_connect_once(
addr.clone(),
server_name.clone(),
Arc::clone(&client_config),
&settings,
handle.clone(),
chunk_size,
)
.await
{
Ok(flow) => return Ok(flow),
Err(error) if attempt == attempts => {
return Err(final_retry_error(error, attempt));
}
Err(_) => {
let delay = settings.retry_policy.backoff_after_attempt(attempt);
if !delay.is_zero() {
sleep(delay).await;
}
}
}
}
Err(StreamError::Failed(
"connection retry policy had no attempts".into(),
))
}
async fn tls_client_connect_once<A>(
addr: A,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
settings: &ConnectionSettings,
handle: Handle,
chunk_size: usize,
) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
where
A: ToSocketAddrs + Send + 'static,
{
let execution = tls_connection_execution(handle);
let connect_timeout = settings.connect_timeout;
let handshake_timeout = settings.handshake_timeout;
let (tls, connection) = execution
.run(async move {
let tcp =
io_with_optional_timeout("TCP connect", connect_timeout, TcpStream::connect(addr))
.await?;
let connection = TlsConnection {
local_addr: tcp.local_addr().map_err(io_error)?,
remote_addr: tcp.peer_addr().map_err(io_error)?,
};
let tls = io_with_optional_timeout(
"TLS handshake",
handshake_timeout,
TlsConnector::from(client_config).connect(server_name, tcp),
)
.await?;
Ok((tls, connection))
})
.await?;
Ok(tls_flow_from_stream_with_execution(
tls, connection, execution, chunk_size,
))
}
async fn io_with_optional_timeout<T, Fut>(
operation: &'static str,
limit: Option<Duration>,
future: Fut,
) -> StreamResult<T>
where
Fut: Future<Output = std::io::Result<T>>,
{
match limit {
Some(duration) => match timeout(duration, future).await {
Ok(Ok(value)) => Ok(value),
Ok(Err(error)) => Err(io_error(error)),
Err(_) => Err(StreamError::Failed(format!(
"{operation} timed out after {duration:?}"
))),
},
None => future.await.map_err(io_error),
}
}
fn final_retry_error(error: StreamError, attempts: usize) -> StreamError {
if attempts <= 1 {
error
} else {
StreamError::Failed(format!(
"connection establishment failed after {attempts} attempts: {error}"
))
}
}
fn io_error(error: std::io::Error) -> StreamError {
StreamError::Failed(error.to_string())
}
fn sane_multiplier(multiplier: f64) -> f64 {
if multiplier.is_finite() && multiplier >= 1.0 {
multiplier
} else {
1.0
}
}