1use crate::error::{DepsError, Result};
2use bytes::Bytes;
3use dashmap::DashMap;
4use reqwest::{Client, StatusCode, header};
5use std::time::Instant;
6
7const MAX_CACHE_ENTRIES: usize = 1000;
9
10const HTTP_TIMEOUT_SECS: u64 = 30;
12
13const CACHE_EVICTION_PERCENTAGE: usize = 10;
15
16#[inline]
23fn ensure_https(url: &str) -> Result<()> {
24 #[cfg(not(test))]
25 if !url.starts_with("https://") {
26 return Err(DepsError::CacheError(format!("URL must use HTTPS: {url}")));
27 }
28 #[cfg(test)]
29 let _ = url; Ok(())
31}
32
33#[derive(Debug, Clone)]
58pub struct CachedResponse {
59 pub body: Bytes,
60 pub etag: Option<String>,
61 pub last_modified: Option<String>,
62 pub fetched_at: Instant,
63}
64
65pub struct HttpCache {
93 entries: DashMap<String, CachedResponse>,
94 client: Client,
95}
96
97impl HttpCache {
98 pub fn new() -> Self {
103 let client = Client::builder()
104 .user_agent(format!("deps-lsp/{}", env!("CARGO_PKG_VERSION")))
105 .timeout(std::time::Duration::from_secs(HTTP_TIMEOUT_SECS))
106 .build()
107 .expect("failed to create HTTP client");
108
109 Self {
110 entries: DashMap::new(),
111 client,
112 }
113 }
114
115 pub async fn get_cached(&self, url: &str) -> Result<Bytes> {
148 if self.entries.len() >= MAX_CACHE_ENTRIES {
150 self.evict_entries();
151 }
152
153 if let Some(cached) = self.entries.get(url).map(|r| r.clone()) {
154 match self.conditional_request(url, &cached).await {
158 Ok(Some(new_body)) => {
159 return Ok(new_body);
160 }
161 Ok(None) => {
162 return Ok(cached.body);
163 }
164 Err(e) => {
165 tracing::warn!("conditional request failed, using cache: {e}");
166 return Ok(cached.body);
167 }
168 }
169 }
170
171 self.fetch_and_store(url).await
173 }
174
175 pub async fn get_cached_with_headers(
180 &self,
181 url: &str,
182 extra_headers: &[(header::HeaderName, &str)],
183 ) -> Result<Bytes> {
184 if self.entries.len() >= MAX_CACHE_ENTRIES {
185 self.evict_entries();
186 }
187
188 if let Some(cached) = self.entries.get(url).map(|r| r.clone()) {
189 match self
190 .conditional_request_with_headers(url, &cached, extra_headers)
191 .await
192 {
193 Ok(Some(new_body)) => return Ok(new_body),
194 Ok(None) => return Ok(cached.body),
195 Err(e) => {
196 tracing::warn!("conditional request failed, using cache: {e}");
197 return Ok(cached.body);
198 }
199 }
200 }
201
202 self.fetch_and_store_with_headers(url, extra_headers).await
203 }
204
205 async fn conditional_request(
216 &self,
217 url: &str,
218 cached: &CachedResponse,
219 ) -> Result<Option<Bytes>> {
220 ensure_https(url)?;
221 let mut request = self.client.get(url);
222
223 if let Some(etag) = &cached.etag {
224 request = request.header(header::IF_NONE_MATCH, etag);
225 }
226 if let Some(last_modified) = &cached.last_modified {
227 request = request.header(header::IF_MODIFIED_SINCE, last_modified);
228 }
229
230 let response = request.send().await.map_err(|e| DepsError::RegistryError {
231 package: url.to_string(),
232 source: e,
233 })?;
234
235 if response.status() == StatusCode::NOT_MODIFIED {
236 return Ok(None);
238 }
239
240 let etag = response
242 .headers()
243 .get(header::ETAG)
244 .and_then(|v| v.to_str().ok())
245 .map(String::from);
246
247 let last_modified = response
248 .headers()
249 .get(header::LAST_MODIFIED)
250 .and_then(|v| v.to_str().ok())
251 .map(String::from);
252
253 let body = response
254 .bytes()
255 .await
256 .map_err(|e| DepsError::RegistryError {
257 package: url.to_string(),
258 source: e,
259 })?;
260
261 self.entries.insert(
263 url.to_string(),
264 CachedResponse {
265 body: body.clone(),
266 etag,
267 last_modified,
268 fetched_at: Instant::now(),
269 },
270 );
271
272 Ok(Some(body))
273 }
274
275 pub(crate) async fn fetch_and_store(&self, url: &str) -> Result<Bytes> {
286 ensure_https(url)?;
287 tracing::debug!("fetching fresh: {url}");
288
289 let response = self
290 .client
291 .get(url)
292 .send()
293 .await
294 .map_err(|e| DepsError::RegistryError {
295 package: url.to_string(),
296 source: e,
297 })?;
298
299 if !response.status().is_success() {
300 let status = response.status();
301 return Err(DepsError::CacheError(format!("HTTP {status} for {url}")));
302 }
303
304 let etag = response
305 .headers()
306 .get(header::ETAG)
307 .and_then(|v| v.to_str().ok())
308 .map(String::from);
309
310 let last_modified = response
311 .headers()
312 .get(header::LAST_MODIFIED)
313 .and_then(|v| v.to_str().ok())
314 .map(String::from);
315
316 let body = response
317 .bytes()
318 .await
319 .map_err(|e| DepsError::RegistryError {
320 package: url.to_string(),
321 source: e,
322 })?;
323
324 self.entries.insert(
325 url.to_string(),
326 CachedResponse {
327 body: body.clone(),
328 etag,
329 last_modified,
330 fetched_at: Instant::now(),
331 },
332 );
333
334 Ok(body)
335 }
336
337 async fn conditional_request_with_headers(
338 &self,
339 url: &str,
340 cached: &CachedResponse,
341 extra_headers: &[(header::HeaderName, &str)],
342 ) -> Result<Option<Bytes>> {
343 ensure_https(url)?;
344 let mut request = self.client.get(url);
345
346 for (name, value) in extra_headers {
347 request = request.header(name, *value);
348 }
349 if let Some(etag) = &cached.etag {
350 request = request.header(header::IF_NONE_MATCH, etag);
351 }
352 if let Some(last_modified) = &cached.last_modified {
353 request = request.header(header::IF_MODIFIED_SINCE, last_modified);
354 }
355
356 let response = request.send().await.map_err(|e| DepsError::RegistryError {
357 package: url.to_string(),
358 source: e,
359 })?;
360
361 if response.status() == StatusCode::NOT_MODIFIED {
362 return Ok(None);
363 }
364
365 if !response.status().is_success() {
366 let status = response.status();
367 return Err(DepsError::CacheError(format!("HTTP {status} for {url}")));
368 }
369
370 let etag = response
371 .headers()
372 .get(header::ETAG)
373 .and_then(|v| v.to_str().ok())
374 .map(String::from);
375 let last_modified = response
376 .headers()
377 .get(header::LAST_MODIFIED)
378 .and_then(|v| v.to_str().ok())
379 .map(String::from);
380 let body = response
381 .bytes()
382 .await
383 .map_err(|e| DepsError::RegistryError {
384 package: url.to_string(),
385 source: e,
386 })?;
387
388 self.entries.insert(
389 url.to_string(),
390 CachedResponse {
391 body: body.clone(),
392 etag,
393 last_modified,
394 fetched_at: Instant::now(),
395 },
396 );
397
398 Ok(Some(body))
399 }
400
401 async fn fetch_and_store_with_headers(
402 &self,
403 url: &str,
404 extra_headers: &[(header::HeaderName, &str)],
405 ) -> Result<Bytes> {
406 ensure_https(url)?;
407 tracing::debug!("fetching fresh with headers: {url}");
408
409 let mut request = self.client.get(url);
410 for (name, value) in extra_headers {
411 request = request.header(name, *value);
412 }
413
414 let response = request.send().await.map_err(|e| DepsError::RegistryError {
415 package: url.to_string(),
416 source: e,
417 })?;
418
419 if !response.status().is_success() {
420 let status = response.status();
421 return Err(DepsError::CacheError(format!("HTTP {status} for {url}")));
422 }
423
424 let etag = response
425 .headers()
426 .get(header::ETAG)
427 .and_then(|v| v.to_str().ok())
428 .map(String::from);
429 let last_modified = response
430 .headers()
431 .get(header::LAST_MODIFIED)
432 .and_then(|v| v.to_str().ok())
433 .map(String::from);
434 let body = response
435 .bytes()
436 .await
437 .map_err(|e| DepsError::RegistryError {
438 package: url.to_string(),
439 source: e,
440 })?;
441
442 self.entries.insert(
443 url.to_string(),
444 CachedResponse {
445 body: body.clone(),
446 etag,
447 last_modified,
448 fetched_at: Instant::now(),
449 },
450 );
451
452 Ok(body)
453 }
454
455 pub fn clear(&self) {
460 self.entries.clear();
461 }
462
463 pub fn len(&self) -> usize {
465 self.entries.len()
466 }
467
468 pub fn is_empty(&self) -> bool {
470 self.entries.is_empty()
471 }
472
473 fn evict_entries(&self) {
481 use std::cmp::Reverse;
482 use std::collections::BinaryHeap;
483
484 let target_removals = MAX_CACHE_ENTRIES / CACHE_EVICTION_PERCENTAGE;
485
486 let mut oldest = BinaryHeap::with_capacity(target_removals);
489
490 for entry in &self.entries {
491 let item = (entry.value().fetched_at, entry.key().clone());
492
493 if oldest.len() < target_removals {
494 oldest.push(Reverse(item));
496 } else if let Some(Reverse(newest_of_oldest)) = oldest.peek() {
497 if item.0 < newest_of_oldest.0 {
500 oldest.pop();
501 oldest.push(Reverse(item));
502 }
503 }
504 }
505
506 let removed = oldest.len();
508 for Reverse((_, url)) in oldest {
509 self.entries.remove(&url);
510 }
511
512 tracing::debug!("evicted {} cache entries (O(N) algorithm)", removed);
513 }
514
515 #[doc(hidden)]
517 pub fn get_for_bench(&self, url: &str) -> Option<Bytes> {
518 self.entries.get(url).map(|entry| entry.body.clone())
519 }
520
521 #[doc(hidden)]
523 pub fn insert_for_bench(&self, url: String, response: CachedResponse) {
524 self.entries.insert(url, response);
525 }
526}
527
528impl Default for HttpCache {
529 fn default() -> Self {
530 Self::new()
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
539 fn test_cache_creation() {
540 let cache = HttpCache::new();
541 assert_eq!(cache.len(), 0);
542 assert!(cache.is_empty());
543 }
544
545 #[test]
546 fn test_cache_clear() {
547 let cache = HttpCache::new();
548 cache.entries.insert(
549 "test".into(),
550 CachedResponse {
551 body: Bytes::from_static(&[1, 2, 3]),
552 etag: None,
553 last_modified: None,
554 fetched_at: Instant::now(),
555 },
556 );
557 assert_eq!(cache.len(), 1);
558 cache.clear();
559 assert_eq!(cache.len(), 0);
560 }
561
562 #[test]
563 fn test_cached_response_clone() {
564 let response = CachedResponse {
565 body: Bytes::from_static(&[1, 2, 3]),
566 etag: Some("test".into()),
567 last_modified: Some("date".into()),
568 fetched_at: Instant::now(),
569 };
570 let cloned = response.clone();
571 assert_eq!(response.body, cloned.body);
573 assert_eq!(response.etag, cloned.etag);
574 }
575
576 #[test]
577 fn test_cache_len() {
578 let cache = HttpCache::new();
579 assert_eq!(cache.len(), 0);
580
581 cache.entries.insert(
582 "url1".into(),
583 CachedResponse {
584 body: Bytes::new(),
585 etag: None,
586 last_modified: None,
587 fetched_at: Instant::now(),
588 },
589 );
590
591 assert_eq!(cache.len(), 1);
592 }
593
594 #[tokio::test]
595 async fn test_get_cached_fresh_fetch() {
596 let mut server = mockito::Server::new_async().await;
597
598 let _m = server
599 .mock("GET", "/api/data")
600 .with_status(200)
601 .with_header("etag", "\"abc123\"")
602 .with_body("test data")
603 .create_async()
604 .await;
605
606 let cache = HttpCache::new();
607 let url = format!("{}/api/data", server.url());
608 let result: Bytes = cache.get_cached(&url).await.unwrap();
609
610 assert_eq!(result.as_ref(), b"test data");
611 assert_eq!(cache.len(), 1);
612 }
613
614 #[tokio::test]
615 async fn test_get_cached_cache_hit() {
616 let mut server = mockito::Server::new_async().await;
617 let url = format!("{}/api/data", server.url());
618
619 let cache = HttpCache::new();
620
621 let _m1 = server
622 .mock("GET", "/api/data")
623 .with_status(200)
624 .with_header("etag", "\"abc123\"")
625 .with_body("original data")
626 .create_async()
627 .await;
628
629 let result1: Bytes = cache.get_cached(&url).await.unwrap();
630 assert_eq!(result1.as_ref(), b"original data");
631 assert_eq!(cache.len(), 1);
632
633 drop(_m1);
634
635 let _m2 = server
636 .mock("GET", "/api/data")
637 .match_header("if-none-match", "\"abc123\"")
638 .with_status(304)
639 .create_async()
640 .await;
641
642 let result2: Bytes = cache.get_cached(&url).await.unwrap();
643 assert_eq!(result2.as_ref(), b"original data");
644 }
645
646 #[tokio::test]
647 async fn test_get_cached_304_not_modified() {
648 let mut server = mockito::Server::new_async().await;
649 let url = format!("{}/api/data", server.url());
650
651 let cache = HttpCache::new();
652
653 let _m1 = server
654 .mock("GET", "/api/data")
655 .with_status(200)
656 .with_header("etag", "\"abc123\"")
657 .with_body("original data")
658 .create_async()
659 .await;
660
661 let result1: Bytes = cache.get_cached(&url).await.unwrap();
662 assert_eq!(result1.as_ref(), b"original data");
663
664 drop(_m1);
665
666 let _m2 = server
667 .mock("GET", "/api/data")
668 .match_header("if-none-match", "\"abc123\"")
669 .with_status(304)
670 .create_async()
671 .await;
672
673 let result2: Bytes = cache.get_cached(&url).await.unwrap();
674 assert_eq!(result2.as_ref(), b"original data");
675 }
676
677 #[tokio::test]
678 async fn test_get_cached_etag_validation() {
679 let mut server = mockito::Server::new_async().await;
680 let url = format!("{}/api/data", server.url());
681
682 let cache = HttpCache::new();
683
684 cache.entries.insert(
685 url.clone(),
686 CachedResponse {
687 body: Bytes::from_static(b"cached"),
688 etag: Some("\"tag123\"".into()),
689 last_modified: None,
690 fetched_at: Instant::now(),
691 },
692 );
693
694 let _m = server
695 .mock("GET", "/api/data")
696 .match_header("if-none-match", "\"tag123\"")
697 .with_status(304)
698 .create_async()
699 .await;
700
701 let result: Bytes = cache.get_cached(&url).await.unwrap();
702 assert_eq!(result.as_ref(), b"cached");
703 }
704
705 #[tokio::test]
706 async fn test_get_cached_last_modified_validation() {
707 let mut server = mockito::Server::new_async().await;
708 let url = format!("{}/api/data", server.url());
709
710 let cache = HttpCache::new();
711
712 cache.entries.insert(
713 url.clone(),
714 CachedResponse {
715 body: Bytes::from_static(b"cached"),
716 etag: None,
717 last_modified: Some("Wed, 21 Oct 2024 07:28:00 GMT".into()),
718 fetched_at: Instant::now(),
719 },
720 );
721
722 let _m = server
723 .mock("GET", "/api/data")
724 .match_header("if-modified-since", "Wed, 21 Oct 2024 07:28:00 GMT")
725 .with_status(304)
726 .create_async()
727 .await;
728
729 let result: Bytes = cache.get_cached(&url).await.unwrap();
730 assert_eq!(result.as_ref(), b"cached");
731 }
732
733 #[tokio::test]
734 async fn test_get_cached_network_error_fallback() {
735 let cache = HttpCache::new();
736 let url = "http://invalid.localhost.test/data";
737
738 cache.entries.insert(
739 url.to_string(),
740 CachedResponse {
741 body: Bytes::from_static(b"stale data"),
742 etag: Some("\"old\"".into()),
743 last_modified: None,
744 fetched_at: Instant::now(),
745 },
746 );
747
748 let result: Bytes = cache.get_cached(url).await.unwrap();
749 assert_eq!(result.as_ref(), b"stale data");
750 }
751
752 #[tokio::test]
753 async fn test_fetch_and_store_http_error() {
754 let mut server = mockito::Server::new_async().await;
755
756 let _m = server
757 .mock("GET", "/api/missing")
758 .with_status(404)
759 .with_body("Not Found")
760 .create_async()
761 .await;
762
763 let cache = HttpCache::new();
764 let url = format!("{}/api/missing", server.url());
765 let result: Result<Bytes> = cache.fetch_and_store(&url).await;
766
767 assert!(result.is_err());
768 match result {
769 Err(DepsError::CacheError(msg)) => {
770 assert!(msg.contains("404"));
771 }
772 _ => panic!("Expected CacheError"),
773 }
774 }
775
776 #[tokio::test]
777 async fn test_fetch_and_store_stores_headers() {
778 let mut server = mockito::Server::new_async().await;
779
780 let _m = server
781 .mock("GET", "/api/data")
782 .with_status(200)
783 .with_header("etag", "\"abc123\"")
784 .with_header("last-modified", "Wed, 21 Oct 2024 07:28:00 GMT")
785 .with_body("test")
786 .create_async()
787 .await;
788
789 let cache = HttpCache::new();
790 let url = format!("{}/api/data", server.url());
791 let _: Bytes = cache.fetch_and_store(&url).await.unwrap();
792
793 let cached = cache.entries.get(&url).unwrap();
794 assert_eq!(cached.etag, Some("\"abc123\"".into()));
795 assert_eq!(
796 cached.last_modified,
797 Some("Wed, 21 Oct 2024 07:28:00 GMT".into())
798 );
799 }
800}