1use std::io;
10use std::pin::Pin;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::task::{Context, Poll};
13
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15
16use crate::FailureSchedule;
17
18pub struct AsyncChaosReader<R: AsyncRead + Unpin> {
35 inner: R,
36 schedule: FailureSchedule,
37 attempt: AtomicUsize,
38}
39
40impl<R: AsyncRead + Unpin> AsyncChaosReader<R> {
41 pub fn new(inner: R, schedule: FailureSchedule) -> Self {
43 Self {
44 inner,
45 schedule,
46 attempt: AtomicUsize::new(0),
47 }
48 }
49
50 pub fn attempt_count(&self) -> usize {
52 self.attempt.load(Ordering::Relaxed)
53 }
54}
55
56impl<R: AsyncRead + Unpin> AsyncRead for AsyncChaosReader<R> {
57 fn poll_read(
58 mut self: Pin<&mut Self>,
59 cx: &mut Context<'_>,
60 buf: &mut ReadBuf<'_>,
61 ) -> Poll<io::Result<()>> {
62 let n = self.attempt.fetch_add(1, Ordering::Relaxed) + 1;
63 if let Err(f) = self.schedule.maybe_fail(n) {
64 return Poll::Ready(Err(f.into()));
65 }
66 Pin::new(&mut self.inner).poll_read(cx, buf)
67 }
68}
69
70pub struct AsyncChaosWriter<W: AsyncWrite + Unpin> {
86 inner: W,
87 schedule: FailureSchedule,
88 attempt: AtomicUsize,
89}
90
91impl<W: AsyncWrite + Unpin> AsyncChaosWriter<W> {
92 pub fn new(inner: W, schedule: FailureSchedule) -> Self {
94 Self {
95 inner,
96 schedule,
97 attempt: AtomicUsize::new(0),
98 }
99 }
100
101 pub fn attempt_count(&self) -> usize {
103 self.attempt.load(Ordering::Relaxed)
104 }
105}
106
107impl<W: AsyncWrite + Unpin> AsyncWrite for AsyncChaosWriter<W> {
108 fn poll_write(
109 mut self: Pin<&mut Self>,
110 cx: &mut Context<'_>,
111 buf: &[u8],
112 ) -> Poll<io::Result<usize>> {
113 let n = self.attempt.fetch_add(1, Ordering::Relaxed) + 1;
114 if let Err(f) = self.schedule.maybe_fail(n) {
115 return Poll::Ready(Err(f.into()));
116 }
117 Pin::new(&mut self.inner).poll_write(cx, buf)
118 }
119
120 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
121 Pin::new(&mut self.inner).poll_flush(cx)
122 }
123
124 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
125 Pin::new(&mut self.inner).poll_shutdown(cx)
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::FailureMode;
133 use tokio::io::{AsyncReadExt, AsyncWriteExt};
134
135 #[tokio::test(flavor = "current_thread")]
136 async fn async_reader_passes_through_then_fails() {
137 let data: Vec<u8> = b"hello".to_vec();
138 let cursor = std::io::Cursor::new(data);
139 let schedule = FailureSchedule::on_attempts(&[2], FailureMode::Timeout);
140 let mut reader = AsyncChaosReader::new(cursor, schedule);
141 let mut buf = [0u8; 1];
142 reader.read_exact(&mut buf).await.unwrap();
143 let err = reader.read_exact(&mut buf).await.unwrap_err();
144 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
145 }
146
147 #[tokio::test(flavor = "current_thread")]
148 async fn async_writer_writes_then_fails() {
149 let sink: Vec<u8> = Vec::new();
150 let schedule = FailureSchedule::on_attempts(&[2], FailureMode::ConnectionReset);
151 let mut writer = AsyncChaosWriter::new(sink, schedule);
152 writer.write_all(b"a").await.unwrap();
153 let err = writer.write_all(b"b").await.unwrap_err();
154 assert_eq!(err.kind(), io::ErrorKind::ConnectionReset);
155 let sink = writer.inner;
156 assert_eq!(sink, b"a");
157 }
158}