use crate::error::ActiveStorageError;
use crate::resource_manager::ResourceManager;
use aws_credential_types::Credentials;
use aws_sdk_s3::operation::head_object::HeadObjectError;
use aws_sdk_s3::Client;
use aws_sdk_s3::{config::BehaviorVersion, error::SdkError};
use aws_smithy_runtime_api::http::Response;
use aws_types::region::Region;
use axum::body::Bytes;
use hashbrown::HashMap;
use tokio::sync::{RwLock, SemaphorePermit};
use tracing::Instrument;
use url::Url;
use urlencoding;
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum S3Credentials {
AccessKey {
access_key: String,
secret_key: String,
},
None,
}
impl S3Credentials {
pub fn access_key(access_key: &str, secret_key: &str) -> Self {
S3Credentials::AccessKey {
access_key: access_key.to_string(),
secret_key: secret_key.to_string(),
}
}
}
#[derive(Debug)]
pub struct S3ClientMap {
map: RwLock<HashMap<(Url, S3Credentials), S3Client>>,
}
impl S3ClientMap {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
S3ClientMap {
map: RwLock::new(HashMap::new()),
}
}
pub async fn get(&self, url: &Url, credentials: S3Credentials) -> S3Client {
let key = (url.clone(), credentials.clone());
{
let map = self.map.read().await;
if let Some(client) = map.get(&key) {
return client.clone();
}
}
let mut map = self.map.write().await;
if let Some(client) = map.get(&key) {
client.clone()
} else {
tracing::info!("Creating new S3 client for {}", url);
let client = S3Client::new(url, credentials).await;
let (_, client) = map.insert_unique_unchecked(key, client);
client.clone()
}
}
}
#[derive(Clone, Debug)]
pub struct S3Client {
client: Client,
}
impl S3Client {
pub async fn new(url: &Url, credentials: S3Credentials) -> Self {
let region = Region::new("us-east-1");
let builder = aws_sdk_s3::Config::builder().behavior_version(BehaviorVersion::latest());
let builder = match credentials {
S3Credentials::AccessKey {
access_key,
secret_key,
} => {
let credentials = Credentials::from_keys(access_key, secret_key, None);
builder.credentials_provider(credentials)
}
S3Credentials::None => builder,
};
let s3_config = builder
.region(Some(region))
.endpoint_url(url.to_string())
.force_path_style(true)
.build();
let client = Client::from_conf(s3_config);
Self { client }
}
pub async fn is_authorised(
&self,
bucket: &str,
key: &str,
) -> Result<bool, SdkError<HeadObjectError, Response>> {
let response = self
.client
.head_object()
.bucket(bucket)
.key(key)
.send()
.instrument(tracing::Span::current())
.await;
match response {
Ok(_) => Ok(true),
Err(err) => match &err {
aws_smithy_runtime_api::client::result::SdkError::ServiceError(inner) => {
match inner.raw().status().as_u16() {
403 => Ok(false), _ => Err(err),
}
}
_ => Err(err),
},
}
}
pub async fn download_object<'a>(
&self,
bucket: &str,
key: &str,
range: Option<String>,
resource_manager: &'a ResourceManager,
mem_permits: &mut Option<SemaphorePermit<'a>>,
) -> Result<Bytes, ActiveStorageError> {
let mut response = self
.client
.get_object()
.bucket(bucket)
.key(key)
.set_range(range)
.send()
.instrument(tracing::Span::current())
.await?;
let content_length: usize = response
.content_length()
.ok_or(ActiveStorageError::S3ContentLengthMissing)?
.try_into()?;
match mem_permits {
None => {
*mem_permits = resource_manager.memory(content_length).await?;
}
Some(permits) => {
if permits.num_permits() == 0 {
*mem_permits = resource_manager.memory(content_length).await?;
}
}
}
let mut buf = maligned::align_first::<u8, maligned::A8>(content_length);
while let Some(bytes) = response
.body
.try_next()
.instrument(tracing::Span::current())
.await?
{
buf.extend_from_slice(&bytes)
}
Ok(buf.into())
}
}
pub fn parse_s3_url(url: &Url) -> Result<(Url, String, String), ActiveStorageError> {
let mut segments = url
.path_segments()
.ok_or_else(|| ActiveStorageError::S3RequestError {
error: "S3 URL must have path segments".to_string(),
})?
.peekable();
let bucket = segments
.next()
.ok_or_else(|| ActiveStorageError::S3RequestError {
error: "S3 URL must have bucket".to_string(),
})?
.to_string();
let bucket = urlencoding::decode(&bucket)
.map_err(|e| ActiveStorageError::S3RequestError {
error: format!("Failed to decode bucket name: {e}"),
})?
.to_string();
if segments.peek().is_none() {
return Err(ActiveStorageError::S3RequestError {
error: "S3 URL must have object".to_string(),
});
}
let object = segments.collect::<Vec<_>>().join("/");
let object = urlencoding::decode(&object)
.map_err(|e| ActiveStorageError::S3RequestError {
error: format!("Failed to decode object name: {e}"),
})?
.to_string();
let mut source_url = url.clone();
source_url.set_path("/");
Ok((source_url, bucket, object))
}
pub fn get_range(offset: Option<usize>, size: Option<usize>) -> Option<String> {
match (offset, size) {
(offset, Some(size)) => {
let offset = offset.unwrap_or(0);
let end = offset + size - 1;
Some(format!("bytes={offset}-{end}"))
}
(Some(offset), None) => Some(format!("bytes={offset}-")),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use url::Url;
fn make_access_key() -> S3Credentials {
S3Credentials::access_key("user", "password")
}
fn make_alt_access_key() -> S3Credentials {
S3Credentials::access_key("user2", "password")
}
#[tokio::test]
async fn s3_client_map() {
let url = Url::parse("http://example.com").unwrap();
let map = S3ClientMap::new();
map.get(&url, make_access_key()).await;
map.get(&url, make_access_key()).await;
assert_eq!(map.map.read().await.len(), 1);
map.get(&url, make_alt_access_key()).await;
assert_eq!(map.map.read().await.len(), 2);
map.get(&url, S3Credentials::None).await;
map.get(&url, S3Credentials::None).await;
assert_eq!(map.map.read().await.len(), 3);
}
#[tokio::test]
async fn new() {
let url = Url::parse("http://example.com").unwrap();
S3Client::new(&url, make_access_key()).await;
}
#[tokio::test]
async fn new_no_auth() {
let url = Url::parse("http://example.com").unwrap();
S3Client::new(&url, S3Credentials::None).await;
}
#[test]
fn parse_s3_url_valid() {
let url = Url::parse("http://example.com:8080/bucket/test--operation-min-dtype-uint64--shape-[10, 5, 2]-etc.bin").unwrap();
let (source_url, bucket, object) = parse_s3_url(&url).unwrap();
assert_eq!(source_url.as_str(), "http://example.com:8080/");
assert_eq!(bucket, "bucket");
assert_eq!(
object,
"test--operation-min-dtype-uint64--shape-[10, 5, 2]-etc.bin"
);
}
#[test]
fn parse_s3_url_valid2() {
let url = Url::parse("http://example.com:8080/bucket/a/test--operation-min-dtype-uint64--shape-[10, 5, 2]-etc.bin").unwrap();
let (source_url, bucket, object) = parse_s3_url(&url).unwrap();
assert_eq!(source_url.as_str(), "http://example.com:8080/");
assert_eq!(bucket, "bucket");
assert_eq!(
object,
"a/test--operation-min-dtype-uint64--shape-[10, 5, 2]-etc.bin"
);
}
#[test]
fn parse_s3_url_valid3() {
let url = Url::parse("http://example.com:8080/bucket/a/b/test--operation-min-dtype-uint64--shape-[10, 5, 2]-etc.bin").unwrap();
let (source_url, bucket, object) = parse_s3_url(&url).unwrap();
assert_eq!(source_url.as_str(), "http://example.com:8080/");
assert_eq!(bucket, "bucket");
assert_eq!(
object,
"a/b/test--operation-min-dtype-uint64--shape-[10, 5, 2]-etc.bin"
);
}
#[test]
fn parse_s3_url_invalid_source_url() {
let url = Url::parse("example.com:8080/bucket/object.bin").unwrap();
assert!(
parse_s3_url(&url).is_err(),
"S3 URL must have path segments"
);
}
#[test]
fn parse_s3_url_invalid_bucket() {
let url = Url::parse("http://example.com:8080/").unwrap();
assert!(parse_s3_url(&url).is_err(), "S3 URL must have bucket");
}
#[test]
fn parse_s3_url_invalid_object() {
let url = Url::parse("example.com:8080/bucket/").unwrap();
assert!(parse_s3_url(&url).is_err(), "S3 URL must have object");
}
#[test]
fn get_range_none() {
assert_eq!(None, get_range(None, None));
}
#[test]
fn get_range_both() {
assert_eq!(Some("bytes=1-2".to_string()), get_range(Some(1), Some(2)));
}
#[test]
fn get_range_offset() {
assert_eq!(Some("bytes=1-".to_string()), get_range(Some(1), None));
}
#[test]
fn get_range_size() {
assert_eq!(Some("bytes=0-1".to_string()), get_range(None, Some(2)));
}
}