qcs_api_client_grpc/tonic/
retry.rs1use super::Body;
2use http::{HeaderValue, Request, Response};
3use qcs_api_client_common::{
4 backoff::{self, backoff::Backoff, ExponentialBackoff},
5 configuration::TokenError,
6};
7use tonic::{client::GrpcService, Status};
8
9use qcs_api_client_common::backoff::duration_from_response as duration_from_http_response;
10use std::{
11 future::{poll_fn, Future},
12 pin::Pin,
13 task::{Context, Poll},
14 time::Duration,
15};
16
17use super::{build_duplicate_request, RequestBodyDuplicationError};
18use tower::Layer;
19
20#[derive(Debug, Clone)]
22pub struct RetryLayer {
23 pub(crate) backoff: ExponentialBackoff,
24}
25
26impl Default for RetryLayer {
27 fn default() -> Self {
28 Self {
29 backoff: backoff::default_backoff(),
30 }
31 }
32}
33
34impl<S: GrpcService<Body>> Layer<S> for RetryLayer {
35 type Service = RetryService<S>;
36
37 fn layer(&self, service: S) -> Self::Service {
38 Self::Service {
39 backoff: self.backoff.clone(),
40 service,
41 }
42 }
43}
44
45#[derive(Clone, Debug)]
54pub struct RetryService<S: GrpcService<Body>> {
55 backoff: ExponentialBackoff,
56 service: S,
57}
58
59fn duration_from_response<T>(
64 response: &Response<T>,
65 backoff: &mut ExponentialBackoff,
66) -> Option<Duration> {
67 if let Some(grpc_status) = Status::from_header_map(response.headers()) {
68 match grpc_status.code() {
69 tonic::Code::Unavailable => backoff.next_backoff(),
71 _ => None,
73 }
74 } else {
75 duration_from_http_response(response.status(), response.headers(), backoff)
76 }
77}
78
79impl<S> GrpcService<Body> for RetryService<S>
80where
81 S: GrpcService<Body> + Send + Clone + 'static,
82 S::Future: Send,
83 S::ResponseBody: Send,
84 super::error::Error<TokenError>: From<S::Error> + From<RequestBodyDuplicationError>,
85{
86 type ResponseBody = <S as GrpcService<Body>>::ResponseBody;
87 type Error = super::error::Error<TokenError>;
88 type Future =
89 Pin<Box<dyn Future<Output = Result<Response<Self::ResponseBody>, Self::Error>> + Send>>;
90
91 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92 self.service
93 .poll_ready(cx)
94 .map_err(super::error::Error::from)
95 }
96
97 fn call(&mut self, mut req: Request<Body>) -> Self::Future {
98 if let Ok(request_id) = new_request_id() {
99 req.headers_mut().insert(KEY_X_REQUEST_ID, request_id);
100 }
101
102 let mut backoff = self.backoff.clone();
105 let mut service = self.service.clone();
106 std::mem::swap(&mut self.service, &mut service);
112
113 super::common::pin_future_with_otel_context_if_available(async move {
114 let mut attempt = 0;
115 loop {
116 let (mut request, retained) = build_duplicate_request(req).await?;
117 req = retained;
118
119 poll_fn(|cx| -> Poll<Result<(), _>> { service.poll_ready(cx) })
122 .await
123 .map_err(super::error::Error::from)?;
124
125 if let Ok(retry_index_header_value) =
126 http::HeaderValue::from_str(attempt.to_string().as_str())
127 {
128 request
129 .headers_mut()
130 .insert(KEY_X_REQUEST_RETRY_INDEX, retry_index_header_value);
131 }
132 let duration = match service.call(request).await {
133 Ok(response) => {
134 if let Some(duration) = duration_from_response(&response, &mut backoff) {
135 duration
136 } else {
137 break Ok(response);
138 }
139 }
140 Err(error) => break Err(super::error::Error::from(error)),
141 };
142
143 tokio::time::sleep(duration).await;
144 attempt += 1;
145 }
146 })
147 }
148}
149
150fn new_request_id() -> Result<HeaderValue, http::header::InvalidHeaderValue> {
151 let request_id = uuid::Uuid::new_v4().to_string();
152 HeaderValue::from_str(request_id.as_str())
153}
154
155const KEY_X_REQUEST_ID: &str = "x-request-id";
156const KEY_X_REQUEST_RETRY_INDEX: &str = "x-request-retry-index";
157
158#[cfg(test)]
159mod tests {
160 use std::sync::atomic::{AtomicUsize, Ordering};
161
162 use crate::tonic::uds_grpc_stream;
163 use crate::tonic::wrap_channel_with_retry;
164
165 use super::*;
166 use ::backoff::ExponentialBackoffBuilder;
167 use tonic::server::NamedService;
168 use tonic::Request;
169 use tonic_health::pb::health_check_response::ServingStatus;
170 use tonic_health::pb::health_server::{Health, HealthServer};
171 use tonic_health::{pb::health_client::HealthClient, server::HealthService};
172
173 struct FlakyHealthService {
174 required_tries_count: AtomicUsize,
175 }
176
177 impl FlakyHealthService {
178 const fn new(required_tries_count: usize) -> Self {
179 Self {
180 required_tries_count: AtomicUsize::new(required_tries_count),
181 }
182 }
183
184 #[allow(clippy::result_large_err)]
185 fn make_response(&self) -> Result<tonic_health::pb::HealthCheckResponse, Status> {
186 let remaining = self.required_tries_count.fetch_sub(1, Ordering::SeqCst);
187 if remaining == 0 {
188 let response = tonic_health::pb::HealthCheckResponse {
189 status: ServingStatus::Serving as i32,
190 };
191 Ok(response)
192 } else {
193 self.required_tries_count
194 .store(remaining - 1, Ordering::SeqCst);
195 Err(Status::unavailable("unavailable"))
196 }
197 }
198 }
199
200 impl Default for FlakyHealthService {
201 fn default() -> Self {
202 Self::new(3)
203 }
204 }
205
206 #[tonic::async_trait]
207 impl Health for FlakyHealthService {
208 type WatchStream = tokio_stream::wrappers::ReceiverStream<
209 Result<tonic_health::pb::HealthCheckResponse, Status>,
210 >;
211
212 async fn check(
213 &self,
214 _request: Request<tonic_health::pb::HealthCheckRequest>,
215 ) -> Result<tonic::Response<tonic_health::pb::HealthCheckResponse>, Status> {
216 self.make_response().map(tonic::Response::new)
217 }
218
219 async fn watch(
220 &self,
221 _request: Request<tonic_health::pb::HealthCheckRequest>,
222 ) -> Result<tonic::Response<Self::WatchStream>, Status> {
223 let (tx, rx) = tokio::sync::mpsc::channel(1);
224 tx.send(self.make_response()).await.unwrap();
225 Ok(tonic::Response::new(
226 tokio_stream::wrappers::ReceiverStream::new(rx),
227 ))
228 }
229 }
230
231 #[tokio::test(flavor = "multi_thread")]
232 async fn test_retry_logic() {
233 let health_server = HealthServer::new(FlakyHealthService::default());
234
235 uds_grpc_stream::serve(health_server, |channel| async {
236 let wrapped_channel = wrap_channel_with_retry(channel);
237 let response = HealthClient::new(wrapped_channel)
238 .check(Request::new(tonic_health::pb::HealthCheckRequest {
239 service: <HealthServer<HealthService> as NamedService>::NAME.to_string(),
240 }))
241 .await
242 .unwrap();
243 assert_eq!(response.into_inner().status(), ServingStatus::Serving);
244 })
245 .await
246 .unwrap();
247 }
248
249 #[tokio::test(flavor = "multi_thread")]
250 async fn test_retry_is_not_infinite_long() {
251 let health_server = HealthServer::new(FlakyHealthService::new(50));
252
253 uds_grpc_stream::serve(health_server, |channel| async {
254 let status = HealthClient::new(
255 RetryLayer {
256 backoff: ExponentialBackoffBuilder::new()
257 .with_max_interval(Duration::from_millis(100))
258 .with_max_elapsed_time(Some(Duration::from_secs(1)))
259 .build(),
260 }
261 .layer(channel),
262 )
263 .check(Request::new(tonic_health::pb::HealthCheckRequest {
264 service: <HealthServer<HealthService> as NamedService>::NAME.to_string(),
265 }))
266 .await
267 .unwrap_err();
268 assert_eq!(status.code(), tonic::Code::Unavailable);
269 })
270 .await
271 .unwrap();
272 }
273}