use {
crate::{
error::{DebianError, Result},
io::{ContentDigest, MultiDigester},
repository::{
RepositoryPathVerification, RepositoryPathVerificationState, RepositoryWrite,
RepositoryWriter,
},
},
async_trait::async_trait,
futures::{AsyncRead, AsyncReadExt as FuturesAsyncReadExt},
rusoto_core::{ByteStream, Client, Region, RusotoError},
rusoto_s3::{
GetBucketLocationRequest, GetObjectError, GetObjectRequest, HeadObjectError,
HeadObjectRequest, PutObjectRequest, S3Client, S3,
},
std::{borrow::Cow, pin::Pin, str::FromStr},
tokio::io::AsyncReadExt as TokioAsyncReadExt,
};
pub struct S3Writer {
client: S3Client,
bucket: String,
key_prefix: Option<String>,
}
impl S3Writer {
pub fn new(region: Region, bucket: impl ToString, key_prefix: Option<&str>) -> Self {
Self {
client: S3Client::new(region),
bucket: bucket.to_string(),
key_prefix: key_prefix.map(|x| x.trim_matches('/').to_string()),
}
}
pub fn new_with_client(
client: Client,
region: Region,
bucket: impl ToString,
key_prefix: Option<&str>,
) -> Self {
Self {
client: S3Client::new_with_client(client, region),
bucket: bucket.to_string(),
key_prefix: key_prefix.map(|x| x.trim_matches('/').to_string()),
}
}
pub fn path_to_key(&self, path: &str) -> String {
if let Some(prefix) = &self.key_prefix {
format!("{}/{}", prefix, path.trim_matches('/'))
} else {
path.trim_matches('/').to_string()
}
}
}
#[async_trait]
impl RepositoryWriter for S3Writer {
async fn verify_path<'path>(
&self,
path: &'path str,
expected_content: Option<(u64, ContentDigest)>,
) -> Result<RepositoryPathVerification<'path>> {
if let Some((expected_size, expected_digest)) = expected_content {
let req = GetObjectRequest {
bucket: self.bucket.clone(),
key: self.path_to_key(path),
..Default::default()
};
match self.client.get_object(req).await {
Ok(output) => {
if let Some(cl) = output.content_length {
if cl as u64 != expected_size {
return Ok(RepositoryPathVerification {
path,
state: RepositoryPathVerificationState::ExistsIntegrityMismatch,
});
}
}
if let Some(body) = output.body {
let mut digester = MultiDigester::default();
let mut remaining = expected_size;
let mut reader = body.into_async_read();
let mut buf = [0u8; 16384];
loop {
let size = reader
.read(&mut buf[..])
.await
.map_err(|e| DebianError::RepositoryIoPath(path.to_string(), e))?;
digester.update(&buf[0..size]);
let size = size as u64;
if size >= remaining || size == 0 {
break;
}
remaining -= size;
}
let digests = digester.finish();
Ok(RepositoryPathVerification {
path,
state: if !digests.matches_digest(&expected_digest) {
RepositoryPathVerificationState::ExistsIntegrityMismatch
} else {
RepositoryPathVerificationState::ExistsIntegrityVerified
},
})
} else {
Ok(RepositoryPathVerification {
path,
state: RepositoryPathVerificationState::Missing,
})
}
}
Err(RusotoError::Service(GetObjectError::NoSuchKey(_))) => {
Ok(RepositoryPathVerification {
path,
state: RepositoryPathVerificationState::Missing,
})
}
Err(e) => Err(DebianError::RepositoryIoPath(
path.to_string(),
std::io::Error::new(std::io::ErrorKind::Other, format!("S3 error: {:?}", e)),
)),
}
} else {
let req = HeadObjectRequest {
bucket: self.bucket.clone(),
key: self.path_to_key(path),
..Default::default()
};
match self.client.head_object(req).await {
Ok(_) => Ok(RepositoryPathVerification {
path,
state: RepositoryPathVerificationState::ExistsNoIntegrityCheck,
}),
Err(RusotoError::Service(HeadObjectError::NoSuchKey(_))) => {
Ok(RepositoryPathVerification {
path,
state: RepositoryPathVerificationState::Missing,
})
}
Err(e) => Err(DebianError::RepositoryIoPath(
path.to_string(),
std::io::Error::new(std::io::ErrorKind::Other, format!("S3 error: {:?}", e)),
)),
}
}
}
async fn write_path<'path, 'reader>(
&self,
path: Cow<'path, str>,
mut reader: Pin<Box<dyn AsyncRead + Send + 'reader>>,
) -> Result<RepositoryWrite<'path>> {
let mut buf = vec![];
reader
.read_to_end(&mut buf)
.await
.map_err(|e| DebianError::RepositoryIoPath(path.to_string(), e))?;
let bytes_written = buf.len() as u64;
let stream = futures::stream::once(async { Ok(bytes::Bytes::from(buf)) });
let req = PutObjectRequest {
bucket: self.bucket.clone(),
key: self.path_to_key(path.as_ref()),
body: Some(ByteStream::new(stream)),
..Default::default()
};
match self.client.put_object(req).await {
Ok(_) => Ok(RepositoryWrite {
path,
bytes_written,
}),
Err(e) => Err(DebianError::RepositoryIoPath(
path.to_string(),
std::io::Error::new(std::io::ErrorKind::Other, format!("S3 error: {:?}", e)),
)),
}
}
}
pub async fn get_bucket_region(bucket: impl ToString) -> Result<Region> {
get_bucket_region_with_client(S3Client::new(Region::UsEast1), bucket).await
}
pub async fn get_bucket_region_with_client(
client: S3Client,
bucket: impl ToString,
) -> Result<Region> {
let req = GetBucketLocationRequest {
bucket: bucket.to_string(),
..Default::default()
};
match client.get_bucket_location(req).await {
Ok(res) => {
if let Some(constraint) = res.location_constraint {
Ok(Region::from_str(&constraint)
.map_err(|_| DebianError::S3BadRegion(constraint))?)
} else {
Ok(Region::UsEast1)
}
}
Err(e) => Err(DebianError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
format!("S3 error: {:?}", e),
))),
}
}