1use crate::error::{DepsError, Result};
2use bytes::Bytes;
3use dashmap::DashMap;
4use reqwest::{Client, StatusCode, header};
5use std::time::Instant;
6
7const MAX_CACHE_ENTRIES: usize = 1000;
9
10const HTTP_TIMEOUT_SECS: u64 = 30;
12
13const CACHE_EVICTION_PERCENTAGE: usize = 10;
15
16#[inline]
23fn ensure_https(url: &str) -> Result<()> {
24 #[cfg(not(test))]
25 if !url.starts_with("https://") {
26 return Err(DepsError::CacheError(format!("URL must use HTTPS: {url}")));
27 }
28 #[cfg(test)]
29 let _ = url; Ok(())
31}
32
33#[derive(Debug, Clone)]
58pub struct CachedResponse {
59 pub body: Bytes,
60 pub etag: Option<String>,
61 pub last_modified: Option<String>,
62 pub fetched_at: Instant,
63}
64
65pub struct HttpCache {
93 entries: DashMap<String, CachedResponse>,
94 client: Client,
95}
96
97impl HttpCache {
98 pub fn new() -> Self {
103 let client = Client::builder()
104 .user_agent(format!("deps-lsp/{}", env!("CARGO_PKG_VERSION")))
105 .timeout(std::time::Duration::from_secs(HTTP_TIMEOUT_SECS))
106 .build()
107 .expect("failed to create HTTP client");
108
109 Self {
110 entries: DashMap::new(),
111 client,
112 }
113 }
114
115 pub async fn get_cached(&self, url: &str) -> Result<Bytes> {
148 if self.entries.len() >= MAX_CACHE_ENTRIES {
150 self.evict_entries();
151 }
152
153 if let Some(cached) = self.entries.get(url).map(|r| r.clone()) {
154 match self.conditional_request(url, &cached).await {
158 Ok(Some(new_body)) => {
159 return Ok(new_body);
160 }
161 Ok(None) => {
162 return Ok(cached.body);
163 }
164 Err(e) => {
165 tracing::warn!("conditional request failed, using cache: {e}");
166 return Ok(cached.body);
167 }
168 }
169 }
170
171 self.fetch_and_store(url).await
173 }
174
175 async fn conditional_request(
186 &self,
187 url: &str,
188 cached: &CachedResponse,
189 ) -> Result<Option<Bytes>> {
190 ensure_https(url)?;
191 let mut request = self.client.get(url);
192
193 if let Some(etag) = &cached.etag {
194 request = request.header(header::IF_NONE_MATCH, etag);
195 }
196 if let Some(last_modified) = &cached.last_modified {
197 request = request.header(header::IF_MODIFIED_SINCE, last_modified);
198 }
199
200 let response = request.send().await.map_err(|e| DepsError::RegistryError {
201 package: url.to_string(),
202 source: e,
203 })?;
204
205 if response.status() == StatusCode::NOT_MODIFIED {
206 return Ok(None);
208 }
209
210 let etag = response
212 .headers()
213 .get(header::ETAG)
214 .and_then(|v| v.to_str().ok())
215 .map(String::from);
216
217 let last_modified = response
218 .headers()
219 .get(header::LAST_MODIFIED)
220 .and_then(|v| v.to_str().ok())
221 .map(String::from);
222
223 let body = response
224 .bytes()
225 .await
226 .map_err(|e| DepsError::RegistryError {
227 package: url.to_string(),
228 source: e,
229 })?;
230
231 self.entries.insert(
233 url.to_string(),
234 CachedResponse {
235 body: body.clone(),
236 etag,
237 last_modified,
238 fetched_at: Instant::now(),
239 },
240 );
241
242 Ok(Some(body))
243 }
244
245 pub(crate) async fn fetch_and_store(&self, url: &str) -> Result<Bytes> {
256 ensure_https(url)?;
257 tracing::debug!("fetching fresh: {url}");
258
259 let response = self
260 .client
261 .get(url)
262 .send()
263 .await
264 .map_err(|e| DepsError::RegistryError {
265 package: url.to_string(),
266 source: e,
267 })?;
268
269 if !response.status().is_success() {
270 let status = response.status();
271 return Err(DepsError::CacheError(format!("HTTP {status} for {url}")));
272 }
273
274 let etag = response
275 .headers()
276 .get(header::ETAG)
277 .and_then(|v| v.to_str().ok())
278 .map(String::from);
279
280 let last_modified = response
281 .headers()
282 .get(header::LAST_MODIFIED)
283 .and_then(|v| v.to_str().ok())
284 .map(String::from);
285
286 let body = response
287 .bytes()
288 .await
289 .map_err(|e| DepsError::RegistryError {
290 package: url.to_string(),
291 source: e,
292 })?;
293
294 self.entries.insert(
295 url.to_string(),
296 CachedResponse {
297 body: body.clone(),
298 etag,
299 last_modified,
300 fetched_at: Instant::now(),
301 },
302 );
303
304 Ok(body)
305 }
306
307 pub fn clear(&self) {
312 self.entries.clear();
313 }
314
315 pub fn len(&self) -> usize {
317 self.entries.len()
318 }
319
320 pub fn is_empty(&self) -> bool {
322 self.entries.is_empty()
323 }
324
325 fn evict_entries(&self) {
333 use std::cmp::Reverse;
334 use std::collections::BinaryHeap;
335
336 let target_removals = MAX_CACHE_ENTRIES / CACHE_EVICTION_PERCENTAGE;
337
338 let mut oldest = BinaryHeap::with_capacity(target_removals);
341
342 for entry in &self.entries {
343 let item = (entry.value().fetched_at, entry.key().clone());
344
345 if oldest.len() < target_removals {
346 oldest.push(Reverse(item));
348 } else if let Some(Reverse(newest_of_oldest)) = oldest.peek() {
349 if item.0 < newest_of_oldest.0 {
352 oldest.pop();
353 oldest.push(Reverse(item));
354 }
355 }
356 }
357
358 let removed = oldest.len();
360 for Reverse((_, url)) in oldest {
361 self.entries.remove(&url);
362 }
363
364 tracing::debug!("evicted {} cache entries (O(N) algorithm)", removed);
365 }
366
367 #[doc(hidden)]
369 pub fn get_for_bench(&self, url: &str) -> Option<Bytes> {
370 self.entries.get(url).map(|entry| entry.body.clone())
371 }
372
373 #[doc(hidden)]
375 pub fn insert_for_bench(&self, url: String, response: CachedResponse) {
376 self.entries.insert(url, response);
377 }
378}
379
380impl Default for HttpCache {
381 fn default() -> Self {
382 Self::new()
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_cache_creation() {
392 let cache = HttpCache::new();
393 assert_eq!(cache.len(), 0);
394 assert!(cache.is_empty());
395 }
396
397 #[test]
398 fn test_cache_clear() {
399 let cache = HttpCache::new();
400 cache.entries.insert(
401 "test".into(),
402 CachedResponse {
403 body: Bytes::from_static(&[1, 2, 3]),
404 etag: None,
405 last_modified: None,
406 fetched_at: Instant::now(),
407 },
408 );
409 assert_eq!(cache.len(), 1);
410 cache.clear();
411 assert_eq!(cache.len(), 0);
412 }
413
414 #[test]
415 fn test_cached_response_clone() {
416 let response = CachedResponse {
417 body: Bytes::from_static(&[1, 2, 3]),
418 etag: Some("test".into()),
419 last_modified: Some("date".into()),
420 fetched_at: Instant::now(),
421 };
422 let cloned = response.clone();
423 assert_eq!(response.body, cloned.body);
425 assert_eq!(response.etag, cloned.etag);
426 }
427
428 #[test]
429 fn test_cache_len() {
430 let cache = HttpCache::new();
431 assert_eq!(cache.len(), 0);
432
433 cache.entries.insert(
434 "url1".into(),
435 CachedResponse {
436 body: Bytes::new(),
437 etag: None,
438 last_modified: None,
439 fetched_at: Instant::now(),
440 },
441 );
442
443 assert_eq!(cache.len(), 1);
444 }
445
446 #[tokio::test]
447 async fn test_get_cached_fresh_fetch() {
448 let mut server = mockito::Server::new_async().await;
449
450 let _m = server
451 .mock("GET", "/api/data")
452 .with_status(200)
453 .with_header("etag", "\"abc123\"")
454 .with_body("test data")
455 .create_async()
456 .await;
457
458 let cache = HttpCache::new();
459 let url = format!("{}/api/data", server.url());
460 let result: Bytes = cache.get_cached(&url).await.unwrap();
461
462 assert_eq!(result.as_ref(), b"test data");
463 assert_eq!(cache.len(), 1);
464 }
465
466 #[tokio::test]
467 async fn test_get_cached_cache_hit() {
468 let mut server = mockito::Server::new_async().await;
469 let url = format!("{}/api/data", server.url());
470
471 let cache = HttpCache::new();
472
473 let _m1 = server
474 .mock("GET", "/api/data")
475 .with_status(200)
476 .with_header("etag", "\"abc123\"")
477 .with_body("original data")
478 .create_async()
479 .await;
480
481 let result1: Bytes = cache.get_cached(&url).await.unwrap();
482 assert_eq!(result1.as_ref(), b"original data");
483 assert_eq!(cache.len(), 1);
484
485 drop(_m1);
486
487 let _m2 = server
488 .mock("GET", "/api/data")
489 .match_header("if-none-match", "\"abc123\"")
490 .with_status(304)
491 .create_async()
492 .await;
493
494 let result2: Bytes = cache.get_cached(&url).await.unwrap();
495 assert_eq!(result2.as_ref(), b"original data");
496 }
497
498 #[tokio::test]
499 async fn test_get_cached_304_not_modified() {
500 let mut server = mockito::Server::new_async().await;
501 let url = format!("{}/api/data", server.url());
502
503 let cache = HttpCache::new();
504
505 let _m1 = server
506 .mock("GET", "/api/data")
507 .with_status(200)
508 .with_header("etag", "\"abc123\"")
509 .with_body("original data")
510 .create_async()
511 .await;
512
513 let result1: Bytes = cache.get_cached(&url).await.unwrap();
514 assert_eq!(result1.as_ref(), b"original data");
515
516 drop(_m1);
517
518 let _m2 = server
519 .mock("GET", "/api/data")
520 .match_header("if-none-match", "\"abc123\"")
521 .with_status(304)
522 .create_async()
523 .await;
524
525 let result2: Bytes = cache.get_cached(&url).await.unwrap();
526 assert_eq!(result2.as_ref(), b"original data");
527 }
528
529 #[tokio::test]
530 async fn test_get_cached_etag_validation() {
531 let mut server = mockito::Server::new_async().await;
532 let url = format!("{}/api/data", server.url());
533
534 let cache = HttpCache::new();
535
536 cache.entries.insert(
537 url.clone(),
538 CachedResponse {
539 body: Bytes::from_static(b"cached"),
540 etag: Some("\"tag123\"".into()),
541 last_modified: None,
542 fetched_at: Instant::now(),
543 },
544 );
545
546 let _m = server
547 .mock("GET", "/api/data")
548 .match_header("if-none-match", "\"tag123\"")
549 .with_status(304)
550 .create_async()
551 .await;
552
553 let result: Bytes = cache.get_cached(&url).await.unwrap();
554 assert_eq!(result.as_ref(), b"cached");
555 }
556
557 #[tokio::test]
558 async fn test_get_cached_last_modified_validation() {
559 let mut server = mockito::Server::new_async().await;
560 let url = format!("{}/api/data", server.url());
561
562 let cache = HttpCache::new();
563
564 cache.entries.insert(
565 url.clone(),
566 CachedResponse {
567 body: Bytes::from_static(b"cached"),
568 etag: None,
569 last_modified: Some("Wed, 21 Oct 2024 07:28:00 GMT".into()),
570 fetched_at: Instant::now(),
571 },
572 );
573
574 let _m = server
575 .mock("GET", "/api/data")
576 .match_header("if-modified-since", "Wed, 21 Oct 2024 07:28:00 GMT")
577 .with_status(304)
578 .create_async()
579 .await;
580
581 let result: Bytes = cache.get_cached(&url).await.unwrap();
582 assert_eq!(result.as_ref(), b"cached");
583 }
584
585 #[tokio::test]
586 async fn test_get_cached_network_error_fallback() {
587 let cache = HttpCache::new();
588 let url = "http://invalid.localhost.test/data";
589
590 cache.entries.insert(
591 url.to_string(),
592 CachedResponse {
593 body: Bytes::from_static(b"stale data"),
594 etag: Some("\"old\"".into()),
595 last_modified: None,
596 fetched_at: Instant::now(),
597 },
598 );
599
600 let result: Bytes = cache.get_cached(url).await.unwrap();
601 assert_eq!(result.as_ref(), b"stale data");
602 }
603
604 #[tokio::test]
605 async fn test_fetch_and_store_http_error() {
606 let mut server = mockito::Server::new_async().await;
607
608 let _m = server
609 .mock("GET", "/api/missing")
610 .with_status(404)
611 .with_body("Not Found")
612 .create_async()
613 .await;
614
615 let cache = HttpCache::new();
616 let url = format!("{}/api/missing", server.url());
617 let result: Result<Bytes> = cache.fetch_and_store(&url).await;
618
619 assert!(result.is_err());
620 match result {
621 Err(DepsError::CacheError(msg)) => {
622 assert!(msg.contains("404"));
623 }
624 _ => panic!("Expected CacheError"),
625 }
626 }
627
628 #[tokio::test]
629 async fn test_fetch_and_store_stores_headers() {
630 let mut server = mockito::Server::new_async().await;
631
632 let _m = server
633 .mock("GET", "/api/data")
634 .with_status(200)
635 .with_header("etag", "\"abc123\"")
636 .with_header("last-modified", "Wed, 21 Oct 2024 07:28:00 GMT")
637 .with_body("test")
638 .create_async()
639 .await;
640
641 let cache = HttpCache::new();
642 let url = format!("{}/api/data", server.url());
643 let _: Bytes = cache.fetch_and_store(&url).await.unwrap();
644
645 let cached = cache.entries.get(&url).unwrap();
646 assert_eq!(cached.etag, Some("\"abc123\"".into()));
647 assert_eq!(
648 cached.last_modified,
649 Some("Wed, 21 Oct 2024 07:28:00 GMT".into())
650 );
651 }
652}