Skip to main content

google_cloud_spanner/
transaction_retry_policy.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::Error;
16use google_cloud_gax::backoff_policy::BackoffPolicy;
17use google_cloud_gax::error::rpc::StatusDetails;
18use google_cloud_gax::exponential_backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
19use google_cloud_gax::retry_result::RetryResult;
20use google_cloud_gax::retry_state::RetryState;
21use std::time::Duration;
22
23/// Defines a policy for retrying a transaction when it is aborted by Spanner.
24///
25/// Spanner can abort any read/write transaction due to lock conflicts or other
26/// transient issues. In such cases, the client should retry the complete
27/// transaction.
28pub trait TransactionRetryPolicy: Send + Sync {
29    /// Evaluates whether an aborted transaction should be retried.
30    ///
31    /// * `error` the `Aborted` error that was raised. Note that this policy
32    ///   takes ownership of the error and returns it embedded in the retry result.
33    /// * `attempts` is the number of attempts already made (1 for the first failure).
34    /// * `elapsed` is the total time spent executing the transaction so far.
35    fn on_abort(&self, error: Error, attempts: u32, elapsed: Duration) -> RetryResult;
36}
37
38/// Policy for automatically retrying a transaction when it is aborted based on
39/// the number of attempts and total elapsed time.
40#[derive(Clone, Debug)]
41pub struct BasicTransactionRetryPolicy {
42    /// The maximum number of attempts to make. If 0, this field is ignored.
43    max_attempts: u32,
44    /// The total maximum time to spend retrying. If 0, this field is ignored.
45    total_timeout: Duration,
46}
47
48impl BasicTransactionRetryPolicy {
49    /// Creates a new basic transaction retry policy with no limits.
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Sets the maximum number of attempts to make.
55    pub fn with_max_attempts(mut self, max_attempts: u32) -> Self {
56        self.max_attempts = max_attempts;
57        self
58    }
59
60    /// Sets the total maximum time to spend retrying.
61    pub fn with_total_timeout(mut self, total_timeout: Duration) -> Self {
62        self.total_timeout = total_timeout;
63        self
64    }
65
66    /// Returns the maximum number of attempts configured.
67    pub fn max_attempts(&self) -> u32 {
68        self.max_attempts
69    }
70
71    /// Returns the total maximum time configured.
72    pub fn total_timeout(&self) -> Duration {
73        self.total_timeout
74    }
75}
76
77impl Default for BasicTransactionRetryPolicy {
78    fn default() -> Self {
79        Self {
80            max_attempts: 0,
81            total_timeout: Duration::from_secs(0),
82        }
83    }
84}
85
86impl TransactionRetryPolicy for BasicTransactionRetryPolicy {
87    fn on_abort(&self, error: Error, attempts: u32, elapsed: Duration) -> RetryResult {
88        if self.max_attempts > 0 && attempts >= self.max_attempts {
89            return RetryResult::Exhausted(error);
90        }
91        if self.total_timeout > Duration::from_secs(0) && elapsed > self.total_timeout {
92            return RetryResult::Exhausted(error);
93        }
94        RetryResult::Continue(error)
95    }
96}
97
98/// Helper method to execute an asynchronous closure, retrying it if the
99/// transaction is aborted by the server.
100///
101/// This is used for operations like Partitioned DML transactions in Cloud Spanner, where
102/// the server may abort the transaction due to transient issues, indicating that the client
103/// should re-attempt the entire operation.
104pub(crate) async fn retry_aborted<T, F, Fut>(
105    policy: &dyn TransactionRetryPolicy,
106    mut f: F,
107) -> crate::Result<T>
108where
109    F: FnMut() -> Fut,
110    Fut: std::future::Future<Output = crate::Result<T>>,
111{
112    let start_time = tokio::time::Instant::now();
113    let mut attempts: u32 = 0;
114
115    // This backoff is only used if Spanner does not return a retry delay.
116    let backoff = default_retry_backoff();
117
118    loop {
119        attempts += 1;
120        match f().await {
121            Ok(v) => return Ok(v),
122            Err(e) => {
123                backoff_if_aborted(e, attempts, start_time.elapsed(), policy, &backoff).await?;
124            }
125        }
126    }
127}
128
129pub(crate) fn is_aborted(err: &crate::Error) -> bool {
130    err.status()
131        .is_some_and(|s| s.code == google_cloud_gax::error::rpc::Code::Aborted)
132}
133
134pub(crate) fn extract_retry_delay(err: &crate::Error) -> Option<Duration> {
135    err.status()?.details.iter().find_map(|detail| {
136        let StatusDetails::RetryInfo(retry_info) = detail else {
137            return None;
138        };
139        (*retry_info.retry_delay.as_ref()?).try_into().ok()
140    })
141}
142
143pub(crate) fn default_retry_backoff() -> ExponentialBackoff {
144    ExponentialBackoffBuilder::new()
145        .with_initial_delay(Duration::from_millis(10))
146        .with_maximum_delay(Duration::from_secs(1))
147        .with_scaling(1.3)
148        .build()
149        .unwrap()
150}
151
152/// Evaluates the error against the retry policy and delays execution if a retry is warranted.
153/// Returns Ok(()) after sleeping if a retry should occur, otherwise returns Err with the original error.
154pub(crate) async fn backoff_if_aborted(
155    err: crate::Error,
156    attempts: u32,
157    elapsed: Duration,
158    policy: &dyn TransactionRetryPolicy,
159    backoff: &ExponentialBackoff,
160) -> crate::Result<()> {
161    if !is_aborted(&err) {
162        return Err(err);
163    }
164
165    let e = match policy.on_abort(err, attempts, elapsed) {
166        RetryResult::Continue(err) => err,
167        RetryResult::Exhausted(err) | RetryResult::Permanent(err) => return Err(err),
168    };
169
170    let sleep_duration = extract_retry_delay(&e)
171        .unwrap_or_else(|| backoff.on_failure(&RetryState::new(true).set_attempt_count(attempts)));
172
173    tokio::time::sleep(sleep_duration).await;
174    Ok(())
175}
176
177#[cfg(test)]
178pub(crate) mod tests {
179    use super::*;
180    use crate::Error;
181    use google_cloud_gax::error::rpc::{Code, Status};
182    use google_cloud_rpc::model::RetryInfo;
183    use std::sync::Arc;
184    use std::sync::atomic::{AtomicU32, Ordering};
185    use wkt::Any;
186
187    fn create_aborted_error(retry_delay: Option<Duration>) -> Error {
188        let mut status = Status::default()
189            .set_code(Code::Aborted)
190            .set_message("aborted");
191
192        if let Some(delay) = retry_delay {
193            let retry_info = RetryInfo::default().set_retry_delay(wkt::Duration::clamp(
194                delay.as_secs() as i64,
195                delay.subsec_nanos() as i32,
196            ));
197            status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
198        }
199
200        Error::service(status)
201    }
202
203    pub(crate) fn create_aborted_status(
204        retry_delay: std::time::Duration,
205    ) -> gaxi::grpc::tonic::Status {
206        use prost::Message;
207
208        #[derive(Clone, PartialEq, prost::Message)]
209        struct MockRetryInfo {
210            #[prost(message, optional, tag = "1")]
211            retry_delay: Option<prost_types::Duration>,
212        }
213
214        let retry_info = MockRetryInfo {
215            retry_delay: Some(prost_types::Duration {
216                seconds: retry_delay.as_secs() as i64,
217                nanos: retry_delay.subsec_nanos() as i32,
218            }),
219        };
220
221        let mut retry_buf = vec![];
222        retry_info.encode(&mut retry_buf).unwrap();
223
224        let status = spanner_grpc_mock::google::rpc::Status {
225            code: gaxi::grpc::tonic::Code::Aborted as i32,
226            message: "test transaction aborted".to_string(),
227            details: vec![prost_types::Any {
228                type_url: "type.googleapis.com/google.rpc.RetryInfo".to_string(),
229                value: retry_buf,
230            }],
231        };
232
233        let mut buf = vec![];
234        status.encode(&mut buf).unwrap();
235
236        gaxi::grpc::tonic::Status::with_details(
237            gaxi::grpc::tonic::Code::Aborted,
238            "test transaction aborted",
239            bytes::Bytes::from(buf),
240        )
241    }
242
243    #[test]
244    fn auto_traits() {
245        static_assertions::assert_impl_all!(
246            BasicTransactionRetryPolicy: Send,
247            Sync,
248            Unpin,
249            Clone,
250            std::fmt::Debug,
251            Default,
252            TransactionRetryPolicy,
253        );
254    }
255
256    #[test]
257    fn basic_retry_policy_getters() {
258        let policy = BasicTransactionRetryPolicy::new()
259            .with_max_attempts(3)
260            .with_total_timeout(Duration::from_secs(10));
261        assert_eq!(policy.max_attempts(), 3);
262        assert_eq!(policy.total_timeout(), Duration::from_secs(10));
263    }
264
265    #[tokio::test]
266    async fn retry_aborted_success_first_try() {
267        let policy = BasicTransactionRetryPolicy::default();
268        let res = retry_aborted(&policy, || async { Ok::<i32, Error>(42) }).await;
269        assert_eq!(res.expect("Transaction should succeed cleanly"), 42);
270    }
271
272    #[tokio::test]
273    async fn retry_aborted_not_aborted_error() {
274        let policy = BasicTransactionRetryPolicy::default();
275        let res = retry_aborted(&policy, || async {
276            let status = Status::default()
277                .set_code(Code::Unavailable)
278                .set_message("server unavailable");
279            Err::<i32, Error>(Error::service(status))
280        })
281        .await;
282
283        let err = res.unwrap_err();
284        assert_eq!(
285            err.status().expect("Error should contain a status").code,
286            Code::Unavailable
287        );
288    }
289
290    #[tokio::test(start_paused = true)]
291    async fn retry_aborted_max_attempts_exceeded() {
292        let policy = BasicTransactionRetryPolicy::new()
293            .with_max_attempts(2)
294            .with_total_timeout(Duration::from_secs(0));
295        let attempts = Arc::new(AtomicU32::new(0));
296
297        let res = retry_aborted(&policy, || {
298            let attempts = attempts.clone();
299            async move {
300                attempts.fetch_add(1, Ordering::SeqCst);
301                Err::<i32, Error>(create_aborted_error(None))
302            }
303        })
304        .await;
305
306        assert!(res.is_err());
307        assert_eq!(attempts.load(Ordering::SeqCst), 2); // 1 initial + 1 retry
308    }
309
310    #[tokio::test(start_paused = true)]
311    async fn retry_aborted_with_retry_info() {
312        let policy = BasicTransactionRetryPolicy::default();
313        let attempts = Arc::new(AtomicU32::new(0));
314
315        let start = tokio::time::Instant::now();
316        let res = retry_aborted(&policy, || {
317            let attempts = attempts.clone();
318            async move {
319                let current = attempts.fetch_add(1, Ordering::SeqCst);
320                if current == 0 {
321                    Err::<i32, Error>(create_aborted_error(Some(Duration::from_nanos(1))))
322                } else {
323                    Ok::<i32, Error>(100)
324                }
325            }
326        })
327        .await;
328        let elapsed = start.elapsed();
329
330        assert_eq!(res.expect("Transaction should succeed after 1 retry"), 100);
331        assert_eq!(attempts.load(Ordering::SeqCst), 2);
332        assert!(
333            elapsed >= Duration::from_nanos(1),
334            "Expected elapsed time to be at least 1ns, but was {:?}",
335            elapsed
336        );
337    }
338
339    #[tokio::test(start_paused = true)]
340    async fn retry_aborted_with_default_backoff() {
341        let policy = BasicTransactionRetryPolicy::default();
342        let attempts = Arc::new(AtomicU32::new(0));
343
344        let res = retry_aborted(&policy, || {
345            let attempts = attempts.clone();
346            async move {
347                let current = attempts.fetch_add(1, Ordering::SeqCst);
348                if current == 0 {
349                    Err::<i32, Error>(create_aborted_error(None))
350                } else {
351                    Ok::<i32, Error>(100)
352                }
353            }
354        })
355        .await;
356
357        assert_eq!(
358            res.expect("Transaction should succeed using default backoff"),
359            100
360        );
361        assert_eq!(attempts.load(Ordering::SeqCst), 2);
362    }
363
364    #[tokio::test(start_paused = true)]
365    async fn retry_aborted_total_timeout_exceeded() {
366        let policy = BasicTransactionRetryPolicy::new()
367            .with_max_attempts(0)
368            .with_total_timeout(Duration::from_secs(1));
369        let attempts = Arc::new(AtomicU32::new(0));
370
371        let res = retry_aborted(&policy, || {
372            let attempts = attempts.clone();
373            async move {
374                attempts.fetch_add(1, Ordering::SeqCst);
375                // Return a retry delay of 600ms so that after 2 attempts (1.2s total delay),
376                // we should definitely exceed the 1 second timeout for the 3rd fail check.
377                Err::<i32, Error>(create_aborted_error(Some(Duration::from_millis(600))))
378            }
379        })
380        .await;
381
382        assert!(res.is_err());
383        assert_eq!(attempts.load(Ordering::SeqCst), 3); // Initial + 2 delays = 1.0s elapsed *before* the 3rd attempt's delay
384    }
385
386    #[test]
387    fn is_aborted_non_status_error() {
388        let err = Error::deser("test internal error");
389        assert!(!is_aborted(&err));
390    }
391
392    #[test]
393    fn extract_retry_delay_no_status() {
394        let err = Error::deser("test internal error");
395        assert_eq!(extract_retry_delay(&err), None);
396    }
397
398    #[test]
399    fn extract_retry_delay_no_retry_info() {
400        let mut status = Status::default().set_code(Code::Aborted);
401        // Put a generic empty 'Any' which is not a RetryInfo
402        status = status.set_details(vec![Any::default()]);
403        let err = Error::service(status);
404        assert_eq!(extract_retry_delay(&err), None);
405    }
406
407    #[test]
408    fn extract_retry_delay_empty_retry_info() {
409        let mut status = Status::default().set_code(Code::Aborted);
410        let retry_info = RetryInfo::default(); // no retry_delay set
411        status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
412        let err = Error::service(status);
413        assert_eq!(extract_retry_delay(&err), None);
414    }
415
416    #[test]
417    fn extract_retry_delay_invalid_delay() {
418        let mut status = Status::default().set_code(Code::Aborted);
419        let retry_info = RetryInfo::default().set_retry_delay(wkt::Duration::clamp(
420            -10, // Invalid negative duration
421            0,
422        ));
423        status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
424        let err = Error::service(status);
425        assert_eq!(extract_retry_delay(&err), None);
426    }
427
428    #[tokio::test(start_paused = true)]
429    async fn retry_aborted_with_custom_policy() {
430        struct CustomPolicy;
431        impl TransactionRetryPolicy for CustomPolicy {
432            fn on_abort(&self, error: Error, attempts: u32, _elapsed: Duration) -> RetryResult {
433                if attempts < 3 {
434                    RetryResult::Continue(error)
435                } else {
436                    RetryResult::Exhausted(error)
437                }
438            }
439        }
440
441        let policy = CustomPolicy;
442        let attempts = Arc::new(AtomicU32::new(0));
443
444        let res = retry_aborted(&policy, || {
445            let attempts = attempts.clone();
446            async move {
447                attempts.fetch_add(1, Ordering::SeqCst);
448                Err::<i32, Error>(create_aborted_error(None))
449            }
450        })
451        .await;
452
453        assert!(res.is_err());
454        assert_eq!(attempts.load(Ordering::SeqCst), 3); // Initial + 2 failures check
455    }
456}