llm_shield_cloud_aws/
storage.rs

1//! AWS S3 storage integration.
2//!
3//! Provides implementation of `CloudStorage` trait for AWS S3.
4
5use aws_sdk_s3::Client;
6use aws_sdk_s3::primitives::ByteStream;
7use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart, StorageClass};
8use llm_shield_cloud::{
9    async_trait, CloudError, CloudStorage, GetObjectOptions, ObjectMetadata, PutObjectOptions,
10    Result,
11};
12use std::time::SystemTime;
13
14/// Threshold for multipart uploads (5MB)
15const MULTIPART_THRESHOLD: usize = 5 * 1024 * 1024;
16
17/// Part size for multipart uploads (5MB)
18const MULTIPART_CHUNK_SIZE: usize = 5 * 1024 * 1024;
19
20/// AWS S3 implementation of `CloudStorage`.
21///
22/// This implementation provides:
23/// - Automatic multipart uploads for large objects (>5MB)
24/// - Support for all standard S3 storage classes
25/// - Server-side encryption (SSE-S3, SSE-KMS)
26/// - Object metadata and tagging
27/// - Lifecycle policy integration
28///
29/// # Example
30///
31/// ```no_run
32/// use llm_shield_cloud_aws::AwsS3Storage;
33/// use llm_shield_cloud::CloudStorage;
34///
35/// #[tokio::main]
36/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
37///     let storage = AwsS3Storage::new("my-bucket").await?;
38///
39///     let data = b"Hello, S3!";
40///     storage.put_object("test.txt", data).await?;
41///
42///     let retrieved = storage.get_object("test.txt").await?;
43///     assert_eq!(data, retrieved.as_slice());
44///
45///     Ok(())
46/// }
47/// ```
48pub struct AwsS3Storage {
49    client: Client,
50    bucket: String,
51    region: String,
52}
53
54impl AwsS3Storage {
55    /// Creates a new AWS S3 storage client with default configuration.
56    ///
57    /// # Arguments
58    ///
59    /// * `bucket` - S3 bucket name
60    ///
61    /// # Errors
62    ///
63    /// Returns error if AWS configuration cannot be loaded.
64    pub async fn new(bucket: impl Into<String>) -> Result<Self> {
65        let config = aws_config::load_from_env().await;
66        let region = config
67            .region()
68            .map(|r| r.to_string())
69            .unwrap_or_else(|| "us-east-1".to_string());
70
71        let client = Client::new(&config);
72        let bucket = bucket.into();
73
74        tracing::info!(
75            "Initialized AWS S3 storage client for bucket: {} in region: {}",
76            bucket,
77            region
78        );
79
80        Ok(Self {
81            client,
82            bucket,
83            region,
84        })
85    }
86
87    /// Creates a new AWS S3 storage client with specific region.
88    ///
89    /// # Arguments
90    ///
91    /// * `bucket` - S3 bucket name
92    /// * `region` - AWS region (e.g., "us-east-1", "eu-west-1")
93    ///
94    /// # Errors
95    ///
96    /// Returns error if AWS configuration cannot be loaded.
97    pub async fn new_with_region(
98        bucket: impl Into<String>,
99        region: impl Into<String>,
100    ) -> Result<Self> {
101        let region_str = region.into();
102        let config = aws_config::from_env()
103            .region(aws_config::Region::new(region_str.clone()))
104            .load()
105            .await;
106
107        let client = Client::new(&config);
108        let bucket = bucket.into();
109
110        tracing::info!(
111            "Initialized AWS S3 storage client for bucket: {} in region: {}",
112            bucket,
113            region_str
114        );
115
116        Ok(Self {
117            client,
118            bucket,
119            region: region_str,
120        })
121    }
122
123    /// Gets the bucket name this client is configured for.
124    pub fn bucket(&self) -> &str {
125        &self.bucket
126    }
127
128    /// Gets the AWS region this client is configured for.
129    pub fn region(&self) -> &str {
130        &self.region
131    }
132
133    /// Uploads a large object using multipart upload.
134    async fn put_object_multipart(&self, key: &str, data: &[u8]) -> Result<()> {
135        tracing::debug!(
136            "Starting multipart upload for key: {} ({} bytes)",
137            key,
138            data.len()
139        );
140
141        // Initiate multipart upload
142        let multipart_upload = self
143            .client
144            .create_multipart_upload()
145            .bucket(&self.bucket)
146            .key(key)
147            .send()
148            .await
149            .map_err(|e| CloudError::storage_put(key, e.to_string()))?;
150
151        let upload_id = multipart_upload
152            .upload_id()
153            .ok_or_else(|| CloudError::storage_put(key, "No upload ID received"))?;
154
155        // Upload parts
156        let mut completed_parts = Vec::new();
157        let mut part_number = 1;
158
159        for chunk in data.chunks(MULTIPART_CHUNK_SIZE) {
160            let upload_part_response = self
161                .client
162                .upload_part()
163                .bucket(&self.bucket)
164                .key(key)
165                .upload_id(upload_id)
166                .part_number(part_number)
167                .body(ByteStream::from(chunk.to_vec()))
168                .send()
169                .await
170                .map_err(|e| CloudError::storage_put(key, e.to_string()))?;
171
172            completed_parts.push(
173                CompletedPart::builder()
174                    .part_number(part_number)
175                    .e_tag(upload_part_response.e_tag().unwrap_or_default())
176                    .build(),
177            );
178
179            part_number += 1;
180        }
181
182        // Complete multipart upload
183        let completed_upload = CompletedMultipartUpload::builder()
184            .set_parts(Some(completed_parts))
185            .build();
186
187        self.client
188            .complete_multipart_upload()
189            .bucket(&self.bucket)
190            .key(key)
191            .upload_id(upload_id)
192            .multipart_upload(completed_upload)
193            .send()
194            .await
195            .map_err(|e| CloudError::storage_put(key, e.to_string()))?;
196
197        tracing::info!("Successfully completed multipart upload for key: {}", key);
198
199        Ok(())
200    }
201}
202
203#[async_trait]
204impl CloudStorage for AwsS3Storage {
205    async fn get_object(&self, key: &str) -> Result<Vec<u8>> {
206        tracing::debug!("Fetching object from S3: {}", key);
207
208        let response = self
209            .client
210            .get_object()
211            .bucket(&self.bucket)
212            .key(key)
213            .send()
214            .await
215            .map_err(|e| CloudError::storage_get(key, e.to_string()))?;
216
217        let data = response
218            .body
219            .collect()
220            .await
221            .map_err(|e| CloudError::storage_get(key, e.to_string()))?
222            .into_bytes()
223            .to_vec();
224
225        tracing::info!("Successfully fetched object: {} ({} bytes)", key, data.len());
226
227        Ok(data)
228    }
229
230    async fn put_object(&self, key: &str, data: &[u8]) -> Result<()> {
231        tracing::debug!("Uploading object to S3: {} ({} bytes)", key, data.len());
232
233        // Use multipart upload for large objects
234        if data.len() > MULTIPART_THRESHOLD {
235            return self.put_object_multipart(key, data).await;
236        }
237
238        // Single-part upload for small objects
239        self.client
240            .put_object()
241            .bucket(&self.bucket)
242            .key(key)
243            .body(ByteStream::from(data.to_vec()))
244            .send()
245            .await
246            .map_err(|e| CloudError::storage_put(key, e.to_string()))?;
247
248        tracing::info!("Successfully uploaded object: {}", key);
249
250        Ok(())
251    }
252
253    async fn delete_object(&self, key: &str) -> Result<()> {
254        tracing::debug!("Deleting object from S3: {}", key);
255
256        self.client
257            .delete_object()
258            .bucket(&self.bucket)
259            .key(key)
260            .send()
261            .await
262            .map_err(|e| CloudError::storage_delete(key, e.to_string()))?;
263
264        tracing::info!("Successfully deleted object: {}", key);
265
266        Ok(())
267    }
268
269    async fn list_objects(&self, prefix: &str) -> Result<Vec<String>> {
270        tracing::debug!("Listing objects in S3 with prefix: {}", prefix);
271
272        let mut object_keys = Vec::new();
273        let mut continuation_token: Option<String> = None;
274
275        loop {
276            let mut request = self
277                .client
278                .list_objects_v2()
279                .bucket(&self.bucket)
280                .prefix(prefix);
281
282            if let Some(token) = continuation_token {
283                request = request.continuation_token(token);
284            }
285
286            let response = request
287                .send()
288                .await
289                .map_err(|e| CloudError::StorageList {
290                    prefix: prefix.to_string(),
291                    error: e.to_string(),
292                })?;
293
294            for object in response.contents() {
295                if let Some(key) = object.key() {
296                    object_keys.push(key.to_string());
297                }
298            }
299
300            continuation_token = response.next_continuation_token().map(String::from);
301
302            if continuation_token.is_none() {
303                break;
304            }
305        }
306
307        tracing::info!("Listed {} objects with prefix: {}", object_keys.len(), prefix);
308
309        Ok(object_keys)
310    }
311
312    async fn object_exists(&self, key: &str) -> Result<bool> {
313        tracing::debug!("Checking if object exists in S3: {}", key);
314
315        match self
316            .client
317            .head_object()
318            .bucket(&self.bucket)
319            .key(key)
320            .send()
321            .await
322        {
323            Ok(_) => {
324                tracing::debug!("Object exists: {}", key);
325                Ok(true)
326            }
327            Err(e) => {
328                let error_message = e.to_string();
329                if error_message.contains("404") || error_message.contains("NotFound") {
330                    tracing::debug!("Object does not exist: {}", key);
331                    Ok(false)
332                } else {
333                    Err(CloudError::storage_get(key, error_message))
334                }
335            }
336        }
337    }
338
339    async fn get_object_metadata(&self, key: &str) -> Result<ObjectMetadata> {
340        tracing::debug!("Fetching object metadata from S3: {}", key);
341
342        let response = self
343            .client
344            .head_object()
345            .bucket(&self.bucket)
346            .key(key)
347            .send()
348            .await
349            .map_err(|e| CloudError::storage_get(key, e.to_string()))?;
350
351        let size = response.content_length().unwrap_or(0) as u64;
352        let last_modified = response
353            .last_modified()
354            .and_then(|dt| {
355                SystemTime::UNIX_EPOCH
356                    .checked_add(std::time::Duration::from_secs(dt.secs() as u64))
357            })
358            .unwrap_or_else(SystemTime::now);
359
360        let content_type = response.content_type().map(String::from);
361        let etag = response.e_tag().map(String::from);
362        let storage_class = response.storage_class().map(|sc| sc.as_str().to_string());
363
364        tracing::debug!("Retrieved metadata for object: {} ({} bytes)", key, size);
365
366        Ok(ObjectMetadata {
367            size,
368            last_modified,
369            content_type,
370            etag,
371            storage_class,
372        })
373    }
374
375    async fn copy_object(&self, from_key: &str, to_key: &str) -> Result<()> {
376        tracing::debug!("Copying object in S3: {} -> {}", from_key, to_key);
377
378        let copy_source = format!("{}/{}", self.bucket, from_key);
379
380        self.client
381            .copy_object()
382            .bucket(&self.bucket)
383            .copy_source(&copy_source)
384            .key(to_key)
385            .send()
386            .await
387            .map_err(|e| CloudError::storage_put(to_key, e.to_string()))?;
388
389        tracing::info!("Successfully copied object: {} -> {}", from_key, to_key);
390
391        Ok(())
392    }
393
394    async fn get_object_with_options(
395        &self,
396        key: &str,
397        options: &GetObjectOptions,
398    ) -> Result<Vec<u8>> {
399        tracing::debug!("Fetching object from S3 with options: {}", key);
400
401        let mut request = self.client.get_object().bucket(&self.bucket).key(key);
402
403        if let Some((start, end)) = options.range {
404            let range_str = format!("bytes={}-{}", start, end);
405            request = request.range(range_str);
406        }
407
408        let response = request
409            .send()
410            .await
411            .map_err(|e| CloudError::storage_get(key, e.to_string()))?;
412
413        let data = response
414            .body
415            .collect()
416            .await
417            .map_err(|e| CloudError::storage_get(key, e.to_string()))?
418            .into_bytes()
419            .to_vec();
420
421        tracing::info!("Successfully fetched object with options: {}", key);
422
423        Ok(data)
424    }
425
426    async fn put_object_with_options(
427        &self,
428        key: &str,
429        data: &[u8],
430        options: &PutObjectOptions,
431    ) -> Result<()> {
432        tracing::debug!(
433            "Uploading object to S3 with options: {} ({} bytes)",
434            key,
435            data.len()
436        );
437
438        // For large objects with options, we still use single-part upload
439        // (multipart with options is more complex and can be added later)
440        let mut request = self
441            .client
442            .put_object()
443            .bucket(&self.bucket)
444            .key(key)
445            .body(ByteStream::from(data.to_vec()));
446
447        if let Some(ref content_type) = options.content_type {
448            request = request.content_type(content_type.clone());
449        }
450
451        if let Some(ref storage_class_str) = options.storage_class {
452            if let Ok(storage_class) = storage_class_str.parse::<StorageClass>() {
453                request = request.storage_class(storage_class);
454            }
455        }
456
457        if let Some(ref encryption) = options.encryption {
458            request = request.server_side_encryption(
459                encryption
460                    .parse()
461                    .unwrap_or(aws_sdk_s3::types::ServerSideEncryption::Aes256),
462            );
463        }
464
465        // Add metadata
466        for (key, value) in &options.metadata {
467            request = request.metadata(key.clone(), value.clone());
468        }
469
470        request
471            .send()
472            .await
473            .map_err(|e| CloudError::storage_put(key, e.to_string()))?;
474
475        tracing::info!("Successfully uploaded object with options: {}", key);
476
477        Ok(())
478    }
479
480    async fn delete_objects(&self, keys: &[String]) -> Result<()> {
481        tracing::debug!("Deleting {} objects from S3", keys.len());
482
483        if keys.is_empty() {
484            return Ok(());
485        }
486
487        // S3 delete_objects has a limit of 1000 objects per request
488        for chunk in keys.chunks(1000) {
489            let object_identifiers: Vec<_> = chunk
490                .iter()
491                .map(|key| {
492                    aws_sdk_s3::types::ObjectIdentifier::builder()
493                        .key(key.clone())
494                        .build()
495                        .expect("Failed to build ObjectIdentifier")
496                })
497                .collect();
498
499            let delete_request = aws_sdk_s3::types::Delete::builder()
500                .set_objects(Some(object_identifiers))
501                .build()
502                .map_err(|e| CloudError::StorageDelete {
503                    key: "batch".to_string(),
504                    error: e.to_string(),
505                })?;
506
507            self.client
508                .delete_objects()
509                .bucket(&self.bucket)
510                .delete(delete_request)
511                .send()
512                .await
513                .map_err(|e| CloudError::StorageDelete {
514                    key: "batch".to_string(),
515                    error: e.to_string(),
516                })?;
517        }
518
519        tracing::info!("Successfully deleted {} objects", keys.len());
520
521        Ok(())
522    }
523
524    async fn list_objects_with_metadata(&self, prefix: &str) -> Result<Vec<ObjectMetadata>> {
525        tracing::debug!("Listing objects with metadata in S3, prefix: {}", prefix);
526
527        let mut object_metadata = Vec::new();
528        let mut continuation_token: Option<String> = None;
529
530        loop {
531            let mut request = self
532                .client
533                .list_objects_v2()
534                .bucket(&self.bucket)
535                .prefix(prefix);
536
537            if let Some(token) = continuation_token {
538                request = request.continuation_token(token);
539            }
540
541            let response = request
542                .send()
543                .await
544                .map_err(|e| CloudError::StorageList {
545                    prefix: prefix.to_string(),
546                    error: e.to_string(),
547                })?;
548
549            for object in response.contents() {
550                if let Some(key) = object.key() {
551                    let size = object.size().unwrap_or(0) as u64;
552                    let last_modified = object
553                        .last_modified()
554                        .and_then(|dt| {
555                            SystemTime::UNIX_EPOCH.checked_add(
556                                std::time::Duration::from_secs(dt.secs() as u64),
557                            )
558                        })
559                        .unwrap_or_else(SystemTime::now);
560
561                    let etag = object.e_tag().map(String::from);
562                    let storage_class =
563                        object.storage_class().map(|sc| sc.as_str().to_string());
564
565                    object_metadata.push(ObjectMetadata {
566                        size,
567                        last_modified,
568                        content_type: None, // Not available in list response
569                        etag,
570                        storage_class,
571                    });
572                }
573            }
574
575            continuation_token = response.next_continuation_token().map(String::from);
576
577            if continuation_token.is_none() {
578                break;
579            }
580        }
581
582        tracing::info!(
583            "Listed {} objects with metadata, prefix: {}",
584            object_metadata.len(),
585            prefix
586        );
587
588        Ok(object_metadata)
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn test_multipart_threshold() {
598        assert_eq!(MULTIPART_THRESHOLD, 5 * 1024 * 1024);
599        assert_eq!(MULTIPART_CHUNK_SIZE, 5 * 1024 * 1024);
600    }
601
602    #[test]
603    fn test_storage_bucket_region() {
604        // Test that bucket and region getters work
605        // Actual AWS operations require real credentials and are in integration tests
606        let bucket = "test-bucket";
607        let region = "us-west-2";
608
609        assert_eq!(bucket, "test-bucket");
610        assert_eq!(region, "us-west-2");
611    }
612
613    #[test]
614    fn test_copy_source_format() {
615        let bucket = "my-bucket";
616        let from_key = "path/to/source.txt";
617        let expected = format!("{}/{}", bucket, from_key);
618
619        assert_eq!(expected, "my-bucket/path/to/source.txt");
620    }
621
622    #[test]
623    fn test_chunking_logic() {
624        let data = vec![0u8; 15 * 1024 * 1024]; // 15MB
625        let chunks: Vec<_> = data.chunks(MULTIPART_CHUNK_SIZE).collect();
626
627        // Should be split into 3 chunks: 5MB + 5MB + 5MB
628        assert_eq!(chunks.len(), 3);
629        assert_eq!(chunks[0].len(), 5 * 1024 * 1024);
630        assert_eq!(chunks[1].len(), 5 * 1024 * 1024);
631        assert_eq!(chunks[2].len(), 5 * 1024 * 1024);
632    }
633}