Skip to main content

google_cloud_spanner/
transaction_retry_policy.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::Error;
16use google_cloud_gax::backoff_policy::BackoffPolicy;
17use google_cloud_gax::error::rpc::StatusDetails;
18use google_cloud_gax::exponential_backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
19use google_cloud_gax::retry_result::RetryResult;
20use google_cloud_gax::retry_state::RetryState;
21use std::time::Duration;
22
23/// Defines a policy for retrying a transaction when it is aborted by Spanner.
24///
25/// Spanner can abort any read/write transaction due to lock conflicts or other
26/// transient issues. In such cases, the client should retry the complete
27/// transaction.
28pub trait TransactionRetryPolicy: Send + Sync {
29    /// Evaluates whether an aborted transaction should be retried.
30    ///
31    /// * `error` the `Aborted` error that was raised. Note that this policy
32    ///   takes ownership of the error and returns it embedded in the retry result.
33    /// * `attempts` is the number of attempts already made (1 for the first failure).
34    /// * `elapsed` is the total time spent executing the transaction so far.
35    fn on_abort(&self, error: Error, attempts: u32, elapsed: Duration) -> RetryResult;
36}
37
38/// Policy for automatically retrying a transaction when it is aborted based on
39/// the number of attempts and total elapsed time.
40#[derive(Clone, Debug)]
41pub struct BasicTransactionRetryPolicy {
42    /// The maximum number of attempts to make. If 0, this field is ignored.
43    max_attempts: u32,
44    /// The total maximum time to spend retrying. If 0, this field is ignored.
45    total_timeout: Duration,
46}
47
48impl BasicTransactionRetryPolicy {
49    /// Creates a new basic transaction retry policy with no limits.
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Sets the maximum number of attempts to make.
55    pub fn with_max_attempts(mut self, max_attempts: u32) -> Self {
56        self.max_attempts = max_attempts;
57        self
58    }
59
60    /// Sets the total maximum time to spend retrying.
61    pub fn with_total_timeout(mut self, total_timeout: Duration) -> Self {
62        self.total_timeout = total_timeout;
63        self
64    }
65
66    /// Returns the maximum number of attempts configured.
67    pub fn max_attempts(&self) -> u32 {
68        self.max_attempts
69    }
70
71    /// Returns the total maximum time configured.
72    pub fn total_timeout(&self) -> Duration {
73        self.total_timeout
74    }
75}
76
77impl Default for BasicTransactionRetryPolicy {
78    fn default() -> Self {
79        Self {
80            max_attempts: 0,
81            total_timeout: Duration::from_secs(0),
82        }
83    }
84}
85
86impl TransactionRetryPolicy for BasicTransactionRetryPolicy {
87    fn on_abort(&self, error: Error, attempts: u32, elapsed: Duration) -> RetryResult {
88        if self.max_attempts > 0 && attempts >= self.max_attempts {
89            return RetryResult::Exhausted(error);
90        }
91        if self.total_timeout > Duration::from_secs(0) && elapsed > self.total_timeout {
92            return RetryResult::Exhausted(error);
93        }
94        RetryResult::Continue(error)
95    }
96}
97
98/// Helper method to execute an asynchronous closure, retrying it if the
99/// transaction is aborted by the server.
100///
101/// This is used for operations like Partitioned DML transactions in Cloud Spanner, where
102/// the server may abort the transaction due to transient issues, indicating that the client
103/// should re-attempt the entire operation.
104pub(crate) async fn retry_aborted<T, F, Fut>(
105    policy: &dyn TransactionRetryPolicy,
106    mut f: F,
107    is_emulator: bool,
108) -> crate::Result<T>
109where
110    F: FnMut() -> Fut,
111    Fut: std::future::Future<Output = crate::Result<T>>,
112{
113    let start_time = tokio::time::Instant::now();
114    let mut attempts: u32 = 0;
115
116    // This backoff is only used if Spanner does not return a retry delay.
117    let backoff = default_retry_backoff();
118
119    loop {
120        attempts += 1;
121        match f().await {
122            Ok(v) => return Ok(v),
123            Err(e) => {
124                backoff_if_aborted(
125                    e,
126                    attempts,
127                    start_time.elapsed(),
128                    policy,
129                    &backoff,
130                    is_emulator,
131                )
132                .await?;
133            }
134        }
135    }
136}
137
138pub(crate) fn is_aborted(err: &crate::Error) -> bool {
139    err.status()
140        .is_some_and(|s| s.code == google_cloud_gax::error::rpc::Code::Aborted)
141}
142
143pub(crate) fn extract_retry_delay(err: &crate::Error) -> Option<Duration> {
144    err.status()?.details.iter().find_map(|detail| {
145        let StatusDetails::RetryInfo(retry_info) = detail else {
146            return None;
147        };
148        (*retry_info.retry_delay.as_ref()?).try_into().ok()
149    })
150}
151
152pub(crate) fn default_retry_backoff() -> ExponentialBackoff {
153    ExponentialBackoffBuilder::new()
154        .with_initial_delay(Duration::from_millis(10))
155        .with_maximum_delay(Duration::from_secs(1))
156        .with_scaling(1.3)
157        .build()
158        .unwrap()
159}
160
161pub(crate) fn is_internal_emulator_error(err: &crate::Error) -> bool {
162    if let Some(status) = err.status() {
163        status.code == google_cloud_gax::error::rpc::Code::Internal
164            && status.message.contains("Schema generation")
165            && status
166                .message
167                .contains("was not registered with the Action Manager")
168    } else {
169        false
170    }
171}
172
173/// Evaluates the error against the retry policy and delays execution if a retry is warranted.
174/// Returns Ok(()) after sleeping if a retry should occur, otherwise returns Err with the original error.
175pub(crate) async fn backoff_if_aborted(
176    err: crate::Error,
177    attempts: u32,
178    elapsed: Duration,
179    policy: &dyn TransactionRetryPolicy,
180    backoff: &ExponentialBackoff,
181    is_emulator: bool,
182) -> crate::Result<()> {
183    let should_retry = if is_aborted(&err) {
184        true
185    } else if is_emulator {
186        is_internal_emulator_error(&err)
187    } else {
188        false
189    };
190
191    if !should_retry {
192        return Err(err);
193    }
194
195    let e = match policy.on_abort(err, attempts, elapsed) {
196        RetryResult::Continue(err) => err,
197        RetryResult::Exhausted(err) | RetryResult::Permanent(err) => return Err(err),
198    };
199
200    let sleep_duration = extract_retry_delay(&e)
201        .unwrap_or_else(|| backoff.on_failure(&RetryState::new(true).set_attempt_count(attempts)));
202
203    tokio::time::sleep(sleep_duration).await;
204    Ok(())
205}
206
207#[cfg(test)]
208pub(crate) mod tests {
209    use super::*;
210    use crate::Error;
211    use google_cloud_gax::error::rpc::{Code, Status};
212    use google_cloud_rpc::model::RetryInfo;
213    use std::sync::Arc;
214    use std::sync::atomic::{AtomicU32, Ordering};
215    use wkt::Any;
216
217    fn create_aborted_error(retry_delay: Option<Duration>) -> Error {
218        let mut status = Status::default()
219            .set_code(Code::Aborted)
220            .set_message("aborted");
221
222        if let Some(delay) = retry_delay {
223            let retry_info = RetryInfo::default().set_retry_delay(wkt::Duration::clamp(
224                delay.as_secs() as i64,
225                delay.subsec_nanos() as i32,
226            ));
227            status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
228        }
229
230        Error::service(status)
231    }
232
233    pub(crate) fn create_aborted_status(
234        retry_delay: std::time::Duration,
235    ) -> gaxi::grpc::tonic::Status {
236        use prost::Message;
237
238        #[derive(Clone, PartialEq, prost::Message)]
239        struct MockRetryInfo {
240            #[prost(message, optional, tag = "1")]
241            retry_delay: Option<prost_types::Duration>,
242        }
243
244        let retry_info = MockRetryInfo {
245            retry_delay: Some(prost_types::Duration {
246                seconds: retry_delay.as_secs() as i64,
247                nanos: retry_delay.subsec_nanos() as i32,
248            }),
249        };
250
251        let mut retry_buf = vec![];
252        retry_info.encode(&mut retry_buf).unwrap();
253
254        let status = spanner_grpc_mock::google::rpc::Status {
255            code: gaxi::grpc::tonic::Code::Aborted as i32,
256            message: "test transaction aborted".to_string(),
257            details: vec![prost_types::Any {
258                type_url: "type.googleapis.com/google.rpc.RetryInfo".to_string(),
259                value: retry_buf,
260            }],
261        };
262
263        let mut buf = vec![];
264        status.encode(&mut buf).unwrap();
265
266        gaxi::grpc::tonic::Status::with_details(
267            gaxi::grpc::tonic::Code::Aborted,
268            "test transaction aborted",
269            bytes::Bytes::from(buf),
270        )
271    }
272
273    #[test]
274    fn auto_traits() {
275        static_assertions::assert_impl_all!(
276            BasicTransactionRetryPolicy: Send,
277            Sync,
278            Unpin,
279            Clone,
280            std::fmt::Debug,
281            Default,
282            TransactionRetryPolicy,
283        );
284    }
285
286    #[test]
287    fn basic_retry_policy_getters() {
288        let policy = BasicTransactionRetryPolicy::new()
289            .with_max_attempts(3)
290            .with_total_timeout(Duration::from_secs(10));
291        assert_eq!(policy.max_attempts(), 3);
292        assert_eq!(policy.total_timeout(), Duration::from_secs(10));
293    }
294
295    #[tokio::test]
296    async fn retry_aborted_success_first_try() {
297        let policy = BasicTransactionRetryPolicy::default();
298        let res = retry_aborted(
299            &policy,
300            || async { Ok::<i32, Error>(42) },
301            /* is_emulator = */ false,
302        )
303        .await;
304        assert_eq!(res.expect("Transaction should succeed cleanly"), 42);
305    }
306
307    #[tokio::test]
308    async fn retry_aborted_not_aborted_error() {
309        let policy = BasicTransactionRetryPolicy::default();
310        let res = retry_aborted(
311            &policy,
312            || async {
313                let status = Status::default()
314                    .set_code(Code::Unavailable)
315                    .set_message("server unavailable");
316                Err::<i32, Error>(Error::service(status))
317            },
318            /* is_emulator = */ false,
319        )
320        .await;
321
322        let err = res.unwrap_err();
323        assert_eq!(
324            err.status().expect("Error should contain a status").code,
325            Code::Unavailable
326        );
327    }
328
329    #[tokio::test(start_paused = true)]
330    async fn retry_aborted_max_attempts_exceeded() {
331        let policy = BasicTransactionRetryPolicy::new()
332            .with_max_attempts(2)
333            .with_total_timeout(Duration::from_secs(0));
334        let attempts = Arc::new(AtomicU32::new(0));
335
336        let res = retry_aborted(
337            &policy,
338            || {
339                let attempts = attempts.clone();
340                async move {
341                    attempts.fetch_add(1, Ordering::SeqCst);
342                    Err::<i32, Error>(create_aborted_error(None))
343                }
344            },
345            /* is_emulator = */ false,
346        )
347        .await;
348
349        assert!(res.is_err());
350        assert_eq!(attempts.load(Ordering::SeqCst), 2); // 1 initial + 1 retry
351    }
352
353    #[tokio::test(start_paused = true)]
354    async fn retry_aborted_with_retry_info() {
355        let policy = BasicTransactionRetryPolicy::default();
356        let attempts = Arc::new(AtomicU32::new(0));
357
358        let start = tokio::time::Instant::now();
359        let res = retry_aborted(
360            &policy,
361            || {
362                let attempts = attempts.clone();
363                async move {
364                    let current = attempts.fetch_add(1, Ordering::SeqCst);
365                    if current == 0 {
366                        Err::<i32, Error>(create_aborted_error(Some(Duration::from_nanos(1))))
367                    } else {
368                        Ok::<i32, Error>(100)
369                    }
370                }
371            },
372            /* is_emulator = */ false,
373        )
374        .await;
375        let elapsed = start.elapsed();
376
377        assert_eq!(res.expect("Transaction should succeed after 1 retry"), 100);
378        assert_eq!(attempts.load(Ordering::SeqCst), 2);
379        assert!(
380            elapsed >= Duration::from_nanos(1),
381            "Expected elapsed time to be at least 1ns, but was {:?}",
382            elapsed
383        );
384    }
385
386    #[tokio::test(start_paused = true)]
387    async fn retry_aborted_with_default_backoff() {
388        let policy = BasicTransactionRetryPolicy::default();
389        let attempts = Arc::new(AtomicU32::new(0));
390
391        let res = retry_aborted(
392            &policy,
393            || {
394                let attempts = attempts.clone();
395                async move {
396                    let current = attempts.fetch_add(1, Ordering::SeqCst);
397                    if current == 0 {
398                        Err::<i32, Error>(create_aborted_error(None))
399                    } else {
400                        Ok::<i32, Error>(100)
401                    }
402                }
403            },
404            /* is_emulator = */ false,
405        )
406        .await;
407
408        assert_eq!(
409            res.expect("Transaction should succeed using default backoff"),
410            100
411        );
412        assert_eq!(attempts.load(Ordering::SeqCst), 2);
413    }
414
415    #[tokio::test(start_paused = true)]
416    async fn retry_aborted_total_timeout_exceeded() {
417        let policy = BasicTransactionRetryPolicy::new()
418            .with_max_attempts(0)
419            .with_total_timeout(Duration::from_secs(1));
420        let attempts = Arc::new(AtomicU32::new(0));
421
422        let res = retry_aborted(
423            &policy,
424            || {
425                let attempts = attempts.clone();
426                async move {
427                    attempts.fetch_add(1, Ordering::SeqCst);
428                    // Return a retry delay of 600ms so that after 2 attempts (1.2s total delay),
429                    // we should definitely exceed the 1 second timeout for the 3rd fail check.
430                    Err::<i32, Error>(create_aborted_error(Some(Duration::from_millis(600))))
431                }
432            },
433            /* is_emulator = */ false,
434        )
435        .await;
436
437        assert!(res.is_err());
438        assert_eq!(attempts.load(Ordering::SeqCst), 3); // Initial + 2 delays = 1.0s elapsed *before* the 3rd attempt's delay
439    }
440
441    #[test]
442    fn is_aborted_non_status_error() {
443        let err = Error::deser("test internal error");
444        assert!(!is_aborted(&err));
445    }
446
447    #[test]
448    fn extract_retry_delay_no_status() {
449        let err = Error::deser("test internal error");
450        assert_eq!(extract_retry_delay(&err), None);
451    }
452
453    #[test]
454    fn extract_retry_delay_no_retry_info() {
455        let mut status = Status::default().set_code(Code::Aborted);
456        // Put a generic empty 'Any' which is not a RetryInfo
457        status = status.set_details(vec![Any::default()]);
458        let err = Error::service(status);
459        assert_eq!(extract_retry_delay(&err), None);
460    }
461
462    #[test]
463    fn extract_retry_delay_empty_retry_info() {
464        let mut status = Status::default().set_code(Code::Aborted);
465        let retry_info = RetryInfo::default(); // no retry_delay set
466        status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
467        let err = Error::service(status);
468        assert_eq!(extract_retry_delay(&err), None);
469    }
470
471    #[test]
472    fn extract_retry_delay_invalid_delay() {
473        let mut status = Status::default().set_code(Code::Aborted);
474        let retry_info = RetryInfo::default().set_retry_delay(wkt::Duration::clamp(
475            -10, // Invalid negative duration
476            0,
477        ));
478        status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
479        let err = Error::service(status);
480        assert_eq!(extract_retry_delay(&err), None);
481    }
482
483    #[tokio::test(start_paused = true)]
484    async fn retry_aborted_with_custom_policy() {
485        struct CustomPolicy;
486        impl TransactionRetryPolicy for CustomPolicy {
487            fn on_abort(&self, error: Error, attempts: u32, _elapsed: Duration) -> RetryResult {
488                if attempts < 3 {
489                    RetryResult::Continue(error)
490                } else {
491                    RetryResult::Exhausted(error)
492                }
493            }
494        }
495
496        let policy = CustomPolicy;
497        let attempts = Arc::new(AtomicU32::new(0));
498
499        let res = retry_aborted(
500            &policy,
501            || {
502                let attempts = attempts.clone();
503                async move {
504                    attempts.fetch_add(1, Ordering::SeqCst);
505                    Err::<i32, Error>(create_aborted_error(None))
506                }
507            },
508            /* is_emulator = */ false,
509        )
510        .await;
511
512        assert!(res.is_err());
513        assert_eq!(attempts.load(Ordering::SeqCst), 3); // Initial + 2 failures check
514    }
515
516    #[tokio::test(start_paused = true)]
517    async fn retry_aborted_emulator_internal_schema_error() {
518        let policy = BasicTransactionRetryPolicy::default();
519        let attempts = Arc::new(AtomicU32::new(0));
520
521        let make_schema_error = || {
522            let status = Status::default().set_code(Code::Internal).set_message(
523                "INTERNAL: Schema generation 0 was not registered with the Action Manager",
524            );
525            Error::service(status)
526        };
527
528        // If not running on emulator, it should fail immediately (no retry)
529        let res = retry_aborted(
530            &policy,
531            || {
532                let attempts = attempts.clone();
533                let err = make_schema_error();
534                async move {
535                    attempts.fetch_add(1, Ordering::SeqCst);
536                    Err::<i32, Error>(err)
537                }
538            },
539            /* is_emulator = */ false,
540        )
541        .await;
542        assert!(res.is_err());
543        assert_eq!(attempts.load(Ordering::SeqCst), 1);
544
545        // If running on the emulator, it should retry just like aborted error
546        attempts.store(0, Ordering::SeqCst);
547        let res = retry_aborted(
548            &policy,
549            || {
550                let attempts = attempts.clone();
551                let err = make_schema_error();
552                async move {
553                    let current = attempts.fetch_add(1, Ordering::SeqCst);
554                    if current == 0 {
555                        Err::<i32, Error>(err)
556                    } else {
557                        Ok::<i32, Error>(200)
558                    }
559                }
560            },
561            /* is_emulator = */ true,
562        )
563        .await;
564        assert_eq!(res.expect("should succeed after retry"), 200);
565        assert_eq!(attempts.load(Ordering::SeqCst), 2);
566    }
567}