1use crate::Error;
16use google_cloud_gax::backoff_policy::BackoffPolicy;
17use google_cloud_gax::error::rpc::StatusDetails;
18use google_cloud_gax::exponential_backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
19use google_cloud_gax::retry_result::RetryResult;
20use google_cloud_gax::retry_state::RetryState;
21use std::time::Duration;
22
23pub trait TransactionRetryPolicy: Send + Sync {
29 fn on_abort(&self, error: Error, attempts: u32, elapsed: Duration) -> RetryResult;
36}
37
38#[derive(Clone, Debug)]
41pub struct BasicTransactionRetryPolicy {
42 max_attempts: u32,
44 total_timeout: Duration,
46}
47
48impl BasicTransactionRetryPolicy {
49 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn with_max_attempts(mut self, max_attempts: u32) -> Self {
56 self.max_attempts = max_attempts;
57 self
58 }
59
60 pub fn with_total_timeout(mut self, total_timeout: Duration) -> Self {
62 self.total_timeout = total_timeout;
63 self
64 }
65
66 pub fn max_attempts(&self) -> u32 {
68 self.max_attempts
69 }
70
71 pub fn total_timeout(&self) -> Duration {
73 self.total_timeout
74 }
75}
76
77impl Default for BasicTransactionRetryPolicy {
78 fn default() -> Self {
79 Self {
80 max_attempts: 0,
81 total_timeout: Duration::from_secs(0),
82 }
83 }
84}
85
86impl TransactionRetryPolicy for BasicTransactionRetryPolicy {
87 fn on_abort(&self, error: Error, attempts: u32, elapsed: Duration) -> RetryResult {
88 if self.max_attempts > 0 && attempts >= self.max_attempts {
89 return RetryResult::Exhausted(error);
90 }
91 if self.total_timeout > Duration::from_secs(0) && elapsed > self.total_timeout {
92 return RetryResult::Exhausted(error);
93 }
94 RetryResult::Continue(error)
95 }
96}
97
98pub(crate) async fn retry_aborted<T, F, Fut>(
105 policy: &dyn TransactionRetryPolicy,
106 mut f: F,
107) -> crate::Result<T>
108where
109 F: FnMut() -> Fut,
110 Fut: std::future::Future<Output = crate::Result<T>>,
111{
112 let start_time = tokio::time::Instant::now();
113 let mut attempts: u32 = 0;
114
115 let backoff = default_retry_backoff();
117
118 loop {
119 attempts += 1;
120 match f().await {
121 Ok(v) => return Ok(v),
122 Err(e) => {
123 backoff_if_aborted(e, attempts, start_time.elapsed(), policy, &backoff).await?;
124 }
125 }
126 }
127}
128
129pub(crate) fn is_aborted(err: &crate::Error) -> bool {
130 err.status()
131 .is_some_and(|s| s.code == google_cloud_gax::error::rpc::Code::Aborted)
132}
133
134pub(crate) fn extract_retry_delay(err: &crate::Error) -> Option<Duration> {
135 err.status()?.details.iter().find_map(|detail| {
136 let StatusDetails::RetryInfo(retry_info) = detail else {
137 return None;
138 };
139 (*retry_info.retry_delay.as_ref()?).try_into().ok()
140 })
141}
142
143pub(crate) fn default_retry_backoff() -> ExponentialBackoff {
144 ExponentialBackoffBuilder::new()
145 .with_initial_delay(Duration::from_millis(10))
146 .with_maximum_delay(Duration::from_secs(1))
147 .with_scaling(1.3)
148 .build()
149 .unwrap()
150}
151
152pub(crate) async fn backoff_if_aborted(
155 err: crate::Error,
156 attempts: u32,
157 elapsed: Duration,
158 policy: &dyn TransactionRetryPolicy,
159 backoff: &ExponentialBackoff,
160) -> crate::Result<()> {
161 if !is_aborted(&err) {
162 return Err(err);
163 }
164
165 let e = match policy.on_abort(err, attempts, elapsed) {
166 RetryResult::Continue(err) => err,
167 RetryResult::Exhausted(err) | RetryResult::Permanent(err) => return Err(err),
168 };
169
170 let sleep_duration = extract_retry_delay(&e)
171 .unwrap_or_else(|| backoff.on_failure(&RetryState::new(true).set_attempt_count(attempts)));
172
173 tokio::time::sleep(sleep_duration).await;
174 Ok(())
175}
176
177#[cfg(test)]
178pub(crate) mod tests {
179 use super::*;
180 use crate::Error;
181 use google_cloud_gax::error::rpc::{Code, Status};
182 use google_cloud_rpc::model::RetryInfo;
183 use std::sync::Arc;
184 use std::sync::atomic::{AtomicU32, Ordering};
185 use wkt::Any;
186
187 fn create_aborted_error(retry_delay: Option<Duration>) -> Error {
188 let mut status = Status::default()
189 .set_code(Code::Aborted)
190 .set_message("aborted");
191
192 if let Some(delay) = retry_delay {
193 let retry_info = RetryInfo::default().set_retry_delay(wkt::Duration::clamp(
194 delay.as_secs() as i64,
195 delay.subsec_nanos() as i32,
196 ));
197 status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
198 }
199
200 Error::service(status)
201 }
202
203 pub(crate) fn create_aborted_status(
204 retry_delay: std::time::Duration,
205 ) -> gaxi::grpc::tonic::Status {
206 use prost::Message;
207
208 #[derive(Clone, PartialEq, prost::Message)]
209 struct MockRetryInfo {
210 #[prost(message, optional, tag = "1")]
211 retry_delay: Option<prost_types::Duration>,
212 }
213
214 let retry_info = MockRetryInfo {
215 retry_delay: Some(prost_types::Duration {
216 seconds: retry_delay.as_secs() as i64,
217 nanos: retry_delay.subsec_nanos() as i32,
218 }),
219 };
220
221 let mut retry_buf = vec![];
222 retry_info.encode(&mut retry_buf).unwrap();
223
224 let status = spanner_grpc_mock::google::rpc::Status {
225 code: gaxi::grpc::tonic::Code::Aborted as i32,
226 message: "test transaction aborted".to_string(),
227 details: vec![prost_types::Any {
228 type_url: "type.googleapis.com/google.rpc.RetryInfo".to_string(),
229 value: retry_buf,
230 }],
231 };
232
233 let mut buf = vec![];
234 status.encode(&mut buf).unwrap();
235
236 gaxi::grpc::tonic::Status::with_details(
237 gaxi::grpc::tonic::Code::Aborted,
238 "test transaction aborted",
239 bytes::Bytes::from(buf),
240 )
241 }
242
243 #[test]
244 fn auto_traits() {
245 static_assertions::assert_impl_all!(
246 BasicTransactionRetryPolicy: Send,
247 Sync,
248 Unpin,
249 Clone,
250 std::fmt::Debug,
251 Default,
252 TransactionRetryPolicy,
253 );
254 }
255
256 #[test]
257 fn basic_retry_policy_getters() {
258 let policy = BasicTransactionRetryPolicy::new()
259 .with_max_attempts(3)
260 .with_total_timeout(Duration::from_secs(10));
261 assert_eq!(policy.max_attempts(), 3);
262 assert_eq!(policy.total_timeout(), Duration::from_secs(10));
263 }
264
265 #[tokio::test]
266 async fn retry_aborted_success_first_try() {
267 let policy = BasicTransactionRetryPolicy::default();
268 let res = retry_aborted(&policy, || async { Ok::<i32, Error>(42) }).await;
269 assert_eq!(res.expect("Transaction should succeed cleanly"), 42);
270 }
271
272 #[tokio::test]
273 async fn retry_aborted_not_aborted_error() {
274 let policy = BasicTransactionRetryPolicy::default();
275 let res = retry_aborted(&policy, || async {
276 let status = Status::default()
277 .set_code(Code::Unavailable)
278 .set_message("server unavailable");
279 Err::<i32, Error>(Error::service(status))
280 })
281 .await;
282
283 let err = res.unwrap_err();
284 assert_eq!(
285 err.status().expect("Error should contain a status").code,
286 Code::Unavailable
287 );
288 }
289
290 #[tokio::test(start_paused = true)]
291 async fn retry_aborted_max_attempts_exceeded() {
292 let policy = BasicTransactionRetryPolicy::new()
293 .with_max_attempts(2)
294 .with_total_timeout(Duration::from_secs(0));
295 let attempts = Arc::new(AtomicU32::new(0));
296
297 let res = retry_aborted(&policy, || {
298 let attempts = attempts.clone();
299 async move {
300 attempts.fetch_add(1, Ordering::SeqCst);
301 Err::<i32, Error>(create_aborted_error(None))
302 }
303 })
304 .await;
305
306 assert!(res.is_err());
307 assert_eq!(attempts.load(Ordering::SeqCst), 2); }
309
310 #[tokio::test(start_paused = true)]
311 async fn retry_aborted_with_retry_info() {
312 let policy = BasicTransactionRetryPolicy::default();
313 let attempts = Arc::new(AtomicU32::new(0));
314
315 let start = tokio::time::Instant::now();
316 let res = retry_aborted(&policy, || {
317 let attempts = attempts.clone();
318 async move {
319 let current = attempts.fetch_add(1, Ordering::SeqCst);
320 if current == 0 {
321 Err::<i32, Error>(create_aborted_error(Some(Duration::from_nanos(1))))
322 } else {
323 Ok::<i32, Error>(100)
324 }
325 }
326 })
327 .await;
328 let elapsed = start.elapsed();
329
330 assert_eq!(res.expect("Transaction should succeed after 1 retry"), 100);
331 assert_eq!(attempts.load(Ordering::SeqCst), 2);
332 assert!(
333 elapsed >= Duration::from_nanos(1),
334 "Expected elapsed time to be at least 1ns, but was {:?}",
335 elapsed
336 );
337 }
338
339 #[tokio::test(start_paused = true)]
340 async fn retry_aborted_with_default_backoff() {
341 let policy = BasicTransactionRetryPolicy::default();
342 let attempts = Arc::new(AtomicU32::new(0));
343
344 let res = retry_aborted(&policy, || {
345 let attempts = attempts.clone();
346 async move {
347 let current = attempts.fetch_add(1, Ordering::SeqCst);
348 if current == 0 {
349 Err::<i32, Error>(create_aborted_error(None))
350 } else {
351 Ok::<i32, Error>(100)
352 }
353 }
354 })
355 .await;
356
357 assert_eq!(
358 res.expect("Transaction should succeed using default backoff"),
359 100
360 );
361 assert_eq!(attempts.load(Ordering::SeqCst), 2);
362 }
363
364 #[tokio::test(start_paused = true)]
365 async fn retry_aborted_total_timeout_exceeded() {
366 let policy = BasicTransactionRetryPolicy::new()
367 .with_max_attempts(0)
368 .with_total_timeout(Duration::from_secs(1));
369 let attempts = Arc::new(AtomicU32::new(0));
370
371 let res = retry_aborted(&policy, || {
372 let attempts = attempts.clone();
373 async move {
374 attempts.fetch_add(1, Ordering::SeqCst);
375 Err::<i32, Error>(create_aborted_error(Some(Duration::from_millis(600))))
378 }
379 })
380 .await;
381
382 assert!(res.is_err());
383 assert_eq!(attempts.load(Ordering::SeqCst), 3); }
385
386 #[test]
387 fn is_aborted_non_status_error() {
388 let err = Error::deser("test internal error");
389 assert!(!is_aborted(&err));
390 }
391
392 #[test]
393 fn extract_retry_delay_no_status() {
394 let err = Error::deser("test internal error");
395 assert_eq!(extract_retry_delay(&err), None);
396 }
397
398 #[test]
399 fn extract_retry_delay_no_retry_info() {
400 let mut status = Status::default().set_code(Code::Aborted);
401 status = status.set_details(vec![Any::default()]);
403 let err = Error::service(status);
404 assert_eq!(extract_retry_delay(&err), None);
405 }
406
407 #[test]
408 fn extract_retry_delay_empty_retry_info() {
409 let mut status = Status::default().set_code(Code::Aborted);
410 let retry_info = RetryInfo::default(); status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
412 let err = Error::service(status);
413 assert_eq!(extract_retry_delay(&err), None);
414 }
415
416 #[test]
417 fn extract_retry_delay_invalid_delay() {
418 let mut status = Status::default().set_code(Code::Aborted);
419 let retry_info = RetryInfo::default().set_retry_delay(wkt::Duration::clamp(
420 -10, 0,
422 ));
423 status = status.set_details(vec![Any::from_msg(&retry_info).unwrap()]);
424 let err = Error::service(status);
425 assert_eq!(extract_retry_delay(&err), None);
426 }
427
428 #[tokio::test(start_paused = true)]
429 async fn retry_aborted_with_custom_policy() {
430 struct CustomPolicy;
431 impl TransactionRetryPolicy for CustomPolicy {
432 fn on_abort(&self, error: Error, attempts: u32, _elapsed: Duration) -> RetryResult {
433 if attempts < 3 {
434 RetryResult::Continue(error)
435 } else {
436 RetryResult::Exhausted(error)
437 }
438 }
439 }
440
441 let policy = CustomPolicy;
442 let attempts = Arc::new(AtomicU32::new(0));
443
444 let res = retry_aborted(&policy, || {
445 let attempts = attempts.clone();
446 async move {
447 attempts.fetch_add(1, Ordering::SeqCst);
448 Err::<i32, Error>(create_aborted_error(None))
449 }
450 })
451 .await;
452
453 assert!(res.is_err());
454 assert_eq!(attempts.load(Ordering::SeqCst), 3); }
456}