use std::pin::Pin;
use bytes::Bytes;
use futures::{Stream, TryStream, TryStreamExt};
use reqwest::{Method, StatusCode};
use crate::{
api::{self, percent_encode, DecodeResponse, ListObjectOptions, Object, Page},
errors::{Error, NotFoundError},
};
const GCS_UPLOAD_API_URL: &str = "https://www.googleapis.com/upload/storage/v1/";
#[cfg_attr(any(test, feature = "mocks"), mockall::automock)]
#[async_trait::async_trait]
pub trait BucketClient {
fn bucket_name(&self) -> &str;
async fn ping(&self) -> Result<(), Error>;
async fn list_objects<'a>(&self, options: ListObjectOptions<'a>)
-> Result<Page<Object>, Error>;
async fn create_object<S>(&self, key: &str, value: S) -> Result<Object, Error>
where
S: TryStream + Send + Sync + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
Bytes: From<S::Ok>;
async fn get_object(&self, key: &str) -> Result<Object, Error>;
async fn download_object(
&self,
key: &str,
) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes, Error>> + Send + Sync>>, Error>;
async fn delete_object(&self, key: &str) -> Result<(), Error>;
}
#[derive(Clone)]
pub struct GcsBucketClient {
client: crate::Client,
bucket_name: String,
object_path: String,
upload_url: reqwest::Url,
}
impl GcsBucketClient {
pub(super) fn new(client: crate::Client, bucket_name: String) -> Self {
let encoded_bucket = percent_encode(&bucket_name);
let object_path = format!("b/{}/o", encoded_bucket);
Self {
client,
bucket_name,
upload_url: reqwest::Url::parse(GCS_UPLOAD_API_URL)
.and_then(|u| u.join(&object_path))
.expect("malformed url"),
object_path,
}
}
fn convert_api_error(&self, api_err: api::Error, requested_key: Option<&str>) -> Error {
match api_err {
api::Error::Http(e) => Error::Http(e),
api::Error::Google(e) => {
if e.status == StatusCode::NOT_FOUND {
if e.message.is_empty() || e.message.starts_with("No such object") {
NotFoundError::Object {
bucket: self.bucket_name.clone(),
key: requested_key.unwrap_or_default().into(),
}
.into()
} else {
NotFoundError::Bucket {
bucket: self.bucket_name.clone(),
}
.into()
}
} else if e.status == StatusCode::FORBIDDEN {
Error::PermissionDenied(e.message)
} else {
Error::OtherGoogle(e)
}
}
}
}
}
#[async_trait::async_trait]
impl BucketClient for GcsBucketClient {
fn bucket_name(&self) -> &str {
&self.bucket_name
}
async fn ping(&self) -> Result<(), Error> {
self.list_objects(ListObjectOptions {
max_results: Some(0),
..Default::default()
})
.await?;
Ok(())
}
async fn list_objects<'a>(
&self,
options: ListObjectOptions<'a>,
) -> Result<Page<Object>, Error> {
self.client
.build_request(Method::GET, &self.object_path)
.await?
.query(&options)
.send()
.await?
.decode_response()
.await
.map_err(|e| self.convert_api_error(e, None ))
}
async fn create_object<S>(&self, key: &str, value: S) -> Result<Object, Error>
where
S: TryStream + Send + Sync + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
Bytes: From<S::Ok>,
{
self.client
.build_request_with_url(Method::POST, self.upload_url.clone())
.await?
.query(&[("name", key)])
.body(reqwest::Body::wrap_stream(value))
.send()
.await?
.decode_response()
.await
.map_err(|e| self.convert_api_error(e, Some(key)))
}
async fn get_object(&self, key: &str) -> Result<Object, Error> {
if key.trim().is_empty() {
return Err(Error::NotFound(NotFoundError::Object {
bucket: self.bucket_name.clone(),
key: key.into(),
}));
}
self.client
.build_request(
Method::GET,
&format!("{}/{}", self.object_path, percent_encode(key)),
)
.await?
.send()
.await?
.decode_response()
.await
.map_err(|e| self.convert_api_error(e, Some(key)))
}
async fn download_object(
&self,
key: &str,
) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes, Error>> + Send + Sync>>, Error> {
let res = self
.client
.build_request(
Method::GET,
&format!("{}/{}?alt=media", self.object_path, percent_encode(key)),
)
.await?
.send()
.await?;
if res.status().is_success() {
Ok(Box::pin(res.bytes_stream().map_err(Error::from)))
} else if res.status() == StatusCode::NOT_FOUND {
Err(Error::NotFound(NotFoundError::Object {
bucket: self.bucket_name.clone(),
key: key.into(),
}))
} else {
Err(Error::OtherGoogle(api::GoogleError {
status: res.status(),
message: res.text().await?,
}))
}
}
async fn delete_object(&self, key: &str) -> Result<(), Error> {
self.client
.build_request(
Method::DELETE,
&format!("{}/{}", self.object_path, percent_encode(key)),
)
.await?
.send()
.await?
.decode_response::<()>()
.await
.or_else(|e| match e {
api::Error::Google(api::GoogleError {
status: StatusCode::NOT_FOUND,
..
}) => Ok(()),
_ => Err(e),
})
.map_err(|e| self.convert_api_error(e, Some(key)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
token_provider::oauth::{OAuthTokenProvider, ServiceAccount, SCOPE_STORAGE_FULL_CONTROL},
Client,
};
fn test_bucket() -> String {
std::env::var("CLOUD_STORAGE_LITE_TEST_BUCKET").unwrap()
}
fn random_string() -> String {
let mut rng = rand::thread_rng();
std::iter::repeat(())
.map(|()| rand::Rng::sample(&mut rng, rand::distributions::Alphanumeric))
.map(char::from)
.take(8)
.collect()
}
fn get_client() -> Client {
let token_provider = OAuthTokenProvider::new(
ServiceAccount::read_from_canonical_env().unwrap(),
SCOPE_STORAGE_FULL_CONTROL,
)
.unwrap();
Client::new(token_provider)
}
fn get_bucket_client() -> impl BucketClient {
get_client().into_bucket_client(test_bucket())
}
#[tokio::test]
async fn ping() {
let bucket_client = get_bucket_client();
bucket_client.ping().await.unwrap();
}
#[tokio::test]
async fn ping_notfound() {
let bucket_client = get_client().into_bucket_client(test_bucket() + "qqq");
let result = bucket_client.ping().await;
assert!(matches!(result, Err(Error::NotFound(_))), "{:?}", result);
}
#[tokio::test]
async fn ping_forbidden() {
let bucket_client = get_client().into_bucket_client("admin".into());
let result = bucket_client.ping().await;
assert!(
matches!(result, Err(Error::PermissionDenied(_))),
"{:?}",
result
);
}
static TEST_DATA: &str = "test";
fn make_data_stream() -> impl Stream<Item = Result<Bytes, std::convert::Infallible>> {
futures::stream::once(futures::future::ok::<_, std::convert::Infallible>(
Bytes::from(TEST_DATA),
))
}
#[tokio::test]
async fn create_object() {
let bucket_client = get_bucket_client();
let key = random_string();
bucket_client
.create_object(&key, make_data_stream())
.await
.unwrap();
let obj = bucket_client.get_object(&key).await.unwrap();
assert_eq!(obj.name, key);
assert_eq!(obj.size, TEST_DATA.len() as u64);
assert!(obj.id.starts_with(&(test_bucket() + "/" + &key)));
}
#[tokio::test]
async fn get_object_notfound() {
let bucket_client = get_bucket_client();
assert!(matches!(
bucket_client.get_object("thiskeydoesnotexist").await,
Err(Error::NotFound(NotFoundError::Object { .. }))
));
assert!(matches!(
bucket_client.get_object("").await,
Err(Error::NotFound(NotFoundError::Object { .. }))
));
}
#[tokio::test]
async fn list_objects() {
let bucket_client = get_bucket_client();
let prefix = random_string();
let key1 = prefix.clone() + "key1";
let key2 = prefix.clone() + "key2";
let create_key1 = bucket_client.create_object(&key1, make_data_stream());
let create_key2 = bucket_client.create_object(&key2, make_data_stream());
futures::try_join!(create_key1, create_key2).unwrap();
let page = bucket_client
.list_objects(ListObjectOptions {
prefix: Some(&prefix),
..Default::default()
})
.await
.unwrap();
assert_eq!(page.items.len(), 2);
let page = bucket_client
.list_objects(ListObjectOptions {
prefix: Some(&key1),
..Default::default()
})
.await
.unwrap();
assert_eq!(page.items.len(), 1);
}
#[tokio::test]
async fn download_object() {
let bucket_client = get_bucket_client();
let key = random_string();
bucket_client
.create_object(&key, make_data_stream())
.await
.unwrap();
let downloaded_data = bucket_client
.download_object(&key)
.await
.unwrap()
.try_fold(Vec::new(), |mut buf, chunk| async move {
buf.extend_from_slice(&chunk);
Ok(buf)
})
.await
.unwrap();
assert_eq!(downloaded_data, TEST_DATA.as_bytes());
}
#[tokio::test]
async fn download_notfound() {
let bucket_client = get_bucket_client();
let err_res = bucket_client.download_object("thiskeydoesnotexist").await;
assert!(matches!(
err_res,
Err(Error::NotFound(NotFoundError::Object { .. }))
));
}
#[tokio::test]
async fn delete_object() {
let bucket_client = get_bucket_client();
let key = random_string();
bucket_client
.create_object(&key, make_data_stream())
.await
.unwrap();
bucket_client.delete_object(&key).await.unwrap();
assert!(matches!(
bucket_client.get_object(&key).await.unwrap_err(),
Error::NotFound(NotFoundError::Object { .. })
));
}
#[tokio::test]
async fn delete_nonexistent() {
let bucket_client = get_bucket_client();
bucket_client
.delete_object("thiskeydoesnotexist")
.await
.unwrap();
}
#[tokio::test]
async fn object_lifecycle() {
let bucket_client = get_bucket_client();
let key = random_string() + "/" + &random_string();
bucket_client
.create_object(&key, make_data_stream())
.await
.unwrap();
bucket_client.get_object(&key).await.unwrap();
bucket_client.download_object(&key).await.unwrap();
bucket_client.delete_object(&key).await.unwrap();
bucket_client.get_object(&key).await.unwrap_err();
}
}