use super::service::{validate_object_name, ObjectInfo, Service};
use crate::errors::Result;
use aws_config::{
meta::region::RegionProviderChain, profile::ProfileFileCredentialsProvider, BehaviorVersion,
Region,
};
use aws_credential_types::Credentials;
use aws_sdk_s3::{
self as s3,
error::ProvideErrorMetadata,
operation::{get_object::GetObjectOutput, list_objects_v2::ListObjectsV2Output},
};
use std::future::Future;
use tokio::runtime::Runtime;
pub(in crate::server) struct AwsService {
client: s3::Client,
rt: Runtime,
bucket: String,
}
#[non_exhaustive]
pub enum AwsCredentials {
AccessKey {
access_key_id: String,
secret_access_key: String,
},
Profile { profile_name: String },
Default,
}
impl AwsService {
pub(in crate::server) fn new(
region: String,
bucket: String,
creds: AwsCredentials,
) -> Result<Self> {
let rt = Runtime::new()?;
let config =
rt.block_on(async {
let mut config_provider = aws_config::defaults(BehaviorVersion::v2024_03_28());
match creds {
AwsCredentials::AccessKey {
access_key_id,
secret_access_key,
} => {
config_provider = config_provider.credentials_provider(
Credentials::from_keys(access_key_id, secret_access_key, None),
);
}
AwsCredentials::Profile { profile_name } => {
config_provider = config_provider.credentials_provider(
ProfileFileCredentialsProvider::builder()
.profile_name(profile_name)
.build(),
);
}
AwsCredentials::Default => {
}
}
config_provider
.region(RegionProviderChain::first_try(Region::new(region)))
.load()
.await
});
let client = s3::client::Client::new(&config);
Ok(Self { client, rt, bucket })
}
fn block_on<T, F: Future<Output = Result<T>>>(&self, fut: F) -> Result<T> {
self.rt.block_on(fut)
}
}
fn aws_err<E: Into<s3::Error>>(err: E) -> s3::Error {
err.into()
}
#[allow(clippy::result_large_err)] fn if_key_exists<T>(
res: std::result::Result<T, s3::Error>,
) -> std::result::Result<Option<T>, s3::Error> {
res
.map(Some)
.or_else(|err| match err {
s3::Error::NoSuchKey(_) => Ok(None),
err => Err(err),
})
}
async fn get_body(get_res: GetObjectOutput) -> Result<Vec<u8>> {
Ok(get_res.body.collect().await?.to_vec())
}
impl Service for AwsService {
fn put(&mut self, name: &str, value: &[u8]) -> Result<()> {
self.block_on(async {
validate_object_name(name);
self.client
.put_object()
.bucket(self.bucket.clone())
.key(name)
.body(value.to_vec().into())
.send()
.await
.map_err(aws_err)?;
Ok(())
})
}
fn get(&mut self, name: &str) -> Result<Option<Vec<u8>>> {
self.block_on(async {
validate_object_name(name);
let Some(get_res) = if_key_exists(
self.client
.get_object()
.bucket(self.bucket.clone())
.key(name)
.send()
.await
.map_err(aws_err),
)?
else {
return Ok(None);
};
Ok(Some(get_body(get_res).await?))
})
}
fn del(&mut self, name: &str) -> Result<()> {
self.block_on(async {
validate_object_name(name);
self.client
.delete_object()
.bucket(self.bucket.clone())
.key(name)
.send()
.await
.map_err(aws_err)?;
Ok(())
})
}
fn list<'a>(&'a mut self, prefix: &str) -> Box<dyn Iterator<Item = Result<ObjectInfo>> + 'a> {
validate_object_name(prefix);
Box::new(ObjectIterator {
service: self,
prefix: prefix.to_string(),
last_response: None,
next_index: 0,
})
}
fn compare_and_swap(
&mut self,
name: &str,
existing_value: Option<Vec<u8>>,
new_value: Vec<u8>,
) -> Result<bool> {
self.block_on(async {
validate_object_name(name);
let get_res = if_key_exists(
self.client
.get_object()
.bucket(self.bucket.clone())
.key(name)
.send()
.await
.map_err(aws_err),
)?;
let e_tag;
if let Some(get_res) = get_res {
let Some(existing_value) = existing_value else {
return Ok(false);
};
e_tag = get_res.e_tag.clone();
let body = get_body(get_res).await?;
if body != existing_value {
return Ok(false);
}
} else {
if existing_value.is_some() {
return Ok(false);
}
e_tag = None;
};
#[cfg(test)]
if name.ends_with("-racing-delete") {
println!("deleting object {name}");
self.client
.delete_object()
.bucket(self.bucket.clone())
.key(name)
.send()
.await
.map_err(aws_err)?;
}
#[cfg(test)]
if name.ends_with("-racing-put") {
println!("changing object {name}");
self.client
.put_object()
.bucket(self.bucket.clone())
.key(name)
.body(b"CHANGED".to_vec().into())
.send()
.await
.map_err(aws_err)?;
}
let mut put_builder = self.client.put_object();
if let Some(e_tag) = e_tag {
put_builder = put_builder.if_match(e_tag);
} else {
put_builder = put_builder.if_none_match("*");
}
match put_builder
.bucket(self.bucket.clone())
.key(name)
.body(new_value.to_vec().into())
.send()
.await
.map_err(aws_err)
{
Ok(_) => Ok(true),
Err(err) if err.code() == Some("NoSuchKey") => Ok(false),
Err(err) if err.code() == Some("PreconditionFailed") => Ok(false),
Err(err) if err.code() == Some("ConditionalRequestConflict") => Ok(false),
Err(e) => Err(e.into()),
}
})
}
}
struct ObjectIterator<'a> {
service: &'a mut AwsService,
prefix: String,
last_response: Option<ListObjectsV2Output>,
next_index: usize,
}
impl ObjectIterator<'_> {
fn fetch_batch(&mut self) -> Result<()> {
let mut continuation_token = None;
if let Some(ref resp) = self.last_response {
continuation_token.clone_from(&resp.next_continuation_token);
}
self.last_response = None;
self.last_response = Some(self.service.block_on(async {
#[cfg(test)]
let max_keys = Some(8);
#[cfg(not(test))]
let max_keys = None;
Ok(self
.service
.client
.list_objects_v2()
.bucket(self.service.bucket.clone())
.prefix(self.prefix.clone())
.set_max_keys(max_keys)
.set_continuation_token(continuation_token)
.send()
.await
.map_err(aws_err)?)
})?);
self.next_index = 0;
Ok(())
}
}
impl Iterator for ObjectIterator<'_> {
type Item = Result<ObjectInfo>;
fn next(&mut self) -> Option<Self::Item> {
if self.last_response.is_none() {
if let Err(e) = self.fetch_batch() {
return Some(Err(e));
}
}
if let Some(ref result) = self.last_response {
if let Some(ref items) = result.contents {
if self.next_index < items.len() {
let obj = &items[self.next_index];
self.next_index += 1;
let creation = obj.last_modified.map(|t| t.secs()).unwrap_or(0);
let creation: u64 = creation.try_into().unwrap_or(0);
let name = obj.key.as_ref().expect("object has no key").clone();
return Some(Ok(ObjectInfo {
name: name.clone(),
creation,
}));
} else if result.next_continuation_token.is_some() {
if let Err(e) = self.fetch_batch() {
return Some(Err(e));
}
return self.next();
}
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_service() -> Option<AwsService> {
let Ok(region) = std::env::var("AWS_TEST_REGION") else {
return None;
};
let Ok(bucket) = std::env::var("AWS_TEST_BUCKET") else {
return None;
};
let Ok(access_key_id) = std::env::var("AWS_TEST_ACCESS_KEY_ID") else {
return None;
};
let Ok(secret_access_key) = std::env::var("AWS_TEST_SECRET_ACCESS_KEY") else {
return None;
};
Some(
AwsService::new(
region,
bucket,
AwsCredentials::AccessKey {
access_key_id,
secret_access_key,
},
)
.unwrap(),
)
}
crate::server::cloud::test::service_tests!(make_service());
}