pub use tokio_rustls::rustls;
use crate::async_carrier::{self, AsyncCommandSender, DemandBatcher};
use datum::{Flow, Keep, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex, atomic::AtomicUsize, mpsc as std_mpsc};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::runtime::Handle;
use tokio::sync::{mpsc, watch};
use tokio::task::JoinHandle;
use tokio_rustls::rustls::pki_types::ServerName;
use tokio_rustls::{TlsAcceptor, TlsConnector};
const DEFAULT_CHUNK_SIZE: usize = 8192;
const DEFAULT_RECEIVE_BUFFER: usize = 64;
static ACTIVE_TLS_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
pub type TlsByteSource = Source<Vec<u8>, NotUsed>;
pub type TlsByteSink = Sink<Vec<u8>, StreamCompletion<NotUsed>>;
enum DemandResponse<T> {
Item(T),
Complete,
Error(StreamError),
}
struct ReadResource {
receiver: std_mpsc::Receiver<DemandResponse<Vec<u8>>>,
carrier: TlsCarrier,
demand: DemandBatcher,
pending: Option<DemandResponse<Vec<u8>>>,
}
impl Drop for ReadResource {
fn drop(&mut self) {
self.carrier.close_read();
}
}
enum TlsCarrierCommand {
Demand(usize),
SendOne(Vec<u8>),
SendBatch(Vec<Vec<u8>>),
CloseRead,
CloseWrite {
ack: std_mpsc::Sender<StreamResult<()>>,
},
}
#[derive(Clone)]
struct TlsCarrier {
inner: Arc<TlsCarrierInner>,
}
struct TlsCarrierInner {
commands: AsyncCommandSender<TlsCarrierCommand>,
send_errors: Mutex<std_mpsc::Receiver<StreamError>>,
task: Mutex<Option<JoinHandle<()>>>,
_execution: async_carrier::ShardedTokioCarrierExecution,
}
impl Drop for TlsCarrierInner {
fn drop(&mut self) {
if let Some(task) = self.task.lock().expect("TLS carrier task poisoned").take() {
task.abort();
}
}
}
impl TlsCarrier {
fn close_read(&self) {
let _ = self.inner.commands.try_send(TlsCarrierCommand::CloseRead);
}
fn request_demand(&self, demand: usize) -> StreamResult<()> {
self.inner
.commands
.send_or_blocking(TlsCarrierCommand::Demand(demand))
}
fn send_items(&self, items: Vec<Vec<u8>>) -> StreamResult<()> {
self.check_send_error()?;
self.inner
.commands
.send_or_blocking(TlsCarrierCommand::SendBatch(items))
.map_err(|error| StreamError::Failed(format!("TLS send batch failed: {error:?}")))
}
fn send_one(&self, item: Vec<u8>) -> StreamResult<()> {
self.check_send_error()?;
self.inner
.commands
.send_or_blocking(TlsCarrierCommand::SendOne(item))
.map_err(|error| StreamError::Failed(format!("TLS send failed: {error:?}")))
}
fn close_write(&self) -> StreamResult<()> {
self.check_send_error()?;
let (ack_sender, ack_receiver) = std_mpsc::channel();
if self
.inner
.commands
.send_or_blocking(TlsCarrierCommand::CloseWrite { ack: ack_sender })
.is_err()
{
return Ok(());
}
match ack_receiver.recv() {
Ok(result) => result,
Err(_) => Err(abrupt_termination()),
}?;
self.check_send_error()
}
fn check_send_error(&self) -> StreamResult<()> {
match self
.inner
.send_errors
.lock()
.expect("TLS carrier send error receiver poisoned")
.try_recv()
{
Ok(error) => Err(error),
Err(std_mpsc::TryRecvError::Empty) | Err(std_mpsc::TryRecvError::Disconnected) => {
Ok(())
}
}
}
}
struct SendResource {
carrier: TlsCarrier,
pending: Vec<Vec<u8>>,
batch_size: usize,
}
struct BindResource {
demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
cancel: watch::Sender<bool>,
task: JoinHandle<()>,
}
impl Drop for BindResource {
fn drop(&mut self) {
let _ = self.cancel.send(true);
self.task.abort();
}
}
fn io_error(error: std::io::Error) -> StreamError {
StreamError::Failed(error.to_string())
}
fn abrupt_termination() -> StreamError {
StreamError::AbruptTermination
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TlsConnection {
pub local_addr: SocketAddr,
pub remote_addr: SocketAddr,
}
impl TlsConnection {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
#[must_use]
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TlsBinding {
pub local_addr: SocketAddr,
}
impl TlsBinding {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
pub struct TlsIncomingConnection {
connection: TlsConnection,
source: TlsByteSource,
sink: TlsByteSink,
}
impl TlsIncomingConnection {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.connection.local_addr
}
#[must_use]
pub fn remote_addr(&self) -> SocketAddr {
self.connection.remote_addr
}
#[must_use]
pub fn connection(&self) -> TlsConnection {
self.connection
}
#[must_use]
pub fn into_parts(self) -> (TlsByteSource, TlsByteSink) {
(self.source, self.sink)
}
#[must_use]
pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, NotUsed> {
Flow::from_sink_and_source_coupled(self.sink, self.source)
.map_materialized_value(|_| NotUsed)
}
}
pub struct TokioTls;
pub type Tls = TokioTls;
impl TokioTls {
#[must_use]
pub fn outgoing_connection<A>(
addr: A,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
chunk_size: usize,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(chunk_size > 0, "chunk size must be greater than zero");
Flow::future_flow(move || {
let addr = addr.clone();
let server_name = server_name.clone();
let client_config = Arc::clone(&client_config);
async move {
let handle = Handle::current();
tls_client_connect(addr, server_name, client_config, handle, chunk_size).await
}
})
}
#[must_use]
pub fn outgoing_connection_default<A>(
addr: A,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::outgoing_connection(addr, server_name, client_config, DEFAULT_CHUNK_SIZE)
}
#[must_use]
pub fn bind<A>(
addr: A,
server_config: Arc<rustls::ServerConfig>,
chunk_size: usize,
) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(chunk_size > 0, "chunk size must be greater than zero");
Source::lazy_future_source(move || {
let addr = addr.clone();
let server_config = Arc::clone(&server_config);
async move {
let handle = Handle::current();
let listener = TcpListener::bind(addr).await.map_err(io_error)?;
let local_addr = listener.local_addr().map_err(io_error)?;
Ok(tls_bind_source(
listener,
server_config,
local_addr,
handle,
chunk_size,
))
}
})
}
#[must_use]
pub fn bind_default<A>(
addr: A,
server_config: Arc<rustls::ServerConfig>,
) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::bind(addr, server_config, DEFAULT_CHUNK_SIZE)
}
}
pub(crate) fn tls_flow_from_stream_with_execution<S>(
stream: S,
connection: TlsConnection,
execution: async_carrier::ShardedTokioCarrierExecution,
chunk_size: usize,
) -> Flow<Vec<u8>, Vec<u8>, TlsConnection>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (source, sink) = single_use_tls_halves(stream, execution, chunk_size);
Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| connection)
}
fn tls_incoming_connection<S>(
stream: S,
connection: TlsConnection,
execution: async_carrier::ShardedTokioCarrierExecution,
chunk_size: usize,
) -> TlsIncomingConnection
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (source, sink) = single_use_tls_halves(stream, execution, chunk_size);
TlsIncomingConnection {
connection,
source,
sink,
}
}
fn single_use_tls_halves<S>(
stream: S,
execution: async_carrier::ShardedTokioCarrierExecution,
chunk_size: usize,
) -> (TlsByteSource, TlsByteSink)
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (carrier, receiver) =
start_tls_carrier(stream, execution, chunk_size, DEFAULT_RECEIVE_BUFFER);
let source =
single_use_tls_source_from_carrier(carrier.clone(), receiver, DEFAULT_RECEIVE_BUFFER);
let sink = single_use_tls_sink_from_carrier(carrier, 1);
(source, sink)
}
fn single_use_tls_source_from_carrier(
carrier: TlsCarrier,
receiver: std_mpsc::Receiver<DemandResponse<Vec<u8>>>,
receive_buffer: usize,
) -> TlsByteSource {
let receiver = Arc::new(Mutex::new(Some(receiver)));
Source::unfold_resource(
{
let receiver = Arc::clone(&receiver);
move || {
let receiver = receiver
.lock()
.expect("single-use TLS receiver poisoned")
.take()
.ok_or_else(|| StreamError::Failed("TLS source already materialized".into()))?;
let demand = DemandBatcher::new(receive_buffer);
let pending = match carrier.request_demand(demand.initial()) {
Ok(()) => None,
Err(error) => match receiver.try_recv() {
Ok(response) => Some(response),
Err(std_mpsc::TryRecvError::Empty) => return Err(error),
Err(std_mpsc::TryRecvError::Disconnected) => {
return Err(abrupt_termination());
}
},
};
Ok(ReadResource {
receiver,
carrier: carrier.clone(),
demand,
pending,
})
}
},
read_next_chunk,
close_read_resource,
)
}
fn read_next_chunk(resource: &mut ReadResource) -> StreamResult<Option<Vec<u8>>> {
let response = match resource.pending.take() {
Some(response) => response,
None => resource.receiver.recv().map_err(|_| abrupt_termination())?,
};
match response {
DemandResponse::Item(chunk) => {
if let Some(demand) = resource.demand.record_consumed() {
let _ = resource.carrier.request_demand(demand);
}
Ok(Some(chunk))
}
DemandResponse::Complete => Ok(None),
DemandResponse::Error(error) => Err(error),
}
}
fn close_read_resource(resource: ReadResource) -> StreamResult<()> {
resource.carrier.close_read();
Ok(())
}
fn start_tls_carrier<S>(
stream: S,
execution: async_carrier::ShardedTokioCarrierExecution,
chunk_size: usize,
receive_buffer: usize,
) -> (TlsCarrier, std_mpsc::Receiver<DemandResponse<Vec<u8>>>)
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let command_capacity = async_carrier::DEFAULT_COMMAND_BUFFER.max(receive_buffer);
let (commands, command_receiver) = async_carrier::command_channel(command_capacity, "TLS");
let (send_error_sender, send_error_receiver) = std_mpsc::channel();
let (receive_sender, receive_receiver) =
std_mpsc::sync_channel(receive_buffer.saturating_add(1));
let (reader, writer) = tokio::io::split(stream);
let command_keepalive = commands.clone();
let task = execution.handle().spawn(run_tls_carrier_task(
reader,
writer,
chunk_size,
receive_sender,
send_error_sender,
command_keepalive,
command_receiver,
));
(
TlsCarrier {
inner: Arc::new(TlsCarrierInner {
commands,
send_errors: Mutex::new(send_error_receiver),
task: Mutex::new(Some(task)),
_execution: execution,
}),
},
receive_receiver,
)
}
async fn run_tls_carrier_task<R, W>(
mut reader: R,
mut writer: W,
chunk_size: usize,
receive_sender: std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
send_error_sender: std_mpsc::Sender<StreamError>,
_command_keepalive: AsyncCommandSender<TlsCarrierCommand>,
mut commands: mpsc::Receiver<TlsCarrierCommand>,
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let mut buffer = vec![0_u8; chunk_size];
let mut pending_tail = Vec::with_capacity(chunk_size);
let mut requested = 0_usize;
let mut read_open = true;
let mut write_open = true;
loop {
if !read_open && !write_open {
return;
}
if read_open && requested > 0 {
tokio::select! {
biased;
command = commands.recv() => {
let Some(command) = command else {
return;
};
if !handle_tls_carrier_command(
&mut writer,
command,
&send_error_sender,
&mut read_open,
&mut write_open,
&mut requested,
).await {
return;
}
}
read = reader.read(&mut buffer) => {
match read {
Ok(0) => {
if !pending_tail.is_empty() {
match try_send_tls_read_response(
&receive_sender,
DemandResponse::Item(std::mem::take(&mut pending_tail)),
) {
TlsQueueOutcome::Queued => {
requested = requested.saturating_sub(1);
}
TlsQueueOutcome::Closed => {
read_open = false;
continue;
}
TlsQueueOutcome::Full => {
report_tls_read_error(
&receive_sender,
&send_error_sender,
tls_receive_buffer_overflow(),
);
return;
}
}
}
match try_send_tls_read_response(
&receive_sender,
DemandResponse::Complete,
) {
TlsQueueOutcome::Queued | TlsQueueOutcome::Closed => {
read_open = false;
}
TlsQueueOutcome::Full => {
report_tls_read_error(
&receive_sender,
&send_error_sender,
tls_receive_buffer_overflow(),
);
return;
}
}
}
Ok(read) => {
match queue_tls_read_chunks(
&receive_sender,
&send_error_sender,
chunk_size,
&mut pending_tail,
&buffer[..read],
) {
TlsReadQueueResult::Queued(queued) => {
requested = requested.saturating_sub(queued);
}
TlsReadQueueResult::Closed => {
read_open = false;
}
TlsReadQueueResult::Failed => {
return;
}
}
}
Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
Err(error) => {
report_tls_read_error(
&receive_sender,
&send_error_sender,
io_error(error),
);
return;
}
}
}
}
} else {
let Some(command) = commands.recv().await else {
return;
};
if !handle_tls_carrier_command(
&mut writer,
command,
&send_error_sender,
&mut read_open,
&mut write_open,
&mut requested,
)
.await
{
return;
}
}
}
}
async fn handle_tls_carrier_command<W>(
writer: &mut W,
command: TlsCarrierCommand,
send_error_sender: &std_mpsc::Sender<StreamError>,
read_open: &mut bool,
write_open: &mut bool,
requested: &mut usize,
) -> bool
where
W: AsyncWrite + Unpin,
{
match command {
TlsCarrierCommand::Demand(demand) => {
*requested = requested.saturating_add(demand);
true
}
TlsCarrierCommand::SendOne(chunk) => {
if !*write_open {
report_tls_write_error(
send_error_sender,
StreamError::Failed("TLS write side is closed".to_owned()),
);
return *read_open;
}
if write_one_tls_chunk(writer, send_error_sender, &chunk).await {
true
} else {
*write_open = false;
*read_open
}
}
TlsCarrierCommand::SendBatch(chunks) => {
if !*write_open {
report_tls_write_error(
send_error_sender,
StreamError::Failed("TLS write side is closed".to_owned()),
);
return *read_open;
}
for chunk in &chunks {
if let Err(error) = writer.write_all(chunk).await.map_err(io_error) {
report_tls_write_error(send_error_sender, error);
*write_open = false;
return *read_open;
}
}
if let Err(error) = writer.flush().await.map_err(io_error) {
report_tls_write_error(send_error_sender, error);
*write_open = false;
return *read_open;
}
true
}
TlsCarrierCommand::CloseRead => {
*read_open = false;
true
}
TlsCarrierCommand::CloseWrite { ack } => {
*write_open = false;
let result = close_tls_writer(writer).await;
match result {
Ok(()) => {
let _ = ack.send(Ok(()));
true
}
Err(error) => {
report_tls_write_error(send_error_sender, error.clone());
let _ = ack.send(Err(error));
*read_open
}
}
}
}
}
async fn write_one_tls_chunk<W>(
writer: &mut W,
send_error_sender: &std_mpsc::Sender<StreamError>,
chunk: &[u8],
) -> bool
where
W: AsyncWrite + Unpin,
{
if let Err(error) = writer.write_all(chunk).await.map_err(io_error) {
report_tls_write_error(send_error_sender, error);
return false;
}
if let Err(error) = writer.flush().await.map_err(io_error) {
report_tls_write_error(send_error_sender, error);
return false;
}
true
}
async fn close_tls_writer<W>(writer: &mut W) -> StreamResult<()>
where
W: AsyncWrite + Unpin,
{
writer.flush().await.map_err(io_error)?;
writer.shutdown().await.map_err(io_error)
}
enum TlsReadQueueResult {
Queued(usize),
Closed,
Failed,
}
enum TlsQueueOutcome {
Queued,
Full,
Closed,
}
fn queue_tls_read_chunks(
sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
send_error_sender: &std_mpsc::Sender<StreamError>,
chunk_size: usize,
pending_tail: &mut Vec<u8>,
read_buffer: &[u8],
) -> TlsReadQueueResult {
let mut offset = 0;
let mut queued = 0_usize;
if !pending_tail.is_empty() {
let needed = chunk_size - pending_tail.len();
let take = needed.min(read_buffer.len());
pending_tail.extend_from_slice(&read_buffer[..take]);
offset += take;
if pending_tail.len() == chunk_size {
match try_send_tls_read_response(
sender,
DemandResponse::Item(std::mem::take(pending_tail)),
) {
TlsQueueOutcome::Queued => queued += 1,
TlsQueueOutcome::Closed => return TlsReadQueueResult::Closed,
TlsQueueOutcome::Full => {
report_tls_read_error(sender, send_error_sender, tls_receive_buffer_overflow());
return TlsReadQueueResult::Failed;
}
}
}
}
while offset + chunk_size <= read_buffer.len() {
let next = offset + chunk_size;
match try_send_tls_read_response(
sender,
DemandResponse::Item(read_buffer[offset..next].to_vec()),
) {
TlsQueueOutcome::Queued => queued += 1,
TlsQueueOutcome::Closed => return TlsReadQueueResult::Closed,
TlsQueueOutcome::Full => {
report_tls_read_error(sender, send_error_sender, tls_receive_buffer_overflow());
return TlsReadQueueResult::Failed;
}
}
offset = next;
}
if offset < read_buffer.len() {
pending_tail.extend_from_slice(&read_buffer[offset..]);
}
TlsReadQueueResult::Queued(queued)
}
fn try_send_tls_read_response(
sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
item: DemandResponse<Vec<u8>>,
) -> TlsQueueOutcome {
match sender.try_send(item) {
Ok(()) => TlsQueueOutcome::Queued,
Err(std_mpsc::TrySendError::Full(_)) => TlsQueueOutcome::Full,
Err(std_mpsc::TrySendError::Disconnected(_)) => TlsQueueOutcome::Closed,
}
}
fn report_tls_read_error(
receive_sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
send_error_sender: &std_mpsc::Sender<StreamError>,
error: StreamError,
) {
let _ = send_error_sender.send(error.clone());
let _ = receive_sender.try_send(DemandResponse::Error(error));
}
fn report_tls_write_error(send_error_sender: &std_mpsc::Sender<StreamError>, error: StreamError) {
let _ = send_error_sender.send(error);
}
fn tls_receive_buffer_overflow() -> StreamError {
StreamError::Failed("TLS receive buffer filled without downstream demand".to_owned())
}
fn single_use_tls_sink_from_carrier(carrier: TlsCarrier, batch_size: usize) -> TlsByteSink {
let carrier = Arc::new(Mutex::new(Some(carrier)));
Flow::<Vec<u8>, Vec<u8>>::identity()
.map_with_resource(
{
let carrier = Arc::clone(&carrier);
move || {
let carrier = carrier
.lock()
.expect("single-use TLS carrier poisoned")
.take()
.ok_or_else(|| {
StreamError::Failed("TLS sink already materialized".into())
})?;
Ok(SendResource {
carrier,
pending: Vec::with_capacity(batch_size),
batch_size,
})
}
},
|resource, chunk| {
send_tls_chunk(resource, chunk)?;
Ok(NotUsed)
},
close_tls_send_resource,
)
.to_mat(Sink::ignore(), Keep::right)
}
fn close_tls_send_resource(mut resource: SendResource) -> StreamResult<Option<NotUsed>> {
flush_tls_send_resource(&mut resource)?;
resource.carrier.close_write()?;
Ok(None)
}
fn send_tls_chunk(resource: &mut SendResource, chunk: Vec<u8>) -> StreamResult<()> {
if resource.batch_size <= 1 {
return resource.carrier.send_one(chunk);
}
resource.pending.push(chunk);
if resource.pending.len() >= resource.batch_size {
flush_tls_send_resource(resource)?;
}
Ok(())
}
fn flush_tls_send_resource(resource: &mut SendResource) -> StreamResult<()> {
if resource.pending.is_empty() {
return resource.carrier.check_send_error();
}
let pending = std::mem::take(&mut resource.pending);
resource.carrier.send_items(pending)
}
fn tls_bind_source(
listener: TcpListener,
server_config: Arc<rustls::ServerConfig>,
local_addr: SocketAddr,
handle: Handle,
chunk_size: usize,
) -> Source<TlsIncomingConnection, TlsBinding> {
let listener = Arc::new(Mutex::new(Some(listener)));
Source::unfold_resource(
{
let listener = Arc::clone(&listener);
let handle = handle.clone();
move || {
let listener = listener
.lock()
.expect("single-use TLS listener poisoned")
.take()
.ok_or_else(|| {
StreamError::Failed("TLS listener already materialized".into())
})?;
let (demand_sender, demand_receiver) = mpsc::channel(1);
let (cancel_sender, cancel_receiver) = watch::channel(false);
let task = handle.spawn(run_tls_bind_task(
listener,
Arc::clone(&server_config),
local_addr,
chunk_size,
handle.clone(),
demand_receiver,
cancel_receiver,
));
Ok(BindResource {
demands: demand_sender,
cancel: cancel_sender,
task,
})
}
},
|resource| {
let (reply_sender, reply_receiver) = std_mpsc::channel();
resource
.demands
.blocking_send(reply_sender)
.map_err(|_| abrupt_termination())?;
match reply_receiver.recv() {
Ok(DemandResponse::Item(connection)) => Ok(Some(connection)),
Ok(DemandResponse::Complete) => Ok(None),
Ok(DemandResponse::Error(error)) => Err(error),
Err(_) => Err(abrupt_termination()),
}
},
close_bind_resource,
)
.map_materialized_value(move |_| TlsBinding { local_addr })
}
fn close_bind_resource(resource: BindResource) -> StreamResult<()> {
let _ = resource.cancel.send(true);
resource.task.abort();
Ok(())
}
async fn run_tls_bind_task(
listener: TcpListener,
server_config: Arc<rustls::ServerConfig>,
local_addr: SocketAddr,
chunk_size: usize,
handle: Handle,
mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
mut cancel: watch::Receiver<bool>,
) {
let acceptor = TlsAcceptor::from(server_config);
loop {
let reply = tokio::select! {
demand = demands.recv() => match demand {
Some(reply) => reply,
None => return,
},
changed = cancel.changed() => {
let _ = changed;
return;
}
};
let (tcp, remote_addr) = loop {
let accepted = tokio::select! {
accepted = listener.accept() => accepted,
changed = cancel.changed() => {
let _ = changed;
return;
}
};
match accepted {
Ok(accepted) => break accepted,
Err(error) if is_transient_accept_error(&error) => continue,
Err(error) => {
let _ = reply.send(DemandResponse::Error(io_error(error)));
return;
}
}
};
let connection = TlsConnection {
local_addr: tcp.local_addr().unwrap_or(local_addr),
remote_addr,
};
let execution = tls_connection_execution(handle.clone());
let accepted = tokio::select! {
accepted = accept_tls_on_execution(tcp, acceptor.clone(), &execution) => accepted,
changed = cancel.changed() => {
let _ = changed;
return;
}
};
match accepted {
Ok(stream) => {
let incoming = tls_incoming_connection(stream, connection, execution, chunk_size);
if reply.send(DemandResponse::Item(incoming)).is_err() {
return;
}
}
Err(error) => {
let _ = reply.send(DemandResponse::Error(error));
return;
}
}
}
}
fn is_transient_accept_error(error: &std::io::Error) -> bool {
matches!(
error.kind(),
std::io::ErrorKind::Interrupted
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::ConnectionReset
) || error.raw_os_error().is_some_and(is_transient_accept_errno)
}
#[cfg(target_os = "linux")]
fn is_transient_accept_errno(code: i32) -> bool {
matches!(code, 4 | 103 | 104)
}
#[cfg(not(target_os = "linux"))]
fn is_transient_accept_errno(_code: i32) -> bool {
false
}
pub(crate) fn tls_connection_execution(
fallback: Handle,
) -> async_carrier::ShardedTokioCarrierExecution {
async_carrier::sharded_tokio_carrier_execution(fallback, &ACTIVE_TLS_CONNECTIONS)
}
pub(crate) async fn tls_client_connect<A>(
addr: A,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
fallback: Handle,
chunk_size: usize,
) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
where
A: ToSocketAddrs + Send + 'static,
{
let execution = tls_connection_execution(fallback);
let (tls, connection) = execution
.run(async move {
let tcp = TcpStream::connect(addr).await.map_err(io_error)?;
let connection = TlsConnection {
local_addr: tcp.local_addr().map_err(io_error)?,
remote_addr: tcp.peer_addr().map_err(io_error)?,
};
let tls = TlsConnector::from(client_config)
.connect(server_name, tcp)
.await
.map_err(io_error)?;
Ok((tls, connection))
})
.await?;
Ok(tls_flow_from_stream_with_execution(
tls, connection, execution, chunk_size,
))
}
async fn accept_tls_on_execution(
tcp: TcpStream,
acceptor: TlsAcceptor,
execution: &async_carrier::ShardedTokioCarrierExecution,
) -> StreamResult<tokio_rustls::server::TlsStream<TcpStream>> {
enum AcceptedTcp {
Tokio(TcpStream),
Std(std::net::TcpStream),
}
let tcp = if execution.is_sharded() {
AcceptedTcp::Std(tcp.into_std().map_err(io_error)?)
} else {
AcceptedTcp::Tokio(tcp)
};
execution
.run(async move {
let tcp = match tcp {
AcceptedTcp::Std(std_tcp) => TcpStream::from_std(std_tcp).map_err(io_error)?,
AcceptedTcp::Tokio(tcp) => tcp,
};
acceptor.accept(tcp).await.map_err(io_error)
})
.await
}