1use crate::{CreateBucketOutcome, FileStream, StorageLayerError, StorageLayerImpl};
13use aws_config::SdkConfig;
14use aws_sdk_s3::{
15 config::Credentials,
16 error::SdkError,
17 operation::{
18 create_bucket::CreateBucketError, delete_bucket::DeleteBucketError,
19 delete_object::DeleteObjectError, get_object::GetObjectError, head_bucket::HeadBucketError,
20 put_bucket_cors::PutBucketCorsError,
21 put_bucket_notification_configuration::PutBucketNotificationConfigurationError,
22 put_object::PutObjectError,
23 },
24 presigning::{PresignedRequest, PresigningConfig},
25 primitives::ByteStream,
26 types::{
27 BucketLocationConstraint, CorsConfiguration, CorsRule, CreateBucketConfiguration,
28 NotificationConfiguration, QueueConfiguration,
29 },
30};
31use bytes::Bytes;
32use chrono::{DateTime, TimeDelta, Utc};
33use futures::Stream;
34use serde::{Deserialize, Serialize};
35use std::{error::Error, fmt::Debug, time::Duration};
36use thiserror::Error;
37
38type S3Client = aws_sdk_s3::Client;
39
40#[derive(Debug, Clone, Deserialize, Serialize)]
42pub struct S3StorageLayerFactoryConfig {
43 pub endpoint: S3Endpoint,
45}
46
47#[derive(Debug, Error)]
49pub enum S3StorageLayerFactoryConfigError {
50 #[error("cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_ID")]
52 MissingAccessKeyId,
53
54 #[error("cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_SECRET")]
56 MissingAccessKeySecret,
57}
58
59impl S3StorageLayerFactoryConfig {
60 pub fn from_env() -> Result<Self, S3StorageLayerFactoryConfigError> {
62 let endpoint = S3Endpoint::from_env()?;
63
64 Ok(Self { endpoint })
65 }
66}
67
68#[derive(Clone, Deserialize, Serialize)]
70#[serde(tag = "type", rename_all = "snake_case")]
71pub enum S3Endpoint {
72 Aws,
74 Custom {
76 endpoint: String,
78 external_endpoint: Option<String>,
80 access_key_id: String,
82 access_key_secret: String,
84 },
85}
86
87impl Debug for S3Endpoint {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 match self {
90 Self::Aws => write!(f, "Aws"),
91 Self::Custom { endpoint, .. } => f
92 .debug_struct("Custom")
93 .field("endpoint", endpoint)
94 .finish(),
95 }
96 }
97}
98
99impl S3Endpoint {
100 pub fn from_env() -> Result<Self, S3StorageLayerFactoryConfigError> {
102 match std::env::var("DOCBOX_S3_ENDPOINT") {
103 Ok(endpoint_url) => {
105 let access_key_id = std::env::var("DOCBOX_S3_ACCESS_KEY_ID")
106 .map_err(|_| S3StorageLayerFactoryConfigError::MissingAccessKeyId)?;
107 let access_key_secret = std::env::var("DOCBOX_S3_ACCESS_KEY_SECRET")
108 .map_err(|_| S3StorageLayerFactoryConfigError::MissingAccessKeySecret)?;
109
110 let external_endpoint = std::env::var("DOCBOX_S3_EXTERNAL_ENDPOINT").ok();
111
112 Ok(S3Endpoint::Custom {
113 endpoint: endpoint_url,
114 external_endpoint,
115 access_key_id,
116 access_key_secret,
117 })
118 }
119 Err(_) => Ok(S3Endpoint::Aws),
120 }
121 }
122}
123
124#[derive(Clone)]
126pub struct S3StorageLayerFactory {
127 client: S3Client,
129 external_client: Option<S3Client>,
131}
132
133impl S3StorageLayerFactory {
134 pub fn from_config(aws_config: &SdkConfig, config: S3StorageLayerFactoryConfig) -> Self {
136 let (client, external_client) = match config.endpoint {
137 S3Endpoint::Aws => {
138 tracing::debug!("using aws s3 storage layer");
139 (S3Client::new(aws_config), None)
140 }
141 S3Endpoint::Custom {
142 endpoint,
143 external_endpoint,
144 access_key_id,
145 access_key_secret,
146 } => {
147 tracing::debug!("using custom s3 storage layer");
148 let credentials = Credentials::new(
149 access_key_id,
150 access_key_secret,
151 None,
152 None,
153 "docbox_key_provider",
154 );
155
156 let config_builder = aws_sdk_s3::config::Builder::from(aws_config)
158 .force_path_style(true)
159 .endpoint_url(endpoint)
160 .credentials_provider(credentials);
161
162 let external_client = match external_endpoint {
164 Some(external_endpoint) => {
165 let config = config_builder
166 .clone()
167 .endpoint_url(external_endpoint)
168 .build();
169 let client = S3Client::from_conf(config);
170 Some(client)
171 }
172 None => None,
173 };
174
175 let config = config_builder.build();
176 let client = S3Client::from_conf(config);
177
178 (client, external_client)
179 }
180 };
181
182 Self {
183 client,
184 external_client,
185 }
186 }
187
188 pub fn create_storage_layer(&self, bucket_name: String) -> S3StorageLayer {
190 S3StorageLayer::new(
191 self.client.clone(),
192 self.external_client.clone(),
193 bucket_name,
194 )
195 }
196}
197
198#[derive(Clone)]
200pub struct S3StorageLayer {
201 bucket_name: String,
203
204 client: S3Client,
206
207 external_client: Option<S3Client>,
209}
210
211impl S3StorageLayer {
212 fn new(client: S3Client, external_client: Option<S3Client>, bucket_name: String) -> Self {
214 Self {
215 bucket_name,
216 client,
217 external_client,
218 }
219 }
220}
221
222#[derive(Debug, Error)]
227pub enum S3StorageError {
228 #[error("invalid server configuration (region)")]
230 MissingRegion,
231
232 #[error("failed to create storage bucket")]
234 CreateBucket(SdkError<CreateBucketError>),
235
236 #[error("failed to delete storage bucket")]
238 DeleteBucket(SdkError<DeleteBucketError>),
239
240 #[error("failed to get storage bucket")]
242 HeadBucket(SdkError<HeadBucketError>),
243
244 #[error("failed to store file object")]
246 PutObject(SdkError<PutObjectError>),
247
248 #[error("failed to calculate expiry timestamp")]
250 UnixTimeCalculation,
251
252 #[error("failed to create presigned store file object")]
254 PutObjectPresigned(SdkError<PutObjectError>),
255
256 #[error("failed to create presigned config")]
258 PresignedConfig,
259
260 #[error("failed to get presigned store file object")]
262 GetObjectPresigned(SdkError<GetObjectError>),
263
264 #[error("failed to create bucket notification queue config")]
266 QueueConfig,
267
268 #[error("failed to add bucket notification queue: {0}")]
274 PutBucketNotification(SdkError<PutBucketNotificationConfigurationError>),
275
276 #[error("failed to create bucket cors config")]
278 CreateCorsConfig,
279
280 #[error("failed to set bucket cors rules: {0}")]
286 PutBucketCors(SdkError<PutBucketCorsError>),
287
288 #[error("failed to delete file object")]
290 DeleteObject(SdkError<DeleteObjectError>),
291
292 #[error("failed to get file storage object")]
294 GetObject(SdkError<GetObjectError>),
295}
296
297impl StorageLayerImpl for S3StorageLayer {
298 fn bucket_name(&self) -> String {
299 self.bucket_name.clone()
300 }
301
302 async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError> {
303 let bucket_region = self
304 .client
305 .config()
306 .region()
307 .ok_or(S3StorageError::MissingRegion)?
308 .to_string();
309
310 let constraint = BucketLocationConstraint::from(bucket_region.as_str());
311
312 let cfg = CreateBucketConfiguration::builder()
313 .location_constraint(constraint)
314 .build();
315
316 if let Err(error) = self
317 .client
318 .create_bucket()
319 .create_bucket_configuration(cfg)
320 .bucket(&self.bucket_name)
321 .send()
322 .await
323 {
324 let already_exists = error
325 .as_service_error()
326 .is_some_and(|value| value.is_bucket_already_owned_by_you());
327
328 if already_exists {
330 tracing::debug!("bucket already exists");
331 return Ok(CreateBucketOutcome::Existing);
332 }
333
334 tracing::error!(?error, "failed to create bucket");
335 return Err(S3StorageError::CreateBucket(error).into());
336 }
337
338 Ok(CreateBucketOutcome::New)
339 }
340
341 async fn bucket_exists(&self) -> Result<bool, StorageLayerError> {
342 if let Err(error) = self
343 .client
344 .head_bucket()
345 .bucket(&self.bucket_name)
346 .send()
347 .await
348 {
349 if error
351 .as_service_error()
352 .is_some_and(|error| error.is_not_found())
353 {
354 return Ok(false);
355 }
356
357 return Err(S3StorageError::HeadBucket(error).into());
358 }
359
360 Ok(true)
361 }
362
363 async fn delete_bucket(&self) -> Result<(), StorageLayerError> {
364 if let Err(error) = self
365 .client
366 .delete_bucket()
367 .bucket(&self.bucket_name)
368 .send()
369 .await
370 {
371 if error
374 .as_service_error()
375 .and_then(|err| err.meta().code())
376 .is_some_and(|code| code == "NoSuchBucket")
377 {
378 tracing::debug!("bucket did not exist");
379 return Ok(());
380 }
381
382 tracing::error!(?error, "failed to delete bucket");
383
384 return Err(S3StorageError::DeleteBucket(error).into());
385 }
386
387 Ok(())
388 }
389
390 async fn upload_file(
391 &self,
392 key: &str,
393 content_type: String,
394 body: Bytes,
395 ) -> Result<(), StorageLayerError> {
396 self.client
397 .put_object()
398 .bucket(&self.bucket_name)
399 .content_type(content_type)
400 .key(key)
401 .body(body.into())
402 .send()
403 .await
404 .map_err(|error| {
405 tracing::error!(?error, "failed to store file object");
406 S3StorageError::PutObject(error)
407 })?;
408
409 Ok(())
410 }
411
412 async fn create_presigned(
413 &self,
414 key: &str,
415 size: i64,
416 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
417 let expiry_time_minutes = 30;
418 let expires_at = Utc::now()
419 .checked_add_signed(TimeDelta::minutes(expiry_time_minutes))
420 .ok_or(S3StorageError::UnixTimeCalculation)?;
421
422 let client = match self.external_client.as_ref() {
423 Some(external_client) => external_client,
424 None => &self.client,
425 };
426
427 let result = client
428 .put_object()
429 .bucket(&self.bucket_name)
430 .key(key)
431 .content_length(size)
432 .presigned(
433 PresigningConfig::builder()
434 .expires_in(Duration::from_secs(60 * expiry_time_minutes as u64))
435 .build()
436 .map_err(|error| {
437 tracing::error!(?error, "Failed to create presigned store config");
438 S3StorageError::PresignedConfig
439 })?,
440 )
441 .await
442 .map_err(|error| {
443 tracing::error!(?error, "failed to create presigned store file object");
444 S3StorageError::PutObjectPresigned(error)
445 })?;
446
447 Ok((result, expires_at))
448 }
449
450 async fn create_presigned_download(
451 &self,
452 key: &str,
453 expires_in: Duration,
454 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
455 let expires_at = Utc::now()
456 .checked_add_signed(TimeDelta::seconds(expires_in.as_secs() as i64))
457 .ok_or(S3StorageError::UnixTimeCalculation)?;
458
459 let client = match self.external_client.as_ref() {
460 Some(external_client) => external_client,
461 None => &self.client,
462 };
463
464 let result = client
465 .get_object()
466 .bucket(&self.bucket_name)
467 .key(key)
468 .presigned(PresigningConfig::expires_in(expires_in).map_err(|error| {
469 tracing::error!(?error, "failed to create presigned download config");
470 S3StorageError::PresignedConfig
471 })?)
472 .await
473 .map_err(|error| {
474 tracing::error!(?error, "failed to create presigned download");
475 S3StorageError::GetObjectPresigned(error)
476 })?;
477
478 Ok((result, expires_at))
479 }
480
481 async fn add_bucket_notifications(&self, sqs_arn: &str) -> Result<(), StorageLayerError> {
482 self.client
484 .put_bucket_notification_configuration()
485 .bucket(&self.bucket_name)
486 .notification_configuration(
487 NotificationConfiguration::builder()
488 .set_queue_configurations(Some(vec![
489 QueueConfiguration::builder()
490 .queue_arn(sqs_arn)
491 .events(aws_sdk_s3::types::Event::S3ObjectCreated)
492 .build()
493 .map_err(|error| {
494 tracing::error!(
495 ?error,
496 "failed to create bucket notification queue config"
497 );
498 S3StorageError::QueueConfig
499 })?,
500 ]))
501 .build(),
502 )
503 .send()
504 .await
505 .map_err(|error| {
506 tracing::error!(?error, "failed to add bucket notification queue");
507 S3StorageError::PutBucketNotification(error)
508 })?;
509
510 Ok(())
511 }
512
513 async fn set_bucket_cors_origins(&self, origins: Vec<String>) -> Result<(), StorageLayerError> {
514 if let Err(error) = self
515 .client
516 .put_bucket_cors()
517 .bucket(&self.bucket_name)
518 .cors_configuration(
519 CorsConfiguration::builder()
520 .cors_rules(
521 CorsRule::builder()
522 .allowed_headers("*")
523 .allowed_methods("PUT")
524 .set_allowed_origins(Some(origins))
525 .set_expose_headers(Some(Vec::new()))
526 .build()
527 .map_err(|error| {
528 tracing::error!(?error, "failed to create cors rule");
529 S3StorageError::CreateCorsConfig
530 })?,
531 )
532 .build()
533 .map_err(|error| {
534 tracing::error!(?error, "failed to create cors config");
535 S3StorageError::CreateCorsConfig
536 })?,
537 )
538 .send()
539 .await
540 {
541 if error
543 .raw_response()
544 .is_some_and(|response| response.status().as_u16() == 501)
546 {
547 tracing::warn!("storage s3 backend does not support PutBucketCors.. skipping..");
548 return Ok(());
549 }
550
551 tracing::error!(?error, "failed to add bucket cors");
552 return Err(S3StorageError::PutBucketCors(error).into());
553 };
554
555 Ok(())
556 }
557
558 async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError> {
559 if let Err(error) = self
560 .client
561 .delete_object()
562 .bucket(&self.bucket_name)
563 .key(key)
564 .send()
565 .await
566 {
567 if error
570 .as_service_error()
571 .and_then(|err| err.source())
572 .and_then(|source| source.downcast_ref::<aws_sdk_s3::Error>())
573 .is_some_and(|err| matches!(err, aws_sdk_s3::Error::NoSuchKey(_)))
574 {
575 return Ok(());
576 }
577
578 tracing::error!(?error, "failed to delete file object");
579 return Err(S3StorageError::DeleteObject(error).into());
580 }
581
582 Ok(())
583 }
584
585 async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError> {
586 let object = self
587 .client
588 .get_object()
589 .bucket(&self.bucket_name)
590 .key(key)
591 .send()
592 .await
593 .map_err(|error| {
594 tracing::error!(?error, "failed to get file storage object");
595 S3StorageError::GetObject(error)
596 })?;
597
598 let stream = FileStream {
599 stream: Box::pin(AwsFileStream { inner: object.body }),
600 };
601
602 Ok(stream)
603 }
604}
605
606pub struct AwsFileStream {
608 inner: ByteStream,
609}
610
611impl AwsFileStream {
612 pub fn into_inner(self) -> ByteStream {
614 self.inner
615 }
616}
617
618impl Stream for AwsFileStream {
619 type Item = std::io::Result<Bytes>;
620
621 fn poll_next(
622 self: std::pin::Pin<&mut Self>,
623 cx: &mut std::task::Context<'_>,
624 ) -> std::task::Poll<Option<Self::Item>> {
625 let this = self.get_mut();
626 let inner = std::pin::Pin::new(&mut this.inner);
627 inner.poll_next(cx).map_err(std::io::Error::other)
628 }
629}