1use bytes::Bytes;
8use std::time::Duration;
9
10#[cfg(feature = "s3")]
11use aws_config::BehaviorVersion;
12#[cfg(feature = "s3")]
13use aws_sdk_s3::{
14 Client, Config,
15 config::Region,
16 primitives::ByteStream,
17 types::{CompletedMultipartUpload, CompletedPart, ServerSideEncryption, StorageClass},
18};
19
20use crate::auth::Credentials;
21use crate::error::{CloudError, Result, S3Error};
22use crate::retry::{RetryConfig, RetryExecutor};
23
24use super::CloudStorageBackend;
25
26#[derive(Debug, Clone)]
28pub enum SseConfig {
29 None,
31 Aes256,
33 Kms {
35 key_id: String,
37 },
38}
39
40#[derive(Debug, Clone)]
42pub enum S3StorageClass {
43 Standard,
45 ReducedRedundancy,
47 InfrequentAccess,
49 OneZoneInfrequentAccess,
51 Glacier,
53 GlacierDeepArchive,
55 IntelligentTiering,
57}
58
59impl S3StorageClass {
60 #[cfg(feature = "s3")]
62 fn to_aws_storage_class(&self) -> StorageClass {
63 match self {
64 Self::Standard => StorageClass::Standard,
65 Self::ReducedRedundancy => StorageClass::ReducedRedundancy,
66 Self::InfrequentAccess => StorageClass::StandardIa,
67 Self::OneZoneInfrequentAccess => StorageClass::OnezoneIa,
68 Self::Glacier => StorageClass::Glacier,
69 Self::GlacierDeepArchive => StorageClass::DeepArchive,
70 Self::IntelligentTiering => StorageClass::IntelligentTiering,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct S3Backend {
78 pub bucket: String,
80 pub prefix: String,
82 pub region: Option<String>,
84 pub endpoint: Option<String>,
86 pub sse: SseConfig,
88 pub storage_class: S3StorageClass,
90 pub transfer_acceleration: bool,
92 pub multipart_threshold: usize,
94 pub multipart_chunk_size: usize,
96 pub timeout: Duration,
98 pub retry_config: RetryConfig,
100 pub credentials: Option<Credentials>,
102}
103
104impl S3Backend {
105 pub const DEFAULT_MULTIPART_THRESHOLD: usize = 5 * 1024 * 1024;
107
108 pub const DEFAULT_MULTIPART_CHUNK_SIZE: usize = 5 * 1024 * 1024;
110
111 #[must_use]
117 pub fn new(bucket: impl Into<String>, prefix: impl Into<String>) -> Self {
118 Self {
119 bucket: bucket.into(),
120 prefix: prefix.into(),
121 region: None,
122 endpoint: None,
123 sse: SseConfig::None,
124 storage_class: S3StorageClass::Standard,
125 transfer_acceleration: false,
126 multipart_threshold: Self::DEFAULT_MULTIPART_THRESHOLD,
127 multipart_chunk_size: Self::DEFAULT_MULTIPART_CHUNK_SIZE,
128 timeout: Duration::from_secs(300),
129 retry_config: RetryConfig::default(),
130 credentials: None,
131 }
132 }
133
134 #[must_use]
136 pub fn with_region(mut self, region: impl Into<String>) -> Self {
137 self.region = Some(region.into());
138 self
139 }
140
141 #[must_use]
143 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
144 self.endpoint = Some(endpoint.into());
145 self
146 }
147
148 #[must_use]
150 pub fn with_sse(mut self, sse: SseConfig) -> Self {
151 self.sse = sse;
152 self
153 }
154
155 #[must_use]
157 pub fn with_storage_class(mut self, storage_class: S3StorageClass) -> Self {
158 self.storage_class = storage_class;
159 self
160 }
161
162 #[must_use]
164 pub fn with_transfer_acceleration(mut self, enabled: bool) -> Self {
165 self.transfer_acceleration = enabled;
166 self
167 }
168
169 #[must_use]
171 pub fn with_multipart_threshold(mut self, threshold: usize) -> Self {
172 self.multipart_threshold = threshold;
173 self
174 }
175
176 #[must_use]
178 pub fn with_multipart_chunk_size(mut self, size: usize) -> Self {
179 self.multipart_chunk_size = size;
180 self
181 }
182
183 #[must_use]
185 pub fn with_timeout(mut self, timeout: Duration) -> Self {
186 self.timeout = timeout;
187 self
188 }
189
190 #[must_use]
192 pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
193 self.retry_config = config;
194 self
195 }
196
197 #[must_use]
199 pub fn with_credentials(mut self, credentials: Credentials) -> Self {
200 self.credentials = Some(credentials);
201 self
202 }
203
204 fn full_key(&self, key: &str) -> String {
205 if self.prefix.is_empty() {
206 key.to_string()
207 } else {
208 format!("{}/{}", self.prefix, key)
209 }
210 }
211
212 #[cfg(feature = "s3")]
213 async fn create_client(&self) -> Result<Client> {
214 let mut config_loader = aws_config::defaults(BehaviorVersion::latest());
215
216 if let Some(ref region) = self.region {
217 config_loader = config_loader.region(Region::new(region.clone()));
218 }
219
220 let sdk_config = config_loader.load().await;
221
222 let mut s3_config_builder = Config::builder()
223 .behavior_version(BehaviorVersion::latest())
224 .region(sdk_config.region().cloned());
225
226 if let Some(ref endpoint) = self.endpoint {
227 s3_config_builder = s3_config_builder
228 .endpoint_url(endpoint)
229 .force_path_style(true);
230 }
231
232 let s3_config = s3_config_builder.build();
233 Ok(Client::from_conf(s3_config))
234 }
235
236 #[cfg(feature = "s3")]
237 async fn upload_multipart(&self, key: &str, data: &[u8]) -> Result<()> {
238 let client = self.create_client().await?;
239 let full_key = self.full_key(key);
240
241 let mut create_request = client
243 .create_multipart_upload()
244 .bucket(&self.bucket)
245 .key(&full_key)
246 .storage_class(self.storage_class.to_aws_storage_class());
247
248 create_request = match &self.sse {
250 SseConfig::None => create_request,
251 SseConfig::Aes256 => {
252 create_request.server_side_encryption(ServerSideEncryption::Aes256)
253 }
254 SseConfig::Kms { key_id } => create_request
255 .server_side_encryption(ServerSideEncryption::AwsKms)
256 .ssekms_key_id(key_id),
257 };
258
259 let multipart_upload = create_request.send().await.map_err(|e| {
260 CloudError::S3(S3Error::MultipartUpload {
261 message: format!("Failed to initiate multipart upload: {e}"),
262 })
263 })?;
264
265 let upload_id = multipart_upload.upload_id().ok_or_else(|| {
266 CloudError::S3(S3Error::MultipartUpload {
267 message: "No upload ID returned".to_string(),
268 })
269 })?;
270
271 let mut completed_parts = Vec::new();
273
274 for (idx, chunk) in data.chunks(self.multipart_chunk_size).enumerate() {
275 let part_number = i32::try_from(idx + 1).map_err(|_| {
276 CloudError::S3(S3Error::MultipartUpload {
277 message: format!("Part number overflow at index {idx}"),
278 })
279 })?;
280 let part = client
281 .upload_part()
282 .bucket(&self.bucket)
283 .key(&full_key)
284 .upload_id(upload_id)
285 .part_number(part_number)
286 .body(ByteStream::from(chunk.to_vec()))
287 .send()
288 .await
289 .map_err(|e| {
290 CloudError::S3(S3Error::MultipartUpload {
291 message: format!("Failed to upload part {part_number}: {e}"),
292 })
293 })?;
294
295 if let Some(etag) = part.e_tag() {
296 completed_parts.push(
297 CompletedPart::builder()
298 .e_tag(etag)
299 .part_number(part_number)
300 .build(),
301 );
302 }
303 }
304
305 let completed_upload = CompletedMultipartUpload::builder()
307 .set_parts(Some(completed_parts))
308 .build();
309
310 client
311 .complete_multipart_upload()
312 .bucket(&self.bucket)
313 .key(&full_key)
314 .upload_id(upload_id)
315 .multipart_upload(completed_upload)
316 .send()
317 .await
318 .map_err(|e| {
319 CloudError::S3(S3Error::MultipartUpload {
320 message: format!("Failed to complete multipart upload: {e}"),
321 })
322 })?;
323
324 Ok(())
325 }
326}
327
328#[cfg(all(feature = "s3", feature = "async"))]
329#[async_trait::async_trait]
330impl CloudStorageBackend for S3Backend {
331 async fn get(&self, key: &str) -> Result<Bytes> {
332 let mut executor = RetryExecutor::new(self.retry_config.clone());
333
334 executor
335 .execute(|| async {
336 let client = self.create_client().await?;
337 let full_key = self.full_key(key);
338
339 let response = client
340 .get_object()
341 .bucket(&self.bucket)
342 .key(&full_key)
343 .send()
344 .await
345 .map_err(|e| {
346 CloudError::S3(S3Error::Sdk {
347 message: format!("Failed to get object '{full_key}': {e}"),
348 })
349 })?;
350
351 let data = response.body.collect().await.map_err(|e| {
352 CloudError::S3(S3Error::Sdk {
353 message: format!("Failed to read object body: {e}"),
354 })
355 })?;
356
357 Ok(data.into_bytes())
358 })
359 .await
360 }
361
362 async fn put(&self, key: &str, data: &[u8]) -> Result<()> {
363 if data.len() > self.multipart_threshold {
365 return self.upload_multipart(key, data).await;
366 }
367
368 let mut executor = RetryExecutor::new(self.retry_config.clone());
369
370 executor
371 .execute(|| async {
372 let client = self.create_client().await?;
373 let full_key = self.full_key(key);
374
375 let mut request = client
376 .put_object()
377 .bucket(&self.bucket)
378 .key(&full_key)
379 .body(ByteStream::from(data.to_vec()))
380 .storage_class(self.storage_class.to_aws_storage_class());
381
382 request = match &self.sse {
384 SseConfig::None => request,
385 SseConfig::Aes256 => {
386 request.server_side_encryption(ServerSideEncryption::Aes256)
387 }
388 SseConfig::Kms { key_id } => request
389 .server_side_encryption(ServerSideEncryption::AwsKms)
390 .ssekms_key_id(key_id),
391 };
392
393 request.send().await.map_err(|e| {
394 CloudError::S3(S3Error::Sdk {
395 message: format!("Failed to put object '{full_key}': {e}"),
396 })
397 })?;
398
399 Ok(())
400 })
401 .await
402 }
403
404 async fn delete(&self, key: &str) -> Result<()> {
405 let mut executor = RetryExecutor::new(self.retry_config.clone());
406
407 executor
408 .execute(|| async {
409 let client = self.create_client().await?;
410 let full_key = self.full_key(key);
411
412 client
413 .delete_object()
414 .bucket(&self.bucket)
415 .key(&full_key)
416 .send()
417 .await
418 .map_err(|e| {
419 CloudError::S3(S3Error::Sdk {
420 message: format!("Failed to delete object '{full_key}': {e}"),
421 })
422 })?;
423
424 Ok(())
425 })
426 .await
427 }
428
429 async fn exists(&self, key: &str) -> Result<bool> {
430 let client = self.create_client().await?;
431 let full_key = self.full_key(key);
432
433 match client
434 .head_object()
435 .bucket(&self.bucket)
436 .key(&full_key)
437 .send()
438 .await
439 {
440 Ok(_) => Ok(true),
441 Err(e) => {
442 let error_message = format!("{e}");
443 if error_message.contains("404") || error_message.contains("NotFound") {
444 Ok(false)
445 } else {
446 Err(CloudError::S3(S3Error::Sdk {
447 message: format!("Failed to check object existence '{full_key}': {e}"),
448 }))
449 }
450 }
451 }
452 }
453
454 async fn list_prefix(&self, prefix: &str) -> Result<Vec<String>> {
455 let client = self.create_client().await?;
456 let full_prefix = self.full_key(prefix);
457
458 let mut results = Vec::new();
459 let mut continuation_token: Option<String> = None;
460
461 loop {
462 let mut request = client
463 .list_objects_v2()
464 .bucket(&self.bucket)
465 .prefix(&full_prefix);
466
467 if let Some(ref token) = continuation_token {
468 request = request.continuation_token(token);
469 }
470
471 let response = request.send().await.map_err(|e| {
472 CloudError::S3(S3Error::Sdk {
473 message: format!("Failed to list objects with prefix '{full_prefix}': {e}"),
474 })
475 })?;
476
477 if let Some(contents) = response.contents {
478 for object in contents {
479 if let Some(key) = object.key {
480 let relative_key = if !self.prefix.is_empty() {
482 key.strip_prefix(&format!("{}/", self.prefix))
483 .unwrap_or(&key)
484 .to_string()
485 } else {
486 key
487 };
488 results.push(relative_key);
489 }
490 }
491 }
492
493 if response.is_truncated == Some(true) {
494 continuation_token = response.next_continuation_token;
495 } else {
496 break;
497 }
498 }
499
500 Ok(results)
501 }
502
503 fn is_readonly(&self) -> bool {
504 false
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_s3_backend_new() {
514 let backend = S3Backend::new("my-bucket", "data/zarr");
515 assert_eq!(backend.bucket, "my-bucket");
516 assert_eq!(backend.prefix, "data/zarr");
517 }
518
519 #[test]
520 fn test_s3_backend_builder() {
521 let backend = S3Backend::new("my-bucket", "data")
522 .with_region("us-west-2")
523 .with_sse(SseConfig::Aes256)
524 .with_storage_class(S3StorageClass::IntelligentTiering)
525 .with_transfer_acceleration(true)
526 .with_multipart_threshold(10 * 1024 * 1024)
527 .with_timeout(Duration::from_secs(600));
528
529 assert_eq!(backend.region, Some("us-west-2".to_string()));
530 assert!(matches!(backend.sse, SseConfig::Aes256));
531 assert!(matches!(
532 backend.storage_class,
533 S3StorageClass::IntelligentTiering
534 ));
535 assert!(backend.transfer_acceleration);
536 assert_eq!(backend.multipart_threshold, 10 * 1024 * 1024);
537 assert_eq!(backend.timeout, Duration::from_secs(600));
538 }
539
540 #[test]
541 fn test_s3_backend_full_key() {
542 let backend = S3Backend::new("bucket", "prefix");
543 assert_eq!(backend.full_key("file.txt"), "prefix/file.txt");
544
545 let backend_no_prefix = S3Backend::new("bucket", "");
546 assert_eq!(backend_no_prefix.full_key("file.txt"), "file.txt");
547 }
548}