#![cfg_attr(not(feature = "circ-padding"), expect(dead_code))]
mod boolean_policy;
mod counting_policy;
pub(crate) use boolean_policy::BooleanPolicy;
pub(crate) use counting_policy::CountingPolicy;
use std::{
pin::Pin,
task::{Context, Poll, Waker},
};
use futures::Sink;
use pin_project::pin_project;
use tor_error::Bug;
#[pin_project]
pub(crate) struct SinkBlocker<S, P = BooleanPolicy> {
#[pin]
inner: S,
policy: P,
waker: Option<Waker>,
}
pub(crate) trait Policy {
fn is_blocking(&self) -> bool;
fn take_one(&mut self) -> Result<(), Bug>;
}
impl<S, P> SinkBlocker<S, P> {
pub(crate) fn new(inner: S, policy: P) -> Self {
SinkBlocker {
inner,
policy,
waker: None,
}
}
pub(crate) fn as_inner(&self) -> &S {
&self.inner
}
pub(crate) fn as_inner_mut(&mut self) -> &mut S {
&mut self.inner
}
}
impl<S, P: Policy> SinkBlocker<S, P> {
pub(crate) fn update_policy(&mut self, new_policy: P) {
let was_blocking = self.policy.is_blocking();
let is_blocking = new_policy.is_blocking();
self.policy = new_policy;
if was_blocking && !is_blocking {
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
}
impl<T, S: Sink<T>, P: Policy> Sink<T> for SinkBlocker<S, P> {
type Error = S::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let self_ = self.project();
if self_.policy.is_blocking() {
*self_.waker = Some(cx.waker().clone());
Poll::Pending
} else {
self_.inner.poll_ready(cx)
}
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
let self_ = self.project();
let () = self_.inner.start_send(item)?;
let _: () = self_.policy.take_one().expect(
"take_one failed after is_blocking returned false: bug in Policy or SinkBlocker",
);
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use std::sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
};
use super::*;
use futures::{SinkExt as _, StreamExt as _, channel::mpsc, poll};
use tor_rtmock::MockRuntime;
#[test]
fn block_and_unblock() {
MockRuntime::test_with_various(|runtime| async move {
let (tx, mut rx) = mpsc::channel::<u32>(1);
let tx = SinkBlocker::new(tx, BooleanPolicy::Unblocked);
let mut tx = tx.buffer(5);
let blocked = Arc::new(AtomicBool::new(false));
let n_received = Arc::new(AtomicUsize::new(0));
let blocked_clone = Arc::clone(&blocked);
let n_received_clone = Arc::clone(&n_received);
let n_received_clone2 = Arc::clone(&n_received);
runtime.spawn_identified("Transmitter", async move {
tx.send(1).await.unwrap();
tx.send(2).await.unwrap();
blocked.store(true, Ordering::SeqCst);
tx.get_mut().set_blocked();
tx.feed(3).await.unwrap();
tx.feed(4).await.unwrap();
assert!(dbg!(n_received.load(Ordering::SeqCst)) <= 2);
let flush_future = tx.flush();
assert!(poll!(flush_future).is_pending());
blocked.store(false, Ordering::SeqCst);
tx.get_mut().set_unblocked();
tx.flush().await.unwrap();
tx.close().await.unwrap();
});
runtime.spawn_identified("Receiver", async move {
let n_received = n_received_clone;
let blocked = blocked_clone;
let mut expected = 1;
while let Some(val) = rx.next().await {
assert_eq!(val, expected);
expected += 1;
n_received.fetch_add(1, Ordering::SeqCst);
if val >= 3 {
assert_eq!(blocked.load(Ordering::SeqCst), false);
}
}
dbg!(expected);
});
runtime.progress_until_stalled().await;
assert_eq!(dbg!(n_received_clone2.load(Ordering::SeqCst)), 4);
});
}
}