use anyhow::{Context, Result};
use aws_config::BehaviorVersion;
use aws_credential_types::Credentials;
use aws_sdk_s3::Client;
#[derive(Debug, Clone)]
pub struct S3File {
pub bytes: Vec<u8>,
pub key: String,
}
impl S3File {
pub fn new(bytes: Vec<u8>, key: String) -> Self {
Self { bytes, key }
}
pub fn save_file(&self, file_path: Option<&str>) -> Result<String> {
let path = match file_path {
Some(p) => p.to_string(),
None => {
let filename = self
.key
.split('/')
.last()
.unwrap_or(&self.key)
.to_string();
filename
}
};
std::fs::write(&path, &self.bytes)
.context(format!("Failed to write file to: {}", path))?;
Ok(path)
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
pub fn into_bytes(self) -> Vec<u8> {
self.bytes
}
pub fn key(&self) -> &str {
&self.key
}
}
#[derive(Debug, Clone)]
pub struct S3Client {
pub access_key_id: String,
pub secret_access_key: String,
pub region: String,
}
impl S3Client {
pub fn new(access_key_id: String, secret_access_key: String, region: String) -> Self {
Self {
access_key_id,
secret_access_key,
region,
}
}
pub async fn get_file_from_s3(&self, bucket_name: &str, key: &str) -> Result<S3File> {
let credentials = Credentials::new(
&self.access_key_id,
&self.secret_access_key,
None,
None,
"static",
);
let config = aws_config::defaults(BehaviorVersion::latest())
.credentials_provider(credentials)
.region(aws_config::Region::new(self.region.clone()))
.load()
.await;
let client = Client::new(&config);
let response = client
.get_object()
.bucket(bucket_name)
.key(key)
.send()
.await
.context(format!(
"Failed to fetch object from S3: bucket={}, key={}",
bucket_name, key
))?;
let body = response.body.collect().await?;
Ok(S3File::new(body.into_bytes().to_vec(), key.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore] async fn test_get_file_from_s3() {
let client = S3Client::new(
"YOUR_ACCESS_KEY".to_string(),
"YOUR_SECRET_KEY".to_string(),
"us-east-1".to_string(),
);
let file = client
.get_file_from_s3("your-bucket", "test.txt")
.await
.unwrap();
println!("{}", String::from_utf8_lossy(file.as_bytes()));
}
#[tokio::test]
#[ignore] async fn test_download_and_save() {
use tempdir::TempDir;
let client = S3Client::new(
"YOUR_ACCESS_KEY".to_string(),
"YOUR_SECRET_KEY".to_string(),
"us-east-1".to_string(),
);
let file = client
.get_file_from_s3("your-bucket", "test.txt")
.await
.unwrap();
let temp_dir = TempDir::new("test").unwrap();
let path = temp_dir.path().join("downloaded.txt");
let saved_path = file.save_file(Some(path.to_str().unwrap())).unwrap();
assert_eq!(saved_path, path.to_str().unwrap());
assert!(std::path::Path::new(&saved_path).exists());
}
#[test]
fn test_s3file_creation() {
let bytes = vec![1, 2, 3, 4, 5];
let file = S3File::new(bytes.clone(), "test/file.txt".to_string());
assert_eq!(file.as_bytes(), &bytes);
assert_eq!(file.key(), "test/file.txt");
assert_eq!(file.clone().into_bytes(), bytes);
}
#[test]
fn test_s3file_save_with_default_name() {
use tempdir::TempDir;
let temp_dir = TempDir::new("test").unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
let bytes = b"test content".to_vec();
let file = S3File::new(bytes.clone(), "path/to/myfile.txt".to_string());
let saved_path = file.save_file(None).unwrap();
assert_eq!(saved_path, "myfile.txt");
assert!(std::path::Path::new(&saved_path).exists());
let content = std::fs::read(&saved_path).unwrap();
assert_eq!(content, bytes);
}
#[test]
fn test_s3file_save_with_custom_path() {
use tempdir::TempDir;
let bytes = b"test content".to_vec();
let file = S3File::new(bytes.clone(), "original.txt".to_string());
let temp_dir = TempDir::new("test").unwrap();
let path = temp_dir.path().join("custom.txt");
let saved_path = file.save_file(Some(path.to_str().unwrap())).unwrap();
assert_eq!(saved_path, path.to_str().unwrap());
let content = std::fs::read(&saved_path).unwrap();
assert_eq!(content, bytes);
}
#[test]
fn test_s3_client_creation() {
let client = S3Client::new(
"test_key".to_string(),
"test_secret".to_string(),
"us-west-2".to_string(),
);
assert_eq!(client.access_key_id, "test_key");
assert_eq!(client.secret_access_key, "test_secret");
assert_eq!(client.region, "us-west-2");
}
}