Skip to main content

docbox_storage/
lib.rs

1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3
4//! # Storage
5//!
6//! Docbox storage backend abstraction, handles abstracting the task of working with file
7//! storage to allow for multiple backends and easier testing.
8//!
9//! # Environment Variables
10//!
11//! See [s3] this is currently the only available backend for storage
12
13use aws_config::SdkConfig;
14use aws_sdk_s3::presigning::PresignedRequest;
15use bytes::{Buf, Bytes};
16use bytes_utils::SegmentedBuf;
17use chrono::{DateTime, Utc};
18use futures::{Stream, StreamExt};
19use serde::{Deserialize, Serialize};
20use std::{fmt::Debug, pin::Pin, time::Duration};
21use thiserror::Error;
22
23pub mod s3;
24
25/// Configuration for a storage layer factory
26#[derive(Debug, Clone, Deserialize, Serialize)]
27#[serde(tag = "provider", rename_all = "snake_case")]
28pub enum StorageLayerFactoryConfig {
29    /// Config for a S3 backend
30    S3(s3::S3StorageLayerFactoryConfig),
31}
32
33impl Default for StorageLayerFactoryConfig {
34    fn default() -> Self {
35        Self::S3(Default::default())
36    }
37}
38
39/// Errors that could occur when loading the storage layer factory
40/// configuration from the environment
41#[derive(Debug, Error)]
42pub enum StorageLayerFactoryConfigError {
43    /// Error from the S3 layer config
44    #[error(transparent)]
45    S3(#[from] s3::S3StorageLayerFactoryConfigError),
46}
47
48impl StorageLayerFactoryConfig {
49    /// Load the configuration from the current environment variables
50    pub fn from_env() -> Result<Self, StorageLayerFactoryConfigError> {
51        s3::S3StorageLayerFactoryConfig::from_env()
52            .map(Self::S3)
53            .map_err(StorageLayerFactoryConfigError::S3)
54    }
55}
56
57/// Storage layer factory for creating storage layer instances
58/// with some underlying backend implementation
59#[derive(Clone)]
60pub enum StorageLayerFactory {
61    /// S3 storage backend
62    S3(s3::S3StorageLayerFactory),
63}
64
65/// Errors that can occur when using a storage layer
66#[derive(Debug, Error)]
67pub enum StorageLayerError {
68    /// Error from the S3 layer
69    #[error(transparent)]
70    S3(Box<s3::S3StorageError>),
71
72    /// Error collecting streamed response bytes
73    #[error("failed to collect file contents")]
74    CollectBytes,
75}
76
77impl From<s3::S3StorageError> for StorageLayerError {
78    fn from(value: s3::S3StorageError) -> Self {
79        Self::S3(Box::new(value))
80    }
81}
82
83/// Options required to initialize a storage layer from a [StorageLayerFactory]
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub struct StorageLayerOptions {
86    /// Name of the storage bucket
87    pub bucket_name: String,
88}
89
90impl StorageLayerFactory {
91    /// Create a [StorageLayerFactory] from the provided config
92    pub fn from_config(aws_config: &SdkConfig, config: StorageLayerFactoryConfig) -> Self {
93        match config {
94            StorageLayerFactoryConfig::S3(config) => {
95                Self::S3(s3::S3StorageLayerFactory::from_config(aws_config, config))
96            }
97        }
98    }
99
100    /// Create a simple layer for testing purposes
101    #[cfg(debug_assertions)]
102    pub fn create_test_layer(&self) -> StorageLayer {
103        self.create_layer(StorageLayerOptions {
104            bucket_name: "test".to_string(),
105        })
106    }
107
108    /// Create a storage layer from the provided `options`
109    pub fn create_layer(&self, options: StorageLayerOptions) -> StorageLayer {
110        match self {
111            StorageLayerFactory::S3(s3) => {
112                let layer = s3.create_storage_layer(options.bucket_name);
113                StorageLayer::S3(layer)
114            }
115        }
116    }
117}
118
119/// Storage layer for a tenant with different underlying backends
120#[derive(Clone)]
121pub enum StorageLayer {
122    /// Storage layer backed by S3
123    S3(s3::S3StorageLayer),
124}
125
126/// Outcome from creating a bucket
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum CreateBucketOutcome {
129    /// Fresh bucket was created
130    New,
131    /// Bucket with the same name already exists
132    Existing,
133}
134
135/// Options for properties of an uploaded file
136#[derive(Debug, Default, Clone, PartialEq, Eq)]
137pub struct UploadFileOptions {
138    /// Content type of the uploaded file
139    pub content_type: String,
140    /// Tags to append to the file
141    pub tags: Option<Vec<UploadFileTag>>,
142}
143
144/// Additional behavioral tags to use when uploading the file
145#[derive(Debug, Clone, Copy, PartialEq, Eq)]
146pub enum UploadFileTag {
147    /// Tag that the file should expire after 1 day
148    ExpireDays1,
149    /// Tag that the file should expire after 30 days
150    ExpireDays30,
151}
152
153impl StorageLayer {
154    /// Get the name of the bucket
155    pub fn bucket_name(&self) -> String {
156        match self {
157            StorageLayer::S3(layer) => layer.bucket_name(),
158        }
159    }
160
161    /// Creates the tenant storage bucket
162    ///
163    /// In the event that the bucket already exists, this is treated as a
164    /// [`Ok`] result rather than an error
165    #[tracing::instrument(skip(self))]
166    pub async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError> {
167        match self {
168            StorageLayer::S3(layer) => layer.create_bucket().await,
169        }
170    }
171
172    /// Checks if the bucket exists
173    #[tracing::instrument(skip(self))]
174    pub async fn bucket_exists(&self) -> Result<bool, StorageLayerError> {
175        match self {
176            StorageLayer::S3(layer) => layer.bucket_exists().await,
177        }
178    }
179
180    /// Deletes the tenant storage bucket
181    ///
182    /// In the event that the bucket did not exist before calling this
183    /// function this is treated as an [`Ok`] result
184    #[tracing::instrument(skip(self))]
185    pub async fn delete_bucket(&self) -> Result<(), StorageLayerError> {
186        match self {
187            StorageLayer::S3(layer) => layer.delete_bucket().await,
188        }
189    }
190
191    /// Create a presigned file upload URL
192    #[tracing::instrument(skip(self))]
193    pub async fn create_presigned(
194        &self,
195        key: &str,
196        size: i64,
197    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
198        match self {
199            StorageLayer::S3(layer) => layer.create_presigned(key, size).await,
200        }
201    }
202
203    /// Create a presigned file download URL
204    ///
205    /// Presigned download creation will succeed even if the requested key
206    /// is not present
207    #[tracing::instrument(skip(self))]
208    pub async fn create_presigned_download(
209        &self,
210        key: &str,
211        expires_in: Duration,
212    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
213        match self {
214            StorageLayer::S3(layer) => layer.create_presigned_download(key, expires_in).await,
215        }
216    }
217
218    /// Uploads a file to the S3 bucket for the tenant
219    #[tracing::instrument(skip(self, body), fields(body_length = body.len()))]
220    pub async fn upload_file(
221        &self,
222        key: &str,
223        body: Bytes,
224        options: UploadFileOptions,
225    ) -> Result<(), StorageLayerError> {
226        match self {
227            StorageLayer::S3(layer) => layer.upload_file(key, body, options).await,
228        }
229    }
230
231    /// Add the SNS notification to a bucket
232    #[tracing::instrument(skip(self))]
233    pub async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError> {
234        match self {
235            StorageLayer::S3(layer) => layer.add_bucket_notifications(sns_arn).await,
236        }
237    }
238
239    /// Sets the allowed CORS origins for accessing the storage from the frontend
240    #[tracing::instrument(skip(self))]
241    pub async fn set_bucket_cors_origins(
242        &self,
243        origins: Vec<String>,
244    ) -> Result<(), StorageLayerError> {
245        match self {
246            StorageLayer::S3(layer) => layer.set_bucket_cors_origins(origins).await,
247        }
248    }
249
250    /// Deletes the file with the provided `key`
251    ///
252    /// In the event that the file did not exist before calling this
253    /// function this is treated as an [`Ok`] result
254    #[tracing::instrument(skip(self))]
255    pub async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError> {
256        match self {
257            StorageLayer::S3(layer) => layer.delete_file(key).await,
258        }
259    }
260
261    /// Gets a byte stream for a file from S3
262    #[tracing::instrument(skip(self))]
263    pub async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError> {
264        match self {
265            StorageLayer::S3(layer) => layer.get_file(key).await,
266        }
267    }
268
269    /// Get pending migrations for the storage layer based on the list of already applied
270    /// migration names
271    #[tracing::instrument(skip(self))]
272    pub async fn get_pending_migrations(
273        &self,
274        applied_names: Vec<String>,
275    ) -> Result<Vec<String>, StorageLayerError> {
276        match self {
277            StorageLayer::S3(layer) => layer.get_pending_migrations(applied_names).await,
278        }
279    }
280
281    /// Apply a migration by name
282    #[tracing::instrument(skip(self))]
283    pub async fn apply_migration(&self, name: &str) -> Result<(), StorageLayerError> {
284        match self {
285            StorageLayer::S3(layer) => layer.apply_migration(name).await,
286        }
287    }
288}
289
290/// Internal trait defining required async implementations for a storage backend
291pub(crate) trait StorageLayerImpl {
292    fn bucket_name(&self) -> String;
293
294    async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError>;
295
296    async fn bucket_exists(&self) -> Result<bool, StorageLayerError>;
297
298    async fn delete_bucket(&self) -> Result<(), StorageLayerError>;
299
300    async fn create_presigned(
301        &self,
302        key: &str,
303        size: i64,
304    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
305
306    async fn create_presigned_download(
307        &self,
308        key: &str,
309        expires_in: Duration,
310    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
311
312    async fn upload_file(
313        &self,
314        key: &str,
315        body: Bytes,
316        options: UploadFileOptions,
317    ) -> Result<(), StorageLayerError>;
318
319    async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError>;
320
321    async fn set_bucket_cors_origins(&self, origins: Vec<String>) -> Result<(), StorageLayerError>;
322
323    async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError>;
324
325    async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError>;
326
327    async fn get_pending_migrations(
328        &self,
329        applied_names: Vec<String>,
330    ) -> Result<Vec<String>, StorageLayerError>;
331
332    async fn apply_migration(&self, name: &str) -> Result<(), StorageLayerError>;
333}
334
335/// Stream of bytes from a file
336pub struct FileStream {
337    /// Underlying stream
338    pub stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>,
339}
340
341impl Debug for FileStream {
342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343        f.debug_struct("FileStream").finish()
344    }
345}
346
347impl Stream for FileStream {
348    type Item = std::io::Result<Bytes>;
349
350    fn poll_next(
351        mut self: std::pin::Pin<&mut Self>,
352        cx: &mut std::task::Context<'_>,
353    ) -> std::task::Poll<Option<Self::Item>> {
354        self.stream.as_mut().poll_next(cx)
355    }
356}
357
358impl FileStream {
359    /// Collect the stream to completion as a single [Bytes] buffer
360    pub async fn collect_bytes(mut self) -> Result<Bytes, StorageLayerError> {
361        let mut output = SegmentedBuf::new();
362
363        while let Some(result) = self.next().await {
364            let chunk = result.map_err(|error| {
365                tracing::error!(?error, "failed to collect file stream bytes");
366                StorageLayerError::CollectBytes
367            })?;
368
369            output.push(chunk);
370        }
371
372        Ok(output.copy_to_bytes(output.remaining()))
373    }
374}