use futures::channel::mpsc;
use futures::io::{AsyncRead, AsyncWrite};
use futures::sink::{Sink, SinkExt};
use futures::stream::Stream;
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
use std::pin::Pin;
use std::task::{Context, Poll};
const CAPACITY: usize = 4;
const CHUNKSZ: usize = 213;
pub fn stream_pair() -> (LocalStream, LocalStream) {
let (w1, r2) = mpsc::channel(CAPACITY);
let (w2, r1) = mpsc::channel(CAPACITY);
let s1 = LocalStream {
w: w1,
r: r1,
pending_bytes: Vec::new(),
tls_cert: None,
};
let s2 = LocalStream {
w: w2,
r: r2,
pending_bytes: Vec::new(),
tls_cert: None,
};
(s1, s2)
}
pub struct LocalStream {
w: mpsc::Sender<IoResult<Vec<u8>>>,
r: mpsc::Receiver<IoResult<Vec<u8>>>,
pending_bytes: Vec<u8>,
pub(crate) tls_cert: Option<Vec<u8>>,
}
fn drain_helper(buf: &mut [u8], pending_bytes: &mut Vec<u8>) -> usize {
let n_to_drain = std::cmp::min(buf.len(), pending_bytes.len());
buf[..n_to_drain].copy_from_slice(&pending_bytes[..n_to_drain]);
pending_bytes.drain(..n_to_drain);
n_to_drain
}
impl AsyncRead for LocalStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<IoResult<usize>> {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
if self.tls_cert.is_some() {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"attempted to treat a TLS stream as non-TLS!",
)));
}
if !self.pending_bytes.is_empty() {
return Poll::Ready(Ok(drain_helper(buf, &mut self.pending_bytes)));
}
match futures::ready!(Pin::new(&mut self.r).poll_next(cx)) {
Some(Err(e)) => Poll::Ready(Err(e)),
Some(Ok(bytes)) => {
self.pending_bytes = bytes;
let n = drain_helper(buf, &mut self.pending_bytes);
Poll::Ready(Ok(n))
}
None => Poll::Ready(Ok(0)), }
}
}
impl AsyncWrite for LocalStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<IoResult<usize>> {
if self.tls_cert.is_some() {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"attempted to treat a TLS stream as non-TLS!",
)));
}
match futures::ready!(Pin::new(&mut self.w).poll_ready(cx)) {
Ok(()) => (),
Err(e) => return Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
}
let buf = if buf.len() > CHUNKSZ {
&buf[..CHUNKSZ]
} else {
buf
};
let len = buf.len();
match Pin::new(&mut self.w).start_send(Ok(buf.to_vec())) {
Ok(()) => Poll::Ready(Ok(len)),
Err(e) => Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
Pin::new(&mut self.w)
.poll_flush(cx)
.map_err(|e| IoError::new(ErrorKind::BrokenPipe, e))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
Pin::new(&mut self.w)
.poll_close(cx)
.map_err(|e| IoError::new(ErrorKind::Other, e))
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub struct SyntheticError;
impl std::error::Error for SyntheticError {}
impl std::fmt::Display for SyntheticError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Synthetic error")
}
}
impl LocalStream {
pub async fn send_err(&mut self, kind: ErrorKind) {
let _ignore = self.w.send(Err(IoError::new(kind, SyntheticError))).await;
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use futures_await_test::async_test;
use rand::Rng;
use tor_basic_utils::test_rng::testing_rng;
#[async_test]
async fn basic_rw() {
let (mut s1, mut s2) = stream_pair();
let mut text1 = vec![0_u8; 9999];
testing_rng().fill(&mut text1[..]);
let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
async {
for _ in 0_u8..10 {
s1.write_all(&text1[..]).await?;
}
s1.close().await?;
Ok(())
},
async {
let mut text2: Vec<u8> = Vec::new();
let mut buf = [0_u8; 33];
loop {
let n = s2.read(&mut buf[..]).await?;
if n == 0 {
break;
}
text2.extend(&buf[..n]);
}
for ch in text2[..].chunks(text1.len()) {
assert_eq!(ch, &text1[..]);
}
Ok(())
}
);
v1.unwrap();
v2.unwrap();
}
#[async_test]
async fn send_error() {
let (mut s1, mut s2) = stream_pair();
let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
async {
s1.write_all(b"hello world").await?;
s1.send_err(ErrorKind::PermissionDenied).await;
Ok(())
},
async {
let mut buf = [0_u8; 33];
loop {
let n = s2.read(&mut buf[..]).await?;
if n == 0 {
break;
}
}
Ok(())
}
);
v1.unwrap();
let e = v2.err().unwrap();
assert_eq!(e.kind(), ErrorKind::PermissionDenied);
let synth = e.into_inner().unwrap();
assert_eq!(synth.to_string(), "Synthetic error");
}
#[async_test]
async fn drop_reader() {
let (mut s1, s2) = stream_pair();
let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
async {
for _ in 0_u16..1000 {
s1.write_all(&[9_u8; 9999]).await?;
}
Ok(())
},
async {
drop(s2);
Ok(())
}
);
v2.unwrap();
let e = v1.err().unwrap();
assert_eq!(e.kind(), ErrorKind::BrokenPipe);
}
}