use aws_smithy_async::future::pagination_stream::PaginationStream;
use aws_smithy_types::error::metadata::ProvideErrorMetadata;
use aws_smithy_types_convert::stream::{PaginationStreamExt, PaginationStreamImplStream};
use aws_types::request_id::RequestId;
use futures_util::Stream;
use pin_project_lite::pin_project;
use std::{
cell::Cell,
error::Error,
pin::Pin,
task::{Context, Poll},
};
use crate::{
KeyValue,
middleware::aws::{AwsSpan, AwsSpanBuilder},
};
struct Void;
impl RequestId for Void {
fn request_id(&self) -> Option<&str> {
None
}
}
enum StreamStateKind {
Waiting,
Flowing,
Finished,
}
#[derive(Default)]
enum StreamState<'a> {
Waiting(Box<AwsSpanBuilder<'a>>),
Flowing(AwsSpan),
Finished,
#[default]
Invalid,
}
impl<'a> StreamState<'a> {
fn new(span: impl Into<AwsSpanBuilder<'a>>) -> Self {
let span = Into::<AwsSpanBuilder>::into(span);
Self::Waiting(Box::new(
span.attribute(KeyValue::new("aws.pagination_stream", true)),
))
}
fn kind(&self) -> StreamStateKind {
match self {
StreamState::Waiting(_) => StreamStateKind::Waiting,
StreamState::Flowing(_) => StreamStateKind::Flowing,
StreamState::Finished => StreamStateKind::Finished,
StreamState::Invalid => {
panic!("Invalid instrumented stream state")
}
}
}
fn start(self) -> Self {
let Self::Waiting(span) = self else {
panic!("Instrumented stream state is not Waiting");
};
Self::Flowing(span.start())
}
fn end<E: RequestId + ProvideErrorMetadata + Error>(
self,
aws_response: &Result<Void, E>,
) -> Self {
let Self::Flowing(span) = self else {
panic!("Instrumented stream state is not Flowing");
};
span.end(aws_response);
Self::Finished
}
}
pin_project! {
pub struct InstrumentedStream<'a, S: Stream> {
#[pin]
inner: S,
state: Cell<StreamState<'a>>,
}
}
impl<T, E, S> Stream for InstrumentedStream<'_, S>
where
E: RequestId + ProvideErrorMetadata + Error,
S: Stream<Item = Result<T, E>>,
{
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match this.state.get_mut().kind() {
StreamStateKind::Waiting => {
this.state.set(this.state.take().start());
this.inner.poll_next(cx)
}
StreamStateKind::Flowing => match this.inner.poll_next(cx) {
Poll::Ready(None) => {
this.state.set(this.state.take().end(&Ok::<_, E>(Void)));
Poll::Ready(None)
}
Poll::Ready(Some(Err(err))) => {
let aws_result = Err(err);
this.state.set(this.state.take().end(&aws_result));
Poll::Ready(aws_result.err().map(Err))
}
result => result,
},
StreamStateKind::Finished => Poll::Ready(None),
}
}
}
pub trait AwsStreamInstrument<T, E, S>
where
E: RequestId + ProvideErrorMetadata + Error,
S: Stream<Item = Result<T, E>>,
{
fn instrument<'a>(
self,
span: impl Into<AwsSpanBuilder<'a>>,
) -> InstrumentedStream<'a, S>;
}
impl<T, E, S> AwsStreamInstrument<T, E, S> for S
where
E: RequestId + ProvideErrorMetadata + Error,
S: Stream<Item = Result<T, E>>,
{
fn instrument<'a>(
self,
span: impl Into<AwsSpanBuilder<'a>>,
) -> InstrumentedStream<'a, S> {
InstrumentedStream {
inner: self,
state: Cell::new(StreamState::new(span)),
}
}
}
impl<T, E> AwsStreamInstrument<T, E, PaginationStreamImplStream<Result<T, E>>>
for PaginationStream<Result<T, E>>
where
E: RequestId + ProvideErrorMetadata + Error,
{
fn instrument<'a>(
self,
span: impl Into<AwsSpanBuilder<'a>>,
) -> InstrumentedStream<'a, PaginationStreamImplStream<Result<T, E>>> {
self.into_stream_03x().instrument(span)
}
}