Skip to main content

docbox_storage/
lib.rs

1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3
4//! # Storage
5//!
6//! Docbox storage backend abstraction, handles abstracting the task of working with file
7//! storage to allow for multiple backends and easier testing.
8//!
9//! # Environment Variables
10//!
11//! See [s3] this is currently the only available backend for storage
12
13use aws_config::SdkConfig;
14use aws_sdk_s3::presigning::PresignedRequest;
15use bytes::{Buf, Bytes};
16use bytes_utils::SegmentedBuf;
17use chrono::{DateTime, Utc};
18use docbox_database::models::tenant::Tenant;
19use futures::{Stream, StreamExt};
20use serde::{Deserialize, Serialize};
21use std::{fmt::Debug, pin::Pin, time::Duration};
22use thiserror::Error;
23
24pub mod s3;
25
26/// Configuration for a storage layer factory
27#[derive(Debug, Clone, Deserialize, Serialize)]
28#[serde(tag = "provider", rename_all = "snake_case")]
29pub enum StorageLayerFactoryConfig {
30    /// Config for a S3 backend
31    S3(s3::S3StorageLayerFactoryConfig),
32}
33
34impl Default for StorageLayerFactoryConfig {
35    fn default() -> Self {
36        Self::S3(Default::default())
37    }
38}
39
40/// Errors that could occur when loading the storage layer factory
41/// configuration from the environment
42#[derive(Debug, Error)]
43pub enum StorageLayerFactoryConfigError {
44    /// Error from the S3 layer config
45    #[error(transparent)]
46    S3(#[from] s3::S3StorageLayerFactoryConfigError),
47}
48
49impl StorageLayerFactoryConfig {
50    /// Load the configuration from the current environment variables
51    pub fn from_env() -> Result<Self, StorageLayerFactoryConfigError> {
52        s3::S3StorageLayerFactoryConfig::from_env()
53            .map(Self::S3)
54            .map_err(StorageLayerFactoryConfigError::S3)
55    }
56}
57
58/// Storage layer factory for creating storage layer instances
59/// with some underlying backend implementation
60#[derive(Clone)]
61pub enum StorageLayerFactory {
62    /// S3 storage backend
63    S3(s3::S3StorageLayerFactory),
64}
65
66/// Errors that can occur when using a storage layer
67#[derive(Debug, Error)]
68pub enum StorageLayerError {
69    /// Error from the S3 layer
70    #[error(transparent)]
71    S3(Box<s3::S3StorageError>),
72
73    /// Error collecting streamed response bytes
74    #[error("failed to collect file contents")]
75    CollectBytes,
76}
77
78impl From<s3::S3StorageError> for StorageLayerError {
79    fn from(value: s3::S3StorageError) -> Self {
80        Self::S3(Box::new(value))
81    }
82}
83
84impl StorageLayerFactory {
85    /// Create a [StorageLayerFactory] from the provided config
86    pub fn from_config(aws_config: &SdkConfig, config: StorageLayerFactoryConfig) -> Self {
87        match config {
88            StorageLayerFactoryConfig::S3(config) => {
89                Self::S3(s3::S3StorageLayerFactory::from_config(aws_config, config))
90            }
91        }
92    }
93
94    /// Create a new storage layer from the factory
95    pub fn create_storage_layer(&self, tenant: &Tenant) -> TenantStorageLayer {
96        self.create_storage_layer_bucket(tenant.s3_name.clone())
97    }
98
99    /// Create a new storage layer from a bucket name directly
100    pub fn create_storage_layer_bucket(&self, bucket_name: String) -> TenantStorageLayer {
101        match self {
102            StorageLayerFactory::S3(s3) => {
103                let layer = s3.create_storage_layer(bucket_name);
104                TenantStorageLayer::S3(layer)
105            }
106        }
107    }
108}
109
110/// Storage layer for a tenant with different underlying backends
111#[derive(Clone)]
112pub enum TenantStorageLayer {
113    /// Storage layer backed by S3
114    S3(s3::S3StorageLayer),
115}
116
117/// Outcome from creating a bucket
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119pub enum CreateBucketOutcome {
120    /// Fresh bucket was created
121    New,
122    /// Bucket with the same name already exists
123    Existing,
124}
125
126impl TenantStorageLayer {
127    /// Get the name of the bucket
128    pub fn bucket_name(&self) -> String {
129        match self {
130            TenantStorageLayer::S3(layer) => layer.bucket_name(),
131        }
132    }
133
134    /// Creates the tenant storage bucket
135    ///
136    /// In the event that the bucket already exists, this is treated as a
137    /// [`Ok`] result rather than an error
138    #[tracing::instrument(skip(self))]
139    pub async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError> {
140        match self {
141            TenantStorageLayer::S3(layer) => layer.create_bucket().await,
142        }
143    }
144
145    /// Checks if the bucket exists
146    #[tracing::instrument(skip(self))]
147    pub async fn bucket_exists(&self) -> Result<bool, StorageLayerError> {
148        match self {
149            TenantStorageLayer::S3(layer) => layer.bucket_exists().await,
150        }
151    }
152
153    /// Deletes the tenant storage bucket
154    ///
155    /// In the event that the bucket did not exist before calling this
156    /// function this is treated as an [`Ok`] result
157    #[tracing::instrument(skip(self))]
158    pub async fn delete_bucket(&self) -> Result<(), StorageLayerError> {
159        match self {
160            TenantStorageLayer::S3(layer) => layer.delete_bucket().await,
161        }
162    }
163
164    /// Create a presigned file upload URL
165    #[tracing::instrument(skip(self))]
166    pub async fn create_presigned(
167        &self,
168        key: &str,
169        size: i64,
170    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
171        match self {
172            TenantStorageLayer::S3(layer) => layer.create_presigned(key, size).await,
173        }
174    }
175
176    /// Create a presigned file download URL
177    ///
178    /// Presigned download creation will succeed even if the requested key
179    /// is not present
180    #[tracing::instrument(skip(self))]
181    pub async fn create_presigned_download(
182        &self,
183        key: &str,
184        expires_in: Duration,
185    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
186        match self {
187            TenantStorageLayer::S3(layer) => layer.create_presigned_download(key, expires_in).await,
188        }
189    }
190
191    /// Uploads a file to the S3 bucket for the tenant
192    #[tracing::instrument(skip(self, body), fields(body_length = body.len()))]
193    pub async fn upload_file(
194        &self,
195        key: &str,
196        content_type: String,
197        body: Bytes,
198    ) -> Result<(), StorageLayerError> {
199        match self {
200            TenantStorageLayer::S3(layer) => layer.upload_file(key, content_type, body).await,
201        }
202    }
203
204    /// Add the SNS notification to a bucket
205    #[tracing::instrument(skip(self))]
206    pub async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError> {
207        match self {
208            TenantStorageLayer::S3(layer) => layer.add_bucket_notifications(sns_arn).await,
209        }
210    }
211
212    /// Sets the allowed CORS origins for accessing the storage from the frontend
213    #[tracing::instrument(skip(self))]
214    pub async fn set_bucket_cors_origins(
215        &self,
216        origins: Vec<String>,
217    ) -> Result<(), StorageLayerError> {
218        match self {
219            TenantStorageLayer::S3(layer) => layer.set_bucket_cors_origins(origins).await,
220        }
221    }
222
223    /// Deletes the file with the provided `key`
224    ///
225    /// In the event that the file did not exist before calling this
226    /// function this is treated as an [`Ok`] result
227    #[tracing::instrument(skip(self))]
228    pub async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError> {
229        match self {
230            TenantStorageLayer::S3(layer) => layer.delete_file(key).await,
231        }
232    }
233
234    /// Gets a byte stream for a file from S3
235    #[tracing::instrument(skip(self))]
236    pub async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError> {
237        match self {
238            TenantStorageLayer::S3(layer) => layer.get_file(key).await,
239        }
240    }
241}
242
243/// Internal trait defining required async implementations for a storage backend
244pub(crate) trait StorageLayerImpl {
245    fn bucket_name(&self) -> String;
246
247    async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError>;
248
249    async fn bucket_exists(&self) -> Result<bool, StorageLayerError>;
250
251    async fn delete_bucket(&self) -> Result<(), StorageLayerError>;
252
253    async fn create_presigned(
254        &self,
255        key: &str,
256        size: i64,
257    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
258
259    async fn create_presigned_download(
260        &self,
261        key: &str,
262        expires_in: Duration,
263    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
264
265    async fn upload_file(
266        &self,
267        key: &str,
268        content_type: String,
269        body: Bytes,
270    ) -> Result<(), StorageLayerError>;
271
272    async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError>;
273
274    async fn set_bucket_cors_origins(&self, origins: Vec<String>) -> Result<(), StorageLayerError>;
275
276    async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError>;
277
278    async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError>;
279}
280
281/// Stream of bytes from a file
282pub struct FileStream {
283    /// Underlying stream
284    pub stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>,
285}
286
287impl Debug for FileStream {
288    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        f.debug_struct("FileStream").finish()
290    }
291}
292
293impl Stream for FileStream {
294    type Item = std::io::Result<Bytes>;
295
296    fn poll_next(
297        mut self: std::pin::Pin<&mut Self>,
298        cx: &mut std::task::Context<'_>,
299    ) -> std::task::Poll<Option<Self::Item>> {
300        self.stream.as_mut().poll_next(cx)
301    }
302}
303
304impl FileStream {
305    /// Collect the stream to completion as a single [Bytes] buffer
306    pub async fn collect_bytes(mut self) -> Result<Bytes, StorageLayerError> {
307        let mut output = SegmentedBuf::new();
308
309        while let Some(result) = self.next().await {
310            let chunk = result.map_err(|error| {
311                tracing::error!(?error, "failed to collect file stream bytes");
312                StorageLayerError::CollectBytes
313            })?;
314
315            output.push(chunk);
316        }
317
318        Ok(output.copy_to_bytes(output.remaining()))
319    }
320}