google_cloud_lro/
lib.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Types and functions to make LROs easier to use and to require less boilerplate.
16
17use gax::error::Error;
18use gax::polling_backoff_policy::PollingBackoffPolicy;
19use gax::polling_policy::PollingPolicy;
20use gax::Result;
21use std::future::Future;
22use std::marker::PhantomData;
23use std::sync::Arc;
24use std::time::Instant;
25
26/// The result of polling a Long-Running Operation (LRO).
27///
28/// # Parameters
29/// * `R` - the response type. This is the type returned when the LRO completes
30///   successfully.
31/// * `M` - the metadata type. While operations are in progress the LRO may
32///   return values of this type.
33#[derive(Debug)]
34pub enum PollingResult<R, M> {
35    /// The operation is still in progress.
36    InProgress(Option<M>),
37    /// The operation completed. This includes the result.
38    Completed(Result<R>),
39    /// An error trying to poll the LRO.
40    ///
41    /// Not all errors indicate that the operation failed. For example, this
42    /// may fail because it was not possible to connect to Google Cloud. Such
43    /// transient errors may disappear in the next polling attempt.
44    ///
45    /// Other errors will never recover. For example, a [ServiceError] with
46    /// a [NOT_FOUND], [ABORTED], or [PERMISSION_DENIED] code will never
47    /// recover.
48    ///
49    /// [ServiceError]: gax::error::ServiceError
50    /// [NOT_FOUND]: rpc::model::code::NOT_FOUND
51    /// [ABORTED]: rpc::model::code::ABORTED
52    /// [PERMISSION_DENIED]: rpc::model::code::PERMISSION_DENIED
53    PollingError(Error),
54}
55
56/// A wrapper around [longrunning::model::Operation] with typed responses.
57///
58/// This is intended as an implementation detail of the generated clients.
59/// Applications should have no need to create or use this struct.
60#[doc(hidden)]
61pub struct Operation<R, M> {
62    inner: longrunning::model::Operation,
63    response: std::marker::PhantomData<R>,
64    metadata: std::marker::PhantomData<M>,
65}
66
67impl<R, M> Operation<R, M> {
68    pub fn new(inner: longrunning::model::Operation) -> Self {
69        Self {
70            inner,
71            response: PhantomData,
72            metadata: PhantomData,
73        }
74    }
75
76    fn name(&self) -> String {
77        self.inner.name.clone()
78    }
79    fn done(&self) -> bool {
80        self.inner.done
81    }
82    fn metadata(&self) -> Option<&wkt::Any> {
83        self.inner.metadata.as_ref()
84    }
85    fn response(&self) -> Option<&wkt::Any> {
86        use longrunning::model::operation::Result;
87        self.inner.result.as_ref().and_then(|r| match r {
88            Result::Error(_) => None,
89            Result::Response(r) => Some(r.as_ref()),
90            _ => None,
91        })
92    }
93    fn error(&self) -> Option<&rpc::model::Status> {
94        use longrunning::model::operation::Result;
95        self.inner.result.as_ref().and_then(|r| match r {
96            Result::Error(rpc) => Some(rpc.as_ref()),
97            Result::Response(_) => None,
98            _ => None,
99        })
100    }
101}
102
103/// The trait implemented by LRO helpers.
104///
105/// # Parameters
106/// * `R` - the response type, that is, the type of response included when the
107///   long-running operation completes successfully.
108/// * `M` - the metadata type, that is, the type returned by the service when
109///   the long-running operation is still in progress.
110pub trait Poller<R, M> {
111    /// Query the current status of the long-running operation.
112    fn poll(&mut self) -> impl Future<Output = Option<PollingResult<R, M>>>;
113
114    /// Poll the long-running operation until it completes.
115    fn until_done(self) -> impl Future<Output = Result<R>>;
116
117    /// Convert a poller to a [futures::Stream].
118    #[cfg(feature = "unstable-stream")]
119    fn to_stream(self) -> impl futures::Stream<Item = PollingResult<R, M>>;
120}
121
122/// Creates a new `impl Poller<R, M>` from the closures created by the generator.
123///
124/// This is intended as an implementation detail of the generated clients.
125/// Applications should have no need to create or use this struct.
126#[doc(hidden)]
127pub fn new_poller<ResponseType, MetadataType, S, SF, Q, QF>(
128    polling_policy: Arc<dyn PollingPolicy>,
129    polling_backoff_policy: Arc<dyn PollingBackoffPolicy>,
130    start: S,
131    query: Q,
132) -> impl Poller<ResponseType, MetadataType>
133where
134    ResponseType: wkt::message::Message + serde::de::DeserializeOwned,
135    MetadataType: wkt::message::Message + serde::de::DeserializeOwned,
136    S: FnOnce() -> SF + Send + Sync,
137    SF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
138        + Send
139        + 'static,
140    Q: Fn(String) -> QF + Send + Sync + Clone,
141    QF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
142        + Send
143        + 'static,
144{
145    PollerImpl::new(polling_policy, polling_backoff_policy, start, query)
146}
147
148/// An implementation of `Poller` based on closures.
149///
150/// Thanks to this implementation, the code generator (`sidekick`) needs to
151/// produce two closures: one to start the operation, and one to query progress.
152///
153/// Applications should not need to create this type, or use it directly. It is
154/// only public so the generated code can use it.
155///
156/// # Parameters
157/// * `ResponseType` - the response type. Typically this is a message
158///   representing the final disposition of the long-running operation.
159/// * `MetadataType` - the metadata type. The data included with partially
160///   completed instances of this long-running operations.
161/// * `S` - the start closure. Starts a LRO. This implementation expects that
162///   all necessary parameters, and request options, including retry options
163///   are captured by this function.
164/// * `SF` - the type of future returned by `S`.
165/// * `Q` - the query closure. Queries the status of the LRO created by `start`.
166///   It receives the name of the operation as its only input parameter. It
167///   should have captured any stubs and request options.
168/// * `QF` - the type of future returned by `Q`.
169struct PollerImpl<ResponseType, MetadataType, S, SF, Q, QF>
170where
171    S: FnOnce() -> SF + Send + Sync,
172    SF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
173        + Send
174        + 'static,
175    Q: Fn(String) -> QF + Send + Sync + Clone,
176    QF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
177        + Send
178        + 'static,
179{
180    polling_policy: Arc<dyn PollingPolicy>,
181    backoff_policy: Arc<dyn PollingBackoffPolicy>,
182    start: Option<S>,
183    query: Q,
184    operation: Option<String>,
185    loop_start: Instant,
186    attempt_count: u32,
187}
188
189impl<ResponseType, MetadataType, S, SF, Q, QF> PollerImpl<ResponseType, MetadataType, S, SF, Q, QF>
190where
191    S: FnOnce() -> SF + Send + Sync,
192    SF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
193        + Send
194        + 'static,
195    Q: Fn(String) -> QF + Send + Sync + Clone,
196    QF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
197        + Send
198        + 'static,
199{
200    pub fn new(
201        polling_policy: Arc<dyn PollingPolicy>,
202        backoff_policy: Arc<dyn PollingBackoffPolicy>,
203        start: S,
204        query: Q,
205    ) -> Self {
206        Self {
207            polling_policy,
208            backoff_policy,
209            start: Some(start),
210            query,
211            operation: None,
212            loop_start: Instant::now(),
213            attempt_count: 0,
214        }
215    }
216}
217
218impl<ResponseType, MetadataType, S, SF, P, PF> Poller<ResponseType, MetadataType>
219    for PollerImpl<ResponseType, MetadataType, S, SF, P, PF>
220where
221    ResponseType: wkt::message::Message + serde::de::DeserializeOwned,
222    MetadataType: wkt::message::Message + serde::de::DeserializeOwned,
223    S: FnOnce() -> SF + Send + Sync,
224    SF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
225        + Send
226        + 'static,
227    P: Fn(String) -> PF + Send + Sync + Clone,
228    PF: std::future::Future<Output = Result<Operation<ResponseType, MetadataType>>>
229        + Send
230        + 'static,
231{
232    async fn poll(&mut self) -> Option<PollingResult<ResponseType, MetadataType>> {
233        if let Some(start) = self.start.take() {
234            let result = start().await;
235            let (op, poll) = details::handle_start(result);
236            self.operation = op;
237            return Some(poll);
238        }
239        if let Some(name) = self.operation.take() {
240            self.attempt_count += 1;
241            let result = (self.query)(name.clone()).await;
242            let (op, poll) = details::handle_poll(
243                self.polling_policy.clone(),
244                self.loop_start,
245                self.attempt_count,
246                name,
247                result,
248            );
249            self.operation = op;
250            return Some(poll);
251        }
252        None
253    }
254
255    async fn until_done(mut self) -> Result<ResponseType> {
256        let loop_start = std::time::Instant::now();
257        let mut attempt_count = 0;
258        while let Some(p) = self.poll().await {
259            match p {
260                // Return, the operation completed or the polling policy is
261                // exhausted.
262                PollingResult::Completed(r) => return r,
263                // Continue, the operation was successfully polled and the
264                // polling policy was queried.
265                PollingResult::InProgress(_) => (),
266                // Continue, the polling policy was queried and decided the
267                // error is recoverable.
268                PollingResult::PollingError(_) => (),
269            }
270            attempt_count += 1;
271            tokio::time::sleep(self.backoff_policy.wait_period(loop_start, attempt_count)).await;
272        }
273        // We can only get here if `poll()` returns `None`, but it only returns
274        // `None` after it returned `Polling::Completed` and therefore this is
275        // never reached.
276        unreachable!("loop should exit via the `Completed` branch vs. this line");
277    }
278
279    #[cfg(feature = "unstable-stream")]
280    fn to_stream(self) -> impl futures::Stream<Item = PollingResult<ResponseType, MetadataType>>
281    where
282        ResponseType: wkt::message::Message + serde::de::DeserializeOwned,
283        MetadataType: wkt::message::Message + serde::de::DeserializeOwned,
284    {
285        use futures::stream::unfold;
286        unfold(Some(self), move |state| async move {
287            if let Some(mut poller) = state {
288                if let Some(pr) = poller.poll().await {
289                    return Some((pr, Some(poller)));
290                }
291            };
292            None
293        })
294    }
295}
296
297mod details;
298
299#[cfg(test)]
300mod test {
301    use super::*;
302    use gax::exponential_backoff::ExponentialBackoff;
303    use gax::exponential_backoff::ExponentialBackoffBuilder;
304    use gax::polling_policy::*;
305    use std::time::Duration;
306
307    type ResponseType = wkt::Duration;
308    type MetadataType = wkt::Timestamp;
309    type TestOperation = Operation<ResponseType, MetadataType>;
310
311    #[test]
312    fn typed_operation_with_metadata() -> Result<()> {
313        let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
314            .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
315        let op = longrunning::model::Operation::default()
316            .set_name("test-only-name")
317            .set_metadata(any);
318        let op = TestOperation::new(op);
319        assert_eq!(op.name(), "test-only-name");
320        assert!(!op.done());
321        assert!(matches!(op.metadata(), Some(_)));
322        assert!(matches!(op.response(), None));
323        assert!(matches!(op.error(), None));
324        let got = op
325            .metadata()
326            .unwrap()
327            .try_into_message::<wkt::Timestamp>()
328            .map_err(Error::other)?;
329        assert_eq!(got, wkt::Timestamp::clamp(123, 0));
330
331        Ok(())
332    }
333
334    #[test]
335    fn typed_operation_with_response() -> Result<()> {
336        let any = wkt::Any::try_from(&wkt::Duration::clamp(23, 0))
337            .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
338        let op = longrunning::model::Operation::default()
339            .set_name("test-only-name")
340            .set_result(longrunning::model::operation::Result::Response(any.into()));
341        let op = TestOperation::new(op);
342        assert_eq!(op.name(), "test-only-name");
343        assert!(!op.done());
344        assert!(matches!(op.metadata(), None));
345        assert!(matches!(op.response(), Some(_)));
346        assert!(matches!(op.error(), None));
347        let got = op
348            .response()
349            .unwrap()
350            .try_into_message::<wkt::Duration>()
351            .map_err(Error::other)?;
352        assert_eq!(got, wkt::Duration::clamp(23, 0));
353
354        Ok(())
355    }
356
357    #[test]
358    fn typed_operation_with_error() -> Result<()> {
359        let rpc = rpc::model::Status::default()
360            .set_message("test only")
361            .set_code(16);
362        let op = longrunning::model::Operation::default()
363            .set_name("test-only-name")
364            .set_result(longrunning::model::operation::Result::Error(
365                rpc.clone().into(),
366            ));
367        let op = TestOperation::new(op);
368        assert_eq!(op.name(), "test-only-name");
369        assert!(!op.done());
370        assert!(matches!(op.metadata(), None));
371        assert!(matches!(op.response(), None));
372        assert!(matches!(op.error(), Some(_)));
373        let got = op.error().unwrap();
374        assert_eq!(got, &rpc);
375
376        Ok(())
377    }
378
379    #[tokio::test(flavor = "multi_thread")]
380    async fn poll_basic_flow() {
381        let start = || async move {
382            let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
383                .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
384            let op = longrunning::model::Operation::default()
385                .set_name("test-only-name")
386                .set_metadata(any);
387            let op = TestOperation::new(op);
388            Ok::<TestOperation, Error>(op)
389        };
390
391        let query = |_: String| async move {
392            let any = wkt::Any::try_from(&wkt::Duration::clamp(234, 0))
393                .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
394            let result = longrunning::model::operation::Result::Response(any.into());
395            let op = longrunning::model::Operation::default()
396                .set_done(true)
397                .set_result(result);
398            let op = TestOperation::new(op);
399
400            Ok::<TestOperation, Error>(op)
401        };
402
403        let mut poller = PollerImpl::new(
404            Arc::new(AlwaysContinue),
405            Arc::new(ExponentialBackoff::default()),
406            start,
407            query,
408        );
409        let p0 = poller.poll().await;
410        match p0.unwrap() {
411            PollingResult::InProgress(m) => {
412                assert_eq!(m, Some(wkt::Timestamp::clamp(123, 0)));
413            }
414            r => {
415                assert!(false, "{r:?}");
416            }
417        }
418
419        let p1 = poller.poll().await;
420        match p1.unwrap() {
421            PollingResult::Completed(r) => {
422                let response = r.unwrap();
423                assert_eq!(response, wkt::Duration::clamp(234, 0));
424            }
425            r => {
426                assert!(false, "{r:?}");
427            }
428        }
429
430        let p2 = poller.poll().await;
431        assert!(p2.is_none(), "{p2:?}");
432    }
433
434    #[tokio::test(flavor = "multi_thread")]
435    async fn poll_basic_stream() {
436        let start = || async move {
437            let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
438                .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
439            let op = longrunning::model::Operation::default()
440                .set_name("test-only-name")
441                .set_metadata(any);
442            let op = TestOperation::new(op);
443            Ok::<TestOperation, Error>(op)
444        };
445
446        let query = |_: String| async move {
447            let any = wkt::Any::try_from(&wkt::Duration::clamp(234, 0))
448                .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
449            let result = longrunning::model::operation::Result::Response(any.into());
450            let op = longrunning::model::Operation::default()
451                .set_done(true)
452                .set_result(result);
453            let op = TestOperation::new(op);
454
455            Ok::<TestOperation, Error>(op)
456        };
457
458        use futures::StreamExt;
459        let mut stream = new_poller(
460            Arc::new(AlwaysContinue),
461            Arc::new(ExponentialBackoff::default()),
462            start,
463            query,
464        )
465        .to_stream();
466        let mut stream = std::pin::pin!(stream);
467        let p0 = stream.next().await;
468        match p0.unwrap() {
469            PollingResult::InProgress(m) => {
470                assert_eq!(m, Some(wkt::Timestamp::clamp(123, 0)));
471            }
472            r => {
473                assert!(false, "{r:?}");
474            }
475        }
476
477        let p1 = stream.next().await;
478        match p1.unwrap() {
479            PollingResult::Completed(r) => {
480                let response = r.unwrap();
481                assert_eq!(response, wkt::Duration::clamp(234, 0));
482            }
483            r => {
484                assert!(false, "{r:?}");
485            }
486        }
487
488        let p2 = stream.next().await;
489        assert!(p2.is_none(), "{p2:?}");
490    }
491
492    #[tokio::test(flavor = "multi_thread")]
493    async fn until_done_basic_flow() -> Result<()> {
494        let start = || async move {
495            let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
496                .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
497            let op = longrunning::model::Operation::default()
498                .set_name("test-only-name")
499                .set_metadata(any);
500            let op = TestOperation::new(op);
501            Ok::<TestOperation, Error>(op)
502        };
503
504        let query = |_: String| async move {
505            let any = wkt::Any::try_from(&wkt::Duration::clamp(234, 0))
506                .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
507            let result = longrunning::model::operation::Result::Response(any.into());
508            let op = longrunning::model::Operation::default()
509                .set_done(true)
510                .set_result(result);
511            let op = TestOperation::new(op);
512
513            Ok::<TestOperation, Error>(op)
514        };
515
516        let poller = PollerImpl::new(
517            Arc::new(AlwaysContinue),
518            Arc::new(
519                ExponentialBackoffBuilder::new()
520                    .with_initial_delay(Duration::from_millis(1))
521                    .clamp(),
522            ),
523            start,
524            query,
525        );
526        let response = poller.until_done().await?;
527        assert_eq!(response, wkt::Duration::clamp(234, 0));
528
529        Ok(())
530    }
531
532    #[tokio::test(flavor = "multi_thread")]
533    async fn until_done_with_recoverable_polling_error() -> Result<()> {
534        let start = || async move {
535            let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
536                .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
537            let op = longrunning::model::Operation::default()
538                .set_name("test-only-name")
539                .set_metadata(any);
540            let op = TestOperation::new(op);
541            Ok::<TestOperation, Error>(op)
542        };
543
544        let count = Arc::new(std::sync::Mutex::new(0_u32));
545        let query = move |_: String| {
546            let mut guard = count.lock().unwrap();
547            let c = *guard;
548            *guard = c + 1;
549            drop(guard);
550            async move {
551                if c == 0 {
552                    return Err::<TestOperation, Error>(Error::other(
553                        "recoverable (see policy below)",
554                    ));
555                }
556                let any = wkt::Any::try_from(&wkt::Duration::clamp(234, 0))
557                    .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
558                let result = longrunning::model::operation::Result::Response(any.into());
559                let op = longrunning::model::Operation::default()
560                    .set_done(true)
561                    .set_result(result);
562                let op = TestOperation::new(op);
563
564                Ok::<TestOperation, Error>(op)
565            }
566        };
567
568        let poller = PollerImpl::new(
569            Arc::new(AlwaysContinue),
570            Arc::new(
571                ExponentialBackoffBuilder::new()
572                    .with_initial_delay(Duration::from_millis(1))
573                    .clamp(),
574            ),
575            start,
576            query,
577        );
578        let response = poller.until_done().await?;
579        assert_eq!(response, wkt::Duration::clamp(234, 0));
580
581        Ok(())
582    }
583
584    #[tokio::test(flavor = "multi_thread")]
585    async fn until_done_with_unrecoverable_polling_error() -> Result<()> {
586        let start = || async move {
587            let any = wkt::Any::try_from(&wkt::Timestamp::clamp(123, 0))
588                .map_err(|e| Error::other(format!("unexpected error in Any::try_from {e}")))?;
589            let op = longrunning::model::Operation::default()
590                .set_name("test-only-name")
591                .set_metadata(any);
592            let op = TestOperation::new(op);
593            Ok::<TestOperation, Error>(op)
594        };
595
596        let query = move |_: String| async move {
597            return Err::<TestOperation, Error>(Error::other("unrecoverable (see policy below)"));
598        };
599
600        let poller = PollerImpl::new(
601            Arc::new(Aip194Strict),
602            Arc::new(
603                ExponentialBackoffBuilder::new()
604                    .with_initial_delay(Duration::from_millis(1))
605                    .clamp(),
606            ),
607            start,
608            query,
609        );
610        let response = poller.until_done().await;
611        assert!(response.is_err());
612        assert!(
613            format!("{response:?}").contains("unrecoverable"),
614            "{response:?}"
615        );
616
617        Ok(())
618    }
619}