aahc/util/
io.rs

1use futures_core::ready;
2use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
3use std::future::Future;
4use std::io::{IoSliceMut, Result};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8/// A set of additional utility functions available on any type implementing `AsyncBufRead`.
9pub trait AsyncBufReadExt: AsyncBufRead {
10	/// Fills the internal buffer, then invokes a callback which can consume some bytes from that
11	/// buffer.
12	///
13	/// This function adapts the `poll_fill_buf` and `consume` methods into a proper future.
14	///
15	/// The type parameter `CallbackReturn` is the type returned by the callback, and `Callback` is
16	/// the callback itself.
17	///
18	/// The parameter `callback` is the callback function which can use data from the buffer. It is
19	/// called at most once per call to `read_buf` (it may not be called at all if an error
20	/// occurs). It is passed the bytes in the buffer. Its return value identifies how many bytes
21	/// to consume from the buffer, along with an arbitrary value to pass back to the caller of
22	/// `read_buf`.
23	fn read_buf<CallbackReturn, Callback: FnOnce(&'_ [u8]) -> (usize, CallbackReturn) + Unpin>(
24		self: Pin<&mut Self>,
25		callback: Callback,
26	) -> ReadBufFuture<'_, Self, CallbackReturn, Callback> {
27		ReadBufFuture {
28			source: self,
29			callback: Some(callback),
30		}
31	}
32}
33
34impl<R: AsyncBufRead + ?Sized> AsyncBufReadExt for R {}
35
36/// A set of additional utility functions available on any type implementing `AsyncRead`.
37pub trait AsyncReadExt: AsyncRead {
38	/// Reads data to a caller-provided buffer.
39	fn read<'buffer>(
40		self: Pin<&mut Self>,
41		buffer: &'buffer mut [u8],
42	) -> ReadFuture<'_, 'buffer, Self> {
43		ReadFuture {
44			source: self,
45			buffer,
46		}
47	}
48
49	/// Reads data to multiple caller-provided buffers.
50	fn read_vectored<'buffers>(
51		self: Pin<&mut Self>,
52		buffers: &'buffers mut [IoSliceMut<'buffers>],
53	) -> ReadVectoredFuture<'_, 'buffers, Self> {
54		ReadVectoredFuture {
55			source: self,
56			buffers,
57		}
58	}
59
60	/// Performs a vectored read with an additional length limit.
61	fn read_vectored_bounded<'bufs>(
62		self: Pin<&mut Self>,
63		bufs: &'bufs mut [IoSliceMut<'bufs>],
64		limit: u64,
65	) -> ReadVectoredBoundedFuture<'_, 'bufs, Self> {
66		ReadVectoredBoundedFuture {
67			source: self,
68			bufs,
69			limit,
70		}
71	}
72
73	/// A wrapper around [`futures_io::AsyncRead::poll_read_vectored`] that limits the amount of
74	/// data returned to a specified quantity.
75	fn poll_read_vectored_bounded(
76		self: Pin<&mut Self>,
77		cx: &mut Context<'_>,
78		bufs: &mut [IoSliceMut<'_>],
79		limit: u64,
80	) -> Poll<Result<usize>> {
81		if limit == 0 {
82			Ok(0).into()
83		} else {
84			let limit = std::cmp::min(limit, usize::MAX as u64) as usize;
85			let first_buffer = &mut bufs[0];
86			if first_buffer.len() >= limit {
87				// The first IoSlice alone covers at least limit bytes. Our read will only touch that
88				// slice and no more.
89				self.poll_read(cx, &mut first_buffer[..limit])
90			} else {
91				// The first IoSlice alone is smaller than limit bytes. Do a vectored read. To avoid
92				// modifying any of the IoSlices, choose only enough IoSlices to add up to ≤limit
93				// bytes. This might even mean just one IoSlice (if the first slice is smaller than
94				// limit but the first two added are larger), but in general it could be more than one.
95				let buf_count: usize = bufs
96					.iter()
97					.scan(0_usize, |size_so_far, elt| {
98						*size_so_far += elt.len();
99						Some(*size_so_far > limit)
100					})
101					.enumerate()
102					.find(|elt| elt.1)
103					.unwrap_or((bufs.len(), false))
104					.0;
105				self.poll_read_vectored(cx, &mut bufs[..buf_count])
106			}
107		}
108	}
109}
110
111impl<R: AsyncRead + ?Sized> AsyncReadExt for R {}
112
113/// A set of additional utility functions available on any type implementing `AsyncWrite`.
114pub trait AsyncWriteExt: AsyncWrite {
115	/// Writes a block of bytes to the writeable.
116	///
117	/// This function performs repeated writes into the writeable until the entire requested data
118	/// has been written.
119	fn write_all<'a>(self: Pin<&'a mut Self>, data: &'a [u8]) -> WriteAllFuture<'a, Self> {
120		WriteAllFuture { sink: self, data }
121	}
122}
123
124impl<W: AsyncWrite + ?Sized> AsyncWriteExt for W {}
125
126/// A future that fills an `AsyncBufRead`’s internal buffer and then invokes a callback to consume
127/// some or all of the data.
128#[derive(Debug)]
129pub struct ReadBufFuture<
130	'source,
131	Source: AsyncBufRead + ?Sized,
132	CallbackReturn,
133	Callback: FnOnce(&[u8]) -> (usize, CallbackReturn) + Unpin,
134> {
135	source: Pin<&'source mut Source>,
136	callback: Option<Callback>,
137}
138
139impl<
140		Source: AsyncBufRead + ?Sized,
141		CallbackReturn,
142		Callback: FnOnce(&[u8]) -> (usize, CallbackReturn) + Unpin,
143	> Future for ReadBufFuture<'_, Source, CallbackReturn, Callback>
144{
145	type Output = Result<CallbackReturn>;
146
147	fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
148		let this = self.get_mut();
149		let data = ready!(this.source.as_mut().poll_fill_buf(cx))?;
150		let (consumed, ret) = (this.callback.take().unwrap())(data);
151		this.source.as_mut().consume(consumed);
152		Ok(ret).into()
153	}
154}
155
156/// A future that reads from an `AsyncRead` into a single caller-provided buffer.
157#[derive(Debug)]
158pub struct ReadFuture<'source, 'buffer, Source: AsyncRead + ?Sized> {
159	source: Pin<&'source mut Source>,
160	buffer: &'buffer mut [u8],
161}
162
163impl<Source: AsyncRead + ?Sized> Future for ReadFuture<'_, '_, Source> {
164	type Output = Result<usize>;
165
166	fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167		let this = self.get_mut();
168		this.source.as_mut().poll_read(cx, this.buffer)
169	}
170}
171
172/// A future that reads from an `AsyncRead` into a collection of caller-provided buffers.
173#[derive(Debug)]
174pub struct ReadVectoredFuture<'source, 'buffers, Source: AsyncRead + ?Sized> {
175	source: Pin<&'source mut Source>,
176	buffers: &'buffers mut [IoSliceMut<'buffers>],
177}
178
179impl<Source: AsyncRead + ?Sized> Future for ReadVectoredFuture<'_, '_, Source> {
180	type Output = Result<usize>;
181
182	fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
183		let this = self.get_mut();
184		this.source.as_mut().poll_read_vectored(cx, this.buffers)
185	}
186}
187
188/// A future that reads from an `AsyncRead` into a collection of caller-provided buffers, with an
189/// additional length limit.
190#[derive(Debug)]
191pub struct ReadVectoredBoundedFuture<'source, 'bufs, Source: AsyncRead + ?Sized> {
192	source: Pin<&'source mut Source>,
193	bufs: &'bufs mut [IoSliceMut<'bufs>],
194	limit: u64,
195}
196
197impl<Source: AsyncRead + ?Sized> Future for ReadVectoredBoundedFuture<'_, '_, Source> {
198	type Output = Result<usize>;
199
200	fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
201		let this = self.get_mut();
202		this.source
203			.as_mut()
204			.poll_read_vectored_bounded(cx, this.bufs, this.limit)
205	}
206}
207
208/// A future that writes all of an array to an `AsyncWrite`.
209#[derive(Debug)]
210pub struct WriteAllFuture<'a, T: AsyncWrite + ?Sized> {
211	sink: Pin<&'a mut T>,
212	data: &'a [u8],
213}
214
215impl<T: AsyncWrite + ?Sized> Future for WriteAllFuture<'_, T> {
216	type Output = Result<()>;
217
218	fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219		while !self.data.is_empty() {
220			let data = self.data;
221			let bytes_written = ready!(self.sink.as_mut().poll_write(cx, data))?;
222			self.data = &self.data[bytes_written..];
223		}
224		Ok(()).into()
225	}
226}
227
228/// Issues repeated reads until the caller-provided buffer is full.
229#[cfg(test)]
230pub async fn read_all<'buffer, Source: AsyncRead + ?Sized>(
231	mut src: Pin<&mut Source>,
232	mut buffer: &'buffer mut [u8],
233) -> Result<()> {
234	while !buffer.is_empty() {
235		let bytes_read = src.as_mut().read(buffer).await?;
236		if bytes_read == 0 {
237			return Err(std::io::ErrorKind::UnexpectedEof.into());
238		} else {
239			buffer = &mut buffer[bytes_read..];
240		}
241	}
242	Ok(())
243}
244
245#[cfg(test)]
246mod test {
247	use super::*;
248	use futures_executor::block_on;
249	use futures_io::AsyncWrite;
250	use std::pin::Pin;
251	use std::task::{Context, Poll};
252
253	/// Tests calling `read` on a source.
254	#[test]
255	fn test_read() {
256		block_on(async {
257			let mut src: &[u8] = &b"abcdefgh"[..];
258			let mut buffer = [0u8; 4];
259			let bytes_read = Pin::new(&mut src).read(&mut buffer[..]).await.unwrap();
260			assert_eq!(bytes_read, 4);
261			assert_eq!(&buffer, b"abcd");
262		});
263	}
264
265	/// Tests calling `read_vectored` on a source.
266	#[test]
267	fn test_read_vectored() {
268		block_on(async {
269			let mut src: &[u8] = &b"abcdefgh"[..];
270			let mut buf1 = [0u8; 4];
271			let mut buf2 = [0u8; 2];
272			let mut slices = [IoSliceMut::new(&mut buf1), IoSliceMut::new(&mut buf2)];
273			let bytes_read = Pin::new(&mut src).read_vectored(&mut slices).await.unwrap();
274			assert_eq!(bytes_read, 6);
275			assert_eq!(&buf1, b"abcd");
276			assert_eq!(&buf2, b"ef");
277		});
278	}
279
280	/// Tests calling `poll_read_vectored_bounded` with a limit small enough to fill less than one
281	/// buffer.
282	#[test]
283	fn test_poll_read_vectored_bounded_one_partial() {
284		block_on(async {
285			let mut src: &[u8] = &b"abcdefgh"[..];
286			let mut buf1 = [0_u8; 4];
287			let mut buf2 = [0_u8; 4];
288			let mut slices = [
289				IoSliceMut::new(&mut buf1[..]),
290				IoSliceMut::new(&mut buf2[..]),
291			];
292			let bytes_read = Pin::new(&mut src)
293				.read_vectored_bounded(&mut slices, 3)
294				.await
295				.unwrap();
296			assert_eq!(bytes_read, 3);
297			assert_eq!(&buf1, b"abc\0");
298			assert_eq!(&buf2, b"\0\0\0\0");
299		});
300	}
301
302	/// Tests calling `poll_read_vectored_bounded` with a limit large enough to fill the first
303	/// buffer, but not the first two.
304	#[test]
305	fn test_poll_read_vectored_bounded_one_full() {
306		block_on(async {
307			let mut src: &[u8] = &b"abcdefgh"[..];
308			let mut buf1 = [0_u8; 4];
309			let mut buf2 = [0_u8; 4];
310			let mut slices = [
311				IoSliceMut::new(&mut buf1[..]),
312				IoSliceMut::new(&mut buf2[..]),
313			];
314			let bytes_read = Pin::new(&mut src)
315				.read_vectored_bounded(&mut slices, 5)
316				.await
317				.unwrap();
318			assert_eq!(bytes_read, 4);
319			assert_eq!(&buf1, b"abcd");
320			assert_eq!(&buf2, b"\0\0\0\0");
321		});
322	}
323
324	/// Tests calling `poll_read_vectored_bounded` with a limit large enough to fill the first two
325	/// buffers.
326	#[test]
327	fn test_poll_read_vectored_bounded_two_full() {
328		block_on(async {
329			let mut src: &[u8] = &b"abcdefghij"[..];
330			let mut buf1 = [0_u8; 4];
331			let mut buf2 = [0_u8; 4];
332			let mut slices = [
333				IoSliceMut::new(&mut buf1[..]),
334				IoSliceMut::new(&mut buf2[..]),
335			];
336			let bytes_read = Pin::new(&mut src)
337				.read_vectored_bounded(&mut slices, 10)
338				.await
339				.unwrap();
340			assert_eq!(bytes_read, 8);
341			assert_eq!(&buf1, b"abcd");
342			assert_eq!(&buf2, b"efgh");
343		});
344	}
345
346	/// Tests calling `write_all` on a sink that accepts unlimited data at a time.
347	#[test]
348	fn test_write_all_fast() {
349		struct Test {
350			v: Vec<u8>,
351		}
352		impl AsyncWrite for Test {
353			fn poll_write(
354				mut self: Pin<&mut Self>,
355				_cx: &mut Context<'_>,
356				data: &[u8],
357			) -> Poll<Result<usize>> {
358				self.v.extend_from_slice(data);
359				Ok(data.len()).into()
360			}
361
362			fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
363				panic!("Should not be called");
364			}
365
366			fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
367				panic!("Should not be called");
368			}
369		}
370		let mut t = Test { v: vec![] };
371		block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap();
372		assert_eq!(t.v.as_slice(), b"abcdefgh");
373	}
374
375	/// Tests calling `write_all` on a sink that accepts data only one byte at a time.
376	#[test]
377	fn test_write_all_slow() {
378		struct Test {
379			v: Vec<u8>,
380		}
381		impl AsyncWrite for Test {
382			fn poll_write(
383				mut self: Pin<&mut Self>,
384				_cx: &mut Context<'_>,
385				data: &[u8],
386			) -> Poll<Result<usize>> {
387				match data.first() {
388					None => Ok(0).into(),
389					Some(&b) => {
390						self.v.push(b);
391						Ok(1).into()
392					}
393				}
394			}
395
396			fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
397				panic!("Should not be called");
398			}
399
400			fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
401				panic!("Should not be called");
402			}
403		}
404		let mut t = Test { v: vec![] };
405		block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap();
406		assert_eq!(t.v.as_slice(), b"abcdefgh");
407	}
408
409	/// Tests calling `write_all` on a sink that returns an error.
410	#[test]
411	fn test_write_all_error() {
412		struct Test {
413			already_called: bool,
414		}
415		impl AsyncWrite for Test {
416			fn poll_write(
417				mut self: Pin<&mut Self>,
418				_cx: &mut Context<'_>,
419				_data: &[u8],
420			) -> Poll<Result<usize>> {
421				assert!(!self.already_called);
422				self.already_called = true;
423				Err(std::io::Error::new(
424					std::io::ErrorKind::Other,
425					"Test error message",
426				))
427				.into()
428			}
429
430			fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
431				panic!("Should not be called");
432			}
433
434			fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
435				panic!("Should not be called");
436			}
437		}
438		let mut t = Test {
439			already_called: false,
440		};
441		let e = block_on(async { Pin::new(&mut t).write_all(b"abcdefgh").await }).unwrap_err();
442		assert_eq!(e.kind(), std::io::ErrorKind::Other);
443		assert_eq!(format!("{}", e), "Test error message");
444	}
445}