qcs_api_client_grpc/tonic/
retry.rs

1use super::Body;
2use http::{HeaderValue, Request, Response};
3use qcs_api_client_common::{
4    backoff::{self, backoff::Backoff, ExponentialBackoff},
5    configuration::TokenError,
6};
7use tonic::{client::GrpcService, Status};
8
9use qcs_api_client_common::backoff::duration_from_response as duration_from_http_response;
10use std::{
11    future::{poll_fn, Future},
12    pin::Pin,
13    task::{Context, Poll},
14    time::Duration,
15};
16
17use super::{build_duplicate_request, RequestBodyDuplicationError};
18use tower::Layer;
19
20/// The [`Layer`] used to apply exponential backoff retry logic to requests.
21#[derive(Debug, Clone)]
22pub struct RetryLayer {
23    pub(crate) backoff: ExponentialBackoff,
24}
25
26impl Default for RetryLayer {
27    fn default() -> Self {
28        Self {
29            backoff: backoff::default_backoff(),
30        }
31    }
32}
33
34impl<S: GrpcService<Body>> Layer<S> for RetryLayer {
35    type Service = RetryService<S>;
36
37    fn layer(&self, service: S) -> Self::Service {
38        Self::Service {
39            backoff: self.backoff.clone(),
40            service,
41        }
42    }
43}
44
45/// The [`GrpcService`] that wraps the gRPC client in order to provide exponential backoff retry
46/// logic.
47///
48/// This middleware will add a `x-request-id` header to each request with a unique UUID and a
49/// `x-request-retry-index` header with the number of retries that have been attempted for the
50/// request.
51///
52/// See also: [`RetryLayer`].
53#[derive(Clone, Debug)]
54pub struct RetryService<S: GrpcService<Body>> {
55    backoff: ExponentialBackoff,
56    service: S,
57}
58
59/// Return `Some` if the request should be retried and the provided `backoff`
60/// has another backoff to try, or, for an http request, if the response
61/// specifies a `Retry-After` header. If `None` is returned, the request should
62/// not be retried.
63fn duration_from_response<T>(
64    response: &Response<T>,
65    backoff: &mut ExponentialBackoff,
66) -> Option<Duration> {
67    if let Some(grpc_status) = Status::from_header_map(response.headers()) {
68        match grpc_status.code() {
69            // gRPC has no equivalent to RETRY-AFTER, so just use the backoff
70            tonic::Code::Unavailable => backoff.next_backoff(),
71            // No other gRPC statuses are retried.
72            _ => None,
73        }
74    } else {
75        duration_from_http_response(response.status(), response.headers(), backoff)
76    }
77}
78
79impl<S> GrpcService<Body> for RetryService<S>
80where
81    S: GrpcService<Body> + Send + Clone + 'static,
82    S::Future: Send,
83    S::ResponseBody: Send,
84    super::error::Error<TokenError>: From<S::Error> + From<RequestBodyDuplicationError>,
85{
86    type ResponseBody = <S as GrpcService<Body>>::ResponseBody;
87    type Error = super::error::Error<TokenError>;
88    type Future =
89        Pin<Box<dyn Future<Output = Result<Response<Self::ResponseBody>, Self::Error>> + Send>>;
90
91    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92        self.service
93            .poll_ready(cx)
94            .map_err(super::error::Error::from)
95    }
96
97    fn call(&mut self, mut req: Request<Body>) -> Self::Future {
98        if let Ok(request_id) = new_request_id() {
99            req.headers_mut().insert(KEY_X_REQUEST_ID, request_id);
100        }
101
102        // Clone the `backoff` so that new requests don't reuse it
103        // and so that the `backoff` can be moved into the async closure.
104        let mut backoff = self.backoff.clone();
105        let mut service = self.service.clone();
106        // It is necessary to replace self.service with the above clone
107        // because the cloned version may not be "ready".
108        //
109        // See this github issue for more context:
110        // https://github.com/tower-rs/tower/issues/547
111        std::mem::swap(&mut self.service, &mut service);
112
113        super::common::pin_future_with_otel_context_if_available(async move {
114            let mut attempt = 0;
115            loop {
116                let (mut request, retained) = build_duplicate_request(req).await?;
117                req = retained;
118
119                // Ensure that the service is ready before trying to use it.
120                // Failure to do this *will* cause a panic.
121                poll_fn(|cx| -> Poll<Result<(), _>> { service.poll_ready(cx) })
122                    .await
123                    .map_err(super::error::Error::from)?;
124
125                if let Ok(retry_index_header_value) =
126                    http::HeaderValue::from_str(attempt.to_string().as_str())
127                {
128                    request
129                        .headers_mut()
130                        .insert(KEY_X_REQUEST_RETRY_INDEX, retry_index_header_value);
131                }
132                let duration = match service.call(request).await {
133                    Ok(response) => {
134                        if let Some(duration) = duration_from_response(&response, &mut backoff) {
135                            duration
136                        } else {
137                            break Ok(response);
138                        }
139                    }
140                    Err(error) => break Err(super::error::Error::from(error)),
141                };
142
143                tokio::time::sleep(duration).await;
144                attempt += 1;
145            }
146        })
147    }
148}
149
150fn new_request_id() -> Result<HeaderValue, http::header::InvalidHeaderValue> {
151    let request_id = uuid::Uuid::new_v4().to_string();
152    HeaderValue::from_str(request_id.as_str())
153}
154
155const KEY_X_REQUEST_ID: &str = "x-request-id";
156const KEY_X_REQUEST_RETRY_INDEX: &str = "x-request-retry-index";
157
158#[cfg(test)]
159mod tests {
160    use std::sync::atomic::{AtomicUsize, Ordering};
161
162    use crate::tonic::uds_grpc_stream;
163    use crate::tonic::wrap_channel_with_retry;
164
165    use super::*;
166    use ::backoff::ExponentialBackoffBuilder;
167    use tonic::server::NamedService;
168    use tonic::Request;
169    use tonic_health::pb::health_check_response::ServingStatus;
170    use tonic_health::pb::health_server::{Health, HealthServer};
171    use tonic_health::{pb::health_client::HealthClient, server::HealthService};
172
173    struct FlakyHealthService {
174        required_tries_count: AtomicUsize,
175    }
176
177    impl FlakyHealthService {
178        const fn new(required_tries_count: usize) -> Self {
179            Self {
180                required_tries_count: AtomicUsize::new(required_tries_count),
181            }
182        }
183
184        #[allow(clippy::result_large_err)]
185        fn make_response(&self) -> Result<tonic_health::pb::HealthCheckResponse, Status> {
186            let remaining = self.required_tries_count.fetch_sub(1, Ordering::SeqCst);
187            if remaining == 0 {
188                let response = tonic_health::pb::HealthCheckResponse {
189                    status: ServingStatus::Serving as i32,
190                };
191                Ok(response)
192            } else {
193                self.required_tries_count
194                    .store(remaining - 1, Ordering::SeqCst);
195                Err(Status::unavailable("unavailable"))
196            }
197        }
198    }
199
200    impl Default for FlakyHealthService {
201        fn default() -> Self {
202            Self::new(3)
203        }
204    }
205
206    #[tonic::async_trait]
207    impl Health for FlakyHealthService {
208        type WatchStream = tokio_stream::wrappers::ReceiverStream<
209            Result<tonic_health::pb::HealthCheckResponse, Status>,
210        >;
211
212        async fn check(
213            &self,
214            _request: Request<tonic_health::pb::HealthCheckRequest>,
215        ) -> Result<tonic::Response<tonic_health::pb::HealthCheckResponse>, Status> {
216            self.make_response().map(tonic::Response::new)
217        }
218
219        async fn watch(
220            &self,
221            _request: Request<tonic_health::pb::HealthCheckRequest>,
222        ) -> Result<tonic::Response<Self::WatchStream>, Status> {
223            let (tx, rx) = tokio::sync::mpsc::channel(1);
224            tx.send(self.make_response()).await.unwrap();
225            Ok(tonic::Response::new(
226                tokio_stream::wrappers::ReceiverStream::new(rx),
227            ))
228        }
229    }
230
231    #[tokio::test(flavor = "multi_thread")]
232    async fn test_retry_logic() {
233        let health_server = HealthServer::new(FlakyHealthService::default());
234
235        uds_grpc_stream::serve(health_server, |channel| async {
236            let wrapped_channel = wrap_channel_with_retry(channel);
237            let response = HealthClient::new(wrapped_channel)
238                .check(Request::new(tonic_health::pb::HealthCheckRequest {
239                    service: <HealthServer<HealthService> as NamedService>::NAME.to_string(),
240                }))
241                .await
242                .unwrap();
243            assert_eq!(response.into_inner().status(), ServingStatus::Serving);
244        })
245        .await
246        .unwrap();
247    }
248
249    #[tokio::test(flavor = "multi_thread")]
250    async fn test_retry_is_not_infinite_long() {
251        let health_server = HealthServer::new(FlakyHealthService::new(50));
252
253        uds_grpc_stream::serve(health_server, |channel| async {
254            let status = HealthClient::new(
255                RetryLayer {
256                    backoff: ExponentialBackoffBuilder::new()
257                        .with_max_interval(Duration::from_millis(100))
258                        .with_max_elapsed_time(Some(Duration::from_secs(1)))
259                        .build(),
260                }
261                .layer(channel),
262            )
263            .check(Request::new(tonic_health::pb::HealthCheckRequest {
264                service: <HealthServer<HealthService> as NamedService>::NAME.to_string(),
265            }))
266            .await
267            .unwrap_err();
268            assert_eq!(status.code(), tonic::Code::Unavailable);
269        })
270        .await
271        .unwrap();
272    }
273}