1use std::iter::Take;
2use std::marker::PhantomData;
3
4use google_cloud_gax::grpc::{Code, Status};
5use google_cloud_gax::retry::{CodeCondition, Condition, ExponentialBackoff, Retry, RetrySetting, TryAs};
6
7pub struct TransactionCondition<E>
8where
9 E: TryAs<Status>,
10{
11 inner: CodeCondition,
12 _marker: PhantomData<E>,
13}
14
15impl<E> Condition<E> for TransactionCondition<E>
16where
17 E: TryAs<Status>,
18{
19 fn should_retry(&mut self, error: &E) -> bool {
20 if let Some(status) = error.try_as() {
21 let code = status.code();
22 if code == Code::Internal
23 && !status.message().contains("stream terminated by RST_STREAM")
24 && !status.message().contains("HTTP/2 error code: INTERNAL_ERROR")
25 && !status.message().contains("Connection closed with unknown cause")
26 && !status
27 .message()
28 .contains("Received unexpected EOS on DATA frame from server")
29 {
30 return false;
31 }
32 return self.inner.should_retry(error);
33 }
34 false
35 }
36}
37
38pub struct TransactionRetry<E>
39where
40 E: TryAs<Status>,
41{
42 strategy: Take<ExponentialBackoff>,
43 condition: TransactionCondition<E>,
44}
45
46impl<E> TransactionRetry<E>
47where
48 E: TryAs<Status>,
49{
50 pub async fn next(&mut self, status: E) -> Result<(), E> {
51 let duration = if self.condition.should_retry(&status) {
52 self.strategy.next()
53 } else {
54 None
55 };
56 match duration {
57 Some(duration) => {
58 tokio::time::sleep(duration).await;
59 Ok(())
60 }
61 None => Err(status),
62 }
63 }
64
65 pub fn new() -> Self {
66 let setting = TransactionRetrySetting::default();
67 let strategy = <TransactionRetrySetting as Retry<E, TransactionCondition<E>>>::strategy(&setting);
68 Self {
69 strategy,
70 condition: setting.condition(),
71 }
72 }
73}
74
75impl<E> Default for TransactionRetry<E>
76where
77 E: TryAs<Status>,
78{
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84#[derive(Clone, Debug)]
85pub struct TransactionRetrySetting {
86 pub inner: RetrySetting,
87}
88
89impl<E> Retry<E, TransactionCondition<E>> for TransactionRetrySetting
90where
91 E: TryAs<Status>,
92{
93 fn strategy(&self) -> Take<ExponentialBackoff> {
94 self.inner.strategy()
95 }
96
97 fn condition(&self) -> TransactionCondition<E> {
98 TransactionCondition {
99 inner: CodeCondition::new(self.inner.codes.clone()),
100 _marker: PhantomData,
101 }
102 }
103
104 fn notify(error: &E, duration: std::time::Duration) {
105 if let Some(status) = error.try_as() {
106 tracing::trace!("transaction retry fn, error: {:?}, duration: {:?}", status, duration);
107 };
108 }
109}
110
111impl TransactionRetrySetting {
112 pub fn new(codes: Vec<Code>) -> Self {
113 Self {
114 inner: RetrySetting {
115 codes,
116 ..Default::default()
117 },
118 }
119 }
120}
121
122impl Default for TransactionRetrySetting {
123 fn default() -> Self {
124 TransactionRetrySetting::new(vec![Code::Aborted])
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use google_cloud_gax::grpc::{Code, Status};
131 use google_cloud_gax::retry::{Condition, Retry};
132
133 use crate::client::Error;
134 use crate::retry::TransactionRetrySetting;
135
136 #[test]
137 fn test_transaction_condition() {
138 let err = &Error::GRPC(Status::new(Code::Internal, "stream terminated by RST_STREAM"));
139 let default = TransactionRetrySetting::default();
140 assert!(!default.condition().should_retry(err));
141
142 let err = &Error::GRPC(Status::new(Code::Aborted, ""));
143 assert!(default.condition().should_retry(err));
144 }
145}