use http::{HeaderMap, HeaderValue};
type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub trait BodyCallback: Send + Sync {
fn update(&mut self, bytes: &[u8]) -> Result<(), BoxError> {
let _ = bytes;
Ok(())
}
fn trailers(&self) -> Result<Option<HeaderMap<HeaderValue>>, BoxError> {
Ok(None)
}
fn make_new(&self) -> Box<dyn BodyCallback>;
}
impl BodyCallback for Box<dyn BodyCallback> {
fn update(&mut self, bytes: &[u8]) -> Result<(), BoxError> {
self.as_mut().update(bytes)
}
fn trailers(&self) -> Result<Option<HeaderMap<HeaderValue>>, BoxError> {
self.as_ref().trailers()
}
fn make_new(&self) -> Box<dyn BodyCallback> {
self.as_ref().make_new()
}
}
#[cfg(test)]
mod tests {
use super::{BodyCallback, BoxError};
use crate::body::SdkBody;
use crate::byte_stream::ByteStream;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[tracing_test::traced_test]
#[tokio::test]
async fn callbacks_are_called_for_update() {
struct CallbackA;
struct CallbackB;
impl BodyCallback for CallbackA {
fn update(&mut self, _bytes: &[u8]) -> Result<(), BoxError> {
tracing::debug!("callback A was called");
Ok(())
}
fn make_new(&self) -> Box<dyn BodyCallback> {
Box::new(Self)
}
}
impl BodyCallback for CallbackB {
fn update(&mut self, _bytes: &[u8]) -> Result<(), BoxError> {
tracing::debug!("callback B was called");
Ok(())
}
fn make_new(&self) -> Box<dyn BodyCallback> {
Box::new(Self)
}
}
let mut body = SdkBody::from("test");
body.with_callback(Box::new(CallbackA))
.with_callback(Box::new(CallbackB));
let body = ByteStream::from(body).collect().await.unwrap().into_bytes();
let body = std::str::from_utf8(&body).unwrap();
assert_eq!(body, "test");
assert!(logs_contain("callback A was called"));
assert!(logs_contain("callback B was called"));
}
struct TestCallback {
times_called: Arc<AtomicUsize>,
}
impl BodyCallback for TestCallback {
fn update(&mut self, _bytes: &[u8]) -> Result<(), BoxError> {
self.times_called.fetch_add(1, Ordering::SeqCst);
Ok(())
}
fn make_new(&self) -> Box<dyn BodyCallback> {
Box::new(Self {
times_called: Arc::new(AtomicUsize::new(0)),
})
}
}
#[tokio::test]
async fn callback_for_buffered_body_is_called_once() {
let times_called = Arc::new(AtomicUsize::new(0));
let test_text: String = (0..=1000)
.into_iter()
.map(|n| format!("line {}\n", n))
.collect();
{
let mut body = SdkBody::from(test_text);
let callback = TestCallback {
times_called: times_called.clone(),
};
body.with_callback(Box::new(callback));
let _body = ByteStream::new(body).collect().await.unwrap().into_bytes();
}
let times_called = Arc::try_unwrap(times_called).unwrap();
let times_called = times_called.into_inner();
assert_eq!(times_called, 1);
}
#[tracing_test::traced_test]
#[tokio::test]
async fn callback_for_streaming_body_is_called_per_chunk() {
let times_called = Arc::new(AtomicUsize::new(0));
{
let test_stream = tokio_stream::iter(
(1..=1000)
.into_iter()
.map(|n| -> Result<String, std::io::Error> { Ok(format!("line {}\n", n)) }),
);
let mut body = SdkBody::from(hyper::body::Body::wrap_stream(test_stream));
tracing::trace!("{:?}", body);
assert!(logs_contain("Streaming(Body(Streaming))"));
let callback = TestCallback {
times_called: times_called.clone(),
};
body.with_callback(Box::new(callback));
let _body = ByteStream::new(body).collect().await.unwrap().into_bytes();
}
let times_called = Arc::try_unwrap(times_called).unwrap();
let times_called = times_called.into_inner();
assert_eq!(times_called, 1000);
}
}