use std::fmt::Debug;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use futures::{Sink, SinkExt, Stream};
use tor_async_utils::SinkTrySend;
use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
use tor_cell::relaycell::UnparsedRelayMsg;
use tor_memquota::mq_queue::{self, ChannelSpec, MpscSpec, MpscUnboundedSpec};
use tor_rtcompat::DynTimeProvider;
use crate::memquota::{SpecificAccount, StreamAccount};
#[cfg(feature = "flowctl-cc")]
type Spec = MpscUnboundedSpec;
#[cfg(not(feature = "flowctl-cc"))]
type Spec = MpscSpec;
pub(crate) fn stream_queue(
#[cfg(not(feature = "flowctl-cc"))] size: usize,
memquota: &StreamAccount,
time_prov: &DynTimeProvider,
) -> Result<(StreamQueueSender, StreamQueueReceiver), tor_memquota::Error> {
let (sender, receiver) = {
cfg_if::cfg_if! {
if #[cfg(not(feature = "flowctl-cc"))] {
MpscSpec::new(size).new_mq(time_prov.clone(), memquota.as_raw_account())?
} else {
MpscUnboundedSpec::new().new_mq(time_prov.clone(), memquota.as_raw_account())?
}
}
};
let receiver = StreamUnobtrusivePeeker::new(receiver);
let counter = Arc::new(Mutex::new(0));
Ok((
StreamQueueSender {
sender,
counter: Arc::clone(&counter),
},
StreamQueueReceiver { receiver, counter },
))
}
#[cfg(test)]
pub(crate) fn fake_stream_queue(
#[cfg(not(feature = "flowctl-cc"))] size: usize,
) -> (StreamQueueSender, StreamQueueReceiver) {
stream_queue(
#[cfg(not(feature = "flowctl-cc"))]
size,
&StreamAccount::new_noop(),
&DynTimeProvider::new(tor_rtmock::MockRuntime::default()),
)
.expect("create fake stream queue")
}
#[derive(Debug)]
#[pin_project::pin_project]
pub(crate) struct StreamQueueSender {
#[pin]
sender: mq_queue::Sender<UnparsedRelayMsg, Spec>,
counter: Arc<Mutex<usize>>,
}
#[derive(Debug)]
#[pin_project::pin_project]
pub(crate) struct StreamQueueReceiver {
#[pin]
receiver: StreamUnobtrusivePeeker<mq_queue::Receiver<UnparsedRelayMsg, Spec>>,
counter: Arc<Mutex<usize>>,
}
impl StreamQueueSender {
pub(crate) fn approx_stream_bytes(&self) -> usize {
*self.counter.lock().expect("poisoned")
}
}
impl StreamQueueReceiver {
pub(crate) fn approx_stream_bytes(&self) -> usize {
*self.counter.lock().expect("poisoned")
}
}
impl Sink<UnparsedRelayMsg> for StreamQueueSender {
type Error = <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as Sink<UnparsedRelayMsg>>::Error;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
self.sender.poll_ready_unpin(cx)
}
fn start_send(
mut self: Pin<&mut Self>,
item: UnparsedRelayMsg,
) -> std::result::Result<(), Self::Error> {
let mut self_ = self.as_mut().project();
let stream_data_len = data_len(&item);
let mut counter = self_.counter.lock().expect("poisoned");
self_.sender.start_send_unpin(item)?;
*counter = counter
.checked_add(stream_data_len.into())
.expect("queue has more than `usize::MAX` bytes?!");
Ok(())
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
self.sender.poll_flush_unpin(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
self.sender.poll_close_unpin(cx)
}
}
impl SinkTrySend<UnparsedRelayMsg> for StreamQueueSender {
type Error =
<mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as SinkTrySend<UnparsedRelayMsg>>::Error;
fn try_send_or_return(
mut self: Pin<&mut Self>,
item: UnparsedRelayMsg,
) -> Result<
(),
(
<Self as SinkTrySend<UnparsedRelayMsg>>::Error,
UnparsedRelayMsg,
),
> {
let self_ = self.as_mut().project();
let stream_data_len = data_len(&item);
let mut counter = self_.counter.lock().expect("poisoned");
self_.sender.try_send_or_return(item)?;
*counter = counter
.checked_add(stream_data_len.into())
.expect("queue has more than `usize::MAX` bytes?!");
Ok(())
}
}
impl Stream for StreamQueueReceiver {
type Item = UnparsedRelayMsg;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let self_ = self.as_mut().project();
let mut counter = self_.counter.lock().expect("poisoned");
let item = match self_.receiver.poll_next(cx) {
Poll::Ready(Some(x)) => x,
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
};
let stream_data_len = data_len(&item);
if stream_data_len != 0 {
*counter = counter
.checked_sub(stream_data_len.into())
.expect("we've removed more bytes than we've added?!");
}
Poll::Ready(Some(item))
}
}
impl UnobtrusivePeekableStream for StreamQueueReceiver {
fn unobtrusive_peek_mut<'s>(
self: Pin<&'s mut Self>,
) -> Option<&'s mut <Self as futures::Stream>::Item> {
self.project().receiver.unobtrusive_peek_mut()
}
}
fn data_len(item: &UnparsedRelayMsg) -> u16 {
item.data_len().unwrap_or(0)
}