use std::{
pin::Pin,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
task::ready,
task::{Context, Poll},
};
use futures::{Stream, sink::Sink, stream::FusedStream};
use pin_project::pin_project;
#[derive(Clone, Debug)]
#[pin_project]
pub struct CountingSink<S> {
#[pin]
inner: S,
count: Arc<AtomicUsize>,
}
#[derive(Clone, Debug)]
#[pin_project]
pub struct CountingStream<S> {
#[pin]
inner: S,
count: Arc<AtomicUsize>,
}
impl<T, S: Sink<T>> Sink<T> for CountingSink<S> {
type Error = S::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
let self_ = self.project();
let r = self_.inner.start_send(item);
if r.is_ok() {
self_.count.fetch_add(1, Ordering::Relaxed);
}
r
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
impl<S: Stream> Stream for CountingStream<S> {
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let self_ = self.project();
let next = ready!(self_.inner.poll_next(cx));
if next.is_some() {
self_.count.fetch_sub(1, Ordering::Relaxed);
}
Poll::Ready(next)
}
}
impl<S: FusedStream> FusedStream for CountingStream<S> {
fn is_terminated(&self) -> bool {
self.inner.is_terminated()
}
}
impl<S> CountingStream<S> {
pub fn approx_count(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
}
impl<S> CountingSink<S> {
pub fn approx_count(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
}
pub fn channel<T, U>(tx: T, rx: U) -> (CountingSink<T>, CountingStream<U>) {
let count = Arc::new(AtomicUsize::new(0));
let new_tx = CountingSink {
inner: tx,
count: Arc::clone(&count),
};
let new_rx = CountingStream { inner: rx, count };
(new_tx, new_rx)
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use futures::{SinkExt as _, StreamExt as _};
#[test]
fn send_only_onetask() {
tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
let (mut tx, rx) = super::channel(tx, rx);
for n in 1..10 {
tx.send(n).await.unwrap();
assert_eq!(tx.approx_count(), n);
assert_eq!(rx.approx_count(), n);
}
});
}
#[test]
fn send_only_twotasks() {
tor_rtmock::MockRuntime::test_with_various(|rt| async move {
let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
let (mut tx, rx) = super::channel(tx, rx);
let mut tx2 = tx.clone();
let j1 = rt.spawn_join("thread1", async move {
for n in 1..=10 {
tx.send(n).await.unwrap();
assert!(tx.approx_count() >= n);
}
});
let j2 = rt.spawn_join("thread2", async move {
for n in 1..=10 {
tx2.send(n).await.unwrap();
assert!(tx2.approx_count() >= n);
}
});
j1.await;
j2.await;
assert_eq!(rx.approx_count(), 20);
});
}
#[test]
fn send_and_receive() {
tor_rtmock::MockRuntime::test_with_various(|rt| async move {
let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
let (mut tx, mut rx) = super::channel(tx, rx);
const MAX: usize = 10000;
let mut tx2 = tx.clone();
let j1 = rt.spawn_join("thread1", async move {
for n in 1..=MAX {
tx.send(n).await.unwrap();
}
});
let j2 = rt.spawn_join("thread2", async move {
for n in 1..=MAX {
tx2.send(n).await.unwrap();
}
});
let j3 = rt.spawn_join("receiver", async move {
let mut total = 0;
while let Some(x) = rx.next().await {
total += x; let count = rx.approx_count();
assert!(count <= MAX * 2);
}
assert_eq!(total, MAX * (MAX + 1)); rx
});
j1.await;
j2.await;
let rx = j3.await;
assert_eq!(rx.approx_count(), 0);
});
}
}