use std::{
collections::VecDeque,
fmt::Debug,
panic::panic_any,
sync::{Arc, Condvar, Mutex, MutexGuard},
time::{Duration, Instant},
};
use crate::stream::{BoxStream, NotUsed, Sink, Source, StreamCompletion};
use crate::{StreamError, StreamResult};
const DEFAULT_PROBE_TIMEOUT: Duration = Duration::from_secs(3);
pub struct TestSource;
impl TestSource {
#[must_use]
pub fn probe<T: Send + 'static>() -> Source<T, TestPublisherProbe<T>> {
Source::from_materialized_factory(|_| {
let shared = Arc::new(SourceProbeShared::default());
let stream = Box::new(TestSourceStream {
shared: Arc::clone(&shared),
waiting_for_command: false,
}) as BoxStream<T>;
Ok((stream, TestPublisherProbe::new(shared)))
})
}
}
pub struct TestSink;
impl TestSink {
#[must_use]
pub fn probe<T: Send + 'static>() -> Sink<T, TestSubscriberProbe<T>> {
Sink::from_runner(|mut input, materializer| {
let shared = Arc::new(SinkProbeShared::default());
let task_shared = Arc::clone(&shared);
let completion = materializer.spawn_stream(move |_cancelled| {
loop {
task_shared.wait_for_request()?;
match input.next() {
Some(Ok(item)) => task_shared.push_event(SinkEvent::Next(item)),
Some(Err(error)) => {
task_shared.push_event(SinkEvent::Error(error.clone()));
return Err(error);
}
None => {
task_shared.push_event(SinkEvent::Complete);
return Ok(NotUsed);
}
}
}
});
Ok(TestSubscriberProbe::new(shared, completion))
})
}
}
pub fn assert_next_eq<T>(actual: &T, expected: &T)
where
T: Debug + PartialEq,
{
assert_eq!(
actual, expected,
"expected next element {expected:?}, got {actual:?}"
);
}
pub fn assert_next_n_eq<T>(actual: &[T], expected: &[T])
where
T: Debug + PartialEq,
{
assert_eq!(
actual, expected,
"expected next elements {expected:?}, got {actual:?}"
);
}
pub struct TestPublisherProbe<T> {
shared: Arc<SourceProbeShared<T>>,
timeout: Duration,
}
impl<T> TestPublisherProbe<T> {
fn new(shared: Arc<SourceProbeShared<T>>) -> Self {
Self {
shared,
timeout: DEFAULT_PROBE_TIMEOUT,
}
}
pub fn set_timeout(&mut self, timeout: Duration) {
self.timeout = timeout;
}
pub fn send_next(&self, element: T) {
self.shared.enqueue(SourceCommand::Next(element));
}
pub fn send_complete(&self) {
self.shared.enqueue(SourceCommand::Complete);
}
pub fn send_error(&self, error: StreamError) {
self.shared.enqueue(SourceCommand::Error(error));
}
pub fn expect_request(&self) -> usize {
self.shared.expect_request(self.timeout)
}
pub fn expect_cancellation(&self) {
self.shared.expect_cancellation(self.timeout);
}
}
impl<T> Drop for TestPublisherProbe<T> {
fn drop(&mut self) {
self.shared.fail_if_open(StreamError::Failed(
"test source probe dropped before completion".to_owned(),
));
}
}
pub struct TestSubscriberProbe<T> {
shared: Arc<SinkProbeShared<T>>,
timeout: Duration,
completion: Option<StreamCompletion<NotUsed>>,
}
impl<T> TestSubscriberProbe<T> {
fn new(shared: Arc<SinkProbeShared<T>>, completion: StreamCompletion<NotUsed>) -> Self {
Self {
shared,
timeout: DEFAULT_PROBE_TIMEOUT,
completion: Some(completion),
}
}
pub fn set_timeout(&mut self, timeout: Duration) {
self.timeout = timeout;
}
pub fn request(&self, n: usize) {
assert!(n > 0, "request count must be positive, got {n}");
self.shared.request(n);
}
pub fn expect_next(&self) -> T {
match self.shared.expect_event(self.timeout, "next element") {
SinkEvent::Next(item) => item,
SinkEvent::Complete => panic_any(format!(
"expected next element, got stream completion after waiting {:?}",
self.timeout
)),
SinkEvent::Error(error) => {
panic_any(format!("expected next element, got stream error {error:?}"))
}
}
}
pub fn assert_next(&self, expected: T)
where
T: Debug + PartialEq,
{
let actual = self.expect_next();
assert_next_eq(&actual, &expected);
}
pub fn expect_next_n(&self, n: usize) -> Vec<T> {
(0..n).map(|_| self.expect_next()).collect()
}
pub fn assert_next_n<I>(&self, expected: I)
where
T: Debug + PartialEq,
I: IntoIterator<Item = T>,
{
let expected: Vec<T> = expected.into_iter().collect();
let actual = self.expect_next_n(expected.len());
assert_next_n_eq(&actual, &expected);
}
pub fn expect_complete(&self) {
match self.shared.expect_event(self.timeout, "stream completion") {
SinkEvent::Complete => {}
SinkEvent::Next(_) => panic_any("expected stream completion, got next element"),
SinkEvent::Error(error) => panic_any(format!(
"expected stream completion, got stream error {error:?}"
)),
}
}
pub fn expect_error(&self) -> StreamError {
match self.shared.expect_event(self.timeout, "stream error") {
SinkEvent::Error(error) => error,
SinkEvent::Next(_) => panic_any("expected stream error, got next element"),
SinkEvent::Complete => panic_any("expected stream error, got stream completion"),
}
}
pub fn expect_no_message(&self, timeout: Duration) {
self.shared.expect_no_message(timeout);
}
#[must_use]
pub fn drain_until_complete(&self) -> Vec<T> {
self.request(usize::MAX / 2);
let mut values = Vec::new();
loop {
match self.shared.expect_event(self.timeout, "stream completion") {
SinkEvent::Next(item) => values.push(item),
SinkEvent::Complete => return values,
SinkEvent::Error(error) => panic_any(format!(
"expected stream completion, got stream error {error:?}"
)),
}
}
}
pub fn cancel(&mut self) {
self.shared.cancel();
let _ = self.completion.take();
}
}
impl<T> Drop for TestSubscriberProbe<T> {
fn drop(&mut self) {
self.shared.cancel();
let _ = self.completion.take();
}
}
struct TestSourceStream<T> {
shared: Arc<SourceProbeShared<T>>,
waiting_for_command: bool,
}
impl<T> Iterator for TestSourceStream<T> {
type Item = StreamResult<T>;
fn next(&mut self) -> Option<Self::Item> {
if !self.waiting_for_command {
self.shared.record_demand();
self.waiting_for_command = true;
}
match self.shared.next_command() {
Some(SourceCommand::Next(item)) => {
self.waiting_for_command = false;
Some(Ok(item))
}
Some(SourceCommand::Complete) => {
self.waiting_for_command = false;
None
}
Some(SourceCommand::Error(error)) => {
self.waiting_for_command = false;
Some(Err(error))
}
None => {
self.waiting_for_command = false;
None
}
}
}
}
impl<T> Drop for TestSourceStream<T> {
fn drop(&mut self) {
self.shared.mark_cancelled();
}
}
enum SourceCommand<T> {
Next(T),
Complete,
Error(StreamError),
}
struct SourceProbeShared<T> {
state: Mutex<SourceProbeState<T>>,
condvar: Condvar,
}
struct SourceProbeState<T> {
commands: VecDeque<SourceCommand<T>>,
request_events: VecDeque<usize>,
cancelled: bool,
terminated: bool,
}
impl<T> Default for SourceProbeShared<T> {
fn default() -> Self {
Self {
state: Mutex::new(SourceProbeState {
commands: VecDeque::new(),
request_events: VecDeque::new(),
cancelled: false,
terminated: false,
}),
condvar: Condvar::new(),
}
}
}
impl<T> SourceProbeShared<T> {
fn enqueue(&self, command: SourceCommand<T>) {
let mut state = lock_unpoison(&self.state);
if state.terminated {
panic_any("test source probe is already terminated");
}
state.commands.push_back(command);
if !matches!(state.commands.back(), Some(SourceCommand::Next(_))) {
state.terminated = true;
}
self.condvar.notify_all();
}
fn fail_if_open(&self, error: StreamError) {
let mut state = lock_unpoison(&self.state);
if state.terminated {
return;
}
state.commands.push_back(SourceCommand::Error(error));
state.terminated = true;
self.condvar.notify_all();
}
fn record_demand(&self) {
let mut state = lock_unpoison(&self.state);
if state.terminated {
return;
}
state.request_events.push_back(1);
self.condvar.notify_all();
}
fn next_command(&self) -> Option<SourceCommand<T>> {
let mut state = lock_unpoison(&self.state);
loop {
if let Some(command) = state.commands.pop_front() {
if matches!(command, SourceCommand::Complete | SourceCommand::Error(_)) {
state.terminated = true;
}
return Some(command);
}
if state.terminated {
return None;
}
state = wait_unpoison(&self.condvar, state);
}
}
fn expect_request(&self, timeout: Duration) -> usize {
let deadline = Instant::now() + timeout;
let mut state = lock_unpoison(&self.state);
loop {
if let Some(requested) = state.request_events.pop_front() {
return requested;
}
if state.cancelled {
panic_any("expected downstream demand, but the stream was cancelled");
}
state = wait_until(&self.condvar, state, deadline, "downstream demand");
}
}
fn expect_cancellation(&self, timeout: Duration) {
let deadline = Instant::now() + timeout;
let mut state = lock_unpoison(&self.state);
while !state.cancelled {
state = wait_until(&self.condvar, state, deadline, "stream cancellation");
}
}
fn mark_cancelled(&self) {
let mut state = lock_unpoison(&self.state);
state.cancelled = true;
state.terminated = true;
self.condvar.notify_all();
}
}
enum SinkEvent<T> {
Next(T),
Complete,
Error(StreamError),
}
struct SinkProbeShared<T> {
state: Mutex<SinkProbeState<T>>,
condvar: Condvar,
}
struct SinkProbeState<T> {
requested: usize,
events: VecDeque<SinkEvent<T>>,
cancelled: bool,
}
impl<T> Default for SinkProbeShared<T> {
fn default() -> Self {
Self {
state: Mutex::new(SinkProbeState {
requested: 0,
events: VecDeque::new(),
cancelled: false,
}),
condvar: Condvar::new(),
}
}
}
impl<T> SinkProbeShared<T> {
fn request(&self, n: usize) {
let mut state = lock_unpoison(&self.state);
state.requested = state.requested.saturating_add(n);
self.condvar.notify_all();
}
fn wait_for_request(&self) -> StreamResult<()> {
let mut state = lock_unpoison(&self.state);
loop {
if state.cancelled {
return Err(StreamError::Cancelled);
}
if state.requested > 0 {
state.requested -= 1;
return Ok(());
}
state = wait_unpoison(&self.condvar, state);
}
}
fn push_event(&self, event: SinkEvent<T>) {
let mut state = lock_unpoison(&self.state);
state.events.push_back(event);
self.condvar.notify_all();
}
fn expect_event(&self, timeout: Duration, expected: &str) -> SinkEvent<T> {
let deadline = Instant::now() + timeout;
let mut state = lock_unpoison(&self.state);
loop {
if let Some(event) = state.events.pop_front() {
return event;
}
state = wait_until(&self.condvar, state, deadline, expected);
}
}
fn expect_no_message(&self, timeout: Duration) {
let deadline = Instant::now() + timeout;
let mut state = lock_unpoison(&self.state);
while state.events.is_empty() {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return;
}
let (next_state, result) = wait_timeout_unpoison(&self.condvar, state, remaining);
state = next_state;
if result.timed_out() && state.events.is_empty() {
return;
}
}
let event = state
.events
.pop_front()
.expect("queued sink event present after wake");
panic_any(format!(
"expected no stream message for {timeout:?}, got {}",
describe_event(&event)
));
}
fn cancel(&self) {
let mut state = lock_unpoison(&self.state);
state.cancelled = true;
self.condvar.notify_all();
}
}
fn wait_until<'a, T>(
condvar: &Condvar,
state: MutexGuard<'a, T>,
deadline: Instant,
expected: &str,
) -> MutexGuard<'a, T> {
let started = Instant::now();
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
panic_any(format!(
"timed out waiting for {expected} after {:?}",
started.elapsed()
));
}
let (state, result) = wait_timeout_unpoison(condvar, state, remaining);
if result.timed_out() {
panic_any(format!(
"timed out waiting for {expected} after {:?}",
started.elapsed()
));
}
state
}
fn lock_unpoison<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
mutex
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn wait_unpoison<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
condvar
.wait(guard)
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn wait_timeout_unpoison<'a, T>(
condvar: &Condvar,
guard: MutexGuard<'a, T>,
timeout: Duration,
) -> (MutexGuard<'a, T>, std::sync::WaitTimeoutResult) {
condvar
.wait_timeout(guard, timeout)
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn describe_event<T>(event: &SinkEvent<T>) -> String {
match event {
SinkEvent::Next(_) => "next element".to_owned(),
SinkEvent::Complete => "stream completion".to_owned(),
SinkEvent::Error(error) => format!("stream error {error:?}"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Keep, Materializer, Sink, Source};
use std::panic::{self, AssertUnwindSafe};
fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
match payload.downcast::<String>() {
Ok(message) => *message,
Err(payload) => match payload.downcast::<&'static str>() {
Ok(message) => (*message).to_owned(),
Err(_) => "<non-string panic payload>".to_owned(),
},
}
}
#[test]
fn test_source_and_sink_probes_drive_map_and_completion() {
let materializer = Materializer::new();
let (source, sink) = TestSource::probe::<i32>()
.map(|value| value * 2)
.to_mat(TestSink::probe(), Keep::both)
.run_with_materializer(&materializer)
.expect("test graph materializes");
sink.request(1);
assert_eq!(source.expect_request(), 1);
source.send_next(2);
sink.assert_next(4);
sink.request(1);
assert_eq!(source.expect_request(), 1);
source.send_complete();
sink.expect_complete();
}
#[test]
fn test_sink_probe_validates_take_and_completion() {
let sink = Source::from_iter(1..=5)
.map(|value| value + 10)
.take(2)
.run_with(TestSink::probe())
.expect("test sink materializes");
sink.request(2);
sink.assert_next_n([11, 12]);
sink.request(1);
sink.expect_complete();
}
#[test]
fn test_source_probe_surfaces_stream_errors() {
let materializer = Materializer::new();
let (source, sink) = TestSource::probe::<i32>()
.to_mat(TestSink::probe(), Keep::both)
.run_with_materializer(&materializer)
.expect("test graph materializes");
sink.request(1);
assert_eq!(source.expect_request(), 1);
source.send_error(StreamError::Failed("boom".to_owned()));
assert_eq!(sink.expect_error(), StreamError::Failed("boom".to_owned()));
}
#[test]
fn test_source_probe_observes_downstream_cancellation() {
let materializer = Materializer::new();
let (source, completion) = TestSource::probe::<i32>()
.take(1)
.to_mat(Sink::collect(), Keep::both)
.run_with_materializer(&materializer)
.expect("test graph materializes");
assert_eq!(source.expect_request(), 1);
source.send_next(7);
assert_eq!(completion.wait().expect("take collects one item"), vec![7]);
source.expect_cancellation();
}
#[test]
fn test_sink_probe_observes_empty_source_completion_after_request() {
let sink = Source::<i32>::empty()
.run_with(TestSink::probe())
.expect("test sink materializes");
sink.request(1);
sink.expect_complete();
}
#[test]
fn test_sink_probe_observes_failed_source_error_after_request() {
let sink = Source::<i32>::failed(StreamError::Failed("boom".to_owned()))
.run_with(TestSink::probe())
.expect("test sink materializes");
sink.request(1);
assert_eq!(sink.expect_error(), StreamError::Failed("boom".to_owned()));
}
#[test]
fn test_testkit_blueprints_materialize_independent_probe_pairs() {
let blueprint = TestSource::probe::<i32>()
.map(|value| value * 10)
.to_mat(TestSink::probe(), Keep::both);
let materializer = Materializer::new();
let (source_a, sink_a) = blueprint
.run_with_materializer(&materializer)
.expect("first probe pair materializes");
let (source_b, sink_b) = blueprint
.run_with_materializer(&materializer)
.expect("second probe pair materializes");
sink_a.request(1);
assert_eq!(source_a.expect_request(), 1);
source_a.send_next(2);
sink_a.assert_next(20);
sink_b.expect_no_message(Duration::from_millis(25));
sink_b.request(1);
assert_eq!(source_b.expect_request(), 1);
source_b.send_next(3);
sink_b.assert_next(30);
sink_a.expect_no_message(Duration::from_millis(25));
sink_a.request(1);
assert_eq!(source_a.expect_request(), 1);
source_a.send_complete();
sink_a.expect_complete();
sink_b.request(1);
sink_b.expect_no_message(Duration::from_millis(25));
source_b.send_complete();
sink_b.expect_complete();
}
#[test]
fn test_assert_next_reports_expected_and_actual_values() {
let sink = Source::single(1)
.run_with(TestSink::probe())
.expect("test sink materializes");
sink.request(1);
let panic = panic::catch_unwind(AssertUnwindSafe(|| sink.assert_next(2)))
.expect_err("assert_next should panic on mismatch");
let message = panic_message(panic);
assert!(message.contains("expected next element 2, got 1"));
}
#[test]
fn test_expect_complete_times_out_with_clear_message() {
let materializer = Materializer::new();
let (source, mut sink) = TestSource::probe::<i32>()
.to_mat(TestSink::probe(), Keep::both)
.run_with_materializer(&materializer)
.expect("test graph materializes");
sink.set_timeout(Duration::from_millis(50));
sink.request(1);
assert_eq!(source.expect_request(), 1);
source.send_next(1);
sink.assert_next(1);
let panic = panic::catch_unwind(AssertUnwindSafe(|| sink.expect_complete()))
.expect_err("expect_complete should panic on timeout");
let message = panic_message(panic);
assert!(message.contains("timed out waiting for stream completion"));
}
#[test]
fn test_expect_next_times_out_with_clear_message() {
let materializer = Materializer::new();
let (_source, mut sink) = TestSource::probe::<i32>()
.to_mat(TestSink::probe(), Keep::both)
.run_with_materializer(&materializer)
.expect("test graph materializes");
sink.set_timeout(Duration::from_millis(50));
sink.request(1);
let panic = panic::catch_unwind(AssertUnwindSafe(|| sink.expect_next()))
.expect_err("expect_next should panic on timeout");
let message = panic_message(panic);
assert!(message.contains("timed out waiting for next element"));
}
}