use crate::consts::MAX_LOOPS;
use crate::types::{InStream, OutStream};
use err_derive::Error;
use futures::{channel::mpsc, prelude::*};
use quinn::{ConnectionError, IncomingBiStreams, RecvStream, SendStream};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use tokio_serde::{formats::SymmetricalBincode, SymmetricallyFramed};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
pub type IncomingStreams = mpsc::UnboundedReceiver<NewInStream>;
pub type NewInStream = (String, StreamId, OutStream, InStream);
#[derive(Copy, Clone, Debug)]
pub enum StreamId {
Proxied(u64),
Direct(u64),
}
impl StreamId {
pub fn is_proxied(&self) -> bool {
matches!(self, Self::Proxied(_))
}
pub fn is_direct(&self) -> bool {
matches!(self, Self::Direct(_))
}
}
impl From<StreamId> for u64 {
fn from(sid: StreamId) -> Self {
match sid {
StreamId::Proxied(sid) => sid,
StreamId::Direct(sid) => sid,
}
}
}
#[derive(Debug, Error)]
pub enum IncomingStreamsError {
#[error(display = "Unexpected end of Bi stream")]
EndOfBiStream,
#[error(display = "IncomingStreams receiver is gone")]
ReceiverClosed,
#[error(display = "Accepting Bi stream: {:?}", _0)]
AcceptBiStream(#[source] ConnectionError),
}
def_into_error!(IncomingStreamsError);
pub(super) struct IncomingStreamsInner {
ibi: IncomingBiStreams,
ref_count: usize,
driver: Option<Waker>,
sender: mpsc::UnboundedSender<NewInStream>,
}
impl IncomingStreamsInner {
pub(super) fn instream_init<T: 'static>(
send: SendStream,
recv: RecvStream,
sender: mpsc::UnboundedSender<NewInStream>,
strmid: T,
) where
T: Send + FnOnce(u64) -> StreamId,
{
tokio::spawn(async move {
let mut read_stream = SymmetricallyFramed::new(
FramedRead::new(recv, LengthDelimitedCodec::new()),
SymmetricalBincode::<(String, u64)>::default(),
);
let (peer, chanid) = match read_stream.try_next().await {
Err(e) => {
tracing::warn!("instream_init: error: {:?}", e);
return;
}
Ok(None) => {
tracing::warn!("instream_init: unexpected end of stream");
return;
}
Ok(Some(peer_chanid)) => peer_chanid,
};
let instream = read_stream.into_inner();
let outstream = FramedWrite::new(send, LengthDelimitedCodec::new());
sender.unbounded_send((peer, strmid(chanid), outstream, instream)).ok();
});
}
fn handle_events(&mut self, cx: &mut Context) -> Result<(), IncomingStreamsError> {
match self.sender.poll_ready(cx) {
Poll::Ready(Err(_)) => Err(IncomingStreamsError::ReceiverClosed),
_ => Ok(()),
}
}
fn drive_streams_recv(&mut self, cx: &mut Context) -> Result<bool, IncomingStreamsError> {
let mut recvd = 0;
loop {
let (send, recv) = match self.ibi.poll_next_unpin(cx) {
Poll::Pending => break,
Poll::Ready(None) => Err(IncomingStreamsError::EndOfBiStream),
Poll::Ready(Some(r)) => r.map_err(IncomingStreamsError::AcceptBiStream),
}?;
Self::instream_init(send, recv, self.sender.clone(), StreamId::Proxied);
recvd += 1;
if recvd >= MAX_LOOPS {
return Ok(true);
}
}
Ok(false)
}
fn run_driver(&mut self, cx: &mut Context) -> Result<(), IncomingStreamsError> {
let mut iters = 0;
loop {
self.handle_events(cx)?;
let keep_going = self.drive_streams_recv(cx)?;
if !keep_going {
break;
}
iters += 1;
if iters >= MAX_LOOPS {
cx.waker().wake_by_ref();
break;
}
}
Ok(())
}
}
def_ref!(IncomingStreamsInner, IncomingStreamsRef);
impl IncomingStreamsRef {
pub(super) fn new(ibi: IncomingBiStreams, sender: mpsc::UnboundedSender<NewInStream>) -> Self {
Self(Arc::new(Mutex::new(IncomingStreamsInner {
ibi,
ref_count: 0,
driver: None,
sender,
})))
}
}
def_driver!(IncomingStreamsRef, IncomingStreamsDriver, IncomingStreamsError);
impl Drop for IncomingStreamsDriver {
fn drop(&mut self) {
let mut inner = self.0.lock().unwrap();
inner.sender.disconnect();
}
}