futures_testing/
driver.rs1use core::future::Future;
2use core::marker::PhantomData;
3use core::ops::{AsyncFnMut, ControlFlow};
4use core::pin::pin;
5use core::task::Context;
6use std::{
7 pin::Pin,
8 task::{Poll, Waker},
9};
10
11use arbitrary::{Arbitrary, Unstructured};
12use futures_util::Sink;
13
14use crate::Driver;
15
16pin_project_lite::pin_project!(
17 pub struct PollFnDriver<F, A> {
19 f: F,
20 _arg: PhantomData<A>,
21 }
22);
23
24pub fn drive_poll_fn<A, F>(f: F) -> PollFnDriver<F, A>
34where
35 A: for<'a> Arbitrary<'a>,
36 F: FnMut(A) -> Poll<ControlFlow<()>>,
37{
38 PollFnDriver {
39 f,
40 _arg: PhantomData,
41 }
42}
43
44impl<'a, A, F> Driver<'a> for PollFnDriver<F, A>
45where
46 A: Arbitrary<'a>,
47 F: FnMut(A) -> Poll<ControlFlow<()>>,
48{
49 fn poll(
50 self: Pin<&mut Self>,
51 args: &mut Unstructured<'a>,
52 ) -> arbitrary::Result<Poll<ControlFlow<()>>> {
53 Ok((self.project().f)(args.arbitrary()?))
54 }
55}
56
57pub struct AsyncFnDriver<F, A> {
59 f: F,
60 _arg: PhantomData<fn(A)>,
61}
62
63pub fn drive_fn<A, F>(f: F) -> AsyncFnDriver<F, A>
71where
72 A: for<'a> Arbitrary<'a>,
73 F: AsyncFnMut(A) -> ControlFlow<()>,
74{
75 AsyncFnDriver {
76 f,
77 _arg: PhantomData,
78 }
79}
80
81impl<'a, A, F> Driver<'a> for AsyncFnDriver<F, A>
82where
83 A: Arbitrary<'a>,
84 F: AsyncFnMut(A) -> ControlFlow<()> + Unpin,
85{
86 fn poll(
87 self: Pin<&mut Self>,
88 args: &mut Unstructured<'a>,
89 ) -> arbitrary::Result<Poll<ControlFlow<()>>> {
90 let this = self.get_mut();
91 let cx = &mut Context::from_waker(Waker::noop());
92 let arg: A = args.arbitrary()?;
93 let mut fut = pin!((this.f)(arg));
94 match fut.as_mut().poll(cx) {
95 Poll::Ready(cf) => Ok(Poll::Ready(cf)),
96 Poll::Pending => Ok(Poll::Pending),
97 }
98 }
99}
100
101pub fn drive_sink<A, S>(sink: S) -> SinkDriver<S, A>
107where
108 A: for<'a> Arbitrary<'a>,
109 S: Sink<A, Error: std::fmt::Debug>,
110{
111 SinkDriver {
112 sink,
113 closing: false,
114 closed: false,
115 _arg: PhantomData,
116 }
117}
118
119pin_project_lite::pin_project!(
120 pub struct SinkDriver<S, A> {
121 #[pin]
122 sink: S,
123 closing: bool,
124 closed: bool,
125 _arg: PhantomData<A>,
126 }
127);
128
129impl<'a, S, A> Driver<'a> for SinkDriver<S, A>
130where
131 A: Arbitrary<'a>,
132 S: Sink<A, Error: std::fmt::Debug>,
133{
134 fn poll(
135 self: Pin<&mut Self>,
136 args: &mut Unstructured<'a>,
137 ) -> arbitrary::Result<Poll<ControlFlow<()>>> {
138 let mut this = self.project();
139 let mut cx = Context::from_waker(Waker::noop());
140
141 if *this.closed {
142 return Ok(Poll::Ready(ControlFlow::Break(())));
143 }
144
145 *this.closing = *this.closing || args.ratio(1u8, 255u8)?;
147 if *this.closing {
148 if let Poll::Ready(res) = this.sink.poll_close(&mut cx) {
149 res.unwrap();
150 *this.closed = true;
151 return Ok(Poll::Ready(ControlFlow::Break(())));
152 }
153 } else {
154 let Poll::Ready(res) = this.sink.as_mut().poll_ready(&mut cx) else {
155 return Ok(Poll::Pending);
156 };
157 res.unwrap();
158
159 this.sink.as_mut().start_send(args.arbitrary()?).unwrap();
160 }
161
162 Ok(Poll::Pending)
163 }
164}