alpaca-rest-http 0.25.1

Shared HTTP transport layer for the alpaca-rust workspace
Documentation
use std::time::Duration;

use reqwest::{Method, StatusCode};

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RetryConfig {
    retryable_methods: Vec<Method>,
    max_retries: u32,
    retry_on_429: bool,
    respect_retry_after: bool,
    base_backoff: Duration,
    max_backoff: Duration,
    total_retry_budget: Option<Duration>,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RetryDecision {
    DoNotRetry,
    RetryAfter(Duration),
}

impl Default for RetryConfig {
    fn default() -> Self {
        Self {
            retryable_methods: vec![Method::GET],
            max_retries: 3,
            retry_on_429: false,
            respect_retry_after: false,
            base_backoff: Duration::from_millis(50),
            max_backoff: Duration::from_secs(5),
            total_retry_budget: None,
        }
    }
}

impl RetryConfig {
    #[must_use]
    pub fn with_retryable_methods<I>(mut self, methods: I) -> Self
    where
        I: IntoIterator<Item = Method>,
    {
        self.retryable_methods = methods.into_iter().collect();
        self
    }

    #[must_use]
    pub fn with_retry_on_429(mut self, retry_on_429: bool) -> Self {
        self.retry_on_429 = retry_on_429;
        self
    }

    #[must_use]
    pub fn with_respect_retry_after(mut self, respect_retry_after: bool) -> Self {
        self.respect_retry_after = respect_retry_after;
        self
    }

    #[must_use]
    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
        self.max_retries = max_retries;
        self
    }

    #[must_use]
    pub fn classify_response(
        &self,
        method: &Method,
        status: StatusCode,
        attempt: u32,
        retry_after: Option<Duration>,
        elapsed: Duration,
    ) -> RetryDecision {
        if attempt >= self.max_retries || !self.retryable_methods.iter().any(|item| item == method)
        {
            return RetryDecision::DoNotRetry;
        }

        let wait = if status == StatusCode::TOO_MANY_REQUESTS {
            if !self.retry_on_429 {
                return RetryDecision::DoNotRetry;
            }

            if self.respect_retry_after {
                retry_after.unwrap_or_else(|| self.backoff(attempt + 1))
            } else {
                self.backoff(attempt + 1)
            }
        } else if status.is_server_error() {
            self.backoff(attempt + 1)
        } else {
            return RetryDecision::DoNotRetry;
        };

        let wait = wait.min(self.max_backoff);

        if let Some(total_retry_budget) = self.total_retry_budget {
            let Some(remaining_budget) = total_retry_budget.checked_sub(elapsed) else {
                return RetryDecision::DoNotRetry;
            };
            if remaining_budget.is_zero() {
                return RetryDecision::DoNotRetry;
            }
            return RetryDecision::RetryAfter(wait.min(remaining_budget));
        }

        RetryDecision::RetryAfter(wait)
    }

    #[must_use]
    pub fn classify_transport_error(
        &self,
        method: &Method,
        attempt: u32,
        elapsed: Duration,
    ) -> RetryDecision {
        if attempt >= self.max_retries || !self.retryable_methods.iter().any(|item| item == method)
        {
            return RetryDecision::DoNotRetry;
        }

        let wait = self.backoff(attempt + 1).min(self.max_backoff);

        if let Some(total_retry_budget) = self.total_retry_budget {
            let Some(remaining_budget) = total_retry_budget.checked_sub(elapsed) else {
                return RetryDecision::DoNotRetry;
            };
            if remaining_budget.is_zero() {
                return RetryDecision::DoNotRetry;
            }
            return RetryDecision::RetryAfter(wait.min(remaining_budget));
        }

        RetryDecision::RetryAfter(wait)
    }

    fn backoff(&self, attempt: u32) -> Duration {
        let factor = 1u32
            .checked_shl(attempt.saturating_sub(1))
            .unwrap_or(u32::MAX);
        let millis = self.base_backoff.as_millis();
        let scaled = millis.saturating_mul(u128::from(factor));
        let bounded = scaled.min(self.max_backoff.as_millis());
        Duration::from_millis(u64::try_from(bounded).unwrap_or(u64::MAX))
    }
}