Skip to main content

libdd_trace_utils/send_with_retry/
mod.rs

1// Copyright 2024-Present Datadog, Inc. https://www.datadoghq.com/
2// SPDX-License-Identifier: Apache-2.0
3
4//! Provide [`send_with_retry`] utility to send a payload to an [`Endpoint`] with retries if the
5//! request fails.
6
7mod retry_strategy;
8pub use retry_strategy::{RetryBackoffType, RetryStrategy};
9
10use bytes::Bytes;
11use http::{HeaderMap, Method};
12use libdd_common::{http_common, Connect, Endpoint, GenericHttpClient, HttpRequestBuilder};
13use std::time::Duration;
14use tracing::{debug, error};
15
16pub type Attempts = u32;
17
18pub type SendWithRetryResult = Result<(http_common::HttpResponse, Attempts), SendWithRetryError>;
19
20/// All errors contain the number of attempts after which the final error was returned
21#[derive(Debug)]
22pub enum SendWithRetryError {
23    /// The request received an error HTTP code.
24    Http(http_common::HttpResponse, Attempts),
25    /// Treats timeout errors originated in the transport layer.
26    Timeout(Attempts),
27    /// Treats errors coming from networking.
28    Network(http_common::ClientError, Attempts),
29    /// Treats errors coming from building the request
30    Build(Attempts),
31}
32
33impl std::fmt::Display for SendWithRetryError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            SendWithRetryError::Http(_, _) => write!(f, "Http error code received"),
37            SendWithRetryError::Timeout(_) => write!(f, "Request timed out"),
38            SendWithRetryError::Network(error, _) => write!(f, "Network error: {error}"),
39            SendWithRetryError::Build(_) => {
40                write!(f, "Failed to build request due to invalid property")
41            }
42        }
43    }
44}
45
46impl std::error::Error for SendWithRetryError {}
47
48impl SendWithRetryError {
49    fn from_request_error(err: RequestError, request_attempt: Attempts) -> Self {
50        match err {
51            RequestError::Build => SendWithRetryError::Build(request_attempt),
52            RequestError::Network(error) => SendWithRetryError::Network(error, request_attempt),
53            RequestError::TimeoutApi => SendWithRetryError::Timeout(request_attempt),
54        }
55    }
56}
57
58#[derive(Debug)]
59enum RequestError {
60    Build,
61    Network(http_common::ClientError),
62    TimeoutApi,
63}
64
65impl std::fmt::Display for RequestError {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            RequestError::TimeoutApi => write!(f, "Api timeout exhausted"),
69            RequestError::Network(error) => write!(f, "Network error: {error}"),
70            RequestError::Build => write!(f, "Failed to build request due to invalid property"),
71        }
72    }
73}
74
75impl std::error::Error for RequestError {}
76
77/// Send the `payload` with a POST request to `target` using the provided `retry_strategy` if the
78/// request fails.
79///
80/// The request builder from [`Endpoint::to_request_builder`] is used with the associated headers
81/// (api key, test token), and `headers` are added to the request. The request is executed with a
82/// timeout of [`Endpoint::timeout_ms`].
83///
84/// # Arguments
85///
86/// # Returns
87///
88/// Return a [`SendWithRetryResult`] containing the response and the number of attempts or an error
89/// describing the last attempt failure.
90///
91/// # Errors
92/// Fail if the request didn't succeed after applying the retry strategy.
93///
94/// # Example
95///
96/// ```rust, no_run
97/// # use libdd_common::Endpoint;
98/// # use libdd_common::http_common::new_default_client;
99/// # use libdd_trace_utils::send_with_retry::*;
100/// # async fn run() -> SendWithRetryResult {
101/// let payload: Vec<u8> = vec![0, 1, 2, 3];
102/// let target = Endpoint {
103///     url: "localhost:8126/v04/traces".parse::<hyper::Uri>().unwrap(),
104///     ..Endpoint::default()
105/// };
106/// let mut headers = http::HeaderMap::new();
107/// headers.insert(
108///     http::HeaderName::from_static("content-type"),
109///     http::HeaderValue::from_static("application/msgpack"),
110/// );
111/// let retry_strategy = RetryStrategy::new(3, 10, RetryBackoffType::Exponential, Some(5));
112/// let client = new_default_client();
113/// send_with_retry(&client, &target, payload, &headers, &retry_strategy).await
114/// # }
115/// ```
116pub async fn send_with_retry<C: Connect>(
117    client: &GenericHttpClient<C>,
118    target: &Endpoint,
119    payload: Vec<u8>,
120    headers: &HeaderMap,
121    retry_strategy: &RetryStrategy,
122) -> SendWithRetryResult {
123    let mut request_attempt = 0;
124    // Wrap the payload in Bytes to avoid expensive clone between retries
125    let payload = Bytes::from(payload);
126
127    debug!(
128        url = %target.url,
129        payload_size = payload.len(),
130        max_retries = retry_strategy.max_retries(),
131        "Sending with retry"
132    );
133
134    loop {
135        request_attempt += 1;
136
137        debug!(
138            attempt = request_attempt,
139            max_retries = retry_strategy.max_retries(),
140            "Attempting request"
141        );
142
143        let mut req = target
144            .to_request_builder(concat!("Tracer/", env!("CARGO_PKG_VERSION")))
145            .or(Err(SendWithRetryError::Build(request_attempt)))?
146            .method(Method::POST);
147        for (key, value) in headers {
148            req = req.header(key, value);
149        }
150
151        match send_request(
152            client,
153            Duration::from_millis(target.timeout_ms),
154            req,
155            payload.clone(),
156        )
157        .await
158        {
159            // An Ok response doesn't necessarily mean the request was successful, we need to
160            // check the status code and if it's not a 2xx or 3xx we treat it as an error
161            Ok(response) => {
162                let status = response.status();
163                debug!(status = %status, attempt = request_attempt, "Received response");
164
165                if status.is_client_error() || status.is_server_error() {
166                    debug!(
167                        status = %status,
168                        attempt = request_attempt,
169                        max_retries = retry_strategy.max_retries(),
170                        "Received error status code"
171                    );
172
173                    if request_attempt < retry_strategy.max_retries() {
174                        debug!(
175                            attempt = request_attempt,
176                            remaining_retries = retry_strategy.max_retries() - request_attempt,
177                            "Retrying after error status code"
178                        );
179                        retry_strategy.delay(request_attempt).await;
180                        continue;
181                    } else {
182                        error!(
183                            status = %status,
184                            attempts = request_attempt,
185                            "Max retries exceeded, returning HTTP error"
186                        );
187                        return Err(SendWithRetryError::Http(response, request_attempt));
188                    }
189                } else {
190                    debug!(
191                        status = %status,
192                        attempts = request_attempt,
193                        "Request succeeded"
194                    );
195                    return Ok((response, request_attempt));
196                }
197            }
198            Err(e) => {
199                debug!(
200                    error = ?e,
201                    attempt = request_attempt,
202                    max_retries = retry_strategy.max_retries(),
203                    "Request failed with error"
204                );
205
206                if request_attempt < retry_strategy.max_retries() {
207                    debug!(
208                        attempt = request_attempt,
209                        remaining_retries = retry_strategy.max_retries() - request_attempt,
210                        "Retrying after request error"
211                    );
212                    retry_strategy.delay(request_attempt).await;
213                    continue;
214                } else {
215                    error!(
216                        error = ?e,
217                        attempts = request_attempt,
218                        "Max retries exceeded, returning request error"
219                    );
220                    return Err(SendWithRetryError::from_request_error(e, request_attempt));
221                }
222            }
223        }
224    }
225}
226
227async fn send_request<C: Connect>(
228    client: &GenericHttpClient<C>,
229    timeout: Duration,
230    req: HttpRequestBuilder,
231    payload: Bytes,
232) -> Result<http_common::HttpResponse, RequestError> {
233    let req = req
234        .body(http_common::Body::from_bytes(payload))
235        .or(Err(RequestError::Build))?;
236
237    let req_future = { client.request(req) };
238
239    match tokio::time::timeout(timeout, req_future).await {
240        Ok(resp) => match resp {
241            Ok(body) => Ok(http_common::into_response(body)),
242            Err(e) => Err(RequestError::Network(http_common::into_error(e))),
243        },
244        Err(_) => Err(RequestError::TimeoutApi),
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use crate::test_utils::poll_for_mock_hit;
252    use httpmock::MockServer;
253
254    #[cfg_attr(miri, ignore)]
255    #[tokio::test]
256    async fn test_zero_retries_on_error() {
257        let server = MockServer::start();
258
259        let mut mock_503 = server
260            .mock_async(|_when, then| {
261                then.status(503)
262                    .header("content-type", "application/json")
263                    .body(r#"{"status":"error"}"#);
264            })
265            .await;
266
267        // We add this mock so that if a second request was made it would be a success and our
268        // assertion below that last_result is an error would fail.
269        let _mock_202 = server
270            .mock_async(|_when, then| {
271                then.status(202)
272                    .header("content-type", "application/json")
273                    .body(r#"{"status":"ok"}"#);
274            })
275            .await;
276
277        let target_endpoint = Endpoint {
278            url: server.url("").to_owned().parse().unwrap(),
279            api_key: Some("test-key".into()),
280            ..Default::default()
281        };
282
283        let strategy = RetryStrategy::new(0, 2, RetryBackoffType::Constant, None);
284
285        let client = libdd_common::http_common::new_default_client();
286        tokio::spawn(async move {
287            let result = send_with_retry(
288                &client,
289                &target_endpoint,
290                vec![0, 1, 2, 3],
291                &HeaderMap::new(),
292                &strategy,
293            )
294            .await;
295            assert!(result.is_err(), "Expected an error result");
296            assert!(
297                matches!(result.unwrap_err(), SendWithRetryError::Http(_, 1)),
298                "Expected an http error with one attempt"
299            );
300        });
301
302        assert!(poll_for_mock_hit(&mut mock_503, 10, 100, 1, true).await);
303    }
304
305    #[cfg_attr(miri, ignore)]
306    #[tokio::test]
307    async fn test_retry_logic_error_then_success() {
308        let server = MockServer::start();
309
310        let mut mock_503 = server
311            .mock_async(|_when, then| {
312                then.status(503)
313                    .header("content-type", "application/json")
314                    .body(r#"{"status":"error"}"#);
315            })
316            .await;
317
318        let mut mock_202 = server
319            .mock_async(|_when, then| {
320                then.status(202)
321                    .header("content-type", "application/json")
322                    .body(r#"{"status":"ok"}"#);
323            })
324            .await;
325
326        let target_endpoint = Endpoint {
327            url: server.url("").to_owned().parse().unwrap(),
328            api_key: Some("test-key".into()),
329            ..Default::default()
330        };
331
332        let strategy = RetryStrategy::new(2, 250, RetryBackoffType::Constant, None);
333
334        let client = libdd_common::http_common::new_default_client();
335        tokio::spawn(async move {
336            let result = send_with_retry(
337                &client,
338                &target_endpoint,
339                vec![0, 1, 2, 3],
340                &HeaderMap::new(),
341                &strategy,
342            )
343            .await;
344            assert!(
345                matches!(result.unwrap(), (_, 2)),
346                "Expected an ok result after two attempts"
347            );
348        });
349
350        assert!(poll_for_mock_hit(&mut mock_503, 10, 100, 1, true).await);
351        assert!(
352            poll_for_mock_hit(&mut mock_202, 10, 100, 1, true).await,
353            "Expected a retry request after a 5xx error"
354        );
355    }
356
357    #[cfg_attr(miri, ignore)]
358    #[tokio::test]
359    async fn test_retry_logic_max_errors() {
360        let server = MockServer::start();
361        let expected_retry_attempts = 3;
362        let mut mock_503 = server
363            .mock_async(|_when, then| {
364                then.status(503)
365                    .header("content-type", "application/json")
366                    .body(r#"{"status":"error"}"#);
367            })
368            .await;
369
370        let target_endpoint = Endpoint {
371            url: server.url("").to_owned().parse().unwrap(),
372            api_key: Some("test-key".into()),
373            ..Default::default()
374        };
375
376        let strategy = RetryStrategy::new(
377            expected_retry_attempts,
378            10,
379            RetryBackoffType::Constant,
380            None,
381        );
382
383        let client = libdd_common::http_common::new_default_client();
384        tokio::spawn(async move {
385            let result = send_with_retry(
386                &client,
387                &target_endpoint,
388                vec![0, 1, 2, 3],
389                &HeaderMap::new(),
390                &strategy,
391            )
392            .await;
393            assert!(
394                matches!(result.unwrap_err(), SendWithRetryError::Http(_, attempts) if attempts == expected_retry_attempts),
395                "Expected an error result after max retry attempts"
396            );
397        });
398
399        assert!(
400            poll_for_mock_hit(
401                &mut mock_503,
402                10,
403                100,
404                expected_retry_attempts as usize,
405                true
406            )
407            .await,
408            "Expected max retry attempts"
409        );
410    }
411
412    #[cfg_attr(miri, ignore)]
413    #[tokio::test]
414    async fn test_retry_logic_no_errors() {
415        let server = MockServer::start();
416        let mut mock_202 = server
417            .mock_async(|_when, then| {
418                then.status(202)
419                    .header("content-type", "application/json")
420                    .body(r#"{"status":"Ok"}"#);
421            })
422            .await;
423
424        let target_endpoint = Endpoint {
425            url: server.url("").to_owned().parse().unwrap(),
426            api_key: Some("test-key".into()),
427            ..Default::default()
428        };
429
430        let strategy = RetryStrategy::new(2, 10, RetryBackoffType::Constant, None);
431
432        let client = libdd_common::http_common::new_default_client();
433        tokio::spawn(async move {
434            let result = send_with_retry(
435                &client,
436                &target_endpoint,
437                vec![0, 1, 2, 3],
438                &HeaderMap::new(),
439                &strategy,
440            )
441            .await;
442            assert!(
443                matches!(result, Ok((_, attempts)) if attempts == 1),
444                "Expected an ok result after one attempts"
445            );
446        });
447
448        assert!(
449            poll_for_mock_hit(&mut mock_202, 10, 250, 1, true).await,
450            "Expected only one request attempt"
451        );
452    }
453}