#![cfg_attr(
feature = "memquota",
doc = "let config = tor_memquota::Config::builder().max(1024*1024*1024).build().unwrap();",
doc = "let trk = MemoryQuotaTracker::new(&runtime, config).unwrap();"
)]
#![cfg_attr(
not(feature = "memquota"),
doc = "let trk = MemoryQuotaTracker::new_noop();"
)]
#![forbid(unsafe_code)]
use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
use crate::internal_prelude::*;
use std::task::{Context, Poll, Poll::*};
use tor_async_utils::{ErasedSinkTrySendError, SinkCloseChannel, SinkTrySend};
#[derive(Educe)]
#[educe(Debug, Clone(bound = "C::Sender<Entry<T>>: Clone"))]
pub struct Sender<T: Debug + Send + 'static, C: ChannelSpec> {
tx: C::Sender<Entry<T>>,
mq: TypedParticipation<Entry<T>>,
#[educe(Debug(ignore))] runtime: DynTimeProvider,
}
#[derive(Educe)] #[educe(Debug)]
pub struct Receiver<T: Debug + Send + 'static, C: ChannelSpec> {
inner: Arc<ReceiverInner<T, C>>,
}
#[derive(Educe)]
#[educe(Debug)]
struct ReceiverInner<T: Debug + Send + 'static, C: ChannelSpec> {
state: Mutex<Result<ReceiverState<T, C>, CollapsedDueToReclaim>>,
}
#[derive(Educe)]
#[educe(Debug)]
struct ReceiverState<T: Debug + Send + 'static, C: ChannelSpec> {
rx: StreamUnobtrusivePeeker<C::Receiver<Entry<T>>>,
mq: TypedParticipation<Entry<T>>,
#[educe(Debug(method = "receiver_state_debug_collapse_notify"))]
collapse_callbacks: Vec<CollapseCallback>,
}
#[derive(Debug)]
struct Entry<T> {
t: T,
when: CoarseInstant,
}
#[derive(Error, Clone, Debug)]
#[non_exhaustive]
pub enum SendError<CE> {
#[error("channel send failed")]
Channel(#[source] CE),
#[error("memory quota exhausted, queue reclaimed")]
Memquota(#[from] Error),
}
pub type CollapseCallback = Box<dyn FnOnce(CollapseReason) + Send + Sync + 'static>;
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
#[non_exhaustive]
pub enum CollapseReason {
ReceiverDropped,
MemoryReclaimed,
}
#[derive(Debug, Clone, Copy)]
struct CollapsedDueToReclaim;
pub trait ChannelSpec: Sealed + Sized + 'static {
type Sender<T: Debug + Send + 'static>: Sink<T, Error = Self::SendError>
+ Debug + Unpin + Sized;
type Receiver<T: Debug + Send + 'static>: Stream<Item = T> + Debug + Unpin + Send + Sized;
type SendError: std::error::Error;
#[allow(clippy::type_complexity)] fn new_mq<T>(self, runtime: DynTimeProvider, account: &Account) -> crate::Result<(
Sender<T, Self>,
Receiver<T, Self>,
)>
where
T: HasMemoryCost + Debug + Send + 'static,
{
let (rx, (tx, mq)) = account.register_participant_with(
runtime.now_coarse(),
move |mq| {
let mq = TypedParticipation::new(mq);
let collapse_callbacks = vec![];
let (tx, rx) = self.raw_channel::<Entry<T>>();
let rx = StreamUnobtrusivePeeker::new(rx);
let state = ReceiverState { rx, mq: mq.clone(), collapse_callbacks };
let state = Mutex::new(Ok(state));
let inner = ReceiverInner { state };
Ok::<_, crate::Error>((inner.into(), (tx, mq)))
},
)??;
let runtime = runtime.clone();
let tx = Sender { runtime, tx, mq };
let rx = Receiver { inner: rx };
Ok((tx, rx))
}
fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>);
fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>);
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor)]
#[allow(clippy::exhaustive_structs)] pub struct MpscSpec {
pub buffer: usize,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor, Default)]
#[allow(clippy::exhaustive_structs)] pub struct MpscUnboundedSpec;
impl Sealed for MpscSpec {}
impl Sealed for MpscUnboundedSpec {}
impl ChannelSpec for MpscSpec {
type Sender<T: Debug + Send + 'static> = mpsc::Sender<T>;
type Receiver<T: Debug + Send + 'static> = mpsc::Receiver<T>;
type SendError = mpsc::SendError;
fn raw_channel<T: Debug + Send + 'static>(self) -> (mpsc::Sender<T>, mpsc::Receiver<T>) {
mpsc_channel_no_memquota(self.buffer)
}
fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
rx.close();
}
}
impl ChannelSpec for MpscUnboundedSpec {
type Sender<T: Debug + Send + 'static> = mpsc::UnboundedSender<T>;
type Receiver<T: Debug + Send + 'static> = mpsc::UnboundedReceiver<T>;
type SendError = mpsc::SendError;
fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>) {
mpsc::unbounded()
}
fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
rx.close();
}
}
impl<T, C> Sink<T> for Sender<T, C>
where
T: HasMemoryCost + Debug + Send + 'static,
C: ChannelSpec,
{
type Error = SendError<C::SendError>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut()
.tx
.poll_ready_unpin(cx)
.map_err(SendError::Channel)
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
let self_ = self.get_mut();
let item = Entry {
t: item,
when: self_.runtime.now_coarse(),
};
self_.mq.try_claim(item, |item| {
self_.tx.start_send_unpin(item).map_err(SendError::Channel)
})?
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx
.poll_flush_unpin(cx)
.map(|r| r.map_err(SendError::Channel))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx
.poll_close_unpin(cx)
.map(|r| r.map_err(SendError::Channel))
}
}
impl<T, C> SinkTrySend<T> for Sender<T, C>
where
T: HasMemoryCost + Debug + Send + 'static,
C: ChannelSpec,
C::Sender<Entry<T>>: SinkTrySend<Entry<T>>,
<C::Sender<Entry<T>> as SinkTrySend<Entry<T>>>::Error: Send + Sync,
{
type Error = ErasedSinkTrySendError;
fn try_send_or_return(
self: Pin<&mut Self>,
item: T,
) -> Result<(), (<Self as SinkTrySend<T>>::Error, T)> {
let self_ = self.get_mut();
let item = Entry {
t: item,
when: self_.runtime.now_coarse(),
};
use ErasedSinkTrySendError as ESTSE;
self_
.mq
.try_claim_or_return(item, |item| {
Pin::new(&mut self_.tx).try_send_or_return(item)
})
.map_err(|(mqe, unsent)| (ESTSE::Other(Arc::new(mqe)), unsent.t))?
.map_err(|(tse, unsent)| (ESTSE::from(tse), unsent.t))
}
}
impl<T, C> SinkCloseChannel<T> for Sender<T, C>
where
T: HasMemoryCost + Debug + Send, C: ChannelSpec,
C::Sender<Entry<T>>: SinkCloseChannel<Entry<T>>,
{
fn close_channel(self: Pin<&mut Self>) {
Pin::new(&mut self.get_mut().tx).close_channel();
}
}
impl<T, C> Sender<T, C>
where
T: Debug + Send + 'static,
C: ChannelSpec,
{
pub fn time_provider(&self) -> &DynTimeProvider {
&self.runtime
}
}
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Stream for Receiver<T, C> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut state = self.inner.lock();
let state = match &mut *state {
Ok(y) => y,
Err(CollapsedDueToReclaim) => return Ready(None),
};
let ret = state.rx.poll_next_unpin(cx);
if let Ready(Some(item)) = &ret {
if let Some(enabled) = EnabledToken::new_if_compiled_in() {
let cost = item.typed_memory_cost(enabled);
state.mq.release(&cost);
}
}
ret.map(|r| r.map(|e| e.t))
}
}
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> FusedStream for Receiver<T, C>
where
C::Receiver<Entry<T>>: FusedStream,
{
fn is_terminated(&self) -> bool {
match &*self.inner.lock() {
Ok(y) => y.rx.is_terminated(),
Err(CollapsedDueToReclaim) => true,
}
}
}
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Receiver<T, C> {
pub fn register_collapse_hook(&self, call: CollapseCallback) {
let mut state = self.inner.lock();
let state = match &mut *state {
Ok(y) => y,
Err(reason) => {
let reason = (*reason).into();
drop::<MutexGuard<_>>(state);
call(reason);
return;
}
};
state.collapse_callbacks.push(call);
}
}
impl<T: Debug + Send + 'static, C: ChannelSpec> ReceiverInner<T, C> {
fn lock(&self) -> MutexGuard<Result<ReceiverState<T, C>, CollapsedDueToReclaim>> {
self.state.lock().expect("mq_mpsc lock poisoned")
}
}
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> IsParticipant
for ReceiverInner<T, C>
{
fn get_oldest(&self, _: EnabledToken) -> Option<CoarseInstant> {
let mut state = self.lock();
let state = match &mut *state {
Ok(y) => y,
Err(CollapsedDueToReclaim) => return None,
};
Pin::new(&mut state.rx)
.unobtrusive_peek()
.map(|peeked| peeked.when)
}
fn reclaim(self: Arc<Self>, _: EnabledToken) -> mtracker::ReclaimFuture {
Box::pin(async move {
let reason = CollapsedDueToReclaim;
let mut state_guard = self.lock();
let state = mem::replace(&mut *state_guard, Err(reason));
drop::<MutexGuard<_>>(state_guard);
#[allow(clippy::single_match)] match state {
Ok(mut state) => {
for call in state.collapse_callbacks.drain(..) {
call(reason.into());
}
drop::<ReceiverState<_, _>>(state); }
Err(CollapsedDueToReclaim) => {}
};
mtracker::Reclaimed::Collapsing
})
}
}
impl<T: Debug + Send + 'static, C: ChannelSpec> Drop for ReceiverState<T, C> {
fn drop(&mut self) {
mem::replace(&mut self.mq, Participation::new_dangling().into())
.into_raw()
.destroy_participant();
for call in self.collapse_callbacks.drain(..) {
call(CollapseReason::ReceiverDropped);
}
let mut noop_cx = Context::from_waker(Waker::noop());
if let Some(mut rx_inner) =
StreamUnobtrusivePeeker::as_raw_inner_pin_mut(Pin::new(&mut self.rx))
{
C::close_receiver(&mut rx_inner);
}
while let Ready(Some(item)) = self.rx.poll_next_unpin(&mut noop_cx) {
drop::<Entry<T>>(item);
}
}
}
fn receiver_state_debug_collapse_notify(
v: &[CollapseCallback],
f: &mut fmt::Formatter,
) -> fmt::Result {
Debug::fmt(&v.len(), f)
}
impl<T: HasMemoryCost> HasMemoryCost for Entry<T> {
fn memory_cost(&self, enabled: EnabledToken) -> usize {
let time_size = std::alloc::Layout::new::<CoarseInstant>().size();
self.t.memory_cost(enabled).saturating_add(time_size)
}
}
impl From<CollapsedDueToReclaim> for CollapseReason {
fn from(CollapsedDueToReclaim: CollapsedDueToReclaim) -> CollapseReason {
CollapseReason::MemoryReclaimed
}
}
#[cfg(all(test, feature = "memquota", not(miri) /* coarsetime */))]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::string_slice)] #![allow(clippy::arithmetic_side_effects)]
use super::*;
use crate::mtracker::test::*;
use tor_rtmock::MockRuntime;
use tracing::debug;
use tracing_test::traced_test;
#[derive(Default, Debug)]
struct ItemTracker {
state: Mutex<ItemTrackerState>,
}
#[derive(Default, Debug)]
struct ItemTrackerState {
existing: usize,
next_id: usize,
}
#[derive(Debug)]
struct Item {
id: usize,
tracker: Arc<ItemTracker>,
}
impl ItemTracker {
fn new_item(self: &Arc<Self>) -> Item {
let mut state = self.lock();
let id = state.next_id;
state.existing += 1;
state.next_id += 1;
debug!("new {id}");
Item {
tracker: self.clone(),
id,
}
}
fn new_tracker() -> Arc<Self> {
Arc::default()
}
fn lock(&self) -> MutexGuard<ItemTrackerState> {
self.state.lock().unwrap()
}
}
impl Drop for Item {
fn drop(&mut self) {
debug!("old {}", self.id);
self.tracker.state.lock().unwrap().existing -= 1;
}
}
impl HasMemoryCost for Item {
fn memory_cost(&self, _: EnabledToken) -> usize {
mbytes(1)
}
}
struct Setup {
dtp: DynTimeProvider,
trk: Arc<mtracker::MemoryQuotaTracker>,
acct: Account,
itrk: Arc<ItemTracker>,
}
fn setup(rt: &MockRuntime) -> Setup {
let dtp = DynTimeProvider::new(rt.clone());
let trk = mk_tracker(rt);
let acct = trk.new_account(None).unwrap();
let itrk = ItemTracker::new_tracker();
Setup {
dtp,
trk,
acct,
itrk,
}
}
#[derive(Debug)]
struct Gigantic;
impl HasMemoryCost for Gigantic {
fn memory_cost(&self, _et: EnabledToken) -> usize {
mbytes(100)
}
}
impl Setup {
fn check_zero_claimed(&self, n_queues: usize) {
let used = self.trk.used_current_approx();
debug!(
"checking zero balance (with slop {n_queues} * 2 * {}; used={used:?}",
*mtracker::MAX_CACHE,
);
assert!(used.unwrap() <= n_queues * 2 * *mtracker::MAX_CACHE);
}
}
#[traced_test]
#[test]
fn lifecycle() {
MockRuntime::test_with_various(|rt| async move {
let s = setup(&rt);
let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
tx.send(s.itrk.new_item()).await.unwrap();
let _: Item = rx.next().await.unwrap();
for _ in 0..20 {
tx.send(s.itrk.new_item()).await.unwrap();
}
debug!("still existing items {}", s.itrk.lock().existing);
rt.advance_until_stalled().await;
assert!(s.itrk.lock().existing == 0);
assert!(rx.next().await.is_none());
let _: SendError<_> = tx.send(s.itrk.new_item()).await.unwrap_err();
});
}
#[traced_test]
#[test]
fn fill_and_empty() {
MockRuntime::test_with_various(|rt| async move {
let s = setup(&rt);
let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
const COUNT: usize = 19;
for _ in 0..COUNT {
tx.send(s.itrk.new_item()).await.unwrap();
}
rt.advance_until_stalled().await;
for _ in 0..COUNT {
let _: Item = rx.next().await.unwrap();
}
rt.advance_until_stalled().await;
s.check_zero_claimed(1);
});
}
#[traced_test]
#[test]
fn sink_error() {
#[derive(Debug, Copy, Clone)]
struct BustedSink {
error: BustedError,
}
impl<T> Sink<T> for BustedSink {
type Error = BustedError;
fn poll_ready(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Ready(Err(self.error))
}
fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
panic!("poll_ready always gives error, start_send should not be called");
}
fn poll_flush(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Ready(Ok(()))
}
}
impl<T> SinkTrySend<T> for BustedSink {
type Error = BustedError;
fn try_send_or_return(self: Pin<&mut Self>, item: T) -> Result<(), (BustedError, T)> {
Err((self.error, item))
}
}
impl tor_async_utils::SinkTrySendError for BustedError {
fn is_disconnected(&self) -> bool {
self.is_disconnected
}
fn is_full(&self) -> bool {
false
}
}
#[derive(Error, Debug, Clone, Copy)]
#[error("busted, for testing, dc={is_disconnected:?}")]
struct BustedError {
is_disconnected: bool,
}
struct BustedQueueSpec {
error: BustedError,
}
impl Sealed for BustedQueueSpec {}
impl ChannelSpec for BustedQueueSpec {
type Sender<T: Debug + Send + 'static> = BustedSink;
type Receiver<T: Debug + Send + 'static> = futures::stream::Pending<T>;
type SendError = BustedError;
fn raw_channel<T: Debug + Send + 'static>(self) -> (BustedSink, Self::Receiver<T>) {
(BustedSink { error: self.error }, futures::stream::pending())
}
fn close_receiver<T: Debug + Send + 'static>(_rx: &mut Self::Receiver<T>) {}
}
use ErasedSinkTrySendError as ESTSE;
MockRuntime::test_with_various(|rt| async move {
let error = BustedError {
is_disconnected: true,
};
let s = setup(&rt);
let (mut tx, _rx) = BustedQueueSpec { error }
.new_mq(s.dtp.clone(), &s.acct)
.unwrap();
let e = tx.send(s.itrk.new_item()).await.unwrap_err();
assert!(matches!(e, SendError::Channel(BustedError { .. })));
assert_eq!(s.itrk.lock().existing, 0);
fn error_is_other_of<E>(e: ESTSE) -> Result<(), impl Debug>
where
E: std::error::Error + 'static,
{
match e {
ESTSE::Other(e) if e.is::<E>() => Ok(()),
other => Err(other),
}
}
let item = s.itrk.new_item();
let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
assert!(matches!(e, ESTSE::Disconnected), "{e:?}");
let error = BustedError {
is_disconnected: false,
};
let (mut tx, _rx) = BustedQueueSpec { error }
.new_mq(s.dtp.clone(), &s.acct)
.unwrap();
let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
error_is_other_of::<BustedError>(e).unwrap();
s.check_zero_claimed(1);
{
let (mut tx, _rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
tx.send(Gigantic).await.unwrap();
rt.advance_until_stalled().await;
}
let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
error_is_other_of::<crate::Error>(e).unwrap();
drop::<Item>(item);
});
}
}