docbox_core/storage/
s3.rs1use std::{error::Error, time::Duration};
2
3use super::{FileStream, StorageLayer};
4use crate::aws::S3Client;
5use anyhow::Context;
6use aws_config::SdkConfig;
7use aws_sdk_s3::{
8 config::Credentials,
9 presigning::{PresignedRequest, PresigningConfig},
10 primitives::ByteStream,
11 types::{
12 BucketLocationConstraint, CorsConfiguration, CorsRule, CreateBucketConfiguration,
13 NotificationConfiguration, QueueConfiguration,
14 },
15};
16use bytes::Bytes;
17use chrono::{DateTime, TimeDelta, Utc};
18use futures::Stream;
19use reqwest::StatusCode;
20use serde::Deserialize;
21
22#[derive(Debug, Clone, Deserialize)]
23pub struct S3StorageLayerFactoryConfig {
24 pub endpoint: S3Endpoint,
25}
26
27impl S3StorageLayerFactoryConfig {
28 pub fn from_env() -> anyhow::Result<Self> {
29 let endpoint = S3Endpoint::from_env()?;
30
31 Ok(Self { endpoint })
32 }
33}
34
35#[derive(Debug, Clone, Deserialize)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum S3Endpoint {
38 Aws,
39 Custom {
40 endpoint: String,
41 access_key_id: String,
42 access_key_secret: String,
43 },
44}
45
46impl S3Endpoint {
47 pub fn from_env() -> anyhow::Result<Self> {
48 match std::env::var("DOCBOX_S3_ENDPOINT") {
49 Ok(endpoint_url) => {
51 let access_key_id = std::env::var("DOCBOX_S3_ACCESS_KEY_ID").context(
52 "cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_ID",
53 )?;
54 let access_key_secret = std::env::var("DOCBOX_S3_ACCESS_KEY_SECRET").context(
55 "cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_SECRET",
56 )?;
57
58 Ok(S3Endpoint::Custom {
59 endpoint: endpoint_url,
60 access_key_id,
61 access_key_secret,
62 })
63 }
64 Err(_) => Ok(S3Endpoint::Aws),
65 }
66 }
67}
68
69#[derive(Clone)]
70pub struct S3StorageLayerFactory {
71 client: S3Client,
73}
74
75impl S3StorageLayerFactory {
76 pub fn from_config(aws_config: &SdkConfig, config: S3StorageLayerFactoryConfig) -> Self {
77 let client = match config.endpoint {
78 S3Endpoint::Aws => {
79 tracing::debug!("using aws s3 storage layer");
80 S3Client::new(aws_config)
81 }
82 S3Endpoint::Custom {
83 endpoint,
84 access_key_id,
85 access_key_secret,
86 } => {
87 tracing::debug!("using custom s3 storage layer");
88 let credentials = Credentials::new(
89 access_key_id,
90 access_key_secret,
91 None,
92 None,
93 "docbox_key_provider",
94 );
95
96 let config = aws_sdk_s3::config::Builder::from(aws_config)
98 .force_path_style(true)
99 .endpoint_url(endpoint)
100 .credentials_provider(credentials)
101 .build();
102 S3Client::from_conf(config)
103 }
104 };
105
106 Self { client }
107 }
108
109 pub fn create_storage_layer(&self, bucket_name: String) -> S3StorageLayer {
110 S3StorageLayer {
111 client: self.client.clone(),
112 bucket_name,
113 }
114 }
115}
116
117#[derive(Clone)]
118pub struct S3StorageLayer {
119 client: S3Client,
121
122 bucket_name: String,
124}
125
126impl S3StorageLayer {
127 pub fn new(client: S3Client, bucket_name: String) -> Self {
128 Self {
129 client,
130 bucket_name,
131 }
132 }
133}
134
135impl StorageLayer for S3StorageLayer {
136 async fn create_bucket(&self) -> anyhow::Result<()> {
137 let bucket_region = self
138 .client
139 .config()
140 .region()
141 .context("AWS config missing AWS_REGION")?
142 .to_string();
143
144 let constraint = BucketLocationConstraint::from(bucket_region.as_str());
145
146 let cfg = CreateBucketConfiguration::builder()
147 .location_constraint(constraint)
148 .build();
149
150 if let Err(err) = self
151 .client
152 .create_bucket()
153 .create_bucket_configuration(cfg)
154 .bucket(&self.bucket_name)
155 .send()
156 .await
157 {
158 let already_exists = err
159 .as_service_error()
160 .is_some_and(|value| value.is_bucket_already_owned_by_you());
161
162 if already_exists {
164 return Ok(());
165 }
166
167 tracing::error!(cause = ?err, "failed to create bucket");
168
169 return Err(err.into());
170 }
171
172 Ok(())
173 }
174
175 async fn delete_bucket(&self) -> anyhow::Result<()> {
176 self.client
177 .delete_bucket()
178 .bucket(&self.bucket_name)
179 .send()
180 .await
181 .context("failed to delete bucket")?;
182
183 Ok(())
184 }
185
186 async fn upload_file(
187 &self,
188 key: &str,
189 content_type: String,
190 body: Bytes,
191 ) -> anyhow::Result<()> {
192 self.client
193 .put_object()
194 .bucket(&self.bucket_name)
195 .content_type(content_type)
196 .key(key)
197 .body(body.into())
198 .send()
199 .await
200 .context("failed to store file in s3 bucket")?;
201
202 Ok(())
203 }
204
205 async fn create_presigned(
206 &self,
207 key: &str,
208 size: i64,
209 ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
210 let expiry_time_minutes = 30;
211 let expires_at = Utc::now()
212 .checked_add_signed(TimeDelta::minutes(expiry_time_minutes))
213 .context("expiry time exceeds unix limit")?;
214
215 let result = self
216 .client
217 .put_object()
218 .bucket(&self.bucket_name)
219 .key(key)
220 .content_length(size)
221 .presigned(
222 PresigningConfig::builder()
223 .expires_in(Duration::from_secs(60 * expiry_time_minutes as u64))
224 .build()?,
225 )
226 .await
227 .context("failed to create presigned request")?;
228
229 Ok((result, expires_at))
230 }
231
232 async fn add_bucket_notifications(&self, sqs_arn: &str) -> anyhow::Result<()> {
233 self.client
235 .put_bucket_notification_configuration()
236 .bucket(&self.bucket_name)
237 .notification_configuration(
238 NotificationConfiguration::builder()
239 .set_queue_configurations(Some(vec![
240 QueueConfiguration::builder()
241 .queue_arn(sqs_arn)
242 .events(aws_sdk_s3::types::Event::S3ObjectCreated)
243 .build()?,
244 ]))
245 .build(),
246 )
247 .send()
248 .await?;
249
250 Ok(())
251 }
252
253 async fn add_bucket_cors(&self, origins: Vec<String>) -> anyhow::Result<()> {
254 if let Err(cause) = self
255 .client
256 .put_bucket_cors()
257 .bucket(&self.bucket_name)
258 .cors_configuration(
259 CorsConfiguration::builder()
260 .cors_rules(
261 CorsRule::builder()
262 .allowed_headers("*")
263 .allowed_methods("PUT")
264 .set_allowed_origins(Some(origins))
265 .set_expose_headers(Some(Vec::new()))
266 .build()?,
267 )
268 .build()?,
269 )
270 .send()
271 .await
272 {
273 if cause.raw_response().is_some_and(|response| {
275 response.status().as_u16() == StatusCode::NOT_IMPLEMENTED.as_u16()
276 }) {
277 return Ok(());
278 }
279
280 return Err(cause.into());
281 };
282
283 Ok(())
284 }
285
286 async fn delete_file(&self, key: &str) -> anyhow::Result<()> {
287 if let Err(cause) = self
288 .client
289 .delete_object()
290 .bucket(&self.bucket_name)
291 .key(key)
292 .send()
293 .await
294 {
295 if cause
298 .as_service_error()
299 .and_then(|err| err.source())
300 .and_then(|source| source.downcast_ref::<aws_sdk_s3::Error>())
301 .is_some_and(|err| matches!(err, aws_sdk_s3::Error::NoSuchKey(_)))
302 {
303 return Ok(());
304 }
305
306 return Err(cause.into());
307 }
308
309 Ok(())
310 }
311
312 async fn get_file(&self, key: &str) -> anyhow::Result<FileStream> {
313 let object = self
314 .client
315 .get_object()
316 .bucket(&self.bucket_name)
317 .key(key)
318 .send()
319 .await?;
320
321 let stream = FileStream {
322 stream: Box::pin(AwsFileStream { inner: object.body }),
323 };
324
325 Ok(stream)
326 }
327}
328
329pub struct AwsFileStream {
330 inner: ByteStream,
331}
332
333impl Stream for AwsFileStream {
334 type Item = std::io::Result<Bytes>;
335
336 fn poll_next(
337 self: std::pin::Pin<&mut Self>,
338 cx: &mut std::task::Context<'_>,
339 ) -> std::task::Poll<Option<Self::Item>> {
340 let this = self.get_mut();
341 let inner = std::pin::Pin::new(&mut this.inner);
342 inner.poll_next(cx).map_err(std::io::Error::other)
343 }
344}