use bytes::{Buf, Bytes, BytesMut};
use futures::{
future::BoxFuture,
ready,
stream::Stream,
task::{Context, Poll},
FutureExt,
};
use std::{collections::VecDeque, error::Error, fmt, mem, pin::Pin, sync::Arc};
use tokio::sync::{mpsc, oneshot, Mutex};
use super::{
credit::{ChannelCreditReturner, UsedCredit},
multiplexer::PortEvt,
Request,
};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ReceiveError {
Multiplexer,
ExceedsMaxDataSize(usize),
ExceedsMaxPortCount(usize),
}
impl ReceiveError {
pub fn is_terminated(&self) -> bool {
matches!(self, Self::Multiplexer)
}
}
impl fmt::Display for ReceiveError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Multiplexer => write!(f, "multiplexer terminated"),
Self::ExceedsMaxDataSize(max_size) => {
write!(f, "data exceeds maximum allowed size of {} bytes", max_size)
}
Self::ExceedsMaxPortCount(max_count) => {
write!(f, "port message exceeds maximum allowed count of {} ports", max_count)
}
}
}
}
impl Error for ReceiveError {}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ReceiveAnyError {
Multiplexer,
}
impl ReceiveAnyError {
pub fn is_terminated(&self) -> bool {
matches!(self, Self::Multiplexer)
}
}
impl fmt::Display for ReceiveAnyError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Multiplexer => write!(f, "multiplexer terminated"),
}
}
}
impl Error for ReceiveAnyError {}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ReceiveChunkError {
Multiplexer,
Cancelled,
}
impl ReceiveChunkError {
pub fn is_terminated(&self) -> bool {
matches!(self, Self::Multiplexer)
}
}
impl fmt::Display for ReceiveChunkError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Multiplexer => write!(f, "multiplexer terminated"),
Self::Cancelled => write!(f, "transmission cancelled"),
}
}
}
pub(crate) struct ReceivedData {
pub buf: Bytes,
pub first: bool,
pub last: bool,
pub credit: UsedCredit,
}
pub(crate) struct ReceivedPortRequests {
pub requests: Vec<Request>,
pub first: bool,
pub last: bool,
pub credit: UsedCredit,
}
pub(crate) enum PortReceiveMsg {
Data(ReceivedData),
PortRequests(ReceivedPortRequests),
Finished,
}
#[derive(Clone)]
pub struct DataBuf {
bufs: VecDeque<Bytes>,
remaining: usize,
}
impl fmt::Debug for DataBuf {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("DataBuf").field("remaining", &self.remaining).finish_non_exhaustive()
}
}
impl DataBuf {
fn new() -> Self {
Self { bufs: VecDeque::new(), remaining: 0 }
}
fn try_push(&mut self, buf: Bytes, max_size: usize) -> Result<(), Bytes> {
match self.remaining.checked_add(buf.len()) {
Some(new_size) if new_size <= max_size => {
self.bufs.push_back(buf);
self.remaining = new_size;
Ok(())
}
_ => Err(buf),
}
}
}
impl Default for DataBuf {
fn default() -> Self {
Self::new()
}
}
impl Buf for DataBuf {
fn remaining(&self) -> usize {
self.remaining
}
fn chunk(&self) -> &[u8] {
match self.bufs.front() {
Some(buf) => buf.chunk(),
None => &[],
}
}
fn advance(&mut self, mut cnt: usize) {
while cnt > 0 {
match self.bufs.front_mut() {
Some(buf) if buf.len() > cnt => {
self.remaining -= cnt;
buf.advance(cnt);
cnt = 0;
}
Some(buf) => {
self.remaining -= buf.len();
cnt -= buf.len();
self.bufs.pop_front();
}
None => {
panic!("cannot advance beyond end of data");
}
}
}
}
}
impl From<DataBuf> for BytesMut {
fn from(mut data: DataBuf) -> Self {
let mut continuous = BytesMut::with_capacity(data.remaining);
while let Some(buf) = data.bufs.pop_front() {
continuous.extend_from_slice(&buf);
}
continuous
}
}
impl From<DataBuf> for Bytes {
fn from(mut data: DataBuf) -> Self {
if data.bufs.len() == 1 {
data.bufs.pop_front().unwrap()
} else {
BytesMut::from(data).into()
}
}
}
impl From<DataBuf> for Vec<u8> {
fn from(mut data: DataBuf) -> Self {
let mut continuous = Vec::with_capacity(data.remaining);
while let Some(buf) = data.bufs.pop_front() {
continuous.extend_from_slice(&buf);
}
continuous
}
}
impl From<Bytes> for DataBuf {
fn from(data: Bytes) -> Self {
let remaining = data.len();
let mut bufs = VecDeque::new();
bufs.push_back(data);
Self { bufs, remaining }
}
}
pub enum Received {
Data(DataBuf),
BigData,
Requests(Vec<Request>),
}
enum Receiving {
Nothing,
Data(DataBuf),
Chunks { chunks: VecDeque<Bytes>, completed: bool },
Requests(Vec<Request>),
}
impl Default for Receiving {
fn default() -> Self {
Self::Nothing
}
}
pub struct Receiver {
local_port: u32,
remote_port: u32,
max_data_size: usize,
max_ports: usize,
tx: mpsc::Sender<PortEvt>,
rx: mpsc::UnboundedReceiver<PortReceiveMsg>,
receiving: Receiving,
credits: ChannelCreditReturner,
closed: bool,
finished: bool,
_drop_tx: oneshot::Sender<()>,
}
impl fmt::Debug for Receiver {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Receiver")
.field("local_port", &self.local_port)
.field("remote_port", &self.remote_port)
.field("max_data_size", &self.max_data_size)
.field("max_ports", &self.max_ports)
.field("closed", &self.closed)
.field("finished", &self.finished)
.finish_non_exhaustive()
}
}
impl Receiver {
pub(crate) fn new(
local_port: u32, remote_port: u32, max_data_size: usize, max_port_count: usize,
tx: mpsc::Sender<PortEvt>, rx: mpsc::UnboundedReceiver<PortReceiveMsg>, credits: ChannelCreditReturner,
) -> Self {
let (_drop_tx, drop_rx) = oneshot::channel();
let tx_drop = tx.clone();
tokio::spawn(async move {
let _ = drop_rx.await;
let _ = tx_drop.send(PortEvt::ReceiverDropped { local_port }).await;
});
Self {
local_port,
remote_port,
max_data_size,
max_ports: max_port_count,
tx,
rx,
receiving: Receiving::Nothing,
credits,
closed: false,
finished: false,
_drop_tx,
}
}
pub fn local_port(&self) -> u32 {
self.local_port
}
pub fn remote_port(&self) -> u32 {
self.remote_port
}
pub fn max_data_size(&self) -> usize {
self.max_data_size
}
pub fn set_max_data_size(&mut self, max_data_size: usize) {
self.max_data_size = max_data_size;
}
pub fn max_ports(&self) -> usize {
self.max_ports
}
pub fn set_max_ports(&mut self, max_ports: usize) {
self.max_ports = max_ports;
}
pub async fn recv(&mut self) -> Result<Option<DataBuf>, ReceiveError> {
loop {
match self.recv_any().await? {
Some(Received::Data(data)) => break Ok(Some(data)),
Some(Received::BigData) => break Err(ReceiveError::ExceedsMaxDataSize(self.max_data_size)),
Some(Received::Requests(_)) => (),
None => break Ok(None),
}
}
}
pub async fn recv_chunk(&mut self) -> Result<Option<Bytes>, ReceiveChunkError> {
if self.finished {
return Ok(None);
}
loop {
self.credits.return_flush().await;
match &mut self.receiving {
Receiving::Chunks { chunks, .. } if !chunks.is_empty() => {
return Ok(Some(chunks.pop_front().unwrap()))
}
Receiving::Chunks { completed: true, .. } => {
self.receiving = Receiving::Nothing;
return Ok(None);
}
_ => match self.rx.recv().await {
Some(PortReceiveMsg::Data(data)) => {
self.credits.start_return(data.credit, self.remote_port, &self.tx);
match (&self.receiving, data.first) {
(Receiving::Chunks { .. }, true) => {
self.receiving =
Receiving::Chunks { chunks: vec![data.buf].into(), completed: data.last };
return Err(ReceiveChunkError::Cancelled);
}
(Receiving::Chunks { .. }, false) | (_, true) => {
self.receiving =
Receiving::Chunks { chunks: VecDeque::new(), completed: data.last };
return Ok(Some(data.buf));
}
(_, false) => (),
}
}
Some(PortReceiveMsg::PortRequests(req)) => {
self.credits.start_return(req.credit, self.remote_port, &self.tx);
if let Receiving::Chunks { .. } = &self.receiving {
self.receiving = Receiving::Nothing;
return Err(ReceiveChunkError::Cancelled);
}
}
Some(PortReceiveMsg::Finished) => {
self.finished = true;
if let Receiving::Chunks { .. } = &self.receiving {
self.receiving = Receiving::Nothing;
return Err(ReceiveChunkError::Cancelled);
} else {
return Ok(None);
}
}
None => return Err(ReceiveChunkError::Multiplexer),
},
}
}
}
pub async fn recv_any(&mut self) -> Result<Option<Received>, ReceiveError> {
if self.finished {
return Ok(None);
}
loop {
self.credits.return_flush().await;
match self.rx.recv().await {
Some(PortReceiveMsg::Data(data)) => {
self.credits.start_return(data.credit, self.remote_port, &self.tx);
if data.first {
self.receiving = Receiving::Data(DataBuf::new());
}
if let Receiving::Data(mut data_buf) = mem::take(&mut self.receiving) {
match data_buf.try_push(data.buf, self.max_data_size) {
Ok(()) => {
if data.last {
return Ok(Some(Received::Data(data_buf)));
} else {
self.receiving = Receiving::Data(data_buf);
}
}
Err(buf) => {
data_buf.bufs.push_back(buf);
self.receiving =
Receiving::Chunks { chunks: data_buf.bufs, completed: data.last };
return Ok(Some(Received::BigData));
}
}
}
}
Some(PortReceiveMsg::PortRequests(req)) => {
self.credits.start_return(req.credit, self.remote_port, &self.tx);
if req.first {
self.receiving = Receiving::Requests(Vec::new());
}
if let Receiving::Requests(mut requests) = mem::take(&mut self.receiving) {
requests.extend(req.requests);
if requests.len() > self.max_ports {
self.receiving = Receiving::Nothing;
return Err(ReceiveError::ExceedsMaxPortCount(self.max_ports));
}
if req.last {
return Ok(Some(Received::Requests(requests)));
} else {
self.receiving = Receiving::Requests(requests);
}
}
}
Some(PortReceiveMsg::Finished) => {
self.finished = true;
return Ok(None);
}
None => return Err(ReceiveError::Multiplexer),
}
}
}
pub async fn close(&mut self) {
if !self.closed {
let _ = self.tx.send(PortEvt::ReceiverClosed { local_port: self.local_port }).await;
self.closed = true;
}
}
pub fn into_stream(self) -> ReceiverStream {
ReceiverStream::new(self)
}
}
impl Drop for Receiver {
fn drop(&mut self) {
}
}
pub struct ReceiverStream {
receiver: Arc<Mutex<Receiver>>,
receive_fut: Option<BoxFuture<'static, Result<Option<DataBuf>, ReceiveError>>>,
}
impl ReceiverStream {
fn new(receiver: Receiver) -> Self {
Self { receiver: Arc::new(Mutex::new(receiver)), receive_fut: None }
}
async fn receive(receiver: Arc<Mutex<Receiver>>) -> Result<Option<DataBuf>, ReceiveError> {
let mut receiver = receiver.lock().await;
receiver.recv().await
}
fn poll_next(&mut self, cx: &mut Context) -> Poll<Result<Option<DataBuf>, ReceiveError>> {
if self.receive_fut.is_none() {
self.receive_fut = Some(Self::receive(self.receiver.clone()).boxed());
}
let fut = self.receive_fut.as_mut().unwrap();
let res = ready!(fut.as_mut().poll(cx));
self.receive_fut = None;
Poll::Ready(res)
}
pub async fn close(&mut self) {
self.receiver.lock().await.close().await
}
}
impl Stream for ReceiverStream {
type Item = Result<DataBuf, ReceiveError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let res = ready!(Pin::into_inner(self).poll_next(cx));
Poll::Ready(res.transpose())
}
}