lychee_lib/ratelimit/
pool.rs1use dashmap::DashMap;
2use http::Method;
3use reqwest::{Client, Request};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::ratelimit::{
8 CacheableResponse, Host, HostConfigs, HostKey, HostStats, HostStatsMap, RateLimitConfig,
9};
10use crate::types::Result;
11use crate::{ErrorKind, Uri};
12
13pub type ClientMap = HashMap<HostKey, reqwest::Client>;
15
16#[derive(Debug)]
28pub struct HostPool {
29 hosts: DashMap<HostKey, Arc<Host>>,
31
32 global_config: RateLimitConfig,
34
35 host_configs: HostConfigs,
37
38 default_client: Client,
40
41 client_map: ClientMap,
43}
44
45impl HostPool {
46 #[must_use]
48 pub fn new(
49 global_config: RateLimitConfig,
50 host_configs: HostConfigs,
51 default_client: Client,
52 client_map: ClientMap,
53 ) -> Self {
54 Self {
55 hosts: DashMap::new(),
56 global_config,
57 host_configs,
58 default_client,
59 client_map,
60 }
61 }
62
63 pub(crate) async fn execute_request(
71 &self,
72 request: Request,
73 needs_body: bool,
74 ) -> Result<CacheableResponse> {
75 let url = request.url();
76 let host_key = HostKey::try_from(url)?;
77 let host = self.get_or_create_host(host_key);
78 host.execute_request(request, needs_body).await
79 }
80
81 pub fn build_request(&self, method: Method, uri: &Uri) -> Result<Request> {
89 let host_key = HostKey::try_from(uri)?;
90 let host = self.get_or_create_host(host_key);
91 host.get_client()
92 .request(method, uri.url.clone())
93 .build()
94 .map_err(ErrorKind::BuildRequestClient)
95 }
96
97 fn get_or_create_host(&self, host_key: HostKey) -> Arc<Host> {
99 self.hosts
100 .entry(host_key.clone())
101 .or_insert_with(|| {
102 let host_config = self
103 .host_configs
104 .get(&host_key)
105 .cloned()
106 .unwrap_or_default();
107
108 let client = self
109 .client_map
110 .get(&host_key)
111 .unwrap_or(&self.default_client)
112 .clone();
113
114 Arc::new(Host::new(
115 host_key,
116 &host_config,
117 &self.global_config,
118 client,
119 ))
120 })
121 .value()
122 .clone()
123 }
124
125 #[must_use]
128 pub fn host_stats(&self, hostname: &str) -> HostStats {
129 let host_key = HostKey::from(hostname);
130 self.hosts
131 .get(&host_key)
132 .map(|host| host.stats())
133 .unwrap_or_default()
134 }
135
136 #[must_use]
139 pub fn all_host_stats(&self) -> HostStatsMap {
140 HostStatsMap::from(
141 self.hosts
142 .iter()
143 .map(|entry| {
144 let hostname = entry.key().to_string();
145 let stats = entry.value().stats();
146 (hostname, stats)
147 })
148 .collect::<HashMap<_, _>>(),
149 )
150 }
151
152 #[must_use]
156 pub fn active_host_count(&self) -> usize {
157 self.hosts.len()
158 }
159
160 #[must_use]
163 pub fn host_configurations(&self) -> HostConfigs {
164 self.host_configs.clone()
165 }
166
167 #[must_use]
177 pub fn remove_host(&self, hostname: &str) -> bool {
178 let host_key = HostKey::from(hostname);
179 self.hosts.remove(&host_key).is_some()
180 }
181
182 #[must_use]
184 pub fn cache_stats(&self) -> HashMap<String, (usize, f64)> {
185 self.hosts
186 .iter()
187 .map(|entry| {
188 let hostname = entry.key().to_string();
189 let cache_size = entry.value().cache_size();
190 let hit_rate = entry.value().stats().cache_hit_rate();
191 (hostname, (cache_size, hit_rate))
192 })
193 .collect()
194 }
195
196 pub fn record_persistent_cache_hit(&self, uri: &crate::Uri) {
201 if !uri.is_file() && !uri.is_mail() {
202 match crate::ratelimit::HostKey::try_from(uri) {
203 Ok(key) => {
204 let host = self.get_or_create_host(key);
205 host.record_persistent_cache_hit();
206 }
207 Err(e) => {
208 log::debug!("Failed to record cache hit for {uri}: {e}");
209 }
210 }
211 }
212 }
213}
214
215impl Default for HostPool {
216 fn default() -> Self {
217 Self::new(
218 RateLimitConfig::default(),
219 HostConfigs::default(),
220 Client::default(),
221 HashMap::new(),
222 )
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use crate::ratelimit::RateLimitConfig;
230
231 use url::Url;
232
233 #[test]
234 fn test_host_pool_creation() {
235 let pool = HostPool::new(
236 RateLimitConfig::default(),
237 HostConfigs::default(),
238 Client::default(),
239 HashMap::new(),
240 );
241
242 assert_eq!(pool.active_host_count(), 0);
243 }
244
245 #[test]
246 fn test_host_pool_default() {
247 let pool = HostPool::default();
248 assert_eq!(pool.active_host_count(), 0);
249 }
250
251 #[tokio::test]
252 async fn test_host_creation_on_demand() {
253 let pool = HostPool::default();
254 let url: Url = "https://example.com/path".parse().unwrap();
255 let host_key = HostKey::try_from(&url).unwrap();
256
257 assert_eq!(pool.active_host_count(), 0);
259 assert_eq!(pool.host_stats("example.com").total_requests, 0);
260
261 let host = pool.get_or_create_host(host_key);
263
264 assert_eq!(pool.active_host_count(), 1);
266 assert_eq!(pool.host_stats("example.com").total_requests, 0);
267 assert_eq!(host.key.as_str(), "example.com");
268 }
269
270 #[tokio::test]
271 async fn test_host_reuse() {
272 let pool = HostPool::default();
273 let url: Url = "https://example.com/path1".parse().unwrap();
274 let host_key1 = HostKey::try_from(&url).unwrap();
275
276 let url: Url = "https://example.com/path2".parse().unwrap();
277 let host_key2 = HostKey::try_from(&url).unwrap();
278
279 let host1 = pool.get_or_create_host(host_key1);
281 assert_eq!(pool.active_host_count(), 1);
282
283 let host2 = pool.get_or_create_host(host_key2);
285 assert_eq!(pool.active_host_count(), 1);
286
287 assert!(Arc::ptr_eq(&host1, &host2));
289 }
290
291 #[test]
292 fn test_host_config_management() {
293 let pool = HostPool::default();
294
295 let configs = pool.host_configurations();
297 assert_eq!(configs.len(), 0);
298 }
299
300 #[test]
301 fn test_host_removal() {
302 let pool = HostPool::default();
303
304 assert!(!pool.remove_host("nonexistent.com"));
306
307 }
310}