use crate::{Poller, PollingBackoffPolicy, PollingErrorPolicy, PollingResult, Result};
use google_cloud_gax::polling_state::PollingState;
use google_cloud_wkt::Empty;
use google_cloud_wkt::message::Message;
use std::sync::Arc;
pub type Operation<R, M> = crate::details::Operation<R, M>;
pub fn new_poller<ResponseType, MetadataType, S, SF, Q, QF>(
polling_error_policy: Arc<dyn PollingErrorPolicy>,
polling_backoff_policy: Arc<dyn PollingBackoffPolicy>,
start: S,
query: Q,
) -> impl Poller<ResponseType, MetadataType>
where
ResponseType: Message + serde::ser::Serialize + serde::de::DeserializeOwned + Send,
MetadataType: Message + serde::ser::Serialize + serde::de::DeserializeOwned + Send,
S: FnOnce() -> SF + Send + Sync,
SF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
+ Send
+ 'static,
Q: Fn(String) -> QF + Send + Sync + Clone,
QF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
+ Send
+ 'static,
{
PollerImpl::new(polling_error_policy, polling_backoff_policy, start, query)
}
pub fn new_unit_response_poller<MetadataType, S, SF, Q, QF>(
polling_error_policy: Arc<dyn PollingErrorPolicy>,
polling_backoff_policy: Arc<dyn PollingBackoffPolicy>,
start: S,
query: Q,
) -> impl Poller<(), MetadataType>
where
MetadataType: Message + serde::ser::Serialize + serde::de::DeserializeOwned + Send,
S: FnOnce() -> SF + Send + Sync,
SF: std::future::Future<Output = Result<Operation<Empty, MetadataType>>> + Send + 'static,
Q: Fn(String) -> QF + Send + Sync + Clone,
QF: std::future::Future<Output = Result<Operation<Empty, MetadataType>>> + Send + 'static,
{
let poller = new_poller(polling_error_policy, polling_backoff_policy, start, query);
UnitResponsePoller::new(poller)
}
pub fn new_unit_metadata_poller<ResponseType, S, SF, Q, QF>(
polling_error_policy: Arc<dyn PollingErrorPolicy>,
polling_backoff_policy: Arc<dyn PollingBackoffPolicy>,
start: S,
query: Q,
) -> impl Poller<ResponseType, ()>
where
ResponseType: Message + serde::ser::Serialize + serde::de::DeserializeOwned + Send,
S: FnOnce() -> SF + Send + Sync,
SF: std::future::Future<Output = Result<Operation<ResponseType, Empty>>> + Send + 'static,
Q: Fn(String) -> QF + Send + Sync + Clone,
QF: std::future::Future<Output = Result<Operation<ResponseType, Empty>>> + Send + 'static,
{
let poller = new_poller(polling_error_policy, polling_backoff_policy, start, query);
UnitMetadataPoller::new(poller)
}
pub fn new_unit_poller<S, SF, Q, QF>(
polling_error_policy: Arc<dyn PollingErrorPolicy>,
polling_backoff_policy: Arc<dyn PollingBackoffPolicy>,
start: S,
query: Q,
) -> impl Poller<(), ()>
where
S: FnOnce() -> SF + Send + Sync,
SF: std::future::Future<Output = Result<Operation<Empty, Empty>>> + Send + 'static,
Q: Fn(String) -> QF + Send + Sync + Clone,
QF: std::future::Future<Output = Result<Operation<Empty, Empty>>> + Send + 'static,
{
let poller = new_poller(polling_error_policy, polling_backoff_policy, start, query);
UnitResponsePoller::new(UnitMetadataPoller::new(poller))
}
struct UnitResponsePoller<P> {
poller: P,
}
impl<P> UnitResponsePoller<P> {
pub(crate) fn new(poller: P) -> Self {
Self { poller }
}
}
impl<P> crate::sealed::Poller for UnitResponsePoller<P> {}
impl<P, M> Poller<(), M> for UnitResponsePoller<P>
where
P: Poller<Empty, M>,
{
async fn poll(&mut self) -> Option<PollingResult<(), M>> {
self.poller.poll().await.map(self::map_polling_result)
}
async fn until_done(self) -> Result<()> {
self.poller.until_done().await.map(|_| ())
}
#[cfg(feature = "unstable-stream")]
fn into_stream(self) -> impl futures::Stream<Item = PollingResult<(), M>> + Unpin {
use futures::StreamExt;
self.poller.into_stream().map(self::map_polling_result)
}
}
struct UnitMetadataPoller<P> {
poller: P,
}
impl<P> UnitMetadataPoller<P> {
pub(crate) fn new(poller: P) -> Self {
Self { poller }
}
}
impl<P> crate::sealed::Poller for UnitMetadataPoller<P> {}
impl<P, R> Poller<R, ()> for UnitMetadataPoller<P>
where
P: Poller<R, Empty>,
{
async fn poll(&mut self) -> Option<PollingResult<R, ()>> {
self.poller.poll().await.map(self::map_polling_metadata)
}
async fn until_done(self) -> Result<R> {
self.poller.until_done().await
}
#[cfg(feature = "unstable-stream")]
fn into_stream(self) -> impl futures::Stream<Item = PollingResult<R, ()>> + Unpin {
use futures::StreamExt;
self.poller.into_stream().map(self::map_polling_metadata)
}
}
fn map_polling_result<M>(result: PollingResult<Empty, M>) -> PollingResult<(), M> {
match result {
PollingResult::Completed(r) => PollingResult::Completed(r.map(|_| ())),
PollingResult::InProgress(m) => PollingResult::InProgress(m),
PollingResult::PollingError(e) => PollingResult::PollingError(e),
}
}
fn map_polling_metadata<R>(result: PollingResult<R, Empty>) -> PollingResult<R, ()> {
match result {
PollingResult::Completed(r) => PollingResult::Completed(r),
PollingResult::InProgress(m) => PollingResult::InProgress(m.map(|_| ())),
PollingResult::PollingError(e) => PollingResult::PollingError(e),
}
}
struct PollerImpl<S, Q> {
error_policy: Arc<dyn PollingErrorPolicy>,
backoff_policy: Arc<dyn PollingBackoffPolicy>,
start: Option<S>,
query: Q,
operation: Option<String>,
state: PollingState,
}
impl<S, Q> PollerImpl<S, Q> {
pub fn new(
error_policy: Arc<dyn PollingErrorPolicy>,
backoff_policy: Arc<dyn PollingBackoffPolicy>,
start: S,
query: Q,
) -> Self {
Self {
error_policy,
backoff_policy,
start: Some(start),
query,
operation: None,
state: PollingState::default(),
}
}
}
impl<ResponseType, MetadataType, S, SF, P, PF> Poller<ResponseType, MetadataType>
for PollerImpl<S, P>
where
ResponseType: Message + serde::ser::Serialize + serde::de::DeserializeOwned + Send,
MetadataType: Message + serde::ser::Serialize + serde::de::DeserializeOwned + Send,
S: FnOnce() -> SF + Send + Sync,
SF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
+ Send
+ 'static,
P: Fn(String) -> PF + Send + Sync + Clone,
PF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
+ Send
+ 'static,
{
async fn poll(&mut self) -> Option<PollingResult<ResponseType, MetadataType>> {
if let Some(start) = self.start.take() {
let result = start().await;
let (op, poll) = crate::details::handle_start(result);
self.operation = op;
return Some(poll);
}
if let Some(name) = self.operation.take() {
self.state.attempt_count += 1;
let result = (self.query)(name.clone()).await;
let (op, poll) =
crate::details::handle_poll(self.error_policy.clone(), &self.state, name, result);
self.operation = op;
return Some(poll);
}
None
}
async fn until_done(mut self) -> Result<ResponseType> {
let mut state = PollingState::default();
while let Some(p) = self.poll().await {
match p {
PollingResult::Completed(r) => return r,
PollingResult::InProgress(_) => (),
PollingResult::PollingError(_) => (),
}
state.attempt_count += 1;
tokio::time::sleep(self.backoff_policy.wait_period(&state)).await;
}
unreachable!("loop should exit via the `Completed` branch vs. this line");
}
#[cfg(feature = "unstable-stream")]
fn into_stream(
self,
) -> impl futures::Stream<Item = PollingResult<ResponseType, MetadataType>> + Unpin
where
ResponseType: Message + serde::de::DeserializeOwned,
MetadataType: Message + serde::de::DeserializeOwned,
{
use futures::stream::unfold;
Box::pin(unfold(self, |mut poller| async move {
poller.poll().await.map(|item| (item, poller))
}))
}
}
impl<S, Q> crate::sealed::Poller for PollerImpl<S, Q> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::Error;
use google_cloud_gax::error::rpc::{Code, Status};
use google_cloud_gax::exponential_backoff::ExponentialBackoff;
use google_cloud_gax::exponential_backoff::ExponentialBackoffBuilder;
use google_cloud_gax::polling_error_policy::{Aip194Strict, AlwaysContinue};
use google_cloud_longrunning::model::{
Operation as OperationAny, operation::Result as ResultAny,
};
use google_cloud_wkt::{Any, Duration, Timestamp};
use std::time::Duration as StdDuration;
type ResponseType = Duration;
type MetadataType = Timestamp;
type TestOperation = Operation<ResponseType, MetadataType>;
type EmptyResponseOperation = Operation<Empty, MetadataType>;
type EmptyMetadataOperation = Operation<ResponseType, Empty>;
#[tokio::test(flavor = "multi_thread")]
async fn poll_basic_flow() {
let start = || async move {
let any = Any::from_msg(&Timestamp::clamp(123, 0))
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = TestOperation::new(op);
Ok::<TestOperation, Error>(op)
};
let query = |_: String| async move {
let any = Any::from_msg(&Duration::clamp(234, 0))
.expect("test message deserializes via Any::from_msg");
let result = ResultAny::Response(any.into());
let op = OperationAny::default().set_done(true).set_result(result);
let op = TestOperation::new(op);
Ok::<TestOperation, Error>(op)
};
let mut poller = PollerImpl::new(
Arc::new(AlwaysContinue),
Arc::new(ExponentialBackoff::default()),
start,
query,
);
let p0 = poller.poll().await;
match p0.unwrap() {
PollingResult::InProgress(m) => {
assert_eq!(m, Some(Timestamp::clamp(123, 0)));
}
r => {
panic!("{r:?}");
}
}
let p1 = poller.poll().await;
match p1.unwrap() {
PollingResult::Completed(r) => {
let response = r.unwrap();
assert_eq!(response, Duration::clamp(234, 0));
}
r => {
panic!("{r:?}");
}
}
let p2 = poller.poll().await;
assert!(p2.is_none(), "{p2:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn poll_basic_stream() {
let start = || async move {
let any = Any::from_msg(&Timestamp::clamp(123, 0))
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = TestOperation::new(op);
Ok::<TestOperation, Error>(op)
};
let query = |_: String| async move {
let any = Any::from_msg(&Duration::clamp(234, 0))
.expect("test message deserializes via Any::from_msg");
let result = ResultAny::Response(any.into());
let op = OperationAny::default().set_done(true).set_result(result);
let op = TestOperation::new(op);
Ok::<TestOperation, Error>(op)
};
use futures::StreamExt;
let mut stream = new_poller(
Arc::new(AlwaysContinue),
Arc::new(ExponentialBackoff::default()),
start,
query,
)
.into_stream();
let p0 = stream.next().await;
match p0.unwrap() {
PollingResult::InProgress(m) => {
assert_eq!(m, Some(Timestamp::clamp(123, 0)));
}
r => {
panic!("{r:?}");
}
}
let p1 = stream.next().await;
match p1.unwrap() {
PollingResult::Completed(r) => {
let response = r.unwrap();
assert_eq!(response, Duration::clamp(234, 0));
}
r => {
panic!("{r:?}");
}
}
let p2 = stream.next().await;
assert!(p2.is_none(), "{p2:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn until_done_basic_flow() -> Result<()> {
let start = || async move {
let any = Any::from_msg(&Timestamp::clamp(123, 0))
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = TestOperation::new(op);
Ok::<TestOperation, Error>(op)
};
let query = |_: String| async move {
let any = Any::from_msg(&Duration::clamp(234, 0))
.expect("test message deserializes via Any::from_msg");
let result = ResultAny::Response(any.into());
let op = OperationAny::default().set_done(true).set_result(result);
let op = TestOperation::new(op);
Ok::<TestOperation, Error>(op)
};
let poller = PollerImpl::new(
Arc::new(AlwaysContinue),
Arc::new(
ExponentialBackoffBuilder::new()
.with_initial_delay(StdDuration::from_millis(1))
.clamp(),
),
start,
query,
);
let response = poller.until_done().await?;
assert_eq!(response, Duration::clamp(234, 0));
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn unit_poll_basic_flow() {
let start = || async move {
let any = Any::from_msg(&Timestamp::clamp(123, 0))
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = EmptyResponseOperation::new(op);
Ok::<EmptyResponseOperation, Error>(op)
};
let query = |_: String| async move {
let any = Any::from_msg(&Empty::default())
.expect("test message deserializes via Any::from_msg");
let result = ResultAny::Response(any.into());
let op = OperationAny::default().set_done(true).set_result(result);
let op = EmptyResponseOperation::new(op);
Ok::<EmptyResponseOperation, Error>(op)
};
let mut poller = new_unit_response_poller(
Arc::new(AlwaysContinue),
Arc::new(ExponentialBackoff::default()),
start,
query,
);
let p0 = poller.poll().await;
match p0.unwrap() {
PollingResult::InProgress(m) => {
assert_eq!(m, Some(Timestamp::clamp(123, 0)));
}
r => {
panic!("{r:?}");
}
}
let p1 = poller.poll().await;
match p1.unwrap() {
PollingResult::Completed(Ok(_)) => {}
PollingResult::Completed(Err(e)) => {
panic!("{e}");
}
r => {
panic!("{r:?}");
}
}
let p2 = poller.poll().await;
assert!(p2.is_none(), "{p2:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn unit_poll_basic_stream() {
let start = || async move {
let any = Any::from_msg(&Timestamp::clamp(123, 0))
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = EmptyResponseOperation::new(op);
Ok::<EmptyResponseOperation, Error>(op)
};
let query = |_: String| async move {
let any = Any::from_msg(&Empty::default())
.expect("test message deserializes via Any::from_msg");
let result = ResultAny::Response(any.into());
let op = OperationAny::default().set_done(true).set_result(result);
let op = EmptyResponseOperation::new(op);
Ok::<EmptyResponseOperation, Error>(op)
};
use futures::StreamExt;
let mut stream = new_unit_response_poller(
Arc::new(AlwaysContinue),
Arc::new(ExponentialBackoff::default()),
start,
query,
)
.into_stream();
let p0 = stream.next().await;
match p0.unwrap() {
PollingResult::InProgress(m) => {
assert_eq!(m, Some(Timestamp::clamp(123, 0)));
}
r => {
panic!("{r:?}");
}
}
let p1 = stream.next().await;
match p1.unwrap() {
PollingResult::Completed(Ok(_)) => {}
PollingResult::Completed(Err(e)) => {
panic!("{e}");
}
r => {
panic!("{r:?}");
}
}
let p2 = stream.next().await;
assert!(p2.is_none(), "{p2:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn unit_until_done_basic_flow() -> Result<()> {
let start = || async move {
let any = Any::from_msg(&Timestamp::clamp(123, 0))
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = EmptyResponseOperation::new(op);
Ok::<EmptyResponseOperation, Error>(op)
};
let query = |_: String| async move {
let any = Any::from_msg(&Empty::default())
.expect("test message deserializes via Any::from_msg");
let result = ResultAny::Response(any.into());
let op = OperationAny::default().set_done(true).set_result(result);
let op = EmptyResponseOperation::new(op);
Ok::<EmptyResponseOperation, Error>(op)
};
let poller = new_unit_response_poller(
Arc::new(AlwaysContinue),
Arc::new(
ExponentialBackoffBuilder::new()
.with_initial_delay(StdDuration::from_millis(1))
.clamp(),
),
start,
query,
);
poller.until_done().await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn unit_metadata_poll_basic_flow() {
let start = || async move {
let any = Any::from_msg(&Empty::default())
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = EmptyMetadataOperation::new(op);
Ok::<EmptyMetadataOperation, Error>(op)
};
let query = |_: String| async move {
let any = Any::from_msg(&Duration::clamp(123, 456))
.expect("test message deserializes via Any::from_msg");
let result = ResultAny::Response(any.into());
let op = OperationAny::default().set_done(true).set_result(result);
let op = EmptyMetadataOperation::new(op);
Ok::<EmptyMetadataOperation, Error>(op)
};
let mut poller = new_unit_metadata_poller(
Arc::new(AlwaysContinue),
Arc::new(ExponentialBackoff::default()),
start,
query,
);
let p0 = poller.poll().await;
match p0.unwrap() {
PollingResult::InProgress(m) => {
assert_eq!(m, Some(()));
}
r => {
panic!("{r:?}");
}
}
let p1 = poller.poll().await;
match p1.unwrap() {
PollingResult::Completed(Ok(_)) => {}
PollingResult::Completed(Err(e)) => {
panic!("{e}");
}
r => {
panic!("{r:?}");
}
}
let p2 = poller.poll().await;
assert!(p2.is_none(), "{p2:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn unit_metadata_poll_basic_stream() {
let start = || async move {
let any = Any::from_msg(&Empty::default())
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = EmptyMetadataOperation::new(op);
Ok::<EmptyMetadataOperation, Error>(op)
};
let query = |_: String| async move {
let any = Any::from_msg(&Duration::clamp(123, 456))
.expect("test message deserializes via Any::from_msg");
let result = ResultAny::Response(any.into());
let op = OperationAny::default().set_done(true).set_result(result);
let op = EmptyMetadataOperation::new(op);
Ok::<EmptyMetadataOperation, Error>(op)
};
use futures::StreamExt;
let mut stream = new_unit_metadata_poller(
Arc::new(AlwaysContinue),
Arc::new(ExponentialBackoff::default()),
start,
query,
)
.into_stream();
let p0 = stream.next().await;
match p0.unwrap() {
PollingResult::InProgress(m) => {
assert_eq!(m, Some(()));
}
r => {
panic!("{r:?}");
}
}
let p1 = stream.next().await;
match p1.unwrap() {
PollingResult::Completed(Ok(d)) => {
assert_eq!(d, Duration::clamp(123, 456));
}
PollingResult::Completed(Err(e)) => {
panic!("{e}");
}
r => {
panic!("{r:?}");
}
}
let p2 = stream.next().await;
assert!(p2.is_none(), "{p2:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn unit_metadata_until_done_basic_flow() -> Result<()> {
let start = || async move {
let any = Any::from_msg(&Empty::default())
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = EmptyMetadataOperation::new(op);
Ok::<EmptyMetadataOperation, Error>(op)
};
let query = |_: String| async move {
let any = Any::from_msg(&Duration::clamp(123, 456))
.expect("test message deserializes via Any::from_msg");
let result = ResultAny::Response(any.into());
let op = OperationAny::default().set_done(true).set_result(result);
let op = EmptyMetadataOperation::new(op);
Ok::<EmptyMetadataOperation, Error>(op)
};
let poller = new_unit_metadata_poller(
Arc::new(AlwaysContinue),
Arc::new(
ExponentialBackoffBuilder::new()
.with_initial_delay(StdDuration::from_millis(1))
.clamp(),
),
start,
query,
);
let d = poller.until_done().await?;
assert_eq!(d, Duration::clamp(123, 456));
Ok(())
}
#[test]
fn unit_result_map() {
use PollingResult::{Completed, InProgress, PollingError};
type TestResult = PollingResult<Empty, Timestamp>;
let got = map_polling_result(TestResult::Completed(Ok(Empty::default())));
assert!(matches!(got, Completed(Ok(_))), "{got:?}");
let got = map_polling_result(TestResult::Completed(Err(service_error())));
assert!(
matches!(&got, Completed(Err(e)) if e.status() == service_error().status()),
"{got:?}"
);
let got = map_polling_result(TestResult::InProgress(None));
assert!(matches!(got, InProgress(None)), "{got:?}");
let got = map_polling_result(TestResult::InProgress(Some(Timestamp::clamp(123, 456))));
assert!(matches!(got, InProgress(Some(t)) if t == Timestamp::clamp(123, 456)));
let got = map_polling_result(TestResult::PollingError(polling_error()));
assert!(matches!(&got, PollingError(e) if e.is_io()), "{got:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn unit_both_until_done_basic_flow() -> Result<()> {
type EmptyOperation = Operation<Empty, Empty>;
let start = || async move {
let any = Any::from_msg(&Empty::default())
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = EmptyOperation::new(op);
Ok::<EmptyOperation, Error>(op)
};
let query = |_: String| async move {
let any = Any::from_msg(&Empty::default())
.expect("test message deserializes via Any::from_msg");
let result = ResultAny::Response(any.into());
let op = OperationAny::default().set_done(true).set_result(result);
let op = EmptyOperation::new(op);
Ok::<EmptyOperation, Error>(op)
};
let poller = new_unit_poller(
Arc::new(AlwaysContinue),
Arc::new(
ExponentialBackoffBuilder::new()
.with_initial_delay(StdDuration::from_millis(1))
.clamp(),
),
start,
query,
);
poller.until_done().await?;
Ok(())
}
#[test]
fn unit_metadata_map() {
use PollingResult::{Completed, InProgress, PollingError};
type TestResult = PollingResult<Duration, Empty>;
let got = map_polling_metadata(TestResult::Completed(Ok(Duration::clamp(123, 456))));
assert!(matches!(got, Completed(Ok(_))), "{got:?}");
let got = map_polling_metadata(TestResult::Completed(Err(service_error())));
assert!(
matches!(&got, Completed(Err(e)) if e.status() == service_error().status()),
"{got:?}"
);
let got = map_polling_metadata(TestResult::InProgress(None));
assert!(matches!(got, InProgress(None)), "{got:?}");
let got = map_polling_metadata(TestResult::InProgress(Some(Empty::default())));
assert!(matches!(got, InProgress(Some(_))), "{got:?}");
let got = map_polling_metadata(TestResult::PollingError(polling_error()));
assert!(matches!(&got, PollingError(e) if e.is_io()), "{got:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn until_done_with_recoverable_polling_error() -> Result<()> {
let start = || async move {
let any = Any::from_msg(&Timestamp::clamp(123, 0))
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = TestOperation::new(op);
Ok::<TestOperation, Error>(op)
};
let count = Arc::new(std::sync::Mutex::new(0_u32));
let query = move |_: String| {
let mut guard = count.lock().unwrap();
let c = *guard;
*guard = c + 1;
drop(guard);
async move {
if c == 0 {
return Err::<TestOperation, Error>(polling_error());
}
let any = Any::from_msg(&Duration::clamp(234, 0))
.expect("test message deserializes via Any::from_msg");
let result = ResultAny::Response(any.into());
let op = OperationAny::default().set_done(true).set_result(result);
let op = TestOperation::new(op);
Ok::<TestOperation, Error>(op)
}
};
let poller = PollerImpl::new(
Arc::new(AlwaysContinue),
Arc::new(
ExponentialBackoffBuilder::new()
.with_initial_delay(StdDuration::from_millis(1))
.clamp(),
),
start,
query,
);
let response = poller.until_done().await?;
assert_eq!(response, Duration::clamp(234, 0));
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn until_done_with_unrecoverable_polling_error() -> Result<()> {
let start = || async move {
let any = Any::from_msg(&Timestamp::clamp(123, 0))
.expect("test message deserializes via Any::from_msg");
let op = OperationAny::default()
.set_name("test-only-name")
.set_metadata(any);
let op = TestOperation::new(op);
Ok::<TestOperation, Error>(op)
};
let query = move |_: String| async move { Err::<TestOperation, Error>(unrecoverable()) };
let poller = PollerImpl::new(
Arc::new(Aip194Strict),
Arc::new(
ExponentialBackoffBuilder::new()
.with_initial_delay(StdDuration::from_millis(1))
.clamp(),
),
start,
query,
);
let response = poller.until_done().await;
assert!(response.is_err(), "{response:?}");
assert!(
format!("{response:?}").contains("unrecoverable"),
"{response:?}"
);
Ok(())
}
fn service_error() -> Error {
Error::service(
Status::default()
.set_code(Code::ResourceExhausted)
.set_message("too many things"),
)
}
fn unrecoverable() -> Error {
Error::service(
Status::default()
.set_code(Code::Aborted)
.set_message("unrecoverable"),
)
}
fn polling_error() -> Error {
Error::io("something failed")
}
}