Skip to main content

libdd_trace_utils/send_with_retry/
mod.rs

1// Copyright 2024-Present Datadog, Inc. https://www.datadoghq.com/
2// SPDX-License-Identifier: Apache-2.0
3
4//! Provide [`send_with_retry`] utility to send a payload to an [`Endpoint`] with retries if the
5//! request fails.
6
7mod retry_strategy;
8pub use retry_strategy::{RetryBackoffType, RetryStrategy};
9
10use bytes::Bytes;
11use http::Method;
12use libdd_common::{http_common, Connect, Endpoint, GenericHttpClient, HttpRequestBuilder};
13use std::{collections::HashMap, time::Duration};
14use tracing::{debug, error};
15
16pub type Attempts = u32;
17
18pub type SendWithRetryResult = Result<(http_common::HttpResponse, Attempts), SendWithRetryError>;
19
20/// All errors contain the number of attempts after which the final error was returned
21#[derive(Debug)]
22pub enum SendWithRetryError {
23    /// The request received an error HTTP code.
24    Http(http_common::HttpResponse, Attempts),
25    /// Treats timeout errors originated in the transport layer.
26    Timeout(Attempts),
27    /// Treats errors coming from networking.
28    Network(http_common::ClientError, Attempts),
29    /// Treats errors coming from building the request
30    Build(Attempts),
31}
32
33impl std::fmt::Display for SendWithRetryError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            SendWithRetryError::Http(_, _) => write!(f, "Http error code received"),
37            SendWithRetryError::Timeout(_) => write!(f, "Request timed out"),
38            SendWithRetryError::Network(error, _) => write!(f, "Network error: {error}"),
39            SendWithRetryError::Build(_) => {
40                write!(f, "Failed to build request due to invalid property")
41            }
42        }
43    }
44}
45
46impl std::error::Error for SendWithRetryError {}
47
48impl SendWithRetryError {
49    fn from_request_error(err: RequestError, request_attempt: Attempts) -> Self {
50        match err {
51            RequestError::Build => SendWithRetryError::Build(request_attempt),
52            RequestError::Network(error) => SendWithRetryError::Network(error, request_attempt),
53            RequestError::TimeoutApi => SendWithRetryError::Timeout(request_attempt),
54        }
55    }
56}
57
58#[derive(Debug)]
59enum RequestError {
60    Build,
61    Network(http_common::ClientError),
62    TimeoutApi,
63}
64
65impl std::fmt::Display for RequestError {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            RequestError::TimeoutApi => write!(f, "Api timeout exhausted"),
69            RequestError::Network(error) => write!(f, "Network error: {error}"),
70            RequestError::Build => write!(f, "Failed to build request due to invalid property"),
71        }
72    }
73}
74
75impl std::error::Error for RequestError {}
76
77/// Send the `payload` with a POST request to `target` using the provided `retry_strategy` if the
78/// request fails.
79///
80/// The request builder from [`Endpoint::to_request_builder`] is used with the associated headers
81/// (api key, test token), and `headers` are added to the request. The request is executed with a
82/// timeout of [`Endpoint::timeout_ms`].
83///
84/// # Arguments
85///
86/// # Returns
87///
88/// Return a [`SendWithRetryResult`] containing the response and the number of attempts or an error
89/// describing the last attempt failure.
90///
91/// # Errors
92/// Fail if the request didn't succeed after applying the retry strategy.
93///
94/// # Example
95///
96/// ```rust, no_run
97/// # use libdd_common::Endpoint;
98/// # use libdd_common::http_common::new_default_client;
99/// # use libdd_trace_utils::send_with_retry::*;
100/// # use std::collections::HashMap;
101/// # async fn run() -> SendWithRetryResult {
102/// let payload: Vec<u8> = vec![0, 1, 2, 3];
103/// let target = Endpoint {
104///     url: "localhost:8126/v04/traces".parse::<hyper::Uri>().unwrap(),
105///     ..Endpoint::default()
106/// };
107/// let headers = HashMap::from([("Content-type", "application/msgpack".to_string())]);
108/// let retry_strategy = RetryStrategy::new(3, 10, RetryBackoffType::Exponential, Some(5));
109/// let client = new_default_client();
110/// send_with_retry(&client, &target, payload, &headers, &retry_strategy).await
111/// # }
112/// ```
113pub async fn send_with_retry<C: Connect>(
114    client: &GenericHttpClient<C>,
115    target: &Endpoint,
116    payload: Vec<u8>,
117    headers: &HashMap<&'static str, String>,
118    retry_strategy: &RetryStrategy,
119) -> SendWithRetryResult {
120    let mut request_attempt = 0;
121    // Wrap the payload in Bytes to avoid expensive clone between retries
122    let payload = Bytes::from(payload);
123
124    debug!(
125        url = %target.url,
126        payload_size = payload.len(),
127        max_retries = retry_strategy.max_retries(),
128        "Sending with retry"
129    );
130
131    loop {
132        request_attempt += 1;
133
134        debug!(
135            attempt = request_attempt,
136            max_retries = retry_strategy.max_retries(),
137            "Attempting request"
138        );
139
140        let mut req = target
141            .to_request_builder(concat!("Tracer/", env!("CARGO_PKG_VERSION")))
142            .or(Err(SendWithRetryError::Build(request_attempt)))?
143            .method(Method::POST);
144        for (key, value) in headers {
145            req = req.header(*key, value.clone());
146        }
147
148        match send_request(
149            client,
150            Duration::from_millis(target.timeout_ms),
151            req,
152            payload.clone(),
153        )
154        .await
155        {
156            // An Ok response doesn't necessarily mean the request was successful, we need to
157            // check the status code and if it's not a 2xx or 3xx we treat it as an error
158            Ok(response) => {
159                let status = response.status();
160                debug!(status = %status, attempt = request_attempt, "Received response");
161
162                if status.is_client_error() || status.is_server_error() {
163                    debug!(
164                        status = %status,
165                        attempt = request_attempt,
166                        max_retries = retry_strategy.max_retries(),
167                        "Received error status code"
168                    );
169
170                    if request_attempt < retry_strategy.max_retries() {
171                        debug!(
172                            attempt = request_attempt,
173                            remaining_retries = retry_strategy.max_retries() - request_attempt,
174                            "Retrying after error status code"
175                        );
176                        retry_strategy.delay(request_attempt).await;
177                        continue;
178                    } else {
179                        error!(
180                            status = %status,
181                            attempts = request_attempt,
182                            "Max retries exceeded, returning HTTP error"
183                        );
184                        return Err(SendWithRetryError::Http(response, request_attempt));
185                    }
186                } else {
187                    debug!(
188                        status = %status,
189                        attempts = request_attempt,
190                        "Request succeeded"
191                    );
192                    return Ok((response, request_attempt));
193                }
194            }
195            Err(e) => {
196                debug!(
197                    error = ?e,
198                    attempt = request_attempt,
199                    max_retries = retry_strategy.max_retries(),
200                    "Request failed with error"
201                );
202
203                if request_attempt < retry_strategy.max_retries() {
204                    debug!(
205                        attempt = request_attempt,
206                        remaining_retries = retry_strategy.max_retries() - request_attempt,
207                        "Retrying after request error"
208                    );
209                    retry_strategy.delay(request_attempt).await;
210                    continue;
211                } else {
212                    error!(
213                        error = ?e,
214                        attempts = request_attempt,
215                        "Max retries exceeded, returning request error"
216                    );
217                    return Err(SendWithRetryError::from_request_error(e, request_attempt));
218                }
219            }
220        }
221    }
222}
223
224async fn send_request<C: Connect>(
225    client: &GenericHttpClient<C>,
226    timeout: Duration,
227    req: HttpRequestBuilder,
228    payload: Bytes,
229) -> Result<http_common::HttpResponse, RequestError> {
230    let req = req
231        .body(http_common::Body::from_bytes(payload))
232        .or(Err(RequestError::Build))?;
233
234    let req_future = { client.request(req) };
235
236    match tokio::time::timeout(timeout, req_future).await {
237        Ok(resp) => match resp {
238            Ok(body) => Ok(http_common::into_response(body)),
239            Err(e) => Err(RequestError::Network(http_common::into_error(e))),
240        },
241        Err(_) => Err(RequestError::TimeoutApi),
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use crate::test_utils::poll_for_mock_hit;
249    use httpmock::MockServer;
250
251    #[cfg_attr(miri, ignore)]
252    #[tokio::test]
253    async fn test_zero_retries_on_error() {
254        let server = MockServer::start();
255
256        let mut mock_503 = server
257            .mock_async(|_when, then| {
258                then.status(503)
259                    .header("content-type", "application/json")
260                    .body(r#"{"status":"error"}"#);
261            })
262            .await;
263
264        // We add this mock so that if a second request was made it would be a success and our
265        // assertion below that last_result is an error would fail.
266        let _mock_202 = server
267            .mock_async(|_when, then| {
268                then.status(202)
269                    .header("content-type", "application/json")
270                    .body(r#"{"status":"ok"}"#);
271            })
272            .await;
273
274        let target_endpoint = Endpoint {
275            url: server.url("").to_owned().parse().unwrap(),
276            api_key: Some("test-key".into()),
277            ..Default::default()
278        };
279
280        let strategy = RetryStrategy::new(0, 2, RetryBackoffType::Constant, None);
281
282        let client = libdd_common::http_common::new_default_client();
283        tokio::spawn(async move {
284            let result = send_with_retry(
285                &client,
286                &target_endpoint,
287                vec![0, 1, 2, 3],
288                &HashMap::new(),
289                &strategy,
290            )
291            .await;
292            assert!(result.is_err(), "Expected an error result");
293            assert!(
294                matches!(result.unwrap_err(), SendWithRetryError::Http(_, 1)),
295                "Expected an http error with one attempt"
296            );
297        });
298
299        assert!(poll_for_mock_hit(&mut mock_503, 10, 100, 1, true).await);
300    }
301
302    #[cfg_attr(miri, ignore)]
303    #[tokio::test]
304    async fn test_retry_logic_error_then_success() {
305        let server = MockServer::start();
306
307        let mut mock_503 = server
308            .mock_async(|_when, then| {
309                then.status(503)
310                    .header("content-type", "application/json")
311                    .body(r#"{"status":"error"}"#);
312            })
313            .await;
314
315        let mut mock_202 = server
316            .mock_async(|_when, then| {
317                then.status(202)
318                    .header("content-type", "application/json")
319                    .body(r#"{"status":"ok"}"#);
320            })
321            .await;
322
323        let target_endpoint = Endpoint {
324            url: server.url("").to_owned().parse().unwrap(),
325            api_key: Some("test-key".into()),
326            ..Default::default()
327        };
328
329        let strategy = RetryStrategy::new(2, 250, RetryBackoffType::Constant, None);
330
331        let client = libdd_common::http_common::new_default_client();
332        tokio::spawn(async move {
333            let result = send_with_retry(
334                &client,
335                &target_endpoint,
336                vec![0, 1, 2, 3],
337                &HashMap::new(),
338                &strategy,
339            )
340            .await;
341            assert!(
342                matches!(result.unwrap(), (_, 2)),
343                "Expected an ok result after two attempts"
344            );
345        });
346
347        assert!(poll_for_mock_hit(&mut mock_503, 10, 100, 1, true).await);
348        assert!(
349            poll_for_mock_hit(&mut mock_202, 10, 100, 1, true).await,
350            "Expected a retry request after a 5xx error"
351        );
352    }
353
354    #[cfg_attr(miri, ignore)]
355    #[tokio::test]
356    async fn test_retry_logic_max_errors() {
357        let server = MockServer::start();
358        let expected_retry_attempts = 3;
359        let mut mock_503 = server
360            .mock_async(|_when, then| {
361                then.status(503)
362                    .header("content-type", "application/json")
363                    .body(r#"{"status":"error"}"#);
364            })
365            .await;
366
367        let target_endpoint = Endpoint {
368            url: server.url("").to_owned().parse().unwrap(),
369            api_key: Some("test-key".into()),
370            ..Default::default()
371        };
372
373        let strategy = RetryStrategy::new(
374            expected_retry_attempts,
375            10,
376            RetryBackoffType::Constant,
377            None,
378        );
379
380        let client = libdd_common::http_common::new_default_client();
381        tokio::spawn(async move {
382            let result = send_with_retry(
383                &client,
384                &target_endpoint,
385                vec![0, 1, 2, 3],
386                &HashMap::new(),
387                &strategy,
388            )
389            .await;
390            assert!(
391                matches!(result.unwrap_err(), SendWithRetryError::Http(_, attempts) if attempts == expected_retry_attempts),
392                "Expected an error result after max retry attempts"
393            );
394        });
395
396        assert!(
397            poll_for_mock_hit(
398                &mut mock_503,
399                10,
400                100,
401                expected_retry_attempts as usize,
402                true
403            )
404            .await,
405            "Expected max retry attempts"
406        );
407    }
408
409    #[cfg_attr(miri, ignore)]
410    #[tokio::test]
411    async fn test_retry_logic_no_errors() {
412        let server = MockServer::start();
413        let mut mock_202 = server
414            .mock_async(|_when, then| {
415                then.status(202)
416                    .header("content-type", "application/json")
417                    .body(r#"{"status":"Ok"}"#);
418            })
419            .await;
420
421        let target_endpoint = Endpoint {
422            url: server.url("").to_owned().parse().unwrap(),
423            api_key: Some("test-key".into()),
424            ..Default::default()
425        };
426
427        let strategy = RetryStrategy::new(2, 10, RetryBackoffType::Constant, None);
428
429        let client = libdd_common::http_common::new_default_client();
430        tokio::spawn(async move {
431            let result = send_with_retry(
432                &client,
433                &target_endpoint,
434                vec![0, 1, 2, 3],
435                &HashMap::new(),
436                &strategy,
437            )
438            .await;
439            assert!(
440                matches!(result, Ok((_, attempts)) if attempts == 1),
441                "Expected an ok result after one attempts"
442            );
443        });
444
445        assert!(
446            poll_for_mock_hit(&mut mock_202, 10, 250, 1, true).await,
447            "Expected only one request attempt"
448        );
449    }
450}