use super::service::{validate_object_name, ObjectInfo, Service};
use crate::{
errors::Result,
server::{cloud::iter::AsyncObjectIterator, http},
};
use async_trait::async_trait;
use aws_config::{
environment::EnvironmentVariableRegionProvider,
meta::region::RegionProviderChain,
profile::{self, ProfileFileCredentialsProvider},
BehaviorVersion, Region,
};
use aws_credential_types::Credentials;
use aws_sdk_s3::{
self as s3,
config::http::{HttpRequest, HttpResponse},
error::ProvideErrorMetadata,
operation::{get_object::GetObjectOutput, list_objects_v2::ListObjectsV2Output},
};
use aws_smithy_runtime_api::{
client::{
http::{HttpClient, HttpConnector, HttpConnectorFuture, SharedHttpConnector},
result::ConnectorError,
retries::ErrorKind,
},
http::{Headers, StatusCode},
};
use reqwest::Method;
pub(in crate::server) struct AwsService {
client: s3::Client,
bucket: String,
}
#[non_exhaustive]
pub enum AwsCredentials {
AccessKey {
access_key_id: String,
secret_access_key: String,
},
Profile { profile_name: String },
Default,
}
impl AwsService {
pub(in crate::server) async fn new(
region: Option<String>,
bucket: String,
creds: AwsCredentials,
endpoint_url: Option<String>,
force_path_style: bool,
) -> Result<Self> {
let mut config_provider = aws_config::defaults(BehaviorVersion::latest());
match creds {
AwsCredentials::AccessKey {
access_key_id,
secret_access_key,
} => {
config_provider = config_provider.credentials_provider(Credentials::from_keys(
access_key_id,
secret_access_key,
None,
));
}
AwsCredentials::Profile { profile_name } => {
config_provider = config_provider.credentials_provider(
ProfileFileCredentialsProvider::builder()
.profile_name(profile_name)
.build(),
);
}
AwsCredentials::Default => {
}
}
config_provider = config_provider.http_client(ReqwestClient::new()?);
config_provider = config_provider.region(
RegionProviderChain::first_try(region.map(Region::new))
.or_else(EnvironmentVariableRegionProvider::new())
.or_else(profile::region::Builder::default().build())
.or_else(Region::new("us-east-1")),
);
if let Some(url) = endpoint_url {
config_provider = config_provider.endpoint_url(url)
};
let config = config_provider.load().await;
let s3_config = aws_sdk_s3::config::Builder::from(&config)
.force_path_style(force_path_style)
.build();
let client = aws_sdk_s3::Client::from_conf(s3_config);
Ok(Self { client, bucket })
}
}
#[derive(Debug)]
struct ReqwestClient {
connector: SharedHttpConnector,
}
impl ReqwestClient {
fn new() -> Result<Self> {
let client = http::client()?;
Ok(ReqwestClient {
connector: SharedHttpConnector::new(ReqwestConnector { client }),
})
}
}
impl HttpClient for ReqwestClient {
fn http_connector(
&self,
_settings: &aws_smithy_runtime_api::client::http::HttpConnectorSettings,
_components: &aws_sdk_s3::config::RuntimeComponents,
) -> SharedHttpConnector {
self.connector.clone()
}
}
#[derive(Debug)]
struct ReqwestConnector {
client: reqwest::Client,
}
fn reqwest_error_to_connector(err: reqwest::Error) -> ConnectorError {
let mut kind = None;
if err.is_connect() || err.is_timeout() {
kind = Some(ErrorKind::TransientError);
}
if err.is_request() {
kind = Some(ErrorKind::ClientError);
}
ConnectorError::other(Box::new(err), kind)
}
impl HttpConnector for ReqwestConnector {
fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
use std::str::FromStr;
let method = Method::from_str(request.method()).unwrap();
let mut reqwest_req = self.client.request(method, request.uri());
for (h, v) in request.headers() {
reqwest_req = reqwest_req.header(h, v);
}
if let Some(b) = request.into_body().bytes().map(|b| b.to_vec()) {
reqwest_req = reqwest_req.body(b);
}
HttpConnectorFuture::new(async {
let reqwest_resp = reqwest_req
.send()
.await
.map_err(reqwest_error_to_connector)?;
let status_code = reqwest_resp.status().as_u16();
let mut aws_headers = Headers::new();
for (h, v) in reqwest_resp.headers() {
if let Ok(v) = v.to_str() {
aws_headers.insert(h.to_string(), v.to_owned());
}
}
let body = reqwest_resp
.bytes()
.await
.map_err(reqwest_error_to_connector)?;
let mut resp = HttpResponse::new(
StatusCode::try_from(status_code)
.map_err(|e| ConnectorError::other(Box::new(e), None))?,
body.into(),
);
*resp.headers_mut() = aws_headers;
Ok(resp)
})
}
}
fn aws_err<E: Into<s3::Error>>(err: E) -> s3::Error {
err.into()
}
#[allow(clippy::result_large_err)] fn if_key_exists<T>(
res: std::result::Result<T, s3::Error>,
) -> std::result::Result<Option<T>, s3::Error> {
res
.map(Some)
.or_else(|err| match err {
s3::Error::NoSuchKey(_) => Ok(None),
err => Err(err),
})
}
async fn get_body(get_res: GetObjectOutput) -> Result<Vec<u8>> {
Ok(get_res.body.collect().await?.to_vec())
}
#[async_trait]
impl Service for AwsService {
async fn put(&mut self, name: &str, value: &[u8]) -> Result<()> {
validate_object_name(name);
self.client
.put_object()
.bucket(self.bucket.clone())
.key(name)
.body(value.to_vec().into())
.send()
.await
.map_err(aws_err)?;
Ok(())
}
async fn get(&mut self, name: &str) -> Result<Option<Vec<u8>>> {
validate_object_name(name);
let Some(get_res) = if_key_exists(
self.client
.get_object()
.bucket(self.bucket.clone())
.key(name)
.send()
.await
.map_err(aws_err),
)?
else {
return Ok(None);
};
Ok(Some(get_body(get_res).await?))
}
async fn del(&mut self, name: &str) -> Result<()> {
validate_object_name(name);
self.client
.delete_object()
.bucket(self.bucket.clone())
.key(name)
.send()
.await
.map_err(aws_err)?;
Ok(())
}
async fn list<'a>(&'a mut self, prefix: &'a str) -> Box<dyn AsyncObjectIterator + Send + 'a> {
validate_object_name(prefix);
Box::new(ObjectIterator {
service: self,
prefix: prefix.to_string(),
last_response: None,
next_index: 0,
})
}
async fn compare_and_swap(
&mut self,
name: &str,
existing_value: Option<Vec<u8>>,
new_value: Vec<u8>,
) -> Result<bool> {
validate_object_name(name);
let get_res = if_key_exists(
self.client
.get_object()
.bucket(self.bucket.clone())
.key(name)
.send()
.await
.map_err(aws_err),
)?;
let e_tag;
if let Some(get_res) = get_res {
let Some(existing_value) = existing_value else {
return Ok(false);
};
e_tag = get_res.e_tag.clone();
let body = get_body(get_res).await?;
if body != existing_value {
return Ok(false);
}
} else {
if existing_value.is_some() {
return Ok(false);
}
e_tag = None;
};
#[cfg(test)]
if name.ends_with("-racing-delete") {
println!("deleting object {name}");
self.client
.delete_object()
.bucket(self.bucket.clone())
.key(name)
.send()
.await
.map_err(aws_err)?;
}
#[cfg(test)]
if name.ends_with("-racing-put") {
println!("changing object {name}");
self.client
.put_object()
.bucket(self.bucket.clone())
.key(name)
.body(b"CHANGED".to_vec().into())
.send()
.await
.map_err(aws_err)?;
}
let mut put_builder = self.client.put_object();
if let Some(e_tag) = e_tag {
put_builder = put_builder.if_match(e_tag);
} else {
put_builder = put_builder.if_none_match("*");
}
match put_builder
.bucket(self.bucket.clone())
.key(name)
.body(new_value.to_vec().into())
.send()
.await
.map_err(aws_err)
{
Ok(_) => Ok(true),
Err(err) if err.code() == Some("NoSuchKey") => Ok(false),
Err(err) if err.code() == Some("PreconditionFailed") => Ok(false),
Err(err) if err.code() == Some("ConditionalRequestConflict") => Ok(false),
Err(e) => Err(e.into()),
}
}
}
struct ObjectIterator<'a> {
service: &'a mut AwsService,
prefix: String,
last_response: Option<ListObjectsV2Output>,
next_index: usize,
}
impl ObjectIterator<'_> {
async fn fetch_batch(&mut self) -> Result<()> {
let mut continuation_token = None;
if let Some(ref resp) = self.last_response {
continuation_token.clone_from(&resp.next_continuation_token);
}
#[cfg(test)]
let max_keys = Some(8);
#[cfg(not(test))]
let max_keys = None;
self.last_response = None;
self.last_response = Some(
self.service
.client
.list_objects_v2()
.bucket(self.service.bucket.clone())
.prefix(self.prefix.clone())
.set_max_keys(max_keys)
.set_continuation_token(continuation_token)
.send()
.await
.map_err(aws_err)?,
);
self.next_index = 0;
Ok(())
}
}
#[async_trait]
impl AsyncObjectIterator for ObjectIterator<'_> {
async fn next(&mut self) -> Option<Result<ObjectInfo>> {
if self.last_response.is_none() {
if let Err(e) = self.fetch_batch().await {
return Some(Err(e));
}
}
if let Some(ref result) = self.last_response {
if let Some(ref items) = result.contents {
if self.next_index < items.len() {
let obj = &items[self.next_index];
self.next_index += 1;
let creation = obj.last_modified.map(|t| t.secs()).unwrap_or(0);
let creation: u64 = creation.try_into().unwrap_or(0);
let name = obj.key.as_ref().expect("object has no key").clone();
return Some(Ok(ObjectInfo {
name: name.clone(),
creation,
}));
} else if result.next_continuation_token.is_some() {
if let Err(e) = self.fetch_batch().await {
return Some(Err(e));
}
return self.next().await;
}
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn make_service() -> Option<AwsService> {
let fail_if_not_set = std::env::var("AWS_FAIL_IF_NOT_SET").is_ok();
let Ok(region) = std::env::var("AWS_TEST_REGION") else {
if fail_if_not_set {
panic!("AWS_TEST_REGION not set");
}
return None;
};
let Ok(bucket) = std::env::var("AWS_TEST_BUCKET") else {
if fail_if_not_set {
panic!("AWS_TEST_BUCKET not set");
}
return None;
};
let Ok(access_key_id) = std::env::var("AWS_TEST_ACCESS_KEY_ID") else {
if fail_if_not_set {
panic!("AWS_TEST_ACCESS_KEY_ID not set");
}
return None;
};
let Ok(secret_access_key) = std::env::var("AWS_TEST_SECRET_ACCESS_KEY") else {
if fail_if_not_set {
panic!("AWS_TEST_SECRET_ACCESS_KEY not set");
}
return None;
};
let endpoint_url = std::env::var("AWS_TEST_ENDPOINT_URL").ok();
let force_path_style = std::env::var("AWS_TEST_FORCE_PATH_STYLE")
.map(|f| f == "1" || f == "true")
.unwrap_or(false);
Some(
AwsService::new(
Some(region),
bucket,
AwsCredentials::AccessKey {
access_key_id,
secret_access_key,
},
endpoint_url,
force_path_style,
)
.await
.unwrap(),
)
}
crate::server::cloud::test::service_tests!(make_service().await);
}