1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3
4use aws_config::SdkConfig;
14use aws_sdk_s3::presigning::PresignedRequest;
15use bytes::{Buf, Bytes};
16use bytes_utils::SegmentedBuf;
17use chrono::{DateTime, Utc};
18use docbox_database::models::tenant::Tenant;
19use futures::{Stream, StreamExt};
20use serde::{Deserialize, Serialize};
21use std::{fmt::Debug, pin::Pin, time::Duration};
22use thiserror::Error;
23
24pub mod s3;
25
26#[derive(Debug, Clone, Deserialize, Serialize)]
28#[serde(tag = "provider", rename_all = "snake_case")]
29pub enum StorageLayerFactoryConfig {
30 S3(s3::S3StorageLayerFactoryConfig),
32}
33
34#[derive(Debug, Error)]
37pub enum StorageLayerFactoryConfigError {
38 #[error(transparent)]
40 S3(#[from] s3::S3StorageLayerFactoryConfigError),
41}
42
43impl StorageLayerFactoryConfig {
44 pub fn from_env() -> Result<Self, StorageLayerFactoryConfigError> {
46 s3::S3StorageLayerFactoryConfig::from_env()
47 .map(Self::S3)
48 .map_err(StorageLayerFactoryConfigError::S3)
49 }
50}
51
52#[derive(Clone)]
55pub enum StorageLayerFactory {
56 S3(s3::S3StorageLayerFactory),
58}
59
60#[derive(Debug, Error)]
62pub enum StorageLayerError {
63 #[error(transparent)]
65 S3(Box<s3::S3StorageError>),
66
67 #[error("failed to collect file contents")]
69 CollectBytes,
70}
71
72impl From<s3::S3StorageError> for StorageLayerError {
73 fn from(value: s3::S3StorageError) -> Self {
74 Self::S3(Box::new(value))
75 }
76}
77
78impl StorageLayerFactory {
79 pub fn from_config(aws_config: &SdkConfig, config: StorageLayerFactoryConfig) -> Self {
81 match config {
82 StorageLayerFactoryConfig::S3(config) => {
83 Self::S3(s3::S3StorageLayerFactory::from_config(aws_config, config))
84 }
85 }
86 }
87
88 pub fn create_storage_layer(&self, tenant: &Tenant) -> TenantStorageLayer {
90 self.create_storage_layer_bucket(tenant.s3_name.clone())
91 }
92
93 pub fn create_storage_layer_bucket(&self, bucket_name: String) -> TenantStorageLayer {
95 match self {
96 StorageLayerFactory::S3(s3) => {
97 let layer = s3.create_storage_layer(bucket_name);
98 TenantStorageLayer::S3(layer)
99 }
100 }
101 }
102}
103
104#[derive(Clone)]
106pub enum TenantStorageLayer {
107 S3(s3::S3StorageLayer),
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum CreateBucketOutcome {
114 New,
116 Existing,
118}
119
120impl TenantStorageLayer {
121 pub fn bucket_name(&self) -> String {
123 match self {
124 TenantStorageLayer::S3(layer) => layer.bucket_name(),
125 }
126 }
127
128 #[tracing::instrument(skip(self))]
133 pub async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError> {
134 match self {
135 TenantStorageLayer::S3(layer) => layer.create_bucket().await,
136 }
137 }
138
139 #[tracing::instrument(skip(self))]
141 pub async fn bucket_exists(&self) -> Result<bool, StorageLayerError> {
142 match self {
143 TenantStorageLayer::S3(layer) => layer.bucket_exists().await,
144 }
145 }
146
147 #[tracing::instrument(skip(self))]
152 pub async fn delete_bucket(&self) -> Result<(), StorageLayerError> {
153 match self {
154 TenantStorageLayer::S3(layer) => layer.delete_bucket().await,
155 }
156 }
157
158 #[tracing::instrument(skip(self))]
160 pub async fn create_presigned(
161 &self,
162 key: &str,
163 size: i64,
164 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
165 match self {
166 TenantStorageLayer::S3(layer) => layer.create_presigned(key, size).await,
167 }
168 }
169
170 #[tracing::instrument(skip(self))]
175 pub async fn create_presigned_download(
176 &self,
177 key: &str,
178 expires_in: Duration,
179 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
180 match self {
181 TenantStorageLayer::S3(layer) => layer.create_presigned_download(key, expires_in).await,
182 }
183 }
184
185 #[tracing::instrument(skip(self, body), fields(body_length = body.len()))]
187 pub async fn upload_file(
188 &self,
189 key: &str,
190 content_type: String,
191 body: Bytes,
192 ) -> Result<(), StorageLayerError> {
193 match self {
194 TenantStorageLayer::S3(layer) => layer.upload_file(key, content_type, body).await,
195 }
196 }
197
198 #[tracing::instrument(skip(self))]
200 pub async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError> {
201 match self {
202 TenantStorageLayer::S3(layer) => layer.add_bucket_notifications(sns_arn).await,
203 }
204 }
205
206 #[tracing::instrument(skip(self))]
208 pub async fn set_bucket_cors_origins(
209 &self,
210 origins: Vec<String>,
211 ) -> Result<(), StorageLayerError> {
212 match self {
213 TenantStorageLayer::S3(layer) => layer.set_bucket_cors_origins(origins).await,
214 }
215 }
216
217 #[tracing::instrument(skip(self))]
222 pub async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError> {
223 match self {
224 TenantStorageLayer::S3(layer) => layer.delete_file(key).await,
225 }
226 }
227
228 #[tracing::instrument(skip(self))]
230 pub async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError> {
231 match self {
232 TenantStorageLayer::S3(layer) => layer.get_file(key).await,
233 }
234 }
235}
236
237pub(crate) trait StorageLayerImpl {
239 fn bucket_name(&self) -> String;
240
241 async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError>;
242
243 async fn bucket_exists(&self) -> Result<bool, StorageLayerError>;
244
245 async fn delete_bucket(&self) -> Result<(), StorageLayerError>;
246
247 async fn create_presigned(
248 &self,
249 key: &str,
250 size: i64,
251 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
252
253 async fn create_presigned_download(
254 &self,
255 key: &str,
256 expires_in: Duration,
257 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
258
259 async fn upload_file(
260 &self,
261 key: &str,
262 content_type: String,
263 body: Bytes,
264 ) -> Result<(), StorageLayerError>;
265
266 async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError>;
267
268 async fn set_bucket_cors_origins(&self, origins: Vec<String>) -> Result<(), StorageLayerError>;
269
270 async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError>;
271
272 async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError>;
273}
274
275pub struct FileStream {
277 pub stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>,
279}
280
281impl Debug for FileStream {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 f.debug_struct("FileStream").finish()
284 }
285}
286
287impl Stream for FileStream {
288 type Item = std::io::Result<Bytes>;
289
290 fn poll_next(
291 mut self: std::pin::Pin<&mut Self>,
292 cx: &mut std::task::Context<'_>,
293 ) -> std::task::Poll<Option<Self::Item>> {
294 self.stream.as_mut().poll_next(cx)
295 }
296}
297
298impl FileStream {
299 pub async fn collect_bytes(mut self) -> Result<Bytes, StorageLayerError> {
301 let mut output = SegmentedBuf::new();
302
303 while let Some(result) = self.next().await {
304 let chunk = result.map_err(|error| {
305 tracing::error!(?error, "failed to collect file stream bytes");
306 StorageLayerError::CollectBytes
307 })?;
308
309 output.push(chunk);
310 }
311
312 Ok(output.copy_to_bytes(output.remaining()))
313 }
314}