Skip to main content

google_cloud_lro/internal/
discovery.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module implements LROs for discovery-based client libraries.
16//!
17//! The discovery-based services use a different (older) form of LROs, where the
18//! "Operation" type does not include the final result, and the errors, if any,
19//! are not represented using the `google.rpc.Status` proto.
20//!
21//! In discovery-based services, the LRO functions return an "Operation" type.
22//! This type is specific to each service, that is, it is not a shared type that
23//! we can name directly.
24//!
25
26use crate::{
27    Poller, PollingBackoffPolicy, PollingErrorPolicy, PollingResult, Result,
28    sealed::Poller as SealedPoller,
29};
30use google_cloud_gax::error::rpc::Status;
31use google_cloud_gax::polling_state::PollingState;
32use google_cloud_gax::retry_result::RetryResult;
33use std::sync::Arc;
34
35#[cfg(google_cloud_unstable_tracing)]
36use super::LroRecorder;
37
38/// Defines the trait for an "Operation" type in the discovery poller.
39///
40/// In discovery-based services each client library defines a different type as
41/// the "Operation" type for long-running operations.
42///
43/// The client libraries must implement the `DiscoveryOperation` trait for this
44/// type. The trait defines how to determine if an operation has completed, if
45/// it completed with an error, and how to extract its name to perform
46/// additional polling requests.
47///
48/// Extracting the error may require hand-crafted code, as it is service
49/// specific and requires substantial coding.
50pub trait DiscoveryOperation {
51    /// Returns true if the operation has completed, with or without an error.
52    fn done(&self) -> bool;
53
54    /// Returns the name of the operation.
55    ///
56    /// It may be `None` in which case the polling loop stops.
57    fn name(&self) -> Option<&String>;
58
59    /// Returns the error status of the operation, if any.
60    fn error(&self) -> Option<Status> {
61        None
62    }
63}
64
65pub fn new_discovery_poller<S, SF, Q, QF, O>(
66    polling_error_policy: Arc<dyn PollingErrorPolicy>,
67    polling_backoff_policy: Arc<dyn PollingBackoffPolicy>,
68    start: S,
69    query: Q,
70) -> impl Poller<O, O>
71where
72    O: DiscoveryOperation + Send,
73    S: FnOnce() -> SF + Send + Sync,
74    SF: std::future::Future<Output = Result<O>> + Send + 'static,
75    Q: FnMut(String) -> QF + Send + Sync + Clone,
76    QF: std::future::Future<Output = Result<O>> + Send + 'static,
77{
78    DiscoveryPoller::new(polling_error_policy, polling_backoff_policy, start, query)
79}
80
81struct DiscoveryPoller<S, Q> {
82    error_policy: Arc<dyn PollingErrorPolicy>,
83    backoff_policy: Arc<dyn PollingBackoffPolicy>,
84    start: Option<S>,
85    query: Q,
86    operation: Option<String>,
87    state: PollingState,
88}
89
90impl<S, Q> DiscoveryPoller<S, Q> {
91    pub fn new(
92        error_policy: Arc<dyn PollingErrorPolicy>,
93        backoff_policy: Arc<dyn PollingBackoffPolicy>,
94        start: S,
95        query: Q,
96    ) -> Self {
97        Self {
98            error_policy,
99            backoff_policy,
100            start: Some(start),
101            query,
102            operation: None,
103            state: PollingState::default(),
104        }
105    }
106}
107
108impl<S, Q> SealedPoller for DiscoveryPoller<S, Q>
109where
110    S: Send,
111    Q: Send,
112{
113    async fn backoff(&mut self, state: &PollingState) {
114        let backoff = self.backoff_policy.wait_period(state);
115        tokio::time::sleep(backoff).await;
116    }
117}
118
119impl<O, S, SF, Q, QF> crate::Poller<O, O> for DiscoveryPoller<S, Q>
120where
121    O: DiscoveryOperation + Send,
122    S: FnOnce() -> SF + Send + Sync,
123    SF: std::future::Future<Output = Result<O>> + Send + 'static,
124    Q: FnMut(String) -> QF + Send + Sync + Clone,
125    QF: std::future::Future<Output = Result<O>> + Send + 'static,
126{
127    async fn poll(&mut self) -> Option<PollingResult<O, O>> {
128        if let Some(start) = self.start.take() {
129            let result = start().await;
130            #[cfg(google_cloud_unstable_tracing)]
131            if let Ok(ref op) = result {
132                let name = op.name();
133                if let (Some(name), Some(recorder)) = (name, LroRecorder::current()) {
134                    recorder.record_destination_id(name);
135                }
136            }
137            let (op, poll) = self::handle_start(result);
138            self.operation = op;
139            return Some(poll);
140        }
141        if let Some(name) = self.operation.take() {
142            self.state.attempt_count += 1;
143            let result = (self.query)(name.clone()).await;
144            let (op, poll) =
145                self::handle_poll(self.error_policy.clone(), &self.state, name, result);
146            #[cfg(google_cloud_unstable_tracing)]
147            if let (Some(next_name), Some(recorder)) = (&op, LroRecorder::current()) {
148                recorder.record_destination_id(next_name);
149            }
150            self.operation = op;
151            return Some(poll);
152        }
153        None
154    }
155    async fn until_done(self) -> Result<O> {
156        crate::until_done(self).await
157    }
158
159    #[cfg(feature = "unstable-stream")]
160    fn into_stream(self) -> impl futures::Stream<Item = PollingResult<O, O>> + Unpin {
161        crate::into_stream(self)
162    }
163}
164
165fn handle_start<O>(result: Result<O>) -> (Option<String>, PollingResult<O, O>)
166where
167    O: DiscoveryOperation,
168{
169    match result {
170        Err(ref _e) => (None, PollingResult::Completed(result)),
171        Ok(o) if o.done() => (None, PollingResult::Completed(Ok(o))),
172        Ok(o) => handle_polling_success(o),
173    }
174}
175
176fn handle_poll<O>(
177    error_policy: Arc<dyn PollingErrorPolicy>,
178    state: &PollingState,
179    operation_name: String,
180    result: Result<O>,
181) -> (Option<String>, PollingResult<O, O>)
182where
183    O: DiscoveryOperation,
184{
185    match result {
186        Err(e) => {
187            let state = error_policy.on_error(state, e);
188            handle_polling_error(state, operation_name)
189        }
190        Ok(o) if o.done() => (None, PollingResult::Completed(Ok(o))),
191        Ok(o) => handle_polling_success(o),
192    }
193}
194
195fn handle_polling_error<O>(
196    state: RetryResult,
197    operation_name: String,
198) -> (Option<String>, PollingResult<O, O>)
199where
200    O: DiscoveryOperation,
201{
202    match state {
203        RetryResult::Continue(e) => (Some(operation_name), PollingResult::PollingError(e)),
204        RetryResult::Exhausted(e) | RetryResult::Permanent(e) => {
205            (None, PollingResult::Completed(Err(e)))
206        }
207    }
208}
209
210fn handle_polling_success<O>(o: O) -> (Option<String>, PollingResult<O, O>)
211where
212    O: DiscoveryOperation,
213{
214    (o.name().cloned(), PollingResult::InProgress(Some(o)))
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::Error;
221    use google_cloud_gax::error::rpc::{Code, Status};
222    use google_cloud_gax::exponential_backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
223    use google_cloud_gax::polling_error_policy::{Aip194Strict, AlwaysContinue};
224    use std::time::Duration;
225
226    #[cfg(not(google_cloud_unstable_tracing))]
227    pub(crate) struct DummySpan;
228
229    #[cfg(not(google_cloud_unstable_tracing))]
230    fn test_span() -> DummySpan {
231        DummySpan
232    }
233
234    #[cfg(not(google_cloud_unstable_tracing))]
235    pub(crate) trait Instrument: Sized {
236        fn instrument(self, _span: DummySpan) -> Self {
237            self
238        }
239    }
240
241    #[cfg(not(google_cloud_unstable_tracing))]
242    impl<T> Instrument for T {}
243
244    #[cfg(google_cloud_unstable_tracing)]
245    use tracing::Instrument;
246
247    #[cfg(google_cloud_unstable_tracing)]
248    fn test_span() -> tracing::Span {
249        tracing::info_span!(
250            "test_span",
251            gcp.resource.destination.id = tracing::field::Empty,
252        )
253    }
254
255    #[tokio::test]
256    async fn poller_until_done_success() {
257        let start = || async move {
258            let op = TestOperation {
259                name: Some("start-name".into()),
260                ..TestOperation::default()
261            };
262            Ok(op)
263        };
264        let query = |_name| async move {
265            let op = TestOperation {
266                done: true,
267                value: Some(42),
268                ..TestOperation::default()
269            };
270            Ok(op)
271        };
272        let got = new_discovery_poller(
273            Arc::new(AlwaysContinue),
274            Arc::new(test_backoff()),
275            start,
276            query,
277        )
278        .until_done()
279        .instrument(test_span())
280        .await;
281        assert!(
282            matches!(
283                got,
284                Ok(TestOperation {
285                    value: Some(42),
286                    ..
287                })
288            ),
289            "{got:?}"
290        );
291    }
292
293    #[tokio::test]
294    async fn poller_until_done_success_with_transient() {
295        let start = || async move {
296            let op = TestOperation {
297                name: Some("start-name".into()),
298                ..TestOperation::default()
299            };
300            Ok(op)
301        };
302        let mut query_count = 0;
303        let query = move |_name| {
304            query_count += 1;
305            let count = query_count;
306            async move {
307                match count {
308                    1 => Err(transient()),
309                    _ => {
310                        let op = TestOperation {
311                            done: true,
312                            value: Some(42),
313                            ..TestOperation::default()
314                        };
315                        Ok(op)
316                    }
317                }
318            }
319        };
320        let got = new_discovery_poller(
321            Arc::new(AlwaysContinue),
322            Arc::new(test_backoff()),
323            start,
324            query,
325        )
326        .until_done()
327        .instrument(test_span())
328        .await;
329        assert!(
330            matches!(
331                got,
332                Ok(TestOperation {
333                    value: Some(42),
334                    ..
335                })
336            ),
337            "{got:?}"
338        );
339    }
340
341    #[tokio::test]
342    async fn poller_until_done_error_on_start() {
343        let start = || async move { Err(Error::service(permanent_status())) };
344        let query = async |_name| -> Result<TestOperation> {
345            panic!();
346        };
347        let got = new_discovery_poller(
348            Arc::new(AlwaysContinue),
349            Arc::new(test_backoff()),
350            start,
351            query,
352        )
353        .until_done()
354        .await;
355        assert!(
356            matches!(
357                got,
358                Err(ref e) if e.status() == Some(&permanent_status())
359            ),
360            "{got:?}"
361        );
362    }
363
364    #[tokio::test]
365    async fn poller_into_stream() {
366        use futures::StreamExt;
367        let start = || async move {
368            let op = TestOperation {
369                name: Some("start-name".into()),
370                ..TestOperation::default()
371            };
372            Ok(op)
373        };
374        let query = |_name| async move {
375            let op = TestOperation {
376                done: true,
377                value: Some(42),
378                ..TestOperation::default()
379            };
380            Ok(op)
381        };
382        let mut stream = new_discovery_poller(
383            Arc::new(AlwaysContinue),
384            Arc::new(test_backoff()),
385            start,
386            query,
387        )
388        .into_stream();
389        // The stream should return 2 Some(t) and a None.
390        let got = stream.next().await;
391        assert!(
392            matches!(got, Some(PollingResult::InProgress(Some(_)))),
393            "{got:?}"
394        );
395        let got = stream.next().await;
396        assert!(
397            matches!(
398                got,
399                Some(PollingResult::Completed(Ok(TestOperation {
400                    value: Some(42),
401                    ..
402                })))
403            ),
404            "{got:?}"
405        );
406        let got = stream.next().await;
407        assert!(got.is_none(), "{got:?}");
408    }
409
410    #[test]
411    fn start_error() {
412        let got = handle_start::<TestOperation>(Err(transient()));
413        assert!(got.0.is_none(), "{got:?}");
414        assert!(
415            matches!(&got.1, PollingResult::Completed(Err(_))),
416            "{got:?}"
417        );
418    }
419
420    #[test]
421    fn start_done() {
422        let input = TestOperation {
423            done: true,
424            ..TestOperation::default()
425        };
426        let got = handle_start(Ok(input));
427        assert!(got.0.is_none(), "{got:?}");
428        assert!(matches!(&got.1, PollingResult::Completed(Ok(_))), "{got:?}");
429    }
430
431    #[test]
432    fn start_in_progress() {
433        let input = TestOperation {
434            done: false,
435            name: Some("in-progress".to_string()),
436            ..TestOperation::default()
437        };
438        let got = handle_start(Ok(input));
439        assert_eq!(got.0.as_deref(), Some("in-progress"), "{got:?}");
440        assert!(
441            matches!(&got.1, PollingResult::InProgress(Some(_))),
442            "{got:?}"
443        );
444    }
445
446    #[test]
447    fn poll_error() {
448        let policy = Aip194Strict;
449        let state = PollingState::default();
450        let got = handle_poll::<TestOperation>(
451            Arc::new(policy),
452            &state,
453            "started".to_string(),
454            Err(transient()),
455        );
456        assert_eq!(got.0.as_deref(), Some("started"), "{got:?}");
457        assert!(matches!(got.1, PollingResult::PollingError(_)), "{got:?}");
458    }
459
460    #[test]
461    fn poll_done_success() {
462        let policy = Aip194Strict;
463        let state = PollingState::default();
464        let input = TestOperation {
465            done: true,
466            name: Some("in-progress".into()),
467            ..TestOperation::default()
468        };
469        let got = handle_poll(Arc::new(policy), &state, "started".to_string(), Ok(input));
470        assert!(got.0.is_none(), "{got:?}");
471        assert!(matches!(got.1, PollingResult::Completed(Ok(_))), "{got:?}");
472    }
473
474    #[test]
475    fn poll_in_progress() {
476        let policy = Aip194Strict;
477        let state = PollingState::default();
478        let input = TestOperation {
479            done: false,
480            name: Some("in-progress".into()),
481            ..TestOperation::default()
482        };
483        let got = handle_poll(Arc::new(policy), &state, "started".to_string(), Ok(input));
484        assert_eq!(got.0.as_deref(), Some("in-progress"), "{got:?}");
485        assert!(matches!(got.1, PollingResult::InProgress(_)), "{got:?}");
486    }
487
488    #[test]
489    fn polling_error() {
490        let got = handle_polling_error::<TestOperation>(
491            RetryResult::Continue(transient()),
492            "name-for-continue".to_string(),
493        );
494        assert_eq!(got.0.as_deref(), Some("name-for-continue"), "{got:?}");
495        assert!(
496            matches!(got.1, PollingResult::PollingError(ref e) if is_transient(e)),
497            "{got:?}"
498        );
499
500        let got = handle_polling_error::<TestOperation>(
501            RetryResult::Exhausted(transient()),
502            "name-for-exhausted".to_string(),
503        );
504        assert!(got.0.is_none(), "{got:?}");
505        assert!(
506            matches!(got.1, PollingResult::Completed(Err(ref e)) if is_transient(e)),
507            "{got:?}"
508        );
509
510        let got = handle_polling_error::<TestOperation>(
511            RetryResult::Permanent(transient()),
512            "name-for-permanent".to_string(),
513        );
514        assert!(got.0.is_none(), "{got:?}");
515        assert!(
516            matches!(got.1, PollingResult::Completed(Err(ref e)) if is_transient(e)),
517            "{got:?}"
518        );
519    }
520
521    #[test]
522    fn polling_success() {
523        let input = TestOperation {
524            name: Some("in-progress".to_string()),
525            ..TestOperation::default()
526        };
527        let got = handle_polling_success(input);
528        assert_eq!(got.0.as_deref(), Some("in-progress"), "{got:?}");
529        assert!(
530            matches!(&got.1, PollingResult::InProgress(Some(_))),
531            "{got:?}"
532        );
533    }
534
535    fn is_transient(error: &Error) -> bool {
536        error.status().is_some_and(|s| s == &transient_status())
537    }
538
539    fn transient() -> Error {
540        Error::service(transient_status())
541    }
542
543    fn transient_status() -> Status {
544        Status::default()
545            .set_code(Code::Unavailable)
546            .set_message("try-again")
547    }
548
549    fn permanent_status() -> Status {
550        Status::default()
551            .set_code(Code::PermissionDenied)
552            .set_message("uh-oh")
553    }
554
555    fn test_backoff() -> ExponentialBackoff {
556        ExponentialBackoffBuilder::new()
557            .with_initial_delay(Duration::from_millis(1))
558            .with_maximum_delay(Duration::from_millis(1))
559            .build()
560            .expect("hard-coded values should succeed")
561    }
562
563    #[derive(Debug, Default, PartialEq)]
564    struct TestOperation {
565        done: bool,
566        name: Option<String>,
567        value: Option<i32>,
568    }
569
570    impl DiscoveryOperation for TestOperation {
571        fn done(&self) -> bool {
572            self.done
573        }
574        fn name(&self) -> Option<&String> {
575            self.name.as_ref()
576        }
577    }
578
579    #[cfg(google_cloud_unstable_tracing)]
580    #[tokio::test]
581    async fn test_discovery_poller_tracing() {
582        let guard = google_cloud_test_utils::test_layer::TestLayer::initialize();
583
584        let start = || async move {
585            let op = TestOperation {
586                name: Some("discovery-operation-123".into()),
587                ..TestOperation::default()
588            };
589            Ok(op)
590        };
591
592        let count = Arc::new(std::sync::Mutex::new(0));
593        let query_count = count.clone();
594        let query = move |_: String| {
595            let mut c = query_count.lock().unwrap();
596            *c += 1;
597            let is_done = *c > 1;
598            async move {
599                if is_done {
600                    let op = TestOperation {
601                        done: true,
602                        value: Some(42),
603                        ..TestOperation::default()
604                    };
605                    Ok(op)
606                } else {
607                    let op = TestOperation {
608                        name: Some("discovery-operation-123".into()),
609                        ..TestOperation::default()
610                    };
611                    Ok(op)
612                }
613            }
614        };
615
616        let mut poller = DiscoveryPoller::new(
617            Arc::new(AlwaysContinue),
618            Arc::new(test_backoff()),
619            start,
620            query,
621        );
622
623        let span = test_span();
624        let poller_ref = &mut poller;
625        let recorder = crate::internal::LroRecorder::new(span.clone());
626        let _ = recorder
627            .scope(async move { poller_ref.poll().instrument(span).await })
628            .await;
629
630        {
631            let captured = google_cloud_test_utils::test_layer::TestLayer::capture(&guard);
632            let got = captured
633                .iter()
634                .find(|s| s.name == "test_span")
635                .unwrap_or_else(|| panic!("missing `test_span` in captured spans: {captured:?}"));
636            assert_eq!(
637                got.attributes
638                    .get("gcp.resource.destination.id")
639                    .and_then(|v| v.as_string()),
640                Some("discovery-operation-123".to_string())
641            );
642        }
643
644        let span = test_span();
645        let poller_ref2 = &mut poller;
646        let recorder2 = crate::internal::LroRecorder::new(span.clone());
647        let _ = recorder2
648            .scope(async move { poller_ref2.poll().instrument(span).await })
649            .await;
650
651        {
652            let captured = google_cloud_test_utils::test_layer::TestLayer::capture(&guard);
653            let got = captured
654                .iter()
655                .find(|s| s.name == "test_span")
656                .unwrap_or_else(|| panic!("missing `test_span` in captured spans: {captured:?}"));
657            assert_eq!(
658                got.attributes
659                    .get("gcp.resource.destination.id")
660                    .and_then(|v| v.as_string()),
661                Some("discovery-operation-123".to_string())
662            );
663        }
664    }
665}