1use embedded_io_async::{Read, Write};
2
3use super::BypassError;
4
5pub struct BufferedWrite<'buf, T: Write> {
9 inner: T,
10 buf: &'buf mut [u8],
11 pos: usize,
12}
13
14impl<'buf, T: Write> BufferedWrite<'buf, T> {
15 pub fn new(inner: T, buf: &'buf mut [u8]) -> Self {
17 Self { inner, buf, pos: 0 }
18 }
19
20 pub fn new_with_data(inner: T, buf: &'buf mut [u8], written: usize) -> Self {
22 Self {
23 inner,
24 buf,
25 pos: written,
26 }
27 }
28
29 pub fn is_empty(&self) -> bool {
31 self.pos == 0
32 }
33
34 pub fn written(&self) -> usize {
36 self.pos
37 }
38
39 pub fn clear(&mut self) {
41 self.pos = 0;
42 }
43
44 pub fn bypass(&mut self) -> Result<&mut T, BypassError> {
46 match self.pos {
47 0 => Ok(&mut self.inner),
48 _ => Err(BypassError),
49 }
50 }
51
52 pub fn bypass_with_buf(&mut self) -> Result<(&mut T, &mut [u8]), BypassError> {
54 match self.pos {
55 0 => Ok((&mut self.inner, self.buf)),
56 _ => Err(BypassError),
57 }
58 }
59
60 pub fn split(&mut self) -> (&mut T, &mut [u8], usize) {
62 (&mut self.inner, self.buf, self.pos)
63 }
64
65 pub fn release(self) -> T {
67 self.inner
68 }
69}
70
71impl<T: Write> embedded_io::ErrorType for BufferedWrite<'_, T> {
72 type Error = T::Error;
73}
74
75impl<T: Read + Write> Read for BufferedWrite<'_, T> {
76 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
77 self.inner.read(buf).await
78 }
79
80 async fn read_exact(
81 &mut self,
82 buf: &mut [u8],
83 ) -> Result<(), embedded_io::ReadExactError<Self::Error>> {
84 self.inner.read_exact(buf).await
85 }
86}
87
88impl<T: Write> Write for BufferedWrite<'_, T> {
89 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
90 if buf.is_empty() {
91 return Ok(0);
92 }
93 if self.pos == 0 && buf.len() >= self.buf.len() {
94 return self.inner.write(buf).await;
96 }
97
98 let buffered = usize::min(buf.len(), self.buf.len() - self.pos);
99 assert!(buffered > 0);
100
101 let mut new_pos = self.pos;
102 self.buf[new_pos..new_pos + buffered].copy_from_slice(&buf[..buffered]);
103 new_pos += buffered;
104
105 if new_pos < self.buf.len() {
106 self.pos = new_pos;
108 } else {
109 let written = self.inner.write(self.buf).await?;
111
112 if written < new_pos {
114 self.buf.copy_within(written..new_pos, 0);
116 self.pos = new_pos - written;
117 } else {
118 self.pos = 0;
119 }
120 }
121
122 Ok(buffered)
123 }
124
125 async fn flush(&mut self) -> Result<(), Self::Error> {
126 if self.pos > 0 {
127 self.inner.write_all(&self.buf[..self.pos]).await?;
128 self.pos = 0;
129 }
130
131 self.inner.flush().await
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use embedded_io::{Error, ErrorKind, ErrorType};
138
139 use super::*;
140
141 #[tokio::test]
142 async fn can_append_to_buffer() {
143 let mut inner = Vec::new();
144 let mut buf = [0; 8];
145 let mut buffered = BufferedWrite::new(&mut inner, &mut buf);
146
147 assert_eq!(2, buffered.write(&[1, 2]).await.unwrap());
148 assert_eq!(2, buffered.pos);
149 assert_eq!(0, buffered.inner.len());
150
151 assert_eq!(2, buffered.write(&[3, 4]).await.unwrap());
152 assert_eq!(4, buffered.pos);
153 assert_eq!(0, buffered.inner.len());
154
155 assert_eq!(4, buffered.write(&[5, 6, 7, 8]).await.unwrap());
156 assert_eq!(0, buffered.pos);
157 assert_eq!(8, buffered.inner.len());
158 assert_eq!(&[1, 2, 3, 4, 5, 6, 7, 8], buffered.inner.as_slice());
159 }
160
161 #[tokio::test]
162 async fn bypass_large_write_when_empty() {
163 let mut inner = Vec::new();
164 let mut buf = [0; 8];
165 let mut buffered = BufferedWrite::new(&mut inner, &mut buf);
166
167 assert_eq!(8, buffered.write(&[1, 2, 3, 4, 5, 6, 7, 8]).await.unwrap());
168 assert_eq!(0, buffered.pos);
169 assert_eq!(8, buffered.inner.len());
170 }
171
172 #[tokio::test]
173 async fn large_write_when_not_empty() {
174 let mut inner = Vec::new();
175 let mut buf = [0; 8];
176 let mut buffered = BufferedWrite::new(&mut inner, &mut buf);
177
178 assert_eq!(1, buffered.write(&[1]).await.unwrap());
179 assert_eq!(1, buffered.pos);
180 assert_eq!(0, buffered.inner.len());
181
182 assert_eq!(7, buffered.write(&[2, 3, 4, 5, 6, 7, 8, 9]).await.unwrap());
183 assert_eq!(0, buffered.pos);
184 assert_eq!(8, buffered.inner.len());
185 }
186
187 #[tokio::test]
188 async fn large_write_when_not_empty_can_handle_write_errors() {
189 let mut inner = UnstableWrite::default();
190 inner.writeable.push(0); inner.writeable.push(8); let mut buf = [0; 8];
193 let mut buffered = BufferedWrite::new(&mut inner, &mut buf);
194
195 assert_eq!(1, buffered.write(&[1]).await.unwrap());
196 assert_eq!(1, buffered.pos);
197 assert_eq!(0, buffered.inner.written.len());
198
199 assert!(buffered.write(&[2, 3, 4, 5, 6, 7, 8]).await.is_err());
200
201 assert_eq!(7, buffered.write(&[2, 3, 4, 5, 6, 7, 8]).await.unwrap());
202 assert_eq!(0, buffered.pos);
203 assert_eq!(8, buffered.inner.written.len());
204 }
205
206 #[derive(Default)]
207 struct UnstableWrite {
208 written: Vec<u8>,
209 writes: usize,
210 writeable: Vec<usize>,
211 }
212
213 #[derive(Debug)]
214 struct UnstableError;
215
216 impl core::fmt::Display for UnstableError {
217 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
218 write!(f, "UnstableError")
219 }
220 }
221
222 impl std::error::Error for UnstableError {}
223
224 impl Error for UnstableError {
225 fn kind(&self) -> ErrorKind {
226 ErrorKind::Other
227 }
228 }
229
230 impl ErrorType for UnstableWrite {
231 type Error = UnstableError;
232 }
233
234 impl Write for UnstableWrite {
235 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
236 let written = self.writeable[self.writes];
237 self.writes += 1;
238 if written > 0 {
239 self.written.extend_from_slice(&buf[..written]);
240 Ok(written)
241 } else {
242 Err(UnstableError)
243 }
244 }
245
246 async fn flush(&mut self) -> Result<(), Self::Error> {
247 Ok(())
248 }
249 }
250
251 #[tokio::test]
252 async fn flush_clears_buffer() {
253 let mut inner = Vec::new();
254 let mut buf = [0; 8];
255 let mut buffered = BufferedWrite::new(&mut inner, &mut buf);
256
257 assert_eq!(2, buffered.write(&[1, 2]).await.unwrap());
258 assert_eq!(2, buffered.pos);
259 assert_eq!(0, buffered.inner.len());
260
261 buffered.flush().await.unwrap();
262 assert_eq!(0, buffered.pos);
263 assert_eq!(2, buffered.inner.len());
264 }
265}