use educe::Educe;
use futures::Stream;
use futures::stream::FusedStream;
use pin_project::pin_project;
use crate::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll, Poll::*, Waker};
#[derive(Debug)]
#[pin_project(project = PeekerProj)]
pub struct StreamUnobtrusivePeeker<S: Stream> {
buffered: Option<S::Item>,
poll_waker: Option<Waker>,
#[pin]
inner: Option<S>,
}
impl<S: Stream> StreamUnobtrusivePeeker<S> {
pub fn new(inner: S) -> Self {
StreamUnobtrusivePeeker {
buffered: None,
poll_waker: None,
inner: Some(inner),
}
}
}
impl<S: Stream> UnobtrusivePeekableStream for StreamUnobtrusivePeeker<S> {
fn unobtrusive_peek_mut<'s>(mut self: Pin<&'s mut Self>) -> Option<&'s mut S::Item> {
#[allow(clippy::question_mark)] if self.as_mut().project().buffered.is_none() {
let mut self_ = self.as_mut().project();
let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
return None;
};
let waker = if let Some(waker) = self_.poll_waker.as_ref() {
waker
} else {
Waker::noop()
};
match inner.poll_next(&mut Context::from_waker(waker)) {
Pending => {}
Ready(item_or_eof) => {
if let Some(waker) = self_.poll_waker.take() {
waker.wake();
}
match item_or_eof {
None => self_.inner.set(None),
Some(item) => *self_.buffered = Some(item),
}
}
};
}
self.project().buffered.as_mut()
}
}
impl<S: Stream> PeekableStream for StreamUnobtrusivePeeker<S> {
fn poll_peek<'s>(self: Pin<&'s mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
self.impl_poll_next_or_peek(cx, |buffered| buffered.as_ref())
}
fn poll_peek_mut<'s>(
self: Pin<&'s mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<&'s mut S::Item>> {
self.impl_poll_next_or_peek(cx, |buffered| buffered.as_mut())
}
}
impl<S: Stream> StreamUnobtrusivePeeker<S> {
fn impl_poll_next_or_peek<'s, R: 's>(
self: Pin<&'s mut Self>,
cx: &mut Context<'_>,
return_value_obtainer: impl FnOnce(&'s mut Option<S::Item>) -> Option<R>,
) -> Poll<Option<R>> {
let mut self_ = self.project();
let r = Self::next_or_peek_inner(&mut self_, cx);
let r = r.map(|()| return_value_obtainer(self_.buffered));
Self::return_from_poll(self_.poll_waker, cx, r)
}
fn next_or_peek_inner(self_: &mut PeekerProj<S>, cx: &mut Context<'_>) -> Poll<()> {
if let Some(_item) = self_.buffered.as_ref() {
return Ready(());
}
let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
return Ready(());
};
match inner.poll_next(cx) {
Ready(None) => {
self_.inner.set(None);
Ready(())
}
Ready(Some(item)) => {
*self_.buffered = Some(item);
Ready(())
}
Pending => {
Pending
}
}
}
#[allow(dead_code)] pub fn peek(self: Pin<&mut Self>) -> PeekFuture<Self> {
PeekFuture { peeker: Some(self) }
}
fn return_from_poll<R>(
poll_waker: &mut Option<Waker>,
cx: &mut Context<'_>,
r: Poll<R>,
) -> Poll<R> {
*poll_waker = match &r {
Ready(_) => {
None
}
Pending => {
Some(cx.waker().clone())
}
};
r
}
pub fn as_raw_inner_pin_mut<'s>(self: Pin<&'s mut Self>) -> Option<Pin<&'s mut S>> {
self.project().inner.as_pin_mut()
}
}
impl<S: Stream> Stream for StreamUnobtrusivePeeker<S> {
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.impl_poll_next_or_peek(cx, |buffered| buffered.take())
}
fn size_hint(&self) -> (usize, Option<usize>) {
let buf = self.buffered.iter().count();
let (imin, imax) = match &self.inner {
Some(inner) => inner.size_hint(),
None => (0, Some(0)),
};
(imin + buf, imax.and_then(|imap| imap.checked_add(buf)))
}
}
impl<S: Stream> FusedStream for StreamUnobtrusivePeeker<S> {
fn is_terminated(&self) -> bool {
self.buffered.is_none() && self.inner.is_none()
}
}
#[derive(Educe)]
#[educe(Debug(bound("S: Debug")))]
#[must_use = "peek() return a Future, which does nothing unless awaited"]
pub struct PeekFuture<'s, S> {
peeker: Option<Pin<&'s mut S>>,
}
impl<'s, S: PeekableStream> PeekFuture<'s, S> {
pub fn new(stream: Pin<&'s mut S>) -> Self {
Self {
peeker: Some(stream),
}
}
}
impl<'s, S: PeekableStream> Future for PeekFuture<'s, S> {
type Output = Option<&'s S::Item>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
let self_ = self.get_mut();
let peeker = self_
.peeker
.as_mut()
.expect("PeekFuture polled after Ready");
match peeker.as_mut().poll_peek(cx) {
Pending => return Pending,
Ready(_y) => {
}
}
let peeker = self_.peeker.take().expect("it was Some before!");
let r = peeker.poll_peek(cx);
assert!(r.is_ready(), "it was Ready before!");
r
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use futures::channel::mpsc;
use futures::{SinkExt as _, StreamExt as _};
use std::pin::pin;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tor_rtcompat::SleepProvider as _;
use tor_rtmock::MockRuntime;
fn ms(ms: u64) -> Duration {
Duration::from_millis(ms)
}
#[test]
fn wakeups() {
MockRuntime::test_with_various(|rt| async move {
let (mut tx, rx) = mpsc::unbounded();
let ended = Arc::new(Mutex::new(false));
rt.spawn_identified("rxr", {
let rt = rt.clone();
let ended = ended.clone();
async move {
let rx = StreamUnobtrusivePeeker::new(rx);
let mut rx = pin!(rx);
let mut next = 0;
loop {
rt.sleep(ms(50)).await;
eprintln!("rx peek... ");
let peeked = rx.as_mut().unobtrusive_peek_mut();
eprintln!("rx peeked {peeked:?}");
if let Some(peeked) = peeked {
assert_eq!(*peeked, next);
}
rt.sleep(ms(50)).await;
eprintln!("rx next... ");
let eaten = rx.next().await;
eprintln!("rx eaten {eaten:?}");
if let Some(eaten) = eaten {
assert_eq!(eaten, next);
next += 1;
} else {
break;
}
}
*ended.lock().unwrap() = true;
eprintln!("rx ended");
}
});
rt.spawn_identified("tx", {
let rt = rt.clone();
async move {
let mut numbers = 0..;
for wait in [125, 1, 125, 45, 1, 1, 1, 1000, 20, 1, 125, 125, 1000] {
eprintln!("tx sleep {wait}");
rt.sleep(ms(wait)).await;
let num = numbers.next().unwrap();
eprintln!("tx sending {num}");
tx.send(num).await.unwrap();
}
eprintln!("tx final #1");
rt.sleep(ms(75)).await;
eprintln!("tx EOF");
drop(tx);
eprintln!("tx final #2");
rt.sleep(ms(10)).await;
assert!(!*ended.lock().unwrap());
eprintln!("tx final #3");
rt.sleep(ms(50)).await;
eprintln!("tx final #4");
assert!(*ended.lock().unwrap());
}
});
rt.advance_until_stalled().await;
});
}
#[test]
fn poll_peek_paths() {
MockRuntime::test_with_various(|rt| async move {
let (mut tx, rx) = mpsc::unbounded();
let ended = Arc::new(Mutex::new(false));
rt.spawn_identified("rxr", {
let rt = rt.clone();
let ended = ended.clone();
async move {
let rx = StreamUnobtrusivePeeker::new(rx);
let mut rx = pin!(rx);
while let Some(peeked) = rx.as_mut().peek().await.copied() {
eprintln!("rx peeked {peeked}");
let eaten = rx.next().await.unwrap();
eprintln!("rx eaten {eaten}");
assert_eq!(peeked, eaten);
rt.sleep(ms(10)).await;
eprintln!("rx slept, peeking");
}
*ended.lock().unwrap() = true;
eprintln!("rx ended");
}
});
rt.spawn_identified("tx", {
let rt = rt.clone();
async move {
let mut numbers = 0..;
macro_rules! send { {} => {
let num = numbers.next().unwrap();
eprintln!("tx send {num}");
tx.send(num).await.unwrap();
} }
eprintln!("tx starting");
rt.sleep(ms(100)).await;
send!();
rt.sleep(ms(100)).await;
send!();
send!();
rt.sleep(ms(100)).await;
eprintln!("tx dropping");
drop(tx);
rt.sleep(ms(5)).await;
eprintln!("tx ending");
assert!(*ended.lock().unwrap());
}
});
rt.advance_until_stalled().await;
});
}
}