use crate::error::{DepsError, Result};
use bytes::Bytes;
use dashmap::DashMap;
use reqwest::{Client, StatusCode, header};
use std::time::Instant;
const MAX_CACHE_ENTRIES: usize = 1000;
const HTTP_TIMEOUT_SECS: u64 = 30;
const CACHE_EVICTION_PERCENTAGE: usize = 10;
#[inline]
fn ensure_https(url: &str) -> Result<()> {
#[cfg(not(test))]
if !url.starts_with("https://") {
return Err(DepsError::CacheError(format!("URL must use HTTPS: {url}")));
}
#[cfg(test)]
let _ = url; Ok(())
}
#[derive(Debug, Clone)]
pub struct CachedResponse {
pub body: Bytes,
pub etag: Option<String>,
pub last_modified: Option<String>,
pub fetched_at: Instant,
}
pub struct HttpCache {
entries: DashMap<String, CachedResponse>,
client: Client,
}
impl HttpCache {
pub fn new() -> Self {
let client = Client::builder()
.user_agent(format!("deps-lsp/{}", env!("CARGO_PKG_VERSION")))
.timeout(std::time::Duration::from_secs(HTTP_TIMEOUT_SECS))
.build()
.expect("failed to create HTTP client");
Self {
entries: DashMap::new(),
client,
}
}
pub async fn get_cached(&self, url: &str) -> Result<Bytes> {
if self.entries.len() >= MAX_CACHE_ENTRIES {
self.evict_entries();
}
if let Some(cached) = self.entries.get(url).map(|r| r.clone()) {
match self.conditional_request(url, &cached).await {
Ok(Some(new_body)) => {
return Ok(new_body);
}
Ok(None) => {
return Ok(cached.body);
}
Err(e) => {
tracing::warn!("conditional request failed, using cache: {e}");
return Ok(cached.body);
}
}
}
self.fetch_and_store(url).await
}
pub async fn get_cached_with_headers(
&self,
url: &str,
extra_headers: &[(header::HeaderName, &str)],
) -> Result<Bytes> {
if self.entries.len() >= MAX_CACHE_ENTRIES {
self.evict_entries();
}
if let Some(cached) = self.entries.get(url).map(|r| r.clone()) {
match self
.conditional_request_with_headers(url, &cached, extra_headers)
.await
{
Ok(Some(new_body)) => return Ok(new_body),
Ok(None) => return Ok(cached.body),
Err(e) => {
tracing::warn!("conditional request failed, using cache: {e}");
return Ok(cached.body);
}
}
}
self.fetch_and_store_with_headers(url, extra_headers).await
}
async fn conditional_request(
&self,
url: &str,
cached: &CachedResponse,
) -> Result<Option<Bytes>> {
ensure_https(url)?;
let mut request = self.client.get(url);
if let Some(etag) = &cached.etag {
request = request.header(header::IF_NONE_MATCH, etag);
}
if let Some(last_modified) = &cached.last_modified {
request = request.header(header::IF_MODIFIED_SINCE, last_modified);
}
let response = request.send().await.map_err(|e| DepsError::RegistryError {
package: url.to_string(),
source: e,
})?;
if response.status() == StatusCode::NOT_MODIFIED {
return Ok(None);
}
let etag = response
.headers()
.get(header::ETAG)
.and_then(|v| v.to_str().ok())
.map(String::from);
let last_modified = response
.headers()
.get(header::LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(String::from);
let body = response
.bytes()
.await
.map_err(|e| DepsError::RegistryError {
package: url.to_string(),
source: e,
})?;
self.entries.insert(
url.to_string(),
CachedResponse {
body: body.clone(),
etag,
last_modified,
fetched_at: Instant::now(),
},
);
Ok(Some(body))
}
pub(crate) async fn fetch_and_store(&self, url: &str) -> Result<Bytes> {
ensure_https(url)?;
tracing::debug!("fetching fresh: {url}");
let response = self
.client
.get(url)
.send()
.await
.map_err(|e| DepsError::RegistryError {
package: url.to_string(),
source: e,
})?;
if !response.status().is_success() {
let status = response.status();
return Err(DepsError::CacheError(format!("HTTP {status} for {url}")));
}
let etag = response
.headers()
.get(header::ETAG)
.and_then(|v| v.to_str().ok())
.map(String::from);
let last_modified = response
.headers()
.get(header::LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(String::from);
let body = response
.bytes()
.await
.map_err(|e| DepsError::RegistryError {
package: url.to_string(),
source: e,
})?;
self.entries.insert(
url.to_string(),
CachedResponse {
body: body.clone(),
etag,
last_modified,
fetched_at: Instant::now(),
},
);
Ok(body)
}
async fn conditional_request_with_headers(
&self,
url: &str,
cached: &CachedResponse,
extra_headers: &[(header::HeaderName, &str)],
) -> Result<Option<Bytes>> {
ensure_https(url)?;
let mut request = self.client.get(url);
for (name, value) in extra_headers {
request = request.header(name, *value);
}
if let Some(etag) = &cached.etag {
request = request.header(header::IF_NONE_MATCH, etag);
}
if let Some(last_modified) = &cached.last_modified {
request = request.header(header::IF_MODIFIED_SINCE, last_modified);
}
let response = request.send().await.map_err(|e| DepsError::RegistryError {
package: url.to_string(),
source: e,
})?;
if response.status() == StatusCode::NOT_MODIFIED {
return Ok(None);
}
if !response.status().is_success() {
let status = response.status();
return Err(DepsError::CacheError(format!("HTTP {status} for {url}")));
}
let etag = response
.headers()
.get(header::ETAG)
.and_then(|v| v.to_str().ok())
.map(String::from);
let last_modified = response
.headers()
.get(header::LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(String::from);
let body = response
.bytes()
.await
.map_err(|e| DepsError::RegistryError {
package: url.to_string(),
source: e,
})?;
self.entries.insert(
url.to_string(),
CachedResponse {
body: body.clone(),
etag,
last_modified,
fetched_at: Instant::now(),
},
);
Ok(Some(body))
}
async fn fetch_and_store_with_headers(
&self,
url: &str,
extra_headers: &[(header::HeaderName, &str)],
) -> Result<Bytes> {
ensure_https(url)?;
tracing::debug!("fetching fresh with headers: {url}");
let mut request = self.client.get(url);
for (name, value) in extra_headers {
request = request.header(name, *value);
}
let response = request.send().await.map_err(|e| DepsError::RegistryError {
package: url.to_string(),
source: e,
})?;
if !response.status().is_success() {
let status = response.status();
return Err(DepsError::CacheError(format!("HTTP {status} for {url}")));
}
let etag = response
.headers()
.get(header::ETAG)
.and_then(|v| v.to_str().ok())
.map(String::from);
let last_modified = response
.headers()
.get(header::LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(String::from);
let body = response
.bytes()
.await
.map_err(|e| DepsError::RegistryError {
package: url.to_string(),
source: e,
})?;
self.entries.insert(
url.to_string(),
CachedResponse {
body: body.clone(),
etag,
last_modified,
fetched_at: Instant::now(),
},
);
Ok(body)
}
pub fn clear(&self) {
self.entries.clear();
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
fn evict_entries(&self) {
use std::cmp::Reverse;
use std::collections::BinaryHeap;
let target_removals = MAX_CACHE_ENTRIES / CACHE_EVICTION_PERCENTAGE;
let mut oldest = BinaryHeap::with_capacity(target_removals);
for entry in &self.entries {
let item = (entry.value().fetched_at, entry.key().clone());
if oldest.len() < target_removals {
oldest.push(Reverse(item));
} else if let Some(Reverse(newest_of_oldest)) = oldest.peek() {
if item.0 < newest_of_oldest.0 {
oldest.pop();
oldest.push(Reverse(item));
}
}
}
let removed = oldest.len();
for Reverse((_, url)) in oldest {
self.entries.remove(&url);
}
tracing::debug!("evicted {} cache entries (O(N) algorithm)", removed);
}
#[doc(hidden)]
pub fn get_for_bench(&self, url: &str) -> Option<Bytes> {
self.entries.get(url).map(|entry| entry.body.clone())
}
#[doc(hidden)]
pub fn insert_for_bench(&self, url: String, response: CachedResponse) {
self.entries.insert(url, response);
}
}
impl Default for HttpCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_creation() {
let cache = HttpCache::new();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_cache_clear() {
let cache = HttpCache::new();
cache.entries.insert(
"test".into(),
CachedResponse {
body: Bytes::from_static(&[1, 2, 3]),
etag: None,
last_modified: None,
fetched_at: Instant::now(),
},
);
assert_eq!(cache.len(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
}
#[test]
fn test_cached_response_clone() {
let response = CachedResponse {
body: Bytes::from_static(&[1, 2, 3]),
etag: Some("test".into()),
last_modified: Some("date".into()),
fetched_at: Instant::now(),
};
let cloned = response.clone();
assert_eq!(response.body, cloned.body);
assert_eq!(response.etag, cloned.etag);
}
#[test]
fn test_cache_len() {
let cache = HttpCache::new();
assert_eq!(cache.len(), 0);
cache.entries.insert(
"url1".into(),
CachedResponse {
body: Bytes::new(),
etag: None,
last_modified: None,
fetched_at: Instant::now(),
},
);
assert_eq!(cache.len(), 1);
}
#[tokio::test]
async fn test_get_cached_fresh_fetch() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/api/data")
.with_status(200)
.with_header("etag", "\"abc123\"")
.with_body("test data")
.create_async()
.await;
let cache = HttpCache::new();
let url = format!("{}/api/data", server.url());
let result: Bytes = cache.get_cached(&url).await.unwrap();
assert_eq!(result.as_ref(), b"test data");
assert_eq!(cache.len(), 1);
}
#[tokio::test]
async fn test_get_cached_cache_hit() {
let mut server = mockito::Server::new_async().await;
let url = format!("{}/api/data", server.url());
let cache = HttpCache::new();
let _m1 = server
.mock("GET", "/api/data")
.with_status(200)
.with_header("etag", "\"abc123\"")
.with_body("original data")
.create_async()
.await;
let result1: Bytes = cache.get_cached(&url).await.unwrap();
assert_eq!(result1.as_ref(), b"original data");
assert_eq!(cache.len(), 1);
drop(_m1);
let _m2 = server
.mock("GET", "/api/data")
.match_header("if-none-match", "\"abc123\"")
.with_status(304)
.create_async()
.await;
let result2: Bytes = cache.get_cached(&url).await.unwrap();
assert_eq!(result2.as_ref(), b"original data");
}
#[tokio::test]
async fn test_get_cached_304_not_modified() {
let mut server = mockito::Server::new_async().await;
let url = format!("{}/api/data", server.url());
let cache = HttpCache::new();
let _m1 = server
.mock("GET", "/api/data")
.with_status(200)
.with_header("etag", "\"abc123\"")
.with_body("original data")
.create_async()
.await;
let result1: Bytes = cache.get_cached(&url).await.unwrap();
assert_eq!(result1.as_ref(), b"original data");
drop(_m1);
let _m2 = server
.mock("GET", "/api/data")
.match_header("if-none-match", "\"abc123\"")
.with_status(304)
.create_async()
.await;
let result2: Bytes = cache.get_cached(&url).await.unwrap();
assert_eq!(result2.as_ref(), b"original data");
}
#[tokio::test]
async fn test_get_cached_etag_validation() {
let mut server = mockito::Server::new_async().await;
let url = format!("{}/api/data", server.url());
let cache = HttpCache::new();
cache.entries.insert(
url.clone(),
CachedResponse {
body: Bytes::from_static(b"cached"),
etag: Some("\"tag123\"".into()),
last_modified: None,
fetched_at: Instant::now(),
},
);
let _m = server
.mock("GET", "/api/data")
.match_header("if-none-match", "\"tag123\"")
.with_status(304)
.create_async()
.await;
let result: Bytes = cache.get_cached(&url).await.unwrap();
assert_eq!(result.as_ref(), b"cached");
}
#[tokio::test]
async fn test_get_cached_last_modified_validation() {
let mut server = mockito::Server::new_async().await;
let url = format!("{}/api/data", server.url());
let cache = HttpCache::new();
cache.entries.insert(
url.clone(),
CachedResponse {
body: Bytes::from_static(b"cached"),
etag: None,
last_modified: Some("Wed, 21 Oct 2024 07:28:00 GMT".into()),
fetched_at: Instant::now(),
},
);
let _m = server
.mock("GET", "/api/data")
.match_header("if-modified-since", "Wed, 21 Oct 2024 07:28:00 GMT")
.with_status(304)
.create_async()
.await;
let result: Bytes = cache.get_cached(&url).await.unwrap();
assert_eq!(result.as_ref(), b"cached");
}
#[tokio::test]
async fn test_get_cached_network_error_fallback() {
let cache = HttpCache::new();
let url = "http://invalid.localhost.test/data";
cache.entries.insert(
url.to_string(),
CachedResponse {
body: Bytes::from_static(b"stale data"),
etag: Some("\"old\"".into()),
last_modified: None,
fetched_at: Instant::now(),
},
);
let result: Bytes = cache.get_cached(url).await.unwrap();
assert_eq!(result.as_ref(), b"stale data");
}
#[tokio::test]
async fn test_fetch_and_store_http_error() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/api/missing")
.with_status(404)
.with_body("Not Found")
.create_async()
.await;
let cache = HttpCache::new();
let url = format!("{}/api/missing", server.url());
let result: Result<Bytes> = cache.fetch_and_store(&url).await;
assert!(result.is_err());
match result {
Err(DepsError::CacheError(msg)) => {
assert!(msg.contains("404"));
}
_ => panic!("Expected CacheError"),
}
}
#[tokio::test]
async fn test_fetch_and_store_stores_headers() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/api/data")
.with_status(200)
.with_header("etag", "\"abc123\"")
.with_header("last-modified", "Wed, 21 Oct 2024 07:28:00 GMT")
.with_body("test")
.create_async()
.await;
let cache = HttpCache::new();
let url = format!("{}/api/data", server.url());
let _: Bytes = cache.fetch_and_store(&url).await.unwrap();
let cached = cache.entries.get(&url).unwrap();
assert_eq!(cached.etag, Some("\"abc123\"".into()));
assert_eq!(
cached.last_modified,
Some("Wed, 21 Oct 2024 07:28:00 GMT".into())
);
}
}