docbox_core/storage/
s3.rs1use std::{error::Error, time::Duration};
2
3use super::{FileStream, StorageLayer};
4use crate::aws::S3Client;
5use anyhow::Context;
6use aws_config::SdkConfig;
7use aws_sdk_s3::{
8 config::Credentials,
9 presigning::{PresignedRequest, PresigningConfig},
10 primitives::ByteStream,
11 types::{
12 BucketLocationConstraint, CorsConfiguration, CorsRule, CreateBucketConfiguration,
13 NotificationConfiguration, QueueConfiguration,
14 },
15};
16use bytes::Bytes;
17use chrono::{DateTime, TimeDelta, Utc};
18use futures::Stream;
19use reqwest::StatusCode;
20use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, Deserialize, Serialize)]
23pub struct S3StorageLayerFactoryConfig {
24 pub endpoint: S3Endpoint,
25}
26
27impl S3StorageLayerFactoryConfig {
28 pub fn from_env() -> anyhow::Result<Self> {
29 let endpoint = S3Endpoint::from_env()?;
30
31 Ok(Self { endpoint })
32 }
33}
34
35#[derive(Debug, Clone, Deserialize, Serialize)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum S3Endpoint {
38 Aws,
39 Custom {
40 endpoint: String,
41 access_key_id: String,
42 access_key_secret: String,
43 },
44}
45
46impl S3Endpoint {
47 pub fn from_env() -> anyhow::Result<Self> {
48 match std::env::var("DOCBOX_S3_ENDPOINT") {
49 Ok(endpoint_url) => {
51 let access_key_id = std::env::var("DOCBOX_S3_ACCESS_KEY_ID").context(
52 "cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_ID",
53 )?;
54 let access_key_secret = std::env::var("DOCBOX_S3_ACCESS_KEY_SECRET").context(
55 "cannot use DOCBOX_S3_ENDPOINT without specifying DOCBOX_S3_ACCESS_KEY_SECRET",
56 )?;
57
58 Ok(S3Endpoint::Custom {
59 endpoint: endpoint_url,
60 access_key_id,
61 access_key_secret,
62 })
63 }
64 Err(_) => Ok(S3Endpoint::Aws),
65 }
66 }
67}
68
69#[derive(Clone)]
70pub struct S3StorageLayerFactory {
71 client: S3Client,
73}
74
75impl S3StorageLayerFactory {
76 pub fn from_config(aws_config: &SdkConfig, config: S3StorageLayerFactoryConfig) -> Self {
77 let client = match config.endpoint {
78 S3Endpoint::Aws => {
79 tracing::debug!("using aws s3 storage layer");
80 S3Client::new(aws_config)
81 }
82 S3Endpoint::Custom {
83 endpoint,
84 access_key_id,
85 access_key_secret,
86 } => {
87 tracing::debug!("using custom s3 storage layer");
88 let credentials = Credentials::new(
89 access_key_id,
90 access_key_secret,
91 None,
92 None,
93 "docbox_key_provider",
94 );
95
96 let config = aws_sdk_s3::config::Builder::from(aws_config)
98 .force_path_style(true)
99 .endpoint_url(endpoint)
100 .credentials_provider(credentials)
101 .build();
102 S3Client::from_conf(config)
103 }
104 };
105
106 Self { client }
107 }
108
109 pub fn create_storage_layer(&self, bucket_name: String) -> S3StorageLayer {
110 S3StorageLayer {
111 client: self.client.clone(),
112 bucket_name,
113 }
114 }
115}
116
117#[derive(Clone)]
118pub struct S3StorageLayer {
119 client: S3Client,
121
122 bucket_name: String,
124}
125
126impl S3StorageLayer {
127 pub fn new(client: S3Client, bucket_name: String) -> Self {
128 Self {
129 client,
130 bucket_name,
131 }
132 }
133}
134
135impl StorageLayer for S3StorageLayer {
136 async fn create_bucket(&self) -> anyhow::Result<()> {
137 let bucket_region = self
138 .client
139 .config()
140 .region()
141 .context("AWS config missing AWS_REGION")?
142 .to_string();
143
144 let constraint = BucketLocationConstraint::from(bucket_region.as_str());
145
146 let cfg = CreateBucketConfiguration::builder()
147 .location_constraint(constraint)
148 .build();
149
150 if let Err(err) = self
151 .client
152 .create_bucket()
153 .create_bucket_configuration(cfg)
154 .bucket(&self.bucket_name)
155 .send()
156 .await
157 {
158 let already_exists = err
159 .as_service_error()
160 .is_some_and(|value| value.is_bucket_already_owned_by_you());
161
162 if already_exists {
164 return Ok(());
165 }
166
167 tracing::error!(cause = ?err, "failed to create bucket");
168
169 return Err(err.into());
170 }
171
172 Ok(())
173 }
174
175 async fn delete_bucket(&self) -> anyhow::Result<()> {
176 self.client
177 .delete_bucket()
178 .bucket(&self.bucket_name)
179 .send()
180 .await
181 .context("failed to delete bucket")?;
182
183 Ok(())
184 }
185
186 async fn upload_file(
187 &self,
188 key: &str,
189 content_type: String,
190 body: Bytes,
191 ) -> anyhow::Result<()> {
192 self.client
193 .put_object()
194 .bucket(&self.bucket_name)
195 .content_type(content_type)
196 .key(key)
197 .body(body.into())
198 .send()
199 .await
200 .context("failed to store file in s3 bucket")?;
201
202 Ok(())
203 }
204
205 async fn create_presigned(
206 &self,
207 key: &str,
208 size: i64,
209 ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
210 let expiry_time_minutes = 30;
211 let expires_at = Utc::now()
212 .checked_add_signed(TimeDelta::minutes(expiry_time_minutes))
213 .context("expiry time exceeds unix limit")?;
214
215 let result = self
216 .client
217 .put_object()
218 .bucket(&self.bucket_name)
219 .key(key)
220 .content_length(size)
221 .presigned(
222 PresigningConfig::builder()
223 .expires_in(Duration::from_secs(60 * expiry_time_minutes as u64))
224 .build()?,
225 )
226 .await
227 .context("failed to create presigned request")?;
228
229 Ok((result, expires_at))
230 }
231
232 async fn create_presigned_download(
233 &self,
234 key: &str,
235 expires_in: Duration,
236 ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
237 let expires_at = Utc::now()
238 .checked_add_signed(TimeDelta::seconds(expires_in.as_secs() as i64))
239 .context("expiry time exceeds unix limit")?;
240
241 let result = self
242 .client
243 .get_object()
244 .bucket(&self.bucket_name)
245 .key(key)
246 .presigned(PresigningConfig::expires_in(expires_in)?)
247 .await?;
248
249 Ok((result, expires_at))
250 }
251
252 async fn add_bucket_notifications(&self, sqs_arn: &str) -> anyhow::Result<()> {
253 self.client
255 .put_bucket_notification_configuration()
256 .bucket(&self.bucket_name)
257 .notification_configuration(
258 NotificationConfiguration::builder()
259 .set_queue_configurations(Some(vec![
260 QueueConfiguration::builder()
261 .queue_arn(sqs_arn)
262 .events(aws_sdk_s3::types::Event::S3ObjectCreated)
263 .build()?,
264 ]))
265 .build(),
266 )
267 .send()
268 .await?;
269
270 Ok(())
271 }
272
273 async fn add_bucket_cors(&self, origins: Vec<String>) -> anyhow::Result<()> {
274 if let Err(cause) = self
275 .client
276 .put_bucket_cors()
277 .bucket(&self.bucket_name)
278 .cors_configuration(
279 CorsConfiguration::builder()
280 .cors_rules(
281 CorsRule::builder()
282 .allowed_headers("*")
283 .allowed_methods("PUT")
284 .set_allowed_origins(Some(origins))
285 .set_expose_headers(Some(Vec::new()))
286 .build()?,
287 )
288 .build()?,
289 )
290 .send()
291 .await
292 {
293 if cause.raw_response().is_some_and(|response| {
295 response.status().as_u16() == StatusCode::NOT_IMPLEMENTED.as_u16()
296 }) {
297 return Ok(());
298 }
299
300 return Err(cause.into());
301 };
302
303 Ok(())
304 }
305
306 async fn delete_file(&self, key: &str) -> anyhow::Result<()> {
307 if let Err(cause) = self
308 .client
309 .delete_object()
310 .bucket(&self.bucket_name)
311 .key(key)
312 .send()
313 .await
314 {
315 if cause
318 .as_service_error()
319 .and_then(|err| err.source())
320 .and_then(|source| source.downcast_ref::<aws_sdk_s3::Error>())
321 .is_some_and(|err| matches!(err, aws_sdk_s3::Error::NoSuchKey(_)))
322 {
323 return Ok(());
324 }
325
326 return Err(cause.into());
327 }
328
329 Ok(())
330 }
331
332 async fn get_file(&self, key: &str) -> anyhow::Result<FileStream> {
333 let object = self
334 .client
335 .get_object()
336 .bucket(&self.bucket_name)
337 .key(key)
338 .send()
339 .await?;
340
341 let stream = FileStream {
342 stream: Box::pin(AwsFileStream { inner: object.body }),
343 };
344
345 Ok(stream)
346 }
347}
348
349pub struct AwsFileStream {
350 inner: ByteStream,
351}
352
353impl Stream for AwsFileStream {
354 type Item = std::io::Result<Bytes>;
355
356 fn poll_next(
357 self: std::pin::Pin<&mut Self>,
358 cx: &mut std::task::Context<'_>,
359 ) -> std::task::Poll<Option<Self::Item>> {
360 let this = self.get_mut();
361 let inner = std::pin::Pin::new(&mut this.inner);
362 inner.poll_next(cx).map_err(std::io::Error::other)
363 }
364}