1use crate::{
27 Poller, PollingBackoffPolicy, PollingErrorPolicy, PollingResult, Result,
28 sealed::Poller as SealedPoller,
29};
30use google_cloud_gax::error::rpc::Status;
31use google_cloud_gax::polling_state::PollingState;
32use google_cloud_gax::retry_result::RetryResult;
33use std::sync::Arc;
34
35#[cfg(google_cloud_unstable_tracing)]
36use super::LroRecorder;
37
38pub trait DiscoveryOperation {
51 fn done(&self) -> bool;
53
54 fn name(&self) -> Option<&String>;
58
59 fn error(&self) -> Option<Status> {
61 None
62 }
63}
64
65pub fn new_discovery_poller<S, SF, Q, QF, O>(
66 polling_error_policy: Arc<dyn PollingErrorPolicy>,
67 polling_backoff_policy: Arc<dyn PollingBackoffPolicy>,
68 start: S,
69 query: Q,
70) -> impl Poller<O, O>
71where
72 O: DiscoveryOperation + Send,
73 S: FnOnce() -> SF + Send + Sync,
74 SF: std::future::Future<Output = Result<O>> + Send + 'static,
75 Q: FnMut(String) -> QF + Send + Sync + Clone,
76 QF: std::future::Future<Output = Result<O>> + Send + 'static,
77{
78 DiscoveryPoller::new(polling_error_policy, polling_backoff_policy, start, query)
79}
80
81struct DiscoveryPoller<S, Q> {
82 error_policy: Arc<dyn PollingErrorPolicy>,
83 backoff_policy: Arc<dyn PollingBackoffPolicy>,
84 start: Option<S>,
85 query: Q,
86 operation: Option<String>,
87 state: PollingState,
88}
89
90impl<S, Q> DiscoveryPoller<S, Q> {
91 pub fn new(
92 error_policy: Arc<dyn PollingErrorPolicy>,
93 backoff_policy: Arc<dyn PollingBackoffPolicy>,
94 start: S,
95 query: Q,
96 ) -> Self {
97 Self {
98 error_policy,
99 backoff_policy,
100 start: Some(start),
101 query,
102 operation: None,
103 state: PollingState::default(),
104 }
105 }
106}
107
108impl<S, Q> SealedPoller for DiscoveryPoller<S, Q>
109where
110 S: Send,
111 Q: Send,
112{
113 async fn backoff(&mut self, state: &PollingState) {
114 let backoff = self.backoff_policy.wait_period(state);
115 tokio::time::sleep(backoff).await;
116 }
117}
118
119impl<O, S, SF, Q, QF> crate::Poller<O, O> for DiscoveryPoller<S, Q>
120where
121 O: DiscoveryOperation + Send,
122 S: FnOnce() -> SF + Send + Sync,
123 SF: std::future::Future<Output = Result<O>> + Send + 'static,
124 Q: FnMut(String) -> QF + Send + Sync + Clone,
125 QF: std::future::Future<Output = Result<O>> + Send + 'static,
126{
127 async fn poll(&mut self) -> Option<PollingResult<O, O>> {
128 if let Some(start) = self.start.take() {
129 let result = start().await;
130 #[cfg(google_cloud_unstable_tracing)]
131 if let Ok(ref op) = result {
132 let name = op.name();
133 if let (Some(name), Some(recorder)) = (name, LroRecorder::current()) {
134 recorder.record_destination_id(name);
135 }
136 }
137 let (op, poll) = self::handle_start(result);
138 self.operation = op;
139 return Some(poll);
140 }
141 if let Some(name) = self.operation.take() {
142 self.state.attempt_count += 1;
143 let result = (self.query)(name.clone()).await;
144 let (op, poll) =
145 self::handle_poll(self.error_policy.clone(), &self.state, name, result);
146 #[cfg(google_cloud_unstable_tracing)]
147 if let (Some(next_name), Some(recorder)) = (&op, LroRecorder::current()) {
148 recorder.record_destination_id(next_name);
149 }
150 self.operation = op;
151 return Some(poll);
152 }
153 None
154 }
155 async fn until_done(self) -> Result<O> {
156 crate::until_done(self).await
157 }
158
159 #[cfg(feature = "unstable-stream")]
160 fn into_stream(self) -> impl futures::Stream<Item = PollingResult<O, O>> + Unpin {
161 crate::into_stream(self)
162 }
163}
164
165fn handle_start<O>(result: Result<O>) -> (Option<String>, PollingResult<O, O>)
166where
167 O: DiscoveryOperation,
168{
169 match result {
170 Err(ref _e) => (None, PollingResult::Completed(result)),
171 Ok(o) if o.done() => (None, PollingResult::Completed(Ok(o))),
172 Ok(o) => handle_polling_success(o),
173 }
174}
175
176fn handle_poll<O>(
177 error_policy: Arc<dyn PollingErrorPolicy>,
178 state: &PollingState,
179 operation_name: String,
180 result: Result<O>,
181) -> (Option<String>, PollingResult<O, O>)
182where
183 O: DiscoveryOperation,
184{
185 match result {
186 Err(e) => {
187 let state = error_policy.on_error(state, e);
188 handle_polling_error(state, operation_name)
189 }
190 Ok(o) if o.done() => (None, PollingResult::Completed(Ok(o))),
191 Ok(o) => handle_polling_success(o),
192 }
193}
194
195fn handle_polling_error<O>(
196 state: RetryResult,
197 operation_name: String,
198) -> (Option<String>, PollingResult<O, O>)
199where
200 O: DiscoveryOperation,
201{
202 match state {
203 RetryResult::Continue(e) => (Some(operation_name), PollingResult::PollingError(e)),
204 RetryResult::Exhausted(e) | RetryResult::Permanent(e) => {
205 (None, PollingResult::Completed(Err(e)))
206 }
207 }
208}
209
210fn handle_polling_success<O>(o: O) -> (Option<String>, PollingResult<O, O>)
211where
212 O: DiscoveryOperation,
213{
214 (o.name().cloned(), PollingResult::InProgress(Some(o)))
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::Error;
221 use google_cloud_gax::error::rpc::{Code, Status};
222 use google_cloud_gax::exponential_backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
223 use google_cloud_gax::polling_error_policy::{Aip194Strict, AlwaysContinue};
224 use std::time::Duration;
225
226 #[cfg(not(google_cloud_unstable_tracing))]
227 pub(crate) struct DummySpan;
228
229 #[cfg(not(google_cloud_unstable_tracing))]
230 fn test_span() -> DummySpan {
231 DummySpan
232 }
233
234 #[cfg(not(google_cloud_unstable_tracing))]
235 pub(crate) trait Instrument: Sized {
236 fn instrument(self, _span: DummySpan) -> Self {
237 self
238 }
239 }
240
241 #[cfg(not(google_cloud_unstable_tracing))]
242 impl<T> Instrument for T {}
243
244 #[cfg(google_cloud_unstable_tracing)]
245 use tracing::Instrument;
246
247 #[cfg(google_cloud_unstable_tracing)]
248 fn test_span() -> tracing::Span {
249 tracing::info_span!(
250 "test_span",
251 gcp.resource.destination.id = tracing::field::Empty,
252 )
253 }
254
255 #[tokio::test]
256 async fn poller_until_done_success() {
257 let start = || async move {
258 let op = TestOperation {
259 name: Some("start-name".into()),
260 ..TestOperation::default()
261 };
262 Ok(op)
263 };
264 let query = |_name| async move {
265 let op = TestOperation {
266 done: true,
267 value: Some(42),
268 ..TestOperation::default()
269 };
270 Ok(op)
271 };
272 let got = new_discovery_poller(
273 Arc::new(AlwaysContinue),
274 Arc::new(test_backoff()),
275 start,
276 query,
277 )
278 .until_done()
279 .instrument(test_span())
280 .await;
281 assert!(
282 matches!(
283 got,
284 Ok(TestOperation {
285 value: Some(42),
286 ..
287 })
288 ),
289 "{got:?}"
290 );
291 }
292
293 #[tokio::test]
294 async fn poller_until_done_success_with_transient() {
295 let start = || async move {
296 let op = TestOperation {
297 name: Some("start-name".into()),
298 ..TestOperation::default()
299 };
300 Ok(op)
301 };
302 let mut query_count = 0;
303 let query = move |_name| {
304 query_count += 1;
305 let count = query_count;
306 async move {
307 match count {
308 1 => Err(transient()),
309 _ => {
310 let op = TestOperation {
311 done: true,
312 value: Some(42),
313 ..TestOperation::default()
314 };
315 Ok(op)
316 }
317 }
318 }
319 };
320 let got = new_discovery_poller(
321 Arc::new(AlwaysContinue),
322 Arc::new(test_backoff()),
323 start,
324 query,
325 )
326 .until_done()
327 .instrument(test_span())
328 .await;
329 assert!(
330 matches!(
331 got,
332 Ok(TestOperation {
333 value: Some(42),
334 ..
335 })
336 ),
337 "{got:?}"
338 );
339 }
340
341 #[tokio::test]
342 async fn poller_until_done_error_on_start() {
343 let start = || async move { Err(Error::service(permanent_status())) };
344 let query = async |_name| -> Result<TestOperation> {
345 panic!();
346 };
347 let got = new_discovery_poller(
348 Arc::new(AlwaysContinue),
349 Arc::new(test_backoff()),
350 start,
351 query,
352 )
353 .until_done()
354 .await;
355 assert!(
356 matches!(
357 got,
358 Err(ref e) if e.status() == Some(&permanent_status())
359 ),
360 "{got:?}"
361 );
362 }
363
364 #[tokio::test]
365 async fn poller_into_stream() {
366 use futures::StreamExt;
367 let start = || async move {
368 let op = TestOperation {
369 name: Some("start-name".into()),
370 ..TestOperation::default()
371 };
372 Ok(op)
373 };
374 let query = |_name| async move {
375 let op = TestOperation {
376 done: true,
377 value: Some(42),
378 ..TestOperation::default()
379 };
380 Ok(op)
381 };
382 let mut stream = new_discovery_poller(
383 Arc::new(AlwaysContinue),
384 Arc::new(test_backoff()),
385 start,
386 query,
387 )
388 .into_stream();
389 let got = stream.next().await;
391 assert!(
392 matches!(got, Some(PollingResult::InProgress(Some(_)))),
393 "{got:?}"
394 );
395 let got = stream.next().await;
396 assert!(
397 matches!(
398 got,
399 Some(PollingResult::Completed(Ok(TestOperation {
400 value: Some(42),
401 ..
402 })))
403 ),
404 "{got:?}"
405 );
406 let got = stream.next().await;
407 assert!(got.is_none(), "{got:?}");
408 }
409
410 #[test]
411 fn start_error() {
412 let got = handle_start::<TestOperation>(Err(transient()));
413 assert!(got.0.is_none(), "{got:?}");
414 assert!(
415 matches!(&got.1, PollingResult::Completed(Err(_))),
416 "{got:?}"
417 );
418 }
419
420 #[test]
421 fn start_done() {
422 let input = TestOperation {
423 done: true,
424 ..TestOperation::default()
425 };
426 let got = handle_start(Ok(input));
427 assert!(got.0.is_none(), "{got:?}");
428 assert!(matches!(&got.1, PollingResult::Completed(Ok(_))), "{got:?}");
429 }
430
431 #[test]
432 fn start_in_progress() {
433 let input = TestOperation {
434 done: false,
435 name: Some("in-progress".to_string()),
436 ..TestOperation::default()
437 };
438 let got = handle_start(Ok(input));
439 assert_eq!(got.0.as_deref(), Some("in-progress"), "{got:?}");
440 assert!(
441 matches!(&got.1, PollingResult::InProgress(Some(_))),
442 "{got:?}"
443 );
444 }
445
446 #[test]
447 fn poll_error() {
448 let policy = Aip194Strict;
449 let state = PollingState::default();
450 let got = handle_poll::<TestOperation>(
451 Arc::new(policy),
452 &state,
453 "started".to_string(),
454 Err(transient()),
455 );
456 assert_eq!(got.0.as_deref(), Some("started"), "{got:?}");
457 assert!(matches!(got.1, PollingResult::PollingError(_)), "{got:?}");
458 }
459
460 #[test]
461 fn poll_done_success() {
462 let policy = Aip194Strict;
463 let state = PollingState::default();
464 let input = TestOperation {
465 done: true,
466 name: Some("in-progress".into()),
467 ..TestOperation::default()
468 };
469 let got = handle_poll(Arc::new(policy), &state, "started".to_string(), Ok(input));
470 assert!(got.0.is_none(), "{got:?}");
471 assert!(matches!(got.1, PollingResult::Completed(Ok(_))), "{got:?}");
472 }
473
474 #[test]
475 fn poll_in_progress() {
476 let policy = Aip194Strict;
477 let state = PollingState::default();
478 let input = TestOperation {
479 done: false,
480 name: Some("in-progress".into()),
481 ..TestOperation::default()
482 };
483 let got = handle_poll(Arc::new(policy), &state, "started".to_string(), Ok(input));
484 assert_eq!(got.0.as_deref(), Some("in-progress"), "{got:?}");
485 assert!(matches!(got.1, PollingResult::InProgress(_)), "{got:?}");
486 }
487
488 #[test]
489 fn polling_error() {
490 let got = handle_polling_error::<TestOperation>(
491 RetryResult::Continue(transient()),
492 "name-for-continue".to_string(),
493 );
494 assert_eq!(got.0.as_deref(), Some("name-for-continue"), "{got:?}");
495 assert!(
496 matches!(got.1, PollingResult::PollingError(ref e) if is_transient(e)),
497 "{got:?}"
498 );
499
500 let got = handle_polling_error::<TestOperation>(
501 RetryResult::Exhausted(transient()),
502 "name-for-exhausted".to_string(),
503 );
504 assert!(got.0.is_none(), "{got:?}");
505 assert!(
506 matches!(got.1, PollingResult::Completed(Err(ref e)) if is_transient(e)),
507 "{got:?}"
508 );
509
510 let got = handle_polling_error::<TestOperation>(
511 RetryResult::Permanent(transient()),
512 "name-for-permanent".to_string(),
513 );
514 assert!(got.0.is_none(), "{got:?}");
515 assert!(
516 matches!(got.1, PollingResult::Completed(Err(ref e)) if is_transient(e)),
517 "{got:?}"
518 );
519 }
520
521 #[test]
522 fn polling_success() {
523 let input = TestOperation {
524 name: Some("in-progress".to_string()),
525 ..TestOperation::default()
526 };
527 let got = handle_polling_success(input);
528 assert_eq!(got.0.as_deref(), Some("in-progress"), "{got:?}");
529 assert!(
530 matches!(&got.1, PollingResult::InProgress(Some(_))),
531 "{got:?}"
532 );
533 }
534
535 fn is_transient(error: &Error) -> bool {
536 error.status().is_some_and(|s| s == &transient_status())
537 }
538
539 fn transient() -> Error {
540 Error::service(transient_status())
541 }
542
543 fn transient_status() -> Status {
544 Status::default()
545 .set_code(Code::Unavailable)
546 .set_message("try-again")
547 }
548
549 fn permanent_status() -> Status {
550 Status::default()
551 .set_code(Code::PermissionDenied)
552 .set_message("uh-oh")
553 }
554
555 fn test_backoff() -> ExponentialBackoff {
556 ExponentialBackoffBuilder::new()
557 .with_initial_delay(Duration::from_millis(1))
558 .with_maximum_delay(Duration::from_millis(1))
559 .build()
560 .expect("hard-coded values should succeed")
561 }
562
563 #[derive(Debug, Default, PartialEq)]
564 struct TestOperation {
565 done: bool,
566 name: Option<String>,
567 value: Option<i32>,
568 }
569
570 impl DiscoveryOperation for TestOperation {
571 fn done(&self) -> bool {
572 self.done
573 }
574 fn name(&self) -> Option<&String> {
575 self.name.as_ref()
576 }
577 }
578
579 #[cfg(google_cloud_unstable_tracing)]
580 #[tokio::test]
581 async fn test_discovery_poller_tracing() {
582 let guard = google_cloud_test_utils::test_layer::TestLayer::initialize();
583
584 let start = || async move {
585 let op = TestOperation {
586 name: Some("discovery-operation-123".into()),
587 ..TestOperation::default()
588 };
589 Ok(op)
590 };
591
592 let count = Arc::new(std::sync::Mutex::new(0));
593 let query_count = count.clone();
594 let query = move |_: String| {
595 let mut c = query_count.lock().unwrap();
596 *c += 1;
597 let is_done = *c > 1;
598 async move {
599 if is_done {
600 let op = TestOperation {
601 done: true,
602 value: Some(42),
603 ..TestOperation::default()
604 };
605 Ok(op)
606 } else {
607 let op = TestOperation {
608 name: Some("discovery-operation-123".into()),
609 ..TestOperation::default()
610 };
611 Ok(op)
612 }
613 }
614 };
615
616 let mut poller = DiscoveryPoller::new(
617 Arc::new(AlwaysContinue),
618 Arc::new(test_backoff()),
619 start,
620 query,
621 );
622
623 let span = test_span();
624 let poller_ref = &mut poller;
625 let recorder = crate::internal::LroRecorder::new(span.clone());
626 let _ = recorder
627 .scope(async move { poller_ref.poll().instrument(span).await })
628 .await;
629
630 {
631 let captured = google_cloud_test_utils::test_layer::TestLayer::capture(&guard);
632 let got = captured
633 .iter()
634 .find(|s| s.name == "test_span")
635 .unwrap_or_else(|| panic!("missing `test_span` in captured spans: {captured:?}"));
636 assert_eq!(
637 got.attributes
638 .get("gcp.resource.destination.id")
639 .and_then(|v| v.as_string()),
640 Some("discovery-operation-123".to_string())
641 );
642 }
643
644 let span = test_span();
645 let poller_ref2 = &mut poller;
646 let recorder2 = crate::internal::LroRecorder::new(span.clone());
647 let _ = recorder2
648 .scope(async move { poller_ref2.poll().instrument(span).await })
649 .await;
650
651 {
652 let captured = google_cloud_test_utils::test_layer::TestLayer::capture(&guard);
653 let got = captured
654 .iter()
655 .find(|s| s.name == "test_span")
656 .unwrap_or_else(|| panic!("missing `test_span` in captured spans: {captured:?}"));
657 assert_eq!(
658 got.attributes
659 .get("gcp.resource.destination.id")
660 .and_then(|v| v.as_string()),
661 Some("discovery-operation-123".to_string())
662 );
663 }
664 }
665}