use bytes::Bytes;
use reqwest::{Client, RequestBuilder, Response};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum Error {
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error("server rate limited the request, asking to wait {retry_after} seconds")]
RateLimit {
retry_after: u64,
#[source]
source: reqwest::Error,
},
#[error("did not find a runtime token in the ACTIONS_RUNTIME_TOKEN environment variable")]
NoRuntimeToken,
#[error("did not find the endpoint URL in the ACTIONS_CACHE_URL environment variable")]
NoEndpointUrl,
}
impl Error {
pub fn retry_after(&self) -> Option<u64> {
if let Self::RateLimit { retry_after, .. } = *self {
Some(retry_after)
} else {
None
}
}
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Deserialize, Debug)]
pub struct CacheHit {
#[serde(rename = "cacheKey")]
pub key: String,
pub scope: String,
}
pub struct Cache {
client: Client,
token: String,
endpoint: String,
}
impl Cache {
pub fn new(user_agent: &str) -> Result<Self> {
let token = std::env::var("ACTIONS_RUNTIME_TOKEN").map_err(|_| Error::NoRuntimeToken)?;
let endpoint = format!(
"{}/_apis/artifactcache",
std::env::var("ACTIONS_CACHE_URL")
.map_err(|_| Error::NoEndpointUrl)?
.trim_end_matches('/')
);
let client = Client::builder().user_agent(user_agent).build()?;
Ok(Self {
client,
token,
endpoint,
})
}
fn api_request(&self, builder: RequestBuilder) -> RequestBuilder {
builder.bearer_auth(&self.token).header(
reqwest::header::ACCEPT,
"application/json;api-version=6.0-preview.1",
)
}
pub async fn get_url(
&self,
key_space: &str,
key_prefixes: &[&str],
) -> Result<Option<(CacheHit, String)>> {
#[derive(Deserialize)]
pub struct GetResponse {
#[serde(flatten)]
hit: CacheHit,
#[serde(rename = "archiveLocation")]
location: String,
}
let response = self
.api_request(self.client.get(format!("{}/cache", self.endpoint)))
.query(&[("keys", &*key_prefixes.join(",")), ("version", key_space)])
.send()
.await?;
tracing::debug!(response_headers = ?response.headers());
if response.status() == reqwest::StatusCode::NO_CONTENT {
Ok(None)
} else {
let response: GetResponse = error_for_response(response)?.json().await?;
Ok(Some((response.hit, response.location)))
}
}
pub async fn get_bytes(
&self,
key_space: &str,
keys: &[&str],
) -> Result<Option<(CacheHit, Bytes)>> {
if let Some((hit, location)) = self.get_url(key_space, keys).await? {
let response = self.client.get(location).send().await?;
tracing::debug!(response_headers = ?response.headers());
Ok(Some((hit, response.bytes().await?)))
} else {
Ok(None)
}
}
pub async fn put_bytes(&self, key_space: &str, key: &str, data: Bytes) -> Result<()> {
#[derive(Serialize)]
struct ReserveRequest<'a> {
key: &'a str,
version: &'a str,
}
#[derive(Deserialize)]
struct ReserveResponse {
#[serde(rename = "cacheId")]
cache_id: i64,
}
let response = self
.api_request(self.client.post(format!("{}/caches", self.endpoint)))
.json(&ReserveRequest {
key,
version: key_space,
})
.send()
.await?;
tracing::debug!(response_headers = ?response.headers());
let ReserveResponse { cache_id } = error_for_response(response)?.json().await?;
if !data.is_empty() {
let response = self
.api_request(
self.client
.patch(format!("{}/caches/{}", self.endpoint, cache_id)),
)
.header(
reqwest::header::CONTENT_RANGE,
format!("bytes {}-{}/*", 0, data.len() - 1),
)
.header(reqwest::header::CONTENT_TYPE, "application/octet-stream")
.body(data.clone())
.send()
.await?;
tracing::debug!(response_headers = ?response.headers());
error_for_response(response)?;
}
#[derive(Serialize)]
struct RequestBody<'a> {
key: &'a str,
version: &'a str,
}
#[derive(Serialize)]
struct FinalizeRequest {
size: usize,
}
let response = self
.api_request(
self.client
.post(format!("{}/caches/{}", self.endpoint, cache_id)),
)
.json(&FinalizeRequest { size: data.len() })
.send()
.await?;
tracing::debug!(response_headers = ?response.headers());
error_for_response(response)?;
Ok(())
}
}
fn error_for_response(response: Response) -> Result<Response> {
if response.status().is_client_error() || response.status().is_server_error() {
if let Some(retry_after) = response
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok()?.parse().ok())
{
return Err(Error::RateLimit {
retry_after,
source: response.error_for_status().unwrap_err(),
});
}
}
response.error_for_status().map_err(Into::into)
}