use bytes::Bytes;
use futures::{Future, FutureExt, Sink, Stream, TryStreamExt, future::BoxFuture};
use std::{
convert::TryInto,
error::Error,
fmt, io,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, BufReader, BufWriter};
use tokio_util::codec::LengthDelimitedCodec;
use crate::{
RemoteSend,
chmux::{ChMux, ChMuxError},
codec,
rch::base,
};
#[cfg_attr(docsrs, doc(cfg(feature = "rch")))]
#[derive(Debug, Clone)]
pub enum ConnectError<TransportSinkError, TransportStreamError> {
ChMux(ChMuxError<TransportSinkError, TransportStreamError>),
RemoteConnect(base::ConnectError),
}
impl<TransportSinkError, TransportStreamError> fmt::Display
for ConnectError<TransportSinkError, TransportStreamError>
where
TransportSinkError: fmt::Display,
TransportStreamError: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::ChMux(err) => write!(f, "chmux error: {err}"),
Self::RemoteConnect(err) => write!(f, "channel connect failed: {err}"),
}
}
}
impl<TransportSinkError, TransportStreamError> Error for ConnectError<TransportSinkError, TransportStreamError>
where
TransportSinkError: Error,
TransportStreamError: Error,
{
}
impl<TransportSinkError, TransportStreamError> From<ChMuxError<TransportSinkError, TransportStreamError>>
for ConnectError<TransportSinkError, TransportStreamError>
{
fn from(err: ChMuxError<TransportSinkError, TransportStreamError>) -> Self {
Self::ChMux(err)
}
}
impl<TransportSinkError, TransportStreamError> From<base::ConnectError>
for ConnectError<TransportSinkError, TransportStreamError>
{
fn from(err: base::ConnectError) -> Self {
Self::RemoteConnect(err)
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "rch")))]
#[must_use = "You must poll or spawn the Connect future for the connection to work."]
pub struct Connect<'transport, TransportSinkError, TransportStreamError>(
BoxFuture<'transport, Result<(), ChMuxError<TransportSinkError, TransportStreamError>>>,
);
impl<'transport, TransportSinkError, TransportStreamError>
Connect<'transport, TransportSinkError, TransportStreamError>
{
pub async fn framed<TransportSink, TransportStream, Tx, Rx, Codec>(
cfg: crate::Cfg, transport_sink: TransportSink, transport_stream: TransportStream,
) -> Result<
(
Connect<'transport, TransportSinkError, TransportStreamError>,
base::Sender<Tx, Codec>,
base::Receiver<Rx, Codec>,
),
ConnectError<TransportSinkError, TransportStreamError>,
>
where
TransportSink: Sink<Bytes, Error = TransportSinkError> + Send + Sync + Unpin + 'transport,
TransportSinkError: Error + Send + Sync + 'static,
TransportStream: Stream<Item = Result<Bytes, TransportStreamError>> + Send + Sync + Unpin + 'transport,
TransportStreamError: Error + Send + Sync + 'static,
Tx: RemoteSend,
Rx: RemoteSend,
Codec: codec::Codec,
{
let (mux, client, mut listener) = ChMux::new(cfg, transport_sink, transport_stream).await?;
let mut connection = Self(mux.run().boxed());
tokio::select! {
biased;
Err(err) = &mut connection => Err(err.into()),
result = base::connect(&client, &mut listener) => {
match result {
Ok((tx, rx)) => Ok((connection, tx, rx)),
Err(err) => Err(err.into()),
}
}
}
}
}
impl<'transport> Connect<'transport, io::Error, io::Error> {
pub async fn io<Read, Write, Tx, Rx, Codec>(
cfg: crate::Cfg, input: Read, output: Write,
) -> Result<
(Connect<'transport, io::Error, io::Error>, base::Sender<Tx, Codec>, base::Receiver<Rx, Codec>),
ConnectError<io::Error, io::Error>,
>
where
Read: AsyncRead + Send + Sync + Unpin + 'transport,
Write: AsyncWrite + Send + Sync + Unpin + 'transport,
Tx: RemoteSend,
Rx: RemoteSend,
Codec: codec::Codec,
{
let max_recv_frame_length: usize = cfg.max_frame_length().try_into().unwrap();
let transport_sink = LengthDelimitedCodec::builder()
.little_endian()
.length_field_length(4)
.max_frame_length(u32::MAX as _)
.new_write(output);
let transport_stream = LengthDelimitedCodec::builder()
.little_endian()
.length_field_length(4)
.max_frame_length(max_recv_frame_length)
.new_read(input)
.map_ok(|item| item.freeze());
Self::framed(cfg, transport_sink, transport_stream).await
}
pub async fn io_buffered<Read, Write, Tx, Rx, Codec>(
cfg: crate::Cfg, input: Read, output: Write, buffer: usize,
) -> Result<
(Connect<'transport, io::Error, io::Error>, base::Sender<Tx, Codec>, base::Receiver<Rx, Codec>),
ConnectError<io::Error, io::Error>,
>
where
Read: AsyncRead + Send + Sync + Unpin + 'transport,
Write: AsyncWrite + Send + Sync + Unpin + 'transport,
Tx: RemoteSend,
Rx: RemoteSend,
Codec: codec::Codec,
{
let buf_input = BufReader::with_capacity(buffer, input);
let buf_output = BufWriter::with_capacity(buffer, output);
Self::io(cfg, buf_input, buf_output).await
}
}
impl<TransportSinkError, TransportStreamError> Future for Connect<'_, TransportSinkError, TransportStreamError> {
type Output = Result<(), ChMuxError<TransportSinkError, TransportStreamError>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
Pin::into_inner(self).0.poll_unpin(cx)
}
}