use crate::body::SdkBody;
use crate::byte_stream::error::Error;
use bytes::Buf;
use bytes::Bytes;
use bytes_utils::SegmentedBuf;
use pin_project_lite::pin_project;
use std::future::poll_fn;
use std::io::IoSlice;
use std::pin::Pin;
use std::task::{Context, Poll};
#[cfg(feature = "rt-tokio")]
mod bytestream_util;
#[cfg(feature = "rt-tokio")]
pub use bytestream_util::Length;
pub mod error;
#[cfg(feature = "rt-tokio")]
pub use self::bytestream_util::FsBuilder;
#[cfg(feature = "http-body-0-4-x")]
pub mod http_body_0_4_x;
#[cfg(feature = "http-body-1-x")]
pub mod http_body_1_x;
pin_project! {
#[derive(Debug)]
pub struct ByteStream {
#[pin]
inner: Inner,
}
}
impl ByteStream {
pub fn new(body: SdkBody) -> Self {
Self {
inner: Inner::new(body),
}
}
pub fn from_static(bytes: &'static [u8]) -> Self {
Self {
inner: Inner::new(SdkBody::from(Bytes::from_static(bytes))),
}
}
pub fn into_inner(self) -> SdkBody {
self.inner.body
}
pub async fn next(&mut self) -> Option<Result<Bytes, Error>> {
Some(self.inner.next().await?.map_err(Error::streaming))
}
#[cfg(feature = "byte-stream-poll-next")]
pub fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Error>>> {
self.project().inner.poll_next(cx).map_err(Error::streaming)
}
pub async fn try_next(&mut self) -> Result<Option<Bytes>, Error> {
self.next().await.transpose()
}
pub fn bytes(&self) -> Option<&[u8]> {
let Inner { body } = &self.inner;
body.bytes()
}
pub fn size_hint(&self) -> (u64, Option<u64>) {
self.inner.size_hint()
}
pub async fn collect(self) -> Result<AggregatedBytes, Error> {
self.inner.collect().await.map_err(Error::streaming)
}
#[cfg(feature = "rt-tokio")]
pub fn read_from() -> crate::byte_stream::FsBuilder {
crate::byte_stream::FsBuilder::new()
}
#[cfg(feature = "rt-tokio")]
pub async fn from_path(
path: impl AsRef<std::path::Path>,
) -> Result<Self, crate::byte_stream::error::Error> {
crate::byte_stream::FsBuilder::new()
.path(path)
.build()
.await
}
#[cfg(feature = "rt-tokio")]
pub fn into_async_read(self) -> impl tokio::io::AsyncBufRead {
struct FuturesStreamCompatByteStream(ByteStream);
impl futures_core::stream::Stream for FuturesStreamCompatByteStream {
type Item = Result<Bytes, Error>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.0.inner)
.poll_next(cx)
.map_err(Error::streaming)
}
}
tokio_util::io::StreamReader::new(FuturesStreamCompatByteStream(self))
}
pub fn map(self, f: impl Fn(SdkBody) -> SdkBody + Send + Sync + 'static) -> ByteStream {
ByteStream::new(self.into_inner().map(f))
}
}
impl Default for ByteStream {
fn default() -> Self {
Self {
inner: Inner {
body: SdkBody::from(""),
},
}
}
}
impl From<SdkBody> for ByteStream {
fn from(inp: SdkBody) -> Self {
ByteStream::new(inp)
}
}
impl From<Bytes> for ByteStream {
fn from(input: Bytes) -> Self {
ByteStream::new(SdkBody::from(input))
}
}
impl From<Vec<u8>> for ByteStream {
fn from(input: Vec<u8>) -> Self {
Self::from(Bytes::from(input))
}
}
#[derive(Debug, Clone)]
pub struct AggregatedBytes(SegmentedBuf<Bytes>);
impl AggregatedBytes {
pub fn into_bytes(mut self) -> Bytes {
self.0.copy_to_bytes(self.0.remaining())
}
pub fn into_segments(self) -> impl Iterator<Item = Bytes> {
self.0.into_inner().into_iter()
}
pub fn to_vec(self) -> Vec<u8> {
self.0.into_inner().into_iter().flatten().collect()
}
}
impl Buf for AggregatedBytes {
fn remaining(&self) -> usize {
self.0.remaining()
}
fn chunk(&self) -> &[u8] {
self.0.chunk()
}
fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
self.0.chunks_vectored(dst)
}
fn advance(&mut self, cnt: usize) {
self.0.advance(cnt)
}
fn copy_to_bytes(&mut self, len: usize) -> Bytes {
self.0.copy_to_bytes(len)
}
}
pin_project! {
#[derive(Debug)]
struct Inner {
#[pin]
body: SdkBody,
}
}
impl Inner {
fn new(body: SdkBody) -> Self {
Self { body }
}
async fn next(&mut self) -> Option<Result<Bytes, crate::body::Error>> {
let mut me = Pin::new(self);
poll_fn(|cx| me.as_mut().poll_next(cx)).await
}
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, crate::body::Error>>> {
self.project().body.poll_next(cx)
}
async fn collect(self) -> Result<AggregatedBytes, crate::body::Error> {
let mut output = SegmentedBuf::new();
let body = self.body;
pin_utils::pin_mut!(body);
while let Some(buf) = body.next().await {
output.push(buf?);
}
Ok(AggregatedBytes(output))
}
fn size_hint(&self) -> (u64, Option<u64>) {
self.body.bounds_on_remaining_length()
}
}
#[cfg(all(test, feature = "rt-tokio"))]
mod tests {
use super::{ByteStream, Inner};
use crate::body::SdkBody;
use bytes::Bytes;
use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test]
async fn read_from_string_body() {
let body = SdkBody::from("a simple body");
assert_eq!(
Inner::new(body)
.collect()
.await
.expect("no errors")
.into_bytes(),
Bytes::from("a simple body")
);
}
#[tokio::test]
async fn bytestream_into_async_read() {
use tokio::io::AsyncBufReadExt;
let byte_stream = ByteStream::from_static(b"data 1\ndata 2\ndata 3");
let async_buf_read = tokio::io::BufReader::new(byte_stream.into_async_read());
let mut lines = async_buf_read.lines();
assert_eq!(lines.next_line().await.unwrap(), Some("data 1".to_owned()));
assert_eq!(lines.next_line().await.unwrap(), Some("data 2".to_owned()));
assert_eq!(lines.next_line().await.unwrap(), Some("data 3".to_owned()));
assert_eq!(lines.next_line().await.unwrap(), None);
}
#[tokio::test]
async fn valid_size_hint() {
assert_eq!(ByteStream::from_static(b"hello").size_hint().1, Some(5));
assert_eq!(ByteStream::from_static(b"").size_hint().1, Some(0));
let mut f = NamedTempFile::new().unwrap();
f.write_all(b"hello").unwrap();
let body = ByteStream::from_path(f.path()).await.unwrap();
assert_eq!(body.inner.size_hint().1, Some(5));
let mut f = NamedTempFile::new().unwrap();
f.write_all(b"").unwrap();
let body = ByteStream::from_path(f.path()).await.unwrap();
assert_eq!(body.inner.size_hint().1, Some(0));
}
#[allow(clippy::bool_assert_comparison)]
#[tokio::test]
async fn valid_eos() {
assert_eq!(
ByteStream::from_static(b"hello").inner.body.is_end_stream(),
false
);
let mut f = NamedTempFile::new().unwrap();
f.write_all(b"hello").unwrap();
let body = ByteStream::from_path(f.path()).await.unwrap();
assert_eq!(body.inner.body.content_length(), Some(5));
assert!(!body.inner.body.is_end_stream());
assert_eq!(
ByteStream::from_static(b"").inner.body.is_end_stream(),
true
);
let mut f = NamedTempFile::new().unwrap();
f.write_all(b"").unwrap();
let body = ByteStream::from_path(f.path()).await.unwrap();
assert_eq!(body.inner.body.content_length(), Some(0));
assert!(body.inner.body.is_end_stream());
}
}