use buffer_unordered_weighted::StreamExt as _;
use futures::{stream, StreamExt as _};
use proptest::prelude::*;
use proptest_derive::Arbitrary;
use std::time::Duration;
#[derive(Clone, Debug, Arbitrary)]
struct TestState {
#[proptest(strategy = "1usize..64")]
max_weight: usize,
#[proptest(strategy = "prop::collection::vec(TestFutureDesc::arbitrary(), 0..512usize)")]
future_descriptions: Vec<TestFutureDesc>,
}
#[derive(Copy, Clone, Debug, Arbitrary)]
struct TestFutureDesc {
#[proptest(strategy = "duration_strategy()")]
delay: Duration,
#[proptest(strategy = "0usize..8")]
weight: usize,
}
fn duration_strategy() -> BoxedStrategy<Duration> {
(0u64..1000).prop_map(Duration::from_millis).boxed()
}
proptest! {
#[test]
fn proptest_buffer_unordered(state: TestState) {
proptest_buffer_unordered_impl(state)
}
}
#[derive(Clone, Copy, Debug)]
enum FutureEvent {
Started(usize, TestFutureDesc),
Finished(usize, TestFutureDesc),
}
fn proptest_buffer_unordered_impl(state: TestState) {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_time()
.start_paused(true)
.build()
.expect("tokio builder succeeded");
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
let futures = state
.future_descriptions
.iter()
.enumerate()
.map(|(id, desc)| {
let sender = sender.clone();
let delay_fut = async move {
sender
.send(FutureEvent::Started(id, *desc))
.expect("receiver held open by loop");
tokio::time::sleep(desc.delay).await;
sender
.send(FutureEvent::Finished(id, *desc))
.expect("receiver held open by loop");
};
(desc.weight, delay_fut)
});
let stream = stream::iter(futures);
let mut completed_map = vec![false; state.future_descriptions.len()];
let mut last_started_id: Option<usize> = None;
let mut current_weight = 0;
runtime.block_on(async move {
let mut stream = stream.buffer_unordered_weighted(state.max_weight);
loop {
tokio::select! {
biased;
recv = receiver.recv() => {
match recv {
Some(FutureEvent::Started(id, desc)) => {
let expected_id = last_started_id.map_or(0, |id| id + 1);
assert_eq!(expected_id, id, "expected future id to start != actual id that started");
last_started_id = Some(id);
assert!(
current_weight < state.max_weight,
"current weight {} exceeds max weight {}",
current_weight,
state.max_weight,
);
current_weight += desc.weight;
}
Some(FutureEvent::Finished(id, desc)) => {
completed_map[id] = true;
current_weight -= desc.weight;
}
None => {
}
}
}
next = stream.next() => {
if next.is_none() {
assert_eq!(stream.current_weight(), 0, "all futures complete => current weight is 0");
break;
}
}
else => {
tokio::time::advance(Duration::from_millis(1)).await;
}
}
}
let not_completed: Vec<_> = completed_map
.iter()
.enumerate()
.filter_map(|(n, &v)| (!v).then(|| n.to_string()))
.collect();
if !not_completed.is_empty() {
let not_completed_ids = not_completed.join(", ");
panic!("some futures did not complete: {}", not_completed_ids);
}
})
}