1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3
4use aws_config::SdkConfig;
14use aws_sdk_s3::presigning::PresignedRequest;
15use bytes::{Buf, Bytes};
16use bytes_utils::SegmentedBuf;
17use chrono::{DateTime, Utc};
18use docbox_database::models::tenant::Tenant;
19use futures::{Stream, StreamExt};
20use serde::{Deserialize, Serialize};
21use std::{fmt::Debug, pin::Pin, time::Duration};
22use thiserror::Error;
23
24pub mod s3;
25
26#[derive(Debug, Clone, Deserialize, Serialize)]
28#[serde(tag = "provider", rename_all = "snake_case")]
29pub enum StorageLayerFactoryConfig {
30 S3(s3::S3StorageLayerFactoryConfig),
32}
33
34impl Default for StorageLayerFactoryConfig {
35 fn default() -> Self {
36 Self::S3(Default::default())
37 }
38}
39
40#[derive(Debug, Error)]
43pub enum StorageLayerFactoryConfigError {
44 #[error(transparent)]
46 S3(#[from] s3::S3StorageLayerFactoryConfigError),
47}
48
49impl StorageLayerFactoryConfig {
50 pub fn from_env() -> Result<Self, StorageLayerFactoryConfigError> {
52 s3::S3StorageLayerFactoryConfig::from_env()
53 .map(Self::S3)
54 .map_err(StorageLayerFactoryConfigError::S3)
55 }
56}
57
58#[derive(Clone)]
61pub enum StorageLayerFactory {
62 S3(s3::S3StorageLayerFactory),
64}
65
66#[derive(Debug, Error)]
68pub enum StorageLayerError {
69 #[error(transparent)]
71 S3(Box<s3::S3StorageError>),
72
73 #[error("failed to collect file contents")]
75 CollectBytes,
76}
77
78impl From<s3::S3StorageError> for StorageLayerError {
79 fn from(value: s3::S3StorageError) -> Self {
80 Self::S3(Box::new(value))
81 }
82}
83
84impl StorageLayerFactory {
85 pub fn from_config(aws_config: &SdkConfig, config: StorageLayerFactoryConfig) -> Self {
87 match config {
88 StorageLayerFactoryConfig::S3(config) => {
89 Self::S3(s3::S3StorageLayerFactory::from_config(aws_config, config))
90 }
91 }
92 }
93
94 pub fn create_storage_layer(&self, tenant: &Tenant) -> TenantStorageLayer {
96 self.create_storage_layer_bucket(tenant.s3_name.clone())
97 }
98
99 pub fn create_storage_layer_bucket(&self, bucket_name: String) -> TenantStorageLayer {
101 match self {
102 StorageLayerFactory::S3(s3) => {
103 let layer = s3.create_storage_layer(bucket_name);
104 TenantStorageLayer::S3(layer)
105 }
106 }
107 }
108}
109
110#[derive(Clone)]
112pub enum TenantStorageLayer {
113 S3(s3::S3StorageLayer),
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119pub enum CreateBucketOutcome {
120 New,
122 Existing,
124}
125
126impl TenantStorageLayer {
127 pub fn bucket_name(&self) -> String {
129 match self {
130 TenantStorageLayer::S3(layer) => layer.bucket_name(),
131 }
132 }
133
134 #[tracing::instrument(skip(self))]
139 pub async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError> {
140 match self {
141 TenantStorageLayer::S3(layer) => layer.create_bucket().await,
142 }
143 }
144
145 #[tracing::instrument(skip(self))]
147 pub async fn bucket_exists(&self) -> Result<bool, StorageLayerError> {
148 match self {
149 TenantStorageLayer::S3(layer) => layer.bucket_exists().await,
150 }
151 }
152
153 #[tracing::instrument(skip(self))]
158 pub async fn delete_bucket(&self) -> Result<(), StorageLayerError> {
159 match self {
160 TenantStorageLayer::S3(layer) => layer.delete_bucket().await,
161 }
162 }
163
164 #[tracing::instrument(skip(self))]
166 pub async fn create_presigned(
167 &self,
168 key: &str,
169 size: i64,
170 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
171 match self {
172 TenantStorageLayer::S3(layer) => layer.create_presigned(key, size).await,
173 }
174 }
175
176 #[tracing::instrument(skip(self))]
181 pub async fn create_presigned_download(
182 &self,
183 key: &str,
184 expires_in: Duration,
185 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
186 match self {
187 TenantStorageLayer::S3(layer) => layer.create_presigned_download(key, expires_in).await,
188 }
189 }
190
191 #[tracing::instrument(skip(self, body), fields(body_length = body.len()))]
193 pub async fn upload_file(
194 &self,
195 key: &str,
196 content_type: String,
197 body: Bytes,
198 ) -> Result<(), StorageLayerError> {
199 match self {
200 TenantStorageLayer::S3(layer) => layer.upload_file(key, content_type, body).await,
201 }
202 }
203
204 #[tracing::instrument(skip(self))]
206 pub async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError> {
207 match self {
208 TenantStorageLayer::S3(layer) => layer.add_bucket_notifications(sns_arn).await,
209 }
210 }
211
212 #[tracing::instrument(skip(self))]
214 pub async fn set_bucket_cors_origins(
215 &self,
216 origins: Vec<String>,
217 ) -> Result<(), StorageLayerError> {
218 match self {
219 TenantStorageLayer::S3(layer) => layer.set_bucket_cors_origins(origins).await,
220 }
221 }
222
223 #[tracing::instrument(skip(self))]
228 pub async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError> {
229 match self {
230 TenantStorageLayer::S3(layer) => layer.delete_file(key).await,
231 }
232 }
233
234 #[tracing::instrument(skip(self))]
236 pub async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError> {
237 match self {
238 TenantStorageLayer::S3(layer) => layer.get_file(key).await,
239 }
240 }
241}
242
243pub(crate) trait StorageLayerImpl {
245 fn bucket_name(&self) -> String;
246
247 async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError>;
248
249 async fn bucket_exists(&self) -> Result<bool, StorageLayerError>;
250
251 async fn delete_bucket(&self) -> Result<(), StorageLayerError>;
252
253 async fn create_presigned(
254 &self,
255 key: &str,
256 size: i64,
257 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
258
259 async fn create_presigned_download(
260 &self,
261 key: &str,
262 expires_in: Duration,
263 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
264
265 async fn upload_file(
266 &self,
267 key: &str,
268 content_type: String,
269 body: Bytes,
270 ) -> Result<(), StorageLayerError>;
271
272 async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError>;
273
274 async fn set_bucket_cors_origins(&self, origins: Vec<String>) -> Result<(), StorageLayerError>;
275
276 async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError>;
277
278 async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError>;
279}
280
281pub struct FileStream {
283 pub stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>,
285}
286
287impl Debug for FileStream {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 f.debug_struct("FileStream").finish()
290 }
291}
292
293impl Stream for FileStream {
294 type Item = std::io::Result<Bytes>;
295
296 fn poll_next(
297 mut self: std::pin::Pin<&mut Self>,
298 cx: &mut std::task::Context<'_>,
299 ) -> std::task::Poll<Option<Self::Item>> {
300 self.stream.as_mut().poll_next(cx)
301 }
302}
303
304impl FileStream {
305 pub async fn collect_bytes(mut self) -> Result<Bytes, StorageLayerError> {
307 let mut output = SegmentedBuf::new();
308
309 while let Some(result) = self.next().await {
310 let chunk = result.map_err(|error| {
311 tracing::error!(?error, "failed to collect file stream bytes");
312 StorageLayerError::CollectBytes
313 })?;
314
315 output.push(chunk);
316 }
317
318 Ok(output.copy_to_bytes(output.remaining()))
319 }
320}