Skip to main content

qcs_api_client_grpc/tonic/
retry.rs

1use super::Body;
2use qcs_api_client_common::{
3    backoff::{self, ExponentialBackoff, backoff::Backoff},
4    configuration::TokenError,
5};
6use qcs_dependencies_client::http::{HeaderValue, Request, Response};
7use qcs_dependencies_client::tonic::{Status, client::GrpcService};
8
9use qcs_api_client_common::backoff::duration_from_response as duration_from_http_response;
10use std::{
11    future::{Future, poll_fn},
12    pin::Pin,
13    task::{Context, Poll},
14    time::Duration,
15};
16
17use super::{RequestBodyDuplicationError, build_duplicate_request};
18use qcs_dependencies_client::tower::Layer;
19
20/// The [`Layer`] used to apply exponential backoff retry logic to requests.
21#[derive(Debug, Clone)]
22pub struct RetryLayer {
23    pub(crate) backoff: ExponentialBackoff,
24}
25
26impl Default for RetryLayer {
27    fn default() -> Self {
28        Self {
29            backoff: backoff::default_backoff(),
30        }
31    }
32}
33
34impl<S: GrpcService<Body>> Layer<S> for RetryLayer {
35    type Service = RetryService<S>;
36
37    fn layer(&self, service: S) -> Self::Service {
38        Self::Service {
39            backoff: self.backoff.clone(),
40            service,
41        }
42    }
43}
44
45/// The [`GrpcService`] that wraps the gRPC client in order to provide exponential backoff retry
46/// logic.
47///
48/// This middleware will add a `x-request-id` header to each request with a unique UUID and a
49/// `x-request-retry-index` header with the number of retries that have been attempted for the
50/// request.
51///
52/// See also: [`RetryLayer`].
53#[derive(Clone, Debug)]
54pub struct RetryService<S: GrpcService<Body>> {
55    backoff: ExponentialBackoff,
56    service: S,
57}
58
59/// Return `Some` if the request should be retried and the provided `backoff`
60/// has another backoff to try, or, for an http request, if the response
61/// specifies a `Retry-After` header. If `None` is returned, the request should
62/// not be retried.
63fn duration_from_response<T>(
64    response: &Response<T>,
65    backoff: &mut ExponentialBackoff,
66) -> Option<Duration> {
67    if let Some(grpc_status) = Status::from_header_map(response.headers()) {
68        match grpc_status.code() {
69            // gRPC has no equivalent to RETRY-AFTER, so just use the backoff
70            qcs_dependencies_client::tonic::Code::Unavailable => backoff.next_backoff(),
71            // No other gRPC statuses are retried.
72            _ => None,
73        }
74    } else {
75        duration_from_http_response(response.status(), response.headers(), backoff)
76    }
77}
78
79impl<S> GrpcService<Body> for RetryService<S>
80where
81    S: GrpcService<Body> + Send + Clone + 'static,
82    S::Future: Send,
83    S::ResponseBody: Send,
84    super::error::Error<TokenError>: From<S::Error> + From<RequestBodyDuplicationError>,
85{
86    type ResponseBody = <S as GrpcService<Body>>::ResponseBody;
87    type Error = super::error::Error<TokenError>;
88    type Future =
89        Pin<Box<dyn Future<Output = Result<Response<Self::ResponseBody>, Self::Error>> + Send>>;
90
91    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92        self.service
93            .poll_ready(cx)
94            .map_err(super::error::Error::from)
95    }
96
97    fn call(&mut self, mut req: Request<Body>) -> Self::Future {
98        if let Ok(request_id) = new_request_id() {
99            req.headers_mut().insert(KEY_X_REQUEST_ID, request_id);
100        }
101
102        // Clone the `backoff` so that new requests don't reuse it
103        // and so that the `backoff` can be moved into the async closure.
104        let mut backoff = self.backoff.clone();
105        let mut service = self.service.clone();
106        // It is necessary to replace self.service with the above clone
107        // because the cloned version may not be "ready".
108        //
109        // See this github issue for more context:
110        // https://github.com/tower-rs/tower/issues/547
111        std::mem::swap(&mut self.service, &mut service);
112
113        super::common::pin_future_with_otel_context_if_available(async move {
114            let mut attempt = 0;
115            loop {
116                let (mut request, retained) = build_duplicate_request(req).await?;
117                req = retained;
118
119                // Ensure that the service is ready before trying to use it.
120                // Failure to do this *will* cause a panic.
121                poll_fn(|cx| -> Poll<Result<(), _>> { service.poll_ready(cx) })
122                    .await
123                    .map_err(super::error::Error::from)?;
124
125                if let Ok(retry_index_header_value) =
126                    qcs_dependencies_client::http::HeaderValue::from_str(
127                        attempt.to_string().as_str(),
128                    )
129                {
130                    request
131                        .headers_mut()
132                        .insert(KEY_X_REQUEST_RETRY_INDEX, retry_index_header_value);
133                }
134                let duration = match service.call(request).await {
135                    Ok(response) => {
136                        if let Some(duration) = duration_from_response(&response, &mut backoff) {
137                            duration
138                        } else {
139                            break Ok(response);
140                        }
141                    }
142                    Err(error) => break Err(super::error::Error::from(error)),
143                };
144
145                tokio::time::sleep(duration).await;
146                attempt += 1;
147            }
148        })
149    }
150}
151
152fn new_request_id() -> Result<HeaderValue, qcs_dependencies_client::http::header::InvalidHeaderValue>
153{
154    let request_id = uuid::Uuid::new_v4().to_string();
155    HeaderValue::from_str(request_id.as_str())
156}
157
158const KEY_X_REQUEST_ID: &str = "x-request-id";
159const KEY_X_REQUEST_RETRY_INDEX: &str = "x-request-retry-index";
160
161#[cfg(test)]
162mod tests {
163    use std::sync::atomic::{AtomicUsize, Ordering};
164
165    use crate::tonic::uds_grpc_stream;
166    use crate::tonic::wrap_channel_with_retry;
167
168    use super::*;
169    use ::backoff::ExponentialBackoffBuilder;
170    use qcs_dependencies_client::tonic::Request;
171    use qcs_dependencies_client::tonic::server::NamedService;
172    use qcs_dependencies_client::tonic_health::pb::health_check_response::ServingStatus;
173    use qcs_dependencies_client::tonic_health::pb::health_server::{Health, HealthServer};
174    use qcs_dependencies_client::tonic_health::{
175        pb::health_client::HealthClient, server::HealthService,
176    };
177
178    struct FlakyHealthService {
179        required_tries_count: AtomicUsize,
180    }
181
182    impl FlakyHealthService {
183        const fn new(required_tries_count: usize) -> Self {
184            Self {
185                required_tries_count: AtomicUsize::new(required_tries_count),
186            }
187        }
188
189        #[allow(clippy::result_large_err)]
190        fn make_response(
191            &self,
192        ) -> Result<qcs_dependencies_client::tonic_health::pb::HealthCheckResponse, Status>
193        {
194            let remaining = self.required_tries_count.fetch_sub(1, Ordering::SeqCst);
195            if remaining == 0 {
196                let response = qcs_dependencies_client::tonic_health::pb::HealthCheckResponse {
197                    status: ServingStatus::Serving as i32,
198                };
199                Ok(response)
200            } else {
201                self.required_tries_count
202                    .store(remaining - 1, Ordering::SeqCst);
203                Err(Status::unavailable("unavailable"))
204            }
205        }
206    }
207
208    impl Default for FlakyHealthService {
209        fn default() -> Self {
210            Self::new(3)
211        }
212    }
213
214    #[qcs_dependencies_client::tonic::async_trait]
215    impl Health for FlakyHealthService {
216        type WatchStream = tokio_stream::wrappers::ReceiverStream<
217            Result<qcs_dependencies_client::tonic_health::pb::HealthCheckResponse, Status>,
218        >;
219
220        async fn check(
221            &self,
222            _request: Request<qcs_dependencies_client::tonic_health::pb::HealthCheckRequest>,
223        ) -> Result<
224            qcs_dependencies_client::tonic::Response<
225                qcs_dependencies_client::tonic_health::pb::HealthCheckResponse,
226            >,
227            Status,
228        > {
229            self.make_response()
230                .map(qcs_dependencies_client::tonic::Response::new)
231        }
232
233        async fn watch(
234            &self,
235            _request: Request<qcs_dependencies_client::tonic_health::pb::HealthCheckRequest>,
236        ) -> Result<qcs_dependencies_client::tonic::Response<Self::WatchStream>, Status> {
237            let (tx, rx) = tokio::sync::mpsc::channel(1);
238            tx.send(self.make_response()).await.unwrap();
239            Ok(qcs_dependencies_client::tonic::Response::new(
240                tokio_stream::wrappers::ReceiverStream::new(rx),
241            ))
242        }
243    }
244
245    #[tokio::test(flavor = "multi_thread")]
246    async fn test_retry_logic() {
247        let health_server = HealthServer::new(FlakyHealthService::default());
248
249        uds_grpc_stream::serve(health_server, |channel| async {
250            let wrapped_channel = wrap_channel_with_retry(channel);
251            let response = HealthClient::new(wrapped_channel)
252                .check(Request::new(
253                    qcs_dependencies_client::tonic_health::pb::HealthCheckRequest {
254                        service: <HealthServer<HealthService> as NamedService>::NAME.to_string(),
255                    },
256                ))
257                .await
258                .unwrap();
259            assert_eq!(response.into_inner().status(), ServingStatus::Serving);
260        })
261        .await
262        .unwrap();
263    }
264
265    #[tokio::test(flavor = "multi_thread")]
266    async fn test_retry_is_not_infinite_long() {
267        let health_server = HealthServer::new(FlakyHealthService::new(50));
268
269        uds_grpc_stream::serve(health_server, |channel| async {
270            let status = HealthClient::new(
271                RetryLayer {
272                    backoff: ExponentialBackoffBuilder::new()
273                        .with_max_interval(Duration::from_millis(100))
274                        .with_max_elapsed_time(Some(Duration::from_secs(1)))
275                        .build(),
276                }
277                .layer(channel),
278            )
279            .check(Request::new(
280                qcs_dependencies_client::tonic_health::pb::HealthCheckRequest {
281                    service: <HealthServer<HealthService> as NamedService>::NAME.to_string(),
282                },
283            ))
284            .await
285            .unwrap_err();
286            assert_eq!(
287                status.code(),
288                qcs_dependencies_client::tonic::Code::Unavailable
289            );
290        })
291        .await
292        .unwrap();
293    }
294}