use crate::{
ratelimit::{CacheableResponse, headers},
retry::RetryExt,
};
use dashmap::DashMap;
use governor::{
Quota, RateLimiter,
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
};
use http::StatusCode;
use humantime_serde::re::humantime::format_duration;
use log::warn;
use reqwest::{Client as ReqwestClient, Request, Response as ReqwestResponse};
use std::time::{Duration, Instant};
use std::{num::NonZeroU32, sync::Mutex};
use tokio::sync::Semaphore;
use super::key::HostKey;
use super::stats::HostStats;
use crate::Uri;
use crate::types::Result;
use crate::{
ErrorKind,
ratelimit::{HostConfig, RateLimitConfig},
};
const MAXIMUM_BACKOFF: Duration = Duration::from_secs(60);
type HostCache = DashMap<Uri, CacheableResponse>;
#[derive(Debug)]
pub struct Host {
pub key: HostKey,
rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
semaphore: Semaphore,
client: ReqwestClient,
stats: Mutex<HostStats>,
backoff_duration: Mutex<Duration>,
cache: HostCache,
}
impl Host {
#[must_use]
pub fn new(
key: HostKey,
host_config: &HostConfig,
global_config: &RateLimitConfig,
client: ReqwestClient,
) -> Self {
const MAX_BURST: NonZeroU32 = NonZeroU32::new(1).unwrap();
let interval = host_config.effective_request_interval(global_config);
let rate_limiter =
Quota::with_period(interval).map(|q| RateLimiter::direct(q.allow_burst(MAX_BURST)));
let max_concurrent = host_config.effective_concurrency(global_config);
let semaphore = Semaphore::new(max_concurrent);
Host {
key,
rate_limiter,
semaphore,
client,
stats: Mutex::new(HostStats::default()),
backoff_duration: Mutex::new(Duration::from_millis(0)),
cache: DashMap::new(),
}
}
fn get_cached_status(&self, uri: &Uri, needs_body: bool) -> Option<CacheableResponse> {
let cached = self.cache.get(uri)?.clone();
if needs_body {
if cached.text.is_some() {
Some(cached)
} else {
None
}
} else {
Some(cached)
}
}
fn record_cache_hit(&self) {
self.stats.lock().unwrap().record_cache_hit();
}
fn record_cache_miss(&self) {
self.stats.lock().unwrap().record_cache_miss();
}
fn cache_result(&self, uri: &Uri, response: CacheableResponse) {
if !response.status.should_retry() {
self.cache.insert(uri.clone(), response);
}
}
pub(crate) async fn execute_request(
&self,
request: Request,
needs_body: bool,
) -> Result<CacheableResponse> {
let mut url = request.url().clone();
url.set_fragment(None);
let uri = Uri::from(url);
let _permit = self.acquire_semaphore().await;
if let Some(cached) = self.get_cached_status(&uri, needs_body) {
self.record_cache_hit();
return Ok(cached);
}
self.await_backoff().await;
if let Some(rate_limiter) = &self.rate_limiter {
rate_limiter.until_ready().await;
}
if let Some(cached) = self.get_cached_status(&uri, needs_body) {
self.record_cache_hit();
return Ok(cached);
}
self.record_cache_miss();
self.perform_request(request, uri, needs_body).await
}
pub(crate) const fn get_client(&self) -> &ReqwestClient {
&self.client
}
async fn perform_request(
&self,
request: Request,
uri: Uri,
needs_body: bool,
) -> Result<CacheableResponse> {
let start_time = Instant::now();
let response = match self.client.execute(request).await {
Ok(response) => response,
Err(e) => {
return Err(ErrorKind::NetworkRequest(e));
}
};
self.update_stats(response.status(), start_time.elapsed());
self.update_backoff(response.status());
self.handle_rate_limit_headers(&response);
let response = CacheableResponse::from_response(response, needs_body).await?;
self.cache_result(&uri, response.clone());
Ok(response)
}
async fn await_backoff(&self) {
let backoff_duration = {
let backoff = self.backoff_duration.lock().unwrap();
*backoff
};
if !backoff_duration.is_zero() {
log::debug!(
"Host {} applying backoff delay of {}ms due to previous rate limiting or errors",
self.key,
backoff_duration.as_millis()
);
tokio::time::sleep(backoff_duration).await;
}
}
async fn acquire_semaphore(&self) -> tokio::sync::SemaphorePermit<'_> {
self.semaphore
.acquire()
.await
.expect("Semaphore was closed unexpectedly")
}
fn update_backoff(&self, status: StatusCode) {
let mut backoff = self.backoff_duration.lock().unwrap();
match status.as_u16() {
200..=299 => {
*backoff = Duration::from_millis(0);
}
429 => {
let new_backoff = std::cmp::min(
if backoff.is_zero() {
Duration::from_millis(500)
} else {
*backoff * 2
},
Duration::from_secs(30),
);
log::debug!(
"Host {} hit rate limit (429), increasing backoff from {}ms to {}ms",
self.key,
backoff.as_millis(),
new_backoff.as_millis()
);
*backoff = new_backoff;
}
500..=599 => {
*backoff = std::cmp::min(
*backoff + Duration::from_millis(200),
Duration::from_secs(10),
);
}
_ => {} }
}
fn update_stats(&self, status: StatusCode, request_time: Duration) {
self.stats
.lock()
.unwrap()
.record_response(status.as_u16(), request_time);
}
fn handle_rate_limit_headers(&self, response: &ReqwestResponse) {
let headers = response.headers();
self.handle_retry_after_header(headers);
self.handle_common_rate_limit_header_fields(headers);
}
fn handle_common_rate_limit_header_fields(&self, headers: &http::HeaderMap) {
if let (Some(remaining), Some(limit)) =
headers::parse_common_rate_limit_header_fields(headers)
&& limit > 0
{
#[allow(clippy::cast_precision_loss)]
let usage_ratio = limit.saturating_sub(remaining) as f64 / limit as f64;
if usage_ratio > 0.8 {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let duration = Duration::from_millis((200.0 * (usage_ratio - 0.8) / 0.2) as u64);
self.increase_backoff(duration);
}
}
}
fn handle_retry_after_header(&self, headers: &http::HeaderMap) {
if let Some(retry_after_value) = headers.get("retry-after") {
let duration = match headers::parse_retry_after(retry_after_value) {
Ok(e) => e,
Err(e) => {
warn!("Unable to parse Retry-After header as per RFC 7231: {e}");
return;
}
};
self.increase_backoff(duration);
}
}
fn increase_backoff(&self, mut increased_backoff: Duration) {
if increased_backoff > MAXIMUM_BACKOFF {
warn!(
"Host {} sent an unexpectedly big rate limit backoff duration of {}. Capping the duration to {} instead.",
self.key,
format_duration(increased_backoff),
format_duration(MAXIMUM_BACKOFF)
);
increased_backoff = MAXIMUM_BACKOFF;
}
let mut backoff = self.backoff_duration.lock().unwrap();
*backoff = std::cmp::max(*backoff, increased_backoff);
}
pub fn stats(&self) -> HostStats {
self.stats.lock().unwrap().clone()
}
pub(crate) fn record_persistent_cache_hit(&self) {
self.record_cache_hit();
}
pub fn cache_size(&self) -> usize {
self.cache.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ratelimit::{HostConfig, RateLimitConfig};
use reqwest::Client;
#[tokio::test]
async fn test_host_creation() {
let key = HostKey::from("example.com");
let host_config = HostConfig::default();
let global_config = RateLimitConfig::default();
let host = Host::new(key.clone(), &host_config, &global_config, Client::default());
assert_eq!(host.key, key);
assert_eq!(host.semaphore.available_permits(), 10); assert!((host.stats().success_rate() - 1.0).abs() < f64::EPSILON);
assert_eq!(host.cache_size(), 0);
}
}