1use aws_sdk_s3::Client;
6use aws_sdk_s3::primitives::ByteStream;
7use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart, StorageClass};
8use llm_shield_cloud::{
9 async_trait, CloudError, CloudStorage, GetObjectOptions, ObjectMetadata, PutObjectOptions,
10 Result,
11};
12use std::time::SystemTime;
13
14const MULTIPART_THRESHOLD: usize = 5 * 1024 * 1024;
16
17const MULTIPART_CHUNK_SIZE: usize = 5 * 1024 * 1024;
19
20pub struct AwsS3Storage {
49 client: Client,
50 bucket: String,
51 region: String,
52}
53
54impl AwsS3Storage {
55 pub async fn new(bucket: impl Into<String>) -> Result<Self> {
65 let config = aws_config::load_from_env().await;
66 let region = config
67 .region()
68 .map(|r| r.to_string())
69 .unwrap_or_else(|| "us-east-1".to_string());
70
71 let client = Client::new(&config);
72 let bucket = bucket.into();
73
74 tracing::info!(
75 "Initialized AWS S3 storage client for bucket: {} in region: {}",
76 bucket,
77 region
78 );
79
80 Ok(Self {
81 client,
82 bucket,
83 region,
84 })
85 }
86
87 pub async fn new_with_region(
98 bucket: impl Into<String>,
99 region: impl Into<String>,
100 ) -> Result<Self> {
101 let region_str = region.into();
102 let config = aws_config::from_env()
103 .region(aws_config::Region::new(region_str.clone()))
104 .load()
105 .await;
106
107 let client = Client::new(&config);
108 let bucket = bucket.into();
109
110 tracing::info!(
111 "Initialized AWS S3 storage client for bucket: {} in region: {}",
112 bucket,
113 region_str
114 );
115
116 Ok(Self {
117 client,
118 bucket,
119 region: region_str,
120 })
121 }
122
123 pub fn bucket(&self) -> &str {
125 &self.bucket
126 }
127
128 pub fn region(&self) -> &str {
130 &self.region
131 }
132
133 async fn put_object_multipart(&self, key: &str, data: &[u8]) -> Result<()> {
135 tracing::debug!(
136 "Starting multipart upload for key: {} ({} bytes)",
137 key,
138 data.len()
139 );
140
141 let multipart_upload = self
143 .client
144 .create_multipart_upload()
145 .bucket(&self.bucket)
146 .key(key)
147 .send()
148 .await
149 .map_err(|e| CloudError::storage_put(key, e.to_string()))?;
150
151 let upload_id = multipart_upload
152 .upload_id()
153 .ok_or_else(|| CloudError::storage_put(key, "No upload ID received"))?;
154
155 let mut completed_parts = Vec::new();
157 let mut part_number = 1;
158
159 for chunk in data.chunks(MULTIPART_CHUNK_SIZE) {
160 let upload_part_response = self
161 .client
162 .upload_part()
163 .bucket(&self.bucket)
164 .key(key)
165 .upload_id(upload_id)
166 .part_number(part_number)
167 .body(ByteStream::from(chunk.to_vec()))
168 .send()
169 .await
170 .map_err(|e| CloudError::storage_put(key, e.to_string()))?;
171
172 completed_parts.push(
173 CompletedPart::builder()
174 .part_number(part_number)
175 .e_tag(upload_part_response.e_tag().unwrap_or_default())
176 .build(),
177 );
178
179 part_number += 1;
180 }
181
182 let completed_upload = CompletedMultipartUpload::builder()
184 .set_parts(Some(completed_parts))
185 .build();
186
187 self.client
188 .complete_multipart_upload()
189 .bucket(&self.bucket)
190 .key(key)
191 .upload_id(upload_id)
192 .multipart_upload(completed_upload)
193 .send()
194 .await
195 .map_err(|e| CloudError::storage_put(key, e.to_string()))?;
196
197 tracing::info!("Successfully completed multipart upload for key: {}", key);
198
199 Ok(())
200 }
201}
202
203#[async_trait]
204impl CloudStorage for AwsS3Storage {
205 async fn get_object(&self, key: &str) -> Result<Vec<u8>> {
206 tracing::debug!("Fetching object from S3: {}", key);
207
208 let response = self
209 .client
210 .get_object()
211 .bucket(&self.bucket)
212 .key(key)
213 .send()
214 .await
215 .map_err(|e| CloudError::storage_get(key, e.to_string()))?;
216
217 let data = response
218 .body
219 .collect()
220 .await
221 .map_err(|e| CloudError::storage_get(key, e.to_string()))?
222 .into_bytes()
223 .to_vec();
224
225 tracing::info!("Successfully fetched object: {} ({} bytes)", key, data.len());
226
227 Ok(data)
228 }
229
230 async fn put_object(&self, key: &str, data: &[u8]) -> Result<()> {
231 tracing::debug!("Uploading object to S3: {} ({} bytes)", key, data.len());
232
233 if data.len() > MULTIPART_THRESHOLD {
235 return self.put_object_multipart(key, data).await;
236 }
237
238 self.client
240 .put_object()
241 .bucket(&self.bucket)
242 .key(key)
243 .body(ByteStream::from(data.to_vec()))
244 .send()
245 .await
246 .map_err(|e| CloudError::storage_put(key, e.to_string()))?;
247
248 tracing::info!("Successfully uploaded object: {}", key);
249
250 Ok(())
251 }
252
253 async fn delete_object(&self, key: &str) -> Result<()> {
254 tracing::debug!("Deleting object from S3: {}", key);
255
256 self.client
257 .delete_object()
258 .bucket(&self.bucket)
259 .key(key)
260 .send()
261 .await
262 .map_err(|e| CloudError::storage_delete(key, e.to_string()))?;
263
264 tracing::info!("Successfully deleted object: {}", key);
265
266 Ok(())
267 }
268
269 async fn list_objects(&self, prefix: &str) -> Result<Vec<String>> {
270 tracing::debug!("Listing objects in S3 with prefix: {}", prefix);
271
272 let mut object_keys = Vec::new();
273 let mut continuation_token: Option<String> = None;
274
275 loop {
276 let mut request = self
277 .client
278 .list_objects_v2()
279 .bucket(&self.bucket)
280 .prefix(prefix);
281
282 if let Some(token) = continuation_token {
283 request = request.continuation_token(token);
284 }
285
286 let response = request
287 .send()
288 .await
289 .map_err(|e| CloudError::StorageList {
290 prefix: prefix.to_string(),
291 error: e.to_string(),
292 })?;
293
294 for object in response.contents() {
295 if let Some(key) = object.key() {
296 object_keys.push(key.to_string());
297 }
298 }
299
300 continuation_token = response.next_continuation_token().map(String::from);
301
302 if continuation_token.is_none() {
303 break;
304 }
305 }
306
307 tracing::info!("Listed {} objects with prefix: {}", object_keys.len(), prefix);
308
309 Ok(object_keys)
310 }
311
312 async fn object_exists(&self, key: &str) -> Result<bool> {
313 tracing::debug!("Checking if object exists in S3: {}", key);
314
315 match self
316 .client
317 .head_object()
318 .bucket(&self.bucket)
319 .key(key)
320 .send()
321 .await
322 {
323 Ok(_) => {
324 tracing::debug!("Object exists: {}", key);
325 Ok(true)
326 }
327 Err(e) => {
328 let error_message = e.to_string();
329 if error_message.contains("404") || error_message.contains("NotFound") {
330 tracing::debug!("Object does not exist: {}", key);
331 Ok(false)
332 } else {
333 Err(CloudError::storage_get(key, error_message))
334 }
335 }
336 }
337 }
338
339 async fn get_object_metadata(&self, key: &str) -> Result<ObjectMetadata> {
340 tracing::debug!("Fetching object metadata from S3: {}", key);
341
342 let response = self
343 .client
344 .head_object()
345 .bucket(&self.bucket)
346 .key(key)
347 .send()
348 .await
349 .map_err(|e| CloudError::storage_get(key, e.to_string()))?;
350
351 let size = response.content_length().unwrap_or(0) as u64;
352 let last_modified = response
353 .last_modified()
354 .and_then(|dt| {
355 SystemTime::UNIX_EPOCH
356 .checked_add(std::time::Duration::from_secs(dt.secs() as u64))
357 })
358 .unwrap_or_else(SystemTime::now);
359
360 let content_type = response.content_type().map(String::from);
361 let etag = response.e_tag().map(String::from);
362 let storage_class = response.storage_class().map(|sc| sc.as_str().to_string());
363
364 tracing::debug!("Retrieved metadata for object: {} ({} bytes)", key, size);
365
366 Ok(ObjectMetadata {
367 size,
368 last_modified,
369 content_type,
370 etag,
371 storage_class,
372 })
373 }
374
375 async fn copy_object(&self, from_key: &str, to_key: &str) -> Result<()> {
376 tracing::debug!("Copying object in S3: {} -> {}", from_key, to_key);
377
378 let copy_source = format!("{}/{}", self.bucket, from_key);
379
380 self.client
381 .copy_object()
382 .bucket(&self.bucket)
383 .copy_source(©_source)
384 .key(to_key)
385 .send()
386 .await
387 .map_err(|e| CloudError::storage_put(to_key, e.to_string()))?;
388
389 tracing::info!("Successfully copied object: {} -> {}", from_key, to_key);
390
391 Ok(())
392 }
393
394 async fn get_object_with_options(
395 &self,
396 key: &str,
397 options: &GetObjectOptions,
398 ) -> Result<Vec<u8>> {
399 tracing::debug!("Fetching object from S3 with options: {}", key);
400
401 let mut request = self.client.get_object().bucket(&self.bucket).key(key);
402
403 if let Some((start, end)) = options.range {
404 let range_str = format!("bytes={}-{}", start, end);
405 request = request.range(range_str);
406 }
407
408 let response = request
409 .send()
410 .await
411 .map_err(|e| CloudError::storage_get(key, e.to_string()))?;
412
413 let data = response
414 .body
415 .collect()
416 .await
417 .map_err(|e| CloudError::storage_get(key, e.to_string()))?
418 .into_bytes()
419 .to_vec();
420
421 tracing::info!("Successfully fetched object with options: {}", key);
422
423 Ok(data)
424 }
425
426 async fn put_object_with_options(
427 &self,
428 key: &str,
429 data: &[u8],
430 options: &PutObjectOptions,
431 ) -> Result<()> {
432 tracing::debug!(
433 "Uploading object to S3 with options: {} ({} bytes)",
434 key,
435 data.len()
436 );
437
438 let mut request = self
441 .client
442 .put_object()
443 .bucket(&self.bucket)
444 .key(key)
445 .body(ByteStream::from(data.to_vec()));
446
447 if let Some(ref content_type) = options.content_type {
448 request = request.content_type(content_type.clone());
449 }
450
451 if let Some(ref storage_class_str) = options.storage_class {
452 if let Ok(storage_class) = storage_class_str.parse::<StorageClass>() {
453 request = request.storage_class(storage_class);
454 }
455 }
456
457 if let Some(ref encryption) = options.encryption {
458 request = request.server_side_encryption(
459 encryption
460 .parse()
461 .unwrap_or(aws_sdk_s3::types::ServerSideEncryption::Aes256),
462 );
463 }
464
465 for (key, value) in &options.metadata {
467 request = request.metadata(key.clone(), value.clone());
468 }
469
470 request
471 .send()
472 .await
473 .map_err(|e| CloudError::storage_put(key, e.to_string()))?;
474
475 tracing::info!("Successfully uploaded object with options: {}", key);
476
477 Ok(())
478 }
479
480 async fn delete_objects(&self, keys: &[String]) -> Result<()> {
481 tracing::debug!("Deleting {} objects from S3", keys.len());
482
483 if keys.is_empty() {
484 return Ok(());
485 }
486
487 for chunk in keys.chunks(1000) {
489 let object_identifiers: Vec<_> = chunk
490 .iter()
491 .map(|key| {
492 aws_sdk_s3::types::ObjectIdentifier::builder()
493 .key(key.clone())
494 .build()
495 .expect("Failed to build ObjectIdentifier")
496 })
497 .collect();
498
499 let delete_request = aws_sdk_s3::types::Delete::builder()
500 .set_objects(Some(object_identifiers))
501 .build()
502 .map_err(|e| CloudError::StorageDelete {
503 key: "batch".to_string(),
504 error: e.to_string(),
505 })?;
506
507 self.client
508 .delete_objects()
509 .bucket(&self.bucket)
510 .delete(delete_request)
511 .send()
512 .await
513 .map_err(|e| CloudError::StorageDelete {
514 key: "batch".to_string(),
515 error: e.to_string(),
516 })?;
517 }
518
519 tracing::info!("Successfully deleted {} objects", keys.len());
520
521 Ok(())
522 }
523
524 async fn list_objects_with_metadata(&self, prefix: &str) -> Result<Vec<ObjectMetadata>> {
525 tracing::debug!("Listing objects with metadata in S3, prefix: {}", prefix);
526
527 let mut object_metadata = Vec::new();
528 let mut continuation_token: Option<String> = None;
529
530 loop {
531 let mut request = self
532 .client
533 .list_objects_v2()
534 .bucket(&self.bucket)
535 .prefix(prefix);
536
537 if let Some(token) = continuation_token {
538 request = request.continuation_token(token);
539 }
540
541 let response = request
542 .send()
543 .await
544 .map_err(|e| CloudError::StorageList {
545 prefix: prefix.to_string(),
546 error: e.to_string(),
547 })?;
548
549 for object in response.contents() {
550 if let Some(key) = object.key() {
551 let size = object.size().unwrap_or(0) as u64;
552 let last_modified = object
553 .last_modified()
554 .and_then(|dt| {
555 SystemTime::UNIX_EPOCH.checked_add(
556 std::time::Duration::from_secs(dt.secs() as u64),
557 )
558 })
559 .unwrap_or_else(SystemTime::now);
560
561 let etag = object.e_tag().map(String::from);
562 let storage_class =
563 object.storage_class().map(|sc| sc.as_str().to_string());
564
565 object_metadata.push(ObjectMetadata {
566 size,
567 last_modified,
568 content_type: None, etag,
570 storage_class,
571 });
572 }
573 }
574
575 continuation_token = response.next_continuation_token().map(String::from);
576
577 if continuation_token.is_none() {
578 break;
579 }
580 }
581
582 tracing::info!(
583 "Listed {} objects with metadata, prefix: {}",
584 object_metadata.len(),
585 prefix
586 );
587
588 Ok(object_metadata)
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 fn test_multipart_threshold() {
598 assert_eq!(MULTIPART_THRESHOLD, 5 * 1024 * 1024);
599 assert_eq!(MULTIPART_CHUNK_SIZE, 5 * 1024 * 1024);
600 }
601
602 #[test]
603 fn test_storage_bucket_region() {
604 let bucket = "test-bucket";
607 let region = "us-west-2";
608
609 assert_eq!(bucket, "test-bucket");
610 assert_eq!(region, "us-west-2");
611 }
612
613 #[test]
614 fn test_copy_source_format() {
615 let bucket = "my-bucket";
616 let from_key = "path/to/source.txt";
617 let expected = format!("{}/{}", bucket, from_key);
618
619 assert_eq!(expected, "my-bucket/path/to/source.txt");
620 }
621
622 #[test]
623 fn test_chunking_logic() {
624 let data = vec![0u8; 15 * 1024 * 1024]; let chunks: Vec<_> = data.chunks(MULTIPART_CHUNK_SIZE).collect();
626
627 assert_eq!(chunks.len(), 3);
629 assert_eq!(chunks[0].len(), 5 * 1024 * 1024);
630 assert_eq!(chunks[1].len(), 5 * 1024 * 1024);
631 assert_eq!(chunks[2].len(), 5 * 1024 * 1024);
632 }
633}