gcloud_spanner/
retry.rs

1use std::iter::Take;
2use std::marker::PhantomData;
3
4use google_cloud_gax::grpc::{Code, Status};
5use google_cloud_gax::retry::{CodeCondition, Condition, ExponentialBackoff, Retry, RetrySetting, TryAs};
6
7pub struct TransactionCondition<E>
8where
9    E: TryAs<Status>,
10{
11    inner: CodeCondition,
12    _marker: PhantomData<E>,
13}
14
15impl<E> Condition<E> for TransactionCondition<E>
16where
17    E: TryAs<Status>,
18{
19    fn should_retry(&mut self, error: &E) -> bool {
20        if let Some(status) = error.try_as() {
21            let code = status.code();
22            if code == Code::Internal
23                && !status.message().contains("stream terminated by RST_STREAM")
24                && !status.message().contains("HTTP/2 error code: INTERNAL_ERROR")
25                && !status.message().contains("Connection closed with unknown cause")
26                && !status
27                    .message()
28                    .contains("Received unexpected EOS on DATA frame from server")
29            {
30                return false;
31            }
32            return self.inner.should_retry(error);
33        }
34        false
35    }
36}
37
38pub struct TransactionRetry<E>
39where
40    E: TryAs<Status>,
41{
42    strategy: Take<ExponentialBackoff>,
43    condition: TransactionCondition<E>,
44}
45
46impl<E> TransactionRetry<E>
47where
48    E: TryAs<Status>,
49{
50    pub async fn next(&mut self, status: E) -> Result<(), E> {
51        let duration = if self.condition.should_retry(&status) {
52            self.strategy.next()
53        } else {
54            None
55        };
56        match duration {
57            Some(duration) => {
58                tokio::time::sleep(duration).await;
59                Ok(())
60            }
61            None => Err(status),
62        }
63    }
64
65    pub fn new() -> Self {
66        let setting = TransactionRetrySetting::default();
67        let strategy = <TransactionRetrySetting as Retry<E, TransactionCondition<E>>>::strategy(&setting);
68        Self {
69            strategy,
70            condition: setting.condition(),
71        }
72    }
73}
74
75impl<E> Default for TransactionRetry<E>
76where
77    E: TryAs<Status>,
78{
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84#[derive(Clone, Debug)]
85pub struct TransactionRetrySetting {
86    pub inner: RetrySetting,
87}
88
89impl<E> Retry<E, TransactionCondition<E>> for TransactionRetrySetting
90where
91    E: TryAs<Status>,
92{
93    fn strategy(&self) -> Take<ExponentialBackoff> {
94        self.inner.strategy()
95    }
96
97    fn condition(&self) -> TransactionCondition<E> {
98        TransactionCondition {
99            inner: CodeCondition::new(self.inner.codes.clone()),
100            _marker: PhantomData,
101        }
102    }
103
104    fn notify(error: &E, duration: std::time::Duration) {
105        if let Some(status) = error.try_as() {
106            tracing::trace!("transaction retry fn, error: {:?}, duration: {:?}", status, duration);
107        };
108    }
109}
110
111impl TransactionRetrySetting {
112    pub fn new(codes: Vec<Code>) -> Self {
113        Self {
114            inner: RetrySetting {
115                codes,
116                ..Default::default()
117            },
118        }
119    }
120}
121
122impl Default for TransactionRetrySetting {
123    fn default() -> Self {
124        TransactionRetrySetting::new(vec![Code::Aborted])
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use google_cloud_gax::grpc::{Code, Status};
131    use google_cloud_gax::retry::{Condition, Retry};
132
133    use crate::client::Error;
134    use crate::retry::TransactionRetrySetting;
135
136    #[test]
137    fn test_transaction_condition() {
138        let err = &Error::GRPC(Status::new(Code::Internal, "stream terminated by RST_STREAM"));
139        let default = TransactionRetrySetting::default();
140        assert!(!default.condition().should_retry(err));
141
142        let err = &Error::GRPC(Status::new(Code::Aborted, ""));
143        assert!(default.condition().should_retry(err));
144    }
145}