1use crate::error::{DepsError, Result};
2use dashmap::DashMap;
3use reqwest::{Client, StatusCode, header};
4use std::sync::Arc;
5use std::time::Instant;
6
7const MAX_CACHE_ENTRIES: usize = 1000;
9
10const HTTP_TIMEOUT_SECS: u64 = 30;
12
13const CACHE_EVICTION_PERCENTAGE: usize = 10;
15
16#[inline]
23fn ensure_https(url: &str) -> Result<()> {
24 #[cfg(not(test))]
25 if !url.starts_with("https://") {
26 return Err(DepsError::CacheError(format!(
27 "URL must use HTTPS: {}",
28 url
29 )));
30 }
31 #[cfg(test)]
32 let _ = url; Ok(())
34}
35
36#[derive(Debug, Clone)]
61pub struct CachedResponse {
62 pub body: Arc<Vec<u8>>,
63 pub etag: Option<String>,
64 pub last_modified: Option<String>,
65 pub fetched_at: Instant,
66}
67
68pub struct HttpCache {
98 entries: DashMap<String, CachedResponse>,
99 client: Client,
100}
101
102impl HttpCache {
103 pub fn new() -> Self {
108 let client = Client::builder()
109 .user_agent(format!("deps-lsp/{}", env!("CARGO_PKG_VERSION")))
110 .timeout(std::time::Duration::from_secs(HTTP_TIMEOUT_SECS))
111 .build()
112 .expect("failed to create HTTP client");
113
114 Self {
115 entries: DashMap::new(),
116 client,
117 }
118 }
119
120 pub async fn get_cached(&self, url: &str) -> Result<Arc<Vec<u8>>> {
154 if self.entries.len() >= MAX_CACHE_ENTRIES {
156 self.evict_entries();
157 }
158
159 if let Some(cached) = self.entries.get(url) {
160 match self.conditional_request(url, &cached).await {
162 Ok(Some(new_body)) => {
163 return Ok(new_body);
165 }
166 Ok(None) => {
167 return Ok(Arc::clone(&cached.body));
169 }
170 Err(e) => {
171 tracing::warn!("conditional request failed, using cache: {}", e);
173 return Ok(Arc::clone(&cached.body));
174 }
175 }
176 }
177
178 self.fetch_and_store(url).await
180 }
181
182 async fn conditional_request(
193 &self,
194 url: &str,
195 cached: &CachedResponse,
196 ) -> Result<Option<Arc<Vec<u8>>>> {
197 ensure_https(url)?;
198 let mut request = self.client.get(url);
199
200 if let Some(etag) = &cached.etag {
201 request = request.header(header::IF_NONE_MATCH, etag);
202 }
203 if let Some(last_modified) = &cached.last_modified {
204 request = request.header(header::IF_MODIFIED_SINCE, last_modified);
205 }
206
207 let response = request.send().await.map_err(|e| DepsError::RegistryError {
208 package: url.to_string(),
209 source: e,
210 })?;
211
212 if response.status() == StatusCode::NOT_MODIFIED {
213 return Ok(None);
215 }
216
217 let etag = response
219 .headers()
220 .get(header::ETAG)
221 .and_then(|v| v.to_str().ok())
222 .map(String::from);
223
224 let last_modified = response
225 .headers()
226 .get(header::LAST_MODIFIED)
227 .and_then(|v| v.to_str().ok())
228 .map(String::from);
229
230 let body = response
231 .bytes()
232 .await
233 .map_err(|e| DepsError::RegistryError {
234 package: url.to_string(),
235 source: e,
236 })?;
237
238 let body_arc = Arc::new(body.to_vec());
239
240 self.entries.insert(
242 url.to_string(),
243 CachedResponse {
244 body: Arc::clone(&body_arc),
245 etag,
246 last_modified,
247 fetched_at: Instant::now(),
248 },
249 );
250
251 Ok(Some(body_arc))
252 }
253
254 pub(crate) async fn fetch_and_store(&self, url: &str) -> Result<Arc<Vec<u8>>> {
265 ensure_https(url)?;
266 tracing::debug!("fetching fresh: {}", url);
267
268 let response = self
269 .client
270 .get(url)
271 .send()
272 .await
273 .map_err(|e| DepsError::RegistryError {
274 package: url.to_string(),
275 source: e,
276 })?;
277
278 if !response.status().is_success() {
279 return Err(DepsError::CacheError(format!(
280 "HTTP {} for {}",
281 response.status(),
282 url
283 )));
284 }
285
286 let etag = response
287 .headers()
288 .get(header::ETAG)
289 .and_then(|v| v.to_str().ok())
290 .map(String::from);
291
292 let last_modified = response
293 .headers()
294 .get(header::LAST_MODIFIED)
295 .and_then(|v| v.to_str().ok())
296 .map(String::from);
297
298 let body = response
299 .bytes()
300 .await
301 .map_err(|e| DepsError::RegistryError {
302 package: url.to_string(),
303 source: e,
304 })?;
305
306 let body_arc = Arc::new(body.to_vec());
307
308 self.entries.insert(
309 url.to_string(),
310 CachedResponse {
311 body: Arc::clone(&body_arc),
312 etag,
313 last_modified,
314 fetched_at: Instant::now(),
315 },
316 );
317
318 Ok(body_arc)
319 }
320
321 pub fn clear(&self) {
326 self.entries.clear();
327 }
328
329 pub fn len(&self) -> usize {
331 self.entries.len()
332 }
333
334 pub fn is_empty(&self) -> bool {
336 self.entries.is_empty()
337 }
338
339 fn evict_entries(&self) {
344 let target_removals = MAX_CACHE_ENTRIES / CACHE_EVICTION_PERCENTAGE;
345 let mut removed = 0;
346
347 let mut entries_to_remove = Vec::new();
349
350 for entry in self.entries.iter() {
351 entries_to_remove.push((entry.key().clone(), entry.value().fetched_at));
352 if entries_to_remove.len() >= MAX_CACHE_ENTRIES {
353 break;
354 }
355 }
356
357 entries_to_remove.sort_by_key(|(_, time)| *time);
359
360 for (url, _) in entries_to_remove.iter().take(target_removals) {
362 self.entries.remove(url);
363 removed += 1;
364 }
365
366 tracing::debug!("evicted {} cache entries", removed);
367 }
368
369 #[doc(hidden)]
371 pub fn get_for_bench(&self, url: &str) -> Option<Arc<Vec<u8>>> {
372 self.entries.get(url).map(|entry| Arc::clone(&entry.body))
373 }
374
375 #[doc(hidden)]
377 pub fn insert_for_bench(&self, url: String, response: CachedResponse) {
378 self.entries.insert(url, response);
379 }
380}
381
382impl Default for HttpCache {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_cache_creation() {
394 let cache = HttpCache::new();
395 assert_eq!(cache.len(), 0);
396 assert!(cache.is_empty());
397 }
398
399 #[test]
400 fn test_cache_clear() {
401 let cache = HttpCache::new();
402 cache.entries.insert(
403 "test".into(),
404 CachedResponse {
405 body: Arc::new(vec![1, 2, 3]),
406 etag: None,
407 last_modified: None,
408 fetched_at: Instant::now(),
409 },
410 );
411 assert_eq!(cache.len(), 1);
412 cache.clear();
413 assert_eq!(cache.len(), 0);
414 }
415
416 #[test]
417 fn test_cached_response_clone() {
418 let response = CachedResponse {
419 body: Arc::new(vec![1, 2, 3]),
420 etag: Some("test".into()),
421 last_modified: Some("date".into()),
422 fetched_at: Instant::now(),
423 };
424 let cloned = response.clone();
425 assert!(Arc::ptr_eq(&response.body, &cloned.body));
427 assert_eq!(response.etag, cloned.etag);
428 }
429
430 #[test]
431 fn test_cache_len() {
432 let cache = HttpCache::new();
433 assert_eq!(cache.len(), 0);
434
435 cache.entries.insert(
436 "url1".into(),
437 CachedResponse {
438 body: Arc::new(vec![]),
439 etag: None,
440 last_modified: None,
441 fetched_at: Instant::now(),
442 },
443 );
444
445 assert_eq!(cache.len(), 1);
446 }
447
448 #[tokio::test]
449 async fn test_get_cached_fresh_fetch() {
450 let mut server = mockito::Server::new_async().await;
451
452 let _m = server
453 .mock("GET", "/api/data")
454 .with_status(200)
455 .with_header("etag", "\"abc123\"")
456 .with_body("test data")
457 .create_async()
458 .await;
459
460 let cache = HttpCache::new();
461 let url = format!("{}/api/data", server.url());
462 let result = cache.get_cached(&url).await.unwrap();
463
464 assert_eq!(&**result, b"test data");
465 assert_eq!(cache.len(), 1);
466 }
467
468 #[tokio::test]
469 async fn test_get_cached_cache_hit() {
470 let mut server = mockito::Server::new_async().await;
471 let url = format!("{}/api/data", server.url());
472
473 let cache = HttpCache::new();
474
475 let _m1 = server
476 .mock("GET", "/api/data")
477 .with_status(200)
478 .with_header("etag", "\"abc123\"")
479 .with_body("original data")
480 .create_async()
481 .await;
482
483 let result1 = cache.get_cached(&url).await.unwrap();
484 assert_eq!(&**result1, b"original data");
485 assert_eq!(cache.len(), 1);
486
487 drop(_m1);
488
489 let _m2 = server
490 .mock("GET", "/api/data")
491 .match_header("if-none-match", "\"abc123\"")
492 .with_status(304)
493 .create_async()
494 .await;
495
496 let result2 = cache.get_cached(&url).await.unwrap();
497 assert_eq!(&**result2, b"original data");
498 }
499
500 #[tokio::test]
501 async fn test_get_cached_304_not_modified() {
502 let mut server = mockito::Server::new_async().await;
503 let url = format!("{}/api/data", server.url());
504
505 let cache = HttpCache::new();
506
507 let _m1 = server
508 .mock("GET", "/api/data")
509 .with_status(200)
510 .with_header("etag", "\"abc123\"")
511 .with_body("original data")
512 .create_async()
513 .await;
514
515 let result1 = cache.get_cached(&url).await.unwrap();
516 assert_eq!(&**result1, b"original data");
517
518 drop(_m1);
519
520 let _m2 = server
521 .mock("GET", "/api/data")
522 .match_header("if-none-match", "\"abc123\"")
523 .with_status(304)
524 .create_async()
525 .await;
526
527 let result2 = cache.get_cached(&url).await.unwrap();
528 assert_eq!(&**result2, b"original data");
529 }
530
531 #[tokio::test]
532 async fn test_get_cached_etag_validation() {
533 let mut server = mockito::Server::new_async().await;
534 let url = format!("{}/api/data", server.url());
535
536 let cache = HttpCache::new();
537
538 cache.entries.insert(
539 url.clone(),
540 CachedResponse {
541 body: Arc::new(b"cached".to_vec()),
542 etag: Some("\"tag123\"".into()),
543 last_modified: None,
544 fetched_at: Instant::now(),
545 },
546 );
547
548 let _m = server
549 .mock("GET", "/api/data")
550 .match_header("if-none-match", "\"tag123\"")
551 .with_status(304)
552 .create_async()
553 .await;
554
555 let result = cache.get_cached(&url).await.unwrap();
556 assert_eq!(&**result, b"cached");
557 }
558
559 #[tokio::test]
560 async fn test_get_cached_last_modified_validation() {
561 let mut server = mockito::Server::new_async().await;
562 let url = format!("{}/api/data", server.url());
563
564 let cache = HttpCache::new();
565
566 cache.entries.insert(
567 url.clone(),
568 CachedResponse {
569 body: Arc::new(b"cached".to_vec()),
570 etag: None,
571 last_modified: Some("Wed, 21 Oct 2024 07:28:00 GMT".into()),
572 fetched_at: Instant::now(),
573 },
574 );
575
576 let _m = server
577 .mock("GET", "/api/data")
578 .match_header("if-modified-since", "Wed, 21 Oct 2024 07:28:00 GMT")
579 .with_status(304)
580 .create_async()
581 .await;
582
583 let result = cache.get_cached(&url).await.unwrap();
584 assert_eq!(&**result, b"cached");
585 }
586
587 #[tokio::test]
588 async fn test_get_cached_network_error_fallback() {
589 let cache = HttpCache::new();
590 let url = "http://invalid.localhost.test/data";
591
592 cache.entries.insert(
593 url.to_string(),
594 CachedResponse {
595 body: Arc::new(b"stale data".to_vec()),
596 etag: Some("\"old\"".into()),
597 last_modified: None,
598 fetched_at: Instant::now(),
599 },
600 );
601
602 let result = cache.get_cached(url).await.unwrap();
603 assert_eq!(&**result, b"stale data");
604 }
605
606 #[tokio::test]
607 async fn test_fetch_and_store_http_error() {
608 let mut server = mockito::Server::new_async().await;
609
610 let _m = server
611 .mock("GET", "/api/missing")
612 .with_status(404)
613 .with_body("Not Found")
614 .create_async()
615 .await;
616
617 let cache = HttpCache::new();
618 let url = format!("{}/api/missing", server.url());
619 let result = cache.fetch_and_store(&url).await;
620
621 assert!(result.is_err());
622 match result {
623 Err(DepsError::CacheError(msg)) => {
624 assert!(msg.contains("404"));
625 }
626 _ => panic!("Expected CacheError"),
627 }
628 }
629
630 #[tokio::test]
631 async fn test_fetch_and_store_stores_headers() {
632 let mut server = mockito::Server::new_async().await;
633
634 let _m = server
635 .mock("GET", "/api/data")
636 .with_status(200)
637 .with_header("etag", "\"abc123\"")
638 .with_header("last-modified", "Wed, 21 Oct 2024 07:28:00 GMT")
639 .with_body("test")
640 .create_async()
641 .await;
642
643 let cache = HttpCache::new();
644 let url = format!("{}/api/data", server.url());
645 cache.fetch_and_store(&url).await.unwrap();
646
647 let cached = cache.entries.get(&url).unwrap();
648 assert_eq!(cached.etag, Some("\"abc123\"".into()));
649 assert_eq!(
650 cached.last_modified,
651 Some("Wed, 21 Oct 2024 07:28:00 GMT".into())
652 );
653 }
654}