1use crate::{FileStream, StorageLayer};
2use anyhow::Context;
3use aws_config::SdkConfig;
4use aws_sdk_s3::{
5 config::Credentials,
6 presigning::{PresignedRequest, PresigningConfig},
7 primitives::ByteStream,
8 types::{
9 BucketLocationConstraint, CorsConfiguration, CorsRule, CreateBucketConfiguration,
10 NotificationConfiguration, QueueConfiguration,
11 },
12};
13use bytes::Bytes;
14use chrono::{DateTime, TimeDelta, Utc};
15use futures::Stream;
16use serde::{Deserialize, Serialize};
17use std::{error::Error, time::Duration};
18
19pub type S3Client = aws_sdk_s3::Client;
20
21#[derive(Debug, Clone, Deserialize, Serialize)]
22pub struct S3StorageLayerFactoryConfig {
23 pub endpoint: S3Endpoint,
24}
25
26impl S3StorageLayerFactoryConfig {
27 pub fn from_env() -> anyhow::Result<Self> {
28 let endpoint = S3Endpoint::from_env()?;
29
30 Ok(Self { endpoint })
31 }
32}
33
34#[derive(Debug, Clone, Deserialize, Serialize)]
35#[serde(tag = "type", rename_all = "snake_case")]
36pub enum S3Endpoint {
37 Aws,
38 Custom {
39 endpoint: String,
40 access_key_id: String,
41 access_key_secret: String,
42 },
43}
44
45impl S3Endpoint {
46 pub fn from_env() -> anyhow::Result<Self> {
47 match std::env::var("DOCBOX_S3_ENDPOINT") {
48 Ok(endpoint_url) => {
50 let access_key_id = std::env::var("DOCBOX_S3_ACCESS_KEY_ID").context(
51 "cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_ID",
52 )?;
53 let access_key_secret = std::env::var("DOCBOX_S3_ACCESS_KEY_SECRET").context(
54 "cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_SECRET",
55 )?;
56
57 Ok(S3Endpoint::Custom {
58 endpoint: endpoint_url,
59 access_key_id,
60 access_key_secret,
61 })
62 }
63 Err(_) => Ok(S3Endpoint::Aws),
64 }
65 }
66}
67
68#[derive(Clone)]
69pub struct S3StorageLayerFactory {
70 client: S3Client,
72}
73
74impl S3StorageLayerFactory {
75 pub fn from_config(aws_config: &SdkConfig, config: S3StorageLayerFactoryConfig) -> Self {
76 let client = match config.endpoint {
77 S3Endpoint::Aws => {
78 tracing::debug!("using aws s3 storage layer");
79 S3Client::new(aws_config)
80 }
81 S3Endpoint::Custom {
82 endpoint,
83 access_key_id,
84 access_key_secret,
85 } => {
86 tracing::debug!("using custom s3 storage layer");
87 let credentials = Credentials::new(
88 access_key_id,
89 access_key_secret,
90 None,
91 None,
92 "docbox_key_provider",
93 );
94
95 let config = aws_sdk_s3::config::Builder::from(aws_config)
97 .force_path_style(true)
98 .endpoint_url(endpoint)
99 .credentials_provider(credentials)
100 .build();
101 S3Client::from_conf(config)
102 }
103 };
104
105 Self { client }
106 }
107
108 pub fn create_storage_layer(&self, bucket_name: String) -> S3StorageLayer {
109 S3StorageLayer {
110 client: self.client.clone(),
111 bucket_name,
112 }
113 }
114}
115
116#[derive(Clone)]
117pub struct S3StorageLayer {
118 client: S3Client,
120
121 bucket_name: String,
123}
124
125impl S3StorageLayer {
126 pub fn new(client: S3Client, bucket_name: String) -> Self {
127 Self {
128 client,
129 bucket_name,
130 }
131 }
132}
133
134impl StorageLayer for S3StorageLayer {
135 async fn create_bucket(&self) -> anyhow::Result<()> {
136 let bucket_region = self
137 .client
138 .config()
139 .region()
140 .context("AWS config missing AWS_REGION")?
141 .to_string();
142
143 let constraint = BucketLocationConstraint::from(bucket_region.as_str());
144
145 let cfg = CreateBucketConfiguration::builder()
146 .location_constraint(constraint)
147 .build();
148
149 if let Err(err) = self
150 .client
151 .create_bucket()
152 .create_bucket_configuration(cfg)
153 .bucket(&self.bucket_name)
154 .send()
155 .await
156 {
157 let already_exists = err
158 .as_service_error()
159 .is_some_and(|value| value.is_bucket_already_owned_by_you());
160
161 if already_exists {
163 return Ok(());
164 }
165
166 tracing::error!(cause = ?err, "failed to create bucket");
167
168 return Err(err.into());
169 }
170
171 Ok(())
172 }
173
174 async fn delete_bucket(&self) -> anyhow::Result<()> {
175 self.client
176 .delete_bucket()
177 .bucket(&self.bucket_name)
178 .send()
179 .await
180 .context("failed to delete bucket")?;
181
182 Ok(())
183 }
184
185 async fn upload_file(
186 &self,
187 key: &str,
188 content_type: String,
189 body: Bytes,
190 ) -> anyhow::Result<()> {
191 self.client
192 .put_object()
193 .bucket(&self.bucket_name)
194 .content_type(content_type)
195 .key(key)
196 .body(body.into())
197 .send()
198 .await
199 .context("failed to store file in s3 bucket")?;
200
201 Ok(())
202 }
203
204 async fn create_presigned(
205 &self,
206 key: &str,
207 size: i64,
208 ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
209 let expiry_time_minutes = 30;
210 let expires_at = Utc::now()
211 .checked_add_signed(TimeDelta::minutes(expiry_time_minutes))
212 .context("expiry time exceeds unix limit")?;
213
214 let result = self
215 .client
216 .put_object()
217 .bucket(&self.bucket_name)
218 .key(key)
219 .content_length(size)
220 .presigned(
221 PresigningConfig::builder()
222 .expires_in(Duration::from_secs(60 * expiry_time_minutes as u64))
223 .build()?,
224 )
225 .await
226 .context("failed to create presigned request")?;
227
228 Ok((result, expires_at))
229 }
230
231 async fn create_presigned_download(
232 &self,
233 key: &str,
234 expires_in: Duration,
235 ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
236 let expires_at = Utc::now()
237 .checked_add_signed(TimeDelta::seconds(expires_in.as_secs() as i64))
238 .context("expiry time exceeds unix limit")?;
239
240 let result = self
241 .client
242 .get_object()
243 .bucket(&self.bucket_name)
244 .key(key)
245 .presigned(PresigningConfig::expires_in(expires_in)?)
246 .await?;
247
248 Ok((result, expires_at))
249 }
250
251 async fn add_bucket_notifications(&self, sqs_arn: &str) -> anyhow::Result<()> {
252 self.client
254 .put_bucket_notification_configuration()
255 .bucket(&self.bucket_name)
256 .notification_configuration(
257 NotificationConfiguration::builder()
258 .set_queue_configurations(Some(vec![
259 QueueConfiguration::builder()
260 .queue_arn(sqs_arn)
261 .events(aws_sdk_s3::types::Event::S3ObjectCreated)
262 .build()?,
263 ]))
264 .build(),
265 )
266 .send()
267 .await?;
268
269 Ok(())
270 }
271
272 async fn add_bucket_cors(&self, origins: Vec<String>) -> anyhow::Result<()> {
273 if let Err(cause) = self
274 .client
275 .put_bucket_cors()
276 .bucket(&self.bucket_name)
277 .cors_configuration(
278 CorsConfiguration::builder()
279 .cors_rules(
280 CorsRule::builder()
281 .allowed_headers("*")
282 .allowed_methods("PUT")
283 .set_allowed_origins(Some(origins))
284 .set_expose_headers(Some(Vec::new()))
285 .build()?,
286 )
287 .build()?,
288 )
289 .send()
290 .await
291 {
292 if cause
294 .raw_response()
295 .is_some_and(|response| response.status().as_u16() == 501)
297 {
298 tracing::warn!("storage s3 backend does not support PutBucketCors.. skipping..");
299 return Ok(());
300 }
301
302 return Err(cause.into());
303 };
304
305 Ok(())
306 }
307
308 async fn delete_file(&self, key: &str) -> anyhow::Result<()> {
309 if let Err(cause) = self
310 .client
311 .delete_object()
312 .bucket(&self.bucket_name)
313 .key(key)
314 .send()
315 .await
316 {
317 if cause
320 .as_service_error()
321 .and_then(|err| err.source())
322 .and_then(|source| source.downcast_ref::<aws_sdk_s3::Error>())
323 .is_some_and(|err| matches!(err, aws_sdk_s3::Error::NoSuchKey(_)))
324 {
325 return Ok(());
326 }
327
328 return Err(cause.into());
329 }
330
331 Ok(())
332 }
333
334 async fn get_file(&self, key: &str) -> anyhow::Result<FileStream> {
335 let object = self
336 .client
337 .get_object()
338 .bucket(&self.bucket_name)
339 .key(key)
340 .send()
341 .await?;
342
343 let stream = FileStream {
344 stream: Box::pin(AwsFileStream { inner: object.body }),
345 };
346
347 Ok(stream)
348 }
349}
350
351pub struct AwsFileStream {
352 inner: ByteStream,
353}
354
355impl AwsFileStream {
356 pub fn into_inner(self) -> ByteStream {
357 self.inner
358 }
359}
360
361impl Stream for AwsFileStream {
362 type Item = std::io::Result<Bytes>;
363
364 fn poll_next(
365 self: std::pin::Pin<&mut Self>,
366 cx: &mut std::task::Context<'_>,
367 ) -> std::task::Poll<Option<Self::Item>> {
368 let this = self.get_mut();
369 let inner = std::pin::Pin::new(&mut this.inner);
370 inner.poll_next(cx).map_err(std::io::Error::other)
371 }
372}