docbox_storage/
s3.rs

1use crate::{FileStream, StorageLayer};
2use anyhow::Context;
3use aws_config::SdkConfig;
4use aws_sdk_s3::{
5    config::Credentials,
6    presigning::{PresignedRequest, PresigningConfig},
7    primitives::ByteStream,
8    types::{
9        BucketLocationConstraint, CorsConfiguration, CorsRule, CreateBucketConfiguration,
10        NotificationConfiguration, QueueConfiguration,
11    },
12};
13use bytes::Bytes;
14use chrono::{DateTime, TimeDelta, Utc};
15use futures::Stream;
16use serde::{Deserialize, Serialize};
17use std::{error::Error, time::Duration};
18
19pub type S3Client = aws_sdk_s3::Client;
20
21#[derive(Debug, Clone, Deserialize, Serialize)]
22pub struct S3StorageLayerFactoryConfig {
23    pub endpoint: S3Endpoint,
24}
25
26impl S3StorageLayerFactoryConfig {
27    pub fn from_env() -> anyhow::Result<Self> {
28        let endpoint = S3Endpoint::from_env()?;
29
30        Ok(Self { endpoint })
31    }
32}
33
34#[derive(Debug, Clone, Deserialize, Serialize)]
35#[serde(tag = "type", rename_all = "snake_case")]
36pub enum S3Endpoint {
37    Aws,
38    Custom {
39        endpoint: String,
40        access_key_id: String,
41        access_key_secret: String,
42    },
43}
44
45impl S3Endpoint {
46    pub fn from_env() -> anyhow::Result<Self> {
47        match std::env::var("DOCBOX_S3_ENDPOINT") {
48            // Using a custom S3 endpoint
49            Ok(endpoint_url) => {
50                let access_key_id = std::env::var("DOCBOX_S3_ACCESS_KEY_ID").context(
51                    "cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_ID",
52                )?;
53                let access_key_secret = std::env::var("DOCBOX_S3_ACCESS_KEY_SECRET").context(
54                    "cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_SECRET",
55                )?;
56
57                Ok(S3Endpoint::Custom {
58                    endpoint: endpoint_url,
59                    access_key_id,
60                    access_key_secret,
61                })
62            }
63            Err(_) => Ok(S3Endpoint::Aws),
64        }
65    }
66}
67
68#[derive(Clone)]
69pub struct S3StorageLayerFactory {
70    /// Client to access S3
71    client: S3Client,
72}
73
74impl S3StorageLayerFactory {
75    pub fn from_config(aws_config: &SdkConfig, config: S3StorageLayerFactoryConfig) -> Self {
76        let client = match config.endpoint {
77            S3Endpoint::Aws => {
78                tracing::debug!("using aws s3 storage layer");
79                S3Client::new(aws_config)
80            }
81            S3Endpoint::Custom {
82                endpoint,
83                access_key_id,
84                access_key_secret,
85            } => {
86                tracing::debug!("using custom s3 storage layer");
87                let credentials = Credentials::new(
88                    access_key_id,
89                    access_key_secret,
90                    None,
91                    None,
92                    "docbox_key_provider",
93                );
94
95                // Enforces the "path" style for S3 bucket access
96                let config = aws_sdk_s3::config::Builder::from(aws_config)
97                    .force_path_style(true)
98                    .endpoint_url(endpoint)
99                    .credentials_provider(credentials)
100                    .build();
101                S3Client::from_conf(config)
102            }
103        };
104
105        Self { client }
106    }
107
108    pub fn create_storage_layer(&self, bucket_name: String) -> S3StorageLayer {
109        S3StorageLayer {
110            client: self.client.clone(),
111            bucket_name,
112        }
113    }
114}
115
116#[derive(Clone)]
117pub struct S3StorageLayer {
118    /// Client to access S3
119    client: S3Client,
120
121    /// Name of the bucket to use
122    bucket_name: String,
123}
124
125impl S3StorageLayer {
126    pub fn new(client: S3Client, bucket_name: String) -> Self {
127        Self {
128            client,
129            bucket_name,
130        }
131    }
132}
133
134impl StorageLayer for S3StorageLayer {
135    async fn create_bucket(&self) -> anyhow::Result<()> {
136        let bucket_region = self
137            .client
138            .config()
139            .region()
140            .context("AWS config missing AWS_REGION")?
141            .to_string();
142
143        let constraint = BucketLocationConstraint::from(bucket_region.as_str());
144
145        let cfg = CreateBucketConfiguration::builder()
146            .location_constraint(constraint)
147            .build();
148
149        if let Err(err) = self
150            .client
151            .create_bucket()
152            .create_bucket_configuration(cfg)
153            .bucket(&self.bucket_name)
154            .send()
155            .await
156        {
157            let already_exists = err
158                .as_service_error()
159                .is_some_and(|value| value.is_bucket_already_owned_by_you());
160
161            // Bucket has already been created
162            if already_exists {
163                return Ok(());
164            }
165
166            tracing::error!(cause = ?err, "failed to create bucket");
167
168            return Err(err.into());
169        }
170
171        Ok(())
172    }
173
174    async fn delete_bucket(&self) -> anyhow::Result<()> {
175        self.client
176            .delete_bucket()
177            .bucket(&self.bucket_name)
178            .send()
179            .await
180            .context("failed to delete bucket")?;
181
182        Ok(())
183    }
184
185    async fn upload_file(
186        &self,
187        key: &str,
188        content_type: String,
189        body: Bytes,
190    ) -> anyhow::Result<()> {
191        self.client
192            .put_object()
193            .bucket(&self.bucket_name)
194            .content_type(content_type)
195            .key(key)
196            .body(body.into())
197            .send()
198            .await
199            .context("failed to store file in s3 bucket")?;
200
201        Ok(())
202    }
203
204    async fn create_presigned(
205        &self,
206        key: &str,
207        size: i64,
208    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
209        let expiry_time_minutes = 30;
210        let expires_at = Utc::now()
211            .checked_add_signed(TimeDelta::minutes(expiry_time_minutes))
212            .context("expiry time exceeds unix limit")?;
213
214        let result = self
215            .client
216            .put_object()
217            .bucket(&self.bucket_name)
218            .key(key)
219            .content_length(size)
220            .presigned(
221                PresigningConfig::builder()
222                    .expires_in(Duration::from_secs(60 * expiry_time_minutes as u64))
223                    .build()?,
224            )
225            .await
226            .context("failed to create presigned request")?;
227
228        Ok((result, expires_at))
229    }
230
231    async fn create_presigned_download(
232        &self,
233        key: &str,
234        expires_in: Duration,
235    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
236        let expires_at = Utc::now()
237            .checked_add_signed(TimeDelta::seconds(expires_in.as_secs() as i64))
238            .context("expiry time exceeds unix limit")?;
239
240        let result = self
241            .client
242            .get_object()
243            .bucket(&self.bucket_name)
244            .key(key)
245            .presigned(PresigningConfig::expires_in(expires_in)?)
246            .await?;
247
248        Ok((result, expires_at))
249    }
250
251    async fn add_bucket_notifications(&self, sqs_arn: &str) -> anyhow::Result<()> {
252        // Connect the S3 bucket for file upload notifications
253        self.client
254            .put_bucket_notification_configuration()
255            .bucket(&self.bucket_name)
256            .notification_configuration(
257                NotificationConfiguration::builder()
258                    .set_queue_configurations(Some(vec![
259                        QueueConfiguration::builder()
260                            .queue_arn(sqs_arn)
261                            .events(aws_sdk_s3::types::Event::S3ObjectCreated)
262                            .build()?,
263                    ]))
264                    .build(),
265            )
266            .send()
267            .await?;
268
269        Ok(())
270    }
271
272    async fn add_bucket_cors(&self, origins: Vec<String>) -> anyhow::Result<()> {
273        if let Err(cause) = self
274            .client
275            .put_bucket_cors()
276            .bucket(&self.bucket_name)
277            .cors_configuration(
278                CorsConfiguration::builder()
279                    .cors_rules(
280                        CorsRule::builder()
281                            .allowed_headers("*")
282                            .allowed_methods("PUT")
283                            .set_allowed_origins(Some(origins))
284                            .set_expose_headers(Some(Vec::new()))
285                            .build()?,
286                    )
287                    .build()?,
288            )
289            .send()
290            .await
291        {
292            // Handle "NotImplemented" errors (minio does not have CORS support)
293            if cause
294                .raw_response()
295                // (501 Not Implemented)
296                .is_some_and(|response| response.status().as_u16() == 501)
297            {
298                tracing::warn!("storage s3 backend does not support PutBucketCors.. skipping..");
299                return Ok(());
300            }
301
302            return Err(cause.into());
303        };
304
305        Ok(())
306    }
307
308    async fn delete_file(&self, key: &str) -> anyhow::Result<()> {
309        if let Err(cause) = self
310            .client
311            .delete_object()
312            .bucket(&self.bucket_name)
313            .key(key)
314            .send()
315            .await
316        {
317            // Handle keys that don't exist in the bucket
318            // (This is not a failure and indicates the file is already deleted)
319            if cause
320                .as_service_error()
321                .and_then(|err| err.source())
322                .and_then(|source| source.downcast_ref::<aws_sdk_s3::Error>())
323                .is_some_and(|err| matches!(err, aws_sdk_s3::Error::NoSuchKey(_)))
324            {
325                return Ok(());
326            }
327
328            return Err(cause.into());
329        }
330
331        Ok(())
332    }
333
334    async fn get_file(&self, key: &str) -> anyhow::Result<FileStream> {
335        let object = self
336            .client
337            .get_object()
338            .bucket(&self.bucket_name)
339            .key(key)
340            .send()
341            .await?;
342
343        let stream = FileStream {
344            stream: Box::pin(AwsFileStream { inner: object.body }),
345        };
346
347        Ok(stream)
348    }
349}
350
351pub struct AwsFileStream {
352    inner: ByteStream,
353}
354
355impl AwsFileStream {
356    pub fn into_inner(self) -> ByteStream {
357        self.inner
358    }
359}
360
361impl Stream for AwsFileStream {
362    type Item = std::io::Result<Bytes>;
363
364    fn poll_next(
365        self: std::pin::Pin<&mut Self>,
366        cx: &mut std::task::Context<'_>,
367    ) -> std::task::Poll<Option<Self::Item>> {
368        let this = self.get_mut();
369        let inner = std::pin::Pin::new(&mut this.inner);
370        inner.poll_next(cx).map_err(std::io::Error::other)
371    }
372}