use crate::body::SdkBody;
use crate::callback::BodyCallback;
use bytes::Buf;
use bytes::Bytes;
use bytes_utils::SegmentedBuf;
use http_body::Body;
use pin_project_lite::pin_project;
use std::error::Error as StdError;
use std::fmt::{Debug, Formatter};
use std::io::IoSlice;
use std::pin::Pin;
use std::task::{Context, Poll};
#[cfg(feature = "rt-tokio")]
mod bytestream_util;
#[cfg(feature = "rt-tokio")]
pub use bytestream_util::Length;
#[cfg(feature = "rt-tokio")]
pub use self::bytestream_util::FsBuilder;
pin_project! {
#[derive(Debug)]
pub struct ByteStream {
#[pin]
inner: Inner<SdkBody>
}
}
impl ByteStream {
pub fn new(body: SdkBody) -> Self {
Self {
inner: Inner::new(body),
}
}
pub fn from_static(bytes: &'static [u8]) -> Self {
Self {
inner: Inner::new(SdkBody::from(Bytes::from_static(bytes))),
}
}
pub fn into_inner(self) -> SdkBody {
self.inner.body
}
pub async fn collect(self) -> Result<AggregatedBytes, Error> {
self.inner.collect().await.map_err(|err| Error(err))
}
#[cfg(feature = "rt-tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt-tokio")))]
pub fn read_from() -> FsBuilder {
FsBuilder::new()
}
#[cfg(feature = "rt-tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt-tokio")))]
pub async fn from_path(path: impl AsRef<std::path::Path>) -> Result<Self, Error> {
FsBuilder::new().path(path).build().await
}
#[deprecated(
since = "0.40.0",
note = "Prefer the more extensible ByteStream::read_from() API"
)]
#[cfg(feature = "rt-tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt-tokio")))]
pub async fn from_file(file: tokio::fs::File) -> Result<Self, Error> {
FsBuilder::new().file(file).build().await
}
pub fn with_body_callback(&mut self, body_callback: Box<dyn BodyCallback>) -> &mut Self {
self.inner.with_body_callback(body_callback);
self
}
#[cfg(feature = "rt-tokio")]
pub fn into_async_read(self) -> impl tokio::io::AsyncRead {
tokio_util::io::StreamReader::new(self)
}
pub fn map(self, f: impl Fn(SdkBody) -> SdkBody + Send + Sync + 'static) -> ByteStream {
ByteStream::new(self.into_inner().map(f))
}
}
impl Default for ByteStream {
fn default() -> Self {
Self {
inner: Inner {
body: SdkBody::from(""),
},
}
}
}
impl From<SdkBody> for ByteStream {
fn from(inp: SdkBody) -> Self {
ByteStream::new(inp)
}
}
impl From<Bytes> for ByteStream {
fn from(input: Bytes) -> Self {
ByteStream::new(SdkBody::from(input))
}
}
impl From<Vec<u8>> for ByteStream {
fn from(input: Vec<u8>) -> Self {
Self::from(Bytes::from(input))
}
}
impl From<hyper::Body> for ByteStream {
fn from(input: hyper::Body) -> Self {
ByteStream::new(SdkBody::from(input))
}
}
#[derive(Debug)]
pub struct Error(Box<dyn StdError + Send + Sync + 'static>);
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(self.0.as_ref() as _)
}
}
impl From<Error> for std::io::Error {
fn from(err: Error) -> Self {
std::io::Error::new(std::io::ErrorKind::Other, err)
}
}
impl futures_core::stream::Stream for ByteStream {
type Item = Result<Bytes, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx).map_err(|e| Error(e))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
#[derive(Debug, Clone)]
pub struct AggregatedBytes(SegmentedBuf<Bytes>);
impl AggregatedBytes {
pub fn into_bytes(mut self) -> Bytes {
self.0.copy_to_bytes(self.0.remaining())
}
}
impl Buf for AggregatedBytes {
fn remaining(&self) -> usize {
self.0.remaining()
}
fn chunk(&self) -> &[u8] {
self.0.chunk()
}
fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
self.0.chunks_vectored(dst)
}
fn advance(&mut self, cnt: usize) {
self.0.advance(cnt)
}
fn copy_to_bytes(&mut self, len: usize) -> Bytes {
self.0.copy_to_bytes(len)
}
}
pin_project! {
#[derive(Debug, Clone, PartialEq, Eq)]
struct Inner<B> {
#[pin]
body: B,
}
}
impl<B> Inner<B> {
fn new(body: B) -> Self {
Self { body }
}
async fn collect(self) -> Result<AggregatedBytes, B::Error>
where
B: http_body::Body<Data = Bytes>,
{
let mut output = SegmentedBuf::new();
let body = self.body;
crate::pin_mut!(body);
while let Some(buf) = body.data().await {
output.push(buf?);
}
Ok(AggregatedBytes(output))
}
}
impl Inner<SdkBody> {
fn with_body_callback(&mut self, body_callback: Box<dyn BodyCallback>) -> &mut Self {
self.body.with_callback(body_callback);
self
}
}
const SIZE_HINT_32_BIT_PANIC_MESSAGE: &str = r#"
You're running a 32-bit system and this stream's length is too large to be represented with a usize.
Please limit stream length to less than 4.294Gb or run this program on a 64-bit computer architecture.
"#;
impl<B> futures_core::stream::Stream for Inner<B>
where
B: http_body::Body<Data = Bytes>,
{
type Item = Result<Bytes, B::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().body.poll_data(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let size_hint = http_body::Body::size_hint(&self.body);
let lower = size_hint.lower().try_into();
let upper = size_hint.upper().map(|u| u.try_into()).transpose();
match (lower, upper) {
(Ok(lower), Ok(upper)) => (lower, upper),
(Err(_), _) | (_, Err(_)) => {
panic!("{}", SIZE_HINT_32_BIT_PANIC_MESSAGE)
}
}
}
}
#[cfg(test)]
mod tests {
use crate::byte_stream::Inner;
use bytes::Bytes;
#[tokio::test]
async fn read_from_string_body() {
let body = hyper::Body::from("a simple body");
assert_eq!(
Inner::new(body)
.collect()
.await
.expect("no errors")
.into_bytes(),
Bytes::from("a simple body")
);
}
#[tokio::test]
async fn read_from_channel_body() {
let (mut sender, body) = hyper::Body::channel();
let byte_stream = Inner::new(body);
tokio::spawn(async move {
sender.send_data(Bytes::from("data 1")).await.unwrap();
sender.send_data(Bytes::from("data 2")).await.unwrap();
sender.send_data(Bytes::from("data 3")).await.unwrap();
});
assert_eq!(
byte_stream.collect().await.expect("no errors").into_bytes(),
Bytes::from("data 1data 2data 3")
);
}
#[cfg(feature = "rt-tokio")]
#[tokio::test]
async fn path_based_bytestreams() -> Result<(), Box<dyn std::error::Error>> {
use super::ByteStream;
use bytes::Buf;
use http_body::Body;
use std::io::Write;
use tempfile::NamedTempFile;
let mut file = NamedTempFile::new()?;
for i in 0..10000 {
writeln!(file, "Brian was here. Briefly. {}", i)?;
}
let body = ByteStream::from_path(&file).await?.into_inner();
assert_eq!(body.size_hint().exact(), Some(298890));
let mut body1 = body.try_clone().expect("retryable bodies are cloneable");
let some_data = body1
.data()
.await
.expect("should have some data")
.expect("read should not fail");
assert!(!some_data.is_empty());
let body2 = body.try_clone().expect("retryable bodies are cloneable");
let body3 = body.try_clone().expect("retryable bodies are cloneable");
let body2 = ByteStream::new(body2).collect().await?.into_bytes();
let body3 = ByteStream::new(body3).collect().await?.into_bytes();
assert_eq!(body2, body3);
assert!(body2.starts_with(b"Brian was here."));
assert!(body2.ends_with(b"9999\n"));
assert_eq!(body2.len(), 298890);
assert_eq!(
ByteStream::new(body1).collect().await?.remaining(),
298890 - some_data.len()
);
Ok(())
}
#[cfg(feature = "rt-tokio")]
#[tokio::test]
async fn bytestream_into_async_read() {
use super::ByteStream;
use tokio::io::AsyncBufReadExt;
let byte_stream = ByteStream::from_static(b"data 1\ndata 2\ndata 3");
let async_buf_read = tokio::io::BufReader::new(byte_stream.into_async_read());
let mut lines = async_buf_read.lines();
assert_eq!(lines.next_line().await.unwrap(), Some("data 1".to_owned()));
assert_eq!(lines.next_line().await.unwrap(), Some("data 2".to_owned()));
assert_eq!(lines.next_line().await.unwrap(), Some("data 3".to_owned()));
assert_eq!(lines.next_line().await.unwrap(), None);
}
}