1use bytes::Bytes;
8use std::time::Duration;
9
10#[cfg(feature = "s3")]
11use aws_config::BehaviorVersion;
12#[cfg(feature = "s3")]
13use aws_sdk_s3::{
14 Client, Config,
15 config::Region,
16 primitives::ByteStream,
17 types::{CompletedMultipartUpload, CompletedPart, ServerSideEncryption, StorageClass},
18};
19
20use crate::auth::Credentials;
21use crate::error::{CloudError, Result, S3Error};
22use crate::retry::{RetryConfig, RetryExecutor};
23
24use super::CloudStorageBackend;
25
26#[derive(Debug, Clone)]
28pub enum SseConfig {
29 None,
31 Aes256,
33 Kms {
35 key_id: String,
37 },
38}
39
40#[derive(Debug, Clone)]
42pub enum S3StorageClass {
43 Standard,
45 ReducedRedundancy,
47 InfrequentAccess,
49 OneZoneInfrequentAccess,
51 Glacier,
53 GlacierDeepArchive,
55 IntelligentTiering,
57}
58
59impl S3StorageClass {
60 #[cfg(feature = "s3")]
62 fn to_aws_storage_class(&self) -> StorageClass {
63 match self {
64 Self::Standard => StorageClass::Standard,
65 Self::ReducedRedundancy => StorageClass::ReducedRedundancy,
66 Self::InfrequentAccess => StorageClass::StandardIa,
67 Self::OneZoneInfrequentAccess => StorageClass::OnezoneIa,
68 Self::Glacier => StorageClass::Glacier,
69 Self::GlacierDeepArchive => StorageClass::DeepArchive,
70 Self::IntelligentTiering => StorageClass::IntelligentTiering,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct S3Backend {
78 pub bucket: String,
80 pub prefix: String,
82 pub region: Option<String>,
84 pub endpoint: Option<String>,
86 pub sse: SseConfig,
88 pub storage_class: S3StorageClass,
90 pub transfer_acceleration: bool,
92 pub multipart_threshold: usize,
94 pub multipart_chunk_size: usize,
96 pub timeout: Duration,
98 pub retry_config: RetryConfig,
100 pub credentials: Option<Credentials>,
102}
103
104impl S3Backend {
105 pub const DEFAULT_MULTIPART_THRESHOLD: usize = 5 * 1024 * 1024;
107
108 pub const DEFAULT_MULTIPART_CHUNK_SIZE: usize = 5 * 1024 * 1024;
110
111 #[must_use]
117 pub fn new(bucket: impl Into<String>, prefix: impl Into<String>) -> Self {
118 Self {
119 bucket: bucket.into(),
120 prefix: prefix.into(),
121 region: None,
122 endpoint: None,
123 sse: SseConfig::None,
124 storage_class: S3StorageClass::Standard,
125 transfer_acceleration: false,
126 multipart_threshold: Self::DEFAULT_MULTIPART_THRESHOLD,
127 multipart_chunk_size: Self::DEFAULT_MULTIPART_CHUNK_SIZE,
128 timeout: Duration::from_secs(300),
129 retry_config: RetryConfig::default(),
130 credentials: None,
131 }
132 }
133
134 #[must_use]
136 pub fn with_region(mut self, region: impl Into<String>) -> Self {
137 self.region = Some(region.into());
138 self
139 }
140
141 #[must_use]
143 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
144 self.endpoint = Some(endpoint.into());
145 self
146 }
147
148 #[must_use]
150 pub fn with_sse(mut self, sse: SseConfig) -> Self {
151 self.sse = sse;
152 self
153 }
154
155 #[must_use]
157 pub fn with_storage_class(mut self, storage_class: S3StorageClass) -> Self {
158 self.storage_class = storage_class;
159 self
160 }
161
162 #[must_use]
164 pub fn with_transfer_acceleration(mut self, enabled: bool) -> Self {
165 self.transfer_acceleration = enabled;
166 self
167 }
168
169 #[must_use]
171 pub fn with_multipart_threshold(mut self, threshold: usize) -> Self {
172 self.multipart_threshold = threshold;
173 self
174 }
175
176 #[must_use]
178 pub fn with_multipart_chunk_size(mut self, size: usize) -> Self {
179 self.multipart_chunk_size = size;
180 self
181 }
182
183 #[must_use]
185 pub fn with_timeout(mut self, timeout: Duration) -> Self {
186 self.timeout = timeout;
187 self
188 }
189
190 #[must_use]
192 pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
193 self.retry_config = config;
194 self
195 }
196
197 #[must_use]
199 pub fn with_credentials(mut self, credentials: Credentials) -> Self {
200 self.credentials = Some(credentials);
201 self
202 }
203
204 fn full_key(&self, key: &str) -> String {
205 if self.prefix.is_empty() {
206 key.to_string()
207 } else {
208 format!("{}/{}", self.prefix, key)
209 }
210 }
211
212 #[cfg(feature = "s3")]
213 async fn create_client(&self) -> Result<Client> {
214 let mut config_loader = aws_config::defaults(BehaviorVersion::latest());
215
216 if let Some(ref region) = self.region {
217 config_loader = config_loader.region(Region::new(region.clone()));
218 }
219
220 let sdk_config = config_loader.load().await;
221
222 let mut s3_config_builder = Config::builder()
223 .behavior_version(BehaviorVersion::latest())
224 .region(sdk_config.region().cloned());
225
226 if let Some(ref endpoint) = self.endpoint {
227 s3_config_builder = s3_config_builder
228 .endpoint_url(endpoint)
229 .force_path_style(true);
230 }
231
232 let s3_config = s3_config_builder.build();
233 Ok(Client::from_conf(s3_config))
234 }
235
236 #[cfg(feature = "s3")]
237 async fn upload_multipart(&self, key: &str, data: &[u8]) -> Result<()> {
238 let client = self.create_client().await?;
239 let full_key = self.full_key(key);
240
241 let mut create_request = client
243 .create_multipart_upload()
244 .bucket(&self.bucket)
245 .key(&full_key)
246 .storage_class(self.storage_class.to_aws_storage_class());
247
248 create_request = match &self.sse {
250 SseConfig::None => create_request,
251 SseConfig::Aes256 => {
252 create_request.server_side_encryption(ServerSideEncryption::Aes256)
253 }
254 SseConfig::Kms { key_id } => create_request
255 .server_side_encryption(ServerSideEncryption::AwsKms)
256 .ssekms_key_id(key_id),
257 };
258
259 let multipart_upload = create_request.send().await.map_err(|e| {
260 CloudError::S3(S3Error::MultipartUpload {
261 message: format!("Failed to initiate multipart upload: {e}"),
262 })
263 })?;
264
265 let upload_id = multipart_upload.upload_id().ok_or_else(|| {
266 CloudError::S3(S3Error::MultipartUpload {
267 message: "No upload ID returned".to_string(),
268 })
269 })?;
270
271 let mut completed_parts = Vec::new();
273 let mut part_number = 1;
274
275 for chunk in data.chunks(self.multipart_chunk_size) {
276 let part = client
277 .upload_part()
278 .bucket(&self.bucket)
279 .key(&full_key)
280 .upload_id(upload_id)
281 .part_number(part_number)
282 .body(ByteStream::from(chunk.to_vec()))
283 .send()
284 .await
285 .map_err(|e| {
286 CloudError::S3(S3Error::MultipartUpload {
287 message: format!("Failed to upload part {part_number}: {e}"),
288 })
289 })?;
290
291 if let Some(etag) = part.e_tag() {
292 completed_parts.push(
293 CompletedPart::builder()
294 .e_tag(etag)
295 .part_number(part_number)
296 .build(),
297 );
298 }
299
300 part_number += 1;
301 }
302
303 let completed_upload = CompletedMultipartUpload::builder()
305 .set_parts(Some(completed_parts))
306 .build();
307
308 client
309 .complete_multipart_upload()
310 .bucket(&self.bucket)
311 .key(&full_key)
312 .upload_id(upload_id)
313 .multipart_upload(completed_upload)
314 .send()
315 .await
316 .map_err(|e| {
317 CloudError::S3(S3Error::MultipartUpload {
318 message: format!("Failed to complete multipart upload: {e}"),
319 })
320 })?;
321
322 Ok(())
323 }
324}
325
326#[cfg(all(feature = "s3", feature = "async"))]
327#[async_trait::async_trait]
328impl CloudStorageBackend for S3Backend {
329 async fn get(&self, key: &str) -> Result<Bytes> {
330 let mut executor = RetryExecutor::new(self.retry_config.clone());
331
332 executor
333 .execute(|| async {
334 let client = self.create_client().await?;
335 let full_key = self.full_key(key);
336
337 let response = client
338 .get_object()
339 .bucket(&self.bucket)
340 .key(&full_key)
341 .send()
342 .await
343 .map_err(|e| {
344 CloudError::S3(S3Error::Sdk {
345 message: format!("Failed to get object '{full_key}': {e}"),
346 })
347 })?;
348
349 let data = response.body.collect().await.map_err(|e| {
350 CloudError::S3(S3Error::Sdk {
351 message: format!("Failed to read object body: {e}"),
352 })
353 })?;
354
355 Ok(data.into_bytes())
356 })
357 .await
358 }
359
360 async fn put(&self, key: &str, data: &[u8]) -> Result<()> {
361 if data.len() > self.multipart_threshold {
363 return self.upload_multipart(key, data).await;
364 }
365
366 let mut executor = RetryExecutor::new(self.retry_config.clone());
367
368 executor
369 .execute(|| async {
370 let client = self.create_client().await?;
371 let full_key = self.full_key(key);
372
373 let mut request = client
374 .put_object()
375 .bucket(&self.bucket)
376 .key(&full_key)
377 .body(ByteStream::from(data.to_vec()))
378 .storage_class(self.storage_class.to_aws_storage_class());
379
380 request = match &self.sse {
382 SseConfig::None => request,
383 SseConfig::Aes256 => {
384 request.server_side_encryption(ServerSideEncryption::Aes256)
385 }
386 SseConfig::Kms { key_id } => request
387 .server_side_encryption(ServerSideEncryption::AwsKms)
388 .ssekms_key_id(key_id),
389 };
390
391 request.send().await.map_err(|e| {
392 CloudError::S3(S3Error::Sdk {
393 message: format!("Failed to put object '{full_key}': {e}"),
394 })
395 })?;
396
397 Ok(())
398 })
399 .await
400 }
401
402 async fn delete(&self, key: &str) -> Result<()> {
403 let mut executor = RetryExecutor::new(self.retry_config.clone());
404
405 executor
406 .execute(|| async {
407 let client = self.create_client().await?;
408 let full_key = self.full_key(key);
409
410 client
411 .delete_object()
412 .bucket(&self.bucket)
413 .key(&full_key)
414 .send()
415 .await
416 .map_err(|e| {
417 CloudError::S3(S3Error::Sdk {
418 message: format!("Failed to delete object '{full_key}': {e}"),
419 })
420 })?;
421
422 Ok(())
423 })
424 .await
425 }
426
427 async fn exists(&self, key: &str) -> Result<bool> {
428 let client = self.create_client().await?;
429 let full_key = self.full_key(key);
430
431 match client
432 .head_object()
433 .bucket(&self.bucket)
434 .key(&full_key)
435 .send()
436 .await
437 {
438 Ok(_) => Ok(true),
439 Err(e) => {
440 let error_message = format!("{e}");
441 if error_message.contains("404") || error_message.contains("NotFound") {
442 Ok(false)
443 } else {
444 Err(CloudError::S3(S3Error::Sdk {
445 message: format!("Failed to check object existence '{full_key}': {e}"),
446 }))
447 }
448 }
449 }
450 }
451
452 async fn list_prefix(&self, prefix: &str) -> Result<Vec<String>> {
453 let client = self.create_client().await?;
454 let full_prefix = self.full_key(prefix);
455
456 let mut results = Vec::new();
457 let mut continuation_token: Option<String> = None;
458
459 loop {
460 let mut request = client
461 .list_objects_v2()
462 .bucket(&self.bucket)
463 .prefix(&full_prefix);
464
465 if let Some(ref token) = continuation_token {
466 request = request.continuation_token(token);
467 }
468
469 let response = request.send().await.map_err(|e| {
470 CloudError::S3(S3Error::Sdk {
471 message: format!("Failed to list objects with prefix '{full_prefix}': {e}"),
472 })
473 })?;
474
475 if let Some(contents) = response.contents {
476 for object in contents {
477 if let Some(key) = object.key {
478 let relative_key = if !self.prefix.is_empty() {
480 key.strip_prefix(&format!("{}/", self.prefix))
481 .unwrap_or(&key)
482 .to_string()
483 } else {
484 key
485 };
486 results.push(relative_key);
487 }
488 }
489 }
490
491 if response.is_truncated == Some(true) {
492 continuation_token = response.next_continuation_token;
493 } else {
494 break;
495 }
496 }
497
498 Ok(results)
499 }
500
501 fn is_readonly(&self) -> bool {
502 false
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_s3_backend_new() {
512 let backend = S3Backend::new("my-bucket", "data/zarr");
513 assert_eq!(backend.bucket, "my-bucket");
514 assert_eq!(backend.prefix, "data/zarr");
515 }
516
517 #[test]
518 fn test_s3_backend_builder() {
519 let backend = S3Backend::new("my-bucket", "data")
520 .with_region("us-west-2")
521 .with_sse(SseConfig::Aes256)
522 .with_storage_class(S3StorageClass::IntelligentTiering)
523 .with_transfer_acceleration(true)
524 .with_multipart_threshold(10 * 1024 * 1024)
525 .with_timeout(Duration::from_secs(600));
526
527 assert_eq!(backend.region, Some("us-west-2".to_string()));
528 assert!(matches!(backend.sse, SseConfig::Aes256));
529 assert!(matches!(
530 backend.storage_class,
531 S3StorageClass::IntelligentTiering
532 ));
533 assert!(backend.transfer_acceleration);
534 assert_eq!(backend.multipart_threshold, 10 * 1024 * 1024);
535 assert_eq!(backend.timeout, Duration::from_secs(600));
536 }
537
538 #[test]
539 fn test_s3_backend_full_key() {
540 let backend = S3Backend::new("bucket", "prefix");
541 assert_eq!(backend.full_key("file.txt"), "prefix/file.txt");
542
543 let backend_no_prefix = S3Backend::new("bucket", "");
544 assert_eq!(backend_no_prefix.full_key("file.txt"), "file.txt");
545 }
546}