1use std::path::Path;
4
5use aws_config::BehaviorVersion;
6use aws_sdk_s3::primitives::ByteStream;
7use aws_sdk_s3::Client as S3Client;
8
9use crate::error::{EngramError, Result};
10
11pub struct CloudStorage {
13 client: S3Client,
14 bucket: String,
15 key: String,
16 encrypt: bool,
17 encryption_key: Option<Vec<u8>>,
18}
19
20impl CloudStorage {
21 pub async fn from_uri(uri: &str, encrypt: bool) -> Result<Self> {
23 let uri = uri
24 .strip_prefix("s3://")
25 .ok_or_else(|| EngramError::Config("URI must start with s3://".to_string()))?;
26
27 let parts: Vec<&str> = uri.splitn(2, '/').collect();
28 if parts.len() != 2 {
29 return Err(EngramError::Config(
30 "URI must be s3://bucket/path".to_string(),
31 ));
32 }
33
34 let bucket = parts[0].to_string();
35 let key = parts[1].to_string();
36
37 let config = aws_config::defaults(BehaviorVersion::latest()).load().await;
39 let client = S3Client::new(&config);
40
41 let encryption_key = if encrypt {
43 Some(generate_encryption_key()?)
44 } else {
45 None
46 };
47
48 Ok(Self {
49 client,
50 bucket,
51 key,
52 encrypt,
53 encryption_key,
54 })
55 }
56
57 pub async fn upload(&self, local_path: &Path) -> Result<u64> {
59 let data = tokio::fs::read(local_path).await?;
60 let size = data.len() as u64;
61
62 let body = if self.encrypt {
63 let encrypted = self.encrypt_data(&data)?;
64 ByteStream::from(encrypted)
65 } else {
66 ByteStream::from(data)
67 };
68
69 self.client
70 .put_object()
71 .bucket(&self.bucket)
72 .key(&self.key)
73 .body(body)
74 .send()
75 .await
76 .map_err(|e| EngramError::CloudStorage(e.to_string()))?;
77
78 tracing::info!(
79 "Uploaded {} bytes to s3://{}/{}",
80 size,
81 self.bucket,
82 self.key
83 );
84 Ok(size)
85 }
86
87 pub async fn download(&self, local_path: &Path) -> Result<u64> {
89 let response = self
90 .client
91 .get_object()
92 .bucket(&self.bucket)
93 .key(&self.key)
94 .send()
95 .await
96 .map_err(|e| EngramError::CloudStorage(e.to_string()))?;
97
98 let data = response
99 .body
100 .collect()
101 .await
102 .map_err(|e| EngramError::CloudStorage(e.to_string()))?
103 .into_bytes();
104
105 let decrypted = if self.encrypt {
106 self.decrypt_data(&data)?
107 } else {
108 data.to_vec()
109 };
110
111 let size = decrypted.len() as u64;
112
113 if let Some(parent) = local_path.parent() {
115 tokio::fs::create_dir_all(parent).await?;
116 }
117
118 tokio::fs::write(local_path, &decrypted).await?;
119
120 tracing::info!(
121 "Downloaded {} bytes from s3://{}/{}",
122 size,
123 self.bucket,
124 self.key
125 );
126 Ok(size)
127 }
128
129 pub async fn exists(&self) -> Result<bool> {
131 match self
132 .client
133 .head_object()
134 .bucket(&self.bucket)
135 .key(&self.key)
136 .send()
137 .await
138 {
139 Ok(_) => Ok(true),
140 Err(e) => {
141 let service_error = e.into_service_error();
142 if service_error.is_not_found() {
143 Ok(false)
144 } else {
145 Err(EngramError::CloudStorage(service_error.to_string()))
146 }
147 }
148 }
149 }
150
151 pub async fn metadata(&self) -> Result<CloudMetadata> {
153 let response = self
154 .client
155 .head_object()
156 .bucket(&self.bucket)
157 .key(&self.key)
158 .send()
159 .await
160 .map_err(|e| EngramError::CloudStorage(e.to_string()))?;
161
162 Ok(CloudMetadata {
163 size: response.content_length().unwrap_or(0) as u64,
164 last_modified: response.last_modified().map(|dt| dt.to_string()),
165 etag: response.e_tag().map(String::from),
166 })
167 }
168
169 pub async fn delete(&self) -> Result<()> {
171 self.client
172 .delete_object()
173 .bucket(&self.bucket)
174 .key(&self.key)
175 .send()
176 .await
177 .map_err(|e| EngramError::CloudStorage(e.to_string()))?;
178
179 Ok(())
180 }
181
182 fn encrypt_data(&self, data: &[u8]) -> Result<Vec<u8>> {
184 use aes_gcm::{
185 aead::{Aead, KeyInit},
186 Aes256Gcm, Nonce,
187 };
188 use rand::RngCore;
189
190 let key = self
191 .encryption_key
192 .as_ref()
193 .ok_or_else(|| EngramError::Encryption("No encryption key".to_string()))?;
194
195 let cipher =
196 Aes256Gcm::new_from_slice(key).map_err(|e| EngramError::Encryption(e.to_string()))?;
197
198 let mut nonce_bytes = [0u8; 12];
200 rand::thread_rng().fill_bytes(&mut nonce_bytes);
201 let nonce = Nonce::from_slice(&nonce_bytes);
202
203 let ciphertext = cipher
204 .encrypt(nonce, data)
205 .map_err(|e| EngramError::Encryption(e.to_string()))?;
206
207 let mut result = Vec::with_capacity(12 + ciphertext.len());
209 result.extend_from_slice(&nonce_bytes);
210 result.extend_from_slice(&ciphertext);
211
212 Ok(result)
213 }
214
215 fn decrypt_data(&self, data: &[u8]) -> Result<Vec<u8>> {
217 use aes_gcm::{
218 aead::{Aead, KeyInit},
219 Aes256Gcm, Nonce,
220 };
221
222 if data.len() < 12 {
223 return Err(EngramError::Encryption("Data too short".to_string()));
224 }
225
226 let key = self
227 .encryption_key
228 .as_ref()
229 .ok_or_else(|| EngramError::Encryption("No encryption key".to_string()))?;
230
231 let cipher =
232 Aes256Gcm::new_from_slice(key).map_err(|e| EngramError::Encryption(e.to_string()))?;
233
234 let nonce = Nonce::from_slice(&data[..12]);
235 let ciphertext = &data[12..];
236
237 let plaintext = cipher
238 .decrypt(nonce, ciphertext)
239 .map_err(|e| EngramError::Encryption(e.to_string()))?;
240
241 Ok(plaintext)
242 }
243}
244
245#[derive(Debug, Clone)]
247pub struct CloudMetadata {
248 pub size: u64,
249 pub last_modified: Option<String>,
250 pub etag: Option<String>,
251}
252
253fn generate_encryption_key() -> Result<Vec<u8>> {
255 use rand::RngCore;
256 let mut key = vec![0u8; 32];
257 rand::thread_rng().fill_bytes(&mut key);
258 Ok(key)
259}
260
261#[allow(dead_code)]
263pub fn derive_key_from_passphrase(passphrase: &str, salt: &[u8]) -> Result<Vec<u8>> {
264 use std::num::NonZeroU32;
265
266 let iterations = NonZeroU32::new(100_000).unwrap();
268 let mut key = vec![0u8; 32];
269
270 let mut hasher = std::collections::hash_map::DefaultHasher::new();
272 use std::hash::{Hash, Hasher};
273 for _ in 0..iterations.get() {
274 passphrase.hash(&mut hasher);
275 salt.hash(&mut hasher);
276 }
277 let hash = hasher.finish();
278 key[..8].copy_from_slice(&hash.to_le_bytes());
279
280 for i in 1..4 {
282 let mut h = std::collections::hash_map::DefaultHasher::new();
283 key[..i * 8].hash(&mut h);
284 passphrase.hash(&mut h);
285 let hash = h.finish();
286 key[i * 8..(i + 1) * 8].copy_from_slice(&hash.to_le_bytes());
287 }
288
289 Ok(key)
290}