use dashmap::DashMap;
use http::Method;
use reqwest::{Client, Request};
use std::collections::HashMap;
use std::sync::Arc;
use crate::ratelimit::{
CacheableResponse, Host, HostConfigs, HostKey, HostStats, HostStatsMap, RateLimitConfig,
};
use crate::types::Result;
use crate::{ErrorKind, Uri};
pub type ClientMap = HashMap<HostKey, reqwest::Client>;
#[derive(Debug)]
pub struct HostPool {
hosts: DashMap<HostKey, Arc<Host>>,
global_config: RateLimitConfig,
host_configs: HostConfigs,
default_client: Client,
client_map: ClientMap,
}
impl HostPool {
#[must_use]
pub fn new(
global_config: RateLimitConfig,
host_configs: HostConfigs,
default_client: Client,
client_map: ClientMap,
) -> Self {
Self {
hosts: DashMap::new(),
global_config,
host_configs,
default_client,
client_map,
}
}
pub(crate) async fn execute_request(
&self,
request: Request,
needs_body: bool,
) -> Result<CacheableResponse> {
let url = request.url();
let host_key = HostKey::try_from(url)?;
let host = self.get_or_create_host(host_key);
host.execute_request(request, needs_body).await
}
pub fn build_request(&self, method: Method, uri: &Uri) -> Result<Request> {
let host_key = HostKey::try_from(uri)?;
let host = self.get_or_create_host(host_key);
host.get_client()
.request(method, uri.url.clone())
.build()
.map_err(ErrorKind::BuildRequestClient)
}
fn get_or_create_host(&self, host_key: HostKey) -> Arc<Host> {
self.hosts
.entry(host_key.clone())
.or_insert_with(|| {
let host_config = self
.host_configs
.get(&host_key)
.cloned()
.unwrap_or_default();
let client = self
.client_map
.get(&host_key)
.unwrap_or(&self.default_client)
.clone();
Arc::new(Host::new(
host_key,
&host_config,
&self.global_config,
client,
))
})
.value()
.clone()
}
#[must_use]
pub fn host_stats(&self, hostname: &str) -> HostStats {
let host_key = HostKey::from(hostname);
self.hosts
.get(&host_key)
.map(|host| host.stats())
.unwrap_or_default()
}
#[must_use]
pub fn all_host_stats(&self) -> HostStatsMap {
HostStatsMap::from(
self.hosts
.iter()
.map(|entry| {
let hostname = entry.key().to_string();
let stats = entry.value().stats();
(hostname, stats)
})
.collect::<HashMap<_, _>>(),
)
}
#[must_use]
pub fn active_host_count(&self) -> usize {
self.hosts.len()
}
#[must_use]
pub fn host_configurations(&self) -> HostConfigs {
self.host_configs.clone()
}
#[must_use]
pub fn remove_host(&self, hostname: &str) -> bool {
let host_key = HostKey::from(hostname);
self.hosts.remove(&host_key).is_some()
}
#[must_use]
pub fn cache_stats(&self) -> HashMap<String, (usize, f64)> {
self.hosts
.iter()
.map(|entry| {
let hostname = entry.key().to_string();
let cache_size = entry.value().cache_size();
let hit_rate = entry.value().stats().cache_hit_rate();
(hostname, (cache_size, hit_rate))
})
.collect()
}
pub fn record_persistent_cache_hit(&self, uri: &crate::Uri) {
if !uri.is_file() && !uri.is_mail() {
match crate::ratelimit::HostKey::try_from(uri) {
Ok(key) => {
let host = self.get_or_create_host(key);
host.record_persistent_cache_hit();
}
Err(e) => {
log::debug!("Failed to record cache hit for {uri}: {e}");
}
}
}
}
}
impl Default for HostPool {
fn default() -> Self {
Self::new(
RateLimitConfig::default(),
HostConfigs::default(),
Client::default(),
HashMap::new(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ratelimit::RateLimitConfig;
use url::Url;
#[test]
fn test_host_pool_creation() {
let pool = HostPool::new(
RateLimitConfig::default(),
HostConfigs::default(),
Client::default(),
HashMap::new(),
);
assert_eq!(pool.active_host_count(), 0);
}
#[test]
fn test_host_pool_default() {
let pool = HostPool::default();
assert_eq!(pool.active_host_count(), 0);
}
#[tokio::test]
async fn test_host_creation_on_demand() {
let pool = HostPool::default();
let url: Url = "https://example.com/path".parse().unwrap();
let host_key = HostKey::try_from(&url).unwrap();
assert_eq!(pool.active_host_count(), 0);
assert_eq!(pool.host_stats("example.com").total_requests, 0);
let host = pool.get_or_create_host(host_key);
assert_eq!(pool.active_host_count(), 1);
assert_eq!(pool.host_stats("example.com").total_requests, 0);
assert_eq!(host.key.as_str(), "example.com");
}
#[tokio::test]
async fn test_host_reuse() {
let pool = HostPool::default();
let url: Url = "https://example.com/path1".parse().unwrap();
let host_key1 = HostKey::try_from(&url).unwrap();
let url: Url = "https://example.com/path2".parse().unwrap();
let host_key2 = HostKey::try_from(&url).unwrap();
let host1 = pool.get_or_create_host(host_key1);
assert_eq!(pool.active_host_count(), 1);
let host2 = pool.get_or_create_host(host_key2);
assert_eq!(pool.active_host_count(), 1);
assert!(Arc::ptr_eq(&host1, &host2));
}
#[test]
fn test_host_config_management() {
let pool = HostPool::default();
let configs = pool.host_configurations();
assert_eq!(configs.len(), 0);
}
#[test]
fn test_host_removal() {
let pool = HostPool::default();
assert!(!pool.remove_host("nonexistent.com"));
}
}