use std::time::Duration;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use tokio_retry::{
RetryIf,
strategy::{ExponentialBackoff, FixedInterval},
};
use super::{error::Error, misc::make_uuid_header_value};
#[derive(Debug, Clone)]
pub struct RequestSettings {
retries: usize,
retry_time: Duration,
max_retry_time: Option<Duration>,
exp_backoff: bool,
timeout: Option<Duration>,
headers: HeaderMap,
}
impl Default for RequestSettings {
fn default() -> Self {
Self {
retries: 3,
retry_time: Duration::from_millis(1000),
max_retry_time: None,
exp_backoff: false,
timeout: None,
headers: HeaderMap::new(),
}
}
}
impl RequestSettings {
#[must_use]
pub const fn retries(mut self, retries: usize) -> Self {
self.retries = retries;
self
}
#[must_use]
pub const fn fixed_retry(mut self, interval: Duration) -> Self {
self.exp_backoff = false;
self.retry_time = interval;
self.max_retry_time = None;
self
}
#[must_use]
pub const fn exp_backoff(mut self, start_time: Duration, max_time: Option<Duration>) -> Self {
self.exp_backoff = true;
self.retry_time = start_time;
self.max_retry_time = max_time;
self
}
#[must_use]
pub const fn timeout(mut self, timeout: Option<Duration>) -> Self {
self.timeout = timeout;
self
}
#[must_use]
pub fn header(mut self, header: HeaderName, value: HeaderValue) -> Self {
self.headers.insert(header, value);
self
}
#[tracing::instrument(level = "debug", skip_all, fields(x_correlation_id, x_request_id))]
pub(crate) async fn make_request(
&self,
mut request: reqwest::RequestBuilder,
) -> Result<reqwest::Response, Error> {
if let Some(timeout) = self.timeout {
request = request.timeout(timeout);
}
let (client, request_result) = request.build_split();
let mut request = request_result?;
let mut headers = self.headers.clone();
headers.extend(request.headers().clone());
*request.headers_mut() = headers;
let id_value = make_uuid_header_value();
request.headers_mut().entry("X-Correlation-Id").or_insert_with(|| id_value.clone());
request.headers_mut().entry("X-Request-Id").or_insert(id_value);
#[expect(
clippy::cast_possible_truncation,
reason = "That high backoff is unrealistic and not an issue"
)]
let strategy: Box<dyn Iterator<Item = Duration> + Send + Sync> = if self.exp_backoff {
let mut exp_backoff =
ExponentialBackoff::from_millis(self.retry_time.as_millis() as u64);
if let Some(max_backoff) = self.max_retry_time {
exp_backoff = exp_backoff.max_delay(max_backoff);
}
Box::new(exp_backoff)
} else {
Box::new(FixedInterval::from_millis(self.retry_time.as_millis() as u64))
};
let x_correlation_id =
request.headers().get("X-Correlation-Id").and_then(|v| v.to_str().ok());
let x_request_id = request.headers().get("X-Request-Id").and_then(|v| v.to_str().ok());
tracing::Span::current()
.record("x_correlation_id", x_correlation_id)
.record("x_request_id", x_request_id);
RetryIf::spawn(
strategy.take(self.retries),
async || {
tracing::debug!("Sending {} request to {}", request.method(), request.url());
let request = request.try_clone().ok_or(Error::RequestNotClone)?;
let response = client.execute(request).await?;
tracing::debug!("Got response: {}", response.status());
Ok(response)
},
Error::should_retry,
)
.await
}
}