1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3
4use aws_config::SdkConfig;
14use aws_sdk_s3::presigning::PresignedRequest;
15use bytes::{Buf, Bytes};
16use bytes_utils::SegmentedBuf;
17use chrono::{DateTime, Utc};
18use docbox_database::models::tenant::Tenant;
19use futures::{Stream, StreamExt};
20use serde::{Deserialize, Serialize};
21use std::{fmt::Debug, pin::Pin, time::Duration};
22use thiserror::Error;
23
24pub mod s3;
25
26#[derive(Debug, Clone, Deserialize, Serialize)]
28#[serde(tag = "provider", rename_all = "snake_case")]
29pub enum StorageLayerFactoryConfig {
30 S3(s3::S3StorageLayerFactoryConfig),
32}
33
34#[derive(Debug, Error)]
37pub enum StorageLayerFactoryConfigError {
38 #[error(transparent)]
40 S3(#[from] s3::S3StorageLayerFactoryConfigError),
41}
42
43impl StorageLayerFactoryConfig {
44 pub fn from_env() -> Result<Self, StorageLayerFactoryConfigError> {
46 s3::S3StorageLayerFactoryConfig::from_env()
47 .map(Self::S3)
48 .map_err(StorageLayerFactoryConfigError::S3)
49 }
50}
51
52#[derive(Clone)]
55pub enum StorageLayerFactory {
56 S3(s3::S3StorageLayerFactory),
58}
59
60#[derive(Debug, Error)]
62pub enum StorageLayerError {
63 #[error(transparent)]
65 S3(Box<s3::S3StorageError>),
66
67 #[error("failed to collect file contents")]
69 CollectBytes,
70}
71
72impl From<s3::S3StorageError> for StorageLayerError {
73 fn from(value: s3::S3StorageError) -> Self {
74 Self::S3(Box::new(value))
75 }
76}
77
78impl StorageLayerFactory {
79 pub fn from_config(aws_config: &SdkConfig, config: StorageLayerFactoryConfig) -> Self {
81 match config {
82 StorageLayerFactoryConfig::S3(config) => {
83 Self::S3(s3::S3StorageLayerFactory::from_config(aws_config, config))
84 }
85 }
86 }
87
88 pub fn create_storage_layer(&self, tenant: &Tenant) -> TenantStorageLayer {
90 match self {
91 StorageLayerFactory::S3(s3) => {
92 let bucket_name = tenant.s3_name.clone();
93 let layer = s3.create_storage_layer(bucket_name);
94 TenantStorageLayer::S3(layer)
95 }
96 }
97 }
98}
99
100#[derive(Clone)]
102pub enum TenantStorageLayer {
103 S3(s3::S3StorageLayer),
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub enum CreateBucketOutcome {
110 New,
112 Existing,
114}
115
116impl TenantStorageLayer {
117 #[tracing::instrument(skip(self))]
122 pub async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError> {
123 match self {
124 TenantStorageLayer::S3(layer) => layer.create_bucket().await,
125 }
126 }
127
128 #[tracing::instrument(skip(self))]
130 pub async fn bucket_exists(&self) -> Result<bool, StorageLayerError> {
131 match self {
132 TenantStorageLayer::S3(layer) => layer.bucket_exists().await,
133 }
134 }
135
136 #[tracing::instrument(skip(self))]
141 pub async fn delete_bucket(&self) -> Result<(), StorageLayerError> {
142 match self {
143 TenantStorageLayer::S3(layer) => layer.delete_bucket().await,
144 }
145 }
146
147 #[tracing::instrument(skip(self))]
149 pub async fn create_presigned(
150 &self,
151 key: &str,
152 size: i64,
153 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
154 match self {
155 TenantStorageLayer::S3(layer) => layer.create_presigned(key, size).await,
156 }
157 }
158
159 #[tracing::instrument(skip(self))]
164 pub async fn create_presigned_download(
165 &self,
166 key: &str,
167 expires_in: Duration,
168 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
169 match self {
170 TenantStorageLayer::S3(layer) => layer.create_presigned_download(key, expires_in).await,
171 }
172 }
173
174 #[tracing::instrument(skip(self, body), fields(body_length = body.len()))]
176 pub async fn upload_file(
177 &self,
178 key: &str,
179 content_type: String,
180 body: Bytes,
181 ) -> Result<(), StorageLayerError> {
182 match self {
183 TenantStorageLayer::S3(layer) => layer.upload_file(key, content_type, body).await,
184 }
185 }
186
187 #[tracing::instrument(skip(self))]
189 pub async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError> {
190 match self {
191 TenantStorageLayer::S3(layer) => layer.add_bucket_notifications(sns_arn).await,
192 }
193 }
194
195 #[tracing::instrument(skip(self))]
197 pub async fn set_bucket_cors_origins(
198 &self,
199 origins: Vec<String>,
200 ) -> Result<(), StorageLayerError> {
201 match self {
202 TenantStorageLayer::S3(layer) => layer.set_bucket_cors_origins(origins).await,
203 }
204 }
205
206 #[tracing::instrument(skip(self))]
211 pub async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError> {
212 match self {
213 TenantStorageLayer::S3(layer) => layer.delete_file(key).await,
214 }
215 }
216
217 #[tracing::instrument(skip(self))]
219 pub async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError> {
220 match self {
221 TenantStorageLayer::S3(layer) => layer.get_file(key).await,
222 }
223 }
224}
225
226pub(crate) trait StorageLayerImpl {
228 async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError>;
229
230 async fn bucket_exists(&self) -> Result<bool, StorageLayerError>;
231
232 async fn delete_bucket(&self) -> Result<(), StorageLayerError>;
233
234 async fn create_presigned(
235 &self,
236 key: &str,
237 size: i64,
238 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
239
240 async fn create_presigned_download(
241 &self,
242 key: &str,
243 expires_in: Duration,
244 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
245
246 async fn upload_file(
247 &self,
248 key: &str,
249 content_type: String,
250 body: Bytes,
251 ) -> Result<(), StorageLayerError>;
252
253 async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError>;
254
255 async fn set_bucket_cors_origins(&self, origins: Vec<String>) -> Result<(), StorageLayerError>;
256
257 async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError>;
258
259 async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError>;
260}
261
262pub struct FileStream {
264 pub stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>,
266}
267
268impl Debug for FileStream {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 f.debug_struct("FileStream").finish()
271 }
272}
273
274impl Stream for FileStream {
275 type Item = std::io::Result<Bytes>;
276
277 fn poll_next(
278 mut self: std::pin::Pin<&mut Self>,
279 cx: &mut std::task::Context<'_>,
280 ) -> std::task::Poll<Option<Self::Item>> {
281 self.stream.as_mut().poll_next(cx)
282 }
283}
284
285impl FileStream {
286 pub async fn collect_bytes(mut self) -> Result<Bytes, StorageLayerError> {
288 let mut output = SegmentedBuf::new();
289
290 while let Some(result) = self.next().await {
291 let chunk = result.map_err(|error| {
292 tracing::error!(?error, "failed to collect file stream bytes");
293 StorageLayerError::CollectBytes
294 })?;
295
296 output.push(chunk);
297 }
298
299 Ok(output.copy_to_bytes(output.remaining()))
300 }
301}