Skip to main content

docbox_storage/
lib.rs

1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3
4//! # Storage
5//!
6//! Docbox storage backend abstraction, handles abstracting the task of working with file
7//! storage to allow for multiple backends and easier testing.
8//!
9//! # Environment Variables
10//!
11//! See [s3] this is currently the only available backend for storage
12
13use aws_config::SdkConfig;
14use aws_sdk_s3::presigning::PresignedRequest;
15use bytes::{Buf, Bytes};
16use bytes_utils::SegmentedBuf;
17use chrono::{DateTime, Utc};
18use futures::{Stream, StreamExt};
19use serde::{Deserialize, Serialize};
20use std::{fmt::Debug, pin::Pin, time::Duration};
21use thiserror::Error;
22
23pub mod s3;
24
25/// Configuration for a storage layer factory
26#[derive(Debug, Clone, Deserialize, Serialize)]
27#[serde(tag = "provider", rename_all = "snake_case")]
28pub enum StorageLayerFactoryConfig {
29    /// Config for a S3 backend
30    S3(s3::S3StorageLayerFactoryConfig),
31}
32
33impl Default for StorageLayerFactoryConfig {
34    fn default() -> Self {
35        Self::S3(Default::default())
36    }
37}
38
39/// Errors that could occur when loading the storage layer factory
40/// configuration from the environment
41#[derive(Debug, Error)]
42pub enum StorageLayerFactoryConfigError {
43    /// Error from the S3 layer config
44    #[error(transparent)]
45    S3(#[from] s3::S3StorageLayerFactoryConfigError),
46}
47
48impl StorageLayerFactoryConfig {
49    /// Load the configuration from the current environment variables
50    pub fn from_env() -> Result<Self, StorageLayerFactoryConfigError> {
51        s3::S3StorageLayerFactoryConfig::from_env()
52            .map(Self::S3)
53            .map_err(StorageLayerFactoryConfigError::S3)
54    }
55}
56
57/// Storage layer factory for creating storage layer instances
58/// with some underlying backend implementation
59#[derive(Clone)]
60pub enum StorageLayerFactory {
61    /// S3 storage backend
62    S3(s3::S3StorageLayerFactory),
63}
64
65/// Errors that can occur when using a storage layer
66#[derive(Debug, Error)]
67pub enum StorageLayerError {
68    /// Error from the S3 layer
69    #[error(transparent)]
70    S3(Box<s3::S3StorageError>),
71
72    /// Error collecting streamed response bytes
73    #[error("failed to collect file contents")]
74    CollectBytes,
75}
76
77impl From<s3::S3StorageError> for StorageLayerError {
78    fn from(value: s3::S3StorageError) -> Self {
79        Self::S3(Box::new(value))
80    }
81}
82
83/// Options required to initialize a storage layer from a [StorageLayerFactory]
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub struct StorageLayerOptions {
86    /// Name of the storage bucket
87    pub bucket_name: String,
88}
89
90impl StorageLayerFactory {
91    /// Create a [StorageLayerFactory] from the provided config
92    pub fn from_config(aws_config: &SdkConfig, config: StorageLayerFactoryConfig) -> Self {
93        match config {
94            StorageLayerFactoryConfig::S3(config) => {
95                Self::S3(s3::S3StorageLayerFactory::from_config(aws_config, config))
96            }
97        }
98    }
99
100    /// Create a simple layer for testing purposes
101    #[cfg(debug_assertions)]
102    pub fn create_test_layer(&self) -> StorageLayer {
103        self.create_layer(StorageLayerOptions {
104            bucket_name: "test".to_string(),
105        })
106    }
107
108    /// Create a storage layer from the provided `options`
109    pub fn create_layer(&self, options: StorageLayerOptions) -> StorageLayer {
110        match self {
111            StorageLayerFactory::S3(s3) => {
112                let layer = s3.create_storage_layer(options.bucket_name);
113                StorageLayer::S3(layer)
114            }
115        }
116    }
117}
118
119/// Storage layer for a tenant with different underlying backends
120#[derive(Clone)]
121pub enum StorageLayer {
122    /// Storage layer backed by S3
123    S3(s3::S3StorageLayer),
124}
125
126/// Outcome from creating a bucket
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum CreateBucketOutcome {
129    /// Fresh bucket was created
130    New,
131    /// Bucket with the same name already exists
132    Existing,
133}
134
135impl StorageLayer {
136    /// Get the name of the bucket
137    pub fn bucket_name(&self) -> String {
138        match self {
139            StorageLayer::S3(layer) => layer.bucket_name(),
140        }
141    }
142
143    /// Creates the tenant storage bucket
144    ///
145    /// In the event that the bucket already exists, this is treated as a
146    /// [`Ok`] result rather than an error
147    #[tracing::instrument(skip(self))]
148    pub async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError> {
149        match self {
150            StorageLayer::S3(layer) => layer.create_bucket().await,
151        }
152    }
153
154    /// Checks if the bucket exists
155    #[tracing::instrument(skip(self))]
156    pub async fn bucket_exists(&self) -> Result<bool, StorageLayerError> {
157        match self {
158            StorageLayer::S3(layer) => layer.bucket_exists().await,
159        }
160    }
161
162    /// Deletes the tenant storage bucket
163    ///
164    /// In the event that the bucket did not exist before calling this
165    /// function this is treated as an [`Ok`] result
166    #[tracing::instrument(skip(self))]
167    pub async fn delete_bucket(&self) -> Result<(), StorageLayerError> {
168        match self {
169            StorageLayer::S3(layer) => layer.delete_bucket().await,
170        }
171    }
172
173    /// Create a presigned file upload URL
174    #[tracing::instrument(skip(self))]
175    pub async fn create_presigned(
176        &self,
177        key: &str,
178        size: i64,
179    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
180        match self {
181            StorageLayer::S3(layer) => layer.create_presigned(key, size).await,
182        }
183    }
184
185    /// Create a presigned file download URL
186    ///
187    /// Presigned download creation will succeed even if the requested key
188    /// is not present
189    #[tracing::instrument(skip(self))]
190    pub async fn create_presigned_download(
191        &self,
192        key: &str,
193        expires_in: Duration,
194    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
195        match self {
196            StorageLayer::S3(layer) => layer.create_presigned_download(key, expires_in).await,
197        }
198    }
199
200    /// Uploads a file to the S3 bucket for the tenant
201    #[tracing::instrument(skip(self, body), fields(body_length = body.len()))]
202    pub async fn upload_file(
203        &self,
204        key: &str,
205        content_type: String,
206        body: Bytes,
207    ) -> Result<(), StorageLayerError> {
208        match self {
209            StorageLayer::S3(layer) => layer.upload_file(key, content_type, body).await,
210        }
211    }
212
213    /// Add the SNS notification to a bucket
214    #[tracing::instrument(skip(self))]
215    pub async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError> {
216        match self {
217            StorageLayer::S3(layer) => layer.add_bucket_notifications(sns_arn).await,
218        }
219    }
220
221    /// Sets the allowed CORS origins for accessing the storage from the frontend
222    #[tracing::instrument(skip(self))]
223    pub async fn set_bucket_cors_origins(
224        &self,
225        origins: Vec<String>,
226    ) -> Result<(), StorageLayerError> {
227        match self {
228            StorageLayer::S3(layer) => layer.set_bucket_cors_origins(origins).await,
229        }
230    }
231
232    /// Deletes the file with the provided `key`
233    ///
234    /// In the event that the file did not exist before calling this
235    /// function this is treated as an [`Ok`] result
236    #[tracing::instrument(skip(self))]
237    pub async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError> {
238        match self {
239            StorageLayer::S3(layer) => layer.delete_file(key).await,
240        }
241    }
242
243    /// Gets a byte stream for a file from S3
244    #[tracing::instrument(skip(self))]
245    pub async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError> {
246        match self {
247            StorageLayer::S3(layer) => layer.get_file(key).await,
248        }
249    }
250}
251
252/// Internal trait defining required async implementations for a storage backend
253pub(crate) trait StorageLayerImpl {
254    fn bucket_name(&self) -> String;
255
256    async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError>;
257
258    async fn bucket_exists(&self) -> Result<bool, StorageLayerError>;
259
260    async fn delete_bucket(&self) -> Result<(), StorageLayerError>;
261
262    async fn create_presigned(
263        &self,
264        key: &str,
265        size: i64,
266    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
267
268    async fn create_presigned_download(
269        &self,
270        key: &str,
271        expires_in: Duration,
272    ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
273
274    async fn upload_file(
275        &self,
276        key: &str,
277        content_type: String,
278        body: Bytes,
279    ) -> Result<(), StorageLayerError>;
280
281    async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError>;
282
283    async fn set_bucket_cors_origins(&self, origins: Vec<String>) -> Result<(), StorageLayerError>;
284
285    async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError>;
286
287    async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError>;
288}
289
290/// Stream of bytes from a file
291pub struct FileStream {
292    /// Underlying stream
293    pub stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>,
294}
295
296impl Debug for FileStream {
297    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298        f.debug_struct("FileStream").finish()
299    }
300}
301
302impl Stream for FileStream {
303    type Item = std::io::Result<Bytes>;
304
305    fn poll_next(
306        mut self: std::pin::Pin<&mut Self>,
307        cx: &mut std::task::Context<'_>,
308    ) -> std::task::Poll<Option<Self::Item>> {
309        self.stream.as_mut().poll_next(cx)
310    }
311}
312
313impl FileStream {
314    /// Collect the stream to completion as a single [Bytes] buffer
315    pub async fn collect_bytes(mut self) -> Result<Bytes, StorageLayerError> {
316        let mut output = SegmentedBuf::new();
317
318        while let Some(result) = self.next().await {
319            let chunk = result.map_err(|error| {
320                tracing::error!(?error, "failed to collect file stream bytes");
321                StorageLayerError::CollectBytes
322            })?;
323
324            output.push(chunk);
325        }
326
327        Ok(output.copy_to_bytes(output.remaining()))
328    }
329}