docbox_core/storage/
s3.rs

1use std::{error::Error, time::Duration};
2
3use super::{FileStream, StorageLayer};
4use crate::aws::S3Client;
5use anyhow::Context;
6use aws_config::SdkConfig;
7use aws_sdk_s3::{
8    config::Credentials,
9    presigning::{PresignedRequest, PresigningConfig},
10    primitives::ByteStream,
11    types::{
12        BucketLocationConstraint, CorsConfiguration, CorsRule, CreateBucketConfiguration,
13        NotificationConfiguration, QueueConfiguration,
14    },
15};
16use bytes::Bytes;
17use chrono::{DateTime, TimeDelta, Utc};
18use futures::Stream;
19use reqwest::StatusCode;
20use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, Deserialize, Serialize)]
23pub struct S3StorageLayerFactoryConfig {
24    pub endpoint: S3Endpoint,
25}
26
27impl S3StorageLayerFactoryConfig {
28    pub fn from_env() -> anyhow::Result<Self> {
29        let endpoint = S3Endpoint::from_env()?;
30
31        Ok(Self { endpoint })
32    }
33}
34
35#[derive(Debug, Clone, Deserialize, Serialize)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum S3Endpoint {
38    Aws,
39    Custom {
40        endpoint: String,
41        access_key_id: String,
42        access_key_secret: String,
43    },
44}
45
46impl S3Endpoint {
47    pub fn from_env() -> anyhow::Result<Self> {
48        match std::env::var("DOCBOX_S3_ENDPOINT") {
49            // Using a custom S3 endpoint
50            Ok(endpoint_url) => {
51                let access_key_id = std::env::var("DOCBOX_S3_ACCESS_KEY_ID").context(
52                    "cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_ID",
53                )?;
54                let access_key_secret = std::env::var("DOCBOX_S3_ACCESS_KEY_SECRET").context(
55                    "cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_SECRET",
56                )?;
57
58                Ok(S3Endpoint::Custom {
59                    endpoint: endpoint_url,
60                    access_key_id,
61                    access_key_secret,
62                })
63            }
64            Err(_) => Ok(S3Endpoint::Aws),
65        }
66    }
67}
68
69#[derive(Clone)]
70pub struct S3StorageLayerFactory {
71    /// Client to access S3
72    client: S3Client,
73}
74
75impl S3StorageLayerFactory {
76    pub fn from_config(aws_config: &SdkConfig, config: S3StorageLayerFactoryConfig) -> Self {
77        let client = match config.endpoint {
78            S3Endpoint::Aws => {
79                tracing::debug!("using aws s3 storage layer");
80                S3Client::new(aws_config)
81            }
82            S3Endpoint::Custom {
83                endpoint,
84                access_key_id,
85                access_key_secret,
86            } => {
87                tracing::debug!("using custom s3 storage layer");
88                let credentials = Credentials::new(
89                    access_key_id,
90                    access_key_secret,
91                    None,
92                    None,
93                    "docbox_key_provider",
94                );
95
96                // Enforces the "path" style for S3 bucket access
97                let config = aws_sdk_s3::config::Builder::from(aws_config)
98                    .force_path_style(true)
99                    .endpoint_url(endpoint)
100                    .credentials_provider(credentials)
101                    .build();
102                S3Client::from_conf(config)
103            }
104        };
105
106        Self { client }
107    }
108
109    pub fn create_storage_layer(&self, bucket_name: String) -> S3StorageLayer {
110        S3StorageLayer {
111            client: self.client.clone(),
112            bucket_name,
113        }
114    }
115}
116
117#[derive(Clone)]
118pub struct S3StorageLayer {
119    /// Client to access S3
120    client: S3Client,
121
122    /// Name of the bucket to use
123    bucket_name: String,
124}
125
126impl S3StorageLayer {
127    pub fn new(client: S3Client, bucket_name: String) -> Self {
128        Self {
129            client,
130            bucket_name,
131        }
132    }
133}
134
135impl StorageLayer for S3StorageLayer {
136    async fn create_bucket(&self) -> anyhow::Result<()> {
137        let bucket_region = self
138            .client
139            .config()
140            .region()
141            .context("AWS config missing AWS_REGION")?
142            .to_string();
143
144        let constraint = BucketLocationConstraint::from(bucket_region.as_str());
145
146        let cfg = CreateBucketConfiguration::builder()
147            .location_constraint(constraint)
148            .build();
149
150        if let Err(err) = self
151            .client
152            .create_bucket()
153            .create_bucket_configuration(cfg)
154            .bucket(&self.bucket_name)
155            .send()
156            .await
157        {
158            let already_exists = err
159                .as_service_error()
160                .is_some_and(|value| value.is_bucket_already_owned_by_you());
161
162            // Bucket has already been created
163            if already_exists {
164                return Ok(());
165            }
166
167            tracing::error!(cause = ?err, "failed to create bucket");
168
169            return Err(err.into());
170        }
171
172        Ok(())
173    }
174
175    async fn delete_bucket(&self) -> anyhow::Result<()> {
176        self.client
177            .delete_bucket()
178            .bucket(&self.bucket_name)
179            .send()
180            .await
181            .context("failed to delete bucket")?;
182
183        Ok(())
184    }
185
186    async fn upload_file(
187        &self,
188        key: &str,
189        content_type: String,
190        body: Bytes,
191    ) -> anyhow::Result<()> {
192        self.client
193            .put_object()
194            .bucket(&self.bucket_name)
195            .content_type(content_type)
196            .key(key)
197            .body(body.into())
198            .send()
199            .await
200            .context("failed to store file in s3 bucket")?;
201
202        Ok(())
203    }
204
205    async fn create_presigned(
206        &self,
207        key: &str,
208        size: i64,
209    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
210        let expiry_time_minutes = 30;
211        let expires_at = Utc::now()
212            .checked_add_signed(TimeDelta::minutes(expiry_time_minutes))
213            .context("expiry time exceeds unix limit")?;
214
215        let result = self
216            .client
217            .put_object()
218            .bucket(&self.bucket_name)
219            .key(key)
220            .content_length(size)
221            .presigned(
222                PresigningConfig::builder()
223                    .expires_in(Duration::from_secs(60 * expiry_time_minutes as u64))
224                    .build()?,
225            )
226            .await
227            .context("failed to create presigned request")?;
228
229        Ok((result, expires_at))
230    }
231
232    async fn create_presigned_download(
233        &self,
234        key: &str,
235        expires_in: Duration,
236    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
237        let expires_at = Utc::now()
238            .checked_add_signed(TimeDelta::seconds(expires_in.as_secs() as i64))
239            .context("expiry time exceeds unix limit")?;
240
241        let result = self
242            .client
243            .get_object()
244            .bucket(&self.bucket_name)
245            .key(key)
246            .presigned(PresigningConfig::expires_in(expires_in)?)
247            .await?;
248
249        Ok((result, expires_at))
250    }
251
252    async fn add_bucket_notifications(&self, sqs_arn: &str) -> anyhow::Result<()> {
253        // Connect the S3 bucket for file upload notifications
254        self.client
255            .put_bucket_notification_configuration()
256            .bucket(&self.bucket_name)
257            .notification_configuration(
258                NotificationConfiguration::builder()
259                    .set_queue_configurations(Some(vec![
260                        QueueConfiguration::builder()
261                            .queue_arn(sqs_arn)
262                            .events(aws_sdk_s3::types::Event::S3ObjectCreated)
263                            .build()?,
264                    ]))
265                    .build(),
266            )
267            .send()
268            .await?;
269
270        Ok(())
271    }
272
273    async fn add_bucket_cors(&self, origins: Vec<String>) -> anyhow::Result<()> {
274        if let Err(cause) = self
275            .client
276            .put_bucket_cors()
277            .bucket(&self.bucket_name)
278            .cors_configuration(
279                CorsConfiguration::builder()
280                    .cors_rules(
281                        CorsRule::builder()
282                            .allowed_headers("*")
283                            .allowed_methods("PUT")
284                            .set_allowed_origins(Some(origins))
285                            .set_expose_headers(Some(Vec::new()))
286                            .build()?,
287                    )
288                    .build()?,
289            )
290            .send()
291            .await
292        {
293            // Handle "NotImplemented" errors (Local minio testing server does not have CORS support)
294            if cause.raw_response().is_some_and(|response| {
295                response.status().as_u16() == StatusCode::NOT_IMPLEMENTED.as_u16()
296            }) {
297                return Ok(());
298            }
299
300            return Err(cause.into());
301        };
302
303        Ok(())
304    }
305
306    async fn delete_file(&self, key: &str) -> anyhow::Result<()> {
307        if let Err(cause) = self
308            .client
309            .delete_object()
310            .bucket(&self.bucket_name)
311            .key(key)
312            .send()
313            .await
314        {
315            // Handle keys that don't exist in the bucket
316            // (This is not a failure and indicates the file is already deleted)
317            if cause
318                .as_service_error()
319                .and_then(|err| err.source())
320                .and_then(|source| source.downcast_ref::<aws_sdk_s3::Error>())
321                .is_some_and(|err| matches!(err, aws_sdk_s3::Error::NoSuchKey(_)))
322            {
323                return Ok(());
324            }
325
326            return Err(cause.into());
327        }
328
329        Ok(())
330    }
331
332    async fn get_file(&self, key: &str) -> anyhow::Result<FileStream> {
333        let object = self
334            .client
335            .get_object()
336            .bucket(&self.bucket_name)
337            .key(key)
338            .send()
339            .await?;
340
341        let stream = FileStream {
342            stream: Box::pin(AwsFileStream { inner: object.body }),
343        };
344
345        Ok(stream)
346    }
347}
348
349pub struct AwsFileStream {
350    inner: ByteStream,
351}
352
353impl Stream for AwsFileStream {
354    type Item = std::io::Result<Bytes>;
355
356    fn poll_next(
357        self: std::pin::Pin<&mut Self>,
358        cx: &mut std::task::Context<'_>,
359    ) -> std::task::Poll<Option<Self::Item>> {
360        let this = self.get_mut();
361        let inner = std::pin::Pin::new(&mut this.inner);
362        inner.poll_next(cx).map_err(std::io::Error::other)
363    }
364}