aahc/util/
io.rs

1use futures_core::ready;
2use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
3use std::io::{IoSliceMut, Result};
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7/// A set of additional utility functions available on any type implementing `AsyncBufRead`.
8pub(crate) trait AsyncBufReadExt: AsyncBufRead {
9	/// Fills the internal buffer, then invokes a callback which can consume some bytes from that
10	/// buffer.
11	///
12	/// This function adapts the `poll_fill_buf` and `consume` methods into a proper future.
13	///
14	/// The type parameter `CallbackReturn` is the type returned by the callback, and `Callback` is
15	/// the callback itself.
16	///
17	/// The parameter `callback` is the callback function which can use data from the buffer. It is
18	/// called at most once per call to `read_buf` (it may not be called at all if an error
19	/// occurs). It is passed the bytes in the buffer. Its return value identifies how many bytes
20	/// to consume from the buffer, along with an arbitrary value to pass back to the caller of
21	/// `read_buf`.
22	fn read_buf<CallbackReturn, Callback: FnOnce(&'_ [u8]) -> (usize, CallbackReturn) + Unpin>(
23		self: Pin<&mut Self>,
24		callback: Callback,
25	) -> ReadBufFuture<'_, Self, CallbackReturn, Callback> {
26		ReadBufFuture {
27			source: self,
28			callback: Some(callback),
29		}
30	}
31}
32
33impl<R: AsyncBufRead + ?Sized> AsyncBufReadExt for R {}
34
35/// A set of additional utility functions available on any type implementing `AsyncRead`.
36pub(crate) trait AsyncReadExt: AsyncRead {
37	/// Reads data to a caller-provided buffer.
38	#[cfg(test)]
39	fn read<'buffer>(
40		self: Pin<&mut Self>,
41		buffer: &'buffer mut [u8],
42	) -> ReadFuture<'_, 'buffer, Self> {
43		ReadFuture {
44			source: self,
45			buffer,
46		}
47	}
48
49	/// Reads data to multiple caller-provided buffers.
50	#[cfg(test)]
51	fn read_vectored<'buffers>(
52		self: Pin<&mut Self>,
53		buffers: &'buffers mut [IoSliceMut<'buffers>],
54	) -> ReadVectoredFuture<'_, 'buffers, Self> {
55		ReadVectoredFuture {
56			source: self,
57			buffers,
58		}
59	}
60
61	/// Performs a vectored read with an additional length limit.
62	#[cfg(test)]
63	fn read_vectored_bounded<'bufs>(
64		self: Pin<&mut Self>,
65		bufs: &'bufs mut [IoSliceMut<'bufs>],
66		limit: u64,
67	) -> ReadVectoredBoundedFuture<'_, 'bufs, Self> {
68		ReadVectoredBoundedFuture {
69			source: self,
70			bufs,
71			limit,
72		}
73	}
74
75	/// A wrapper around [`futures_io::AsyncRead::poll_read_vectored`] that limits the amount of
76	/// data returned to a specified quantity.
77	fn poll_read_vectored_bounded(
78		self: Pin<&mut Self>,
79		cx: &mut Context<'_>,
80		bufs: &mut [IoSliceMut<'_>],
81		limit: u64,
82	) -> Poll<Result<usize>> {
83		if limit == 0 {
84			Ok(0).into()
85		} else {
86			let limit = std::cmp::min(limit, usize::MAX as u64) as usize;
87			let first_buffer = &mut bufs[0];
88			if first_buffer.len() >= limit {
89				// The first IoSlice alone covers at least limit bytes. Our read will only touch that
90				// slice and no more.
91				self.poll_read(cx, &mut first_buffer[..limit])
92			} else {
93				// The first IoSlice alone is smaller than limit bytes. Do a vectored read. To avoid
94				// modifying any of the IoSlices, choose only enough IoSlices to add up to ≤limit
95				// bytes. This might even mean just one IoSlice (if the first slice is smaller than
96				// limit but the first two added are larger), but in general it could be more than one.
97				let buf_count: usize = bufs
98					.iter()
99					.scan(0_usize, |size_so_far, elt| {
100						*size_so_far += elt.len();
101						Some(*size_so_far > limit)
102					})
103					.enumerate()
104					.find(|elt| elt.1)
105					.unwrap_or((bufs.len(), false))
106					.0;
107				self.poll_read_vectored(cx, &mut bufs[..buf_count])
108			}
109		}
110	}
111}
112
113impl<R: AsyncRead + ?Sized> AsyncReadExt for R {}
114
115/// A set of additional utility functions available on any type implementing `AsyncWrite`.
116pub(crate) trait AsyncWriteExt: AsyncWrite {
117	/// Writes a block of bytes to the writeable.
118	///
119	/// This function performs repeated writes into the writeable until the entire requested data
120	/// has been written.
121	fn write_all<'a>(self: Pin<&'a mut Self>, data: &'a [u8]) -> WriteAllFuture<'a, Self> {
122		WriteAllFuture { sink: self, data }
123	}
124}
125
126impl<W: AsyncWrite + ?Sized> AsyncWriteExt for W {}
127
128/// A future that fills an `AsyncBufRead`’s internal buffer and then invokes a callback to consume
129/// some or all of the data.
130#[derive(Debug)]
131pub(crate) struct ReadBufFuture<
132	'source,
133	Source: AsyncBufRead + ?Sized,
134	CallbackReturn,
135	Callback: FnOnce(&[u8]) -> (usize, CallbackReturn) + Unpin,
136> {
137	source: Pin<&'source mut Source>,
138	callback: Option<Callback>,
139}
140
141impl<
142	Source: AsyncBufRead + ?Sized,
143	CallbackReturn,
144	Callback: FnOnce(&[u8]) -> (usize, CallbackReturn) + Unpin,
145> Future for ReadBufFuture<'_, Source, CallbackReturn, Callback>
146{
147	type Output = Result<CallbackReturn>;
148
149	fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150		let this = self.get_mut();
151		let data = ready!(this.source.as_mut().poll_fill_buf(cx))?;
152		let (consumed, ret) = (this.callback.take().unwrap())(data);
153		this.source.as_mut().consume(consumed);
154		Ok(ret).into()
155	}
156}
157
158/// A future that reads from an `AsyncRead` into a single caller-provided buffer.
159#[cfg(test)]
160#[derive(Debug)]
161pub(crate) struct ReadFuture<'source, 'buffer, Source: AsyncRead + ?Sized> {
162	source: Pin<&'source mut Source>,
163	buffer: &'buffer mut [u8],
164}
165
166#[cfg(test)]
167impl<Source: AsyncRead + ?Sized> Future for ReadFuture<'_, '_, Source> {
168	type Output = Result<usize>;
169
170	fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
171		let this = self.get_mut();
172		this.source.as_mut().poll_read(cx, this.buffer)
173	}
174}
175
176/// A future that reads from an `AsyncRead` into a collection of caller-provided buffers.
177#[cfg(test)]
178#[derive(Debug)]
179pub(crate) struct ReadVectoredFuture<'source, 'buffers, Source: AsyncRead + ?Sized> {
180	source: Pin<&'source mut Source>,
181	buffers: &'buffers mut [IoSliceMut<'buffers>],
182}
183
184#[cfg(test)]
185impl<Source: AsyncRead + ?Sized> Future for ReadVectoredFuture<'_, '_, Source> {
186	type Output = Result<usize>;
187
188	fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189		let this = self.get_mut();
190		this.source.as_mut().poll_read_vectored(cx, this.buffers)
191	}
192}
193
194/// A future that reads from an `AsyncRead` into a collection of caller-provided buffers, with an
195/// additional length limit.
196#[cfg(test)]
197#[derive(Debug)]
198pub(crate) struct ReadVectoredBoundedFuture<'source, 'bufs, Source: AsyncRead + ?Sized> {
199	source: Pin<&'source mut Source>,
200	bufs: &'bufs mut [IoSliceMut<'bufs>],
201	limit: u64,
202}
203
204#[cfg(test)]
205impl<Source: AsyncRead + ?Sized> Future for ReadVectoredBoundedFuture<'_, '_, Source> {
206	type Output = Result<usize>;
207
208	fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
209		let this = self.get_mut();
210		this.source
211			.as_mut()
212			.poll_read_vectored_bounded(cx, this.bufs, this.limit)
213	}
214}
215
216/// A future that writes all of an array to an `AsyncWrite`.
217#[derive(Debug)]
218pub(crate) struct WriteAllFuture<'a, T: AsyncWrite + ?Sized> {
219	sink: Pin<&'a mut T>,
220	data: &'a [u8],
221}
222
223impl<T: AsyncWrite + ?Sized> Future for WriteAllFuture<'_, T> {
224	type Output = Result<()>;
225
226	fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
227		while !self.data.is_empty() {
228			let data = self.data;
229			let bytes_written = ready!(self.sink.as_mut().poll_write(cx, data))?;
230			self.data = &self.data[bytes_written..];
231		}
232		Ok(()).into()
233	}
234}
235
236/// Issues repeated reads until the caller-provided buffer is full.
237#[cfg(test)]
238pub(crate) async fn read_all<Source: AsyncRead + ?Sized>(
239	mut src: Pin<&mut Source>,
240	mut buffer: &mut [u8],
241) -> Result<()> {
242	while !buffer.is_empty() {
243		let bytes_read = src.as_mut().read(buffer).await?;
244		if bytes_read == 0 {
245			return Err(std::io::ErrorKind::UnexpectedEof.into());
246		} else {
247			buffer = &mut buffer[bytes_read..];
248		}
249	}
250	Ok(())
251}
252
253#[cfg(test)]
254mod test {
255	use super::*;
256	use futures_executor::block_on;
257
258	/// Tests calling `read` on a source.
259	#[test]
260	fn read() {
261		block_on(async {
262			let mut src: &[u8] = &b"abcdefgh"[..];
263			let mut buffer = [0_u8; 4];
264			let bytes_read = Pin::new(&mut src).read(&mut buffer[..]).await.unwrap();
265			assert_eq!(bytes_read, 4);
266			assert_eq!(&buffer, b"abcd");
267		});
268	}
269
270	/// Tests calling `read_vectored` on a source.
271	#[test]
272	fn read_vectored() {
273		block_on(async {
274			let mut src: &[u8] = &b"abcdefgh"[..];
275			let mut buf1 = [0_u8; 4];
276			let mut buf2 = [0_u8; 2];
277			let mut slices = [IoSliceMut::new(&mut buf1), IoSliceMut::new(&mut buf2)];
278			let bytes_read = Pin::new(&mut src).read_vectored(&mut slices).await.unwrap();
279			assert_eq!(bytes_read, 6);
280			assert_eq!(&buf1, b"abcd");
281			assert_eq!(&buf2, b"ef");
282		});
283	}
284
285	/// Tests calling `poll_read_vectored_bounded` with a limit small enough to fill less than one
286	/// buffer.
287	#[test]
288	fn poll_read_vectored_bounded_one_partial() {
289		block_on(async {
290			let mut src: &[u8] = &b"abcdefgh"[..];
291			let mut buf1 = [0_u8; 4];
292			let mut buf2 = [0_u8; 4];
293			let mut slices = [
294				IoSliceMut::new(&mut buf1[..]),
295				IoSliceMut::new(&mut buf2[..]),
296			];
297			let bytes_read = Pin::new(&mut src)
298				.read_vectored_bounded(&mut slices, 3)
299				.await
300				.unwrap();
301			assert_eq!(bytes_read, 3);
302			assert_eq!(&buf1, b"abc\0");
303			assert_eq!(&buf2, b"\0\0\0\0");
304		});
305	}
306
307	/// Tests calling `poll_read_vectored_bounded` with a limit large enough to fill the first
308	/// buffer, but not the first two.
309	#[test]
310	fn poll_read_vectored_bounded_one_full() {
311		block_on(async {
312			let mut src: &[u8] = &b"abcdefgh"[..];
313			let mut buf1 = [0_u8; 4];
314			let mut buf2 = [0_u8; 4];
315			let mut slices = [
316				IoSliceMut::new(&mut buf1[..]),
317				IoSliceMut::new(&mut buf2[..]),
318			];
319			let bytes_read = Pin::new(&mut src)
320				.read_vectored_bounded(&mut slices, 5)
321				.await
322				.unwrap();
323			assert_eq!(bytes_read, 4);
324			assert_eq!(&buf1, b"abcd");
325			assert_eq!(&buf2, b"\0\0\0\0");
326		});
327	}
328
329	/// Tests calling `poll_read_vectored_bounded` with a limit large enough to fill the first two
330	/// buffers.
331	#[test]
332	fn poll_read_vectored_bounded_two_full() {
333		block_on(async {
334			let mut src: &[u8] = &b"abcdefghij"[..];
335			let mut buf1 = [0_u8; 4];
336			let mut buf2 = [0_u8; 4];
337			let mut slices = [
338				IoSliceMut::new(&mut buf1[..]),
339				IoSliceMut::new(&mut buf2[..]),
340			];
341			let bytes_read = Pin::new(&mut src)
342				.read_vectored_bounded(&mut slices, 10)
343				.await
344				.unwrap();
345			assert_eq!(bytes_read, 8);
346			assert_eq!(&buf1, b"abcd");
347			assert_eq!(&buf2, b"efgh");
348		});
349	}
350
351	/// Tests calling `write_all` on a sink that accepts unlimited data at a time.
352	#[test]
353	fn write_all_fast() {
354		struct Test {
355			v: Vec<u8>,
356		}
357		impl AsyncWrite for Test {
358			fn poll_write(
359				mut self: Pin<&mut Self>,
360				_cx: &mut Context<'_>,
361				buf: &[u8],
362			) -> Poll<Result<usize>> {
363				self.v.extend_from_slice(buf);
364				Ok(buf.len()).into()
365			}
366
367			fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
368				panic!("Should not be called");
369			}
370
371			fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
372				panic!("Should not be called");
373			}
374		}
375		let mut t = Test { v: vec![] };
376		block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap();
377		assert_eq!(t.v.as_slice(), b"abcdefgh");
378	}
379
380	/// Tests calling `write_all` on a sink that accepts data only one byte at a time.
381	#[test]
382	fn write_all_slow() {
383		struct Test {
384			v: Vec<u8>,
385		}
386		impl AsyncWrite for Test {
387			fn poll_write(
388				mut self: Pin<&mut Self>,
389				_cx: &mut Context<'_>,
390				buf: &[u8],
391			) -> Poll<Result<usize>> {
392				match buf.first() {
393					None => Ok(0).into(),
394					Some(&b) => {
395						self.v.push(b);
396						Ok(1).into()
397					}
398				}
399			}
400
401			fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
402				panic!("Should not be called");
403			}
404
405			fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
406				panic!("Should not be called");
407			}
408		}
409		let mut t = Test { v: vec![] };
410		block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap();
411		assert_eq!(t.v.as_slice(), b"abcdefgh");
412	}
413
414	/// Tests calling `write_all` on a sink that returns an error.
415	#[test]
416	fn write_all_error() {
417		struct Test {
418			already_called: bool,
419		}
420		impl AsyncWrite for Test {
421			fn poll_write(
422				mut self: Pin<&mut Self>,
423				_cx: &mut Context<'_>,
424				_data: &[u8],
425			) -> Poll<Result<usize>> {
426				assert!(!self.already_called);
427				self.already_called = true;
428				Err(std::io::Error::other("Test error message")).into()
429			}
430
431			fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
432				panic!("Should not be called");
433			}
434
435			fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
436				panic!("Should not be called");
437			}
438		}
439		let mut t = Test {
440			already_called: false,
441		};
442		let e = block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap_err();
443		assert_eq!(e.kind(), std::io::ErrorKind::Other);
444		assert_eq!(format!("{e}"), "Test error message");
445	}
446}