use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_types::body::SdkBody;
use bytes::Bytes;
use http_body_1x::{Frame, SizeHint};
use pin_project_lite::pin_project;
use std::future::poll_fn;
use std::pin::{pin, Pin};
use std::task::{Context, Poll};
use tokio::sync::mpsc;
pub(crate) fn channel_body() -> (Sender, SdkBody) {
let (tx, rx) = mpsc::channel(1);
let sender = Sender { tx };
let ch_body = ChannelBody { rx };
(sender, SdkBody::from_body_1_x(ch_body))
}
#[derive(Debug)]
pub(crate) struct Sender {
tx: mpsc::Sender<Result<Frame<Bytes>, BoxError>>,
}
impl Sender {
pub(crate) async fn send_data(&mut self, chunk: Bytes) -> Result<(), BoxError> {
let frame = Frame::data(chunk);
self.tx.send(Ok(frame)).await.map_err(|e| e.into())
}
pub(crate) fn abort(self) {
let _ = self.tx.clone().try_send(Err("body write aborted".into()));
}
}
pin_project! {
struct ChannelBody {
rx: mpsc::Receiver<Result<Frame<Bytes>, BoxError>>
}
}
impl http_body_1x::Body for ChannelBody {
type Data = Bytes;
type Error = BoxError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.project();
this.rx.poll_recv(cx)
}
fn is_end_stream(&self) -> bool {
self.rx.is_closed()
}
fn size_hint(&self) -> SizeHint {
SizeHint::default()
}
}
pub(crate) async fn next_data_frame(body: &mut SdkBody) -> Option<Result<Bytes, BoxError>> {
use http_body_1x::Body;
let mut pinned = pin!(body);
match poll_fn(|cx| pinned.as_mut().poll_frame(cx)).await? {
Ok(frame) => {
if frame.is_data() {
Some(Ok(frame.into_data().unwrap()))
} else {
None
}
}
Err(err) => Some(Err(err)),
}
}