qcs_api_client_grpc/tonic/
retry.rs1use super::Body;
2use qcs_api_client_common::{
3 backoff::{self, ExponentialBackoff, backoff::Backoff},
4 configuration::TokenError,
5};
6use qcs_dependencies_client::http::{HeaderValue, Request, Response};
7use qcs_dependencies_client::tonic::{Status, client::GrpcService};
8
9use qcs_api_client_common::backoff::duration_from_response as duration_from_http_response;
10use std::{
11 future::{Future, poll_fn},
12 pin::Pin,
13 task::{Context, Poll},
14 time::Duration,
15};
16
17use super::{RequestBodyDuplicationError, build_duplicate_request};
18use qcs_dependencies_client::tower::Layer;
19
20#[derive(Debug, Clone)]
22pub struct RetryLayer {
23 pub(crate) backoff: ExponentialBackoff,
24}
25
26impl Default for RetryLayer {
27 fn default() -> Self {
28 Self {
29 backoff: backoff::default_backoff(),
30 }
31 }
32}
33
34impl<S: GrpcService<Body>> Layer<S> for RetryLayer {
35 type Service = RetryService<S>;
36
37 fn layer(&self, service: S) -> Self::Service {
38 Self::Service {
39 backoff: self.backoff.clone(),
40 service,
41 }
42 }
43}
44
45#[derive(Clone, Debug)]
54pub struct RetryService<S: GrpcService<Body>> {
55 backoff: ExponentialBackoff,
56 service: S,
57}
58
59fn duration_from_response<T>(
64 response: &Response<T>,
65 backoff: &mut ExponentialBackoff,
66) -> Option<Duration> {
67 if let Some(grpc_status) = Status::from_header_map(response.headers()) {
68 match grpc_status.code() {
69 qcs_dependencies_client::tonic::Code::Unavailable => backoff.next_backoff(),
71 _ => None,
73 }
74 } else {
75 duration_from_http_response(response.status(), response.headers(), backoff)
76 }
77}
78
79impl<S> GrpcService<Body> for RetryService<S>
80where
81 S: GrpcService<Body> + Send + Clone + 'static,
82 S::Future: Send,
83 S::ResponseBody: Send,
84 super::error::Error<TokenError>: From<S::Error> + From<RequestBodyDuplicationError>,
85{
86 type ResponseBody = <S as GrpcService<Body>>::ResponseBody;
87 type Error = super::error::Error<TokenError>;
88 type Future =
89 Pin<Box<dyn Future<Output = Result<Response<Self::ResponseBody>, Self::Error>> + Send>>;
90
91 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92 self.service
93 .poll_ready(cx)
94 .map_err(super::error::Error::from)
95 }
96
97 fn call(&mut self, mut req: Request<Body>) -> Self::Future {
98 if let Ok(request_id) = new_request_id() {
99 req.headers_mut().insert(KEY_X_REQUEST_ID, request_id);
100 }
101
102 let mut backoff = self.backoff.clone();
105 let mut service = self.service.clone();
106 std::mem::swap(&mut self.service, &mut service);
112
113 super::common::pin_future_with_otel_context_if_available(async move {
114 let mut attempt = 0;
115 loop {
116 let (mut request, retained) = build_duplicate_request(req).await?;
117 req = retained;
118
119 poll_fn(|cx| -> Poll<Result<(), _>> { service.poll_ready(cx) })
122 .await
123 .map_err(super::error::Error::from)?;
124
125 if let Ok(retry_index_header_value) =
126 qcs_dependencies_client::http::HeaderValue::from_str(
127 attempt.to_string().as_str(),
128 )
129 {
130 request
131 .headers_mut()
132 .insert(KEY_X_REQUEST_RETRY_INDEX, retry_index_header_value);
133 }
134 let duration = match service.call(request).await {
135 Ok(response) => {
136 if let Some(duration) = duration_from_response(&response, &mut backoff) {
137 duration
138 } else {
139 break Ok(response);
140 }
141 }
142 Err(error) => break Err(super::error::Error::from(error)),
143 };
144
145 tokio::time::sleep(duration).await;
146 attempt += 1;
147 }
148 })
149 }
150}
151
152fn new_request_id() -> Result<HeaderValue, qcs_dependencies_client::http::header::InvalidHeaderValue>
153{
154 let request_id = uuid::Uuid::new_v4().to_string();
155 HeaderValue::from_str(request_id.as_str())
156}
157
158const KEY_X_REQUEST_ID: &str = "x-request-id";
159const KEY_X_REQUEST_RETRY_INDEX: &str = "x-request-retry-index";
160
161#[cfg(test)]
162mod tests {
163 use std::sync::atomic::{AtomicUsize, Ordering};
164
165 use crate::tonic::uds_grpc_stream;
166 use crate::tonic::wrap_channel_with_retry;
167
168 use super::*;
169 use ::backoff::ExponentialBackoffBuilder;
170 use qcs_dependencies_client::tonic::Request;
171 use qcs_dependencies_client::tonic::server::NamedService;
172 use qcs_dependencies_client::tonic_health::pb::health_check_response::ServingStatus;
173 use qcs_dependencies_client::tonic_health::pb::health_server::{Health, HealthServer};
174 use qcs_dependencies_client::tonic_health::{
175 pb::health_client::HealthClient, server::HealthService,
176 };
177
178 struct FlakyHealthService {
179 required_tries_count: AtomicUsize,
180 }
181
182 impl FlakyHealthService {
183 const fn new(required_tries_count: usize) -> Self {
184 Self {
185 required_tries_count: AtomicUsize::new(required_tries_count),
186 }
187 }
188
189 #[allow(clippy::result_large_err)]
190 fn make_response(
191 &self,
192 ) -> Result<qcs_dependencies_client::tonic_health::pb::HealthCheckResponse, Status>
193 {
194 let remaining = self.required_tries_count.fetch_sub(1, Ordering::SeqCst);
195 if remaining == 0 {
196 let response = qcs_dependencies_client::tonic_health::pb::HealthCheckResponse {
197 status: ServingStatus::Serving as i32,
198 };
199 Ok(response)
200 } else {
201 self.required_tries_count
202 .store(remaining - 1, Ordering::SeqCst);
203 Err(Status::unavailable("unavailable"))
204 }
205 }
206 }
207
208 impl Default for FlakyHealthService {
209 fn default() -> Self {
210 Self::new(3)
211 }
212 }
213
214 #[qcs_dependencies_client::tonic::async_trait]
215 impl Health for FlakyHealthService {
216 type WatchStream = tokio_stream::wrappers::ReceiverStream<
217 Result<qcs_dependencies_client::tonic_health::pb::HealthCheckResponse, Status>,
218 >;
219
220 async fn check(
221 &self,
222 _request: Request<qcs_dependencies_client::tonic_health::pb::HealthCheckRequest>,
223 ) -> Result<
224 qcs_dependencies_client::tonic::Response<
225 qcs_dependencies_client::tonic_health::pb::HealthCheckResponse,
226 >,
227 Status,
228 > {
229 self.make_response()
230 .map(qcs_dependencies_client::tonic::Response::new)
231 }
232
233 async fn watch(
234 &self,
235 _request: Request<qcs_dependencies_client::tonic_health::pb::HealthCheckRequest>,
236 ) -> Result<qcs_dependencies_client::tonic::Response<Self::WatchStream>, Status> {
237 let (tx, rx) = tokio::sync::mpsc::channel(1);
238 tx.send(self.make_response()).await.unwrap();
239 Ok(qcs_dependencies_client::tonic::Response::new(
240 tokio_stream::wrappers::ReceiverStream::new(rx),
241 ))
242 }
243 }
244
245 #[tokio::test(flavor = "multi_thread")]
246 async fn test_retry_logic() {
247 let health_server = HealthServer::new(FlakyHealthService::default());
248
249 uds_grpc_stream::serve(health_server, |channel| async {
250 let wrapped_channel = wrap_channel_with_retry(channel);
251 let response = HealthClient::new(wrapped_channel)
252 .check(Request::new(
253 qcs_dependencies_client::tonic_health::pb::HealthCheckRequest {
254 service: <HealthServer<HealthService> as NamedService>::NAME.to_string(),
255 },
256 ))
257 .await
258 .unwrap();
259 assert_eq!(response.into_inner().status(), ServingStatus::Serving);
260 })
261 .await
262 .unwrap();
263 }
264
265 #[tokio::test(flavor = "multi_thread")]
266 async fn test_retry_is_not_infinite_long() {
267 let health_server = HealthServer::new(FlakyHealthService::new(50));
268
269 uds_grpc_stream::serve(health_server, |channel| async {
270 let status = HealthClient::new(
271 RetryLayer {
272 backoff: ExponentialBackoffBuilder::new()
273 .with_max_interval(Duration::from_millis(100))
274 .with_max_elapsed_time(Some(Duration::from_secs(1)))
275 .build(),
276 }
277 .layer(channel),
278 )
279 .check(Request::new(
280 qcs_dependencies_client::tonic_health::pb::HealthCheckRequest {
281 service: <HealthServer<HealthService> as NamedService>::NAME.to_string(),
282 },
283 ))
284 .await
285 .unwrap_err();
286 assert_eq!(
287 status.code(),
288 qcs_dependencies_client::tonic::Code::Unavailable
289 );
290 })
291 .await
292 .unwrap();
293 }
294}