use std::collections::HashMap;
use std::ffi::OsStr;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use anyhow::{Context, Result};
use tokio::sync::Notify;
use walkdir::WalkDir;
fn http_date(t: std::time::SystemTime) -> String {
chrono::DateTime::<chrono::Utc>::from(t)
.format("%a, %d %b %Y %H:%M:%S GMT")
.to_string()
}
struct CacheEntry {
path: PathBuf,
size: u64,
lock_count: u32,
}
struct DownloaderInner {
entries: HashMap<String, CacheEntry>,
total_bytes: u64,
}
struct DownloaderState {
inner: Mutex<DownloaderInner>,
notify: Notify,
}
#[derive(Debug)]
pub struct UploadForbidden;
impl std::fmt::Display for UploadForbidden {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "HTTP 403 Forbidden")
}
}
impl std::error::Error for UploadForbidden {}
pub struct TraceDownloader {
download_dir: PathBuf,
disk_limit: u64,
jwt: Option<String>,
local_base: Option<PathBuf>,
client: reqwest::Client,
download_client: reqwest::Client,
state: Arc<DownloaderState>,
}
pub struct TraceFile {
path: PathBuf,
cache_ref: Option<(String, Arc<DownloaderState>)>,
}
impl TraceFile {
pub fn path(&self) -> &Path {
&self.path
}
}
impl Drop for TraceFile {
fn drop(&mut self) {
if let Some((key, state)) = &self.cache_ref {
let mut inner = state.inner.lock().unwrap();
if let Some(entry) = inner.entries.get_mut(key) {
entry.lock_count = entry.lock_count.saturating_sub(1);
}
drop(inner);
state.notify.notify_waiters();
}
}
}
async fn head_headers(
client: &reqwest::Client,
jwt: Option<&str>,
url: &str,
) -> Result<reqwest::header::HeaderMap> {
let mut req = client.head(url);
if let Some(token) = jwt {
req = req.header("Authorization", format!("Bearer {token}"));
}
let response = req.send().await.with_context(|| format!("HEAD {url}"))?;
if !response.status().is_success() {
anyhow::bail!("HEAD {url}: HTTP {}", response.status());
}
Ok(response.headers().clone())
}
fn normalize_etag(raw: &str) -> String {
let s = raw.trim().strip_prefix("W/").unwrap_or(raw.trim());
s.trim_matches('"').to_lowercase()
}
async fn fetch_server_etag(
client: &reqwest::Client,
jwt: Option<&str>,
url: &str,
) -> Result<Option<String>> {
let headers = head_headers(client, jwt, url).await?;
Ok(headers
.get(reqwest::header::ETAG)
.and_then(|v| v.to_str().ok())
.map(normalize_etag))
}
fn guess_s3_chunk_size(file_size: u64, num_parts: usize) -> u64 {
const MB: u64 = 1024 * 1024;
for &mb in &[5u64, 10] {
let bytes = mb * MB;
if file_size.div_ceil(bytes) == num_parts as u64 {
return bytes;
}
}
(file_size / num_parts as u64 / MB + 1) * MB
}
fn compute_local_etag(path: &Path, server_etag: &str) -> Result<String> {
let data = std::fs::read(path).with_context(|| format!("reading {}", path.display()))?;
if let Some((_, n_str)) = server_etag.rsplit_once('-') {
if let Ok(num_parts) = n_str.parse::<usize>() {
let chunk_size = guess_s3_chunk_size(data.len() as u64, num_parts) as usize;
let part_digests: Vec<[u8; 16]> =
data.chunks(chunk_size).map(|c| md5::compute(c).0).collect();
let combined: Vec<u8> = part_digests
.iter()
.flat_map(|d| d.iter().copied())
.collect();
return Ok(format!(
"{:x}-{}",
md5::compute(&combined),
part_digests.len()
));
}
}
Ok(format!("{:x}", md5::compute(&data)))
}
impl TraceDownloader {
async fn acquire_capacity(&self, trace_path: &str) {
if self.disk_limit == 0 {
return;
}
loop {
let notified = self.state.notify.notified();
tokio::pin!(notified);
notified.as_mut().enable();
let has_room = {
let mut inner = self.state.inner.lock().unwrap();
while inner.total_bytes >= self.disk_limit {
let unlocked: Vec<String> = inner
.entries
.iter()
.filter(|(_, e)| e.lock_count == 0)
.map(|(k, _)| k.clone())
.collect();
if unlocked.is_empty() {
break;
}
let key = unlocked[rand::random::<u64>() as usize % unlocked.len()].clone();
let entry = inner.entries.remove(&key).unwrap();
inner.total_bytes = inner.total_bytes.saturating_sub(entry.size);
if let Err(e) = std::fs::remove_file(&entry.path) {
log::warn!(
"Failed to remove evicted cache file {}: {e}",
entry.path.display()
);
}
}
inner.total_bytes < self.disk_limit
};
if has_room {
return;
}
log::debug!("Waiting for disk space before downloading {trace_path}");
notified.await;
}
}
pub fn new(
download_dir: PathBuf,
disk_limit: u64,
jwt_path: Option<&Path>,
local_base: Option<PathBuf>,
) -> Result<Self> {
let mut entries = HashMap::new();
let mut total_bytes = 0u64;
if local_base.is_none() {
std::fs::create_dir_all(&download_dir)
.with_context(|| format!("creating download dir {}", download_dir.display()))?;
for entry in WalkDir::new(&download_dir).min_depth(1) {
let entry = match entry {
Ok(e) => e,
Err(_) => continue,
};
if !entry.file_type().is_file() {
continue;
}
let path = entry.path();
if path.extension() == Some(OsStr::new("tmp")) {
continue;
}
let Some(name) = path
.strip_prefix(&download_dir)
.ok()
.and_then(|r| r.to_str())
.map(|s| s.to_string())
else {
continue;
};
if let Ok(meta) = std::fs::metadata(path) {
let size = meta.len();
total_bytes += size;
entries.insert(
name,
CacheEntry {
path: path.to_path_buf(),
size,
lock_count: 0,
},
);
}
}
}
let jwt = match jwt_path {
Some(path) => {
let content = std::fs::read_to_string(path)
.with_context(|| format!("reading JWT file {}", path.display()))?;
Some(content.trim().to_string())
}
None => None,
};
Ok(TraceDownloader {
download_dir,
disk_limit,
jwt,
local_base,
client: reqwest::Client::new(),
download_client: reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.context("building download client")?,
state: Arc::new(DownloaderState {
inner: Mutex::new(DownloaderInner {
entries,
total_bytes,
}),
notify: Notify::new(),
}),
})
}
pub fn root_for_file(&self, path: &Path) -> &Path {
if let Some(local_base) = &self.local_base {
if path.starts_with(local_base) {
return local_base;
}
}
&self.download_dir
}
pub async fn fetch_bytes(&self, url: &str) -> Result<Vec<u8>> {
let mut req = self.client.get(url);
if let Some(token) = &self.jwt {
req = req.header("Authorization", format!("Bearer {token}"));
}
let resp = req.send().await.with_context(|| format!("GET {url}"))?;
if !resp.status().is_success() {
anyhow::bail!("HTTP {} from {url}", resp.status());
}
Ok(resp
.bytes()
.await
.with_context(|| format!("reading body from {url}"))?
.to_vec())
}
pub async fn upload_if_absent(
&self,
upload_base_url: &str,
checksum: &str,
data: &[u8],
expires: Option<std::time::Duration>,
) -> Result<String> {
let url = format!("{}/{}.png", upload_base_url.trim_end_matches('/'), checksum);
let mut head_req = self.client.head(&url);
if let Some(token) = &self.jwt {
head_req = head_req.header("Authorization", format!("Bearer {token}"));
}
let already_exists = match head_req.send().await {
Ok(r) => r.status().is_success(),
Err(_) => false,
};
if already_exists {
log::debug!("Skipping upload of {url} (already exists)");
return Ok(url);
}
log::debug!("Uploading {url}");
let dir_url = format!("{}/", upload_base_url.trim_end_matches('/'));
let checksum_png = format!("{checksum}.png");
let part = reqwest::multipart::Part::bytes(data.to_vec()).file_name(checksum_png);
let form = reqwest::multipart::Form::new().part("file", part);
let mut put_req = self
.client
.put(&dir_url)
.header("x-amz-acl", "public-read-write")
.multipart(form);
if let Some(token) = &self.jwt {
put_req = put_req.header("Authorization", format!("Bearer {token}"));
}
if let Some(ttl) = expires {
put_req = put_req.header("Expires", http_date(std::time::SystemTime::now() + ttl));
}
let response = put_req
.send()
.await
.with_context(|| format!("PUT {dir_url}"))?;
if !response.status().is_success() {
let status = response.status();
if status == reqwest::StatusCode::FORBIDDEN {
return Err(anyhow::Error::new(UploadForbidden));
}
let body = response
.text()
.await
.unwrap_or_else(|_| "(unreadable body)".to_string());
anyhow::bail!("HTTP {} uploading {}: {}", status, url, body);
}
Ok(url)
}
pub async fn get(&self, download_url: &str, trace_path: &str) -> Result<TraceFile> {
if let Some(local_base) = &self.local_base {
let path = local_base.join(trace_path);
let exists = path.exists()
|| path
.parent()
.and_then(|p| p.metadata().ok())
.is_some_and(|m| !m.is_dir());
if exists {
log::debug!("Using local trace: {}", path.display());
return Ok(TraceFile {
path,
cache_ref: None,
});
}
log::debug!(
"Local trace not found in traces-db, falling back to download: {}",
path.display()
);
}
let filename = trace_path.trim_start_matches('/').to_string();
{
let mut inner = self.state.inner.lock().unwrap();
if let Some(entry) = inner.entries.get_mut(&filename) {
entry.lock_count += 1;
let path = entry.path.clone();
log::debug!("Serving {} from cache: {}", trace_path, path.display());
return Ok(TraceFile {
path,
cache_ref: Some((filename, Arc::clone(&self.state))),
});
}
}
let url = format!(
"{}/{}",
download_url.trim_end_matches('/'),
trace_path.trim_start_matches('/')
);
self.acquire_capacity(trace_path).await;
{
let mut inner = self.state.inner.lock().unwrap();
if let Some(entry) = inner.entries.get_mut(&filename) {
entry.lock_count += 1;
log::debug!("Serving {} from cache after capacity wait", trace_path);
return Ok(TraceFile {
path: entry.path.clone(),
cache_ref: Some((filename, Arc::clone(&self.state))),
});
}
}
let target_path = self.download_dir.join(&filename);
if let Some(parent) = target_path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("creating directory {}", parent.display()))?;
}
log::debug!("Downloading {} to {}", trace_path, target_path.display());
download_file(
&self.download_client,
self.jwt.as_deref(),
&url,
&target_path,
)
.await
.with_context(|| format!("downloading {trace_path}"))?;
log::debug!("Downloading {} completed", trace_path);
let size = std::fs::metadata(&target_path)
.with_context(|| format!("stat of {}", target_path.display()))?
.len();
let mut inner = self.state.inner.lock().unwrap();
if let Some(entry) = inner.entries.get_mut(&filename) {
entry.lock_count += 1;
return Ok(TraceFile {
path: entry.path.clone(),
cache_ref: Some((filename, Arc::clone(&self.state))),
});
}
inner.total_bytes += size;
inner.entries.insert(
filename.clone(),
CacheEntry {
path: target_path.clone(),
size,
lock_count: 1,
},
);
Ok(TraceFile {
path: target_path,
cache_ref: Some((filename, Arc::clone(&self.state))),
})
}
pub async fn invalidate_if_corrupted(
&self,
download_url: &str,
trace_path: &str,
) -> Result<bool> {
let filename = trace_path.trim_start_matches('/').to_string();
let local_path = {
let inner = self.state.inner.lock().unwrap();
inner.entries.get(&filename).map(|e| e.path.clone())
};
let local_path = match local_path {
Some(p) => p,
None => return Ok(false),
};
let url = format!(
"{}/{}",
download_url.trim_end_matches('/'),
trace_path.trim_start_matches('/')
);
let server_etag = match fetch_server_etag(&self.client, self.jwt.as_deref(), &url).await? {
Some(etag) => etag,
None => {
log::warn!(
"{trace_path}: server returned no ETag, evicting unverifiable cache entry"
);
let mut inner = self.state.inner.lock().unwrap();
if let Some(entry) = inner.entries.remove(&filename) {
inner.total_bytes = inner.total_bytes.saturating_sub(entry.size);
}
if let Err(e) = std::fs::remove_file(&local_path) {
log::warn!(
"Failed to remove unverifiable cache file {}: {e}",
local_path.display()
);
}
drop(inner);
self.state.notify.notify_waiters();
return Ok(true);
}
};
let local_etag = compute_local_etag(&local_path, &server_etag)
.with_context(|| format!("computing ETag for {}", local_path.display()))?;
if local_etag == server_etag {
log::debug!("{trace_path}: ETag matches, cache is valid");
return Ok(false);
}
log::warn!(
"{trace_path}: ETag mismatch (local {local_etag}, server {server_etag}), \
invalidating corrupted cache entry"
);
let mut inner = self.state.inner.lock().unwrap();
if let Some(entry) = inner.entries.remove(&filename) {
inner.total_bytes = inner.total_bytes.saturating_sub(entry.size);
}
if let Err(e) = std::fs::remove_file(&local_path) {
log::warn!(
"Failed to remove invalidated cache file {}: {e}",
local_path.display()
);
}
drop(inner);
self.state.notify.notify_waiters();
Ok(true)
}
}
async fn chase_and_get(
client: &reqwest::Client,
jwt: Option<&str>,
url: &str,
) -> Result<reqwest::Response> {
const MAX_REDIRECTS: usize = 10;
let mut current_url = url.to_string();
let mut send_auth = true;
for _ in 0..=MAX_REDIRECTS {
let mut req = client.get(¤t_url);
if send_auth {
if let Some(token) = jwt {
req = req.header("Authorization", format!("Bearer {token}"));
}
}
let resp = req
.send()
.await
.with_context(|| format!("GET {current_url}"))?;
if !resp.status().is_redirection() {
return Ok(resp);
}
let location = resp
.headers()
.get(reqwest::header::LOCATION)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| anyhow::anyhow!("redirect from {current_url} had no Location header"))?
.to_string();
let base = reqwest::Url::parse(¤t_url)
.with_context(|| format!("parsing current URL {current_url}"))?;
let next = base
.join(&location)
.with_context(|| format!("resolving Location {location:?} against {current_url}"))?;
send_auth = false;
current_url = next.to_string();
}
anyhow::bail!("too many redirects from {url}");
}
async fn download_file(
client: &reqwest::Client,
jwt: Option<&str>,
url: &str,
target_path: &Path,
) -> Result<()> {
let response = chase_and_get(client, jwt, url)
.await
.with_context(|| format!("sending GET {url}"))?;
let status = response.status();
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let final_url = response.url().clone();
if !status.is_success() {
if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
anyhow::bail!(
"HTTP {status} downloading {url} (authentication failure - is --jwt required?)"
);
}
anyhow::bail!("HTTP {} downloading {}", status, url);
}
if content_type.contains("text/html") {
anyhow::bail!(
"download from {url} returned HTML instead of a trace file \
(authentication redirect to {final_url}? try --jwt)"
);
}
let bytes = response
.bytes()
.await
.with_context(|| format!("reading response body from {url}"))?;
let parent = target_path.parent().unwrap_or(Path::new("."));
let mut tmp = tempfile::NamedTempFile::new_in(parent)
.with_context(|| format!("creating temp file in {}", parent.display()))?;
std::io::Write::write_all(&mut tmp, &bytes)
.with_context(|| format!("writing to temp file for {}", target_path.display()))?;
tmp.persist(target_path)
.map_err(|e| anyhow::anyhow!("renaming temp file to {}: {}", target_path.display(), e))?;
Ok(())
}
#[cfg(test)]
impl TraceDownloader {
fn make_test_token(&self, path: PathBuf, size: u64) -> TraceFile {
let key = path.file_name().unwrap().to_str().unwrap().to_string();
let mut inner = self.state.inner.lock().unwrap();
inner.entries.insert(
key.clone(),
CacheEntry {
path: path.clone(),
size,
lock_count: 1,
},
);
inner.total_bytes += size;
TraceFile {
path,
cache_ref: Some((key, Arc::clone(&self.state))),
}
}
async fn wait_and_make_token(&self, path: PathBuf) -> TraceFile {
let key = path.file_name().unwrap().to_str().unwrap().to_string();
let size = std::fs::metadata(&path).unwrap().len();
self.acquire_capacity(&path.display().to_string()).await;
let mut inner = self.state.inner.lock().unwrap();
inner.entries.insert(
key.clone(),
CacheEntry {
path: path.clone(),
size,
lock_count: 1,
},
);
inner.total_bytes += size;
TraceFile {
path,
cache_ref: Some((key, Arc::clone(&self.state))),
}
}
}
#[cfg(test)]
mod tests {
use std::sync::{
Arc,
atomic::{AtomicBool, AtomicU32, Ordering},
};
use bytes::Bytes;
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::{Method, Request, Response};
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as ConnBuilder;
use std::convert::Infallible;
use tokio::net::TcpListener;
use super::*;
async fn start_static_server(status: u16, content_type: Option<&'static str>) -> String {
let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
std_listener.set_nonblocking(true).unwrap();
let port = std_listener.local_addr().unwrap().port();
tokio::spawn(async move {
let listener = TcpListener::from_std(std_listener).unwrap();
let http = Arc::new(ConnBuilder::new(TokioExecutor::new()));
loop {
let (stream, _) = match listener.accept().await {
Ok(x) => x,
Err(_) => break,
};
let http = http.clone();
tokio::spawn(async move {
let svc =
hyper::service::service_fn(move |_req: Request<Incoming>| async move {
let mut builder = Response::builder().status(status);
if let Some(ct) = content_type {
builder = builder.header("content-type", ct);
}
Ok::<_, Infallible>(builder.body(Full::new(Bytes::new())).unwrap())
});
let _ = http.serve_connection(TokioIo::new(stream), svc).await;
});
}
});
format!("http://127.0.0.1:{port}")
}
async fn start_mock_file_server(body: Bytes, get_count: Arc<AtomicU32>) -> String {
let etag = format!("{:x}", md5::compute(&body));
let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
std_listener.set_nonblocking(true).unwrap();
let port = std_listener.local_addr().unwrap().port();
tokio::spawn(async move {
let listener = TcpListener::from_std(std_listener).unwrap();
let http = Arc::new(ConnBuilder::new(TokioExecutor::new()));
loop {
let (stream, _) = match listener.accept().await {
Ok(x) => x,
Err(_) => break,
};
let http = Arc::clone(&http);
let body = body.clone();
let etag = etag.clone();
let get_count = Arc::clone(&get_count);
tokio::spawn(async move {
let svc = hyper::service::service_fn(move |req: Request<Incoming>| {
let body = body.clone();
let etag = etag.clone();
let get_count = Arc::clone(&get_count);
async move {
let builder = Response::builder()
.status(200)
.header("content-type", "application/octet-stream")
.header("content-length", body.len())
.header("etag", format!("\"{etag}\""));
let response_body = if req.method() == Method::HEAD {
Bytes::new()
} else {
get_count.fetch_add(1, Ordering::SeqCst);
body
};
Ok::<_, Infallible>(builder.body(Full::new(response_body)).unwrap())
}
});
let _ = http.serve_connection(TokioIo::new(stream), svc).await;
});
}
});
format!("http://127.0.0.1:{port}")
}
#[tokio::test]
async fn download_401_reports_auth_error() {
let tmpdir = tempfile::tempdir().unwrap();
let base_url = start_static_server(401, None).await;
let downloader = TraceDownloader::new(tmpdir.path().to_path_buf(), 0, None, None).unwrap();
let msg = match downloader.get(&base_url, "trace.mock-trace").await {
Ok(_) => panic!("expected an error on HTTP 401"),
Err(e) => format!("{e:#}"),
};
assert!(msg.contains("401"), "error should mention HTTP 401: {msg}");
assert!(msg.contains("--jwt"), "error should suggest --jwt: {msg}");
}
#[tokio::test]
async fn download_html_response_reports_auth_error() {
let tmpdir = tempfile::tempdir().unwrap();
let base_url = start_static_server(200, Some("text/html; charset=utf-8")).await;
let downloader = TraceDownloader::new(tmpdir.path().to_path_buf(), 0, None, None).unwrap();
let msg = match downloader.get(&base_url, "trace.mock-trace").await {
Ok(_) => panic!("expected an error when server returns HTML"),
Err(e) => format!("{e:#}"),
};
assert!(
msg.to_lowercase().contains("html"),
"error should mention HTML response: {msg}"
);
assert!(msg.contains("--jwt"), "error should suggest --jwt: {msg}");
}
#[tokio::test]
async fn disk_limit_blocks_until_token_dropped() {
let tmpdir = tempfile::tempdir().unwrap();
let downloader =
Arc::new(TraceDownloader::new(tmpdir.path().to_path_buf(), 100, None, None).unwrap());
let held = tmpdir.path().join("held");
std::fs::write(&held, vec![0u8; 100]).unwrap();
let token1 = downloader.make_test_token(held.clone(), 100);
let waiting = tmpdir.path().join("waiting");
std::fs::write(&waiting, vec![0u8; 10]).unwrap();
let d2 = Arc::clone(&downloader);
let w2 = waiting.clone();
let unblocked = Arc::new(AtomicBool::new(false));
let unblocked_bg = Arc::clone(&unblocked);
let handle = tokio::spawn(async move {
let token = d2.wait_and_make_token(w2).await;
unblocked_bg.store(true, Ordering::SeqCst);
token
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(
!unblocked.load(Ordering::SeqCst),
"task should be blocked on disk limit"
);
drop(token1);
let token2 = handle.await.expect("background task panicked");
assert!(
unblocked.load(Ordering::SeqCst),
"task should have unblocked after token drop"
);
assert!(
!held.exists(),
"held file should be evicted when it is unlocked and space is needed"
);
assert!(
waiting.exists(),
"waiting file should still exist while token2 is held"
);
drop(token2);
assert!(
waiting.exists(),
"waiting file should remain in the cache after token2 is dropped"
);
}
#[tokio::test]
async fn cached_file_not_redownloaded_in_same_run() {
let tmpdir = tempfile::tempdir().unwrap();
let trace_path = "subdir/app.mock-trace";
let get_count = Arc::new(AtomicU32::new(0));
let base_url =
start_mock_file_server(Bytes::from_static(b"trace body"), Arc::clone(&get_count)).await;
let downloader = TraceDownloader::new(tmpdir.path().to_path_buf(), 0, None, None).unwrap();
let t1 = downloader.get(&base_url, trace_path).await.unwrap();
assert_eq!(
get_count.load(Ordering::SeqCst),
1,
"first get should download"
);
drop(t1);
let _t2 = downloader.get(&base_url, trace_path).await.unwrap();
assert_eq!(
get_count.load(Ordering::SeqCst),
1,
"second get in the same run should be served from cache without a new GET"
);
}
#[tokio::test]
async fn cross_run_cache_reused_when_valid() {
let tmpdir = tempfile::tempdir().unwrap();
let trace_path = "subdir/app.mock-trace";
let get_count = Arc::new(AtomicU32::new(0));
let base_url =
start_mock_file_server(Bytes::from_static(b"trace body"), Arc::clone(&get_count)).await;
let d1 = TraceDownloader::new(tmpdir.path().to_path_buf(), 0, None, None).unwrap();
let t1 = d1.get(&base_url, trace_path).await.unwrap();
assert_eq!(get_count.load(Ordering::SeqCst), 1);
drop(t1);
drop(d1);
let d2 = TraceDownloader::new(tmpdir.path().to_path_buf(), 0, None, None).unwrap();
let _t2 = d2.get(&base_url, trace_path).await.unwrap();
assert_eq!(
get_count.load(Ordering::SeqCst),
1,
"valid cross-run cache should be served without re-downloading"
);
}
#[tokio::test]
async fn content_corrupted_cache_invalidated() {
let tmpdir = tempfile::tempdir().unwrap();
let trace_path = "subdir/app.mock-trace";
let trace_body: &[u8] = b"trace body contents";
let get_count = Arc::new(AtomicU32::new(0));
let base_url =
start_mock_file_server(Bytes::from_static(trace_body), Arc::clone(&get_count)).await;
let d1 = TraceDownloader::new(tmpdir.path().to_path_buf(), 0, None, None).unwrap();
let t1 = d1.get(&base_url, trace_path).await.unwrap();
let cached_path = t1.path().to_path_buf();
drop(t1);
drop(d1);
let mut data = std::fs::read(&cached_path).unwrap();
data[0] ^= 0xff;
std::fs::write(&cached_path, &data).unwrap();
let d2 = TraceDownloader::new(tmpdir.path().to_path_buf(), 0, None, None).unwrap();
let t2 = d2.get(&base_url, trace_path).await.unwrap();
assert_eq!(
get_count.load(Ordering::SeqCst),
1,
"corruption not yet detected: get() serves from cache without checking"
);
let invalidated = d2
.invalidate_if_corrupted(&base_url, trace_path)
.await
.unwrap();
assert!(invalidated, "ETag mismatch should trigger invalidation");
drop(t2);
let t3 = d2.get(&base_url, trace_path).await.unwrap();
assert_eq!(
get_count.load(Ordering::SeqCst),
2,
"re-download after invalidation"
);
assert_eq!(
std::fs::read(t3.path()).unwrap(),
trace_body,
"re-downloaded file should have correct content"
);
}
}