docbox_storage/
lib.rs

1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3
4//! # Storage
5//!
6//! Docbox storage backend abstraction, handles abstracting the task of working with file
7//! storage to allow for multiple backends and easier testing.
8//!
9//! # Environment Variables
10//!
11//! See [s3] this is currently the only available backend for storage
12
13use aws_config::SdkConfig;
14use aws_sdk_s3::presigning::PresignedRequest;
15use bytes::{Buf, Bytes};
16use bytes_utils::SegmentedBuf;
17use chrono::{DateTime, Utc};
18use docbox_database::models::tenant::Tenant;
19use futures::{Stream, StreamExt};
20use serde::{Deserialize, Serialize};
21use std::{fmt::Debug, pin::Pin, time::Duration};
22use thiserror::Error;
23
24pub mod s3;
25
26/// Configuration for a storage layer factory
27#[derive(Debug, Clone, Deserialize, Serialize)]
28#[serde(tag = "provider", rename_all = "snake_case")]
29pub enum StorageLayerFactoryConfig {
30    /// Config for a S3 backend
31    S3(s3::S3StorageLayerFactoryConfig),
32}
33
34/// Errors that could occur when loading the storage layer factory
35/// configuration from the environment
36#[derive(Debug, Error)]
37pub enum StorageLayerFactoryConfigError {
38    /// Error from the S3 layer config
39    #[error(transparent)]
40    S3(#[from] s3::S3StorageLayerFactoryConfigError),
41}
42
43impl StorageLayerFactoryConfig {
44    /// Load the configuration from the current environment variables
45    pub fn from_env() -> Result<Self, StorageLayerFactoryConfigError> {
46        s3::S3StorageLayerFactoryConfig::from_env()
47            .map(Self::S3)
48            .map_err(StorageLayerFactoryConfigError::S3)
49    }
50}
51
52/// Storage layer factory for creating storage layer instances
53/// with some underlying backend implementation
54#[derive(Clone)]
55pub enum StorageLayerFactory {
56    /// S3 storage backend
57    S3(s3::S3StorageLayerFactory),
58}
59
60/// Errors that can occur when using a storage layer
61#[derive(Debug, Error)]
62pub enum StorageLayerError {
63    /// Error from the S3 layer
64    #[error(transparent)]
65    S3(Box<s3::S3StorageError>),
66
67    /// Error collecting streamed response bytes
68    #[error("failed to collect file contents")]
69    CollectBytes,
70}
71
72impl From<s3::S3StorageError> for StorageLayerError {
73    fn from(value: s3::S3StorageError) -> Self {
74        Self::S3(Box::new(value))
75    }
76}
77
78impl StorageLayerFactory {
79    /// Create a [StorageLayerFactory] from the provided config
80    pub fn from_config(aws_config: &SdkConfig, config: StorageLayerFactoryConfig) -> Self {
81        match config {
82            StorageLayerFactoryConfig::S3(config) => {
83                Self::S3(s3::S3StorageLayerFactory::from_config(aws_config, config))
84            }
85        }
86    }
87
88    /// Create a new storage layer from the factory
89    pub fn create_storage_layer(&self, tenant: &Tenant) -> TenantStorageLayer {
90        self.create_storage_layer_bucket(tenant.s3_name.clone())
91    }
92
93    /// Create a new storage layer from a bucket name directly
94    pub fn create_storage_layer_bucket(&self, bucket_name: String) -> TenantStorageLayer {
95        match self {
96            StorageLayerFactory::S3(s3) => {
97                let layer = s3.create_storage_layer(bucket_name);
98                TenantStorageLayer::S3(layer)
99            }
100        }
101    }
102}
103
104/// Storage layer for a tenant with different underlying backends
105#[derive(Clone)]
106pub enum TenantStorageLayer {
107    /// Storage layer backed by S3
108    S3(s3::S3StorageLayer),
109}
110
111/// Outcome from creating a bucket
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum CreateBucketOutcome {
114    /// Fresh bucket was created
115    New,
116    /// Bucket with the same name already exists
117    Existing,
118}
119
120impl TenantStorageLayer {
121    /// Get the name of the bucket
122    pub fn bucket_name(&self) -> String {
123        match self {
124            TenantStorageLayer::S3(layer) => layer.bucket_name(),
125        }
126    }
127
128    /// Creates the tenant storage bucket
129    ///
130    /// In the event that the bucket already exists, this is treated as a
131    /// [`Ok`] result rather than an error
132    #[tracing::instrument(skip(self))]
133    pub async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError> {
134        match self {
135            TenantStorageLayer::S3(layer) => layer.create_bucket().await,
136        }
137    }
138
139    /// Checks if the bucket exists
140    #[tracing::instrument(skip(self))]
141    pub async fn bucket_exists(&self) -> Result<bool, StorageLayerError> {
142        match self {
143            TenantStorageLayer::S3(layer) => layer.bucket_exists().await,
144        }
145    }
146
147    /// Deletes the tenant storage bucket
148    ///
149    /// In the event that the bucket did not exist before calling this
150    /// function this is treated as an [`Ok`] result
151    #[tracing::instrument(skip(self))]
152    pub async fn delete_bucket(&self) -> Result<(), StorageLayerError> {
153        match self {
154            TenantStorageLayer::S3(layer) => layer.delete_bucket().await,
155        }
156    }
157
158    /// Create a presigned file upload URL
159    #[tracing::instrument(skip(self))]
160    pub async fn create_presigned(
161        &self,
162        key: &str,
163        size: i64,
164    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
165        match self {
166            TenantStorageLayer::S3(layer) => layer.create_presigned(key, size).await,
167        }
168    }
169
170    /// Create a presigned file download URL
171    ///
172    /// Presigned download creation will succeed even if the requested key
173    /// is not present
174    #[tracing::instrument(skip(self))]
175    pub async fn create_presigned_download(
176        &self,
177        key: &str,
178        expires_in: Duration,
179    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
180        match self {
181            TenantStorageLayer::S3(layer) => layer.create_presigned_download(key, expires_in).await,
182        }
183    }
184
185    /// Uploads a file to the S3 bucket for the tenant
186    #[tracing::instrument(skip(self, body), fields(body_length = body.len()))]
187    pub async fn upload_file(
188        &self,
189        key: &str,
190        content_type: String,
191        body: Bytes,
192    ) -> Result<(), StorageLayerError> {
193        match self {
194            TenantStorageLayer::S3(layer) => layer.upload_file(key, content_type, body).await,
195        }
196    }
197
198    /// Add the SNS notification to a bucket
199    #[tracing::instrument(skip(self))]
200    pub async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError> {
201        match self {
202            TenantStorageLayer::S3(layer) => layer.add_bucket_notifications(sns_arn).await,
203        }
204    }
205
206    /// Sets the allowed CORS origins for accessing the storage from the frontend
207    #[tracing::instrument(skip(self))]
208    pub async fn set_bucket_cors_origins(
209        &self,
210        origins: Vec<String>,
211    ) -> Result<(), StorageLayerError> {
212        match self {
213            TenantStorageLayer::S3(layer) => layer.set_bucket_cors_origins(origins).await,
214        }
215    }
216
217    /// Deletes the file with the provided `key`
218    ///
219    /// In the event that the file did not exist before calling this
220    /// function this is treated as an [`Ok`] result
221    #[tracing::instrument(skip(self))]
222    pub async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError> {
223        match self {
224            TenantStorageLayer::S3(layer) => layer.delete_file(key).await,
225        }
226    }
227
228    /// Gets a byte stream for a file from S3
229    #[tracing::instrument(skip(self))]
230    pub async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError> {
231        match self {
232            TenantStorageLayer::S3(layer) => layer.get_file(key).await,
233        }
234    }
235}
236
237/// Internal trait defining required async implementations for a storage backend
238pub(crate) trait StorageLayerImpl {
239    fn bucket_name(&self) -> String;
240
241    async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError>;
242
243    async fn bucket_exists(&self) -> Result<bool, StorageLayerError>;
244
245    async fn delete_bucket(&self) -> Result<(), StorageLayerError>;
246
247    async fn create_presigned(
248        &self,
249        key: &str,
250        size: i64,
251    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
252
253    async fn create_presigned_download(
254        &self,
255        key: &str,
256        expires_in: Duration,
257    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
258
259    async fn upload_file(
260        &self,
261        key: &str,
262        content_type: String,
263        body: Bytes,
264    ) -> Result<(), StorageLayerError>;
265
266    async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError>;
267
268    async fn set_bucket_cors_origins(&self, origins: Vec<String>) -> Result<(), StorageLayerError>;
269
270    async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError>;
271
272    async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError>;
273}
274
275/// Stream of bytes from a file
276pub struct FileStream {
277    /// Underlying stream
278    pub stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>,
279}
280
281impl Debug for FileStream {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        f.debug_struct("FileStream").finish()
284    }
285}
286
287impl Stream for FileStream {
288    type Item = std::io::Result<Bytes>;
289
290    fn poll_next(
291        mut self: std::pin::Pin<&mut Self>,
292        cx: &mut std::task::Context<'_>,
293    ) -> std::task::Poll<Option<Self::Item>> {
294        self.stream.as_mut().poll_next(cx)
295    }
296}
297
298impl FileStream {
299    /// Collect the stream to completion as a single [Bytes] buffer
300    pub async fn collect_bytes(mut self) -> Result<Bytes, StorageLayerError> {
301        let mut output = SegmentedBuf::new();
302
303        while let Some(result) = self.next().await {
304            let chunk = result.map_err(|error| {
305                tracing::error!(?error, "failed to collect file stream bytes");
306                StorageLayerError::CollectBytes
307            })?;
308
309            output.push(chunk);
310        }
311
312        Ok(output.copy_to_bytes(output.remaining()))
313    }
314}