use std::{env, ops::Range, sync::Arc};
use async_trait::async_trait;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use once_cell::sync::Lazy;
use reqwest::{Body, Client, RequestBuilder, Response, StatusCode};
use serde::{Deserialize, Serialize};
use crate::{
auth::{get_token, TokenProvider},
error::{Result, VercelBlobError},
};
const BLOB_API_VERSION: u32 = 4;
static GLOBAL_CLIENT: Lazy<Client> = Lazy::new(Client::new);
pub struct VercelBlobClient {
token_provider: Option<Arc<dyn TokenProvider>>,
base_url: String,
api_version: String,
}
#[derive(Deserialize)]
struct BlobApiErrorDetail {
code: String,
message: Option<String>,
}
#[derive(Deserialize)]
struct BlobApiError {
error: BlobApiErrorDetail,
}
impl VercelBlobClient {
pub fn new() -> Self {
Self {
token_provider: None,
base_url: Self::get_base_url(),
api_version: Self::get_api_version(),
}
}
pub fn new_external(token_provider: Arc<dyn TokenProvider>) -> Self {
Self {
token_provider: Some(token_provider),
base_url: Self::get_base_url(),
api_version: Self::get_api_version(),
}
}
fn get_base_url() -> String {
env::var("VERCEL_BLOB_API_URL")
.or(env::var("NEXT_PUBLIC_VERCEL_BLOB_API_URL"))
.unwrap_or_else(|_| "https://blob.vercel-storage.com".to_string())
}
fn get_api_url(&self, pathname: Option<&str>) -> String {
url_join(self.base_url.clone(), pathname.unwrap_or("").to_string())
}
fn get_api_version() -> String {
env::var("VERCEL_BLOB_API_VERSION_OVERRIDE")
.unwrap_or_else(|_| BLOB_API_VERSION.to_string())
}
fn add_api_version_header(&self, request: RequestBuilder) -> RequestBuilder {
request.header("x-api-version", self.api_version.clone())
}
async fn add_authorization_header(
&self,
request: RequestBuilder,
operation: &str,
pathname: Option<&str>,
) -> Result<RequestBuilder> {
let token = get_token(self.token_provider.as_deref(), operation, pathname).await?;
Ok(request.header("authorization", format!("Bearer {}", token)))
}
async fn handle_error(response: Response) -> VercelBlobError {
let status = response.status();
if status.as_u16() >= 500 {
return VercelBlobError::unknown_error(status);
}
let error = response.json::<BlobApiError>().await;
if error.is_err() {
return VercelBlobError::unknown_error(status);
}
let error = error.unwrap();
match error.error.code.as_str() {
"store_suspended" => VercelBlobError::StoreSuspended(),
"forbidden" => VercelBlobError::Forbidden(),
"not_found" => VercelBlobError::BlobNotFound(),
"store_not_found" => VercelBlobError::StoreNotFound(),
"bad_request" => VercelBlobError::BadRequest(
error
.error
.message
.unwrap_or_else(|| "unknown details".to_string()),
),
_ => VercelBlobError::unknown_error(status),
}
}
}
#[async_trait]
pub trait VercelBlobApi {
async fn list(&self, options: ListCommandOptions) -> Result<ListBlobResult>;
async fn put(
&self,
pathname: &str,
body: impl Into<Body> + Send,
options: PutCommandOptions,
) -> Result<PutBlobResult>;
async fn head(&self, url: &str, options: HeadCommandOptions) -> Result<Option<HeadBlobResult>>;
async fn del(&self, url: &str, options: DelCommandOptions) -> Result<()>;
async fn download(&self, url: &str, options: DownloadCommandOptions) -> Result<Bytes>;
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ListBlobResultBlob {
pub url: String,
pub pathname: String,
pub size: u64,
#[serde(alias = "uploadedAt")]
pub uploaded_at: DateTime<Utc>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ListBlobResult {
pub blobs: Vec<ListBlobResultBlob>,
pub cursor: Option<String>,
#[serde(alias = "hasMore")]
pub has_more: bool,
}
#[derive(Clone, Debug, Default)]
pub struct ListCommandOptions {
pub limit: Option<u64>,
pub prefix: Option<String>,
pub cursor: Option<String>,
}
fn url_join(left: String, right: String) -> String {
if left.ends_with('/') {
if let Some(stripped) = right.strip_prefix('/') {
left + stripped
} else {
left + &right
}
} else if right.starts_with('/') {
left + &right
} else {
left + "/" + &right
}
}
#[async_trait]
impl VercelBlobApi for VercelBlobClient {
async fn list(&self, options: ListCommandOptions) -> Result<ListBlobResult> {
let api_url = self.get_api_url(None);
let mut request = GLOBAL_CLIENT.get(api_url);
if options.limit.is_some() {
request = request.query(&[("limit", options.limit.unwrap())]);
}
if options.prefix.is_some() {
request = request.query(&[("prefix", options.prefix.unwrap())]);
}
if options.cursor.is_some() {
request = request.query(&[("cursor", options.cursor.unwrap())]);
}
request = self.add_api_version_header(request);
request = self.add_authorization_header(request, "list", None).await?;
let rsp = request.send().await?;
if rsp.status() != StatusCode::OK {
Err(Self::handle_error(rsp).await)
} else {
Ok(rsp.json::<ListBlobResult>().await?)
}
}
async fn put(
&self,
pathname: &str,
body: impl Into<Body> + Send,
options: PutCommandOptions,
) -> Result<PutBlobResult> {
if pathname.is_empty() {
return Err(VercelBlobError::required("pathname"));
}
let api_url = self.get_api_url(Some(&format!("/{pathname}")));
let mut request = GLOBAL_CLIENT.put(api_url);
request = self.add_api_version_header(request);
request = self
.add_authorization_header(request, "put", Some(pathname))
.await?;
if !options.add_random_suffix {
request = request.header("x-add-random-suffix", "0");
}
if let Some(content_type) = options.content_type {
request = request.header("x-content-type", content_type);
}
if let Some(cache_control_max_age) = options.cache_control_max_age {
request = request.header("x-cache-control-max-age", cache_control_max_age.to_string());
}
request = request.body(body);
let response = request.send().await?;
if response.status() != StatusCode::OK {
Err(Self::handle_error(response).await)
} else {
let rsp_obj = response.json::<PutBlobResult>().await?;
Ok(rsp_obj)
}
}
async fn head(
&self,
url: &str,
_options: HeadCommandOptions,
) -> Result<Option<HeadBlobResult>> {
let api_url = self.get_api_url(None);
let mut request = GLOBAL_CLIENT.get(api_url);
request = request.query(&[("url", url)]);
request = self.add_api_version_header(request);
request = self
.add_authorization_header(request, "head", Some(url))
.await?;
let response = request.send().await?;
if response.status() != StatusCode::OK {
let err = Self::handle_error(response).await;
match err {
VercelBlobError::BlobNotFound() => Ok(None),
_ => Err(err),
}
} else {
Ok(Some(response.json::<HeadBlobResult>().await?))
}
}
async fn del(&self, url: &str, _options: DelCommandOptions) -> Result<()> {
let api_url = self.get_api_url(Some("/delete"));
let mut request = GLOBAL_CLIENT.post(api_url);
request = self.add_api_version_header(request);
request = self
.add_authorization_header(request, "del", Some(url))
.await?;
request = request.header("content-type", "application/json");
request = request.json(&DelCommandBody {
urls: vec![url.to_string()],
});
let response = request.send().await?;
if response.status() != StatusCode::OK {
Err(Self::handle_error(response).await)
} else {
Ok(())
}
}
async fn download(&self, url: &str, options: DownloadCommandOptions) -> Result<Bytes> {
let mut request = GLOBAL_CLIENT.get(url);
request = self.add_api_version_header(request);
request = self
.add_authorization_header(request, "download", Some(url))
.await?;
if let Some(byte_range) = options.byte_range {
if byte_range.start == byte_range.end {
return Ok(Bytes::new());
}
request = request.header(
"range",
format!("bytes={}-{}", byte_range.start, byte_range.end - 1),
);
}
let response = request.send().await.unwrap();
if response.status() != StatusCode::OK && response.status() != StatusCode::PARTIAL_CONTENT {
Err(Self::handle_error(response).await)
} else {
Ok(response.bytes().await.unwrap())
}
}
}
#[derive(Debug)]
pub struct PutCommandOptions {
pub add_random_suffix: bool,
pub cache_control_max_age: Option<u64>,
pub content_type: Option<String>,
}
impl Default for PutCommandOptions {
fn default() -> Self {
Self {
add_random_suffix: true,
cache_control_max_age: None,
content_type: None,
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct PutBlobResult {
pub url: String,
pub pathname: String,
#[serde(alias = "contentType")]
pub content_type: String,
#[serde(alias = "contentDisposition")]
pub content_disposition: String,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct HeadBlobResult {
pub url: String,
pub size: u64,
#[serde(alias = "uploadedAt")]
pub uploaded_at: DateTime<Utc>,
pub pathname: String,
#[serde(alias = "contentType")]
pub content_type: String,
#[serde(alias = "contentDisposition")]
pub content_disposition: String,
#[serde(alias = "cacheControl")]
pub cache_control: String,
}
#[derive(Debug, Default)]
pub struct HeadCommandOptions {}
#[derive(Debug, Serialize)]
struct DelCommandBody {
urls: Vec<String>,
}
#[derive(Debug, Default)]
pub struct DelCommandOptions {}
#[derive(Debug, Default)]
pub struct DownloadCommandOptions {
pub byte_range: Option<Range<usize>>,
}
#[cfg(test)]
mod tests {
use all_asserts::{assert_false, assert_true};
use mockito::{Matcher, Mock, ServerGuard};
use super::*;
const EXAMPLE_CACHE_CONTROL: &'static str = "public, max-age=31536000, s-maxage=300";
#[derive(Debug, Serialize)]
struct TemplateContext {
url: String,
files: Vec<String>,
}
fn mock_list_rsp(
url: &str,
num_files: u32,
has_more: bool,
cursor: Option<String>,
) -> ListBlobResult {
ListBlobResult {
blobs: (0..num_files)
.map(|i| ListBlobResultBlob {
url: format!("{}/somefile-{}.txt", url, i),
pathname: format!("somefile-{}.txt", i),
size: 123,
uploaded_at: Utc::now(),
})
.collect(),
cursor,
has_more,
}
}
fn create_client(mock_server: &ServerGuard) -> VercelBlobClient {
let client = VercelBlobClient::new();
VercelBlobClient {
api_version: client.api_version,
base_url: mock_server.url(),
token_provider: client.token_provider,
}
}
async fn setup_mock_rsp<T, O, P>(
http_method: &str,
http_path: P,
response: T,
) -> (ServerGuard, Mock)
where
O: Serialize,
T: FnOnce(&str) -> Option<O>,
P: Into<Matcher>,
{
let mut server = mockito::Server::new_async().await;
env::set_var("BLOB_READ_WRITE_TOKEN", "xyz");
let mut mock = server
.mock(http_method, http_path)
.with_status(200)
.with_header("content-type", "application/json")
.match_header("authorization", "Bearer xyz");
let rsp_obj = response(&server.url());
if let Some(rsp_obj) = rsp_obj {
let rsp_json = serde_json::to_string(&rsp_obj).unwrap();
mock = mock.with_body(rsp_json);
}
(server, mock)
}
#[tokio::test]
async fn can_list_no_paging() {
let (server, mock) = setup_mock_rsp("GET", "/", |server_url| {
Some(mock_list_rsp(server_url, 10, false, None))
})
.await;
let mock = mock.create_async().await;
let client = create_client(&server);
let results = client
.list(ListCommandOptions {
..Default::default()
})
.await
.unwrap();
assert_eq!(10, results.blobs.len());
assert_false!(results.has_more);
assert_true!(results.cursor.is_none());
mock.assert_async().await;
}
#[tokio::test]
async fn can_list_paging() {
let (server, mock) = setup_mock_rsp("GET", "/", |server_url| {
Some(mock_list_rsp(server_url, 5, true, Some("xyz".to_string())))
})
.await;
let mock = mock.create_async().await;
let client = create_client(&server);
let results = client
.list(ListCommandOptions {
..Default::default()
})
.await
.unwrap();
assert_eq!(5, results.blobs.len());
assert_true!(results.has_more);
assert_eq!("xyz", results.cursor.unwrap());
mock.assert_async().await;
}
#[tokio::test]
async fn can_put_cache_control() {
let (server, mock) = setup_mock_rsp("PUT", "/somefile.txt", |server_url| {
Some(PutBlobResult {
url: format!("{}/somefile.txt", server_url),
pathname: "somefile.txt".to_string(),
content_type: "text/plain".to_string(),
content_disposition: "inline".to_string(),
})
})
.await;
let mock = mock
.match_header("x-cache-control-max-age", "100")
.create_async()
.await;
let client = create_client(&server);
let data = "here are some new contents";
let pathname = "somefile.txt";
let result = client
.put(
pathname,
data,
PutCommandOptions {
add_random_suffix: false,
cache_control_max_age: Some(100),
..Default::default()
},
)
.await
.unwrap();
assert_eq!(result.pathname, "somefile.txt");
assert_eq!(result.content_type, "text/plain");
mock.assert_async().await;
}
#[tokio::test]
async fn can_put_no_cache_control() {
let (server, mock) = setup_mock_rsp("PUT", "/somefile.txt", |server_url| {
Some(PutBlobResult {
url: format!("{}/somefile.txt", server_url),
pathname: "somefile.txt".to_string(),
content_type: "text/plain".to_string(),
content_disposition: "inline".to_string(),
})
})
.await;
let mock = mock
.match_header("x-cache-control-max-age", Matcher::Missing)
.create_async()
.await;
let client = create_client(&server);
let data = "here are some new contents";
let pathname = "somefile.txt";
let result = client
.put(
pathname,
data,
PutCommandOptions {
add_random_suffix: false,
..Default::default()
},
)
.await
.unwrap();
assert_eq!(result.pathname, "somefile.txt");
assert_eq!(result.content_type, "text/plain");
mock.assert_async().await;
}
#[tokio::test]
async fn can_head() {
let (server, mock) = setup_mock_rsp("GET", Matcher::Any, |server_url| {
Some(HeadBlobResult {
url: format!("{}/somefile.txt", server_url),
size: 123,
uploaded_at: Utc::now(),
pathname: "somefile.txt".to_string(),
content_type: "text/plain".to_string(),
content_disposition: "inline".to_string(),
cache_control: EXAMPLE_CACHE_CONTROL.to_string(),
})
})
.await;
let mock = mock
.match_query(Matcher::UrlEncoded(
"url".to_string(),
format!("{}/somefile.txt", server.url()),
))
.create_async()
.await;
let client = create_client(&server);
let maybe_result = client
.head(
&format!("{}/somefile.txt", server.url()),
HeadCommandOptions {
..Default::default()
},
)
.await
.unwrap();
assert_true!(maybe_result.is_some());
let result = maybe_result.unwrap();
assert_eq!(result.pathname, "somefile.txt");
assert_eq!(result.cache_control, EXAMPLE_CACHE_CONTROL);
mock.assert_async().await;
}
#[tokio::test]
async fn can_del() {
let (server, mock) = setup_mock_rsp::<_, (), _>("POST", "/delete", |_| None).await;
let mock = mock.create_async().await;
let client = create_client(&server);
client
.del(
&format!("{}/somefile.txt", server.url()),
DelCommandOptions {
..Default::default()
},
)
.await
.unwrap();
mock.assert_async().await;
}
#[derive(Debug, Deserialize, Serialize)]
struct MockFile {
text: String,
}
#[tokio::test]
async fn can_download() {
let (server, mock) = setup_mock_rsp("GET", "/somefile.txt", |_| {
Some(MockFile {
text: "hello".to_string(),
})
})
.await;
let mock = mock.create_async().await;
let client = create_client(&server);
let contents = client
.download(
&format!("{}/somefile.txt", server.url()),
DownloadCommandOptions {
..Default::default()
},
)
.await
.unwrap();
mock.assert_async().await;
let parsed_contents = serde_json::from_slice::<MockFile>(&contents).unwrap();
assert_eq!(parsed_contents.text, "hello");
}
}