use bytes::Bytes;
use futures::{
self, executor::block_on, stream, Sink, SinkExt, TryStream, TryStreamExt,
};
use std::{cmp::min, error, fmt, panic, pin::Pin, result};
use tokio::{io, process::Child, sync::mpsc, task};
use tokio_stream::wrappers::ReceiverStream;
use tracing::Span;
use crate::common::*;
pub type BoxFuture<T, E = Error> = futures::future::BoxFuture<'static, Result<T, E>>;
pub type BoxStream<T, E = Error> = futures::stream::BoxStream<'static, Result<T, E>>;
pub trait ConsumeWithParallelism<T>: Sized {
fn consume_with_parallelism(self, parallelism: usize) -> BoxFuture<Vec<T>>;
}
impl<T: Send + Sized + 'static> ConsumeWithParallelism<T> for BoxStream<BoxFuture<T>> {
fn consume_with_parallelism(self, parallelism: usize) -> BoxFuture<Vec<T>> {
self
.try_buffer_unordered(parallelism)
.try_collect::<Vec<T>>()
.boxed()
}
}
pub(crate) fn bytes_channel(
buffer: usize,
) -> (
mpsc::Sender<Result<BytesMut>>,
impl Stream<Item = Result<BytesMut>> + Send + Unpin + 'static,
) {
let (sender, receiver) = mpsc::channel(buffer);
(sender, ReceiverStream::new(receiver))
}
pub async fn try_forward<T, St, Si>(mut stream: St, mut sink: Si) -> Result<()>
where
St: Stream<Item = Result<T>> + Unpin,
Si: Sink<T> + Unpin,
Error: From<Si::Error>,
{
trace!("forwarding stream to sink");
while let Some(result) = stream.next().await {
match result {
Ok(value) => sink
.send(value)
.await
.map_err(Error::from)
.context("error sending value to sink")?,
Err(err) => {
return Err(err.context("error reading from stream"));
}
}
}
sink.close()
.await
.map_err(Error::from)
.context("error sending value to sink")?;
trace!("done forwarding stream to sink");
Ok(())
}
pub(crate) async fn try_forward_to_sender<T, St>(
mut stream: St,
sender: &mut mpsc::Sender<Result<T>>,
) -> Result<()>
where
T: Send,
St: Stream<Item = Result<T>> + Unpin,
{
trace!("forwarding stream to sender");
while let Some(result) = stream.next().await {
match result {
Ok(bytes) => sender.send(Ok(bytes)).await.map_send_err()?,
Err(err) => {
let ret_err = format_err!("error reading from stream: {}", err);
sender.send(Err(err)).await.map_err(|_| {
format_err!("could not forward error to sender: {}", ret_err)
})?;
return Err(ret_err);
}
}
}
trace!("done forwarding stream to sender");
Ok(())
}
pub(crate) async fn copy_stream_to_writer<S, W>(
mut stream: S,
mut wtr: W,
) -> Result<()>
where
S: Stream<Item = Result<BytesMut>> + Unpin + 'static,
W: AsyncWrite + Unpin + 'static,
{
trace!("begin copy_stream_to_writer");
while let Some(result) = stream.next().await {
match result {
Err(err) => {
error!("error reading stream: {}", err);
return Err(err);
}
Ok(bytes) => {
trace!("writing {} bytes", bytes.len());
wtr.write_all(&bytes).await.map_err(|e| {
error!("write error: {}", e);
format_err!("error writing data: {}", e)
})?;
trace!("wrote to writer");
}
}
}
wtr.flush().await?;
trace!("end copy_stream_to_writer");
Ok(())
}
pub(crate) fn copy_reader_to_stream<R>(
mut rdr: R,
) -> Result<impl Stream<Item = Result<BytesMut>> + Send + 'static>
where
R: AsyncRead + Send + Unpin + 'static,
{
let (sender, receiver) = bytes_channel(1);
let worker: BoxFuture<()> = async move {
let mut buffer = vec![0u8; 64 * 1024];
loop {
trace!("reading bytes from reader");
match rdr.read(&mut buffer).await {
Err(err) => {
let nice_err = format_err!("read error: {}", err);
error!("{}", nice_err);
if sender.send(Err(nice_err)).await.is_err() {
error!("broken pipe prevented sending error: {}", err);
}
return Ok(());
}
Ok(count) => {
if count == 0 {
trace!("done copying AsyncRead to stream");
return Ok(());
}
let bytes = BytesMut::from(&buffer[..count]);
trace!("sending {} bytes to stream", bytes.len());
match sender.send(Ok(bytes)).await {
Ok(()) => {
trace!("sent bytes to stream");
}
Err(_err) => {
error!("broken pipe forwarding async data to stream");
return Ok(());
}
}
}
}
}
}
.boxed();
tokio::spawn(worker);
Ok(receiver)
}
pub(crate) struct SyncStreamWriter {
sender: mpsc::Sender<Result<BytesMut>>,
}
impl SyncStreamWriter {
pub fn pipe() -> (Self, impl Stream<Item = Result<BytesMut>> + Send + 'static) {
let (sender, receiver) = bytes_channel(1);
(SyncStreamWriter { sender }, receiver)
}
}
impl SyncStreamWriter {
#[allow(dead_code)]
pub(crate) fn send_error(&mut self, err: Error) -> io::Result<()> {
debug!("sending error: {}", err);
block_on(self.sender.send(Err(err)))
.map_err(|_| io::ErrorKind::BrokenPipe.into())
}
}
impl Write for SyncStreamWriter {
#[instrument(level = "trace", skip_all, fields(buf.len = %buf.len()))]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
trace!("sending {} bytes", buf.len());
block_on(self.sender.send(Ok(BytesMut::from(buf))))
.map_err(|_| -> io::Error { io::ErrorKind::BrokenPipe.into() })?;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
trace!("pretending to flush to an async sender");
Ok(())
}
}
pub(crate) struct SyncStreamReader {
stream: stream::Fuse<BoxStream<BytesMut>>,
seen_error: bool,
buffer: BytesMut,
}
impl SyncStreamReader {
pub(crate) fn new(stream: BoxStream<BytesMut>) -> Self {
Self {
stream: stream.fuse(),
seen_error: false,
buffer: BytesMut::default(),
}
}
}
impl Read for SyncStreamReader {
#[instrument(level = "trace", skip_all, fields(buf.max = %buf.len(), buf.read))]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
assert!(!buf.is_empty());
if self.buffer.is_empty() {
if self.seen_error {
error!("tried to read from stream after error");
return Err(io::ErrorKind::Other.into());
}
match block_on(self.stream.next()) {
None => {
trace!("end of stream");
return Ok(0);
}
Some(Ok(bytes)) => {
trace!("read {} bytes from stream", bytes.len());
assert!(!bytes.is_empty());
self.buffer = bytes;
}
Some(Err(err)) => {
error!("error reading from stream: {}", err);
self.seen_error = true;
return Err(io::Error::new(io::ErrorKind::Other, err));
}
}
}
assert!(!self.buffer.is_empty());
let count = min(self.buffer.len(), buf.len());
buf[..count].copy_from_slice(&self.buffer.split_to(count));
Span::current().record("buf.read", count);
trace!("read returned {} bytes", count);
Ok(count)
}
}
pub(crate) fn box_stream_once<T>(value: Result<T>) -> BoxStream<T>
where
T: Send + 'static,
{
stream::once(async { value }).boxed()
}
pub async fn spawn_blocking<F, T>(f: F) -> Result<T>
where
F: (FnOnce() -> Result<T>) + Send + 'static,
T: Send + 'static,
{
let span = Span::current();
let traced_f = move || -> Result<T> {
let _span = span.entered();
f()
};
match task::spawn_blocking(traced_f).await {
Ok(f_result) => f_result,
Err(join_err) => match join_err.try_into_panic() {
Ok(panic_value) => panic::resume_unwind(panic_value),
Err(join_err) => {
Err(format_err!("background thread failed: {}", join_err))
}
},
}
}
pub(crate) async fn async_read_to_end<R>(mut input: R) -> Result<Vec<u8>>
where
R: AsyncRead + Send + Unpin,
{
let mut buf = vec![];
input.read_to_end(&mut buf).await?;
Ok(buf)
}
pub(crate) async fn async_read_to_string<R>(input: R) -> Result<String>
where
R: AsyncRead + Send + Unpin,
{
let bytes = async_read_to_end(input).await?;
Ok(String::from_utf8(bytes)?)
}
#[allow(dead_code)]
pub(crate) async fn write_to_stdin(
child_name: &str,
child: &mut Child,
data: &[u8],
) -> Result<()> {
let mut child_stdin = child
.stdin
.take()
.ok_or_else(|| format_err!("`{}` doesn't have a stdin handle", child_name))?;
child_stdin
.write_all(data)
.await
.with_context(|| format!("error piping to `{}`", child_name))?;
child_stdin
.shutdown()
.await
.with_context(|| format!("error shutting down pipe to `{}`", child_name))?;
Ok(())
}
pub(crate) async fn buffer_sync_write_and_copy_to_async<W, F, E>(
mut wtr: W,
f: F,
) -> Result<W>
where
W: AsyncWrite + Send + Unpin,
F: FnOnce(&mut dyn Write) -> result::Result<(), E>,
E: Into<Error>,
{
let mut buffer = vec![];
f(&mut buffer).map_err(|e| e.into())?;
wtr.write_all(&buffer).await?;
Ok(wtr)
}
#[derive(Debug)]
pub(crate) struct SendError;
impl fmt::Display for SendError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "cannot send data to closed channel")
}
}
impl error::Error for SendError {}
pub(crate) trait SendResultExt<T> {
fn map_send_err(self) -> Result<T, SendError>;
}
impl<T, ErrInfo> SendResultExt<T> for Result<T, mpsc::error::SendError<ErrInfo>> {
fn map_send_err(self) -> Result<T, SendError> {
match self {
Ok(val) => Ok(val),
Err(_err) => Err(SendError),
}
}
}
pub(crate) type IdiomaticBytesStream = Pin<
Box<
dyn TryStream<
Ok = Bytes,
Error = Box<dyn error::Error + Send + Sync>,
Item = Result<Bytes, Box<dyn error::Error + Send + Sync>>,
> + Send
+ Sync
+ 'static,
>,
>;
pub(crate) fn http_response_stream(
response: reqwest::Response,
) -> BoxStream<BytesMut> {
response
.bytes_stream()
.map_ok(|chunk| BytesMut::from(chunk.as_ref()))
.map_err(|err| err.into())
.boxed()
}
pub(crate) fn idiomatic_bytes_stream(
ctx: &Context,
stream: BoxStream<BytesMut>,
) -> IdiomaticBytesStream {
let to_forward = stream.map_ok(|bytes| bytes.freeze());
let (mut sender, receiver) = mpsc::channel::<Result<Bytes, Error>>(1);
let forwarder: BoxFuture<()> =
async move { try_forward_to_sender(to_forward, &mut sender).await }
.instrument(trace_span!("idiomatic_bytes_stream"))
.boxed();
ctx.spawn_worker(forwarder);
let stream = ReceiverStream::new(receiver)
.map_err(|err| -> Box<dyn error::Error + Send + Sync> { err.into() });
Box::pin(stream)
}