use super::Stream;
use pin_project::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
const ANY_ALL_COOPERATIVE_BUDGET: usize = 1024;
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
pub struct Any<S, P> {
#[pin]
stream: S,
predicate: P,
result: Option<bool>,
}
impl<S, P> Any<S, P> {
#[inline]
pub(crate) fn new(stream: S, predicate: P) -> Self {
Self {
stream,
predicate,
result: None,
}
}
}
impl<S, P> Future for Any<S, P>
where
S: Stream,
P: FnMut(&S::Item) -> bool,
{
type Output = bool;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> {
let mut this = self.project();
assert!(this.result.is_none(), "Any polled after completion");
let mut scanned_this_poll = 0usize;
loop {
match this.stream.as_mut().poll_next(cx) {
Poll::Ready(Some(item)) => {
if (this.predicate)(&item) {
*this.result = Some(true);
return Poll::Ready(true);
}
scanned_this_poll += 1;
if scanned_this_poll >= ANY_ALL_COOPERATIVE_BUDGET {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
Poll::Ready(None) => {
*this.result = Some(false);
return Poll::Ready(false);
}
Poll::Pending => return Poll::Pending,
}
}
}
}
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
pub struct All<S, P> {
#[pin]
stream: S,
predicate: P,
result: Option<bool>,
}
impl<S, P> All<S, P> {
#[inline]
pub(crate) fn new(stream: S, predicate: P) -> Self {
Self {
stream,
predicate,
result: None,
}
}
}
impl<S, P> Future for All<S, P>
where
S: Stream,
P: FnMut(&S::Item) -> bool,
{
type Output = bool;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> {
let mut this = self.project();
assert!(this.result.is_none(), "All polled after completion");
let mut scanned_this_poll = 0usize;
loop {
match this.stream.as_mut().poll_next(cx) {
Poll::Ready(Some(item)) => {
if !(this.predicate)(&item) {
*this.result = Some(false);
return Poll::Ready(false);
}
scanned_this_poll += 1;
if scanned_this_poll >= ANY_ALL_COOPERATIVE_BUDGET {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
Poll::Ready(None) => {
*this.result = Some(true);
return Poll::Ready(true);
}
Poll::Pending => return Poll::Pending,
}
}
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send
)]
use super::*;
use crate::stream::iter;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Poll, Waker};
fn noop_waker() -> Waker {
std::task::Waker::noop().clone()
}
struct TrackWaker(Arc<AtomicBool>);
use std::task::Wake;
impl Wake for TrackWaker {
fn wake(self: Arc<Self>) {
self.0.store(true, Ordering::SeqCst);
}
fn wake_by_ref(self: &Arc<Self>) {
self.0.store(true, Ordering::SeqCst);
}
}
#[derive(Debug, Default)]
struct AlwaysReadyCounter {
next: usize,
end: usize,
}
impl AlwaysReadyCounter {
fn new(end: usize) -> Self {
Self { next: 0, end }
}
}
impl Stream for AlwaysReadyCounter {
type Item = usize;
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.next >= self.end {
return Poll::Ready(None);
}
let item = self.next;
self.next += 1;
Poll::Ready(Some(item))
}
}
#[derive(Debug, Default)]
struct MatchThenPanicStream {
emitted: bool,
}
impl Stream for MatchThenPanicStream {
type Item = usize;
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
assert!(
!self.emitted,
"inner stream repolled after early completion"
);
self.emitted = true;
Poll::Ready(Some(1))
}
}
#[derive(Debug, Default)]
struct OneThenDoneThenPanicStream {
emitted: bool,
completed: bool,
}
impl Stream for OneThenDoneThenPanicStream {
type Item = usize;
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
assert!(!self.completed, "inner stream repolled after completion");
if self.emitted {
self.completed = true;
Poll::Ready(None)
} else {
self.emitted = true;
Poll::Ready(Some(2))
}
}
}
fn init_test(name: &str) {
crate::test_utils::init_test_logging();
crate::test_phase!(name);
}
fn poll_bool<F>(future: &mut F) -> bool
where
F: Future<Output = bool> + Unpin,
{
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
match Pin::new(future).poll(&mut cx) {
Poll::Ready(value) => value,
Poll::Pending => panic!("expected Ready"),
}
}
fn poll_bool_to_completion<F>(future: &mut F) -> (bool, usize)
where
F: Future<Output = bool> + Unpin,
{
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut pending_polls = 0usize;
loop {
match Pin::new(&mut *future).poll(&mut cx) {
Poll::Ready(value) => return (value, pending_polls),
Poll::Pending => {
pending_polls += 1;
assert!(
pending_polls <= 8,
"future did not complete after {pending_polls} pending polls",
);
}
}
}
}
#[test]
fn any_found() {
init_test("any_found");
let mut future = Any::new(iter(vec![1i32, 2, 3, 4, 5]), |&x: &i32| x > 3);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
match Pin::new(&mut future).poll(&mut cx) {
Poll::Ready(found) => {
crate::assert_with_log!(found, "any found", true, found);
}
Poll::Pending => panic!("expected Ready"),
}
crate::test_complete!("any_found");
}
#[test]
fn any_not_found() {
init_test("any_not_found");
let mut future = Any::new(iter(vec![1i32, 2, 3]), |&x: &i32| x > 5);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
match Pin::new(&mut future).poll(&mut cx) {
Poll::Ready(found) => {
crate::assert_with_log!(!found, "any not found", false, found);
}
Poll::Pending => panic!("expected Ready"),
}
crate::test_complete!("any_not_found");
}
#[test]
fn any_empty() {
init_test("any_empty");
let mut future = Any::new(iter(Vec::<i32>::new()), |_: &i32| true);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
match Pin::new(&mut future).poll(&mut cx) {
Poll::Ready(found) => {
crate::assert_with_log!(!found, "empty false", false, found);
}
Poll::Pending => panic!("expected Ready"),
}
crate::test_complete!("any_empty");
}
#[test]
fn all_pass() {
init_test("all_pass");
let mut future = All::new(iter(vec![2i32, 4, 6, 8]), |&x: &i32| x % 2 == 0);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
match Pin::new(&mut future).poll(&mut cx) {
Poll::Ready(all_pass) => {
crate::assert_with_log!(all_pass, "all pass", true, all_pass);
}
Poll::Pending => panic!("expected Ready"),
}
crate::test_complete!("all_pass");
}
#[test]
fn all_fail() {
init_test("all_fail");
let mut future = All::new(iter(vec![2i32, 4, 5, 8]), |&x: &i32| x % 2 == 0);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
match Pin::new(&mut future).poll(&mut cx) {
Poll::Ready(all_pass) => {
crate::assert_with_log!(!all_pass, "all fail", false, all_pass);
}
Poll::Pending => panic!("expected Ready"),
}
crate::test_complete!("all_fail");
}
#[test]
fn all_empty() {
init_test("all_empty");
let mut future = All::new(iter(Vec::<i32>::new()), |_: &i32| false);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
match Pin::new(&mut future).poll(&mut cx) {
Poll::Ready(all_pass) => {
crate::assert_with_log!(all_pass, "empty true", true, all_pass);
}
Poll::Pending => panic!("expected Ready"),
}
crate::test_complete!("all_empty");
}
#[test]
fn any_all_duality_matches_negated_predicate() {
init_test("any_all_duality_matches_negated_predicate");
let data = vec![-3i32, -1, 0, 2, 5];
let mut all_with_no_counterexample = All::new(iter(data.clone()), |&x: &i32| x <= 5);
let mut any_counterexample = Any::new(iter(data.clone()), |&x: &i32| x > 5);
let all_result = poll_bool(&mut all_with_no_counterexample);
let any_negated_result = poll_bool(&mut any_counterexample);
crate::assert_with_log!(
all_result != any_negated_result,
"all(p) equals !any(!p) when all items satisfy p",
true,
all_result != any_negated_result
);
let mut all_with_counterexample = All::new(iter(data.clone()), |&x: &i32| x < 5);
let mut any_counterexample = Any::new(iter(data), |&x: &i32| x >= 5);
let all_result = poll_bool(&mut all_with_counterexample);
let any_negated_result = poll_bool(&mut any_counterexample);
crate::assert_with_log!(
all_result != any_negated_result,
"all(p) equals !any(!p) when a counterexample exists",
true,
all_result != any_negated_result
);
crate::test_complete!("any_all_duality_matches_negated_predicate");
}
#[test]
fn mr_any_all_partition_composition_matches_combined_input() {
init_test("mr_any_all_partition_composition_matches_combined_input");
let left = vec![-7i32, -2, 0, 3];
let right = vec![4i32, 9, 12];
let mut combined = left.clone();
combined.extend(right.iter().copied());
let predicate = |&x: &i32| x >= 9;
let combined_any = poll_bool(&mut Any::new(iter(combined.clone()), predicate));
let left_any = poll_bool(&mut Any::new(iter(left.clone()), predicate));
let right_any = poll_bool(&mut Any::new(iter(right.clone()), predicate));
assert_eq!(combined_any, left_any || right_any);
let predicate = |&x: &i32| x >= -7;
let combined_all = poll_bool(&mut All::new(iter(combined), predicate));
let left_all = poll_bool(&mut All::new(iter(left), predicate));
let right_all = poll_bool(&mut All::new(iter(right), predicate));
assert_eq!(combined_all, left_all && right_all);
crate::test_complete!("mr_any_all_partition_composition_matches_combined_input");
}
#[test]
fn mr_any_all_threshold_translation_preserves_truth_values() {
init_test("mr_any_all_threshold_translation_preserves_truth_values");
let input = vec![-10i32, -3, 0, 7, 14];
let offset = 11i32;
let shifted: Vec<_> = input.iter().map(|item| item + offset).collect();
let baseline_any = poll_bool(&mut Any::new(iter(input.clone()), |&x: &i32| x > 6));
let shifted_any = poll_bool(&mut Any::new(iter(shifted.clone()), |&x: &i32| {
x > 6 + offset
}));
assert_eq!(shifted_any, baseline_any);
let baseline_all = poll_bool(&mut All::new(iter(input), |&x: &i32| x < 20));
let shifted_all = poll_bool(&mut All::new(iter(shifted), |&x: &i32| x < 20 + offset));
assert_eq!(shifted_all, baseline_all);
crate::test_complete!("mr_any_all_threshold_translation_preserves_truth_values");
}
#[test]
fn mr_any_all_append_monotonicity_respects_witnesses() {
init_test("mr_any_all_append_monotonicity_respects_witnesses");
let matching_prefix = vec![1i32, 4, 8];
let mut matching_with_tail = matching_prefix.clone();
matching_with_tail.extend([-100, -200]);
let prefix_any = poll_bool(&mut Any::new(iter(matching_prefix), |&x: &i32| x >= 8));
let tail_any = poll_bool(&mut Any::new(iter(matching_with_tail), |&x: &i32| x >= 8));
assert!(prefix_any);
assert_eq!(tail_any, prefix_any);
let all_prefix = vec![2i32, 4, 6];
let mut all_with_satisfying_tail = all_prefix.clone();
all_with_satisfying_tail.extend([8, 10]);
let mut all_with_counterexample_tail = all_prefix.clone();
all_with_counterexample_tail.extend([8, 11]);
let prefix_all = poll_bool(&mut All::new(iter(all_prefix), |&x: &i32| x % 2 == 0));
let satisfying_tail_all =
poll_bool(&mut All::new(iter(all_with_satisfying_tail), |&x: &i32| {
x % 2 == 0
}));
let counterexample_tail_all = poll_bool(&mut All::new(
iter(all_with_counterexample_tail),
|&x: &i32| x % 2 == 0,
));
assert!(prefix_all);
assert!(satisfying_tail_all);
assert!(!counterexample_tail_all);
crate::test_complete!("mr_any_all_append_monotonicity_respects_witnesses");
}
#[test]
fn mr_any_all_cooperative_segmentation_preserves_results() {
init_test("mr_any_all_cooperative_segmentation_preserves_results");
let len = (ANY_ALL_COOPERATIVE_BUDGET * 2) + 17;
let mut any_future = Any::new(AlwaysReadyCounter::new(len), |&x: &usize| x == len - 1);
let (any_result, any_pending) = poll_bool_to_completion(&mut any_future);
assert!(any_result);
assert_eq!(any_pending, len / ANY_ALL_COOPERATIVE_BUDGET);
let mut all_future = All::new(AlwaysReadyCounter::new(len), |&x: &usize| x < len);
let (all_result, all_pending) = poll_bool_to_completion(&mut all_future);
assert!(all_result);
assert_eq!(all_pending, len / ANY_ALL_COOPERATIVE_BUDGET);
crate::test_complete!("mr_any_all_cooperative_segmentation_preserves_results");
}
#[test]
fn any_yields_after_budget_when_predicate_never_matches() {
init_test("any_yields_after_budget_when_predicate_never_matches");
let mut future = Any::new(
AlwaysReadyCounter::new(ANY_ALL_COOPERATIVE_BUDGET + 5),
|_: &usize| false,
);
let woke = Arc::new(AtomicBool::new(false));
let waker = Waker::from(Arc::new(TrackWaker(woke.clone())));
let mut cx = Context::from_waker(&waker);
let first = Pin::new(&mut future).poll(&mut cx);
crate::assert_with_log!(
matches!(first, Poll::Pending),
"first poll yields cooperatively",
"Poll::Pending",
first
);
crate::assert_with_log!(
future.stream.next == ANY_ALL_COOPERATIVE_BUDGET,
"upstream advanced only to budget",
ANY_ALL_COOPERATIVE_BUDGET,
future.stream.next
);
crate::assert_with_log!(
woke.load(Ordering::SeqCst),
"self-wake requested",
true,
woke.load(Ordering::SeqCst)
);
let second = Pin::new(&mut future).poll(&mut cx);
crate::assert_with_log!(
second == Poll::Ready(false),
"second poll completes with no match",
Poll::Ready(false),
second
);
crate::test_complete!("any_yields_after_budget_when_predicate_never_matches");
}
#[test]
fn all_yields_after_budget_when_predicate_stays_true() {
init_test("all_yields_after_budget_when_predicate_stays_true");
let mut future = All::new(
AlwaysReadyCounter::new(ANY_ALL_COOPERATIVE_BUDGET + 5),
|_: &usize| true,
);
let woke = Arc::new(AtomicBool::new(false));
let waker = Waker::from(Arc::new(TrackWaker(woke.clone())));
let mut cx = Context::from_waker(&waker);
let first = Pin::new(&mut future).poll(&mut cx);
crate::assert_with_log!(
matches!(first, Poll::Pending),
"first poll yields cooperatively",
"Poll::Pending",
first
);
crate::assert_with_log!(
future.stream.next == ANY_ALL_COOPERATIVE_BUDGET,
"upstream advanced only to budget",
ANY_ALL_COOPERATIVE_BUDGET,
future.stream.next
);
crate::assert_with_log!(
woke.load(Ordering::SeqCst),
"self-wake requested",
true,
woke.load(Ordering::SeqCst)
);
let second = Pin::new(&mut future).poll(&mut cx);
crate::assert_with_log!(
second == Poll::Ready(true),
"second poll completes with all true",
Poll::Ready(true),
second
);
crate::test_complete!("all_yields_after_budget_when_predicate_stays_true");
}
#[test]
fn any_repoll_panics_after_completion() {
init_test("any_repoll_panics_after_completion");
let mut future = Any::new(MatchThenPanicStream::default(), |_: &usize| true);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let first = Pin::new(&mut future).poll(&mut cx);
crate::assert_with_log!(
first == Poll::Ready(true),
"first poll finds match",
Poll::Ready(true),
first
);
let second = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Pin::new(&mut future).poll(&mut cx)
}));
let payload = second.expect_err("repoll after completion must panic");
let message = payload
.downcast_ref::<&str>()
.map(ToString::to_string)
.or_else(|| payload.downcast_ref::<String>().cloned())
.unwrap_or_default();
crate::assert_with_log!(
message.contains("Any polled after completion"),
"second poll fails closed",
true,
message.contains("Any polled after completion")
);
crate::test_complete!("any_repoll_panics_after_completion");
}
#[test]
fn all_repoll_panics_after_completion() {
init_test("all_repoll_panics_after_completion");
let mut future = All::new(OneThenDoneThenPanicStream::default(), |_: &usize| true);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let first = Pin::new(&mut future).poll(&mut cx);
crate::assert_with_log!(
first == Poll::Ready(true),
"first poll exhausts stream",
Poll::Ready(true),
first
);
let second = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Pin::new(&mut future).poll(&mut cx)
}));
let payload = second.expect_err("repoll after completion must panic");
let message = payload
.downcast_ref::<&str>()
.map(ToString::to_string)
.or_else(|| payload.downcast_ref::<String>().cloned())
.unwrap_or_default();
crate::assert_with_log!(
message.contains("All polled after completion"),
"second poll fails closed",
true,
message.contains("All polled after completion")
);
crate::test_complete!("all_repoll_panics_after_completion");
}
}