libdd_trace_utils/send_with_retry/
mod.rs

1// Copyright 2024-Present Datadog, Inc. https://www.datadoghq.com/
2// SPDX-License-Identifier: Apache-2.0
3
4//! Provide [`send_with_retry`] utility to send a payload to an [`Endpoint`] with retries if the
5//! request fails.
6
7mod retry_strategy;
8pub use retry_strategy::{RetryBackoffType, RetryStrategy};
9
10use bytes::Bytes;
11use hyper::Method;
12use libdd_common::{hyper_migration, Connect, Endpoint, GenericHttpClient, HttpRequestBuilder};
13use std::{collections::HashMap, time::Duration};
14use tracing::{debug, error};
15
16pub type Attempts = u32;
17
18pub type SendWithRetryResult =
19    Result<(hyper_migration::HttpResponse, Attempts), SendWithRetryError>;
20
21/// All errors contain the number of attempts after which the final error was returned
22#[derive(Debug)]
23pub enum SendWithRetryError {
24    /// The request received an error HTTP code.
25    Http(hyper_migration::HttpResponse, Attempts),
26    /// Treats timeout errors originated in the transport layer.
27    Timeout(Attempts),
28    /// Treats errors coming from networking.
29    Network(hyper_migration::ClientError, Attempts),
30    /// Treats errors coming from building the request
31    Build(Attempts),
32}
33
34impl std::fmt::Display for SendWithRetryError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            SendWithRetryError::Http(_, _) => write!(f, "Http error code received"),
38            SendWithRetryError::Timeout(_) => write!(f, "Request timed out"),
39            SendWithRetryError::Network(error, _) => write!(f, "Network error: {error}"),
40            SendWithRetryError::Build(_) => {
41                write!(f, "Failed to build request due to invalid property")
42            }
43        }
44    }
45}
46
47impl std::error::Error for SendWithRetryError {}
48
49impl SendWithRetryError {
50    fn from_request_error(err: RequestError, request_attempt: Attempts) -> Self {
51        match err {
52            RequestError::Build => SendWithRetryError::Build(request_attempt),
53            RequestError::Network(error) => SendWithRetryError::Network(error, request_attempt),
54            RequestError::TimeoutApi => SendWithRetryError::Timeout(request_attempt),
55        }
56    }
57}
58
59#[derive(Debug)]
60enum RequestError {
61    Build,
62    Network(hyper_migration::ClientError),
63    TimeoutApi,
64}
65
66impl std::fmt::Display for RequestError {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        match self {
69            RequestError::TimeoutApi => write!(f, "Api timeout exhausted"),
70            RequestError::Network(error) => write!(f, "Network error: {error}"),
71            RequestError::Build => write!(f, "Failed to build request due to invalid property"),
72        }
73    }
74}
75
76impl std::error::Error for RequestError {}
77
78/// Send the `payload` with a POST request to `target` using the provided `retry_strategy` if the
79/// request fails.
80///
81/// The request builder from [`Endpoint::to_request_builder`] is used with the associated headers
82/// (api key, test token), and `headers` are added to the request. The request is executed with a
83/// timeout of [`Endpoint::timeout_ms`].
84///
85/// # Arguments
86///
87/// # Returns
88///
89/// Return a [`SendWithRetryResult`] containing the response and the number of attempts or an error
90/// describing the last attempt failure.
91///
92/// # Errors
93/// Fail if the request didn't succeed after applying the retry strategy.
94///
95/// # Example
96///
97/// ```rust, no_run
98/// # use libdd_common::Endpoint;
99/// # use libdd_common::hyper_migration::new_default_client;
100/// # use libdd_trace_utils::send_with_retry::*;
101/// # use std::collections::HashMap;
102/// # async fn run() -> SendWithRetryResult {
103/// let payload: Vec<u8> = vec![0, 1, 2, 3];
104/// let target = Endpoint {
105///     url: "localhost:8126/v04/traces".parse::<hyper::Uri>().unwrap(),
106///     ..Endpoint::default()
107/// };
108/// let headers = HashMap::from([("Content-type", "application/msgpack".to_string())]);
109/// let retry_strategy = RetryStrategy::new(3, 10, RetryBackoffType::Exponential, Some(5));
110/// let client = new_default_client();
111/// send_with_retry(&client, &target, payload, &headers, &retry_strategy).await
112/// # }
113/// ```
114pub async fn send_with_retry<C: Connect>(
115    client: &GenericHttpClient<C>,
116    target: &Endpoint,
117    payload: Vec<u8>,
118    headers: &HashMap<&'static str, String>,
119    retry_strategy: &RetryStrategy,
120) -> SendWithRetryResult {
121    let mut request_attempt = 0;
122    // Wrap the payload in Bytes to avoid expensive clone between retries
123    let payload = Bytes::from(payload);
124
125    debug!(
126        url = %target.url,
127        payload_size = payload.len(),
128        max_retries = retry_strategy.max_retries(),
129        "Sending with retry"
130    );
131
132    loop {
133        request_attempt += 1;
134
135        debug!(
136            attempt = request_attempt,
137            max_retries = retry_strategy.max_retries(),
138            "Attempting request"
139        );
140
141        let mut req = target
142            .to_request_builder(concat!("Tracer/", env!("CARGO_PKG_VERSION")))
143            .or(Err(SendWithRetryError::Build(request_attempt)))?
144            .method(Method::POST);
145        for (key, value) in headers {
146            req = req.header(*key, value.clone());
147        }
148
149        match send_request(
150            client,
151            Duration::from_millis(target.timeout_ms),
152            req,
153            payload.clone(),
154        )
155        .await
156        {
157            // An Ok response doesn't necessarily mean the request was successful, we need to
158            // check the status code and if it's not a 2xx or 3xx we treat it as an error
159            Ok(response) => {
160                let status = response.status();
161                debug!(status = %status, attempt = request_attempt, "Received response");
162
163                if status.is_client_error() || status.is_server_error() {
164                    debug!(
165                        status = %status,
166                        attempt = request_attempt,
167                        max_retries = retry_strategy.max_retries(),
168                        "Received error status code"
169                    );
170
171                    if request_attempt < retry_strategy.max_retries() {
172                        debug!(
173                            attempt = request_attempt,
174                            remaining_retries = retry_strategy.max_retries() - request_attempt,
175                            "Retrying after error status code"
176                        );
177                        retry_strategy.delay(request_attempt).await;
178                        continue;
179                    } else {
180                        error!(
181                            status = %status,
182                            attempts = request_attempt,
183                            "Max retries exceeded, returning HTTP error"
184                        );
185                        return Err(SendWithRetryError::Http(response, request_attempt));
186                    }
187                } else {
188                    debug!(
189                        status = %status,
190                        attempts = request_attempt,
191                        "Request succeeded"
192                    );
193                    return Ok((response, request_attempt));
194                }
195            }
196            Err(e) => {
197                debug!(
198                    error = %e,
199                    attempt = request_attempt,
200                    max_retries = retry_strategy.max_retries(),
201                    "Request failed with error"
202                );
203
204                if request_attempt < retry_strategy.max_retries() {
205                    debug!(
206                        attempt = request_attempt,
207                        remaining_retries = retry_strategy.max_retries() - request_attempt,
208                        "Retrying after request error"
209                    );
210                    retry_strategy.delay(request_attempt).await;
211                    continue;
212                } else {
213                    error!(
214                        error = %e,
215                        attempts = request_attempt,
216                        "Max retries exceeded, returning request error"
217                    );
218                    return Err(SendWithRetryError::from_request_error(e, request_attempt));
219                }
220            }
221        }
222    }
223}
224
225async fn send_request<C: Connect>(
226    client: &GenericHttpClient<C>,
227    timeout: Duration,
228    req: HttpRequestBuilder,
229    payload: Bytes,
230) -> Result<hyper_migration::HttpResponse, RequestError> {
231    let req = req
232        .body(hyper_migration::Body::from_bytes(payload))
233        .or(Err(RequestError::Build))?;
234
235    let req_future = { client.request(req) };
236
237    match tokio::time::timeout(timeout, req_future).await {
238        Ok(resp) => match resp {
239            Ok(body) => Ok(hyper_migration::into_response(body)),
240            Err(e) => Err(RequestError::Network(e)),
241        },
242        Err(_) => Err(RequestError::TimeoutApi),
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::test_utils::poll_for_mock_hit;
250    use httpmock::MockServer;
251
252    #[cfg_attr(miri, ignore)]
253    #[tokio::test]
254    async fn test_zero_retries_on_error() {
255        let server = MockServer::start();
256
257        let mut mock_503 = server
258            .mock_async(|_when, then| {
259                then.status(503)
260                    .header("content-type", "application/json")
261                    .body(r#"{"status":"error"}"#);
262            })
263            .await;
264
265        // We add this mock so that if a second request was made it would be a success and our
266        // assertion below that last_result is an error would fail.
267        let _mock_202 = server
268            .mock_async(|_when, then| {
269                then.status(202)
270                    .header("content-type", "application/json")
271                    .body(r#"{"status":"ok"}"#);
272            })
273            .await;
274
275        let target_endpoint = Endpoint {
276            url: server.url("").to_owned().parse().unwrap(),
277            api_key: Some("test-key".into()),
278            ..Default::default()
279        };
280
281        let strategy = RetryStrategy::new(0, 2, RetryBackoffType::Constant, None);
282
283        let client = libdd_common::hyper_migration::new_default_client();
284        tokio::spawn(async move {
285            let result = send_with_retry(
286                &client,
287                &target_endpoint,
288                vec![0, 1, 2, 3],
289                &HashMap::new(),
290                &strategy,
291            )
292            .await;
293            assert!(result.is_err(), "Expected an error result");
294            assert!(
295                matches!(result.unwrap_err(), SendWithRetryError::Http(_, 1)),
296                "Expected an http error with one attempt"
297            );
298        });
299
300        assert!(poll_for_mock_hit(&mut mock_503, 10, 100, 1, true).await);
301    }
302
303    #[cfg_attr(miri, ignore)]
304    #[tokio::test]
305    async fn test_retry_logic_error_then_success() {
306        let server = MockServer::start();
307
308        let mut mock_503 = server
309            .mock_async(|_when, then| {
310                then.status(503)
311                    .header("content-type", "application/json")
312                    .body(r#"{"status":"error"}"#);
313            })
314            .await;
315
316        let mut mock_202 = server
317            .mock_async(|_when, then| {
318                then.status(202)
319                    .header("content-type", "application/json")
320                    .body(r#"{"status":"ok"}"#);
321            })
322            .await;
323
324        let target_endpoint = Endpoint {
325            url: server.url("").to_owned().parse().unwrap(),
326            api_key: Some("test-key".into()),
327            ..Default::default()
328        };
329
330        let strategy = RetryStrategy::new(2, 250, RetryBackoffType::Constant, None);
331
332        let client = libdd_common::hyper_migration::new_default_client();
333        tokio::spawn(async move {
334            let result = send_with_retry(
335                &client,
336                &target_endpoint,
337                vec![0, 1, 2, 3],
338                &HashMap::new(),
339                &strategy,
340            )
341            .await;
342            assert!(
343                matches!(result.unwrap(), (_, 2)),
344                "Expected an ok result after two attempts"
345            );
346        });
347
348        assert!(poll_for_mock_hit(&mut mock_503, 10, 100, 1, true).await);
349        assert!(
350            poll_for_mock_hit(&mut mock_202, 10, 100, 1, true).await,
351            "Expected a retry request after a 5xx error"
352        );
353    }
354
355    #[cfg_attr(miri, ignore)]
356    #[tokio::test]
357    async fn test_retry_logic_max_errors() {
358        let server = MockServer::start();
359        let expected_retry_attempts = 3;
360        let mut mock_503 = server
361            .mock_async(|_when, then| {
362                then.status(503)
363                    .header("content-type", "application/json")
364                    .body(r#"{"status":"error"}"#);
365            })
366            .await;
367
368        let target_endpoint = Endpoint {
369            url: server.url("").to_owned().parse().unwrap(),
370            api_key: Some("test-key".into()),
371            ..Default::default()
372        };
373
374        let strategy = RetryStrategy::new(
375            expected_retry_attempts,
376            10,
377            RetryBackoffType::Constant,
378            None,
379        );
380
381        let client = libdd_common::hyper_migration::new_default_client();
382        tokio::spawn(async move {
383            let result = send_with_retry(
384                &client,
385                &target_endpoint,
386                vec![0, 1, 2, 3],
387                &HashMap::new(),
388                &strategy,
389            )
390            .await;
391            assert!(
392                matches!(result.unwrap_err(), SendWithRetryError::Http(_, attempts) if attempts == expected_retry_attempts),
393                "Expected an error result after max retry attempts"
394            );
395        });
396
397        assert!(
398            poll_for_mock_hit(
399                &mut mock_503,
400                10,
401                100,
402                expected_retry_attempts as usize,
403                true
404            )
405            .await,
406            "Expected max retry attempts"
407        );
408    }
409
410    #[cfg_attr(miri, ignore)]
411    #[tokio::test]
412    async fn test_retry_logic_no_errors() {
413        let server = MockServer::start();
414        let mut mock_202 = server
415            .mock_async(|_when, then| {
416                then.status(202)
417                    .header("content-type", "application/json")
418                    .body(r#"{"status":"Ok"}"#);
419            })
420            .await;
421
422        let target_endpoint = Endpoint {
423            url: server.url("").to_owned().parse().unwrap(),
424            api_key: Some("test-key".into()),
425            ..Default::default()
426        };
427
428        let strategy = RetryStrategy::new(2, 10, RetryBackoffType::Constant, None);
429
430        let client = libdd_common::hyper_migration::new_default_client();
431        tokio::spawn(async move {
432            let result = send_with_retry(
433                &client,
434                &target_endpoint,
435                vec![0, 1, 2, 3],
436                &HashMap::new(),
437                &strategy,
438            )
439            .await;
440            assert!(
441                matches!(result, Ok((_, attempts)) if attempts == 1),
442                "Expected an ok result after one attempts"
443            );
444        });
445
446        assert!(
447            poll_for_mock_hit(&mut mock_202, 10, 250, 1, true).await,
448            "Expected only one request attempt"
449        );
450    }
451}