walker-common 0.17.0

Common functionality for SBOM and CSAF walker
Documentation
//! Fetching remote resources

mod data;
use backon::{ExponentialBuilder, Retryable};
pub use data::*;

use crate::http::{calculate_retry_after_from_response_header, get_client_error};
use reqwest::{Client, ClientBuilder, IntoUrl, Method, Response, StatusCode};
use std::fmt::Debug;
use std::future::Future;
use std::marker::PhantomData;
use std::time::Duration;
use url::Url;

/// Fetch data using HTTP.
///
/// This is some functionality sitting on top an HTTP client, allowing for additional options like
/// retries.
#[derive(Clone, Debug)]
pub struct Fetcher {
    client: Client,
    retries: usize,
    /// *default_retry_after* is used when a 429 response does not include a Retry-After header
    default_retry_after: Duration,
}

/// Error when retrieving
#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("Request error: {0}")]
    Request(#[from] reqwest::Error),
    #[error("Rate limited (HTTP 429), retry after {0:?}")]
    RateLimited(Duration),
    #[error("Client error: {0}")]
    ClientError(StatusCode),
}

/// Options for the [`Fetcher`]
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct FetcherOptions {
    timeout: Duration,
    retries: usize,
    default_retry_after: Duration,
    max_retry_after: Duration,
}

impl FetcherOptions {
    /// Create a new instance.
    pub fn new() -> Self {
        Self::default()
    }

    /// Set the timeout.
    pub fn timeout(mut self, timeout: impl Into<Duration>) -> Self {
        self.timeout = timeout.into();
        self
    }

    /// Set the number of retries.
    pub fn retries(mut self, retries: usize) -> Self {
        self.retries = retries;
        self
    }

    /// Set the default retry-after duration when a 429 response doesn't include a Retry-After header.
    pub fn retry_after(mut self, duration: Duration) -> Self {
        if duration > self.max_retry_after {
            panic!("Default retry-after cannot be greater than max retry-after (300s)");
        }
        self.default_retry_after = duration;
        self
    }

    /// Set the default retry-after duration when a 429 response doesn't include a Retry-After header
    /// and checks the duration against the maximum retry-after.
    pub fn retry_after_with_max(mut self, default: Duration, max: Duration) -> Self {
        if default > max {
            panic!("Default retry-after cannot be greater than max retry-after");
        }
        self.default_retry_after = default;
        self.max_retry_after = max;
        self
    }
}

impl Default for FetcherOptions {
    fn default() -> Self {
        Self {
            timeout: Duration::from_secs(30),
            retries: 5,
            default_retry_after: Duration::from_secs(10),
            max_retry_after: Duration::from_mins(5),
        }
    }
}

impl From<Client> for Fetcher {
    fn from(client: Client) -> Self {
        Self::with_client(client, FetcherOptions::default())
    }
}

impl Fetcher {
    /// Create a new downloader from options
    pub async fn new(options: FetcherOptions) -> anyhow::Result<Self> {
        let client = ClientBuilder::new().timeout(options.timeout);

        Ok(Self::with_client(client.build()?, options))
    }

    /// Create a fetcher providing an existing client.
    fn with_client(client: Client, options: FetcherOptions) -> Self {
        Self {
            client,
            retries: options.retries,
            default_retry_after: options.default_retry_after,
        }
    }

    async fn new_request(
        &self,
        method: Method,
        url: Url,
    ) -> Result<reqwest::RequestBuilder, reqwest::Error> {
        Ok(self.client.request(method, url))
    }

    /// fetch data, using a GET request.
    pub async fn fetch<D: Data>(&self, url: impl IntoUrl) -> Result<D, Error> {
        log::debug!("Fetching: {}", url.as_str());
        self.fetch_processed(url, TypedProcessor::<D>::new()).await
    }

    /// fetch data, using a GET request, processing the response data.
    pub async fn fetch_processed<D: DataProcessor>(
        &self,
        url: impl IntoUrl,
        processor: D,
    ) -> Result<D::Type, Error> {
        // if the URL building fails, there is no need to re-try, abort now.
        let url = url.into_url()?;

        let retries = self.retries;
        let retry = ExponentialBuilder::default().with_max_times(retries);

        (|| async { self.fetch_once(url.clone(), &processor).await })
            .retry(retry)
            .when(|e| !matches!(e, Error::ClientError(_)))
            .adjust(|e, dur| {
                if let Error::RateLimited(retry_after) = e {
                    if let Some(dur_value) = dur
                        && dur_value > *retry_after
                    {
                        return dur;
                    }
                    Some(*retry_after) // only use server-provided delay if it's longer
                } else {
                    dur // minimum delay as per backoff strategy
                }
            })
            .await
    }

    async fn fetch_once<D: DataProcessor>(
        &self,
        url: Url,
        processor: &D,
    ) -> Result<D::Type, Error> {
        let response = self.new_request(Method::GET, url).await?.send().await?;

        log::debug!("Response Status: {}", response.status());

        // Check for rate limiting
        if let Some(retry_after) =
            calculate_retry_after_from_response_header(&response, self.default_retry_after)
        {
            log::info!("Rate limited (429), retry after: {:?}", retry_after);
            return Err(Error::RateLimited(retry_after));
        }

        // Now test if we can convert the (possibly failed) response to result data.
        // This includes allowed for 404 becoming `None`.
        match processor.process(response).await {
            // Ok, return
            Ok(data) => Ok(data),
            // Error, extract client error
            Err(err) => {
                if let Some(status_code) = err.status().and_then(get_client_error) {
                    log::debug!("Client error: {status_code}");
                    Err(Error::ClientError(status_code))
                } else {
                    // or return other error as is
                    Err(err.into())
                }
            }
        }
    }
}

/// Processing data returned by a request.
pub trait DataProcessor {
    type Type: Sized;
    fn process(
        &self,
        response: reqwest::Response,
    ) -> impl Future<Output = Result<Self::Type, reqwest::Error>>;
}

struct TypedProcessor<D: Data> {
    _marker: PhantomData<D>,
}

impl<D: Data> TypedProcessor<D> {
    pub const fn new() -> Self {
        Self {
            _marker: PhantomData::<D>,
        }
    }
}

/// Extract response payload which implements [`Data`].
impl<D: Data> DataProcessor for TypedProcessor<D> {
    type Type = D;

    async fn process(&self, response: Response) -> Result<Self::Type, reqwest::Error> {
        D::from_response(response).await
    }
}