use crate::async_carrier::{self, AsyncCommandSender, DemandBatcher};
use datum::{Flow, Keep, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult};
pub use quinn::{self, crypto, rustls};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::{Arc, Mutex, mpsc as std_mpsc};
use tokio::net::ToSocketAddrs;
use tokio::runtime::Handle;
use tokio::sync::{mpsc, watch};
use tokio::task::JoinHandle;
pub const DEFAULT_CHUNK_SIZE: usize = 8192;
const DEFAULT_RECEIVE_BUFFER: usize = 64;
pub type QuicByteSource = Source<Vec<u8>, NotUsed>;
pub type QuicByteSink = Sink<Vec<u8>, StreamCompletion<NotUsed>>;
enum DemandResponse<T> {
Item(T),
Complete,
Error(StreamError),
}
struct ReadResource {
receiver: std_mpsc::Receiver<DemandResponse<Vec<u8>>>,
carrier: QuicCarrier,
demand: DemandBatcher,
pending: Option<DemandResponse<Vec<u8>>>,
}
impl Drop for ReadResource {
fn drop(&mut self) {
self.carrier.close_read();
}
}
enum QuicCarrierCommand {
Demand(usize),
SendOne(Vec<u8>),
SendBatch(Vec<Vec<u8>>),
CloseRead,
CloseWrite {
ack: std_mpsc::Sender<StreamResult<()>>,
},
}
#[derive(Clone)]
struct QuicCarrier {
inner: Arc<QuicCarrierInner>,
}
struct QuicCarrierInner {
commands: AsyncCommandSender<QuicCarrierCommand>,
send_errors: Mutex<std_mpsc::Receiver<StreamError>>,
task: Mutex<Option<JoinHandle<()>>>,
}
impl Drop for QuicCarrierInner {
fn drop(&mut self) {
if let Some(task) = self.task.lock().expect("QUIC carrier task poisoned").take() {
task.abort();
}
}
}
impl QuicCarrier {
fn close_read(&self) {
let _ = self.inner.commands.try_send(QuicCarrierCommand::CloseRead);
}
fn request_demand(&self, demand: usize) -> StreamResult<()> {
self.inner
.commands
.send_or_blocking(QuicCarrierCommand::Demand(demand))
}
fn send_items(&self, items: Vec<Vec<u8>>) -> StreamResult<()> {
self.check_send_error()?;
self.inner
.commands
.send_or_blocking(QuicCarrierCommand::SendBatch(items))
.map_err(|error| StreamError::Failed(format!("QUIC send batch failed: {error:?}")))
}
fn send_one(&self, item: Vec<u8>) -> StreamResult<()> {
self.check_send_error()?;
self.inner
.commands
.send_or_blocking(QuicCarrierCommand::SendOne(item))
.map_err(|error| StreamError::Failed(format!("QUIC send failed: {error:?}")))
}
fn close_write(&self) -> StreamResult<()> {
self.check_send_error()?;
let (ack_sender, ack_receiver) = std_mpsc::channel();
if self
.inner
.commands
.send_or_blocking(QuicCarrierCommand::CloseWrite { ack: ack_sender })
.is_err()
{
return Ok(());
}
match ack_receiver.recv() {
Ok(result) => result,
Err(_) => Err(abrupt_termination()),
}?;
self.check_send_error()
}
fn check_send_error(&self) -> StreamResult<()> {
match self
.inner
.send_errors
.lock()
.expect("QUIC carrier send error receiver poisoned")
.try_recv()
{
Ok(error) => Err(error),
Err(std_mpsc::TryRecvError::Empty) | Err(std_mpsc::TryRecvError::Disconnected) => {
Ok(())
}
}
}
}
struct SendResource {
carrier: QuicCarrier,
pending: Vec<Vec<u8>>,
batch_size: usize,
}
#[derive(Clone, Copy)]
struct QuicReadConfig {
chunk_size: usize,
emit_available: bool,
}
struct BindResource {
demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<QuicIncomingConnection>>>,
cancel: watch::Sender<bool>,
task: JoinHandle<()>,
}
impl Drop for BindResource {
fn drop(&mut self) {
let _ = self.cancel.send(true);
self.task.abort();
}
}
struct AcceptBiResource {
demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<QuicBidirectionalStream>>>,
cancel: watch::Sender<bool>,
task: JoinHandle<()>,
}
impl Drop for AcceptBiResource {
fn drop(&mut self) {
let _ = self.cancel.send(true);
self.task.abort();
}
}
fn quic_error(error: impl std::fmt::Display) -> StreamError {
StreamError::Failed(error.to_string())
}
fn io_error(error: std::io::Error) -> StreamError {
StreamError::Failed(error.to_string())
}
fn abrupt_termination() -> StreamError {
StreamError::AbruptTermination
}
fn close_code() -> quinn::VarInt {
quinn::VarInt::from_u32(0)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct QuicBinding {
pub local_addr: SocketAddr,
}
impl QuicBinding {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct QuicStream {
pub id: quinn::StreamId,
}
impl QuicStream {
#[must_use]
pub fn id(&self) -> quinn::StreamId {
self.id
}
}
#[derive(Debug, Clone)]
pub struct QuicConnection {
endpoint: quinn::Endpoint,
connection: quinn::Connection,
handle: Handle,
local_addr: SocketAddr,
remote_addr: SocketAddr,
chunk_size: usize,
}
impl QuicConnection {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
#[must_use]
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
#[must_use]
pub fn chunk_size(&self) -> usize {
self.chunk_size
}
#[must_use]
pub fn quinn_connection(&self) -> &quinn::Connection {
&self.connection
}
#[must_use]
pub fn quinn_endpoint(&self) -> &quinn::Endpoint {
&self.endpoint
}
#[must_use]
pub fn open_bi(
&self,
chunk_size: usize,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<QuicStream>> {
assert!(chunk_size > 0, "chunk size must be greater than zero");
let connection = self.connection.clone();
let handle = self.handle.clone();
Flow::future_flow(move || {
let connection = connection.clone();
let handle = handle.clone();
async move {
let (send, recv) = connection.open_bi().await.map_err(quic_error)?;
Ok(quic_bi_stream_from_halves(send, recv, handle, chunk_size, false).into_flow())
}
})
}
#[must_use]
pub fn open_bi_default(&self) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<QuicStream>> {
self.open_bi(self.chunk_size)
}
#[must_use]
pub fn open_bi_stream(
&self,
chunk_size: usize,
) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
assert!(chunk_size > 0, "chunk size must be greater than zero");
let connection = self.connection.clone();
let handle = self.handle.clone();
Source::lazy_future_source(move || {
let connection = connection.clone();
let handle = handle.clone();
async move {
let (send, recv) = connection.open_bi().await.map_err(quic_error)?;
let stream = quic_bi_stream_from_halves(send, recv, handle, chunk_size, false);
let metadata = stream.stream();
let stream = Arc::new(Mutex::new(Some(stream)));
Ok(Source::unfold_resource(
{
let stream = Arc::clone(&stream);
move || {
stream
.lock()
.expect("single-use QUIC bidi stream poisoned")
.take()
.map(Some)
.ok_or_else(|| {
StreamError::Failed(
"QUIC bidi stream already materialized".into(),
)
})
}
},
|stream| Ok(stream.take()),
|_stream| Ok(()),
)
.map_materialized_value(move |_| metadata))
}
})
}
#[must_use]
pub fn open_bi_stream_default(
&self,
) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
self.open_bi_stream(self.chunk_size)
}
#[must_use]
pub fn open_bi_stream_available(
&self,
chunk_size: usize,
) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
assert!(chunk_size > 0, "chunk size must be greater than zero");
let connection = self.connection.clone();
let handle = self.handle.clone();
Source::lazy_future_source(move || {
let connection = connection.clone();
let handle = handle.clone();
async move {
let (send, recv) = connection.open_bi().await.map_err(quic_error)?;
let stream = quic_bi_stream_from_halves(send, recv, handle, chunk_size, true);
let metadata = stream.stream();
let stream = Arc::new(Mutex::new(Some(stream)));
Ok(Source::unfold_resource(
{
let stream = Arc::clone(&stream);
move || {
stream
.lock()
.expect("single-use QUIC bidi stream poisoned")
.take()
.map(Some)
.ok_or_else(|| {
StreamError::Failed(
"QUIC bidi stream already materialized".into(),
)
})
}
},
|stream| Ok(stream.take()),
|_stream| Ok(()),
)
.map_materialized_value(move |_| metadata))
}
})
}
#[must_use]
pub fn accept_bi(&self, chunk_size: usize) -> Source<QuicBidirectionalStream, QuicConnection> {
assert!(chunk_size > 0, "chunk size must be greater than zero");
let connection = self.clone();
Source::unfold_resource(
{
let connection = connection.clone();
move || {
let handle = connection.handle.clone();
let (demand_sender, demand_receiver) = mpsc::channel(1);
let (cancel_sender, cancel_receiver) = watch::channel(false);
let task = handle.spawn(run_accept_bi_task(
connection.connection.clone(),
chunk_size,
false,
handle.clone(),
demand_receiver,
cancel_receiver,
));
Ok(AcceptBiResource {
demands: demand_sender,
cancel: cancel_sender,
task,
})
}
},
receive_demand_response,
close_accept_bi_resource,
)
.map_materialized_value(move |_| connection.clone())
}
#[must_use]
pub fn accept_bi_default(&self) -> Source<QuicBidirectionalStream, QuicConnection> {
self.accept_bi(self.chunk_size)
}
#[must_use]
pub fn accept_bi_available(
&self,
chunk_size: usize,
) -> Source<QuicBidirectionalStream, QuicConnection> {
assert!(chunk_size > 0, "chunk size must be greater than zero");
let connection = self.clone();
Source::unfold_resource(
{
let connection = connection.clone();
move || {
let handle = connection.handle.clone();
let (demand_sender, demand_receiver) = mpsc::channel(1);
let (cancel_sender, cancel_receiver) = watch::channel(false);
let task = handle.spawn(run_accept_bi_task(
connection.connection.clone(),
chunk_size,
true,
handle.clone(),
demand_receiver,
cancel_receiver,
));
Ok(AcceptBiResource {
demands: demand_sender,
cancel: cancel_sender,
task,
})
}
},
receive_demand_response,
close_accept_bi_resource,
)
.map_materialized_value(move |_| connection.clone())
}
pub fn close(&self, reason: &[u8]) {
self.connection.close(close_code(), reason);
}
}
#[derive(Debug, Clone)]
pub struct QuicIncomingConnection {
connection: QuicConnection,
}
impl QuicIncomingConnection {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.connection.local_addr()
}
#[must_use]
pub fn remote_addr(&self) -> SocketAddr {
self.connection.remote_addr()
}
#[must_use]
pub fn connection(&self) -> QuicConnection {
self.connection.clone()
}
#[must_use]
pub fn into_connection(self) -> QuicConnection {
self.connection
}
#[must_use]
pub fn open_bi(
&self,
chunk_size: usize,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<QuicStream>> {
self.connection.open_bi(chunk_size)
}
#[must_use]
pub fn open_bi_default(&self) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<QuicStream>> {
self.connection.open_bi_default()
}
#[must_use]
pub fn open_bi_stream(
&self,
chunk_size: usize,
) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
self.connection.open_bi_stream(chunk_size)
}
#[must_use]
pub fn open_bi_stream_default(
&self,
) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
self.connection.open_bi_stream_default()
}
#[must_use]
pub fn open_bi_stream_available(
&self,
chunk_size: usize,
) -> Source<QuicBidirectionalStream, StreamCompletion<QuicStream>> {
self.connection.open_bi_stream_available(chunk_size)
}
#[must_use]
pub fn accept_bi(&self, chunk_size: usize) -> Source<QuicBidirectionalStream, QuicConnection> {
self.connection.accept_bi(chunk_size)
}
#[must_use]
pub fn accept_bi_default(&self) -> Source<QuicBidirectionalStream, QuicConnection> {
self.connection.accept_bi_default()
}
#[must_use]
pub fn accept_bi_available(
&self,
chunk_size: usize,
) -> Source<QuicBidirectionalStream, QuicConnection> {
self.connection.accept_bi_available(chunk_size)
}
}
pub struct QuicBidirectionalStream {
stream: QuicStream,
send: quinn::SendStream,
recv: quinn::RecvStream,
handle: Handle,
chunk_size: usize,
emit_available: bool,
}
impl QuicBidirectionalStream {
#[must_use]
pub fn stream(&self) -> QuicStream {
self.stream
}
#[must_use]
pub fn into_parts(self) -> (QuicByteSource, QuicByteSink) {
let Self {
send,
recv,
handle,
chunk_size,
emit_available,
..
} = self;
single_use_quic_halves(send, recv, handle, chunk_size, emit_available)
}
#[must_use]
pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, QuicStream> {
let stream = self.stream;
let (source, sink) = self.into_parts();
Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| stream)
}
pub(crate) fn into_stream_ref_parts(
self,
) -> (quinn::RecvStream, quinn::SendStream, Handle, usize, bool) {
(
self.recv,
self.send,
self.handle,
self.chunk_size,
self.emit_available,
)
}
}
pub struct TokioQuic;
pub type Quic = TokioQuic;
impl TokioQuic {
#[must_use]
pub fn bind<A>(
addr: A,
server_config: quinn::ServerConfig,
chunk_size: usize,
) -> Source<QuicIncomingConnection, StreamCompletion<QuicBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(chunk_size > 0, "chunk size must be greater than zero");
Source::lazy_future_source(move || {
let addr = addr.clone();
let server_config = server_config.clone();
async move {
let handle = Handle::current();
let addr = resolve_addr(addr).await?;
let endpoint = quinn::Endpoint::server(server_config, addr).map_err(io_error)?;
let local_addr = endpoint.local_addr().map_err(io_error)?;
Ok(quic_bind_source(endpoint, local_addr, handle, chunk_size))
}
})
}
#[must_use]
pub fn bind_default<A>(
addr: A,
server_config: quinn::ServerConfig,
) -> Source<QuicIncomingConnection, StreamCompletion<QuicBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::bind(addr, server_config, DEFAULT_CHUNK_SIZE)
}
#[must_use]
pub fn connect<A>(
addr: A,
server_name: impl Into<String>,
client_config: quinn::ClientConfig,
chunk_size: usize,
) -> Source<QuicConnection, StreamCompletion<QuicConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(chunk_size > 0, "chunk size must be greater than zero");
let server_name = server_name.into();
Source::lazy_future_source(move || {
let addr = addr.clone();
let server_name = server_name.clone();
let client_config = client_config.clone();
async move {
let remote_addr = resolve_addr(addr).await?;
let local_addr = client_bind_addr(remote_addr);
let mut endpoint = quinn::Endpoint::client(local_addr).map_err(io_error)?;
endpoint.set_default_client_config(client_config);
let connecting = endpoint
.connect(remote_addr, &server_name)
.map_err(quic_error)?;
let connection = connecting.await.map_err(quic_error)?;
let endpoint_local_addr = endpoint.local_addr().map_err(io_error)?;
let connection = QuicConnection {
local_addr: connection_local_addr(
&connection,
endpoint_local_addr,
remote_addr.ip(),
),
remote_addr: connection.remote_address(),
endpoint,
connection,
handle: Handle::current(),
chunk_size,
};
let materialized = connection.clone();
Ok(
Source::single(connection)
.map_materialized_value(move |_| materialized.clone()),
)
}
})
}
#[must_use]
pub fn connect_default<A>(
addr: A,
server_name: impl Into<String>,
client_config: quinn::ClientConfig,
) -> Source<QuicConnection, StreamCompletion<QuicConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::connect(addr, server_name, client_config, DEFAULT_CHUNK_SIZE)
}
}
async fn resolve_addr<A>(addr: A) -> StreamResult<SocketAddr>
where
A: ToSocketAddrs,
{
let mut addrs = tokio::net::lookup_host(addr).await.map_err(io_error)?;
addrs
.next()
.ok_or_else(|| StreamError::Failed("address resolved to no socket addresses".into()))
}
fn client_bind_addr(remote_addr: SocketAddr) -> SocketAddr {
if remote_addr.is_ipv6() {
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
} else {
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
}
}
fn connection_local_addr(
connection: &quinn::Connection,
endpoint_addr: SocketAddr,
fallback_ip: IpAddr,
) -> SocketAddr {
connection
.local_ip()
.map(|ip| SocketAddr::new(ip, endpoint_addr.port()))
.or_else(|| {
endpoint_addr
.ip()
.is_unspecified()
.then(|| SocketAddr::new(fallback_ip, endpoint_addr.port()))
})
.unwrap_or(endpoint_addr)
}
fn quic_bi_stream_from_halves(
send: quinn::SendStream,
recv: quinn::RecvStream,
handle: Handle,
chunk_size: usize,
emit_available: bool,
) -> QuicBidirectionalStream {
let stream = QuicStream { id: send.id() };
QuicBidirectionalStream {
stream,
send,
recv,
handle,
chunk_size,
emit_available,
}
}
fn single_use_quic_halves(
send: quinn::SendStream,
recv: quinn::RecvStream,
handle: Handle,
chunk_size: usize,
emit_available: bool,
) -> (QuicByteSource, QuicByteSink) {
let (carrier, receiver) = start_quic_carrier(
send,
recv,
handle,
chunk_size,
emit_available,
DEFAULT_RECEIVE_BUFFER,
);
let source =
single_use_quic_read_source_from_carrier(carrier.clone(), receiver, DEFAULT_RECEIVE_BUFFER);
let sink = single_use_quic_write_sink_from_carrier(carrier, 1);
(source, sink)
}
fn single_use_quic_read_source_from_carrier(
carrier: QuicCarrier,
receiver: std_mpsc::Receiver<DemandResponse<Vec<u8>>>,
receive_buffer: usize,
) -> QuicByteSource {
let receiver = Arc::new(Mutex::new(Some(receiver)));
Source::unfold_resource(
{
let receiver = Arc::clone(&receiver);
move || {
let receiver = receiver
.lock()
.expect("single-use QUIC receiver poisoned")
.take()
.ok_or_else(|| {
StreamError::Failed("QUIC source already materialized".into())
})?;
let demand = DemandBatcher::new(receive_buffer);
let pending = match carrier.request_demand(demand.initial()) {
Ok(()) => None,
Err(error) => match receiver.try_recv() {
Ok(response) => Some(response),
Err(std_mpsc::TryRecvError::Empty) => return Err(error),
Err(std_mpsc::TryRecvError::Disconnected) => {
return Err(abrupt_termination());
}
},
};
Ok(ReadResource {
receiver,
carrier: carrier.clone(),
demand,
pending,
})
}
},
read_next_quic_chunk,
close_read_resource,
)
}
fn read_next_quic_chunk(resource: &mut ReadResource) -> StreamResult<Option<Vec<u8>>> {
let response = match resource.pending.take() {
Some(response) => response,
None => resource.receiver.recv().map_err(|_| abrupt_termination())?,
};
match response {
DemandResponse::Item(chunk) => {
if let Some(demand) = resource.demand.record_consumed() {
let _ = resource.carrier.request_demand(demand);
}
Ok(Some(chunk))
}
DemandResponse::Complete => Ok(None),
DemandResponse::Error(error) => Err(error),
}
}
fn close_read_resource(resource: ReadResource) -> StreamResult<()> {
resource.carrier.close_read();
Ok(())
}
fn start_quic_carrier(
send: quinn::SendStream,
recv: quinn::RecvStream,
handle: Handle,
chunk_size: usize,
emit_available: bool,
receive_buffer: usize,
) -> (QuicCarrier, std_mpsc::Receiver<DemandResponse<Vec<u8>>>) {
let command_capacity = async_carrier::DEFAULT_COMMAND_BUFFER.max(receive_buffer);
let (commands, command_receiver) = async_carrier::command_channel(command_capacity, "QUIC");
let (send_error_sender, send_error_receiver) = std_mpsc::channel();
let (receive_sender, receive_receiver) =
std_mpsc::sync_channel(receive_buffer.saturating_add(1));
let command_keepalive = commands.clone();
let read_config = QuicReadConfig {
chunk_size,
emit_available,
};
let task = handle.spawn(run_quic_carrier_task(
send,
recv,
read_config,
receive_sender,
send_error_sender,
command_keepalive,
command_receiver,
));
(
QuicCarrier {
inner: Arc::new(QuicCarrierInner {
commands,
send_errors: Mutex::new(send_error_receiver),
task: Mutex::new(Some(task)),
}),
},
receive_receiver,
)
}
async fn run_quic_carrier_task(
mut send: quinn::SendStream,
mut recv: quinn::RecvStream,
read_config: QuicReadConfig,
receive_sender: std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
send_error_sender: std_mpsc::Sender<StreamError>,
_command_keepalive: AsyncCommandSender<QuicCarrierCommand>,
mut commands: mpsc::Receiver<QuicCarrierCommand>,
) {
let mut buffer = vec![0_u8; read_config.chunk_size];
let mut pending_tail = Vec::with_capacity(read_config.chunk_size);
let mut requested = 0_usize;
let mut read_open = true;
let mut write_open = true;
loop {
if !read_open && !write_open {
return;
}
if read_open && requested > 0 {
tokio::select! {
biased;
command = commands.recv() => {
let Some(command) = command else {
return;
};
if !handle_quic_carrier_command(
&mut send,
command,
&send_error_sender,
&mut read_open,
&mut write_open,
&mut requested,
).await {
return;
}
}
read = recv.read(&mut buffer) => {
match read {
Ok(Some(read)) => {
match queue_quic_read_chunks(
&receive_sender,
&send_error_sender,
read_config.chunk_size,
&mut pending_tail,
&buffer[..read],
read_config.emit_available,
) {
QuicReadQueueResult::Queued(queued) => {
requested = requested.saturating_sub(queued);
}
QuicReadQueueResult::Closed => {
read_open = false;
}
QuicReadQueueResult::Failed => {
return;
}
}
}
Ok(None) => {
if !pending_tail.is_empty() {
match try_send_quic_read_response(
&receive_sender,
DemandResponse::Item(std::mem::take(&mut pending_tail)),
) {
QuicQueueOutcome::Queued => {
requested = requested.saturating_sub(1);
}
QuicQueueOutcome::Closed => {
read_open = false;
continue;
}
QuicQueueOutcome::Full => {
report_quic_read_error(
&receive_sender,
&send_error_sender,
quic_receive_buffer_overflow(),
);
return;
}
}
}
match try_send_quic_read_response(
&receive_sender,
DemandResponse::Complete,
) {
QuicQueueOutcome::Queued | QuicQueueOutcome::Closed => {
read_open = false;
}
QuicQueueOutcome::Full => {
report_quic_read_error(
&receive_sender,
&send_error_sender,
quic_receive_buffer_overflow(),
);
return;
}
}
}
Err(error) => {
report_quic_read_error(
&receive_sender,
&send_error_sender,
quic_error(error),
);
return;
}
}
}
}
} else {
let Some(command) = commands.recv().await else {
return;
};
if !handle_quic_carrier_command(
&mut send,
command,
&send_error_sender,
&mut read_open,
&mut write_open,
&mut requested,
)
.await
{
return;
}
}
}
}
async fn handle_quic_carrier_command(
send: &mut quinn::SendStream,
command: QuicCarrierCommand,
send_error_sender: &std_mpsc::Sender<StreamError>,
read_open: &mut bool,
write_open: &mut bool,
requested: &mut usize,
) -> bool {
match command {
QuicCarrierCommand::Demand(demand) => {
*requested = requested.saturating_add(demand);
true
}
QuicCarrierCommand::SendOne(chunk) => {
if !*write_open {
report_quic_write_error(
send_error_sender,
StreamError::Failed("QUIC write side is closed".to_owned()),
);
return *read_open;
}
if write_one_quic_chunk(send, send_error_sender, &chunk).await {
true
} else {
*write_open = false;
*read_open
}
}
QuicCarrierCommand::SendBatch(chunks) => {
if !*write_open {
report_quic_write_error(
send_error_sender,
StreamError::Failed("QUIC write side is closed".to_owned()),
);
return *read_open;
}
for chunk in &chunks {
if let Err(error) = send.write_all(chunk).await.map_err(quic_error) {
report_quic_write_error(send_error_sender, error);
*write_open = false;
return *read_open;
}
}
true
}
QuicCarrierCommand::CloseRead => {
*read_open = false;
true
}
QuicCarrierCommand::CloseWrite { ack } => {
*write_open = false;
let result = close_quic_writer(send).await;
match result {
Ok(()) => {
let _ = ack.send(Ok(()));
true
}
Err(error) => {
report_quic_write_error(send_error_sender, error.clone());
let _ = ack.send(Err(error));
*read_open
}
}
}
}
}
async fn write_one_quic_chunk(
send: &mut quinn::SendStream,
send_error_sender: &std_mpsc::Sender<StreamError>,
chunk: &[u8],
) -> bool {
if let Err(error) = send.write_all(chunk).await.map_err(quic_error) {
report_quic_write_error(send_error_sender, error);
return false;
}
true
}
async fn close_quic_writer(send: &mut quinn::SendStream) -> StreamResult<()> {
send.write_all(&[]).await.map_err(quic_error)?;
send.finish().map_err(quic_error)
}
enum QuicReadQueueResult {
Queued(usize),
Closed,
Failed,
}
enum QuicQueueOutcome {
Queued,
Full,
Closed,
}
fn queue_quic_read_chunks(
sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
send_error_sender: &std_mpsc::Sender<StreamError>,
chunk_size: usize,
pending_tail: &mut Vec<u8>,
read_buffer: &[u8],
emit_available: bool,
) -> QuicReadQueueResult {
let mut offset = 0;
let mut queued = 0_usize;
if !pending_tail.is_empty() {
let needed = chunk_size - pending_tail.len();
let take = needed.min(read_buffer.len());
pending_tail.extend_from_slice(&read_buffer[..take]);
offset += take;
if pending_tail.len() == chunk_size {
match try_send_quic_read_response(
sender,
DemandResponse::Item(std::mem::take(pending_tail)),
) {
QuicQueueOutcome::Queued => queued += 1,
QuicQueueOutcome::Closed => return QuicReadQueueResult::Closed,
QuicQueueOutcome::Full => {
report_quic_read_error(
sender,
send_error_sender,
quic_receive_buffer_overflow(),
);
return QuicReadQueueResult::Failed;
}
}
}
}
while offset + chunk_size <= read_buffer.len() {
let next = offset + chunk_size;
match try_send_quic_read_response(
sender,
DemandResponse::Item(read_buffer[offset..next].to_vec()),
) {
QuicQueueOutcome::Queued => queued += 1,
QuicQueueOutcome::Closed => return QuicReadQueueResult::Closed,
QuicQueueOutcome::Full => {
report_quic_read_error(sender, send_error_sender, quic_receive_buffer_overflow());
return QuicReadQueueResult::Failed;
}
}
offset = next;
}
if offset < read_buffer.len() {
pending_tail.extend_from_slice(&read_buffer[offset..]);
}
if emit_available && !pending_tail.is_empty() {
match try_send_quic_read_response(
sender,
DemandResponse::Item(std::mem::take(pending_tail)),
) {
QuicQueueOutcome::Queued => queued += 1,
QuicQueueOutcome::Closed => return QuicReadQueueResult::Closed,
QuicQueueOutcome::Full => {
report_quic_read_error(sender, send_error_sender, quic_receive_buffer_overflow());
return QuicReadQueueResult::Failed;
}
}
}
QuicReadQueueResult::Queued(queued)
}
fn try_send_quic_read_response(
sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
item: DemandResponse<Vec<u8>>,
) -> QuicQueueOutcome {
match sender.try_send(item) {
Ok(()) => QuicQueueOutcome::Queued,
Err(std_mpsc::TrySendError::Full(_)) => QuicQueueOutcome::Full,
Err(std_mpsc::TrySendError::Disconnected(_)) => QuicQueueOutcome::Closed,
}
}
fn report_quic_read_error(
receive_sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
send_error_sender: &std_mpsc::Sender<StreamError>,
error: StreamError,
) {
let _ = send_error_sender.send(error.clone());
let _ = receive_sender.try_send(DemandResponse::Error(error));
}
fn report_quic_write_error(send_error_sender: &std_mpsc::Sender<StreamError>, error: StreamError) {
let _ = send_error_sender.send(error);
}
fn quic_receive_buffer_overflow() -> StreamError {
StreamError::Failed("QUIC receive buffer filled without downstream demand".to_owned())
}
fn single_use_quic_write_sink_from_carrier(
carrier: QuicCarrier,
batch_size: usize,
) -> QuicByteSink {
let carrier = Arc::new(Mutex::new(Some(carrier)));
Flow::<Vec<u8>, Vec<u8>>::identity()
.map_with_resource(
{
let carrier = Arc::clone(&carrier);
move || {
let carrier = carrier
.lock()
.expect("single-use QUIC carrier poisoned")
.take()
.ok_or_else(|| {
StreamError::Failed("QUIC sink already materialized".into())
})?;
Ok(SendResource {
carrier,
pending: Vec::with_capacity(batch_size),
batch_size,
})
}
},
|resource, chunk| {
send_quic_chunk(resource, chunk)?;
Ok(NotUsed)
},
close_quic_send_resource,
)
.to_mat(Sink::ignore(), Keep::right)
}
fn close_quic_send_resource(mut resource: SendResource) -> StreamResult<Option<NotUsed>> {
flush_quic_send_resource(&mut resource)?;
resource.carrier.close_write()?;
Ok(None)
}
fn send_quic_chunk(resource: &mut SendResource, chunk: Vec<u8>) -> StreamResult<()> {
if resource.batch_size <= 1 {
return resource.carrier.send_one(chunk);
}
resource.pending.push(chunk);
if resource.pending.len() >= resource.batch_size {
flush_quic_send_resource(resource)?;
}
Ok(())
}
fn flush_quic_send_resource(resource: &mut SendResource) -> StreamResult<()> {
if resource.pending.is_empty() {
return resource.carrier.check_send_error();
}
let pending = std::mem::take(&mut resource.pending);
resource.carrier.send_items(pending)
}
fn quic_bind_source(
endpoint: quinn::Endpoint,
local_addr: SocketAddr,
handle: Handle,
chunk_size: usize,
) -> Source<QuicIncomingConnection, QuicBinding> {
let endpoint = Arc::new(Mutex::new(Some(endpoint)));
Source::unfold_resource(
{
let endpoint = Arc::clone(&endpoint);
let handle = handle.clone();
move || {
let endpoint = endpoint
.lock()
.expect("single-use QUIC endpoint poisoned")
.take()
.ok_or_else(|| {
StreamError::Failed("QUIC endpoint already materialized".into())
})?;
let (demand_sender, demand_receiver) = mpsc::channel(1);
let (cancel_sender, cancel_receiver) = watch::channel(false);
let task = handle.spawn(run_quic_bind_task(
endpoint,
local_addr,
chunk_size,
handle.clone(),
demand_receiver,
cancel_receiver,
));
Ok(BindResource {
demands: demand_sender,
cancel: cancel_sender,
task,
})
}
},
receive_demand_response,
close_bind_resource,
)
.map_materialized_value(move |_| QuicBinding { local_addr })
}
fn receive_demand_response<T>(resource: &mut impl DemandResource<T>) -> StreamResult<Option<T>>
where
T: Send + 'static,
{
let (reply_sender, reply_receiver) = std_mpsc::channel();
resource
.demands()
.blocking_send(reply_sender)
.map_err(|_| abrupt_termination())?;
match reply_receiver.recv() {
Ok(DemandResponse::Item(item)) => Ok(Some(item)),
Ok(DemandResponse::Complete) => Ok(None),
Ok(DemandResponse::Error(error)) => Err(error),
Err(_) => Err(abrupt_termination()),
}
}
trait DemandResource<T>
where
T: Send + 'static,
{
fn demands(&self) -> &mpsc::Sender<std_mpsc::Sender<DemandResponse<T>>>;
}
impl DemandResource<QuicIncomingConnection> for BindResource {
fn demands(&self) -> &mpsc::Sender<std_mpsc::Sender<DemandResponse<QuicIncomingConnection>>> {
&self.demands
}
}
impl DemandResource<QuicBidirectionalStream> for AcceptBiResource {
fn demands(&self) -> &mpsc::Sender<std_mpsc::Sender<DemandResponse<QuicBidirectionalStream>>> {
&self.demands
}
}
fn close_bind_resource(resource: BindResource) -> StreamResult<()> {
let _ = resource.cancel.send(true);
resource.task.abort();
Ok(())
}
fn close_accept_bi_resource(resource: AcceptBiResource) -> StreamResult<()> {
let _ = resource.cancel.send(true);
resource.task.abort();
Ok(())
}
async fn run_quic_bind_task(
endpoint: quinn::Endpoint,
local_addr: SocketAddr,
chunk_size: usize,
handle: Handle,
mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<QuicIncomingConnection>>>,
mut cancel: watch::Receiver<bool>,
) {
loop {
let reply = tokio::select! {
demand = demands.recv() => match demand {
Some(reply) => reply,
None => return,
},
changed = cancel.changed() => {
let _ = changed;
return;
}
};
let incoming = tokio::select! {
incoming = endpoint.accept() => incoming,
changed = cancel.changed() => {
let _ = changed;
return;
}
};
let Some(incoming) = incoming else {
let _ = reply.send(DemandResponse::Complete);
return;
};
let connected = tokio::select! {
connected = incoming => connected,
changed = cancel.changed() => {
let _ = changed;
return;
}
};
match connected {
Ok(connection) => {
let incoming = QuicIncomingConnection {
connection: QuicConnection {
endpoint: endpoint.clone(),
local_addr: connection_local_addr(&connection, local_addr, local_addr.ip()),
remote_addr: connection.remote_address(),
connection,
handle: handle.clone(),
chunk_size,
},
};
if reply.send(DemandResponse::Item(incoming)).is_err() {
return;
}
}
Err(error) => {
let _ = reply.send(DemandResponse::Error(quic_error(error)));
return;
}
}
}
}
async fn run_accept_bi_task(
connection: quinn::Connection,
chunk_size: usize,
emit_available: bool,
handle: Handle,
mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<QuicBidirectionalStream>>>,
mut cancel: watch::Receiver<bool>,
) {
loop {
let reply = tokio::select! {
demand = demands.recv() => match demand {
Some(reply) => reply,
None => return,
},
changed = cancel.changed() => {
let _ = changed;
return;
}
};
let accepted = tokio::select! {
accepted = connection.accept_bi() => accepted,
changed = cancel.changed() => {
let _ = changed;
return;
}
};
match accepted {
Ok((send, recv)) => {
let stream = quic_bi_stream_from_halves(
send,
recv,
handle.clone(),
chunk_size,
emit_available,
);
if reply.send(DemandResponse::Item(stream)).is_err() {
return;
}
}
Err(error) => {
let _ = reply.send(DemandResponse::Error(quic_error(error)));
return;
}
}
}
}