Skip to main content

futures_testing/
driver.rs

1use core::future::Future;
2use core::marker::PhantomData;
3use core::ops::{AsyncFnMut, ControlFlow};
4use core::pin::pin;
5use core::task::Context;
6use std::{
7    pin::Pin,
8    task::{Poll, Waker},
9};
10
11use arbitrary::{Arbitrary, Unstructured};
12use futures_util::Sink;
13
14use crate::Driver;
15
16pin_project_lite::pin_project!(
17    /// See [`drive_poll_fn`]
18    pub struct PollFnDriver<F, A> {
19        f: F,
20        _arg: PhantomData<A>,
21    }
22);
23
24/// Construct a [`Driver`] from a synchronous [`FnMut`].
25///
26/// Use this when the driver logic doesn't need `.await` (e.g. `try_send` on a
27/// channel).
28///
29/// The function receives an arbitrary argument and returns `Poll<ControlFlow<()>>`:
30/// - `Poll::Ready(ControlFlow::Continue(()))` - progress made
31/// - `Poll::Ready(ControlFlow::Break(()))` - driver is done
32/// - `Poll::Pending` - no progress
33pub fn drive_poll_fn<A, F>(f: F) -> PollFnDriver<F, A>
34where
35    A: for<'a> Arbitrary<'a>,
36    F: FnMut(A) -> Poll<ControlFlow<()>>,
37{
38    PollFnDriver {
39        f,
40        _arg: PhantomData,
41    }
42}
43
44impl<'a, A, F> Driver<'a> for PollFnDriver<F, A>
45where
46    A: Arbitrary<'a>,
47    F: FnMut(A) -> Poll<ControlFlow<()>>,
48{
49    fn poll(
50        self: Pin<&mut Self>,
51        args: &mut Unstructured<'a>,
52    ) -> arbitrary::Result<Poll<ControlFlow<()>>> {
53        Ok((self.project().f)(args.arbitrary()?))
54    }
55}
56
57/// See [`drive_fn`]
58pub struct AsyncFnDriver<F, A> {
59    f: F,
60    _arg: PhantomData<fn(A)>,
61}
62
63/// Construct a [`Driver`] from an [`AsyncFnMut`].
64///
65/// Use this when the driver needs `.await` (e.g. `tx.send(item).await`).
66///
67/// The async function receives an arbitrary argument and returns `ControlFlow<()>`:
68/// - `ControlFlow::Continue(())` - progress made, future should be polled
69/// - `ControlFlow::Break(())` - driver is done
70pub fn drive_fn<A, F>(f: F) -> AsyncFnDriver<F, A>
71where
72    A: for<'a> Arbitrary<'a>,
73    F: AsyncFnMut(A) -> ControlFlow<()>,
74{
75    AsyncFnDriver {
76        f,
77        _arg: PhantomData,
78    }
79}
80
81impl<'a, A, F> Driver<'a> for AsyncFnDriver<F, A>
82where
83    A: Arbitrary<'a>,
84    F: AsyncFnMut(A) -> ControlFlow<()> + Unpin,
85{
86    fn poll(
87        self: Pin<&mut Self>,
88        args: &mut Unstructured<'a>,
89    ) -> arbitrary::Result<Poll<ControlFlow<()>>> {
90        let this = self.get_mut();
91        let cx = &mut Context::from_waker(Waker::noop());
92        let arg: A = args.arbitrary()?;
93        let mut fut = pin!((this.f)(arg));
94        match fut.as_mut().poll(cx) {
95            Poll::Ready(cf) => Ok(Poll::Ready(cf)),
96            Poll::Pending => Ok(Poll::Pending),
97        }
98    }
99}
100
101/// Construct a [`Driver`] from a [`Sink`].
102///
103/// Use this when the driver side already implements [`Sink`] (e.g. the sender
104/// half of a `futures::channel::mpsc`). Handles `poll_ready`, `start_send`,
105/// and `poll_close` automatically.
106pub fn drive_sink<A, S>(sink: S) -> SinkDriver<S, A>
107where
108    A: for<'a> Arbitrary<'a>,
109    S: Sink<A, Error: std::fmt::Debug>,
110{
111    SinkDriver {
112        sink,
113        closing: false,
114        closed: false,
115        _arg: PhantomData,
116    }
117}
118
119pin_project_lite::pin_project!(
120    pub struct SinkDriver<S, A> {
121        #[pin]
122        sink: S,
123        closing: bool,
124        closed: bool,
125        _arg: PhantomData<A>,
126    }
127);
128
129impl<'a, S, A> Driver<'a> for SinkDriver<S, A>
130where
131    A: Arbitrary<'a>,
132    S: Sink<A, Error: std::fmt::Debug>,
133{
134    fn poll(
135        self: Pin<&mut Self>,
136        args: &mut Unstructured<'a>,
137    ) -> arbitrary::Result<Poll<ControlFlow<()>>> {
138        let mut this = self.project();
139        let mut cx = Context::from_waker(Waker::noop());
140
141        if *this.closed {
142            return Ok(Poll::Ready(ControlFlow::Break(())));
143        }
144
145        // rare: close the sink
146        *this.closing = *this.closing || args.ratio(1u8, 255u8)?;
147        if *this.closing {
148            if let Poll::Ready(res) = this.sink.poll_close(&mut cx) {
149                res.unwrap();
150                *this.closed = true;
151                return Ok(Poll::Ready(ControlFlow::Break(())));
152            }
153        } else {
154            let Poll::Ready(res) = this.sink.as_mut().poll_ready(&mut cx) else {
155                return Ok(Poll::Pending);
156            };
157            res.unwrap();
158
159            this.sink.as_mut().start_send(args.arbitrary()?).unwrap();
160        }
161
162        Ok(Poll::Pending)
163    }
164}