1use std::iter::Take;
2use std::marker::PhantomData;
3
4use google_cloud_gax::grpc::{Code, Status};
5use google_cloud_gax::retry::{CodeCondition, Condition, ExponentialBackoff, Retry, RetrySetting, TryAs};
6
7pub struct TransactionCondition<E>
8where
9 E: TryAs<Status>,
10{
11 inner: CodeCondition,
12 _marker: PhantomData<E>,
13}
14
15impl<E> Condition<E> for TransactionCondition<E>
16where
17 E: TryAs<Status>,
18{
19 fn should_retry(&mut self, error: &E) -> bool {
20 if let Some(status) = error.try_as() {
21 let code = status.code();
22 if code == Code::Internal
23 && !status.message().contains("stream terminated by RST_STREAM")
24 && !status.message().contains("HTTP/2 error code: INTERNAL_ERROR")
25 && !status.message().contains("Connection closed with unknown cause")
26 && !status
27 .message()
28 .contains("Received unexpected EOS on DATA frame from server")
29 {
30 return false;
31 }
32 return self.inner.should_retry(error);
33 }
34 false
35 }
36}
37
38pub struct TransactionRetry<E>
39where
40 E: TryAs<Status>,
41{
42 strategy: Take<ExponentialBackoff>,
43 condition: TransactionCondition<E>,
44}
45
46impl<E> TransactionRetry<E>
47where
48 E: TryAs<Status>,
49{
50 pub async fn next(&mut self, status: E) -> Result<(), E> {
51 let duration = if self.condition.should_retry(&status) {
52 self.strategy.next()
53 } else {
54 None
55 };
56 match duration {
57 Some(duration) => {
58 tokio::time::sleep(duration).await;
59 Ok(())
60 }
61 None => Err(status),
62 }
63 }
64
65 pub fn new() -> Self {
66 let setting = TransactionRetrySetting::default();
67 let strategy = <TransactionRetrySetting as Retry<E, TransactionCondition<E>>>::strategy(&setting);
68 Self {
69 strategy,
70 condition: setting.condition(),
71 }
72 }
73}
74
75impl<E> Default for TransactionRetry<E>
76where
77 E: TryAs<Status>,
78{
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84#[derive(Clone, Debug)]
85pub struct TransactionRetrySetting {
86 pub inner: RetrySetting,
87}
88
89impl<E> Retry<E, TransactionCondition<E>> for TransactionRetrySetting
90where
91 E: TryAs<Status>,
92{
93 fn strategy(&self) -> Take<ExponentialBackoff> {
94 self.inner.strategy()
95 }
96
97 fn condition(&self) -> TransactionCondition<E> {
98 TransactionCondition {
99 inner: CodeCondition::new(self.inner.codes.clone()),
100 _marker: PhantomData,
101 }
102 }
103
104 fn notify(error: &E, duration: std::time::Duration) {
105 if let Some(status) = error.try_as() {
106 tracing::trace!("transaction retry fn, error: {:?}, duration: {:?}", status, duration);
107 };
108 }
109}
110
111impl TransactionRetrySetting {
112 pub fn new(codes: Vec<Code>) -> Self {
113 Self {
114 inner: RetrySetting {
115 codes,
116 ..Default::default()
117 },
118 }
119 }
120}
121
122impl Default for TransactionRetrySetting {
123 fn default() -> Self {
124 TransactionRetrySetting::new(vec![Code::Aborted])
125 }
126}
127
128pub struct StreamingRetry {
129 strategy: Take<ExponentialBackoff>,
130 condition: CodeCondition,
131}
132
133impl StreamingRetry {
134 pub async fn next(&mut self, status: Status) -> Result<(), Status> {
135 let duration = if self.condition.should_retry(&status) {
136 self.strategy.next()
137 } else {
138 None
139 };
140 match duration {
141 Some(duration) => {
142 tokio::time::sleep(duration).await;
143 Ok(())
144 }
145 None => Err(status),
146 }
147 }
148
149 pub fn new() -> Self {
150 let setting = StreamingRetrySetting::default();
151 let strategy = <StreamingRetrySetting as Retry<Status, CodeCondition>>::strategy(&setting);
152 Self {
153 strategy,
154 condition: setting.condition(),
155 }
156 }
157}
158
159impl Default for StreamingRetry {
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165#[derive(Clone, Debug)]
166pub struct StreamingRetrySetting {
167 pub inner: RetrySetting,
168}
169
170impl Retry<Status, CodeCondition> for StreamingRetrySetting {
171 fn strategy(&self) -> Take<ExponentialBackoff> {
172 self.inner.strategy()
173 }
174
175 fn condition(&self) -> CodeCondition {
176 CodeCondition::new(self.inner.codes.clone())
177 }
178
179 fn notify(error: &Status, duration: std::time::Duration) {
180 tracing::trace!("streaming retry fn, error: {:?}, duration: {:?}", error, duration);
181 }
182}
183
184impl StreamingRetrySetting {
185 pub fn new(codes: Vec<Code>) -> Self {
186 Self {
187 inner: RetrySetting {
188 codes,
189 ..Default::default()
190 },
191 }
192 }
193}
194
195impl Default for StreamingRetrySetting {
196 fn default() -> Self {
197 StreamingRetrySetting::new(vec![Code::Unavailable, Code::ResourceExhausted, Code::Internal])
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use google_cloud_gax::grpc::{Code, Status};
204 use google_cloud_gax::retry::{Condition, Retry};
205
206 use crate::client::Error;
207 use crate::retry::{StreamingRetrySetting, TransactionRetrySetting};
208
209 #[test]
210 fn test_transaction_condition() {
211 let err = &Error::GRPC(Status::new(Code::Internal, "stream terminated by RST_STREAM"));
212 let default = TransactionRetrySetting::default();
213 assert!(!default.condition().should_retry(err));
214
215 let err = &Error::GRPC(Status::new(Code::Aborted, ""));
216 assert!(default.condition().should_retry(err));
217 }
218
219 #[test]
220 fn test_streaming_retry_condition() {
221 let setting = StreamingRetrySetting::default();
222 assert!(setting.condition().should_retry(&Status::new(Code::Unavailable, "")));
223 assert!(setting
224 .condition()
225 .should_retry(&Status::new(Code::ResourceExhausted, "")));
226 assert!(setting.condition().should_retry(&Status::new(Code::Internal, "")));
227 assert!(!setting.condition().should_retry(&Status::new(Code::Aborted, "")));
228 }
229}