use std::{cmp::min, future::Future as StdFuture, pin::Pin, result, thread};
use tokio::io;
use crate::common::*;
pub type BoxFuture<T> = Pin<Box<dyn StdFuture<Output = Result<T>> + Send + 'static>>;
pub type BoxStream<T> = Box<dyn Stream<Item = T, Error = Error> + Send + 'static>;
pub trait ConsumeWithParallelism<T>: Sized {
fn consume_with_parallelism(self, parallelism: usize) -> BoxFuture<Vec<T>>;
}
impl<T: Send + Sized + 'static> ConsumeWithParallelism<T> for BoxStream<BoxFuture<T>> {
fn consume_with_parallelism(self, parallelism: usize) -> BoxFuture<Vec<T>> {
self
.map(|fut: BoxFuture<T>| fut.compat())
.buffer_unordered(parallelism)
.collect()
.compat()
.boxed()
}
}
pub(crate) fn bytes_channel(
buffer: usize,
) -> (
mpsc::Sender<Result<BytesMut>>,
impl Stream<Item = BytesMut, Error = Error> + Send + 'static,
) {
let (sender, receiver) = mpsc::channel(buffer);
let receiver = receiver
.map_err(|_| format_err!("stream read error"))
.and_then(|result| result);
(sender, receiver)
}
pub(crate) async fn copy_stream_to_writer<S, W>(
ctx: Context,
mut stream: S,
mut wtr: W,
) -> Result<()>
where
S: Stream<Item = BytesMut, Error = Error> + 'static,
W: AsyncWrite + 'static,
{
loop {
trace!(ctx.log(), "reading from stream");
match stream.into_future().compat().await {
Err((err, _rest_of_stream)) => {
error!(ctx.log(), "error reading stream: {}", err);
return Err(err);
}
Ok((None, _rest_of_stream)) => {
trace!(ctx.log(), "end of stream");
return Ok(());
}
Ok((Some(bytes), rest_of_stream)) => {
stream = rest_of_stream;
trace!(ctx.log(), "writing {} bytes", bytes.len());
io::write_all(&mut wtr, bytes).compat().await.map_err(|e| {
error!(ctx.log(), "write error: {}", e);
format_err!("error writing data: {}", e)
})?;
trace!(ctx.log(), "wrote to writer");
}
}
}
}
pub(crate) fn copy_reader_to_stream<R>(
ctx: Context,
mut rdr: R,
) -> Result<impl Stream<Item = BytesMut, Error = Error> + Send + 'static>
where
R: AsyncRead + Send + 'static,
{
let (mut sender, receiver) = bytes_channel(1);
let worker = async move {
let mut buffer = vec![0; 64 * 1024];
loop {
trace!(ctx.log(), "reading bytes from reader");
match io::read(rdr, &mut buffer).compat().await {
Err(err) => {
let nice_err = format_err!("stream read error: {}", err);
error!(ctx.log(), "{}", nice_err);
if sender.send(Err(nice_err)).compat().await.is_err() {
error!(
ctx.log(),
"broken pipe prevented sending error: {}", err
);
}
return Ok(());
}
Ok((new_rdr, data, count)) => {
if count == 0 {
trace!(ctx.log(), "done copying AsyncRead to stream");
return Ok(());
}
rdr = new_rdr;
let bytes = BytesMut::from(&data[..count]);
trace!(ctx.log(), "sending {} bytes to stream", bytes.len());
match sender.send(Ok(bytes)).compat().await {
Ok(new_sender) => {
trace!(ctx.log(), "sent bytes to stream");
sender = new_sender;
}
Err(_err) => {
error!(
ctx.log(),
"broken pipe forwarding async data to stream"
);
return Ok(());
}
}
}
}
}
};
tokio::spawn(worker.boxed().compat());
Ok(receiver)
}
pub(crate) struct SyncStreamWriter {
ctx: Context,
sender: Option<mpsc::Sender<Result<BytesMut>>>,
}
impl SyncStreamWriter {
pub fn pipe(ctx: Context) -> (Self, impl Stream<Item = BytesMut, Error = Error>) {
let (sender, receiver) = bytes_channel(1);
(
SyncStreamWriter {
ctx,
sender: Some(sender),
},
receiver,
)
}
}
impl SyncStreamWriter {
#[allow(dead_code)]
pub(crate) fn send_error(&mut self, err: Error) -> io::Result<()> {
debug!(self.ctx.log(), "sending error: {}", err);
if let Some(sender) = self.sender.take() {
match sender.send(Err(err)).wait() {
Ok(sender) => {
self.sender = Some(sender);
Ok(())
}
Err(_err) => Err(io::ErrorKind::BrokenPipe.into()),
}
} else {
Err(io::ErrorKind::BrokenPipe.into())
}
}
}
impl Write for SyncStreamWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
trace!(self.ctx.log(), "sending {} bytes", buf.len());
if let Some(sender) = self.sender.take() {
match sender.send(Ok(BytesMut::from(buf))).wait() {
Ok(sender) => {
self.sender = Some(sender);
Ok(buf.len())
}
Err(_err) => Err(io::ErrorKind::BrokenPipe.into()),
}
} else {
Err(io::ErrorKind::BrokenPipe.into())
}
}
fn flush(&mut self) -> io::Result<()> {
trace!(self.ctx.log(), "flushing");
if let Some(sender) = self.sender.take() {
match sender.flush().wait() {
Ok(sender) => {
self.sender = Some(sender);
Ok(())
}
Err(_err) => Err(io::ErrorKind::BrokenPipe.into()),
}
} else {
Err(io::ErrorKind::BrokenPipe.into())
}
}
}
pub(crate) struct SyncStreamReader {
ctx: Context,
stream: Option<BoxStream<BytesMut>>,
buffer: BytesMut,
}
impl SyncStreamReader {
pub(crate) fn new(ctx: Context, stream: BoxStream<BytesMut>) -> Self {
Self {
ctx,
stream: Some(stream),
buffer: BytesMut::default(),
}
}
}
impl Read for SyncStreamReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
assert!(!buf.is_empty());
if self.buffer.is_empty() {
if let Some(stream) = self.stream.take() {
match stream.into_future().wait() {
Ok((None, _rest_of_stream)) => {
trace!(self.ctx.log(), "end of stream");
return Ok(0);
}
Ok((Some(bytes), rest_of_stream)) => {
self.stream = Some(rest_of_stream);
trace!(
self.ctx.log(),
"read {} bytes from stream",
bytes.len()
);
assert!(!bytes.is_empty());
self.buffer = bytes;
}
Err((err, _rest_of_stream)) => {
error!(self.ctx.log(), "error reading from stream: {}", err);
return Err(io::Error::new(
io::ErrorKind::Other,
Box::new(err.compat()),
));
}
}
} else {
trace!(self.ctx.log(), "stream is already closed");
return Ok(0);
}
}
assert!(!self.buffer.is_empty());
let count = min(self.buffer.len(), buf.len());
buf[..count].copy_from_slice(&self.buffer.split_to(count));
trace!(self.ctx.log(), "read returned {} bytes", count);
Ok(count)
}
}
pub(crate) fn box_stream_once<T>(value: Result<T>) -> BoxStream<T>
where
T: Send + 'static,
{
Box::new(stream::once(value))
}
pub(crate) async fn run_sync_fn_in_background<F, T>(
thread_name: String,
f: F,
) -> Result<T>
where
F: (FnOnce() -> Result<T>) + Send + 'static,
T: Send + 'static,
{
let (sender, receiver) = mpsc::channel(1);
let thr = thread::Builder::new().name(thread_name);
let handle = thr
.spawn(move || {
sender.send(f()).wait().expect(
"should always be able to send results from background thread",
);
})
.context("could not spawn thread")?;
let background_result = receiver.into_future().compat().await;
let result = match background_result {
Ok((Some(Ok(value)), _receiver)) => Ok(value),
Ok((Some(Err(err)), _receiver)) => Err(err),
Ok((None, _receiver)) => {
unreachable!("background thread did not send any results");
}
Err(_) => Err(format_err!("background thread panicked")),
};
handle.join().expect("background worker thread panicked");
result
}
pub fn run_futures_with_runtime(
cmd_future: BoxFuture<()>,
worker_future: BoxFuture<()>,
) -> Result<()> {
let combined_fut = async move {
cmd_future
.compat()
.join(worker_future.compat())
.compat()
.await?;
let result: Result<()> = Ok(());
result
};
let mut runtime =
tokio::runtime::Runtime::new().expect("Unable to create a runtime");
runtime.block_on(combined_fut.boxed().compat())?;
Ok(())
}
pub(crate) async fn async_read_to_end<R>(input: R) -> Result<Vec<u8>>
where
R: AsyncRead + Send,
{
let (_input, bytes) = io::read_to_end(input, vec![]).compat().await?;
Ok(bytes)
}
pub(crate) async fn async_read_to_string<R>(input: R) -> Result<String>
where
R: AsyncRead + Send,
{
let bytes = async_read_to_end(input).await?;
Ok(String::from_utf8(bytes)?)
}
pub(crate) async fn buffer_sync_write_and_copy_to_async<W, F, E>(
wtr: W,
f: F,
) -> Result<W>
where
W: AsyncWrite + Send,
F: FnOnce(&mut dyn Write) -> result::Result<(), E>,
E: Into<Error>,
{
let mut buffer = vec![];
f(&mut buffer).map_err(|e| e.into())?;
let (wtr, _buffer) = io::write_all(wtr, buffer).compat().await?;
Ok(wtr)
}