docbox_storage/
lib.rs

1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3
4//! # Storage
5//!
6//! Docbox storage backend abstraction, handles abstracting the task of working with file
7//! storage to allow for multiple backends and easier testing.
8//!
9//! # Environment Variables
10//!
11//! See [s3] this is currently the only available backend for storage
12
13use aws_config::SdkConfig;
14use aws_sdk_s3::presigning::PresignedRequest;
15use bytes::{Buf, Bytes};
16use bytes_utils::SegmentedBuf;
17use chrono::{DateTime, Utc};
18use docbox_database::models::tenant::Tenant;
19use futures::{Stream, StreamExt};
20use serde::{Deserialize, Serialize};
21use std::{fmt::Debug, pin::Pin, time::Duration};
22use thiserror::Error;
23
24pub mod s3;
25
26/// Configuration for a storage layer factory
27#[derive(Debug, Clone, Deserialize, Serialize)]
28#[serde(tag = "provider", rename_all = "snake_case")]
29pub enum StorageLayerFactoryConfig {
30    /// Config for a S3 backend
31    S3(s3::S3StorageLayerFactoryConfig),
32}
33
34/// Errors that could occur when loading the storage layer factory
35/// configuration from the environment
36#[derive(Debug, Error)]
37pub enum StorageLayerFactoryConfigError {
38    /// Error from the S3 layer config
39    #[error(transparent)]
40    S3(#[from] s3::S3StorageLayerFactoryConfigError),
41}
42
43impl StorageLayerFactoryConfig {
44    /// Load the configuration from the current environment variables
45    pub fn from_env() -> Result<Self, StorageLayerFactoryConfigError> {
46        s3::S3StorageLayerFactoryConfig::from_env()
47            .map(Self::S3)
48            .map_err(StorageLayerFactoryConfigError::S3)
49    }
50}
51
52/// Storage layer factory for creating storage layer instances
53/// with some underlying backend implementation
54#[derive(Clone)]
55pub enum StorageLayerFactory {
56    /// S3 storage backend
57    S3(s3::S3StorageLayerFactory),
58}
59
60/// Errors that can occur when using a storage layer
61#[derive(Debug, Error)]
62pub enum StorageLayerError {
63    /// Error from the S3 layer
64    #[error(transparent)]
65    S3(Box<s3::S3StorageError>),
66
67    /// Error collecting streamed response bytes
68    #[error("failed to collect file contents")]
69    CollectBytes,
70}
71
72impl From<s3::S3StorageError> for StorageLayerError {
73    fn from(value: s3::S3StorageError) -> Self {
74        Self::S3(Box::new(value))
75    }
76}
77
78impl StorageLayerFactory {
79    /// Create a [StorageLayerFactory] from the provided config
80    pub fn from_config(aws_config: &SdkConfig, config: StorageLayerFactoryConfig) -> Self {
81        match config {
82            StorageLayerFactoryConfig::S3(config) => {
83                Self::S3(s3::S3StorageLayerFactory::from_config(aws_config, config))
84            }
85        }
86    }
87
88    /// Create a new storage layer from the factory
89    pub fn create_storage_layer(&self, tenant: &Tenant) -> TenantStorageLayer {
90        match self {
91            StorageLayerFactory::S3(s3) => {
92                let bucket_name = tenant.s3_name.clone();
93                let layer = s3.create_storage_layer(bucket_name);
94                TenantStorageLayer::S3(layer)
95            }
96        }
97    }
98}
99
100/// Storage layer for a tenant with different underlying backends
101#[derive(Clone)]
102pub enum TenantStorageLayer {
103    /// Storage layer backed by S3
104    S3(s3::S3StorageLayer),
105}
106
107/// Outcome from creating a bucket
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub enum CreateBucketOutcome {
110    /// Fresh bucket was created
111    New,
112    /// Bucket with the same name already exists
113    Existing,
114}
115
116impl TenantStorageLayer {
117    /// Creates the tenant storage bucket
118    ///
119    /// In the event that the bucket already exists, this is treated as a
120    /// [`Ok`] result rather than an error
121    #[tracing::instrument(skip(self))]
122    pub async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError> {
123        match self {
124            TenantStorageLayer::S3(layer) => layer.create_bucket().await,
125        }
126    }
127
128    /// Checks if the bucket exists
129    #[tracing::instrument(skip(self))]
130    pub async fn bucket_exists(&self) -> Result<bool, StorageLayerError> {
131        match self {
132            TenantStorageLayer::S3(layer) => layer.bucket_exists().await,
133        }
134    }
135
136    /// Deletes the tenant storage bucket
137    ///
138    /// In the event that the bucket did not exist before calling this
139    /// function this is treated as an [`Ok`] result
140    #[tracing::instrument(skip(self))]
141    pub async fn delete_bucket(&self) -> Result<(), StorageLayerError> {
142        match self {
143            TenantStorageLayer::S3(layer) => layer.delete_bucket().await,
144        }
145    }
146
147    /// Create a presigned file upload URL
148    #[tracing::instrument(skip(self))]
149    pub async fn create_presigned(
150        &self,
151        key: &str,
152        size: i64,
153    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
154        match self {
155            TenantStorageLayer::S3(layer) => layer.create_presigned(key, size).await,
156        }
157    }
158
159    /// Create a presigned file download URL
160    ///
161    /// Presigned download creation will succeed even if the requested key
162    /// is not present
163    #[tracing::instrument(skip(self))]
164    pub async fn create_presigned_download(
165        &self,
166        key: &str,
167        expires_in: Duration,
168    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
169        match self {
170            TenantStorageLayer::S3(layer) => layer.create_presigned_download(key, expires_in).await,
171        }
172    }
173
174    /// Uploads a file to the S3 bucket for the tenant
175    #[tracing::instrument(skip(self, body), fields(body_length = body.len()))]
176    pub async fn upload_file(
177        &self,
178        key: &str,
179        content_type: String,
180        body: Bytes,
181    ) -> Result<(), StorageLayerError> {
182        match self {
183            TenantStorageLayer::S3(layer) => layer.upload_file(key, content_type, body).await,
184        }
185    }
186
187    /// Add the SNS notification to a bucket
188    #[tracing::instrument(skip(self))]
189    pub async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError> {
190        match self {
191            TenantStorageLayer::S3(layer) => layer.add_bucket_notifications(sns_arn).await,
192        }
193    }
194
195    /// Sets the allowed CORS origins for accessing the storage from the frontend
196    #[tracing::instrument(skip(self))]
197    pub async fn set_bucket_cors_origins(
198        &self,
199        origins: Vec<String>,
200    ) -> Result<(), StorageLayerError> {
201        match self {
202            TenantStorageLayer::S3(layer) => layer.set_bucket_cors_origins(origins).await,
203        }
204    }
205
206    /// Deletes the file with the provided `key`
207    ///
208    /// In the event that the file did not exist before calling this
209    /// function this is treated as an [`Ok`] result
210    #[tracing::instrument(skip(self))]
211    pub async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError> {
212        match self {
213            TenantStorageLayer::S3(layer) => layer.delete_file(key).await,
214        }
215    }
216
217    /// Gets a byte stream for a file from S3
218    #[tracing::instrument(skip(self))]
219    pub async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError> {
220        match self {
221            TenantStorageLayer::S3(layer) => layer.get_file(key).await,
222        }
223    }
224}
225
226/// Internal trait defining required async implementations for a storage backend
227pub(crate) trait StorageLayerImpl {
228    async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError>;
229
230    async fn bucket_exists(&self) -> Result<bool, StorageLayerError>;
231
232    async fn delete_bucket(&self) -> Result<(), StorageLayerError>;
233
234    async fn create_presigned(
235        &self,
236        key: &str,
237        size: i64,
238    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
239
240    async fn create_presigned_download(
241        &self,
242        key: &str,
243        expires_in: Duration,
244    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
245
246    async fn upload_file(
247        &self,
248        key: &str,
249        content_type: String,
250        body: Bytes,
251    ) -> Result<(), StorageLayerError>;
252
253    async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError>;
254
255    async fn set_bucket_cors_origins(&self, origins: Vec<String>) -> Result<(), StorageLayerError>;
256
257    async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError>;
258
259    async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError>;
260}
261
262/// Stream of bytes from a file
263pub struct FileStream {
264    /// Underlying stream
265    pub stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>,
266}
267
268impl Debug for FileStream {
269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270        f.debug_struct("FileStream").finish()
271    }
272}
273
274impl Stream for FileStream {
275    type Item = std::io::Result<Bytes>;
276
277    fn poll_next(
278        mut self: std::pin::Pin<&mut Self>,
279        cx: &mut std::task::Context<'_>,
280    ) -> std::task::Poll<Option<Self::Item>> {
281        self.stream.as_mut().poll_next(cx)
282    }
283}
284
285impl FileStream {
286    /// Collect the stream to completion as a single [Bytes] buffer
287    pub async fn collect_bytes(mut self) -> Result<Bytes, StorageLayerError> {
288        let mut output = SegmentedBuf::new();
289
290        while let Some(result) = self.next().await {
291            let chunk = result.map_err(|error| {
292                tracing::error!(?error, "failed to collect file stream bytes");
293                StorageLayerError::CollectBytes
294            })?;
295
296            output.push(chunk);
297        }
298
299        Ok(output.copy_to_bytes(output.remaining()))
300    }
301}