mssf_util/
retry.rs

1// ------------------------------------------------------------
2// Copyright (c) Microsoft Corporation.  All rights reserved.
3// Licensed under the MIT License (MIT). See License.txt in the repo root for license information.
4// ------------------------------------------------------------
5
6use mssf_core::{
7    ErrorCode,
8    runtime::executor::{BoxedCancelToken, Timer},
9};
10use std::{pin::Pin, time::Duration};
11
12/// TimeCounter is used to track elapsed time and remaining time for operations.
13struct TimeCounter {
14    timeout: Duration,
15    start: std::time::Instant,
16}
17
18impl TimeCounter {
19    pub fn new(timeout: Duration) -> Self {
20        TimeCounter {
21            timeout,
22            start: std::time::Instant::now(),
23        }
24    }
25
26    pub fn elapsed(&self) -> Duration {
27        self.start.elapsed()
28    }
29
30    pub fn remaining(&self) -> mssf_core::Result<Duration> {
31        if self.elapsed() < self.timeout {
32            Ok(self.timeout - self.elapsed())
33        } else {
34            Err(ErrorCode::FABRIC_E_TIMEOUT.into())
35        }
36    }
37
38    /// returns a future that will sleep until the remaining time is up.
39    pub fn sleep_until_remaining(
40        &self,
41        timer: &dyn Timer,
42    ) -> mssf_core::Result<impl Future<Output = ()>> {
43        let remaining = self.remaining()?;
44        Ok(timer.sleep(remaining))
45    }
46}
47
48#[derive(Default)]
49pub struct OperationRetryerBuilder {
50    timer: Option<Box<dyn Timer>>,
51    default_timeout: Option<Duration>,
52    max_retry_interval: Option<Duration>,
53}
54
55impl OperationRetryerBuilder {
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// With a runtime timer to use for sleeping.
61    pub fn with_timer(mut self, timer: Box<dyn Timer>) -> Self {
62        self.timer = Some(timer);
63        self
64    }
65
66    pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
67        self.default_timeout = Some(timeout);
68        self
69    }
70
71    pub fn with_max_retry_interval(mut self, interval: Duration) -> Self {
72        self.max_retry_interval = Some(interval);
73        self
74    }
75
76    pub fn build(self) -> OperationRetryer {
77        OperationRetryer::new(
78            self.timer.unwrap_or(Box::new(crate::tokio::TokioTimer)),
79            self.default_timeout.unwrap_or(Duration::from_secs(30)),
80            self.max_retry_interval.unwrap_or(Duration::from_secs(5)),
81        )
82    }
83}
84
85/// A helper to retry an operation with transient error and timeout.
86pub struct OperationRetryer {
87    timer: Box<dyn Timer>,
88    default_timeout: Duration,
89    max_retry_interval: Duration,
90}
91
92impl OperationRetryer {
93    pub fn builder() -> OperationRetryerBuilder {
94        OperationRetryerBuilder::new()
95    }
96
97    fn new(timer: Box<dyn Timer>, default_timeout: Duration, max_retry_interval: Duration) -> Self {
98        OperationRetryer {
99            timer,
100            default_timeout,
101            max_retry_interval,
102        }
103    }
104
105    /// Run the operation with retry on transient errors and timeouts.
106    /// User can provide a total timeout and a cancel token.
107    pub async fn run<T, F, Fut>(
108        &self,
109        op: F,
110        timeout: Option<Duration>,
111        token: Option<BoxedCancelToken>,
112    ) -> mssf_core::Result<T>
113    where
114        F: Fn(Duration, Option<BoxedCancelToken>) -> Fut,
115        Fut: Future<Output = mssf_core::Result<T>> + Send,
116        T: Send,
117    {
118        let timeout = timeout.unwrap_or(self.default_timeout);
119        let timer = TimeCounter::new(timeout);
120        let mut cancel: Pin<Box<dyn std::future::Future<Output = ()> + Send>> =
121            if let Some(t) = &token {
122                t.wait()
123            } else {
124                Box::pin(std::future::pending())
125            };
126        loop {
127            let res = tokio::select! {
128                _ = timer.sleep_until_remaining(self.timer.as_ref())? => {
129                    // Timeout reached, return error.
130                    return Err(ErrorCode::FABRIC_E_TIMEOUT.into());
131                }
132                _ = &mut cancel => {
133                    // Cancellation requested, return error.
134                    return Err(ErrorCode::E_ABORT.into());
135                }
136                // Run the operation with the remaining time and cancel token.
137                res = op(timer.remaining()?, token.clone()) => res,
138            };
139            match res {
140                Ok(r) => return Ok(r),
141                Err(e) => match e.try_as_fabric_error_code() {
142                    Ok(ec) => {
143                        if ec == ErrorCode::FABRIC_E_TIMEOUT || ec.is_transient() {
144                            #[cfg(feature = "tracing")]
145                            tracing::debug!(
146                                "Operation transient error {ec}. Remaining time {:?}. Retrying...",
147                                timer.remaining()?
148                            );
149                            // do nothing, retry.
150                        } else {
151                            return Err(e);
152                        }
153                    }
154                    _ => return Err(e),
155                },
156            }
157            // sleep for a while before retrying.
158            tokio::select! {
159                _ = self.timer.sleep(self.max_retry_interval) => {},
160                _ = timer.sleep_until_remaining(self.timer.as_ref())? => {
161                    // Timeout reached, return error.
162                    return Err(ErrorCode::FABRIC_E_TIMEOUT.into());
163                }
164                _ = &mut cancel => {
165                    // Cancellation requested, return error.
166                    return Err(ErrorCode::E_ABORT.into());
167                }
168            }
169        }
170    }
171}