pub use tokio_rustls::rustls;
use datum::{Flow, Keep, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex, mpsc as std_mpsc};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::runtime::Handle;
use tokio::sync::{mpsc, watch};
use tokio::task::JoinHandle;
use tokio_rustls::rustls::pki_types::ServerName;
use tokio_rustls::{TlsAcceptor, TlsConnector};
const DEFAULT_CHUNK_SIZE: usize = 8192;
pub type TlsByteSource = Source<Vec<u8>, NotUsed>;
pub type TlsByteSink = Sink<Vec<u8>, StreamCompletion<NotUsed>>;
enum DemandResponse<T> {
Item(T),
Complete,
Error(StreamError),
}
struct ReadResource {
receiver: mpsc::Receiver<DemandResponse<Vec<u8>>>,
cancel: watch::Sender<bool>,
task: JoinHandle<()>,
}
impl Drop for ReadResource {
fn drop(&mut self) {
let _ = self.cancel.send(true);
self.task.abort();
}
}
struct BindResource {
demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
cancel: watch::Sender<bool>,
task: JoinHandle<()>,
}
impl Drop for BindResource {
fn drop(&mut self) {
let _ = self.cancel.send(true);
self.task.abort();
}
}
fn io_error(error: std::io::Error) -> StreamError {
StreamError::Failed(error.to_string())
}
fn abrupt_termination() -> StreamError {
StreamError::AbruptTermination
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TlsConnection {
pub local_addr: SocketAddr,
pub remote_addr: SocketAddr,
}
impl TlsConnection {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
#[must_use]
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TlsBinding {
pub local_addr: SocketAddr,
}
impl TlsBinding {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
pub struct TlsIncomingConnection {
connection: TlsConnection,
source: TlsByteSource,
sink: TlsByteSink,
}
impl TlsIncomingConnection {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.connection.local_addr
}
#[must_use]
pub fn remote_addr(&self) -> SocketAddr {
self.connection.remote_addr
}
#[must_use]
pub fn connection(&self) -> TlsConnection {
self.connection
}
#[must_use]
pub fn into_parts(self) -> (TlsByteSource, TlsByteSink) {
(self.source, self.sink)
}
#[must_use]
pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, NotUsed> {
Flow::from_sink_and_source_coupled(self.sink, self.source)
.map_materialized_value(|_| NotUsed)
}
}
pub struct TokioTls;
pub type Tls = TokioTls;
impl TokioTls {
#[must_use]
pub fn outgoing_connection<A>(
addr: A,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
chunk_size: usize,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(chunk_size > 0, "chunk size must be greater than zero");
Flow::future_flow(move || {
let addr = addr.clone();
let server_name = server_name.clone();
let client_config = Arc::clone(&client_config);
async move {
let handle = Handle::current();
let tcp = TcpStream::connect(addr).await.map_err(io_error)?;
let connection = TlsConnection {
local_addr: tcp.local_addr().map_err(io_error)?,
remote_addr: tcp.peer_addr().map_err(io_error)?,
};
let tls = TlsConnector::from(client_config)
.connect(server_name, tcp)
.await
.map_err(io_error)?;
Ok(tls_flow_from_stream(tls, connection, handle, chunk_size))
}
})
}
#[must_use]
pub fn outgoing_connection_default<A>(
addr: A,
server_name: ServerName<'static>,
client_config: Arc<rustls::ClientConfig>,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::outgoing_connection(addr, server_name, client_config, DEFAULT_CHUNK_SIZE)
}
#[must_use]
pub fn bind<A>(
addr: A,
server_config: Arc<rustls::ServerConfig>,
chunk_size: usize,
) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(chunk_size > 0, "chunk size must be greater than zero");
Source::lazy_future_source(move || {
let addr = addr.clone();
let server_config = Arc::clone(&server_config);
async move {
let handle = Handle::current();
let listener = TcpListener::bind(addr).await.map_err(io_error)?;
let local_addr = listener.local_addr().map_err(io_error)?;
Ok(tls_bind_source(
listener,
server_config,
local_addr,
handle,
chunk_size,
))
}
})
}
#[must_use]
pub fn bind_default<A>(
addr: A,
server_config: Arc<rustls::ServerConfig>,
) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::bind(addr, server_config, DEFAULT_CHUNK_SIZE)
}
}
pub(crate) fn tls_flow_from_stream<S>(
stream: S,
connection: TlsConnection,
handle: Handle,
chunk_size: usize,
) -> Flow<Vec<u8>, Vec<u8>, TlsConnection>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (reader, writer) = tokio::io::split(stream);
let source = single_use_async_read_source(reader, handle.clone(), chunk_size);
let sink = single_use_async_write_sink(writer, handle);
Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| connection)
}
fn tls_incoming_connection<S>(
stream: S,
connection: TlsConnection,
handle: Handle,
chunk_size: usize,
) -> TlsIncomingConnection
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (reader, writer) = tokio::io::split(stream);
TlsIncomingConnection {
connection,
source: single_use_async_read_source(reader, handle.clone(), chunk_size),
sink: single_use_async_write_sink(writer, handle),
}
}
fn single_use_async_read_source<R>(reader: R, handle: Handle, chunk_size: usize) -> TlsByteSource
where
R: AsyncRead + Unpin + Send + 'static,
{
let reader = Arc::new(Mutex::new(Some(reader)));
Source::unfold_resource(
{
let reader = Arc::clone(&reader);
move || {
let reader = reader
.lock()
.expect("single-use TLS reader poisoned")
.take()
.ok_or_else(|| StreamError::Failed("TLS reader already materialized".into()))?;
let (sender, receiver) = mpsc::channel(1);
let (cancel_sender, cancel_receiver) = watch::channel(false);
let task = handle.spawn(run_read_task(reader, chunk_size, sender, cancel_receiver));
Ok(ReadResource {
receiver,
cancel: cancel_sender,
task,
})
}
},
|resource| match resource.receiver.blocking_recv() {
Some(DemandResponse::Item(chunk)) => Ok(Some(chunk)),
Some(DemandResponse::Complete) => Ok(None),
Some(DemandResponse::Error(error)) => Err(error),
None => Err(abrupt_termination()),
},
close_read_resource,
)
}
fn close_read_resource(resource: ReadResource) -> StreamResult<()> {
let _ = resource.cancel.send(true);
resource.task.abort();
Ok(())
}
async fn run_read_task<R>(
mut reader: R,
chunk_size: usize,
sender: mpsc::Sender<DemandResponse<Vec<u8>>>,
mut cancel: watch::Receiver<bool>,
) where
R: AsyncRead + Unpin + Send + 'static,
{
let mut buffer = vec![0_u8; chunk_size];
let mut pending_tail = Vec::with_capacity(chunk_size);
loop {
let read = tokio::select! {
read = reader.read(&mut buffer) => read,
changed = cancel.changed() => {
let _ = changed;
return;
}
};
match read {
Ok(0) => {
if !pending_tail.is_empty()
&& !send_read_item(
&sender,
DemandResponse::Item(std::mem::take(&mut pending_tail)),
&mut cancel,
)
.await
{
return;
}
let _ = send_read_item(&sender, DemandResponse::Complete, &mut cancel).await;
return;
}
Ok(read) => {
if !send_read_chunks(
&sender,
chunk_size,
&mut pending_tail,
&buffer[..read],
&mut cancel,
)
.await
{
return;
}
}
Err(error) => {
let _ =
send_read_item(&sender, DemandResponse::Error(io_error(error)), &mut cancel)
.await;
return;
}
}
}
}
async fn send_read_chunks(
sender: &mpsc::Sender<DemandResponse<Vec<u8>>>,
chunk_size: usize,
pending_tail: &mut Vec<u8>,
read_buffer: &[u8],
cancel: &mut watch::Receiver<bool>,
) -> bool {
let mut offset = 0;
if !pending_tail.is_empty() {
let needed = chunk_size - pending_tail.len();
let take = needed.min(read_buffer.len());
pending_tail.extend_from_slice(&read_buffer[..take]);
offset += take;
if pending_tail.len() == chunk_size
&& !send_read_item(
sender,
DemandResponse::Item(std::mem::take(pending_tail)),
cancel,
)
.await
{
return false;
}
}
while offset + chunk_size <= read_buffer.len() {
let next = offset + chunk_size;
if !send_read_item(
sender,
DemandResponse::Item(read_buffer[offset..next].to_vec()),
cancel,
)
.await
{
return false;
}
offset = next;
}
if offset < read_buffer.len() {
pending_tail.extend_from_slice(&read_buffer[offset..]);
}
true
}
async fn send_read_item<T>(
sender: &mpsc::Sender<DemandResponse<T>>,
item: DemandResponse<T>,
cancel: &mut watch::Receiver<bool>,
) -> bool
where
T: Send + 'static,
{
tokio::select! {
result = sender.send(item) => result.is_ok(),
changed = cancel.changed() => {
let _ = changed;
false
}
}
}
fn single_use_async_write_sink<W>(writer: W, handle: Handle) -> TlsByteSink
where
W: AsyncWrite + Unpin + Send + 'static,
{
let writer = Arc::new(Mutex::new(Some(writer)));
Flow::<Vec<u8>, Vec<u8>>::identity()
.map_with_resource(
{
let writer = Arc::clone(&writer);
move || {
writer
.lock()
.expect("single-use TLS writer poisoned")
.take()
.ok_or_else(|| {
StreamError::Failed("TLS writer already materialized".into())
})
}
},
{
let handle = handle.clone();
move |writer, chunk| {
handle.block_on(async {
writer.write_all(&chunk).await.map_err(io_error)?;
writer.flush().await.map_err(io_error)
})?;
Ok(())
}
},
move |mut writer| {
handle.block_on(async {
writer.flush().await.map_err(io_error)?;
writer.shutdown().await.map_err(io_error)
})?;
Ok(None)
},
)
.to_mat(Sink::ignore(), Keep::right)
}
fn tls_bind_source(
listener: TcpListener,
server_config: Arc<rustls::ServerConfig>,
local_addr: SocketAddr,
handle: Handle,
chunk_size: usize,
) -> Source<TlsIncomingConnection, TlsBinding> {
let listener = Arc::new(Mutex::new(Some(listener)));
Source::unfold_resource(
{
let listener = Arc::clone(&listener);
let handle = handle.clone();
move || {
let listener = listener
.lock()
.expect("single-use TLS listener poisoned")
.take()
.ok_or_else(|| {
StreamError::Failed("TLS listener already materialized".into())
})?;
let (demand_sender, demand_receiver) = mpsc::channel(1);
let (cancel_sender, cancel_receiver) = watch::channel(false);
let task = handle.spawn(run_tls_bind_task(
listener,
Arc::clone(&server_config),
local_addr,
chunk_size,
handle.clone(),
demand_receiver,
cancel_receiver,
));
Ok(BindResource {
demands: demand_sender,
cancel: cancel_sender,
task,
})
}
},
|resource| {
let (reply_sender, reply_receiver) = std_mpsc::channel();
resource
.demands
.blocking_send(reply_sender)
.map_err(|_| abrupt_termination())?;
match reply_receiver.recv() {
Ok(DemandResponse::Item(connection)) => Ok(Some(connection)),
Ok(DemandResponse::Complete) => Ok(None),
Ok(DemandResponse::Error(error)) => Err(error),
Err(_) => Err(abrupt_termination()),
}
},
close_bind_resource,
)
.map_materialized_value(move |_| TlsBinding { local_addr })
}
fn close_bind_resource(resource: BindResource) -> StreamResult<()> {
let _ = resource.cancel.send(true);
resource.task.abort();
Ok(())
}
async fn run_tls_bind_task(
listener: TcpListener,
server_config: Arc<rustls::ServerConfig>,
local_addr: SocketAddr,
chunk_size: usize,
handle: Handle,
mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
mut cancel: watch::Receiver<bool>,
) {
let acceptor = TlsAcceptor::from(server_config);
loop {
let reply = tokio::select! {
demand = demands.recv() => match demand {
Some(reply) => reply,
None => return,
},
changed = cancel.changed() => {
let _ = changed;
return;
}
};
let (tcp, remote_addr) = loop {
let accepted = tokio::select! {
accepted = listener.accept() => accepted,
changed = cancel.changed() => {
let _ = changed;
return;
}
};
match accepted {
Ok(accepted) => break accepted,
Err(error) if is_transient_accept_error(&error) => continue,
Err(error) => {
let _ = reply.send(DemandResponse::Error(io_error(error)));
return;
}
}
};
let connection = TlsConnection {
local_addr: tcp.local_addr().unwrap_or(local_addr),
remote_addr,
};
let accepted = tokio::select! {
accepted = acceptor.accept(tcp) => accepted,
changed = cancel.changed() => {
let _ = changed;
return;
}
};
match accepted {
Ok(stream) => {
let incoming =
tls_incoming_connection(stream, connection, handle.clone(), chunk_size);
if reply.send(DemandResponse::Item(incoming)).is_err() {
return;
}
}
Err(error) => {
let _ = reply.send(DemandResponse::Error(io_error(error)));
return;
}
}
}
}
fn is_transient_accept_error(error: &std::io::Error) -> bool {
matches!(
error.kind(),
std::io::ErrorKind::Interrupted
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::ConnectionReset
) || error.raw_os_error().is_some_and(is_transient_accept_errno)
}
#[cfg(target_os = "linux")]
fn is_transient_accept_errno(code: i32) -> bool {
matches!(code, 4 | 103 | 104)
}
#[cfg(not(target_os = "linux"))]
fn is_transient_accept_errno(_code: i32) -> bool {
false
}