mock_io/mock_stream/
futures.rs

1use std::{
2    future::Future,
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use async_channel::{unbounded, Receiver, Sender};
9use futures_io::{AsyncRead, AsyncWrite};
10use pin_project_lite::pin_project;
11
12use crate::{error::Error, futures::Handle};
13
14macro_rules! ready {
15    ($e:expr $(,)?) => {
16        match $e {
17            Poll::Ready(t) => t,
18            Poll::Pending => return Poll::Pending,
19        }
20    };
21}
22
23pin_project! {
24    /// Asynchronous mock IO stream
25    #[derive(Debug)]
26    pub struct MockStream {
27        #[pin]
28        read_half: ReadHalf,
29        #[pin]
30        write_half: WriteHalf,
31    }
32}
33
34impl MockStream {
35    /// Connects to a mock IO listener
36    pub async fn connect(handle: &Handle) -> Result<Self, Error> {
37        let (stream_1, stream_2) = Self::pair();
38        handle.send(stream_2).await?;
39        Ok(stream_1)
40    }
41
42    /// Creates a pair of connected mock streams
43    pub fn pair() -> (Self, Self) {
44        let (sender_1, receiver_1) = unbounded();
45        let (sender_2, receiver_2) = unbounded();
46
47        let stream_1 = Self {
48            read_half: ReadHalf {
49                receiver: receiver_1,
50                remaining: Default::default(),
51            },
52            write_half: WriteHalf { sender: sender_2 },
53        };
54
55        let stream_2 = Self {
56            read_half: ReadHalf {
57                receiver: receiver_2,
58                remaining: Default::default(),
59            },
60            write_half: WriteHalf { sender: sender_1 },
61        };
62
63        (stream_1, stream_2)
64    }
65
66    /// Splits the stream into separate read and write halves
67    pub fn split(self) -> (ReadHalf, WriteHalf) {
68        (self.read_half, self.write_half)
69    }
70}
71
72impl AsyncRead for MockStream {
73    fn poll_read(
74        self: Pin<&mut Self>,
75        cx: &mut Context<'_>,
76        buf: &mut [u8],
77    ) -> Poll<io::Result<usize>> {
78        self.project().read_half.poll_read(cx, buf)
79    }
80}
81
82impl AsyncWrite for MockStream {
83    fn poll_write(
84        self: Pin<&mut Self>,
85        cx: &mut Context<'_>,
86        buf: &[u8],
87    ) -> Poll<io::Result<usize>> {
88        self.project().write_half.poll_write(cx, buf)
89    }
90
91    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
92        self.project().write_half.poll_flush(cx)
93    }
94
95    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
96        self.project().write_half.poll_close(cx)
97    }
98}
99
100/// Read half of asynchronous mock IO stream
101#[derive(Debug)]
102pub struct ReadHalf {
103    receiver: Receiver<Vec<u8>>,
104    remaining: Vec<u8>,
105}
106
107impl ReadHalf {
108    async fn receive(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
109        let available_space = buf.len();
110
111        if self.remaining.is_empty() {
112            self.remaining = self.receiver.recv().await?;
113        }
114
115        let remaining_len = self.remaining.len();
116
117        if remaining_len > available_space {
118            buf.copy_from_slice(&self.remaining[..available_space]);
119            self.remaining = self.remaining[available_space..].to_vec();
120
121            Ok(available_space)
122        } else {
123            buf[..remaining_len].copy_from_slice(&self.remaining);
124            self.remaining = Default::default();
125
126            Ok(remaining_len)
127        }
128    }
129}
130
131impl AsyncRead for ReadHalf {
132    fn poll_read(
133        self: Pin<&mut Self>,
134        cx: &mut Context<'_>,
135        buf: &mut [u8],
136    ) -> Poll<io::Result<usize>> {
137        let mut future = Box::pin(self.get_mut().receive(buf));
138        let result = ready!(future.as_mut().poll(cx));
139
140        Poll::Ready(result.map_err(Into::into))
141    }
142}
143
144/// Write half of asynchronous mock IO stream
145#[derive(Debug, Clone)]
146pub struct WriteHalf {
147    sender: Sender<Vec<u8>>,
148}
149
150impl WriteHalf {
151    async fn send(&mut self, bytes: &[u8]) -> Result<usize, Error> {
152        self.sender
153            .send(bytes.to_vec())
154            .await
155            .map(|_| bytes.len())
156            .map_err(Into::into)
157    }
158}
159
160impl AsyncWrite for WriteHalf {
161    fn poll_write(
162        self: Pin<&mut Self>,
163        cx: &mut Context<'_>,
164        buf: &[u8],
165    ) -> Poll<io::Result<usize>> {
166        let mut future = Box::pin(self.get_mut().send(buf));
167        let result = ready!(future.as_mut().poll(cx));
168
169        Poll::Ready(result.map_err(Into::into))
170    }
171
172    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
173        Poll::Ready(Ok(()))
174    }
175
176    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
177        let _ = self.sender.close();
178        Poll::Ready(Ok(()))
179    }
180}