use std::path::Path;
use std::sync::Arc;
use aws_config::BehaviorVersion;
use aws_sdk_s3::config::Region;
use aws_sdk_s3::error::{ProvideErrorMetadata, SdkError};
use aws_sdk_s3::Client;
use aws_smithy_http_client::tls::rustls_provider::CryptoMode;
use aws_smithy_http_client::tls::Provider as TlsProvider;
use aws_smithy_http_client::Builder as HttpClientBuilder;
use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
use snapdir_core::manifest::Manifest;
use snapdir_core::merkle::{Blake3Hasher, Hasher};
use snapdir_core::store::{manifest_path, object_path, Store, StoreError};
use snapdir_core::Meter;
use crate::fetch::fetch_files_concurrent;
use crate::push::{push_objects_concurrent, upload_object};
use crate::retry::{parse_retry_after, retry_network, Attempt, DefaultJitter, TokioSleeper};
use crate::stream::StreamStore;
use crate::transfer::{classify_error, RateLimiter, TransferConfig};
use std::error::Error as StdError;
use tokio::runtime::Runtime;
const MAX_FETCH_RETRIES: u32 = 5;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct S3Location {
pub bucket: String,
pub prefix: String,
}
impl S3Location {
#[must_use]
pub fn parse(store_url: &str) -> Self {
let without_scheme = match store_url.find("://") {
Some(idx) => &store_url[idx + 3..],
None => store_url,
};
let mut parts = without_scheme.splitn(2, '/');
let bucket = parts.next().unwrap_or("").to_owned();
let prefix = parts.next().unwrap_or("");
let prefix = prefix
.trim_end_matches('/')
.trim_start_matches('/')
.to_owned();
Self { bucket, prefix }
}
#[must_use]
pub fn object_key(&self, checksum: &str) -> String {
self.key_for(&object_path(checksum))
}
#[must_use]
pub fn manifest_key(&self, id: &str) -> String {
self.key_for(&manifest_path(id))
}
fn key_for(&self, rel: &str) -> String {
let rel = rel.trim_start_matches('/');
if self.prefix.is_empty() {
rel.to_owned()
} else {
format!("{}/{rel}", self.prefix)
}
}
}
pub struct S3Store {
client: Client,
location: S3Location,
runtime: Arc<Runtime>,
config: TransferConfig,
req_limiter: RateLimiter,
meter: Option<Arc<Meter>>,
}
impl S3Store {
pub fn connect(store_url: &str, endpoint_url: Option<&str>) -> Result<Self, StoreError> {
Self::connect_with(store_url, endpoint_url, TransferConfig::default())
}
pub fn connect_with(
store_url: &str,
endpoint_url: Option<&str>,
config: TransferConfig,
) -> Result<Self, StoreError> {
let location = S3Location::parse(store_url);
let runtime = build_runtime()?;
let http_client = ring_https_client();
let endpoint = endpoint_url.map(ToOwned::to_owned);
let client = runtime.block_on(async move {
let mut loader = aws_config::defaults(BehaviorVersion::latest())
.http_client(http_client.clone())
.retry_config(aws_config::retry::RetryConfig::disabled());
if let Some(ep) = endpoint.as_deref() {
loader = loader.endpoint_url(ep);
}
let shared = loader.load().await;
let mut builder = aws_sdk_s3::config::Builder::from(&shared);
if endpoint.is_some() {
builder = builder.force_path_style(true);
}
if shared.region().is_none() {
builder = builder.region(Region::new("us-east-1"));
}
Client::from_conf(builder.build())
});
let req_limiter = RateLimiter::new(config.max_requests_per_sec);
Ok(Self {
client,
location,
runtime: Arc::new(runtime),
config,
req_limiter,
meter: None,
})
}
pub fn from_client(client: Client, location: S3Location) -> Result<Self, StoreError> {
let config = TransferConfig::default();
let req_limiter = RateLimiter::new(config.max_requests_per_sec);
Ok(Self {
client,
location,
runtime: Arc::new(build_runtime()?),
config,
req_limiter,
meter: None,
})
}
#[must_use]
pub fn with_meter(mut self, meter: Option<Arc<Meter>>) -> Self {
self.meter = meter;
self
}
#[must_use]
pub fn location(&self) -> &S3Location {
&self.location
}
#[must_use]
pub fn transfer_config(&self) -> &TransferConfig {
&self.config
}
async fn key_exists(&self, key: &str) -> Result<bool, StoreError> {
retry_network(
&self.config.retry,
&self.req_limiter,
&TokioSleeper,
&DefaultJitter::new(),
|| async {
match self
.client
.head_object()
.bucket(&self.location.bucket)
.key(key)
.send()
.await
{
Ok(_) => Ok(true),
Err(err) => {
if err.as_service_error().is_some_and(
aws_sdk_s3::operation::head_object::HeadObjectError::is_not_found,
) {
return Ok(false);
}
Err(s3_attempt_from_err("S3 HEAD object failed", err))
}
}
},
)
.await
}
async fn get_bytes(&self, key: &str) -> Result<Option<Vec<u8>>, StoreError> {
retry_network(
&self.config.retry,
&self.req_limiter,
&TokioSleeper,
&DefaultJitter::new(),
|| async {
match self
.client
.get_object()
.bucket(&self.location.bucket)
.key(key)
.send()
.await
{
Ok(resp) => {
let data = resp.body.collect().await.map_err(|e| {
let err = backend("reading S3 object body", e);
let transient =
matches!(classify_error(&err), crate::adaptive::OpResult::Throttle);
Attempt {
transient,
retry_after: None,
err,
}
})?;
Ok(Some(data.into_bytes().to_vec()))
}
Err(err) => {
if err.as_service_error().is_some_and(
aws_sdk_s3::operation::get_object::GetObjectError::is_no_such_key,
) {
return Ok(None);
}
Err(s3_attempt_from_err("S3 GET object failed", err))
}
}
},
)
.await
}
async fn put_bytes(&self, key: &str, bytes: Vec<u8>) -> Result<(), StoreError> {
retry_network(
&self.config.retry,
&self.req_limiter,
&TokioSleeper,
&DefaultJitter::new(),
|| {
let bytes = bytes.clone();
async move {
self.client
.put_object()
.bucket(&self.location.bucket)
.key(key)
.body(bytes.into())
.send()
.await
.map(|_| ())
.map_err(|err| s3_attempt_from_err("S3 PUT object failed", err))
}
},
)
.await
}
async fn fetch_verified(&self, key: &str, expected: &str) -> Result<Vec<u8>, StoreError> {
let hasher = Blake3Hasher::new();
let mut attempts_left = MAX_FETCH_RETRIES;
loop {
match self.get_bytes(key).await? {
Some(bytes) => {
let actual = hasher.hash_hex(&bytes);
if actual == expected {
return Ok(bytes);
}
attempts_left = attempts_left.saturating_sub(1);
if attempts_left == 0 {
return Err(StoreError::Integrity {
address: format!("s3://{}/{key}", self.location.bucket),
expected: expected.to_owned(),
actual,
});
}
}
None => {
return Err(StoreError::ObjectNotFound {
checksum: expected.to_owned(),
});
}
}
}
}
}
impl Store for S3Store {
fn get_manifest(&self, id: &str) -> Result<Manifest, StoreError> {
let key = self.location.manifest_key(id);
let bytes = self.runtime.block_on(async {
match self.get_bytes(&key).await? {
Some(b) => Ok(b),
None => Err(StoreError::ManifestNotFound { id: id.to_owned() }),
}
})?;
let text = String::from_utf8(bytes).map_err(|err| StoreError::Backend {
message: format!("manifest {id} is not valid UTF-8"),
source: Some(Box::new(err)),
})?;
let manifest = Manifest::parse(&text)?;
let actual = snapdir_core::merkle::snapshot_id(&manifest, &Blake3Hasher::new());
if actual != id {
return Err(StoreError::Integrity {
address: self.location.manifest_key(id),
expected: id.to_owned(),
actual,
});
}
Ok(manifest)
}
fn fetch_files(&self, manifest: &Manifest, dest: &Path) -> Result<(), StoreError> {
let limiter = RateLimiter::new(self.config.max_bytes_per_sec);
let meter = self.meter.as_deref();
let meter_arc = self.meter.clone();
self.runtime.block_on(async {
fetch_files_concurrent(
manifest,
dest,
&self.config,
&limiter,
meter,
meter_arc,
|entry| async {
let key = self.location.object_key(&entry.checksum);
self.fetch_verified(&key, &entry.checksum).await
},
)
.await
})
}
fn push(&self, manifest: &Manifest, source: &Path) -> Result<(), StoreError> {
let hasher = Blake3Hasher::new();
let id = snapdir_core::merkle::snapshot_id(manifest, &hasher);
let limiter = RateLimiter::new(self.config.max_bytes_per_sec);
let meter = self.meter.as_deref();
let meter_arc = self.meter.clone();
self.runtime.block_on(async {
let manifest_key = self.location.manifest_key(&id);
if self.key_exists(&manifest_key).await? {
return Ok(());
}
push_objects_concurrent(
manifest,
&self.config,
&limiter,
meter,
meter_arc,
|entry| {
let object_key = self.location.object_key(&entry.checksum);
upload_object(
entry,
object_key,
source,
&limiter,
meter,
|key| async move { self.key_exists(&key).await },
|key, bytes| async move { self.put_bytes(&key, bytes).await },
)
},
|| async {
let mut text = manifest.to_string();
text.push('\n');
let manifest_actual = hasher.hash_hex(text.as_bytes());
if manifest_actual != id {
return Err(StoreError::Integrity {
address: manifest_key.clone(),
expected: id.clone(),
actual: manifest_actual,
});
}
self.put_bytes(&manifest_key, text.into_bytes()).await
},
)
.await
})
}
}
impl StreamStore for S3Store {
fn has_object(&self, checksum: &str) -> Result<bool, StoreError> {
let key = self.location.object_key(checksum);
self.runtime.block_on(async { self.key_exists(&key).await })
}
fn get_object(&self, checksum: &str) -> Result<Vec<u8>, StoreError> {
let key = self.location.object_key(checksum);
let bytes = self.runtime.block_on(async {
self.get_bytes(&key)
.await?
.ok_or_else(|| StoreError::ObjectNotFound {
checksum: checksum.to_owned(),
})
})?;
let actual = Blake3Hasher::new().hash_hex(&bytes);
if actual != checksum {
return Err(StoreError::Integrity {
address: format!("s3://{}/{key}", self.location.bucket),
expected: checksum.to_owned(),
actual,
});
}
Ok(bytes)
}
fn put_object(&self, checksum: &str, bytes: Vec<u8>) -> Result<(), StoreError> {
let actual = Blake3Hasher::new().hash_hex(&bytes);
if actual != checksum {
return Err(StoreError::Integrity {
address: self.location.object_key(checksum),
expected: checksum.to_owned(),
actual,
});
}
let key = self.location.object_key(checksum);
self.runtime
.block_on(async { self.put_bytes(&key, bytes).await })
}
fn put_manifest(&self, id: &str, manifest: &Manifest) -> Result<(), StoreError> {
let key = self.location.manifest_key(id);
let mut text = manifest.to_string();
text.push('\n');
let actual = Blake3Hasher::new().hash_hex(text.as_bytes());
if actual != id {
return Err(StoreError::Integrity {
address: key,
expected: id.to_owned(),
actual,
});
}
self.runtime
.block_on(async { self.put_bytes(&key, text.into_bytes()).await })
}
}
fn build_runtime() -> Result<Runtime, StoreError> {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| backend("creating tokio runtime for S3Store", e))
}
fn ring_https_client() -> aws_smithy_runtime_api::client::http::SharedHttpClient {
HttpClientBuilder::new()
.tls_provider(TlsProvider::Rustls(CryptoMode::Ring))
.build_https()
}
fn backend<E>(message: &str, source: E) -> StoreError
where
E: std::error::Error + Send + Sync + 'static,
{
StoreError::Backend {
message: message.to_owned(),
source: Some(Box::new(source)),
}
}
fn s3_attempt_from_err<E>(message: &str, err: SdkError<E, HttpResponse>) -> Attempt
where
E: ProvideErrorMetadata + StdError + Send + Sync + 'static,
{
let (http_status, retry_after) = match err.raw_response() {
Some(resp) => {
let status = resp.status().as_u16();
let hint = resp
.headers()
.get("retry-after")
.and_then(parse_retry_after);
(Some(status), hint)
}
None => (None, None),
};
let code = err.code().unwrap_or_default().to_owned();
let store_err = backend(message, err);
let transient = http_status.is_some_and(|s| s == 429 || s == 503)
|| matches!(
classify_error(&store_err),
crate::adaptive::OpResult::Throttle
)
|| {
let c = code.to_ascii_lowercase();
c.contains("slowdown")
|| c.contains("throttl")
|| c.contains("requesttimeout")
|| c.contains("serviceunavailable")
|| c.contains("internalerror")
};
Attempt {
transient,
retry_after,
err: store_err,
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_sdk_s3::error::ErrorMetadata;
use aws_sdk_s3::operation::head_object::HeadObjectError;
use aws_sdk_s3::primitives::SdkBody;
use aws_smithy_runtime_api::http::StatusCode;
use snapdir_core::manifest::PathType;
use std::time::Duration;
fn raw_response(status: u16, retry_after: Option<&str>) -> HttpResponse {
let mut resp = HttpResponse::new(
StatusCode::try_from(status).expect("valid status"),
SdkBody::empty(),
);
if let Some(v) = retry_after {
resp.headers_mut().insert("retry-after", v.to_owned());
}
resp
}
fn s3_service_error(code: &str, status: u16, retry_after: Option<&str>) -> Attempt {
let meta = ErrorMetadata::builder().code(code).build();
let svc = HeadObjectError::generic(meta);
let err = SdkError::service_error(svc, raw_response(status, retry_after));
s3_attempt_from_err("S3 op failed", err)
}
#[test]
fn backoff_wire_s3_extract_503_retry_after_is_transient_with_hint() {
let attempt = s3_service_error("SlowDown", 503, Some("12"));
assert!(attempt.transient, "503/SlowDown must be transient");
assert_eq!(
attempt.retry_after,
Some(Duration::from_secs(12)),
"the Retry-After delta-seconds header must be extracted"
);
}
#[test]
fn backoff_wire_s3_extract_429_without_header_is_transient_no_hint() {
let attempt = s3_service_error("Throttling", 429, None);
assert!(attempt.transient, "429 must be transient");
assert_eq!(
attempt.retry_after, None,
"absent Retry-After header => None (backoff handles the delay)"
);
}
#[test]
fn backoff_wire_s3_extract_404_is_not_transient() {
let attempt = s3_service_error("NoSuchKey", 404, None);
assert!(
!attempt.transient,
"a 404/not-found must never be classified transient"
);
assert_eq!(attempt.retry_after, None);
}
fn strip_leading_dot_slash(path: &str) -> &str {
let trimmed = path.strip_prefix("./").unwrap_or(path);
trimmed.strip_suffix('/').unwrap_or(trimmed)
}
const FOO_CHECKSUM: &str = "49dc870df1de7fd60794cebce449f5ccdae575affaa67a24b62acb03e039db92";
const FOO_SHARDED: &str = "49d/c87/0df/1de7fd60794cebce449f5ccdae575affaa67a24b62acb03e039db92";
const MANIFEST_ID: &str = "aa91e498f401ea9e6ddbaa1138a0dbeb030fab8defc1252d80c77ebefafbc70d";
const MANIFEST_SHARDED: &str =
"aa9/1e4/98f/401ea9e6ddbaa1138a0dbeb030fab8defc1252d80c77ebefafbc70d";
#[test]
fn s3_store_parses_bucket_and_prefix() {
let loc = S3Location::parse("s3://my-bucket/long/term/storage");
assert_eq!(loc.bucket, "my-bucket");
assert_eq!(loc.prefix, "long/term/storage");
}
#[test]
fn s3_store_parse_matches_oracle_cut_fields() {
let loc = S3Location::parse("s3://bucket/a/b/c");
assert_eq!(loc.bucket, "bucket");
assert_eq!(loc.prefix, "a/b/c");
}
#[test]
fn s3_store_parse_strips_trailing_slash() {
let loc = S3Location::parse("s3://bucket/prefix/");
assert_eq!(loc.bucket, "bucket");
assert_eq!(loc.prefix, "prefix");
}
#[test]
fn s3_store_parse_bucket_root_has_empty_prefix() {
let loc = S3Location::parse("s3://bucket");
assert_eq!(loc.bucket, "bucket");
assert_eq!(loc.prefix, "");
let loc_slash = S3Location::parse("s3://bucket/");
assert_eq!(loc_slash.bucket, "bucket");
assert_eq!(loc_slash.prefix, "");
}
#[test]
fn s3_store_parse_accepts_bare_bucket_prefix_without_scheme() {
let loc = S3Location::parse("bucket/some/prefix");
assert_eq!(loc.bucket, "bucket");
assert_eq!(loc.prefix, "some/prefix");
}
#[test]
fn s3_store_object_key_matches_sharded_scheme() {
let loc = S3Location::parse("s3://b/long/term/storage");
assert_eq!(
loc.object_key(FOO_CHECKSUM),
format!("long/term/storage/.objects/{FOO_SHARDED}")
);
}
#[test]
fn s3_store_manifest_key_matches_sharded_scheme() {
let loc = S3Location::parse("s3://b/long/term/storage");
assert_eq!(
loc.manifest_key(MANIFEST_ID),
format!("long/term/storage/.manifests/{MANIFEST_SHARDED}")
);
}
#[test]
fn s3_store_keys_have_no_leading_slash_at_bucket_root() {
let loc = S3Location::parse("s3://bucket");
assert_eq!(
loc.object_key(FOO_CHECKSUM),
format!(".objects/{FOO_SHARDED}")
);
assert_eq!(
loc.manifest_key(MANIFEST_ID),
format!(".manifests/{MANIFEST_SHARDED}")
);
}
#[test]
fn s3_store_object_key_uses_core_object_path() {
let loc = S3Location::parse("s3://b");
assert_eq!(loc.object_key(FOO_CHECKSUM), object_path(FOO_CHECKSUM));
}
#[test]
fn s3_store_strip_leading_dot_slash() {
assert_eq!(strip_leading_dot_slash("./foo"), "foo");
assert_eq!(strip_leading_dot_slash("./a/b/c"), "a/b/c");
assert_eq!(strip_leading_dot_slash("./a/"), "a");
assert_eq!(strip_leading_dot_slash("./"), "");
}
#[test]
fn s3_store_live_round_trip_when_configured() {
use snapdir_core::manifest::ManifestEntry;
let (Ok(endpoint), Ok(store)) = (
std::env::var("SNAPDIR_S3_TEST_ENDPOINT"),
std::env::var("SNAPDIR_S3_TEST_STORE"),
) else {
eprintln!(
"skipping s3_store live round-trip: set SNAPDIR_S3_TEST_ENDPOINT \
and SNAPDIR_S3_TEST_STORE (s3://bucket/prefix) to run it"
);
return;
};
let hasher = Blake3Hasher::new();
let src = std::env::temp_dir().join(format!("snapdir-s3-live-{}", std::process::id()));
std::fs::create_dir_all(&src).unwrap();
std::fs::write(src.join("foo"), b"foo\n").unwrap();
let foo_sum = hasher.hash_hex(b"foo\n");
let root_sum = snapdir_core::merkle::directory_checksum([foo_sum.as_str()], &hasher);
let mut manifest = Manifest::new();
manifest.push(ManifestEntry::new(
PathType::Directory,
"700",
root_sum,
4,
"./",
));
manifest.push(ManifestEntry::new(
PathType::File,
"600",
foo_sum,
4,
"./foo",
));
let manifest = Manifest::from_entries(manifest.entries().to_vec());
let id = snapdir_core::merkle::snapshot_id(&manifest, &hasher);
let s3 = S3Store::connect(&store, Some(&endpoint)).expect("connect");
s3.push(&manifest, &src).expect("push");
let read_back = s3.get_manifest(&id).expect("get_manifest");
assert_eq!(read_back, manifest);
let dest = std::env::temp_dir().join(format!("snapdir-s3-dest-{}", std::process::id()));
std::fs::create_dir_all(&dest).unwrap();
s3.fetch_files(&read_back, &dest).expect("fetch_files");
assert_eq!(std::fs::read(dest.join("foo")).unwrap(), b"foo\n");
let _ = std::fs::remove_dir_all(&src);
let _ = std::fs::remove_dir_all(&dest);
}
}