1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
// Copyright 2018 The Grin Developers
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Custom impls of read_exact and write_all to work around async stream restrictions.

use std::io;
use std::io::prelude::*;
use std::thread;
use std::time::Duration;

/// The default implementation of read_exact is useless with an async stream (TcpStream) as
/// it will return as soon as something has been read, regardless of
/// whether the buffer has been filled (and then errors). This implementation
/// will block until it has read exactly `len` bytes and returns them as a
/// `vec<u8>`. Except for a timeout, this implementation will never return a
/// partially filled buffer.
///
/// The timeout in milliseconds aborts the read when it's met. Note that the
/// time is not guaranteed to be exact. To support cases where we want to poll
/// instead of blocking, a `block_on_empty` boolean, when false, ensures
/// `read_exact` returns early with a `io::ErrorKind::WouldBlock` if nothing
/// has been read from the socket.
pub fn read_exact(
	stream: &mut dyn Read,
	mut buf: &mut [u8],
	timeout: Duration,
	block_on_empty: bool,
) -> io::Result<()> {
	let sleep_time = Duration::from_micros(10);
	let mut count = Duration::new(0, 0);

	let mut read = 0;
	loop {
		match stream.read(buf) {
			Ok(0) => {
				return Err(io::Error::new(
					io::ErrorKind::ConnectionAborted,
					"read_exact",
				));
			}
			Ok(n) => {
				let tmp = buf;
				buf = &mut tmp[n..];
				read += n;
			}
			Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
			Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
				if read == 0 && !block_on_empty {
					return Err(io::Error::new(io::ErrorKind::WouldBlock, "read_exact"));
				}
			}
			Err(e) => return Err(e),
		}
		if !buf.is_empty() {
			thread::sleep(sleep_time);
			count += sleep_time;
		} else {
			break;
		}
		if count > timeout {
			return Err(io::Error::new(
				io::ErrorKind::TimedOut,
				"reading from stream",
			));
		}
	}
	Ok(())
}

/// Same as `read_exact` but for writing.
pub fn write_all(stream: &mut dyn Write, mut buf: &[u8], timeout: Duration) -> io::Result<()> {
	let sleep_time = Duration::from_micros(10);
	let mut count = Duration::new(0, 0);

	while !buf.is_empty() {
		match stream.write(buf) {
			Ok(0) => {
				return Err(io::Error::new(
					io::ErrorKind::WriteZero,
					"failed to write whole buffer",
				));
			}
			Ok(n) => buf = &buf[n..],
			Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
			Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
			Err(e) => return Err(e),
		}
		if !buf.is_empty() {
			thread::sleep(sleep_time);
			count += sleep_time;
		} else {
			break;
		}
		if count > timeout {
			return Err(io::Error::new(io::ErrorKind::TimedOut, "writing to stream"));
		}
	}
	Ok(())
}