use std::fs;
use std::fs::File;
use std::io;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::Context;
use anyhow::Result;
use anyhow::bail;
use futures::FutureExt;
use http::HeaderMap;
use http::Response;
use http::StatusCode;
use http::Version;
use http::response::Parts;
use http_body::Body;
use http_cache_semantics::CachePolicy;
use serde::Deserialize;
use serde::Serialize;
use tracing::debug;
use super::StoredResponse;
use crate::body::CacheBody;
use crate::runtime;
use crate::storage::CacheStorage;
const STORAGE_VERSION: &str = "v1";
const RESPONSE_DIRECTORY_NAME: &str = "responses";
const CONTENT_DIRECTORY_NAME: &str = "content";
const TEMP_DIRECTORY_NAME: &str = "tmp";
#[derive(Serialize)]
struct CachedResponseRef<'a> {
#[serde(with = "http_serde::status_code")]
status: StatusCode,
#[serde(with = "http_serde::version")]
version: Version,
#[serde(with = "http_serde::header_map")]
headers: &'a HeaderMap,
digest: &'a str,
policy: &'a CachePolicy,
}
#[derive(Deserialize)]
struct CachedResponse {
#[serde(with = "http_serde::status_code")]
status: StatusCode,
#[serde(with = "http_serde::version")]
version: Version,
#[serde(with = "http_serde::header_map")]
headers: HeaderMap,
digest: String,
policy: CachePolicy,
}
#[derive(Clone)]
pub struct DefaultCacheStorage(Arc<DefaultCacheStorageInner>);
impl DefaultCacheStorage {
pub fn new(root_dir: impl Into<PathBuf>) -> Self {
Self(Arc::new(DefaultCacheStorageInner(root_dir.into())))
}
}
impl CacheStorage for DefaultCacheStorage {
async fn get<B: Body + Send>(&self, key: &str) -> Result<Option<StoredResponse<B>>> {
let cached = match self.0.read_response(key).await? {
Some(response) => response,
None => return Ok(None),
};
let path = self.body_path(&cached.digest);
let body = match runtime::File::open(&path)
.await
.map(Some)
.or_else(|e| {
if e.kind() == io::ErrorKind::NotFound {
Ok(None)
} else {
Err(e)
}
})
.with_context(|| {
format!(
"failed to open response body `{path}`",
path = path.display()
)
})? {
Some(file) => file,
None => return Ok(None),
};
let mut builder = Response::builder()
.version(cached.version)
.status(cached.status);
let headers = builder.headers_mut().expect("should be valid");
headers.extend(cached.headers);
Ok(Some(StoredResponse {
response: builder
.body(CacheBody::from_file(body).await.with_context(|| {
format!(
"failed to create response body for `{path}`",
path = path.display()
)
})?)
.expect("should be valid"),
policy: cached.policy,
digest: cached.digest,
}))
}
async fn put(
&self,
key: &str,
parts: &Parts,
policy: &CachePolicy,
digest: &str,
) -> Result<()> {
self.0
.write_response(
key,
CachedResponseRef {
status: parts.status,
version: parts.version,
headers: &parts.headers,
digest,
policy,
},
)
.await
}
async fn store<B: Body + Send>(
&self,
key: String,
parts: Parts,
body: B,
policy: CachePolicy,
) -> Result<Response<CacheBody<B>>> {
let inner = self.0.clone();
let temp_dir = inner.temp_dir_path();
fs::create_dir_all(&temp_dir).with_context(|| {
format!(
"failed to create temporary directory `{path}`",
path = temp_dir.display()
)
})?;
let status = parts.status;
let version = parts.version;
let headers = parts.headers.clone();
let body = CacheBody::from_caching_upstream(body, &temp_dir, move |digest, path| {
async move {
let content_path = inner.content_path(&digest);
fs::create_dir_all(content_path.parent().expect("should have parent"))
.context("failed to create content directory")?;
path.persist(&content_path).with_context(|| {
format!(
"failed to persist downloaded body to content path `{path}`",
path = content_path.display()
)
})?;
inner
.write_response(
&key,
CachedResponseRef {
status,
version,
headers: &headers,
digest: &digest,
policy: &policy,
},
)
.await?;
debug!(key, digest, "response body stored successfully");
Ok(())
}
.boxed()
})
.await?;
Ok(Response::from_parts(parts, body))
}
async fn delete(&self, key: &str) -> Result<()> {
self.0.lock_response_exclusive(key).await?;
Ok(())
}
fn body_path(&self, digest: &str) -> PathBuf {
self.0.content_path(digest)
}
}
struct DefaultCacheStorageInner(PathBuf);
impl DefaultCacheStorageInner {
fn response_path(&self, key: &str) -> PathBuf {
let mut path = self.0.to_path_buf();
path.push(STORAGE_VERSION);
path.push(RESPONSE_DIRECTORY_NAME);
path.push(key);
path
}
fn content_path(&self, digest: &str) -> PathBuf {
let mut path = self.0.to_path_buf();
path.push(STORAGE_VERSION);
path.push(CONTENT_DIRECTORY_NAME);
path.push(digest);
path
}
fn temp_dir_path(&self) -> PathBuf {
let mut path = self.0.to_path_buf();
path.push(STORAGE_VERSION);
path.push(TEMP_DIRECTORY_NAME);
path
}
async fn read_response(&self, key: &str) -> Result<Option<CachedResponse>> {
let mut response = match self.lock_response_shared(key).await? {
Some(file) => file,
None => return Ok(None),
};
Ok(bincode::deserialize_from(&mut response)
.inspect_err(|e| {
debug!(
"failed to deserialize response file `{path}`: {e} (cache entry will be \
ignored)",
path = self.response_path(key).display()
);
})
.ok())
}
async fn write_response(&self, key: &str, response: CachedResponseRef<'_>) -> Result<()> {
let mut file = self.lock_response_exclusive(key).await?;
bincode::serialize_into(&mut file, &response)
.with_context(|| format!("failed to serialize response data for cache key `{key}`"))
.map(|_| ())
}
async fn lock_response_shared(&self, key: &str) -> Result<Option<File>> {
let path = self.response_path(key);
match fs::OpenOptions::new()
.read(true)
.open(&path)
.map(Some)
.or_else(|e| {
if e.kind() == io::ErrorKind::NotFound {
Ok(None)
} else {
Err(e)
}
})
.with_context(|| {
format!(
"failed to open response file `{path}`",
path = path.display()
)
})? {
Some(file) => {
match runtime::unwrap_task_output(
runtime::spawn_blocking(move || {
file.lock_shared()
.context("failed to acquire shared lock on response file")?;
Ok(file)
})
.await,
) {
Some(res) => res.map(Some),
None => bail!("failed to wait for file lock"),
}
}
None => Ok(None),
}
}
async fn lock_response_exclusive(&self, key: &str) -> Result<File> {
let path = self.response_path(key);
let dir = path.parent().expect("should have parent directory");
fs::create_dir_all(dir)
.with_context(|| format!("failed to create directory `{dir}`", dir = dir.display()))?;
let mut options = fs::OpenOptions::new();
options.create(true).write(true);
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
options.mode(0o600);
}
let file = options.open(&path).with_context(|| {
format!(
"failed to create response file `{path}`",
path = path.display()
)
})?;
let file = match runtime::unwrap_task_output(
runtime::spawn_blocking(move || {
file.lock()
.context("failed to acquire exclusive lock on response file")?;
anyhow::Ok(file)
})
.await,
) {
Some(res) => res?,
None => bail!("failed to wait for file lock"),
};
file.set_len(0).with_context(|| {
format!(
"failed to truncate response file `{path}`",
path = path.display()
)
})?;
Ok(file)
}
}
#[cfg(all(test, feature = "tokio"))]
mod test {
use futures::StreamExt;
use http::Request;
use http_body_util::BodyDataStream;
use http_cache_semantics::CachePolicy;
use tempfile::tempdir;
use super::*;
#[tokio::test]
async fn cache_miss() {
let dir = tempdir().unwrap();
let storage = DefaultCacheStorage::new(dir.path());
assert!(
storage
.get::<String>("does-not-exist")
.await
.expect("should not fail")
.is_none()
);
}
#[tokio::test]
async fn cache_hit() {
const KEY: &str = "key";
const BODY: &str = "hello world";
const DIGEST: &str = "d74981efa70a0c880b8d8c1985d075dbcbf679b99a5f9914e5aaf96b831a9e24";
const HEADER_NAME: &str = "foo";
const HEADER_VALUE: &str = "bar";
let dir = tempdir().unwrap();
let storage = DefaultCacheStorage::new(dir.path());
assert!(storage.get::<String>(KEY).await.unwrap().is_none());
let request = Request::builder().body("").unwrap();
let response = Response::builder().body(BODY.to_string()).unwrap();
let policy: CachePolicy = CachePolicy::new(&request, &response);
let (parts, body) = response.into_parts();
let response = storage
.store(KEY.to_string(), parts, body, policy)
.await
.unwrap();
let mut stream = BodyDataStream::new(response.into_body());
let data = stream.next().await.unwrap().unwrap();
assert!(stream.next().await.is_none());
assert_eq!(data, BODY);
drop(stream);
let cached = storage.get::<String>(KEY).await.unwrap().unwrap();
assert!(cached.response.headers().get(HEADER_NAME).is_none());
let data = BodyDataStream::new(cached.response.into_body())
.next()
.await
.unwrap()
.unwrap();
assert_eq!(data, BODY);
assert_eq!(cached.digest, DIGEST);
let response = Response::builder()
.header(HEADER_NAME, HEADER_VALUE)
.body(BODY.to_string())
.unwrap();
let policy = CachePolicy::new(&request, &response);
let (parts, _) = response.into_parts();
storage.put(KEY, &parts, &policy, DIGEST).await.unwrap();
let cached = storage.get::<String>(KEY).await.unwrap().unwrap();
assert_eq!(
cached
.response
.headers()
.get(HEADER_NAME)
.map(|v| v.to_str().unwrap()),
Some(HEADER_VALUE)
);
let data = BodyDataStream::new(cached.response.into_body())
.next()
.await
.unwrap()
.unwrap();
assert_eq!(data, BODY);
assert_eq!(cached.digest, DIGEST);
storage.delete(KEY).await.unwrap();
assert!(storage.get::<String>(KEY).await.unwrap().is_none());
}
}