1use gax::error::Error;
18use gax::polling_backoff_policy::PollingBackoffPolicy;
19use gax::polling_policy::PollingPolicy;
20use gax::Result;
21use std::future::Future;
22use std::marker::PhantomData;
23use std::sync::Arc;
24use std::time::Instant;
25
26#[derive(Debug)]
34pub enum PollingResult<R, M> {
35 InProgress(Option<M>),
37 Completed(Result<R>),
39 PollingError(Error),
54}
55
56#[doc(hidden)]
61pub struct Operation<R, M> {
62 inner: longrunning::model::Operation,
63 response: std::marker::PhantomData<R>,
64 metadata: std::marker::PhantomData<M>,
65}
66
67impl<R, M> Operation<R, M> {
68 pub fn new(inner: longrunning::model::Operation) -> Self {
69 Self {
70 inner,
71 response: PhantomData,
72 metadata: PhantomData,
73 }
74 }
75
76 fn name(&self) -> String {
77 self.inner.name.clone()
78 }
79 fn done(&self) -> bool {
80 self.inner.done
81 }
82 fn metadata(&self) -> Option<&wkt::Any> {
83 self.inner.metadata.as_ref()
84 }
85 fn response(&self) -> Option<&wkt::Any> {
86 use longrunning::model::operation::Result;
87 self.inner.result.as_ref().and_then(|r| match r {
88 Result::Error(_) => None,
89 Result::Response(r) => Some(r.as_ref()),
90 _ => None,
91 })
92 }
93 fn error(&self) -> Option<&rpc::model::Status> {
94 use longrunning::model::operation::Result;
95 self.inner.result.as_ref().and_then(|r| match r {
96 Result::Error(rpc) => Some(rpc.as_ref()),
97 Result::Response(_) => None,
98 _ => None,
99 })
100 }
101}
102
103pub trait Poller<R, M> {
111 fn poll(&mut self) -> impl Future<Output = Option<PollingResult<R, M>>>;
113
114 fn until_done(self) -> impl Future<Output = Result<R>>;
116
117 #[cfg(feature = "unstable-stream")]
119 fn to_stream(self) -> impl futures::Stream<Item = PollingResult<R, M>>;
120}
121
122#[doc(hidden)]
127pub fn new_poller<ResponseType, MetadataType, S, SF, Q, QF>(
128 polling_policy: Arc<dyn PollingPolicy>,
129 polling_backoff_policy: Arc<dyn PollingBackoffPolicy>,
130 start: S,
131 query: Q,
132) -> impl Poller<ResponseType, MetadataType>
133where
134 ResponseType: wkt::message::Message + serde::de::DeserializeOwned,
135 MetadataType: wkt::message::Message + serde::de::DeserializeOwned,
136 S: FnOnce() -> SF + Send + Sync,
137 SF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
138 + Send
139 + 'static,
140 Q: Fn(String) -> QF + Send + Sync + Clone,
141 QF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
142 + Send
143 + 'static,
144{
145 PollerImpl::new(polling_policy, polling_backoff_policy, start, query)
146}
147
148struct PollerImpl<ResponseType, MetadataType, S, SF, Q, QF>
170where
171 S: FnOnce() -> SF + Send + Sync,
172 SF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
173 + Send
174 + 'static,
175 Q: Fn(String) -> QF + Send + Sync + Clone,
176 QF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
177 + Send
178 + 'static,
179{
180 polling_policy: Arc<dyn PollingPolicy>,
181 backoff_policy: Arc<dyn PollingBackoffPolicy>,
182 start: Option<S>,
183 query: Q,
184 operation: Option<String>,
185 loop_start: Instant,
186 attempt_count: u32,
187}
188
189impl<ResponseType, MetadataType, S, SF, Q, QF> PollerImpl<ResponseType, MetadataType, S, SF, Q, QF>
190where
191 S: FnOnce() -> SF + Send + Sync,
192 SF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
193 + Send
194 + 'static,
195 Q: Fn(String) -> QF + Send + Sync + Clone,
196 QF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
197 + Send
198 + 'static,
199{
200 pub fn new(
201 polling_policy: Arc<dyn PollingPolicy>,
202 backoff_policy: Arc<dyn PollingBackoffPolicy>,
203 start: S,
204 query: Q,
205 ) -> Self {
206 Self {
207 polling_policy,
208 backoff_policy,
209 start: Some(start),
210 query,
211 operation: None,
212 loop_start: Instant::now(),
213 attempt_count: 0,
214 }
215 }
216}
217
218impl<ResponseType, MetadataType, S, SF, P, PF> Poller<ResponseType, MetadataType>
219 for PollerImpl<ResponseType, MetadataType, S, SF, P, PF>
220where
221 ResponseType: wkt::message::Message + serde::de::DeserializeOwned,
222 MetadataType: wkt::message::Message + serde::de::DeserializeOwned,
223 S: FnOnce() -> SF + Send + Sync,
224 SF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
225 + Send
226 + 'static,
227 P: Fn(String) -> PF + Send + Sync + Clone,
228 PF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
229 + Send
230 + 'static,
231{
232 async fn poll(&mut self) -> Option<PollingResult<ResponseType, MetadataType>> {
233 if let Some(start) = self.start.take() {
234 let result = start().await;
235 let (op, poll) = details::handle_start(result);
236 self.operation = op;
237 return Some(poll);
238 }
239 if let Some(name) = self.operation.take() {
240 self.attempt_count += 1;
241 let result = (self.query)(name.clone()).await;
242 let (op, poll) = details::handle_poll(
243 self.polling_policy.clone(),
244 self.loop_start,
245 self.attempt_count,
246 name,
247 result,
248 );
249 self.operation = op;
250 return Some(poll);
251 }
252 None
253 }
254
255 async fn until_done(mut self) -> Result<ResponseType> {
256 let loop_start = std::time::Instant::now();
257 let mut attempt_count = 0;
258 while let Some(p) = self.poll().await {
259 match p {
260 PollingResult::Completed(r) => return r,
263 PollingResult::InProgress(_) => (),
266 PollingResult::PollingError(_) => (),
269 }
270 attempt_count += 1;
271 tokio::time::sleep(self.backoff_policy.wait_period(loop_start, attempt_count)).await;
272 }
273 unreachable!("loop should exit via the `Completed` branch vs. this line");
277 }
278
279 #[cfg(feature = "unstable-stream")]
280 fn to_stream(self) -> impl futures::Stream<Item = PollingResult<ResponseType, MetadataType>>
281 where
282 ResponseType: wkt::message::Message + serde::de::DeserializeOwned,
283 MetadataType: wkt::message::Message + serde::de::DeserializeOwned,
284 {
285 use futures::stream::unfold;
286 unfold(Some(self), move |state| async move {
287 if let Some(mut poller) = state {
288 if let Some(pr) = poller.poll().await {
289 return Some((pr, Some(poller)));
290 }
291 };
292 None
293 })
294 }
295}
296
297mod details;
298
299#[cfg(test)]
300mod test {
301 use super::*;
302 use gax::exponential_backoff::ExponentialBackoff;
303 use gax::exponential_backoff::ExponentialBackoffBuilder;
304 use gax::polling_policy::*;
305 use std::time::Duration;
306
307 type ResponseType = wkt::Duration;
308 type MetadataType = wkt::Timestamp;
309 type TestOperation = Operation<ResponseType, MetadataType>;
310
311 #[test]
312 fn typed_operation_with_metadata() -> Result<()> {
313 let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
314 .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
315 let op = longrunning::model::Operation::default()
316 .set_name("test-only-name")
317 .set_metadata(any);
318 let op = TestOperation::new(op);
319 assert_eq!(op.name(), "test-only-name");
320 assert!(!op.done());
321 assert!(matches!(op.metadata(), Some(_)));
322 assert!(matches!(op.response(), None));
323 assert!(matches!(op.error(), None));
324 let got = op
325 .metadata()
326 .unwrap()
327 .try_into_message::<wkt::Timestamp>()
328 .map_err(Error::other)?;
329 assert_eq!(got, wkt::Timestamp::clamp(123, 0));
330
331 Ok(())
332 }
333
334 #[test]
335 fn typed_operation_with_response() -> Result<()> {
336 let any = wkt::Any::try_from(&wkt::Duration::clamp(23, 0))
337 .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
338 let op = longrunning::model::Operation::default()
339 .set_name("test-only-name")
340 .set_result(longrunning::model::operation::Result::Response(any.into()));
341 let op = TestOperation::new(op);
342 assert_eq!(op.name(), "test-only-name");
343 assert!(!op.done());
344 assert!(matches!(op.metadata(), None));
345 assert!(matches!(op.response(), Some(_)));
346 assert!(matches!(op.error(), None));
347 let got = op
348 .response()
349 .unwrap()
350 .try_into_message::<wkt::Duration>()
351 .map_err(Error::other)?;
352 assert_eq!(got, wkt::Duration::clamp(23, 0));
353
354 Ok(())
355 }
356
357 #[test]
358 fn typed_operation_with_error() -> Result<()> {
359 let rpc = rpc::model::Status::default()
360 .set_message("test only")
361 .set_code(16);
362 let op = longrunning::model::Operation::default()
363 .set_name("test-only-name")
364 .set_result(longrunning::model::operation::Result::Error(
365 rpc.clone().into(),
366 ));
367 let op = TestOperation::new(op);
368 assert_eq!(op.name(), "test-only-name");
369 assert!(!op.done());
370 assert!(matches!(op.metadata(), None));
371 assert!(matches!(op.response(), None));
372 assert!(matches!(op.error(), Some(_)));
373 let got = op.error().unwrap();
374 assert_eq!(got, &rpc);
375
376 Ok(())
377 }
378
379 #[tokio::test(flavor = "multi_thread")]
380 async fn poll_basic_flow() {
381 let start = || async move {
382 let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
383 .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
384 let op = longrunning::model::Operation::default()
385 .set_name("test-only-name")
386 .set_metadata(any);
387 let op = TestOperation::new(op);
388 Ok::<TestOperation, Error>(op)
389 };
390
391 let query = |_: String| async move {
392 let any = wkt::Any::try_from(&wkt::Duration::clamp(234, 0))
393 .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
394 let result = longrunning::model::operation::Result::Response(any.into());
395 let op = longrunning::model::Operation::default()
396 .set_done(true)
397 .set_result(result);
398 let op = TestOperation::new(op);
399
400 Ok::<TestOperation, Error>(op)
401 };
402
403 let mut poller = PollerImpl::new(
404 Arc::new(AlwaysContinue),
405 Arc::new(ExponentialBackoff::default()),
406 start,
407 query,
408 );
409 let p0 = poller.poll().await;
410 match p0.unwrap() {
411 PollingResult::InProgress(m) => {
412 assert_eq!(m, Some(wkt::Timestamp::clamp(123, 0)));
413 }
414 r => {
415 assert!(false, "{r:?}");
416 }
417 }
418
419 let p1 = poller.poll().await;
420 match p1.unwrap() {
421 PollingResult::Completed(r) => {
422 let response = r.unwrap();
423 assert_eq!(response, wkt::Duration::clamp(234, 0));
424 }
425 r => {
426 assert!(false, "{r:?}");
427 }
428 }
429
430 let p2 = poller.poll().await;
431 assert!(p2.is_none(), "{p2:?}");
432 }
433
434 #[tokio::test(flavor = "multi_thread")]
435 async fn poll_basic_stream() {
436 let start = || async move {
437 let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
438 .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
439 let op = longrunning::model::Operation::default()
440 .set_name("test-only-name")
441 .set_metadata(any);
442 let op = TestOperation::new(op);
443 Ok::<TestOperation, Error>(op)
444 };
445
446 let query = |_: String| async move {
447 let any = wkt::Any::try_from(&wkt::Duration::clamp(234, 0))
448 .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
449 let result = longrunning::model::operation::Result::Response(any.into());
450 let op = longrunning::model::Operation::default()
451 .set_done(true)
452 .set_result(result);
453 let op = TestOperation::new(op);
454
455 Ok::<TestOperation, Error>(op)
456 };
457
458 use futures::StreamExt;
459 let mut stream = new_poller(
460 Arc::new(AlwaysContinue),
461 Arc::new(ExponentialBackoff::default()),
462 start,
463 query,
464 )
465 .to_stream();
466 let mut stream = std::pin::pin!(stream);
467 let p0 = stream.next().await;
468 match p0.unwrap() {
469 PollingResult::InProgress(m) => {
470 assert_eq!(m, Some(wkt::Timestamp::clamp(123, 0)));
471 }
472 r => {
473 assert!(false, "{r:?}");
474 }
475 }
476
477 let p1 = stream.next().await;
478 match p1.unwrap() {
479 PollingResult::Completed(r) => {
480 let response = r.unwrap();
481 assert_eq!(response, wkt::Duration::clamp(234, 0));
482 }
483 r => {
484 assert!(false, "{r:?}");
485 }
486 }
487
488 let p2 = stream.next().await;
489 assert!(p2.is_none(), "{p2:?}");
490 }
491
492 #[tokio::test(flavor = "multi_thread")]
493 async fn until_done_basic_flow() -> Result<()> {
494 let start = || async move {
495 let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
496 .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
497 let op = longrunning::model::Operation::default()
498 .set_name("test-only-name")
499 .set_metadata(any);
500 let op = TestOperation::new(op);
501 Ok::<TestOperation, Error>(op)
502 };
503
504 let query = |_: String| async move {
505 let any = wkt::Any::try_from(&wkt::Duration::clamp(234, 0))
506 .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
507 let result = longrunning::model::operation::Result::Response(any.into());
508 let op = longrunning::model::Operation::default()
509 .set_done(true)
510 .set_result(result);
511 let op = TestOperation::new(op);
512
513 Ok::<TestOperation, Error>(op)
514 };
515
516 let poller = PollerImpl::new(
517 Arc::new(AlwaysContinue),
518 Arc::new(
519 ExponentialBackoffBuilder::new()
520 .with_initial_delay(Duration::from_millis(1))
521 .clamp(),
522 ),
523 start,
524 query,
525 );
526 let response = poller.until_done().await?;
527 assert_eq!(response, wkt::Duration::clamp(234, 0));
528
529 Ok(())
530 }
531
532 #[tokio::test(flavor = "multi_thread")]
533 async fn until_done_with_recoverable_polling_error() -> Result<()> {
534 let start = || async move {
535 let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
536 .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
537 let op = longrunning::model::Operation::default()
538 .set_name("test-only-name")
539 .set_metadata(any);
540 let op = TestOperation::new(op);
541 Ok::<TestOperation, Error>(op)
542 };
543
544 let count = Arc::new(std::sync::Mutex::new(0_u32));
545 let query = move |_: String| {
546 let mut guard = count.lock().unwrap();
547 let c = *guard;
548 *guard = c + 1;
549 drop(guard);
550 async move {
551 if c == 0 {
552 return Err::<TestOperation, Error>(Error::other(
553 "recoverable (see policy below)",
554 ));
555 }
556 let any = wkt::Any::try_from(&wkt::Duration::clamp(234, 0))
557 .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
558 let result = longrunning::model::operation::Result::Response(any.into());
559 let op = longrunning::model::Operation::default()
560 .set_done(true)
561 .set_result(result);
562 let op = TestOperation::new(op);
563
564 Ok::<TestOperation, Error>(op)
565 }
566 };
567
568 let poller = PollerImpl::new(
569 Arc::new(AlwaysContinue),
570 Arc::new(
571 ExponentialBackoffBuilder::new()
572 .with_initial_delay(Duration::from_millis(1))
573 .clamp(),
574 ),
575 start,
576 query,
577 );
578 let response = poller.until_done().await?;
579 assert_eq!(response, wkt::Duration::clamp(234, 0));
580
581 Ok(())
582 }
583
584 #[tokio::test(flavor = "multi_thread")]
585 async fn until_done_with_unrecoverable_polling_error() -> Result<()> {
586 let start = || async move {
587 let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
588 .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
589 let op = longrunning::model::Operation::default()
590 .set_name("test-only-name")
591 .set_metadata(any);
592 let op = TestOperation::new(op);
593 Ok::<TestOperation, Error>(op)
594 };
595
596 let query = move |_: String| async move {
597 return Err::<TestOperation, Error>(Error::other("unrecoverable (see policy below)"));
598 };
599
600 let poller = PollerImpl::new(
601 Arc::new(Aip194Strict),
602 Arc::new(
603 ExponentialBackoffBuilder::new()
604 .with_initial_delay(Duration::from_millis(1))
605 .clamp(),
606 ),
607 start,
608 query,
609 );
610 let response = poller.until_done().await;
611 assert!(response.is_err());
612 assert!(
613 format!("{response:?}").contains("unrecoverable"),
614 "{response:?}"
615 );
616
617 Ok(())
618 }
619}