use aws_sdk_s3::Client;
use aws_sdk_s3::error::SdkError;
use aws_sdk_s3::operation::get_object_tagging::GetObjectTaggingError;
use aws_sdk_s3::types::Tag;
use futures::{StreamExt, stream};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use thiserror::Error;
use tokio::time::sleep;
use crate::command::StreamObject;
#[derive(Error, Debug)]
pub enum TagFetchError {
#[error("Access denied for s3:GetObjectTagging on {bucket}/{key}")]
AccessDenied { bucket: String, key: String },
#[error("Object not found: {bucket}/{key}")]
NotFound { bucket: String, key: String },
#[error("Request throttled by S3 API. Consider reducing --tag-concurrency.")]
Throttled,
#[error("S3 API error: {0}")]
ApiError(String),
#[error("Missing object key")]
MissingKey,
}
impl TagFetchError {
pub fn is_retryable(&self) -> bool {
matches!(self, TagFetchError::Throttled)
}
}
#[derive(Debug, Default)]
pub struct TagFetchStats {
pub success: AtomicUsize,
pub failed: AtomicUsize,
pub throttled: AtomicUsize,
pub access_denied: AtomicUsize,
}
impl TagFetchStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_success(&self) {
self.success.fetch_add(1, Ordering::Relaxed);
}
pub fn record_failure(&self) {
self.failed.fetch_add(1, Ordering::Relaxed);
}
pub fn record_throttled(&self) {
self.throttled.fetch_add(1, Ordering::Relaxed);
}
pub fn record_access_denied(&self) {
self.access_denied.fetch_add(1, Ordering::Relaxed);
}
pub fn total_events(&self) -> usize {
self.success.load(Ordering::Relaxed)
+ self.failed.load(Ordering::Relaxed)
+ self.throttled.load(Ordering::Relaxed)
+ self.access_denied.load(Ordering::Relaxed)
}
}
#[derive(Debug, Clone)]
pub struct TagFetchConfig {
pub concurrency: usize,
pub max_retries: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
}
impl Default for TagFetchConfig {
fn default() -> Self {
Self {
concurrency: 50,
max_retries: 3,
base_delay_ms: 100,
max_delay_ms: 5000,
}
}
}
impl TagFetchConfig {
pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = concurrency;
self
}
}
fn calculate_backoff_delay(attempt: u32, base_delay_ms: u64, max_delay_ms: u64) -> Duration {
let safe_attempt = attempt.min(63);
let delay_ms = base_delay_ms.saturating_mul(1u64 << safe_attempt);
let capped_delay = delay_ms.min(max_delay_ms);
let jitter = (rand_jitter() * (capped_delay as f64 / 2.0)) as u64;
Duration::from_millis(capped_delay + jitter)
}
fn rand_jitter() -> f64 {
use std::time::SystemTime;
let nanos = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0);
(nanos % 1_000_000) as f64 / 1_000_000.0
}
async fn fetch_object_tags(
client: &Client,
bucket: &str,
key: &str,
version_id: Option<&str>,
config: &TagFetchConfig,
stats: &TagFetchStats,
) -> Result<Vec<Tag>, TagFetchError> {
let mut attempt = 0;
loop {
let mut request = client.get_object_tagging().bucket(bucket).key(key);
if let Some(vid) = version_id {
request = request.version_id(vid);
}
match request.send().await {
Ok(output) => {
stats.record_success();
return Ok(output.tag_set().to_vec());
}
Err(err) => {
let fetch_error = classify_error(&err, bucket, key);
match &fetch_error {
TagFetchError::Throttled => {
stats.record_throttled();
if attempt < config.max_retries {
let delay = calculate_backoff_delay(
attempt,
config.base_delay_ms,
config.max_delay_ms,
);
sleep(delay).await;
attempt += 1;
continue;
}
stats.record_failure();
return Err(fetch_error);
}
TagFetchError::AccessDenied { .. } => {
stats.record_access_denied();
return Err(fetch_error);
}
_ => {
stats.record_failure();
return Err(fetch_error);
}
}
}
}
}
}
fn classify_error(err: &SdkError<GetObjectTaggingError>, bucket: &str, key: &str) -> TagFetchError {
match err {
SdkError::ServiceError(service_err) => {
let raw = service_err.raw();
let status = raw.status().as_u16();
match status {
403 => TagFetchError::AccessDenied {
bucket: bucket.to_string(),
key: key.to_string(),
},
404 => TagFetchError::NotFound {
bucket: bucket.to_string(),
key: key.to_string(),
},
503 | 429 => TagFetchError::Throttled,
_ => TagFetchError::ApiError(format!("HTTP {}: {:?}", status, service_err.err())),
}
}
_ => TagFetchError::ApiError(err.to_string()),
}
}
pub async fn fetch_tags_for_objects<I>(
client: Client,
bucket: String,
objects: I,
config: TagFetchConfig,
stats: Arc<TagFetchStats>,
) -> Vec<StreamObject>
where
I: IntoIterator<Item = StreamObject>,
{
let objects: Vec<StreamObject> = objects.into_iter().collect();
stream::iter(objects)
.map(|mut obj| {
let client = client.clone();
let bucket = bucket.clone();
let config = config.clone();
let stats = Arc::clone(&stats);
async move {
if obj.tags.is_some() {
return obj;
}
if obj.is_delete_marker {
obj.tags = Some(Vec::new());
return obj;
}
let key = match obj.object.key() {
Some(k) => k.to_string(),
None => {
obj.tags = Some(Vec::new());
return obj;
}
};
let version_id = obj.version_id.as_deref();
match fetch_object_tags(&client, &bucket, &key, version_id, &config, &stats).await {
Ok(tags) => {
obj.tags = Some(tags);
}
Err(e) => {
eprintln!("Warning: Failed to fetch tags for {}: {}", key, e);
obj.tags = Some(Vec::new());
}
}
obj
}
})
.buffer_unordered(config.concurrency)
.collect()
.await
}
#[cfg(test)]
mod tests {
use super::*;
use aws_sdk_s3::types::Object;
#[test]
fn test_tag_fetch_stats() {
let stats = TagFetchStats::new();
assert_eq!(stats.total_events(), 0);
stats.record_success();
stats.record_success();
stats.record_failure();
stats.record_throttled();
stats.record_access_denied();
assert_eq!(stats.success.load(Ordering::Relaxed), 2);
assert_eq!(stats.failed.load(Ordering::Relaxed), 1);
assert_eq!(stats.throttled.load(Ordering::Relaxed), 1);
assert_eq!(stats.access_denied.load(Ordering::Relaxed), 1);
assert_eq!(stats.total_events(), 5);
}
#[test]
fn test_tag_fetch_config_default() {
let config = TagFetchConfig::default();
assert_eq!(config.concurrency, 50);
assert_eq!(config.max_retries, 3);
assert_eq!(config.base_delay_ms, 100);
assert_eq!(config.max_delay_ms, 5000);
}
#[test]
fn test_tag_fetch_config_with_concurrency() {
let config = TagFetchConfig::default().with_concurrency(100);
assert_eq!(config.concurrency, 100);
}
#[test]
fn test_calculate_backoff_delay() {
let delay0 = calculate_backoff_delay(0, 100, 5000);
let delay1 = calculate_backoff_delay(1, 100, 5000);
let delay2 = calculate_backoff_delay(2, 100, 5000);
assert!(delay0.as_millis() >= 100 && delay0.as_millis() <= 150);
assert!(delay1.as_millis() >= 200 && delay1.as_millis() <= 300);
assert!(delay2.as_millis() >= 400 && delay2.as_millis() <= 600);
}
#[test]
fn test_calculate_backoff_delay_capped() {
let delay = calculate_backoff_delay(10, 100, 5000);
assert!(delay.as_millis() <= 7500);
}
#[test]
fn test_tag_fetch_error_retryable() {
assert!(TagFetchError::Throttled.is_retryable());
assert!(
!TagFetchError::AccessDenied {
bucket: "test".to_string(),
key: "test".to_string()
}
.is_retryable()
);
assert!(
!TagFetchError::NotFound {
bucket: "test".to_string(),
key: "test".to_string()
}
.is_retryable()
);
assert!(!TagFetchError::ApiError("test".to_string()).is_retryable());
assert!(!TagFetchError::MissingKey.is_retryable());
}
#[test]
fn test_stream_object_with_tags() {
let object = Object::builder().key("test.txt").build();
let stream_obj = StreamObject {
object,
version_id: None,
is_latest: None,
is_delete_marker: false,
tags: Some(vec![]),
};
assert!(stream_obj.tags.is_some());
}
#[test]
fn test_delete_marker_has_empty_tags() {
let object = Object::builder().key("deleted.txt").build();
let mut stream_obj = StreamObject {
object,
version_id: None,
is_latest: None,
is_delete_marker: true,
tags: None,
};
if stream_obj.is_delete_marker {
stream_obj.tags = Some(Vec::new());
}
assert!(stream_obj.tags.is_some());
assert!(stream_obj.tags.unwrap().is_empty());
}
#[test]
fn test_classify_error_non_service_error() {
use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
use aws_smithy_runtime_api::client::result::SdkError;
let timeout_err: SdkError<GetObjectTaggingError, HttpResponse> = SdkError::timeout_error(
Box::new(std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout")),
);
let result = classify_error(&timeout_err, "bucket", "key");
assert!(
matches!(result, TagFetchError::ApiError(_)),
"Expected ApiError, got: {:?}",
result
);
}
#[test]
fn test_tag_fetch_error_display() {
let err = TagFetchError::AccessDenied {
bucket: "my-bucket".to_string(),
key: "my-key".to_string(),
};
assert!(err.to_string().contains("Access denied"));
assert!(err.to_string().contains("my-bucket"));
assert!(err.to_string().contains("my-key"));
let err = TagFetchError::NotFound {
bucket: "bucket".to_string(),
key: "key".to_string(),
};
assert!(err.to_string().contains("Object not found"));
let err = TagFetchError::Throttled;
assert!(err.to_string().contains("throttled"));
assert!(err.to_string().contains("tag-concurrency"));
let err = TagFetchError::ApiError("Custom error".to_string());
assert!(err.to_string().contains("Custom error"));
let err = TagFetchError::MissingKey;
assert!(err.to_string().contains("Missing object key"));
}
#[test]
fn test_rand_jitter() {
for _ in 0..100 {
let jitter = rand_jitter();
assert!(
jitter >= 0.0 && jitter < 1.0,
"Jitter out of range: {}",
jitter
);
}
}
#[test]
fn test_tag_fetch_stats_concurrent_updates() {
use std::sync::Arc;
use std::thread;
let stats = Arc::new(TagFetchStats::new());
let mut handles = vec![];
for _ in 0..4 {
let stats_clone = Arc::clone(&stats);
handles.push(thread::spawn(move || {
for _ in 0..100 {
stats_clone.record_success();
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(stats.success.load(Ordering::Relaxed), 400);
}
#[test]
fn test_calculate_backoff_with_zero_base_delay() {
let delay = calculate_backoff_delay(0, 0, 5000);
assert!(delay.as_millis() <= 2500); }
#[test]
fn test_calculate_backoff_overflow_protection() {
let delay = calculate_backoff_delay(u32::MAX, 100, 5000);
assert!(delay.as_millis() <= 7500);
}
#[test]
fn test_object_without_key_handling() {
let object = Object::builder().build(); let stream_obj = StreamObject {
object,
version_id: None,
is_latest: None,
is_delete_marker: false,
tags: None,
};
assert!(stream_obj.key().is_none());
}
use aws_config::BehaviorVersion;
use aws_smithy_runtime::client::http::test_util::{ReplayEvent, StaticReplayClient};
use aws_smithy_types::body::SdkBody;
use http::{HeaderValue, StatusCode};
fn make_test_client(events: Vec<ReplayEvent>) -> Client {
let replay_client = StaticReplayClient::new(events);
Client::from_conf(
aws_sdk_s3::Config::builder()
.behavior_version(BehaviorVersion::latest())
.credentials_provider(aws_sdk_s3::config::Credentials::new(
"test", "test", None, None, "test",
))
.region(aws_sdk_s3::config::Region::new("us-east-1"))
.http_client(replay_client)
.build(),
)
}
fn make_tag_response(key: &str, tags: &[(&str, &str)]) -> ReplayEvent {
let uri = format!(
"https://test-bucket.s3.us-east-1.amazonaws.com/{}?tagging",
key
);
let req = http::Request::builder()
.method("GET")
.uri(&uri)
.body(SdkBody::empty())
.unwrap();
let tag_xml: String = tags
.iter()
.map(|(k, v)| format!("<Tag><Key>{}</Key><Value>{}</Value></Tag>", k, v))
.collect::<Vec<_>>()
.join("");
let resp_body = format!(
r#"<?xml version="1.0" encoding="UTF-8"?>
<Tagging xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<TagSet>{}</TagSet>
</Tagging>"#,
tag_xml
);
let resp = http::Response::builder()
.status(StatusCode::OK)
.header("Content-Type", HeaderValue::from_static("application/xml"))
.body(SdkBody::from(resp_body))
.unwrap();
ReplayEvent::new(req, resp)
}
fn make_error_response(key: &str, status: u16) -> ReplayEvent {
let uri = format!(
"https://test-bucket.s3.us-east-1.amazonaws.com/{}?tagging",
key
);
let req = http::Request::builder()
.method("GET")
.uri(&uri)
.body(SdkBody::empty())
.unwrap();
let resp = http::Response::builder()
.status(StatusCode::from_u16(status).unwrap())
.header("Content-Type", HeaderValue::from_static("application/xml"))
.body(SdkBody::from(""))
.unwrap();
ReplayEvent::new(req, resp)
}
#[tokio::test]
async fn test_fetch_tags_for_objects_success() {
let events = vec![
make_tag_response("file1.txt", &[("env", "prod"), ("team", "data")]),
make_tag_response("file2.txt", &[("env", "dev")]),
];
let client = make_test_client(events);
let stats = Arc::new(TagFetchStats::new());
let config = TagFetchConfig::default().with_concurrency(1);
let objects = vec![
StreamObject::from_object(Object::builder().key("file1.txt").build()),
StreamObject::from_object(Object::builder().key("file2.txt").build()),
];
let results = fetch_tags_for_objects(
client,
"test-bucket".to_string(),
objects,
config,
Arc::clone(&stats),
)
.await;
assert_eq!(results.len(), 2);
assert!(results[0].tags.is_some());
assert!(results[1].tags.is_some());
assert_eq!(stats.success.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn test_fetch_tags_for_objects_with_delete_marker() {
let client = make_test_client(vec![]);
let stats = Arc::new(TagFetchStats::new());
let config = TagFetchConfig::default();
let object = Object::builder().key("deleted.txt").build();
let objects = vec![StreamObject {
object,
version_id: Some("v1".to_string()),
is_latest: Some(true),
is_delete_marker: true,
tags: None,
}];
let results = fetch_tags_for_objects(
client,
"test-bucket".to_string(),
objects,
config,
Arc::clone(&stats),
)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].tags.is_some());
assert!(results[0].tags.as_ref().unwrap().is_empty());
assert_eq!(stats.success.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_fetch_tags_for_objects_already_has_tags() {
let client = make_test_client(vec![]);
let stats = Arc::new(TagFetchStats::new());
let config = TagFetchConfig::default();
let object = Object::builder().key("cached.txt").build();
let existing_tags = vec![Tag::builder().key("cached").value("true").build().unwrap()];
let objects = vec![StreamObject {
object,
version_id: None,
is_latest: None,
is_delete_marker: false,
tags: Some(existing_tags),
}];
let results = fetch_tags_for_objects(
client,
"test-bucket".to_string(),
objects,
config,
Arc::clone(&stats),
)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].tags.is_some());
assert_eq!(results[0].tags.as_ref().unwrap().len(), 1);
assert_eq!(stats.success.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_fetch_tags_for_objects_access_denied() {
let events = vec![make_error_response("forbidden.txt", 403)];
let client = make_test_client(events);
let stats = Arc::new(TagFetchStats::new());
let config = TagFetchConfig::default().with_concurrency(1);
let objects = vec![StreamObject::from_object(
Object::builder().key("forbidden.txt").build(),
)];
let results = fetch_tags_for_objects(
client,
"test-bucket".to_string(),
objects,
config,
Arc::clone(&stats),
)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].tags.is_some());
assert!(results[0].tags.as_ref().unwrap().is_empty());
assert_eq!(stats.access_denied.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_fetch_tags_for_objects_not_found() {
let events = vec![make_error_response("missing.txt", 404)];
let client = make_test_client(events);
let stats = Arc::new(TagFetchStats::new());
let config = TagFetchConfig::default().with_concurrency(1);
let objects = vec![StreamObject::from_object(
Object::builder().key("missing.txt").build(),
)];
let results = fetch_tags_for_objects(
client,
"test-bucket".to_string(),
objects,
config,
Arc::clone(&stats),
)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].tags.is_some());
assert!(results[0].tags.as_ref().unwrap().is_empty());
assert_eq!(stats.failed.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_fetch_tags_for_objects_without_key() {
let client = make_test_client(vec![]);
let stats = Arc::new(TagFetchStats::new());
let config = TagFetchConfig::default();
let objects = vec![StreamObject::from_object(Object::builder().build())];
let results = fetch_tags_for_objects(
client,
"test-bucket".to_string(),
objects,
config,
Arc::clone(&stats),
)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].tags.is_some());
assert!(results[0].tags.as_ref().unwrap().is_empty());
assert_eq!(stats.success.load(Ordering::Relaxed), 0);
}
fn make_versioned_tag_response(
key: &str,
version_id: &str,
tags: &[(&str, &str)],
) -> ReplayEvent {
let uri = format!(
"https://test-bucket.s3.us-east-1.amazonaws.com/{}?tagging&versionId={}",
key, version_id
);
let req = http::Request::builder()
.method("GET")
.uri(&uri)
.body(SdkBody::empty())
.unwrap();
let tag_xml: String = tags
.iter()
.map(|(k, v)| format!("<Tag><Key>{}</Key><Value>{}</Value></Tag>", k, v))
.collect::<Vec<_>>()
.join("");
let resp_body = format!(
r#"<?xml version="1.0" encoding="UTF-8"?>
<Tagging xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<TagSet>{}</TagSet>
</Tagging>"#,
tag_xml
);
let resp = http::Response::builder()
.status(StatusCode::OK)
.header("Content-Type", HeaderValue::from_static("application/xml"))
.body(SdkBody::from(resp_body))
.unwrap();
ReplayEvent::new(req, resp)
}
#[tokio::test]
async fn test_fetch_tags_for_objects_with_version_id() {
let events = vec![make_versioned_tag_response(
"versioned.txt",
"v123",
&[("version", "v123")],
)];
let client = make_test_client(events);
let stats = Arc::new(TagFetchStats::new());
let config = TagFetchConfig::default().with_concurrency(1);
let object = Object::builder().key("versioned.txt").build();
let objects = vec![StreamObject {
object,
version_id: Some("v123".to_string()),
is_latest: Some(true),
is_delete_marker: false,
tags: None,
}];
let results = fetch_tags_for_objects(
client,
"test-bucket".to_string(),
objects,
config,
Arc::clone(&stats),
)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].tags.is_some());
let tags = results[0].tags.as_ref().unwrap();
assert_eq!(tags.len(), 1);
assert_eq!(tags[0].key(), "version");
assert_eq!(tags[0].value(), "v123");
assert_eq!(stats.success.load(Ordering::Relaxed), 1);
}
fn make_throttle_response(key: &str) -> ReplayEvent {
let uri = format!(
"https://test-bucket.s3.us-east-1.amazonaws.com/{}?tagging",
key
);
let req = http::Request::builder()
.method("GET")
.uri(&uri)
.body(SdkBody::empty())
.unwrap();
let resp = http::Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.header("Content-Type", HeaderValue::from_static("application/xml"))
.body(SdkBody::from(
r#"<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>SlowDown</Code>
<Message>Reduce your request rate.</Message>
</Error>"#,
))
.unwrap();
ReplayEvent::new(req, resp)
}
#[tokio::test]
async fn test_fetch_tags_for_objects_throttle_then_success() {
let events = vec![
make_throttle_response("retry.txt"),
make_tag_response("retry.txt", &[("retried", "true")]),
];
let client = make_test_client(events);
let stats = Arc::new(TagFetchStats::new());
let config = TagFetchConfig {
concurrency: 1,
max_retries: 3,
base_delay_ms: 1,
max_delay_ms: 10,
};
let objects = vec![StreamObject::from_object(
Object::builder().key("retry.txt").build(),
)];
let results = fetch_tags_for_objects(
client,
"test-bucket".to_string(),
objects,
config,
Arc::clone(&stats),
)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].tags.is_some());
let tags = results[0].tags.as_ref().unwrap();
assert_eq!(tags.len(), 1);
assert_eq!(tags[0].key(), "retried");
assert_eq!(stats.throttled.load(Ordering::Relaxed), 1);
assert_eq!(stats.success.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_fetch_tags_for_objects_throttle_exhausts_retries() {
let events = vec![
make_throttle_response("exhaust.txt"),
make_throttle_response("exhaust.txt"),
make_throttle_response("exhaust.txt"),
make_throttle_response("exhaust.txt"),
];
let client = make_test_client(events);
let stats = Arc::new(TagFetchStats::new());
let config = TagFetchConfig {
concurrency: 1,
max_retries: 3,
base_delay_ms: 1,
max_delay_ms: 10,
};
let objects = vec![StreamObject::from_object(
Object::builder().key("exhaust.txt").build(),
)];
let results = fetch_tags_for_objects(
client,
"test-bucket".to_string(),
objects,
config,
Arc::clone(&stats),
)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].tags.is_some());
assert!(results[0].tags.as_ref().unwrap().is_empty());
assert_eq!(stats.throttled.load(Ordering::Relaxed), 4);
assert_eq!(stats.failed.load(Ordering::Relaxed), 1);
}
fn make_429_response(key: &str) -> ReplayEvent {
let uri = format!(
"https://test-bucket.s3.us-east-1.amazonaws.com/{}?tagging",
key
);
let req = http::Request::builder()
.method("GET")
.uri(&uri)
.body(SdkBody::empty())
.unwrap();
let resp = http::Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("Content-Type", HeaderValue::from_static("application/xml"))
.body(SdkBody::from(
r#"<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>TooManyRequestsException</Code>
<Message>Rate exceeded</Message>
</Error>"#,
))
.unwrap();
ReplayEvent::new(req, resp)
}
#[tokio::test]
async fn test_fetch_tags_for_objects_429_retry() {
let events = vec![
make_429_response("rate-limited.txt"),
make_tag_response("rate-limited.txt", &[("ok", "true")]),
];
let client = make_test_client(events);
let stats = Arc::new(TagFetchStats::new());
let config = TagFetchConfig {
concurrency: 1,
max_retries: 3,
base_delay_ms: 1,
max_delay_ms: 10,
};
let objects = vec![StreamObject::from_object(
Object::builder().key("rate-limited.txt").build(),
)];
let results = fetch_tags_for_objects(
client,
"test-bucket".to_string(),
objects,
config,
Arc::clone(&stats),
)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].tags.is_some());
assert_eq!(stats.throttled.load(Ordering::Relaxed), 1);
assert_eq!(stats.success.load(Ordering::Relaxed), 1);
}
fn make_generic_error_response(key: &str, status: u16) -> ReplayEvent {
let uri = format!(
"https://test-bucket.s3.us-east-1.amazonaws.com/{}?tagging",
key
);
let req = http::Request::builder()
.method("GET")
.uri(&uri)
.body(SdkBody::empty())
.unwrap();
let resp = http::Response::builder()
.status(StatusCode::from_u16(status).unwrap())
.header("Content-Type", HeaderValue::from_static("application/xml"))
.body(SdkBody::from(format!(
r#"<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>InternalError</Code>
<Message>Generic error {}</Message>
</Error>"#,
status
)))
.unwrap();
ReplayEvent::new(req, resp)
}
#[tokio::test]
async fn test_fetch_tags_generic_error_no_retry() {
let events = vec![make_generic_error_response("error.txt", 500)];
let client = make_test_client(events);
let stats = Arc::new(TagFetchStats::new());
let config = TagFetchConfig {
concurrency: 1,
max_retries: 3,
base_delay_ms: 1,
max_delay_ms: 10,
};
let objects = vec![StreamObject::from_object(
Object::builder().key("error.txt").build(),
)];
let results = fetch_tags_for_objects(
client,
"test-bucket".to_string(),
objects,
config,
Arc::clone(&stats),
)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].tags.as_ref().unwrap().is_empty());
assert_eq!(stats.failed.load(Ordering::Relaxed), 1);
assert_eq!(stats.throttled.load(Ordering::Relaxed), 0);
}
}