use crate::{
error::{ErrorKind, ErrorResponse},
http::{
headers::{HeaderName, Headers},
policies::create_public_api_span,
Context, Format, JsonFormat, Response, StatusCode,
},
sleep,
time::{Duration, OffsetDateTime},
tracing::{Span, SpanStatus},
};
use futures::{channel::oneshot, stream::unfold, Stream, StreamExt};
use serde::Deserialize;
use std::{
convert::Infallible,
fmt,
future::{Future, IntoFuture},
pin::Pin,
str::FromStr,
sync::Arc,
task::{Context as TaskContext, Poll},
};
const DEFAULT_RETRY_TIME: Duration = Duration::seconds(30);
const MIN_RETRY_TIME: Duration = Duration::seconds(1);
#[derive(Debug, Default, PartialEq, Eq)]
pub enum PollerState<N> {
#[default]
Initial,
More(N),
}
impl<N> PollerState<N> {
#[inline]
pub fn map<U, F>(self, f: F) -> PollerState<U>
where
F: FnOnce(N) -> U,
{
match self {
PollerState::Initial => PollerState::Initial,
PollerState::More(c) => PollerState::More(f(c)),
}
}
}
impl<N: Clone> Clone for PollerState<N> {
#[inline]
fn clone(&self) -> Self {
match self {
PollerState::Initial => PollerState::Initial,
PollerState::More(c) => PollerState::More(c.clone()),
}
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub enum PollerStatus {
#[default]
InProgress,
Succeeded,
Failed,
Canceled,
UnknownValue(String),
}
impl From<&str> for PollerStatus {
fn from(value: &str) -> Self {
if "inprogress".eq_ignore_ascii_case(value) {
return PollerStatus::InProgress;
}
if "succeeded".eq_ignore_ascii_case(value) {
return PollerStatus::Succeeded;
}
if "failed".eq_ignore_ascii_case(value) {
return PollerStatus::Failed;
}
if "canceled".eq_ignore_ascii_case(value) || "cancelled".eq_ignore_ascii_case(value) {
return PollerStatus::Canceled;
}
PollerStatus::UnknownValue(value.to_owned())
}
}
impl FromStr for PollerStatus {
type Err = Infallible;
fn from_str(value: &str) -> Result<Self, Self::Err> {
Ok(value.into())
}
}
impl<'de> Deserialize<'de> for PollerStatus {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct PollerStatusVisitor;
impl serde::de::Visitor<'_> for PollerStatusVisitor {
type Value = PollerStatus;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string representing a PollerStatus")
}
fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
FromStr::from_str(s).map_err(serde::de::Error::custom)
}
}
deserializer.deserialize_str(PollerStatusVisitor)
}
}
#[derive(Debug, Clone)]
pub struct PollerOptions<'a> {
pub context: Context<'a>,
pub frequency: Duration,
}
impl Default for PollerOptions<'_> {
fn default() -> Self {
Self {
frequency: DEFAULT_RETRY_TIME,
context: Context::new(),
}
}
}
impl<'a> PollerOptions<'a> {
#[must_use]
pub fn into_owned(self) -> PollerOptions<'static> {
PollerOptions {
context: self.context.into_owned(),
frequency: self.frequency,
}
}
pub fn to_owned(&self) -> PollerOptions<'static> {
PollerOptions {
context: self.context.to_owned(),
frequency: self.frequency,
}
}
}
pub enum PollerResult<M: StatusMonitor, N, F: Format = JsonFormat> {
InProgress {
response: Response<M, F>,
retry_after: Duration,
next: N,
},
Done {
response: Response<M, F>,
},
Succeeded {
response: Response<M, F>,
target: BoxedCallback<M>,
},
}
impl<M: StatusMonitor, N: fmt::Debug, F: Format> fmt::Debug for PollerResult<M, N, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InProgress {
retry_after, next, ..
} => f
.debug_struct("InProgress")
.field("retry_after", &retry_after)
.field("next", &next)
.finish_non_exhaustive(),
Self::Done { .. } => f.debug_struct("Done").finish_non_exhaustive(),
Self::Succeeded { .. } => f.debug_struct("Succeeded").finish_non_exhaustive(),
}
}
}
pub trait StatusMonitor {
type Output;
#[cfg(not(target_arch = "wasm32"))]
type Format: Format + Send;
#[cfg(target_arch = "wasm32")]
type Format: Format;
fn status(&self) -> PollerStatus;
}
#[cfg(not(target_arch = "wasm32"))]
type BoxedStream<M, F> = Box<dyn Stream<Item = crate::Result<Response<M, F>>> + Send>;
#[cfg(target_arch = "wasm32")]
type BoxedStream<M, F> = Box<dyn Stream<Item = crate::Result<Response<M, F>>>>;
#[cfg(not(target_arch = "wasm32"))]
type BoxedFuture<M> = Box<
dyn Future<
Output = crate::Result<
Response<<M as StatusMonitor>::Output, <M as StatusMonitor>::Format>,
>,
> + Send,
>;
#[cfg(target_arch = "wasm32")]
type BoxedFuture<M> = Box<
dyn Future<
Output = crate::Result<
Response<<M as StatusMonitor>::Output, <M as StatusMonitor>::Format>,
>,
>,
>;
#[cfg(not(target_arch = "wasm32"))]
type BoxedCallback<M> = Box<dyn FnOnce() -> Pin<BoxedFuture<M>> + Send>;
#[cfg(target_arch = "wasm32")]
type BoxedCallback<M> = Box<dyn FnOnce() -> Pin<BoxedFuture<M>>>;
#[pin_project::pin_project]
pub struct Poller<M, F: Format = JsonFormat>
where
M: StatusMonitor,
{
#[pin]
stream: Pin<BoxedStream<M, F>>,
target: Option<BoxedFuture<M>>,
}
impl<M, F> Poller<M, F>
where
M: StatusMonitor,
F: Format + Send,
{
pub fn from_callback<
#[cfg(not(target_arch = "wasm32"))] N: AsRef<str> + Send + 'static,
#[cfg(not(target_arch = "wasm32"))] Fun: Fn(PollerState<N>, PollerOptions<'static>) -> Fut + Send + 'static,
#[cfg(not(target_arch = "wasm32"))] Fut: Future<Output = crate::Result<PollerResult<M, N, F>>> + Send + 'static,
#[cfg(target_arch = "wasm32")] N: AsRef<str> + 'static,
#[cfg(target_arch = "wasm32")] Fun: Fn(PollerState<N>, PollerOptions<'static>) -> Fut + 'static,
#[cfg(target_arch = "wasm32")] Fut: Future<Output = crate::Result<PollerResult<M, N, F>>> + 'static,
>(
make_request: Fun,
options: Option<PollerOptions<'static>>,
) -> Self
where
M: Send + 'static,
M::Output: Send + 'static,
M::Format: Send + 'static,
{
let options = options.unwrap_or_default();
let (stream, target) = create_poller_stream(make_request, options);
Self {
stream: Box::pin(stream),
target: Some(target),
}
}
pub fn from_stream<
#[cfg(not(target_arch = "wasm32"))] S: Stream<Item = crate::Result<Response<M, F>>> + Send + 'static,
#[cfg(target_arch = "wasm32")] S: Stream<Item = crate::Result<Response<M, F>>> + 'static,
>(
stream: S,
) -> Self {
Self {
stream: Box::pin(stream),
target: None,
}
}
}
impl<M, F: Format> Stream for Poller<M, F>
where
M: StatusMonitor,
{
type Item = crate::Result<Response<M, F>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
let state = self.project().stream.poll_next(cx);
if let Poll::Ready(Some(Ok(ref response))) = state {
check_status_code(response)?;
}
state
}
}
#[cfg(not(target_arch = "wasm32"))]
impl<M, F: Format + 'static> IntoFuture for Poller<M, F>
where
M: StatusMonitor + 'static,
M::Output: Send + 'static,
M::Format: Send + 'static,
{
type Output = crate::Result<Response<M::Output, M::Format>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
fn into_future(mut self) -> Self::IntoFuture {
Box::pin(async move {
while let Some(result) = self.stream.next().await {
result?;
}
let target = self.target.ok_or_else(|| {
crate::Error::new(
ErrorKind::Other,
"poller completed without a target response",
)
})?;
Box::into_pin(target).await
})
}
}
#[cfg(target_arch = "wasm32")]
impl<M> IntoFuture for Poller<M>
where
M: StatusMonitor + 'static,
M::Output: 'static,
M::Format: 'static,
{
type Output = crate::Result<Response<M::Output, M::Format>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output>>>;
fn into_future(mut self) -> Self::IntoFuture {
Box::pin(async move {
while let Some(result) = self.stream.next().await {
result?;
}
let target = self.target.ok_or_else(|| {
crate::Error::new(
ErrorKind::Other,
"poller completed without a target response",
)
})?;
Box::into_pin(target).await
})
}
}
impl<M: StatusMonitor, F: Format> fmt::Debug for Poller<M, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Poller")
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum State<N> {
Init,
InProgress(N),
Done,
}
type TargetTransmitterType<'a, M> = (Pin<BoxedFuture<M>>, Option<Context<'a>>);
struct StreamState<'a, M, N, Fun>
where
M: StatusMonitor,
{
state: State<N>,
make_request: Fun,
target_tx: Option<oneshot::Sender<TargetTransmitterType<'a, M>>>,
options: PollerOptions<'a>,
added_span: bool,
}
fn create_poller_stream<
M,
F: Format,
#[cfg(not(target_arch = "wasm32"))] N: AsRef<str> + Send + 'static,
#[cfg(not(target_arch = "wasm32"))] Fun: Fn(PollerState<N>, PollerOptions<'static>) -> Fut + Send + 'static,
#[cfg(not(target_arch = "wasm32"))] Fut: Future<Output = crate::Result<PollerResult<M, N, F>>> + Send + 'static,
#[cfg(target_arch = "wasm32")] N: AsRef<str> + 'static,
#[cfg(target_arch = "wasm32")] Fun: Fn(PollerState<N>, PollerOptions<'static>) -> Fut + 'static,
#[cfg(target_arch = "wasm32")] Fut: Future<Output = crate::Result<PollerResult<M, N, F>>> + 'static,
>(
make_request: Fun,
options: PollerOptions<'static>,
) -> (
impl Stream<Item = crate::Result<Response<M, F>>> + 'static,
BoxedFuture<M>,
)
where
M: StatusMonitor + 'static,
M::Output: Send + 'static,
M::Format: Send + 'static,
{
let (target_tx, target_rx) = oneshot::channel();
assert!(
options.frequency >= MIN_RETRY_TIME,
"minimum polling frequency is 1 second"
);
let stream = unfold(
StreamState::<M, N, Fun> {
state: State::Init,
make_request,
target_tx: Some(target_tx),
options,
added_span: false,
},
move |mut poller_stream_state| async move {
let result = match poller_stream_state.state {
State::Init => {
let span =
create_public_api_span(&poller_stream_state.options.context, None, None);
if let Some(ref s) = span {
poller_stream_state.added_span = true;
poller_stream_state.options.context =
poller_stream_state.options.context.with_value(s.clone());
}
(poller_stream_state.make_request)(
PollerState::Initial,
poller_stream_state.options.clone(),
)
.await
}
State::InProgress(n) => {
tracing::debug!(
"subsequent operation request to {:?}",
AsRef::<str>::as_ref(&n)
);
(poller_stream_state.make_request)(
PollerState::More(n),
poller_stream_state.options.clone(),
)
.await
}
State::Done => {
tracing::debug!("done");
return None;
}
};
let (item, next_state) = match result {
Err(e) => {
if poller_stream_state.added_span {
if let Some(span) =
poller_stream_state.options.context.value::<Arc<dyn Span>>()
{
span.set_status(SpanStatus::Error {
description: e.to_string(),
});
span.set_attribute("error.type", e.kind().to_string().into());
span.end();
}
}
poller_stream_state.state = State::Done;
return Some((Err(e), poller_stream_state));
}
Ok(PollerResult::InProgress {
response,
retry_after,
next: n,
}) => {
tracing::trace!("retry poller in {}s", retry_after.whole_seconds());
sleep(retry_after).await;
(Ok(response), State::InProgress(n))
}
Ok(PollerResult::Done { response }) => (Ok(response), State::Done),
Ok(PollerResult::Succeeded {
response,
target: get_target,
}) => {
if let Some(tx) = poller_stream_state.target_tx.take() {
let _ = tx.send((
get_target(),
if poller_stream_state.added_span {
Some(poller_stream_state.options.context.clone())
} else {
None
},
));
}
poller_stream_state.state = State::Done;
return Some((Ok(response), poller_stream_state));
}
};
poller_stream_state.state = next_state;
Some((item, poller_stream_state))
},
);
let target = Box::new(async move {
match target_rx.await {
Ok(target_state) => {
let res = target_state.0.await;
if let Some(ctx) = target_state.1 {
match &res {
Ok(response) => {
if let Some(span) = ctx.value::<Arc<dyn Span>>() {
if response.status().is_server_error() {
span.set_status(SpanStatus::Error {
description: "".to_string(),
});
}
if response.status().is_client_error()
|| response.status().is_server_error()
{
span.set_attribute(
"error.type",
response.status().to_string().into(),
);
}
span.end();
}
}
Err(err) => {
if let Some(span) = ctx.value::<Arc<dyn Span>>() {
span.set_status(SpanStatus::Error {
description: err.to_string(),
});
span.set_attribute("error.type", err.kind().to_string().into());
span.end();
}
}
}
}
res
}
Err(err) => Err(crate::Error::with_error(
ErrorKind::Other,
err,
"poller completed without defining a target",
)),
}
});
(stream, target)
}
pub fn get_retry_after(
headers: &Headers,
retry_headers: &[HeaderName],
options: &PollerOptions,
) -> Duration {
#[cfg_attr(feature = "test", allow(unused_mut))]
let duration =
crate::http::policies::get_retry_after(headers, OffsetDateTime::now_utc, retry_headers)
.unwrap_or(options.frequency);
#[cfg(feature = "test")]
{
use crate::test::RecordingMode;
if matches!(headers.get_optional::<RecordingMode>(), Ok(Some(mode)) if mode == RecordingMode::Playback)
{
if duration > Duration::ZERO {
tracing::debug!(
"overriding {}s poller retry in playback",
duration.whole_seconds()
);
}
return Duration::ZERO;
}
}
duration
}
fn check_status_code<T, F: Format>(response: &Response<T, F>) -> crate::Result<()> {
let status = response.status();
match status {
StatusCode::Ok | StatusCode::Accepted | StatusCode::Created | StatusCode::NoContent => {
Ok(())
}
_ => {
let raw_response = Box::new(response.to_raw_response());
let error_code = F::deserialize(raw_response.body())
.ok()
.and_then(|err: ErrorResponse| err.error)
.and_then(|details| details.code);
Err(ErrorKind::HttpResponse {
status,
error_code,
raw_response: Some(raw_response),
}
.into_error())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "xml")]
use crate::http::XmlFormat;
use crate::http::{
headers::Headers, AsyncRawResponse, HttpClient, Method, NoFormat, RawResponse, Request,
};
use azure_core_test::http::MockHttpClient;
use futures::{FutureExt as _, TryStreamExt as _};
use std::sync::{Arc, Mutex};
#[derive(Debug, serde::Deserialize)]
struct TestStatus {
status: String,
#[serde(default)]
target: Option<String>,
}
#[derive(Debug, serde::Deserialize)]
struct TestOutput {
#[serde(default)]
id: Option<String>,
#[serde(default)]
name: Option<String>,
}
impl StatusMonitor for TestStatus {
type Output = TestOutput;
type Format = JsonFormat;
fn status(&self) -> PollerStatus {
self.status.parse().unwrap_or_default()
}
}
#[cfg(feature = "xml")]
#[derive(Debug, serde::Deserialize)]
struct XmlTestStatus {
status: String,
}
#[cfg(feature = "xml")]
impl StatusMonitor for XmlTestStatus {
type Output = TestOutput;
type Format = XmlFormat;
fn status(&self) -> PollerStatus {
self.status.parse().unwrap_or_default()
}
}
#[tokio::test]
async fn poller_succeeded() {
let call_count = Arc::new(Mutex::new(0));
let mock_client = {
let call_count = call_count.clone();
Arc::new(MockHttpClient::new(move |_| {
let call_count = call_count.clone();
async move {
let mut count = call_count.lock().unwrap();
*count += 1;
if *count == 1 {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Created,
Headers::new(),
br#"{"status":"InProgress"}"#.to_vec(),
))
} else {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
br#"{"status":"Succeeded"}"#.to_vec(),
))
}
}
.boxed()
}))
};
let mut poller = Poller::from_callback(
move |_, _| {
let client = mock_client.clone();
async move {
let req = Request::new("https://example.com".parse().unwrap(), Method::Get);
let raw_response = client.execute_request(&req).await?;
let (status, headers, body) = raw_response.deconstruct();
let bytes = body.collect().await?;
let test_status: TestStatus = crate::json::from_json(&bytes)?;
let response: Response<TestStatus> =
RawResponse::from_bytes(status, headers, bytes).into();
match test_status.status() {
PollerStatus::InProgress => Ok(PollerResult::InProgress {
response,
retry_after: Duration::ZERO,
next: "",
}),
_ => Ok(PollerResult::Done { response }),
}
}
},
None,
);
let first_result = poller.next().await;
assert!(first_result.is_some());
let first_response = first_result.unwrap().unwrap();
assert_eq!(first_response.status(), StatusCode::Created);
let first_body = first_response.into_model().unwrap();
assert_eq!(first_body.status(), PollerStatus::InProgress);
let second_result = poller.next().await;
assert!(second_result.is_some());
let second_response = second_result.unwrap().unwrap();
assert_eq!(second_response.status(), StatusCode::Ok);
let second_body = second_response.into_model().unwrap();
assert_eq!(second_body.status(), PollerStatus::Succeeded);
let third_result = poller.next().await;
assert!(third_result.is_none());
assert_eq!(*call_count.lock().unwrap(), 2);
}
#[tokio::test]
async fn poller_failed() {
let call_count = Arc::new(Mutex::new(0));
let mock_client = {
let call_count = call_count.clone();
Arc::new(MockHttpClient::new(move |_| {
let call_count = call_count.clone();
async move {
let mut count = call_count.lock().unwrap();
*count += 1;
if *count == 1 {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Created,
Headers::new(),
br#"{"status":"InProgress"}"#.to_vec(),
))
} else {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
br#"{"status":"Failed"}"#.to_vec(),
))
}
}
.boxed()
}))
};
let mut poller = Poller::from_callback(
move |_, _| {
let client = mock_client.clone();
async move {
let req = Request::new("https://example.com".parse().unwrap(), Method::Get);
let raw_response = client
.execute_request(&req)
.await?
.try_into_raw_response()
.await?;
let (status, headers, body) = raw_response.deconstruct();
let test_status: TestStatus = crate::json::from_json(&body)?;
let response: Response<TestStatus> =
RawResponse::from_bytes(status, headers, body).into();
match test_status.status() {
PollerStatus::InProgress => Ok(PollerResult::InProgress {
response,
retry_after: Duration::ZERO,
next: "",
}),
_ => Ok(PollerResult::Done { response }),
}
}
},
None,
);
let first_result = poller.next().await;
assert!(first_result.is_some());
let first_response = first_result.unwrap().unwrap();
assert_eq!(first_response.status(), StatusCode::Created);
let first_body = first_response.into_model().unwrap();
assert_eq!(first_body.status(), PollerStatus::InProgress);
let second_result = poller.next().await;
assert!(second_result.is_some());
let second_response = second_result.unwrap().unwrap();
assert_eq!(second_response.status(), StatusCode::Ok);
let second_body = second_response.into_model().unwrap();
assert_eq!(second_body.status(), PollerStatus::Failed);
let third_result = poller.next().await;
assert!(third_result.is_none());
assert_eq!(*call_count.lock().unwrap(), 2);
}
#[tokio::test]
async fn poller_failed_with_http_429() {
let call_count = Arc::new(Mutex::new(0));
let mock_client = {
let call_count = call_count.clone();
Arc::new(MockHttpClient::new(move |_| {
let call_count = call_count.clone();
async move {
let mut count = call_count.lock().unwrap();
*count += 1;
if *count == 1 {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
br#"{"status":"InProgress"}"#.to_vec(),
))
} else {
Ok(AsyncRawResponse::from_bytes(
StatusCode::TooManyRequests,
Headers::new(),
vec![],
))
}
}
.boxed()
}))
};
let mut poller = Poller::from_callback(
move |_, _| {
let client = mock_client.clone();
async move {
let req = Request::new("https://example.com".parse().unwrap(), Method::Get);
let raw_response = client
.execute_request(&req)
.await?
.try_into_raw_response()
.await?;
let (status, headers, body) = raw_response.deconstruct();
if status == StatusCode::Ok {
let test_status: TestStatus = crate::json::from_json(&body)?;
let response: Response<TestStatus> =
RawResponse::from_bytes(status, headers, body).into();
match test_status.status() {
PollerStatus::InProgress => Ok(PollerResult::InProgress {
response,
retry_after: Duration::ZERO,
next: "",
}),
_ => Ok(PollerResult::Done { response }),
}
} else {
let response: Response<TestStatus> =
RawResponse::from_bytes(status, headers, body).into();
Ok(PollerResult::Done { response })
}
}
},
None,
);
let first_result = poller.next().await;
assert!(first_result.is_some());
assert!(first_result.unwrap().is_ok());
let second_result = poller.next().await;
assert!(second_result.is_some());
let error = second_result.unwrap().unwrap_err();
match error.kind() {
ErrorKind::HttpResponse { status, .. } => {
assert_eq!(*status, StatusCode::TooManyRequests);
}
_ => panic!("Expected HttpResponse error, got {:?}", error.kind()),
}
assert_eq!(*call_count.lock().unwrap(), 2);
}
#[tokio::test]
async fn poller_into_future_succeeds() {
let call_count = Arc::new(Mutex::new(0));
let mock_client = {
let call_count = call_count.clone();
Arc::new(MockHttpClient::new(move |_| {
let call_count = call_count.clone();
async move {
let mut count = call_count.lock().unwrap();
*count += 1;
if *count == 1 {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Created,
Headers::new(),
br#"{"status":"InProgress"}"#.to_vec(),
))
} else {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
br#"{"status":"Succeeded","id":"op1","name":"Operation completed successfully"}"#.to_vec(),
))
}
}
.boxed()
}))
};
let poller = Poller::from_callback(
move |_, _| {
let client = mock_client.clone();
async move {
let req = Request::new("https://example.com".parse().unwrap(), Method::Get);
let raw_response = client.execute_request(&req).await?;
let (status, headers, body) = raw_response.deconstruct();
let bytes = body.collect().await?;
let test_status: TestStatus = crate::json::from_json(&bytes)?;
let response: Response<TestStatus> =
RawResponse::from_bytes(status, headers.clone(), bytes.clone()).into();
match test_status.status() {
PollerStatus::InProgress => Ok(PollerResult::InProgress {
response,
retry_after: Duration::ZERO,
next: "",
}),
PollerStatus::Succeeded => {
Ok(PollerResult::Succeeded {
response,
target: Box::new(|| {
Box::pin(async {
use crate::http::headers::Headers;
let headers = Headers::new();
let bytes = bytes::Bytes::from(
r#"{"id": "op1", "name": "Operation completed successfully"}"#,
);
Ok(RawResponse::from_bytes(StatusCode::Ok, headers, bytes)
.into())
})
}),
})
}
_ => Ok(PollerResult::Done { response }),
}
}
},
None,
);
let result = poller.await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.status(), StatusCode::Ok);
let output = response.into_model().unwrap();
assert_eq!(output.id.as_deref(), Some("op1"));
assert_eq!(
output.name.as_deref(),
Some("Operation completed successfully")
);
assert_eq!(*call_count.lock().unwrap(), 2);
}
#[tokio::test]
async fn poller_into_future_with_target_url() {
let call_count = Arc::new(Mutex::new(0));
let mock_client = {
let call_count = call_count.clone();
Arc::new(MockHttpClient::new(move |req: &Request| {
let call_count = call_count.clone();
let url = req.url().to_string();
async move {
let mut count = call_count.lock().unwrap();
*count += 1;
if *count == 1 {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Accepted,
Headers::new(),
br#"{"status":"InProgress"}"#.to_vec(),
))
} else if *count == 2 {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
br#"{"status":"Succeeded","target":"https://example.com/resources/123"}"#.to_vec(),
))
} else {
assert_eq!(url, "https://example.com/resources/123");
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
br#"{"id":"123","name":"Test Resource"}"#.to_vec(),
))
}
}
.boxed()
}))
};
let poller = Poller::from_callback(
move |_, _| {
let client = mock_client.clone();
async move {
let req = Request::new(
"https://example.com/operations/op1".parse().unwrap(),
Method::Get,
);
let raw_response = client.execute_request(&req).await?;
let (status, headers, body) = raw_response.deconstruct();
let bytes = body.collect().await?;
let operation_status: TestStatus = crate::json::from_json(&bytes)?;
let response: Response<TestStatus> =
RawResponse::from_bytes(status, headers.clone(), bytes.clone()).into();
match operation_status.status() {
PollerStatus::InProgress => Ok(PollerResult::InProgress {
response,
retry_after: Duration::ZERO,
next: "",
}),
PollerStatus::Succeeded => {
if let Some(target_url) = operation_status.target {
let client_clone = client.clone();
Ok(PollerResult::Succeeded {
response,
target: Box::new(move || {
Box::pin(async move {
let target_req = Request::new(
target_url.parse().unwrap(),
Method::Get,
);
let target_response =
client_clone.execute_request(&target_req).await?;
let (target_status, target_headers, target_body) =
target_response.deconstruct();
let target_bytes = target_body.collect().await?;
Ok(RawResponse::from_bytes(
target_status,
target_headers,
target_bytes,
)
.into())
})
}),
})
} else {
Err(crate::Error::new(
ErrorKind::Other,
"no target URL in succeeded response",
))
}
}
_ => Ok(PollerResult::Done { response }),
}
}
},
None,
);
let result = poller.await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.status(), StatusCode::Ok);
let resource = response.into_model().unwrap();
assert_eq!(resource.id.as_deref(), Some("123"));
assert_eq!(resource.name.as_deref(), Some("Test Resource"));
assert_eq!(*call_count.lock().unwrap(), 3);
}
#[tokio::test]
async fn poller_into_future_no_response_body() {
#[derive(Debug, serde::Deserialize)]
struct NoBodyStatus {
status: String,
}
impl StatusMonitor for NoBodyStatus {
type Output = ();
type Format = NoFormat;
fn status(&self) -> PollerStatus {
self.status.parse().unwrap_or_default()
}
}
let call_count = Arc::new(Mutex::new(0));
let mock_client = {
let call_count = call_count.clone();
Arc::new(MockHttpClient::new(move |_| {
let call_count = call_count.clone();
async move {
let mut count = call_count.lock().unwrap();
*count += 1;
if *count == 1 {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Accepted,
Headers::new(),
br#"{"status":"InProgress"}"#.to_vec(),
))
} else {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
br#"{"status":"Succeeded"}"#.to_vec(),
))
}
}
.boxed()
}))
};
let poller = Poller::from_callback(
move |_, _| {
let client = mock_client.clone();
async move {
let req = Request::new("https://example.com".parse().unwrap(), Method::Get);
let raw_response = client.execute_request(&req).await?;
let (status, headers, body) = raw_response.deconstruct();
let bytes = body.collect().await?;
let no_body_status: NoBodyStatus = crate::json::from_json(&bytes)?;
let response: Response<NoBodyStatus> =
RawResponse::from_bytes(status, headers.clone(), bytes.clone()).into();
match no_body_status.status() {
PollerStatus::InProgress => Ok(PollerResult::InProgress {
response,
retry_after: Duration::ZERO,
next: "",
}),
PollerStatus::Succeeded => {
Ok(PollerResult::Succeeded {
response,
target: Box::new(move || {
Box::pin(async move {
use crate::http::headers::Headers;
let headers = Headers::new();
Ok(RawResponse::from_bytes(status, headers, Vec::new())
.into())
})
}),
})
}
_ => Ok(PollerResult::Done { response }),
}
}
},
None,
);
let result = poller.await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.status(), StatusCode::Ok);
assert_eq!(*call_count.lock().unwrap(), 2);
}
#[cfg(feature = "xml")]
#[tokio::test]
async fn poller_succeeded_xml() {
let call_count = Arc::new(Mutex::new(0));
let mock_client = {
let call_count = call_count.clone();
Arc::new(MockHttpClient::new(move |_| {
let call_count = call_count.clone();
async move {
let mut count = call_count.lock().unwrap();
*count += 1;
if *count == 1 {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Created,
Headers::new(),
b"<XmlTestStatus><status>InProgress</status></XmlTestStatus>".to_vec(),
))
} else {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
b"<XmlTestStatus><status>Succeeded</status></XmlTestStatus>".to_vec(),
))
}
}
.boxed()
}))
};
let mut poller = Poller::from_callback(
move |_, _| {
let client = mock_client.clone();
async move {
let req = Request::new("https://example.com".parse().unwrap(), Method::Get);
let raw_response = client.execute_request(&req).await?;
let (status, headers, body) = raw_response.deconstruct();
let bytes = body.collect().await?;
let test_status: XmlTestStatus = crate::xml::from_xml(&bytes)?;
let response: Response<XmlTestStatus, XmlFormat> =
RawResponse::from_bytes(status, headers, bytes).into();
match test_status.status() {
PollerStatus::InProgress => Ok(PollerResult::InProgress {
response,
retry_after: Duration::ZERO,
next: "",
}),
_ => Ok(PollerResult::Done { response }),
}
}
},
None,
);
let first_result = poller.next().await;
assert!(first_result.is_some());
let first_response = first_result.unwrap().unwrap();
assert_eq!(first_response.status(), StatusCode::Created);
let first_body = first_response.into_model().unwrap();
assert_eq!(first_body.status(), PollerStatus::InProgress);
let second_result = poller.next().await;
assert!(second_result.is_some());
let second_response = second_result.unwrap().unwrap();
assert_eq!(second_response.status(), StatusCode::Ok);
let second_body = second_response.into_model().unwrap();
assert_eq!(second_body.status(), PollerStatus::Succeeded);
let third_result = poller.next().await;
assert!(third_result.is_none());
assert_eq!(*call_count.lock().unwrap(), 2);
}
#[cfg(feature = "xml")]
#[tokio::test]
async fn poller_into_future_succeeds_xml() {
let call_count = Arc::new(Mutex::new(0));
let mock_client = {
let call_count = call_count.clone();
Arc::new(MockHttpClient::new(move |_| {
let call_count = call_count.clone();
async move {
let mut count = call_count.lock().unwrap();
*count += 1;
if *count == 1 {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Created,
Headers::new(),
b"<XmlTestStatus><status>InProgress</status></XmlTestStatus>"
.to_vec(),
))
} else {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
b"<XmlTestStatus><status>Succeeded</status><id>op1</id><name>Operation completed successfully</name></XmlTestStatus>"
.to_vec(),
))
}
}
.boxed()
}))
};
let poller = Poller::from_callback(
move |_, _| {
let client = mock_client.clone();
async move {
let req = Request::new("https://example.com".parse().unwrap(), Method::Get);
let raw_response = client.execute_request(&req).await?;
let (status, headers, body) = raw_response.deconstruct();
let bytes = body.collect().await?;
let test_status: XmlTestStatus = crate::xml::from_xml(&bytes)?;
let response: Response<XmlTestStatus, XmlFormat> =
RawResponse::from_bytes(status, headers.clone(), bytes.clone()).into();
match test_status.status() {
PollerStatus::InProgress => Ok(PollerResult::InProgress {
response,
retry_after: Duration::ZERO,
next: "",
}),
PollerStatus::Succeeded => {
Ok(PollerResult::Succeeded {
response,
target: Box::new(move || {
Box::pin(async move {
let headers = Headers::new();
let bytes = bytes::Bytes::from(
r#"<TestOutput><id>op1</id><name>Operation completed successfully</name></TestOutput>"#,
);
Ok(RawResponse::from_bytes(StatusCode::Ok, headers, bytes)
.into())
})
}),
})
}
_ => Ok(PollerResult::Done { response }),
}
}
},
None,
);
let result = poller.await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.status(), StatusCode::Ok);
let output = response.into_model().unwrap();
assert_eq!(output.id.as_deref(), Some("op1"));
assert_eq!(
output.name.as_deref(),
Some("Operation completed successfully")
);
assert_eq!(*call_count.lock().unwrap(), 2);
}
#[tokio::test]
async fn poller_into_future_output_is_self() {
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
struct SelfContainedStatus {
status: String,
id: Option<String>,
result: Option<String>,
}
impl StatusMonitor for SelfContainedStatus {
type Output = Self; type Format = JsonFormat;
fn status(&self) -> PollerStatus {
self.status.parse().unwrap_or_default()
}
}
let call_count = Arc::new(Mutex::new(0));
let mock_client = {
let call_count = call_count.clone();
Arc::new(MockHttpClient::new(move |_| {
let call_count = call_count.clone();
async move {
let mut count = call_count.lock().unwrap();
*count += 1;
if *count == 1 {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Created,
Headers::new(),
br#"{"status":"InProgress","id":"op1"}"#.to_vec(),
))
} else {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
br#"{"status":"Succeeded","id":"op1","result":"Operation completed successfully"}"#.to_vec(),
))
}
}
.boxed()
}))
};
let poller = Poller::from_callback(
move |_, _| {
let client = mock_client.clone();
async move {
let req = Request::new("https://example.com".parse().unwrap(), Method::Get);
let raw_response = client.execute_request(&req).await?;
let (status, headers, body) = raw_response.deconstruct();
let bytes = body.collect().await?;
let self_status: SelfContainedStatus = crate::json::from_json(&bytes)?;
let response: Response<SelfContainedStatus> =
RawResponse::from_bytes(status, headers.clone(), bytes.clone()).into();
match self_status.status() {
PollerStatus::InProgress => Ok(PollerResult::InProgress {
response,
retry_after: Duration::ZERO,
next: "",
}),
PollerStatus::Succeeded => {
let final_bytes = bytes.clone();
Ok(PollerResult::Succeeded {
response,
target: Box::new(move || {
Box::pin(async move {
let headers = Headers::new();
Ok(RawResponse::from_bytes(
StatusCode::Ok,
headers,
final_bytes,
)
.into())
})
}),
})
}
_ => Ok(PollerResult::Done { response }),
}
}
},
None,
);
let result = poller.await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.status(), StatusCode::Ok);
let output = response.into_model().unwrap();
assert_eq!(output.id.as_deref(), Some("op1"));
assert_eq!(
output.result.as_deref(),
Some("Operation completed successfully")
);
assert_eq!(*call_count.lock().unwrap(), 2);
}
#[tokio::test]
async fn poller_stream_output_is_self() {
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
struct SelfContainedStatus {
status: String,
id: Option<String>,
result: Option<String>,
}
impl StatusMonitor for SelfContainedStatus {
type Output = Self; type Format = JsonFormat;
fn status(&self) -> PollerStatus {
self.status.parse().unwrap_or_default()
}
}
let call_count = Arc::new(Mutex::new(0));
let mock_client = {
let call_count = call_count.clone();
Arc::new(MockHttpClient::new(move |_| {
let call_count = call_count.clone();
async move {
let mut count = call_count.lock().unwrap();
*count += 1;
if *count == 1 {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Created,
Headers::new(),
br#"{"status":"InProgress","id":"op1"}"#.to_vec(),
))
} else {
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
br#"{"status":"Succeeded","id":"op1","result":"Operation completed successfully"}"#.to_vec(),
))
}
}
.boxed()
}))
};
let mut poller = Poller::from_callback(
move |_, _| {
let client = mock_client.clone();
async move {
let req = Request::new("https://example.com".parse().unwrap(), Method::Get);
let raw_response = client.execute_request(&req).await?;
let (status, headers, body) = raw_response.deconstruct();
let bytes = body.collect().await?;
let self_status: SelfContainedStatus = crate::json::from_json(&bytes)?;
let response: Response<SelfContainedStatus> =
RawResponse::from_bytes(status, headers.clone(), bytes.clone()).into();
match self_status.status() {
PollerStatus::InProgress => Ok(PollerResult::InProgress {
response,
retry_after: Duration::ZERO,
next: "",
}),
PollerStatus::Succeeded => {
let final_bytes = bytes.clone();
Ok(PollerResult::Succeeded {
response,
target: Box::new(move || {
Box::pin(async move {
use crate::http::headers::Headers;
let headers = Headers::new();
Ok(RawResponse::from_bytes(
StatusCode::Ok,
headers,
final_bytes,
)
.into())
})
}),
})
}
_ => Ok(PollerResult::Done { response }),
}
}
},
None,
);
let mut statuses = Vec::new();
while let Some(status_response) = poller.try_next().await.unwrap() {
let status = status_response.into_model().unwrap();
statuses.push(status);
}
assert_eq!(statuses.len(), 2);
assert_eq!(statuses[0].status, "InProgress");
assert_eq!(statuses[0].id.as_deref(), Some("op1"));
assert_eq!(statuses[0].result, None);
assert_eq!(statuses[1].status, "Succeeded");
assert_eq!(statuses[1].id.as_deref(), Some("op1"));
assert_eq!(
statuses[1].result.as_deref(),
Some("Operation completed successfully")
);
assert_eq!(*call_count.lock().unwrap(), 2);
}
}