1use crate::Error;
16use google_cloud_gax::backoff_policy::BackoffPolicy;
17use google_cloud_gax::error::rpc::StatusDetails;
18use google_cloud_gax::exponential_backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
19use google_cloud_gax::retry_result::RetryResult;
20use google_cloud_gax::retry_state::RetryState;
21use std::time::Duration;
22
23pub trait TransactionRetryPolicy: Send + Sync {
29 fn on_abort(&self, error: Error, attempts: u32, elapsed: Duration) -> RetryResult;
36}
37
38#[derive(Clone, Debug)]
41pub struct BasicTransactionRetryPolicy {
42 max_attempts: u32,
44 total_timeout: Duration,
46}
47
48impl BasicTransactionRetryPolicy {
49 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn with_max_attempts(mut self, max_attempts: u32) -> Self {
56 self.max_attempts = max_attempts;
57 self
58 }
59
60 pub fn with_total_timeout(mut self, total_timeout: Duration) -> Self {
62 self.total_timeout = total_timeout;
63 self
64 }
65
66 pub fn max_attempts(&self) -> u32 {
68 self.max_attempts
69 }
70
71 pub fn total_timeout(&self) -> Duration {
73 self.total_timeout
74 }
75}
76
77impl Default for BasicTransactionRetryPolicy {
78 fn default() -> Self {
79 Self {
80 max_attempts: 0,
81 total_timeout: Duration::from_secs(0),
82 }
83 }
84}
85
86impl TransactionRetryPolicy for BasicTransactionRetryPolicy {
87 fn on_abort(&self, error: Error, attempts: u32, elapsed: Duration) -> RetryResult {
88 if self.max_attempts > 0 && attempts >= self.max_attempts {
89 return RetryResult::Exhausted(error);
90 }
91 if self.total_timeout > Duration::from_secs(0) && elapsed > self.total_timeout {
92 return RetryResult::Exhausted(error);
93 }
94 RetryResult::Continue(error)
95 }
96}
97
98pub(crate) async fn retry_aborted<T, F, Fut>(
105 policy: &dyn TransactionRetryPolicy,
106 mut f: F,
107 is_emulator: bool,
108) -> crate::Result<T>
109where
110 F: FnMut() -> Fut,
111 Fut: std::future::Future<Output = crate::Result<T>>,
112{
113 let start_time = tokio::time::Instant::now();
114 let mut attempts: u32 = 0;
115
116 let backoff = default_retry_backoff();
118
119 loop {
120 attempts += 1;
121 match f().await {
122 Ok(v) => return Ok(v),
123 Err(e) => {
124 backoff_if_aborted(
125 e,
126 attempts,
127 start_time.elapsed(),
128 policy,
129 &backoff,
130 is_emulator,
131 )
132 .await?;
133 }
134 }
135 }
136}
137
138pub(crate) fn is_aborted(err: &crate::Error) -> bool {
139 err.status()
140 .is_some_and(|s| s.code == google_cloud_gax::error::rpc::Code::Aborted)
141}
142
143pub(crate) fn extract_retry_delay(err: &crate::Error) -> Option<Duration> {
144 err.status()?.details.iter().find_map(|detail| {
145 let StatusDetails::RetryInfo(retry_info) = detail else {
146 return None;
147 };
148 (*retry_info.retry_delay.as_ref()?).try_into().ok()
149 })
150}
151
152pub(crate) fn default_retry_backoff() -> ExponentialBackoff {
153 ExponentialBackoffBuilder::new()
154 .with_initial_delay(Duration::from_millis(10))
155 .with_maximum_delay(Duration::from_secs(1))
156 .with_scaling(1.3)
157 .build()
158 .unwrap()
159}
160
161pub(crate) fn is_internal_emulator_error(err: &crate::Error) -> bool {
162 if let Some(status) = err.status() {
163 status.code == google_cloud_gax::error::rpc::Code::Internal
164 && status.message.contains("Schema generation")
165 && status
166 .message
167 .contains("was not registered with the Action Manager")
168 } else {
169 false
170 }
171}
172
173pub(crate) async fn backoff_if_aborted(
176 err: crate::Error,
177 attempts: u32,
178 elapsed: Duration,
179 policy: &dyn TransactionRetryPolicy,
180 backoff: &ExponentialBackoff,
181 is_emulator: bool,
182) -> crate::Result<()> {
183 let should_retry = if is_aborted(&err) {
184 true
185 } else if is_emulator {
186 is_internal_emulator_error(&err)
187 } else {
188 false
189 };
190
191 if !should_retry {
192 return Err(err);
193 }
194
195 let e = match policy.on_abort(err, attempts, elapsed) {
196 RetryResult::Continue(err) => err,
197 RetryResult::Exhausted(err) | RetryResult::Permanent(err) => return Err(err),
198 };
199
200 let sleep_duration = extract_retry_delay(&e)
201 .unwrap_or_else(|| backoff.on_failure(&RetryState::new(true).set_attempt_count(attempts)));
202
203 tokio::time::sleep(sleep_duration).await;
204 Ok(())
205}
206
207#[cfg(test)]
208pub(crate) mod tests {
209 use super::*;
210 use crate::Error;
211 use google_cloud_gax::error::rpc::{Code, Status};
212 use google_cloud_rpc::model::RetryInfo;
213 use std::sync::Arc;
214 use std::sync::atomic::{AtomicU32, Ordering};
215 use wkt::Any;
216
217 fn create_aborted_error(retry_delay: Option<Duration>) -> Error {
218 let mut status = Status::default()
219 .set_code(Code::Aborted)
220 .set_message("aborted");
221
222 if let Some(delay) = retry_delay {
223 let retry_info = RetryInfo::default().set_retry_delay(wkt::Duration::clamp(
224 delay.as_secs() as i64,
225 delay.subsec_nanos() as i32,
226 ));
227 status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
228 }
229
230 Error::service(status)
231 }
232
233 pub(crate) fn create_aborted_status(
234 retry_delay: std::time::Duration,
235 ) -> gaxi::grpc::tonic::Status {
236 use prost::Message;
237
238 #[derive(Clone, PartialEq, prost::Message)]
239 struct MockRetryInfo {
240 #[prost(message, optional, tag = "1")]
241 retry_delay: Option<prost_types::Duration>,
242 }
243
244 let retry_info = MockRetryInfo {
245 retry_delay: Some(prost_types::Duration {
246 seconds: retry_delay.as_secs() as i64,
247 nanos: retry_delay.subsec_nanos() as i32,
248 }),
249 };
250
251 let mut retry_buf = vec![];
252 retry_info.encode(&mut retry_buf).unwrap();
253
254 let status = spanner_grpc_mock::google::rpc::Status {
255 code: gaxi::grpc::tonic::Code::Aborted as i32,
256 message: "test transaction aborted".to_string(),
257 details: vec![prost_types::Any {
258 type_url: "type.googleapis.com/google.rpc.RetryInfo".to_string(),
259 value: retry_buf,
260 }],
261 };
262
263 let mut buf = vec![];
264 status.encode(&mut buf).unwrap();
265
266 gaxi::grpc::tonic::Status::with_details(
267 gaxi::grpc::tonic::Code::Aborted,
268 "test transaction aborted",
269 bytes::Bytes::from(buf),
270 )
271 }
272
273 #[test]
274 fn auto_traits() {
275 static_assertions::assert_impl_all!(
276 BasicTransactionRetryPolicy: Send,
277 Sync,
278 Unpin,
279 Clone,
280 std::fmt::Debug,
281 Default,
282 TransactionRetryPolicy,
283 );
284 }
285
286 #[test]
287 fn basic_retry_policy_getters() {
288 let policy = BasicTransactionRetryPolicy::new()
289 .with_max_attempts(3)
290 .with_total_timeout(Duration::from_secs(10));
291 assert_eq!(policy.max_attempts(), 3);
292 assert_eq!(policy.total_timeout(), Duration::from_secs(10));
293 }
294
295 #[tokio::test]
296 async fn retry_aborted_success_first_try() {
297 let policy = BasicTransactionRetryPolicy::default();
298 let res = retry_aborted(
299 &policy,
300 || async { Ok::<i32, Error>(42) },
301 false,
302 )
303 .await;
304 assert_eq!(res.expect("Transaction should succeed cleanly"), 42);
305 }
306
307 #[tokio::test]
308 async fn retry_aborted_not_aborted_error() {
309 let policy = BasicTransactionRetryPolicy::default();
310 let res = retry_aborted(
311 &policy,
312 || async {
313 let status = Status::default()
314 .set_code(Code::Unavailable)
315 .set_message("server unavailable");
316 Err::<i32, Error>(Error::service(status))
317 },
318 false,
319 )
320 .await;
321
322 let err = res.unwrap_err();
323 assert_eq!(
324 err.status().expect("Error should contain a status").code,
325 Code::Unavailable
326 );
327 }
328
329 #[tokio::test(start_paused = true)]
330 async fn retry_aborted_max_attempts_exceeded() {
331 let policy = BasicTransactionRetryPolicy::new()
332 .with_max_attempts(2)
333 .with_total_timeout(Duration::from_secs(0));
334 let attempts = Arc::new(AtomicU32::new(0));
335
336 let res = retry_aborted(
337 &policy,
338 || {
339 let attempts = attempts.clone();
340 async move {
341 attempts.fetch_add(1, Ordering::SeqCst);
342 Err::<i32, Error>(create_aborted_error(None))
343 }
344 },
345 false,
346 )
347 .await;
348
349 assert!(res.is_err());
350 assert_eq!(attempts.load(Ordering::SeqCst), 2); }
352
353 #[tokio::test(start_paused = true)]
354 async fn retry_aborted_with_retry_info() {
355 let policy = BasicTransactionRetryPolicy::default();
356 let attempts = Arc::new(AtomicU32::new(0));
357
358 let start = tokio::time::Instant::now();
359 let res = retry_aborted(
360 &policy,
361 || {
362 let attempts = attempts.clone();
363 async move {
364 let current = attempts.fetch_add(1, Ordering::SeqCst);
365 if current == 0 {
366 Err::<i32, Error>(create_aborted_error(Some(Duration::from_nanos(1))))
367 } else {
368 Ok::<i32, Error>(100)
369 }
370 }
371 },
372 false,
373 )
374 .await;
375 let elapsed = start.elapsed();
376
377 assert_eq!(res.expect("Transaction should succeed after 1 retry"), 100);
378 assert_eq!(attempts.load(Ordering::SeqCst), 2);
379 assert!(
380 elapsed >= Duration::from_nanos(1),
381 "Expected elapsed time to be at least 1ns, but was {:?}",
382 elapsed
383 );
384 }
385
386 #[tokio::test(start_paused = true)]
387 async fn retry_aborted_with_default_backoff() {
388 let policy = BasicTransactionRetryPolicy::default();
389 let attempts = Arc::new(AtomicU32::new(0));
390
391 let res = retry_aborted(
392 &policy,
393 || {
394 let attempts = attempts.clone();
395 async move {
396 let current = attempts.fetch_add(1, Ordering::SeqCst);
397 if current == 0 {
398 Err::<i32, Error>(create_aborted_error(None))
399 } else {
400 Ok::<i32, Error>(100)
401 }
402 }
403 },
404 false,
405 )
406 .await;
407
408 assert_eq!(
409 res.expect("Transaction should succeed using default backoff"),
410 100
411 );
412 assert_eq!(attempts.load(Ordering::SeqCst), 2);
413 }
414
415 #[tokio::test(start_paused = true)]
416 async fn retry_aborted_total_timeout_exceeded() {
417 let policy = BasicTransactionRetryPolicy::new()
418 .with_max_attempts(0)
419 .with_total_timeout(Duration::from_secs(1));
420 let attempts = Arc::new(AtomicU32::new(0));
421
422 let res = retry_aborted(
423 &policy,
424 || {
425 let attempts = attempts.clone();
426 async move {
427 attempts.fetch_add(1, Ordering::SeqCst);
428 Err::<i32, Error>(create_aborted_error(Some(Duration::from_millis(600))))
431 }
432 },
433 false,
434 )
435 .await;
436
437 assert!(res.is_err());
438 assert_eq!(attempts.load(Ordering::SeqCst), 3); }
440
441 #[test]
442 fn is_aborted_non_status_error() {
443 let err = Error::deser("test internal error");
444 assert!(!is_aborted(&err));
445 }
446
447 #[test]
448 fn extract_retry_delay_no_status() {
449 let err = Error::deser("test internal error");
450 assert_eq!(extract_retry_delay(&err), None);
451 }
452
453 #[test]
454 fn extract_retry_delay_no_retry_info() {
455 let mut status = Status::default().set_code(Code::Aborted);
456 status = status.set_details(vec![Any::default()]);
458 let err = Error::service(status);
459 assert_eq!(extract_retry_delay(&err), None);
460 }
461
462 #[test]
463 fn extract_retry_delay_empty_retry_info() {
464 let mut status = Status::default().set_code(Code::Aborted);
465 let retry_info = RetryInfo::default(); status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
467 let err = Error::service(status);
468 assert_eq!(extract_retry_delay(&err), None);
469 }
470
471 #[test]
472 fn extract_retry_delay_invalid_delay() {
473 let mut status = Status::default().set_code(Code::Aborted);
474 let retry_info = RetryInfo::default().set_retry_delay(wkt::Duration::clamp(
475 -10, 0,
477 ));
478 status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
479 let err = Error::service(status);
480 assert_eq!(extract_retry_delay(&err), None);
481 }
482
483 #[tokio::test(start_paused = true)]
484 async fn retry_aborted_with_custom_policy() {
485 struct CustomPolicy;
486 impl TransactionRetryPolicy for CustomPolicy {
487 fn on_abort(&self, error: Error, attempts: u32, _elapsed: Duration) -> RetryResult {
488 if attempts < 3 {
489 RetryResult::Continue(error)
490 } else {
491 RetryResult::Exhausted(error)
492 }
493 }
494 }
495
496 let policy = CustomPolicy;
497 let attempts = Arc::new(AtomicU32::new(0));
498
499 let res = retry_aborted(
500 &policy,
501 || {
502 let attempts = attempts.clone();
503 async move {
504 attempts.fetch_add(1, Ordering::SeqCst);
505 Err::<i32, Error>(create_aborted_error(None))
506 }
507 },
508 false,
509 )
510 .await;
511
512 assert!(res.is_err());
513 assert_eq!(attempts.load(Ordering::SeqCst), 3); }
515
516 #[tokio::test(start_paused = true)]
517 async fn retry_aborted_emulator_internal_schema_error() {
518 let policy = BasicTransactionRetryPolicy::default();
519 let attempts = Arc::new(AtomicU32::new(0));
520
521 let make_schema_error = || {
522 let status = Status::default().set_code(Code::Internal).set_message(
523 "INTERNAL: Schema generation 0 was not registered with the Action Manager",
524 );
525 Error::service(status)
526 };
527
528 let res = retry_aborted(
530 &policy,
531 || {
532 let attempts = attempts.clone();
533 let err = make_schema_error();
534 async move {
535 attempts.fetch_add(1, Ordering::SeqCst);
536 Err::<i32, Error>(err)
537 }
538 },
539 false,
540 )
541 .await;
542 assert!(res.is_err());
543 assert_eq!(attempts.load(Ordering::SeqCst), 1);
544
545 attempts.store(0, Ordering::SeqCst);
547 let res = retry_aborted(
548 &policy,
549 || {
550 let attempts = attempts.clone();
551 let err = make_schema_error();
552 async move {
553 let current = attempts.fetch_add(1, Ordering::SeqCst);
554 if current == 0 {
555 Err::<i32, Error>(err)
556 } else {
557 Ok::<i32, Error>(200)
558 }
559 }
560 },
561 true,
562 )
563 .await;
564 assert_eq!(res.expect("should succeed after retry"), 200);
565 assert_eq!(attempts.load(Ordering::SeqCst), 2);
566 }
567}