buffered_io/asynch/
write.rs

1use embedded_io_async::{Read, Write};
2
3use super::BypassError;
4
5/// A buffered [`Write`]
6///
7/// The BufferedWrite will write into the provided buffer to avoid small writes to the inner writer.
8pub struct BufferedWrite<'buf, T: Write> {
9    inner: T,
10    buf: &'buf mut [u8],
11    pos: usize,
12}
13
14impl<'buf, T: Write> BufferedWrite<'buf, T> {
15    /// Create a new buffered writer
16    pub fn new(inner: T, buf: &'buf mut [u8]) -> Self {
17        Self { inner, buf, pos: 0 }
18    }
19
20    /// Create a new buffered writer with a pre-polulated buffer
21    pub fn new_with_data(inner: T, buf: &'buf mut [u8], written: usize) -> Self {
22        Self {
23            inner,
24            buf,
25            pos: written,
26        }
27    }
28
29    /// Get whether there are any bytes currently buffered
30    pub fn is_empty(&self) -> bool {
31        self.pos == 0
32    }
33
34    /// Get the number of bytes that are currently buffered but not yet written to the inner writer
35    pub fn written(&self) -> usize {
36        self.pos
37    }
38
39    /// Clear the currently buffered, written bytes
40    pub fn clear(&mut self) {
41        self.pos = 0;
42    }
43
44    /// Get the inner writer if there are no currently buffered, written bytes
45    pub fn bypass(&mut self) -> Result<&mut T, BypassError> {
46        match self.pos {
47            0 => Ok(&mut self.inner),
48            _ => Err(BypassError),
49        }
50    }
51
52    /// Get the inner writer if there are no currently buffered, written bytes, and rent the buffer
53    pub fn bypass_with_buf(&mut self) -> Result<(&mut T, &mut [u8]), BypassError> {
54        match self.pos {
55            0 => Ok((&mut self.inner, self.buf)),
56            _ => Err(BypassError),
57        }
58    }
59
60    /// Split the writer to get the inner components
61    pub fn split(&mut self) -> (&mut T, &mut [u8], usize) {
62        (&mut self.inner, self.buf, self.pos)
63    }
64
65    /// Release and get the inner writer
66    pub fn release(self) -> T {
67        self.inner
68    }
69}
70
71impl<T: Write> embedded_io::ErrorType for BufferedWrite<'_, T> {
72    type Error = T::Error;
73}
74
75impl<T: Read + Write> Read for BufferedWrite<'_, T> {
76    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
77        self.inner.read(buf).await
78    }
79
80    async fn read_exact(
81        &mut self,
82        buf: &mut [u8],
83    ) -> Result<(), embedded_io::ReadExactError<Self::Error>> {
84        self.inner.read_exact(buf).await
85    }
86}
87
88impl<T: Write> Write for BufferedWrite<'_, T> {
89    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
90        if buf.is_empty() {
91            return Ok(0);
92        }
93        if self.pos == 0 && buf.len() >= self.buf.len() {
94            // Fast path - nothing in buffer and the buffer to write is large
95            return self.inner.write(buf).await;
96        }
97
98        let buffered = usize::min(buf.len(), self.buf.len() - self.pos);
99        assert!(buffered > 0);
100
101        let mut new_pos = self.pos;
102        self.buf[new_pos..new_pos + buffered].copy_from_slice(&buf[..buffered]);
103        new_pos += buffered;
104
105        if new_pos < self.buf.len() {
106            // The buffer to write could fit in the buffer
107            self.pos = new_pos;
108        } else {
109            // The buffer is full
110            let written = self.inner.write(self.buf).await?;
111
112            // We only assign self.pos _after_ we are sure that the write has completed successfully
113            if written < new_pos {
114                // We only partially wrote the inner buffer
115                self.buf.copy_within(written..new_pos, 0);
116                self.pos = new_pos - written;
117            } else {
118                self.pos = 0;
119            }
120        }
121
122        Ok(buffered)
123    }
124
125    async fn flush(&mut self) -> Result<(), Self::Error> {
126        if self.pos > 0 {
127            self.inner.write_all(&self.buf[..self.pos]).await?;
128            self.pos = 0;
129        }
130
131        self.inner.flush().await
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use embedded_io::{Error, ErrorKind, ErrorType};
138
139    use super::*;
140
141    #[tokio::test]
142    async fn can_append_to_buffer() {
143        let mut inner = Vec::new();
144        let mut buf = [0; 8];
145        let mut buffered = BufferedWrite::new(&mut inner, &mut buf);
146
147        assert_eq!(2, buffered.write(&[1, 2]).await.unwrap());
148        assert_eq!(2, buffered.pos);
149        assert_eq!(0, buffered.inner.len());
150
151        assert_eq!(2, buffered.write(&[3, 4]).await.unwrap());
152        assert_eq!(4, buffered.pos);
153        assert_eq!(0, buffered.inner.len());
154
155        assert_eq!(4, buffered.write(&[5, 6, 7, 8]).await.unwrap());
156        assert_eq!(0, buffered.pos);
157        assert_eq!(8, buffered.inner.len());
158        assert_eq!(&[1, 2, 3, 4, 5, 6, 7, 8], buffered.inner.as_slice());
159    }
160
161    #[tokio::test]
162    async fn bypass_large_write_when_empty() {
163        let mut inner = Vec::new();
164        let mut buf = [0; 8];
165        let mut buffered = BufferedWrite::new(&mut inner, &mut buf);
166
167        assert_eq!(8, buffered.write(&[1, 2, 3, 4, 5, 6, 7, 8]).await.unwrap());
168        assert_eq!(0, buffered.pos);
169        assert_eq!(8, buffered.inner.len());
170    }
171
172    #[tokio::test]
173    async fn large_write_when_not_empty() {
174        let mut inner = Vec::new();
175        let mut buf = [0; 8];
176        let mut buffered = BufferedWrite::new(&mut inner, &mut buf);
177
178        assert_eq!(1, buffered.write(&[1]).await.unwrap());
179        assert_eq!(1, buffered.pos);
180        assert_eq!(0, buffered.inner.len());
181
182        assert_eq!(7, buffered.write(&[2, 3, 4, 5, 6, 7, 8, 9]).await.unwrap());
183        assert_eq!(0, buffered.pos);
184        assert_eq!(8, buffered.inner.len());
185    }
186
187    #[tokio::test]
188    async fn large_write_when_not_empty_can_handle_write_errors() {
189        let mut inner = UnstableWrite::default();
190        inner.writeable.push(0); // Return error
191        inner.writeable.push(8); // Write all bytes
192        let mut buf = [0; 8];
193        let mut buffered = BufferedWrite::new(&mut inner, &mut buf);
194
195        assert_eq!(1, buffered.write(&[1]).await.unwrap());
196        assert_eq!(1, buffered.pos);
197        assert_eq!(0, buffered.inner.written.len());
198
199        assert!(buffered.write(&[2, 3, 4, 5, 6, 7, 8]).await.is_err());
200
201        assert_eq!(7, buffered.write(&[2, 3, 4, 5, 6, 7, 8]).await.unwrap());
202        assert_eq!(0, buffered.pos);
203        assert_eq!(8, buffered.inner.written.len());
204    }
205
206    #[derive(Default)]
207    struct UnstableWrite {
208        written: Vec<u8>,
209        writes: usize,
210        writeable: Vec<usize>,
211    }
212
213    #[derive(Debug)]
214    struct UnstableError;
215
216    impl core::fmt::Display for UnstableError {
217        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
218            write!(f, "UnstableError")
219        }
220    }
221
222    impl std::error::Error for UnstableError {}
223
224    impl Error for UnstableError {
225        fn kind(&self) -> ErrorKind {
226            ErrorKind::Other
227        }
228    }
229
230    impl ErrorType for UnstableWrite {
231        type Error = UnstableError;
232    }
233
234    impl Write for UnstableWrite {
235        async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
236            let written = self.writeable[self.writes];
237            self.writes += 1;
238            if written > 0 {
239                self.written.extend_from_slice(&buf[..written]);
240                Ok(written)
241            } else {
242                Err(UnstableError)
243            }
244        }
245
246        async fn flush(&mut self) -> Result<(), Self::Error> {
247            Ok(())
248        }
249    }
250
251    #[tokio::test]
252    async fn flush_clears_buffer() {
253        let mut inner = Vec::new();
254        let mut buf = [0; 8];
255        let mut buffered = BufferedWrite::new(&mut inner, &mut buf);
256
257        assert_eq!(2, buffered.write(&[1, 2]).await.unwrap());
258        assert_eq!(2, buffered.pos);
259        assert_eq!(0, buffered.inner.len());
260
261        buffered.flush().await.unwrap();
262        assert_eq!(0, buffered.pos);
263        assert_eq!(2, buffered.inner.len());
264    }
265}