Skip to main content

dev_chaos/
async_io.rs

1//! Async IO wrappers. Available with the `async-io` feature.
2//!
3//! Mirror of [`crate::io`] for `tokio::io::AsyncRead` / `AsyncWrite`.
4//! Pulls in `tokio` minimally (no runtime, no networking, no scheduler).
5//!
6//! Schedules are shared with the sync wrappers: a [`FailureSchedule`]
7//! built once can be used in either context.
8
9use std::io;
10use std::pin::Pin;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::task::{Context, Poll};
13
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15
16use crate::FailureSchedule;
17
18/// Async equivalent of [`crate::io::ChaosReader`].
19///
20/// # Example (ignored: requires a tokio runtime)
21///
22/// ```ignore
23/// use dev_chaos::{async_io::AsyncChaosReader, FailureMode, FailureSchedule};
24/// use tokio::io::AsyncReadExt;
25///
26/// # async fn run() {
27/// let data: &[u8] = b"hello";
28/// let schedule = FailureSchedule::on_attempts(&[2], FailureMode::IoError);
29/// let mut reader = AsyncChaosReader::new(data, schedule);
30/// let mut buf = [0u8; 1];
31/// reader.read(&mut buf).await.unwrap();
32/// # }
33/// ```
34pub struct AsyncChaosReader<R: AsyncRead + Unpin> {
35    inner: R,
36    schedule: FailureSchedule,
37    attempt: AtomicUsize,
38}
39
40impl<R: AsyncRead + Unpin> AsyncChaosReader<R> {
41    /// Wrap `inner` with the given schedule.
42    pub fn new(inner: R, schedule: FailureSchedule) -> Self {
43        Self {
44            inner,
45            schedule,
46            attempt: AtomicUsize::new(0),
47        }
48    }
49
50    /// Number of poll attempts so far.
51    pub fn attempt_count(&self) -> usize {
52        self.attempt.load(Ordering::Relaxed)
53    }
54}
55
56impl<R: AsyncRead + Unpin> AsyncRead for AsyncChaosReader<R> {
57    fn poll_read(
58        mut self: Pin<&mut Self>,
59        cx: &mut Context<'_>,
60        buf: &mut ReadBuf<'_>,
61    ) -> Poll<io::Result<()>> {
62        let n = self.attempt.fetch_add(1, Ordering::Relaxed) + 1;
63        if let Err(f) = self.schedule.maybe_fail(n) {
64            return Poll::Ready(Err(f.into()));
65        }
66        Pin::new(&mut self.inner).poll_read(cx, buf)
67    }
68}
69
70/// Async equivalent of [`crate::io::ChaosWriter`].
71///
72/// # Example (ignored: requires a tokio runtime)
73///
74/// ```ignore
75/// use dev_chaos::{async_io::AsyncChaosWriter, FailureMode, FailureSchedule};
76/// use tokio::io::AsyncWriteExt;
77///
78/// # async fn run() {
79/// let mut sink: Vec<u8> = Vec::new();
80/// let schedule = FailureSchedule::on_attempts(&[2], FailureMode::Timeout);
81/// let mut writer = AsyncChaosWriter::new(&mut sink, schedule);
82/// writer.write_all(b"a").await.unwrap();
83/// # }
84/// ```
85pub struct AsyncChaosWriter<W: AsyncWrite + Unpin> {
86    inner: W,
87    schedule: FailureSchedule,
88    attempt: AtomicUsize,
89}
90
91impl<W: AsyncWrite + Unpin> AsyncChaosWriter<W> {
92    /// Wrap `inner` with the given schedule.
93    pub fn new(inner: W, schedule: FailureSchedule) -> Self {
94        Self {
95            inner,
96            schedule,
97            attempt: AtomicUsize::new(0),
98        }
99    }
100
101    /// Number of write attempts so far.
102    pub fn attempt_count(&self) -> usize {
103        self.attempt.load(Ordering::Relaxed)
104    }
105}
106
107impl<W: AsyncWrite + Unpin> AsyncWrite for AsyncChaosWriter<W> {
108    fn poll_write(
109        mut self: Pin<&mut Self>,
110        cx: &mut Context<'_>,
111        buf: &[u8],
112    ) -> Poll<io::Result<usize>> {
113        let n = self.attempt.fetch_add(1, Ordering::Relaxed) + 1;
114        if let Err(f) = self.schedule.maybe_fail(n) {
115            return Poll::Ready(Err(f.into()));
116        }
117        Pin::new(&mut self.inner).poll_write(cx, buf)
118    }
119
120    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
121        Pin::new(&mut self.inner).poll_flush(cx)
122    }
123
124    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
125        Pin::new(&mut self.inner).poll_shutdown(cx)
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::FailureMode;
133    use tokio::io::{AsyncReadExt, AsyncWriteExt};
134
135    #[tokio::test(flavor = "current_thread")]
136    async fn async_reader_passes_through_then_fails() {
137        let data: Vec<u8> = b"hello".to_vec();
138        let cursor = std::io::Cursor::new(data);
139        let schedule = FailureSchedule::on_attempts(&[2], FailureMode::Timeout);
140        let mut reader = AsyncChaosReader::new(cursor, schedule);
141        let mut buf = [0u8; 1];
142        reader.read_exact(&mut buf).await.unwrap();
143        let err = reader.read_exact(&mut buf).await.unwrap_err();
144        assert_eq!(err.kind(), io::ErrorKind::TimedOut);
145    }
146
147    #[tokio::test(flavor = "current_thread")]
148    async fn async_writer_writes_then_fails() {
149        let sink: Vec<u8> = Vec::new();
150        let schedule = FailureSchedule::on_attempts(&[2], FailureMode::ConnectionReset);
151        let mut writer = AsyncChaosWriter::new(sink, schedule);
152        writer.write_all(b"a").await.unwrap();
153        let err = writer.write_all(b"b").await.unwrap_err();
154        assert_eq!(err.kind(), io::ErrorKind::ConnectionReset);
155        let sink = writer.inner;
156        assert_eq!(sink, b"a");
157    }
158}