github-readme-stats 0.2.0

Fetch GitHub user statistics as JSON
use anyhow::{Context, Result};
use reqwest::{StatusCode, header::HeaderMap};
use std::cmp::min;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::time::sleep;

const RETRY_MAX_ATTEMPTS: usize = 3;
const RETRY_BASE_DELAY_MS: u64 = 500;
const RETRY_MAX_DELAY_MS: u64 = 5000;

pub(crate) async fn send_with_retry<F>(
    mut make_request: F,
    context: &str,
) -> Result<reqwest::Response>
where
    F: FnMut() -> reqwest::RequestBuilder,
{
    let mut attempt = 0;
    loop {
        let response = make_request().send().await;
        match response {
            Ok(resp) => {
                if resp.status().is_success() {
                    return Ok(resp);
                }

                if attempt < RETRY_MAX_ATTEMPTS && should_retry(resp.status(), resp.headers()) {
                    let wait = retry_delay(attempt, resp.status(), resp.headers());
                    sleep(wait).await;
                    attempt += 1;
                    continue;
                }

                let status = resp.status();
                let body = resp.text().await.unwrap_or_default();
                anyhow::bail!("{context} failed with status {status}: {body}");
            }
            Err(err) => {
                if attempt < RETRY_MAX_ATTEMPTS {
                    let wait = retry_delay(
                        attempt,
                        StatusCode::INTERNAL_SERVER_ERROR,
                        &HeaderMap::new(),
                    );
                    sleep(wait).await;
                    attempt += 1;
                    continue;
                }
                return Err(err).with_context(|| format!("{context} failed to send request"));
            }
        }
    }
}

fn should_retry(status: StatusCode, headers: &HeaderMap) -> bool {
    if status == StatusCode::TOO_MANY_REQUESTS {
        return true;
    }
    if status.is_server_error() {
        return true;
    }
    if status == StatusCode::FORBIDDEN {
        if let Some(remaining) = header_value(headers, "x-ratelimit-remaining")
            && remaining == "0"
        {
            return true;
        }
        if headers.get("retry-after").is_some() {
            return true;
        }
    }
    false
}

pub(crate) fn retry_delay(attempt: usize, status: StatusCode, headers: &HeaderMap) -> Duration {
    if let Some(secs) = retry_after_seconds(status, headers) {
        return Duration::from_secs(secs);
    }
    let backoff = RETRY_BASE_DELAY_MS.saturating_mul(2u64.saturating_pow(attempt as u32));
    Duration::from_millis(min(backoff, RETRY_MAX_DELAY_MS))
}

fn retry_after_seconds(status: StatusCode, headers: &HeaderMap) -> Option<u64> {
    if status == StatusCode::TOO_MANY_REQUESTS || status == StatusCode::FORBIDDEN {
        if let Some(value) = header_value(headers, "retry-after")
            && let Ok(secs) = value.parse::<u64>()
        {
            return Some(secs);
        }
        if let Some(value) = header_value(headers, "x-ratelimit-reset")
            && let Ok(epoch) = value.parse::<u64>()
        {
            let now = SystemTime::now()
                .duration_since(UNIX_EPOCH)
                .unwrap_or_default()
                .as_secs();
            if epoch > now {
                return Some(epoch - now);
            }
        }
    }
    None
}

fn header_value(headers: &HeaderMap, name: &str) -> Option<String> {
    headers
        .get(name)
        .and_then(|v| v.to_str().ok())
        .map(|s| s.to_string())
}

#[cfg(test)]
mod tests {
    use super::*;
    use reqwest::header::HeaderValue;

    #[test]
    fn retry_delay_prefers_retry_after() {
        let mut headers = HeaderMap::new();
        headers.insert("retry-after", HeaderValue::from_static("3"));
        let delay = retry_delay(0, StatusCode::TOO_MANY_REQUESTS, &headers);
        assert_eq!(delay.as_secs(), 3);
    }
}