use futures::stream::{Stream, StreamExt};
use hyper::body::{Body, Bytes, HttpBody};
use std::collections::VecDeque;
use std::fmt;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
#[derive(Clone, Debug)]
pub struct StringError(String);
impl fmt::Display for StringError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for StringError {}
#[derive(Debug, Default, Clone)]
struct ConsumerState {
waker: Option<Waker>,
cursor: usize,
active: bool,
}
#[derive(Debug)]
struct SharedState {
buffer: VecDeque<Result<Bytes, StringError>>,
offset: usize,
is_done: bool,
consumers: [ConsumerState; 2],
}
impl Default for SharedState {
fn default() -> Self {
Self {
buffer: VecDeque::new(),
offset: 0,
is_done: false,
consumers: [
ConsumerState {
active: true,
..Default::default()
},
ConsumerState {
active: true,
..Default::default()
},
],
}
}
}
#[derive(Debug)]
pub struct BodyTeeStream {
shared: Arc<Mutex<SharedState>>,
id: usize,
}
pub async fn tee(mut hyper_body: Body) -> (Body, Body) {
if HttpBody::size_hint(&hyper_body).exact().is_some() {
let bytes = hyper::body::to_bytes(hyper_body)
.await
.expect("Failed to buffer known-size body");
return (hyper::Body::from(bytes.clone()), hyper::Body::from(bytes));
}
let shared_state = Arc::new(Mutex::new(SharedState::default()));
let s1 = BodyTeeStream {
shared: shared_state.clone(),
id: 0,
};
let s2 = BodyTeeStream {
shared: shared_state.clone(),
id: 1,
};
tokio::spawn(async move {
loop {
let result = hyper_body.next().await;
let mut state = shared_state.lock().unwrap();
let finished = if let Some(item) = result {
let item_to_store = item.map_err(|e| StringError(e.to_string()));
let is_err = item_to_store.is_err();
state.buffer.push_back(item_to_store);
is_err
} else {
true
};
if finished {
state.is_done = true;
}
for consumer in state.consumers.iter_mut().filter(|c| c.active) {
if let Some(waker) = consumer.waker.take() {
waker.wake();
}
}
drain_buffer(&mut state);
if finished {
break;
}
}
});
(Body::wrap_stream(s1), Body::wrap_stream(s2))
}
impl HttpBody for BodyTeeStream {
type Data = Bytes;
type Error = Box<dyn std::error::Error + Send + Sync>;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let this = self.get_mut();
let mut state = this.shared.lock().unwrap();
let SharedState {
buffer,
offset,
is_done,
consumers,
..
} = &mut *state;
let consumer = &mut consumers[this.id];
if consumer.cursor >= *offset {
let buffer_index = consumer.cursor - *offset;
if let Some(result) = buffer.get(buffer_index) {
consumer.cursor += 1;
return Poll::Ready(Some(result.clone().map_err(|e| Box::new(e) as Self::Error)));
}
}
if *is_done {
return Poll::Ready(None);
}
consumer.waker = Some(cx.waker().clone());
Poll::Pending
}
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
fn is_end_stream(&self) -> bool {
let state = self.shared.lock().unwrap();
if !state.is_done {
return false;
}
let consumer = &state.consumers[self.id];
let total_buffered_chunks = state.offset + state.buffer.len();
consumer.cursor >= total_buffered_chunks
}
}
impl Stream for BodyTeeStream {
type Item = Result<Bytes, Box<dyn std::error::Error + Send + Sync>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_data(cx)
}
}
impl Drop for BodyTeeStream {
fn drop(&mut self) {
let mut state = self.shared.lock().unwrap();
state.consumers[self.id].active = false;
let other_id = 1 - self.id;
if state.consumers[other_id].active
&& let Some(waker) = state.consumers[other_id].waker.take()
{
waker.wake();
}
drain_buffer(&mut state);
}
}
fn drain_buffer(state: &mut SharedState) {
let min_cursor = state
.consumers
.iter()
.filter(|c| c.active)
.map(|c| c.cursor)
.min()
.unwrap_or(state.offset + state.buffer.len());
let to_drain = min_cursor.saturating_sub(state.offset);
if to_drain > 0 {
state.buffer.drain(0..to_drain);
state.offset += to_drain;
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream::{self, StreamExt};
use hyper::{Body, body::Bytes};
use std::convert::Infallible;
#[tokio::test]
async fn test_simple_duplication() {
let chunks = vec!["hello", " ", "world"];
let stream = stream::iter(chunks.clone()).map(|s| Ok::<_, Infallible>(Bytes::from(s)));
let body = Body::wrap_stream(stream);
let (body1, body2) = tee(body).await;
let res1_fut = body1
.map(|chunk_res| chunk_res.unwrap())
.collect::<Vec<_>>();
let res2_fut = body2
.map(|chunk_res| chunk_res.unwrap())
.collect::<Vec<_>>();
let (res1, res2) = futures::join!(res1_fut, res2_fut);
let res1_str: Vec<&str> = res1
.iter()
.map(|b| std::str::from_utf8(b).unwrap())
.collect();
let res2_str: Vec<&str> = res2
.iter()
.map(|b| std::str::from_utf8(b).unwrap())
.collect();
assert_eq!(res1_str, chunks);
assert_eq!(res2_str, chunks);
}
#[tokio::test]
async fn test_error_propagation() {
let error = std::io::Error::new(std::io::ErrorKind::Other, "test error");
let stream = stream::iter(vec![
Ok(Bytes::from("one")),
Err(error),
Ok(Bytes::from("two")),
]);
let body = Body::wrap_stream(stream);
let (mut body1, mut body2) = tee(body).await;
assert_eq!(body1.next().await.unwrap().unwrap(), Bytes::from("one"));
let err1 = body1.next().await.unwrap().unwrap_err();
assert!(
err1.to_string().contains("test error"),
"Got error: {}",
err1
);
assert!(
body1.next().await.is_none(),
"Stream should end after error"
);
assert_eq!(body2.next().await.unwrap().unwrap(), Bytes::from("one"));
let err2 = body2.next().await.unwrap().unwrap_err();
assert!(
err2.to_string().contains("test error"),
"Got error: {}",
err1
);
assert!(
body2.next().await.is_none(),
"Stream should end after error"
);
}
#[tokio::test]
async fn test_error_with_one_consumer_dropped() {
let error = std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "aborted");
let stream = stream::iter(vec![Ok(Bytes::from("first")), Err(error)]);
let body = Body::wrap_stream(stream);
let (mut body1, body2) = tee(body).await;
drop(body2);
assert_eq!(body1.next().await.unwrap().unwrap(), Bytes::from("first"));
let err1 = body1.next().await.unwrap().unwrap_err();
assert!(err1.to_string().contains("aborted"));
assert!(
body1.next().await.is_none(),
"Stream should end after error"
);
}
#[tokio::test]
async fn test_size_hint_preservation() {
let data = "this has a known size";
let body = Body::from(data);
let original_size_hint = HttpBody::size_hint(&body);
assert_eq!(original_size_hint.exact(), Some(data.len() as u64));
let (body1, body2) = tee(body).await;
assert_eq!(
HttpBody::size_hint(&body1).exact(),
original_size_hint.exact()
);
assert_eq!(
HttpBody::size_hint(&body2).exact(),
original_size_hint.exact()
);
let body1_bytes = hyper::body::to_bytes(body1).await.unwrap();
let body2_bytes = hyper::body::to_bytes(body2).await.unwrap();
assert_eq!(body1_bytes, data.as_bytes());
assert_eq!(body2_bytes, data.as_bytes());
}
}