use std::path::Path;
use std::sync::Arc;
use aws_config::BehaviorVersion;
use aws_sdk_s3::config::Region;
use aws_sdk_s3::Client;
use aws_smithy_http_client::tls::rustls_provider::CryptoMode;
use aws_smithy_http_client::tls::Provider as TlsProvider;
use aws_smithy_http_client::Builder as HttpClientBuilder;
use snapdir_core::manifest::Manifest;
use snapdir_core::merkle::{Blake3Hasher, Hasher};
use snapdir_core::store::{manifest_path, object_path, Store, StoreError};
use crate::fetch::fetch_files_concurrent;
use crate::push::{push_objects_concurrent, upload_object};
use crate::transfer::{RateLimiter, TransferConfig};
use tokio::runtime::Runtime;
const MAX_FETCH_RETRIES: u32 = 5;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct S3Location {
pub bucket: String,
pub prefix: String,
}
impl S3Location {
#[must_use]
pub fn parse(store_url: &str) -> Self {
let without_scheme = match store_url.find("://") {
Some(idx) => &store_url[idx + 3..],
None => store_url,
};
let mut parts = without_scheme.splitn(2, '/');
let bucket = parts.next().unwrap_or("").to_owned();
let prefix = parts.next().unwrap_or("");
let prefix = prefix
.trim_end_matches('/')
.trim_start_matches('/')
.to_owned();
Self { bucket, prefix }
}
#[must_use]
pub fn object_key(&self, checksum: &str) -> String {
self.key_for(&object_path(checksum))
}
#[must_use]
pub fn manifest_key(&self, id: &str) -> String {
self.key_for(&manifest_path(id))
}
fn key_for(&self, rel: &str) -> String {
let rel = rel.trim_start_matches('/');
if self.prefix.is_empty() {
rel.to_owned()
} else {
format!("{}/{rel}", self.prefix)
}
}
}
pub struct S3Store {
client: Client,
location: S3Location,
runtime: Arc<Runtime>,
config: TransferConfig,
}
impl S3Store {
pub fn connect(store_url: &str, endpoint_url: Option<&str>) -> Result<Self, StoreError> {
Self::connect_with(store_url, endpoint_url, TransferConfig::default())
}
pub fn connect_with(
store_url: &str,
endpoint_url: Option<&str>,
config: TransferConfig,
) -> Result<Self, StoreError> {
let location = S3Location::parse(store_url);
let runtime = build_runtime()?;
let http_client = ring_https_client();
let endpoint = endpoint_url.map(ToOwned::to_owned);
let client = runtime.block_on(async move {
let mut loader =
aws_config::defaults(BehaviorVersion::latest()).http_client(http_client.clone());
if let Some(ep) = endpoint.as_deref() {
loader = loader.endpoint_url(ep);
}
let shared = loader.load().await;
let mut builder = aws_sdk_s3::config::Builder::from(&shared);
if endpoint.is_some() {
builder = builder.force_path_style(true);
}
if shared.region().is_none() {
builder = builder.region(Region::new("us-east-1"));
}
Client::from_conf(builder.build())
});
Ok(Self {
client,
location,
runtime: Arc::new(runtime),
config,
})
}
pub fn from_client(client: Client, location: S3Location) -> Result<Self, StoreError> {
Ok(Self {
client,
location,
runtime: Arc::new(build_runtime()?),
config: TransferConfig::default(),
})
}
#[must_use]
pub fn location(&self) -> &S3Location {
&self.location
}
#[must_use]
pub fn transfer_config(&self) -> &TransferConfig {
&self.config
}
async fn key_exists(&self, key: &str) -> Result<bool, StoreError> {
match self
.client
.head_object()
.bucket(&self.location.bucket)
.key(key)
.send()
.await
{
Ok(_) => Ok(true),
Err(err) => {
let svc = err.into_service_error();
if svc.is_not_found() {
Ok(false)
} else {
Err(backend("S3 HEAD object failed", svc))
}
}
}
}
async fn get_bytes(&self, key: &str) -> Result<Option<Vec<u8>>, StoreError> {
match self
.client
.get_object()
.bucket(&self.location.bucket)
.key(key)
.send()
.await
{
Ok(resp) => {
let data = resp
.body
.collect()
.await
.map_err(|e| backend("reading S3 object body", e))?;
Ok(Some(data.into_bytes().to_vec()))
}
Err(err) => {
let svc = err.into_service_error();
if svc.is_no_such_key() {
Ok(None)
} else {
Err(backend("S3 GET object failed", svc))
}
}
}
}
async fn put_bytes(&self, key: &str, bytes: Vec<u8>) -> Result<(), StoreError> {
self.client
.put_object()
.bucket(&self.location.bucket)
.key(key)
.body(bytes.into())
.send()
.await
.map_err(|e| backend("S3 PUT object failed", e))?;
Ok(())
}
async fn fetch_verified(&self, key: &str, expected: &str) -> Result<Vec<u8>, StoreError> {
let hasher = Blake3Hasher::new();
let mut attempts_left = MAX_FETCH_RETRIES;
loop {
match self.get_bytes(key).await? {
Some(bytes) => {
let actual = hasher.hash_hex(&bytes);
if actual == expected {
return Ok(bytes);
}
attempts_left = attempts_left.saturating_sub(1);
if attempts_left == 0 {
return Err(StoreError::Integrity {
address: format!("s3://{}/{key}", self.location.bucket),
expected: expected.to_owned(),
actual,
});
}
}
None => {
return Err(StoreError::ObjectNotFound {
checksum: expected.to_owned(),
});
}
}
}
}
}
impl Store for S3Store {
fn get_manifest(&self, id: &str) -> Result<Manifest, StoreError> {
let key = self.location.manifest_key(id);
let bytes = self.runtime.block_on(async {
match self.get_bytes(&key).await? {
Some(b) => Ok(b),
None => Err(StoreError::ManifestNotFound { id: id.to_owned() }),
}
})?;
let text = String::from_utf8(bytes).map_err(|err| StoreError::Backend {
message: format!("manifest {id} is not valid UTF-8"),
source: Some(Box::new(err)),
})?;
let manifest = Manifest::parse(&text)?;
let actual = snapdir_core::merkle::snapshot_id(&manifest, &Blake3Hasher::new());
if actual != id {
return Err(StoreError::Integrity {
address: self.location.manifest_key(id),
expected: id.to_owned(),
actual,
});
}
Ok(manifest)
}
fn fetch_files(&self, manifest: &Manifest, dest: &Path) -> Result<(), StoreError> {
let limiter = RateLimiter::new(self.config.max_bytes_per_sec);
self.runtime.block_on(async {
fetch_files_concurrent(manifest, dest, &self.config, &limiter, |entry| async {
let key = self.location.object_key(&entry.checksum);
self.fetch_verified(&key, &entry.checksum).await
})
.await
})
}
fn push(&self, manifest: &Manifest, source: &Path) -> Result<(), StoreError> {
let hasher = Blake3Hasher::new();
let id = snapdir_core::merkle::snapshot_id(manifest, &hasher);
let limiter = RateLimiter::new(self.config.max_bytes_per_sec);
self.runtime.block_on(async {
let manifest_key = self.location.manifest_key(&id);
if self.key_exists(&manifest_key).await? {
return Ok(());
}
push_objects_concurrent(
manifest,
&self.config,
|entry| {
let object_key = self.location.object_key(&entry.checksum);
upload_object(
entry,
object_key,
source,
&limiter,
|key| async move { self.key_exists(&key).await },
|key, bytes| async move { self.put_bytes(&key, bytes).await },
)
},
|| async {
let mut text = manifest.to_string();
text.push('\n');
let manifest_actual = hasher.hash_hex(text.as_bytes());
if manifest_actual != id {
return Err(StoreError::Integrity {
address: manifest_key.clone(),
expected: id.clone(),
actual: manifest_actual,
});
}
self.put_bytes(&manifest_key, text.into_bytes()).await
},
)
.await
})
}
}
fn build_runtime() -> Result<Runtime, StoreError> {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| backend("creating tokio runtime for S3Store", e))
}
fn ring_https_client() -> aws_smithy_runtime_api::client::http::SharedHttpClient {
HttpClientBuilder::new()
.tls_provider(TlsProvider::Rustls(CryptoMode::Ring))
.build_https()
}
fn backend<E>(message: &str, source: E) -> StoreError
where
E: std::error::Error + Send + Sync + 'static,
{
StoreError::Backend {
message: message.to_owned(),
source: Some(Box::new(source)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use snapdir_core::manifest::PathType;
fn strip_leading_dot_slash(path: &str) -> &str {
let trimmed = path.strip_prefix("./").unwrap_or(path);
trimmed.strip_suffix('/').unwrap_or(trimmed)
}
const FOO_CHECKSUM: &str = "49dc870df1de7fd60794cebce449f5ccdae575affaa67a24b62acb03e039db92";
const FOO_SHARDED: &str = "49d/c87/0df/1de7fd60794cebce449f5ccdae575affaa67a24b62acb03e039db92";
const MANIFEST_ID: &str = "aa91e498f401ea9e6ddbaa1138a0dbeb030fab8defc1252d80c77ebefafbc70d";
const MANIFEST_SHARDED: &str =
"aa9/1e4/98f/401ea9e6ddbaa1138a0dbeb030fab8defc1252d80c77ebefafbc70d";
#[test]
fn s3_store_parses_bucket_and_prefix() {
let loc = S3Location::parse("s3://my-bucket/long/term/storage");
assert_eq!(loc.bucket, "my-bucket");
assert_eq!(loc.prefix, "long/term/storage");
}
#[test]
fn s3_store_parse_matches_oracle_cut_fields() {
let loc = S3Location::parse("s3://bucket/a/b/c");
assert_eq!(loc.bucket, "bucket");
assert_eq!(loc.prefix, "a/b/c");
}
#[test]
fn s3_store_parse_strips_trailing_slash() {
let loc = S3Location::parse("s3://bucket/prefix/");
assert_eq!(loc.bucket, "bucket");
assert_eq!(loc.prefix, "prefix");
}
#[test]
fn s3_store_parse_bucket_root_has_empty_prefix() {
let loc = S3Location::parse("s3://bucket");
assert_eq!(loc.bucket, "bucket");
assert_eq!(loc.prefix, "");
let loc_slash = S3Location::parse("s3://bucket/");
assert_eq!(loc_slash.bucket, "bucket");
assert_eq!(loc_slash.prefix, "");
}
#[test]
fn s3_store_parse_accepts_bare_bucket_prefix_without_scheme() {
let loc = S3Location::parse("bucket/some/prefix");
assert_eq!(loc.bucket, "bucket");
assert_eq!(loc.prefix, "some/prefix");
}
#[test]
fn s3_store_object_key_matches_sharded_scheme() {
let loc = S3Location::parse("s3://b/long/term/storage");
assert_eq!(
loc.object_key(FOO_CHECKSUM),
format!("long/term/storage/.objects/{FOO_SHARDED}")
);
}
#[test]
fn s3_store_manifest_key_matches_sharded_scheme() {
let loc = S3Location::parse("s3://b/long/term/storage");
assert_eq!(
loc.manifest_key(MANIFEST_ID),
format!("long/term/storage/.manifests/{MANIFEST_SHARDED}")
);
}
#[test]
fn s3_store_keys_have_no_leading_slash_at_bucket_root() {
let loc = S3Location::parse("s3://bucket");
assert_eq!(
loc.object_key(FOO_CHECKSUM),
format!(".objects/{FOO_SHARDED}")
);
assert_eq!(
loc.manifest_key(MANIFEST_ID),
format!(".manifests/{MANIFEST_SHARDED}")
);
}
#[test]
fn s3_store_object_key_uses_core_object_path() {
let loc = S3Location::parse("s3://b");
assert_eq!(loc.object_key(FOO_CHECKSUM), object_path(FOO_CHECKSUM));
}
#[test]
fn s3_store_strip_leading_dot_slash() {
assert_eq!(strip_leading_dot_slash("./foo"), "foo");
assert_eq!(strip_leading_dot_slash("./a/b/c"), "a/b/c");
assert_eq!(strip_leading_dot_slash("./a/"), "a");
assert_eq!(strip_leading_dot_slash("./"), "");
}
#[test]
fn s3_store_live_round_trip_when_configured() {
use snapdir_core::manifest::ManifestEntry;
let (Ok(endpoint), Ok(store)) = (
std::env::var("SNAPDIR_S3_TEST_ENDPOINT"),
std::env::var("SNAPDIR_S3_TEST_STORE"),
) else {
eprintln!(
"skipping s3_store live round-trip: set SNAPDIR_S3_TEST_ENDPOINT \
and SNAPDIR_S3_TEST_STORE (s3://bucket/prefix) to run it"
);
return;
};
let hasher = Blake3Hasher::new();
let src = std::env::temp_dir().join(format!("snapdir-s3-live-{}", std::process::id()));
std::fs::create_dir_all(&src).unwrap();
std::fs::write(src.join("foo"), b"foo\n").unwrap();
let foo_sum = hasher.hash_hex(b"foo\n");
let root_sum = snapdir_core::merkle::directory_checksum([foo_sum.as_str()], &hasher);
let mut manifest = Manifest::new();
manifest.push(ManifestEntry::new(
PathType::Directory,
"700",
root_sum,
4,
"./",
));
manifest.push(ManifestEntry::new(
PathType::File,
"600",
foo_sum,
4,
"./foo",
));
let manifest = Manifest::from_entries(manifest.entries().to_vec());
let id = snapdir_core::merkle::snapshot_id(&manifest, &hasher);
let s3 = S3Store::connect(&store, Some(&endpoint)).expect("connect");
s3.push(&manifest, &src).expect("push");
let read_back = s3.get_manifest(&id).expect("get_manifest");
assert_eq!(read_back, manifest);
let dest = std::env::temp_dir().join(format!("snapdir-s3-dest-{}", std::process::id()));
std::fs::create_dir_all(&dest).unwrap();
s3.fetch_files(&read_back, &dest).expect("fetch_files");
assert_eq!(std::fs::read(dest.join("foo")).unwrap(), b"foo\n");
let _ = std::fs::remove_dir_all(&src);
let _ = std::fs::remove_dir_all(&dest);
}
}