use anyhow::{Context as _, Result};
use chrono::prelude::*;
use hyper::{
body::Bytes,
client::{connect::dns::GaiResolver, HttpConnector},
header::HeaderValue,
header::USER_AGENT,
Body, Client, Request, Response,
};
use hyper_rustls::HttpsConnector;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::{
collections::{hash_map::DefaultHasher, HashMap},
hash::{Hash, Hasher},
path::PathBuf,
};
use tokio::runtime::Runtime;
pub(crate) fn request<T>(
request: Request<()>,
) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
RUNTIME.block_on(async move {
if should_clear_cache() {
clear_cache().ok();
}
let mut request = request.map(|_| Body::empty());
request
.headers_mut()
.insert(USER_AGENT, HeaderValue::from_static("todo-or-die"));
let hash = hash_request(&request);
let response = if let Some(cached_response) =
cached_response(&hash).context("Failed to read cached response")?
{
cached_response
} else {
execute_request_and_cache_response(request, &hash).await?
};
if !response.status().is_success() {
let body = String::from_utf8_lossy(response.body());
anyhow::bail!(
"Received non-success response. status={}, body={:?}",
response.status(),
body
);
}
let value =
serde_json::from_slice::<T>(&*response.body()).context("Failed to parse response")?;
Ok(value)
})
}
async fn execute_request_and_cache_response(
request: Request<Body>,
hash: &RequestHash,
) -> Result<Response<Bytes>> {
let response = tokio::time::timeout(
std::time::Duration::from_secs(1),
http_client().request(request),
)
.await
.context("HTTP request timed out")?
.context("HTTP request to failed")?;
let (parts, body) = response.into_parts();
let body = hyper::body::to_bytes(body)
.await
.context("Failed to read response")?;
let response = Response::from_parts(parts, body);
if caching_enabled() {
cache_response(hash, &response).context("Failed to cache response")?;
}
Ok(response)
}
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to build tokio runtime")
});
type HyperTlsClient = Client<HttpsConnector<HttpConnector<GaiResolver>>, Body>;
fn http_client() -> &'static HyperTlsClient {
static CLIENT: Lazy<HyperTlsClient> = Lazy::new(|| {
let mut tls = rustls::ClientConfig::new();
tls.set_protocols(&["h2".into(), "http/1.1".into()]);
tls.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
let mut http = hyper::client::HttpConnector::new();
http.enforce_http(false);
hyper::Client::builder().build::<_, Body>(hyper_rustls::HttpsConnector::from((http, tls)))
});
&*CLIENT
}
#[derive(Clone, Debug, PartialEq, Eq)]
struct RequestHash(String);
fn hash_request(request: &Request<Body>) -> RequestHash {
let mut hasher = DefaultHasher::new();
format!("{:?}", request).hash(&mut hasher);
let hash = hasher.finish();
RequestHash(hash.to_string())
}
fn cached_response(hash: &RequestHash) -> Result<Option<Response<Bytes>>> {
if !caching_enabled() {
return Ok(None);
}
let path = cache_dir_path_for_this_version()?.join(&hash.0);
let data = match std::fs::read(&path) {
Ok(file) => file,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
return Ok(None);
}
Err(err) => return Err(err.into()),
};
if let Some(response) = deserialize_response(data)? {
Ok(Some(response))
} else {
std::fs::remove_file(&path)?;
Ok(None)
}
}
fn cache_response(hash: &RequestHash, response: &Response<Bytes>) -> Result<()> {
let path = cache_dir_path_for_this_version()?.join(&hash.0);
let bytes = serialize_response(response)?;
std::fs::write(path, bytes)?;
Ok(())
}
fn clear_cache() -> Result<()> {
let path = top_level_cache_dir()?;
std::fs::remove_dir_all(path)?;
Ok(())
}
fn serialize_response(response: &Response<Bytes>) -> Result<Vec<u8>> {
let headers = response
.headers()
.iter()
.map(|(key, value)| (key.as_str().to_string(), value.as_bytes().to_vec()))
.collect();
let response = SerializedResponse {
status: response.status().as_u16(),
headers,
body: response.body().to_vec(),
expires_at: Local::now() + cache_ttl(),
};
Ok(serde_json::to_vec(&response)?)
}
fn cache_ttl() -> chrono::Duration {
(|| {
let var = std::env::var("TODO_OR_DIE_HTTP_CACHE_TTL_SECONDS")?;
let sec = var.parse()?;
Ok::<_, anyhow::Error>(chrono::Duration::seconds(sec))
})()
.unwrap_or_else(|_| chrono::Duration::hours(1))
}
fn deserialize_response(data: Vec<u8>) -> Result<Option<Response<Bytes>>> {
let response = serde_json::from_slice::<SerializedResponse>(&data)
.context("Failed to deserialize cached HTTP response")?;
let expires_at = response.expires_at.timestamp();
let now = Local::now().timestamp();
if now > expires_at {
return Ok(None);
}
let status = hyper::StatusCode::from_u16(response.status)?;
let headers = response
.headers
.iter()
.map(|(key, value)| {
let key = hyper::header::HeaderName::from_bytes(key.as_bytes())?;
let value = HeaderValue::from_bytes(value)?;
Ok::<_, anyhow::Error>((key, value))
})
.collect::<Result<hyper::HeaderMap>>()
.context("Failed to build header map")?;
let body = Bytes::copy_from_slice(&response.body);
let mut out = Response::new(body);
*out.status_mut() = status;
*out.headers_mut() = headers;
Ok(Some(out))
}
#[derive(Serialize, Deserialize)]
struct SerializedResponse {
status: u16,
headers: HashMap<String, Vec<u8>>,
body: Vec<u8>,
expires_at: DateTime<Local>,
}
fn top_level_cache_dir() -> Result<PathBuf> {
let path = std::env::temp_dir().join("todo_or_die_cache");
std::fs::create_dir_all(&path).context("Failed to create dir to store HTTP caches")?;
Ok(path)
}
fn cache_dir_path_for_this_version() -> Result<PathBuf> {
let todo_or_die_version = env!("CARGO_PKG_VERSION");
let path = top_level_cache_dir()?.join(todo_or_die_version);
std::fs::create_dir_all(&path).context("Failed to create dir to store HTTP caches")?;
Ok(path)
}
fn caching_enabled() -> bool {
!should_clear_cache() && std::env::var("TODO_OR_DIE_DISABLE_HTTP_CACHE").is_err()
}
fn should_clear_cache() -> bool {
std::env::var("TODO_OR_DIE_CLEAR_HTTP_CACHE").is_ok()
}