1use crate::{CreateBucketOutcome, FileStream, StorageLayerError, StorageLayerImpl};
13use aws_config::SdkConfig;
14use aws_sdk_s3::{
15 config::Credentials,
16 error::SdkError,
17 operation::{
18 create_bucket::CreateBucketError, delete_bucket::DeleteBucketError,
19 delete_object::DeleteObjectError, get_object::GetObjectError, head_bucket::HeadBucketError,
20 put_bucket_cors::PutBucketCorsError,
21 put_bucket_notification_configuration::PutBucketNotificationConfigurationError,
22 put_object::PutObjectError,
23 },
24 presigning::{PresignedRequest, PresigningConfig},
25 primitives::ByteStream,
26 types::{
27 BucketLocationConstraint, CorsConfiguration, CorsRule, CreateBucketConfiguration,
28 NotificationConfiguration, QueueConfiguration,
29 },
30};
31use bytes::Bytes;
32use chrono::{DateTime, TimeDelta, Utc};
33use futures::Stream;
34use serde::{Deserialize, Serialize};
35use std::{error::Error, fmt::Debug, time::Duration};
36use thiserror::Error;
37
38type S3Client = aws_sdk_s3::Client;
39
40#[derive(Debug, Default, Clone, Deserialize, Serialize)]
42#[serde(default)]
43pub struct S3StorageLayerFactoryConfig {
44 pub endpoint: S3Endpoint,
46}
47
48#[derive(Debug, Error)]
50pub enum S3StorageLayerFactoryConfigError {
51 #[error("cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_ID")]
53 MissingAccessKeyId,
54
55 #[error("cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_SECRET")]
57 MissingAccessKeySecret,
58}
59
60impl S3StorageLayerFactoryConfig {
61 pub fn from_env() -> Result<Self, S3StorageLayerFactoryConfigError> {
63 let endpoint = S3Endpoint::from_env()?;
64
65 Ok(Self { endpoint })
66 }
67}
68
69#[derive(Default, Clone, Deserialize, Serialize)]
71#[serde(tag = "type", rename_all = "snake_case")]
72pub enum S3Endpoint {
73 #[default]
75 Aws,
76 Custom {
78 endpoint: String,
80 external_endpoint: Option<String>,
82 access_key_id: String,
84 access_key_secret: String,
86 },
87}
88
89impl Debug for S3Endpoint {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 match self {
92 Self::Aws => write!(f, "Aws"),
93 Self::Custom { endpoint, .. } => f
94 .debug_struct("Custom")
95 .field("endpoint", endpoint)
96 .finish(),
97 }
98 }
99}
100
101impl S3Endpoint {
102 pub fn from_env() -> Result<Self, S3StorageLayerFactoryConfigError> {
104 match std::env::var("DOCBOX_S3_ENDPOINT") {
105 Ok(endpoint_url) => {
107 let access_key_id = std::env::var("DOCBOX_S3_ACCESS_KEY_ID")
108 .map_err(|_| S3StorageLayerFactoryConfigError::MissingAccessKeyId)?;
109 let access_key_secret = std::env::var("DOCBOX_S3_ACCESS_KEY_SECRET")
110 .map_err(|_| S3StorageLayerFactoryConfigError::MissingAccessKeySecret)?;
111
112 let external_endpoint = std::env::var("DOCBOX_S3_EXTERNAL_ENDPOINT").ok();
113
114 Ok(S3Endpoint::Custom {
115 endpoint: endpoint_url,
116 external_endpoint,
117 access_key_id,
118 access_key_secret,
119 })
120 }
121 Err(_) => Ok(S3Endpoint::Aws),
122 }
123 }
124}
125
126#[derive(Clone)]
128pub struct S3StorageLayerFactory {
129 client: S3Client,
131 external_client: Option<S3Client>,
133}
134
135impl S3StorageLayerFactory {
136 pub fn from_config(aws_config: &SdkConfig, config: S3StorageLayerFactoryConfig) -> Self {
138 let (client, external_client) = match config.endpoint {
139 S3Endpoint::Aws => {
140 tracing::debug!("using aws s3 storage layer");
141 (S3Client::new(aws_config), None)
142 }
143 S3Endpoint::Custom {
144 endpoint,
145 external_endpoint,
146 access_key_id,
147 access_key_secret,
148 } => {
149 tracing::debug!("using custom s3 storage layer");
150 let credentials = Credentials::new(
151 access_key_id,
152 access_key_secret,
153 None,
154 None,
155 "docbox_key_provider",
156 );
157
158 let config_builder = aws_sdk_s3::config::Builder::from(aws_config)
160 .force_path_style(true)
161 .endpoint_url(endpoint)
162 .credentials_provider(credentials);
163
164 let external_client = match external_endpoint {
166 Some(external_endpoint) => {
167 let config = config_builder
168 .clone()
169 .endpoint_url(external_endpoint)
170 .build();
171 let client = S3Client::from_conf(config);
172 Some(client)
173 }
174 None => None,
175 };
176
177 let config = config_builder.build();
178 let client = S3Client::from_conf(config);
179
180 (client, external_client)
181 }
182 };
183
184 Self {
185 client,
186 external_client,
187 }
188 }
189
190 pub fn create_storage_layer(&self, bucket_name: String) -> S3StorageLayer {
192 S3StorageLayer::new(
193 self.client.clone(),
194 self.external_client.clone(),
195 bucket_name,
196 )
197 }
198}
199
200#[derive(Clone)]
202pub struct S3StorageLayer {
203 bucket_name: String,
205
206 client: S3Client,
208
209 external_client: Option<S3Client>,
211}
212
213impl S3StorageLayer {
214 fn new(client: S3Client, external_client: Option<S3Client>, bucket_name: String) -> Self {
216 Self {
217 bucket_name,
218 client,
219 external_client,
220 }
221 }
222}
223
224#[derive(Debug, Error)]
229pub enum S3StorageError {
230 #[error("invalid server configuration (region)")]
232 MissingRegion,
233
234 #[error("failed to create storage bucket")]
236 CreateBucket(SdkError<CreateBucketError>),
237
238 #[error("failed to delete storage bucket")]
240 DeleteBucket(SdkError<DeleteBucketError>),
241
242 #[error("failed to get storage bucket")]
244 HeadBucket(SdkError<HeadBucketError>),
245
246 #[error("failed to store file object")]
248 PutObject(SdkError<PutObjectError>),
249
250 #[error("failed to calculate expiry timestamp")]
252 UnixTimeCalculation,
253
254 #[error("failed to create presigned store file object")]
256 PutObjectPresigned(SdkError<PutObjectError>),
257
258 #[error("failed to create presigned config")]
260 PresignedConfig,
261
262 #[error("failed to get presigned store file object")]
264 GetObjectPresigned(SdkError<GetObjectError>),
265
266 #[error("failed to create bucket notification queue config")]
268 QueueConfig,
269
270 #[error("failed to add bucket notification queue: {0}")]
276 PutBucketNotification(SdkError<PutBucketNotificationConfigurationError>),
277
278 #[error("failed to create bucket cors config")]
280 CreateCorsConfig,
281
282 #[error("failed to set bucket cors rules: {0}")]
288 PutBucketCors(SdkError<PutBucketCorsError>),
289
290 #[error("failed to delete file object")]
292 DeleteObject(SdkError<DeleteObjectError>),
293
294 #[error("failed to get file storage object")]
296 GetObject(SdkError<GetObjectError>),
297}
298
299impl StorageLayerImpl for S3StorageLayer {
300 fn bucket_name(&self) -> String {
301 self.bucket_name.clone()
302 }
303
304 async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError> {
305 let bucket_region = self
306 .client
307 .config()
308 .region()
309 .ok_or(S3StorageError::MissingRegion)?
310 .to_string();
311
312 let constraint = BucketLocationConstraint::from(bucket_region.as_str());
313
314 let cfg = CreateBucketConfiguration::builder()
315 .location_constraint(constraint)
316 .build();
317
318 if let Err(error) = self
319 .client
320 .create_bucket()
321 .create_bucket_configuration(cfg)
322 .bucket(&self.bucket_name)
323 .send()
324 .await
325 {
326 let already_exists = error
327 .as_service_error()
328 .is_some_and(|value| value.is_bucket_already_owned_by_you());
329
330 if already_exists {
332 tracing::debug!("bucket already exists");
333 return Ok(CreateBucketOutcome::Existing);
334 }
335
336 tracing::error!(?error, "failed to create bucket");
337 return Err(S3StorageError::CreateBucket(error).into());
338 }
339
340 Ok(CreateBucketOutcome::New)
341 }
342
343 async fn bucket_exists(&self) -> Result<bool, StorageLayerError> {
344 if let Err(error) = self
345 .client
346 .head_bucket()
347 .bucket(&self.bucket_name)
348 .send()
349 .await
350 {
351 if error
353 .as_service_error()
354 .is_some_and(|error| error.is_not_found())
355 {
356 return Ok(false);
357 }
358
359 return Err(S3StorageError::HeadBucket(error).into());
360 }
361
362 Ok(true)
363 }
364
365 async fn delete_bucket(&self) -> Result<(), StorageLayerError> {
366 if let Err(error) = self
367 .client
368 .delete_bucket()
369 .bucket(&self.bucket_name)
370 .send()
371 .await
372 {
373 if error
376 .as_service_error()
377 .and_then(|err| err.meta().code())
378 .is_some_and(|code| code == "NoSuchBucket")
379 {
380 tracing::debug!("bucket did not exist");
381 return Ok(());
382 }
383
384 tracing::error!(?error, "failed to delete bucket");
385
386 return Err(S3StorageError::DeleteBucket(error).into());
387 }
388
389 Ok(())
390 }
391
392 async fn upload_file(
393 &self,
394 key: &str,
395 content_type: String,
396 body: Bytes,
397 ) -> Result<(), StorageLayerError> {
398 self.client
399 .put_object()
400 .bucket(&self.bucket_name)
401 .content_type(content_type)
402 .key(key)
403 .body(body.into())
404 .send()
405 .await
406 .map_err(|error| {
407 tracing::error!(?error, "failed to store file object");
408 S3StorageError::PutObject(error)
409 })?;
410
411 Ok(())
412 }
413
414 async fn create_presigned(
415 &self,
416 key: &str,
417 size: i64,
418 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
419 let expiry_time_minutes = 30;
420 let expires_at = Utc::now()
421 .checked_add_signed(TimeDelta::minutes(expiry_time_minutes))
422 .ok_or(S3StorageError::UnixTimeCalculation)?;
423
424 let client = match self.external_client.as_ref() {
425 Some(external_client) => external_client,
426 None => &self.client,
427 };
428
429 let result = client
430 .put_object()
431 .bucket(&self.bucket_name)
432 .key(key)
433 .content_length(size)
434 .presigned(
435 PresigningConfig::builder()
436 .expires_in(Duration::from_secs(60 * expiry_time_minutes as u64))
437 .build()
438 .map_err(|error| {
439 tracing::error!(?error, "Failed to create presigned store config");
440 S3StorageError::PresignedConfig
441 })?,
442 )
443 .await
444 .map_err(|error| {
445 tracing::error!(?error, "failed to create presigned store file object");
446 S3StorageError::PutObjectPresigned(error)
447 })?;
448
449 Ok((result, expires_at))
450 }
451
452 async fn create_presigned_download(
453 &self,
454 key: &str,
455 expires_in: Duration,
456 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
457 let expires_at = Utc::now()
458 .checked_add_signed(TimeDelta::seconds(expires_in.as_secs() as i64))
459 .ok_or(S3StorageError::UnixTimeCalculation)?;
460
461 let client = match self.external_client.as_ref() {
462 Some(external_client) => external_client,
463 None => &self.client,
464 };
465
466 let result = client
467 .get_object()
468 .bucket(&self.bucket_name)
469 .key(key)
470 .presigned(PresigningConfig::expires_in(expires_in).map_err(|error| {
471 tracing::error!(?error, "failed to create presigned download config");
472 S3StorageError::PresignedConfig
473 })?)
474 .await
475 .map_err(|error| {
476 tracing::error!(?error, "failed to create presigned download");
477 S3StorageError::GetObjectPresigned(error)
478 })?;
479
480 Ok((result, expires_at))
481 }
482
483 async fn add_bucket_notifications(&self, sqs_arn: &str) -> Result<(), StorageLayerError> {
484 self.client
486 .put_bucket_notification_configuration()
487 .bucket(&self.bucket_name)
488 .notification_configuration(
489 NotificationConfiguration::builder()
490 .set_queue_configurations(Some(vec![
491 QueueConfiguration::builder()
492 .queue_arn(sqs_arn)
493 .events(aws_sdk_s3::types::Event::S3ObjectCreated)
494 .build()
495 .map_err(|error| {
496 tracing::error!(
497 ?error,
498 "failed to create bucket notification queue config"
499 );
500 S3StorageError::QueueConfig
501 })?,
502 ]))
503 .build(),
504 )
505 .send()
506 .await
507 .map_err(|error| {
508 tracing::error!(?error, "failed to add bucket notification queue");
509 S3StorageError::PutBucketNotification(error)
510 })?;
511
512 Ok(())
513 }
514
515 async fn set_bucket_cors_origins(&self, origins: Vec<String>) -> Result<(), StorageLayerError> {
516 if let Err(error) = self
517 .client
518 .put_bucket_cors()
519 .bucket(&self.bucket_name)
520 .cors_configuration(
521 CorsConfiguration::builder()
522 .cors_rules(
523 CorsRule::builder()
524 .allowed_headers("*")
525 .allowed_methods("PUT")
526 .set_allowed_origins(Some(origins))
527 .set_expose_headers(Some(Vec::new()))
528 .build()
529 .map_err(|error| {
530 tracing::error!(?error, "failed to create cors rule");
531 S3StorageError::CreateCorsConfig
532 })?,
533 )
534 .build()
535 .map_err(|error| {
536 tracing::error!(?error, "failed to create cors config");
537 S3StorageError::CreateCorsConfig
538 })?,
539 )
540 .send()
541 .await
542 {
543 if error
545 .raw_response()
546 .is_some_and(|response| response.status().as_u16() == 501)
548 {
549 tracing::warn!("storage s3 backend does not support PutBucketCors.. skipping..");
550 return Ok(());
551 }
552
553 tracing::error!(?error, "failed to add bucket cors");
554 return Err(S3StorageError::PutBucketCors(error).into());
555 };
556
557 Ok(())
558 }
559
560 async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError> {
561 if let Err(error) = self
562 .client
563 .delete_object()
564 .bucket(&self.bucket_name)
565 .key(key)
566 .send()
567 .await
568 {
569 if error
572 .as_service_error()
573 .and_then(|err| err.source())
574 .and_then(|source| source.downcast_ref::<aws_sdk_s3::Error>())
575 .is_some_and(|err| matches!(err, aws_sdk_s3::Error::NoSuchKey(_)))
576 {
577 return Ok(());
578 }
579
580 tracing::error!(?error, "failed to delete file object");
581 return Err(S3StorageError::DeleteObject(error).into());
582 }
583
584 Ok(())
585 }
586
587 async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError> {
588 let object = self
589 .client
590 .get_object()
591 .bucket(&self.bucket_name)
592 .key(key)
593 .send()
594 .await
595 .map_err(|error| {
596 tracing::error!(?error, "failed to get file storage object");
597 S3StorageError::GetObject(error)
598 })?;
599
600 let stream = FileStream {
601 stream: Box::pin(AwsFileStream { inner: object.body }),
602 };
603
604 Ok(stream)
605 }
606}
607
608pub struct AwsFileStream {
610 inner: ByteStream,
611}
612
613impl AwsFileStream {
614 pub fn into_inner(self) -> ByteStream {
616 self.inner
617 }
618}
619
620impl Stream for AwsFileStream {
621 type Item = std::io::Result<Bytes>;
622
623 fn poll_next(
624 self: std::pin::Pin<&mut Self>,
625 cx: &mut std::task::Context<'_>,
626 ) -> std::task::Poll<Option<Self::Item>> {
627 let this = self.get_mut();
628 let inner = std::pin::Pin::new(&mut this.inner);
629 inner.poll_next(cx).map_err(std::io::Error::other)
630 }
631}