docbox_storage/
lib.rs

1use aws_config::SdkConfig;
2use aws_sdk_s3::presigning::PresignedRequest;
3use bytes::{Buf, Bytes};
4use bytes_utils::SegmentedBuf;
5use chrono::{DateTime, Utc};
6use docbox_database::models::tenant::Tenant;
7use futures::{Stream, StreamExt};
8use serde::{Deserialize, Serialize};
9use std::{pin::Pin, time::Duration};
10
11pub mod s3;
12
13#[derive(Debug, Clone, Deserialize, Serialize)]
14#[serde(tag = "provider", rename_all = "snake_case")]
15pub enum StorageLayerFactoryConfig {
16    S3(s3::S3StorageLayerFactoryConfig),
17}
18
19impl StorageLayerFactoryConfig {
20    pub fn from_env() -> anyhow::Result<Self> {
21        s3::S3StorageLayerFactoryConfig::from_env().map(Self::S3)
22    }
23}
24
25#[derive(Clone)]
26pub enum StorageLayerFactory {
27    S3(s3::S3StorageLayerFactory),
28}
29
30impl StorageLayerFactory {
31    pub fn from_config(aws_config: &SdkConfig, config: StorageLayerFactoryConfig) -> Self {
32        match config {
33            StorageLayerFactoryConfig::S3(config) => {
34                Self::S3(s3::S3StorageLayerFactory::from_config(aws_config, config))
35            }
36        }
37    }
38
39    pub fn create_storage_layer(&self, tenant: &Tenant) -> TenantStorageLayer {
40        match self {
41            StorageLayerFactory::S3(s3) => {
42                let bucket_name = tenant.s3_name.clone();
43                let layer = s3.create_storage_layer(bucket_name);
44                TenantStorageLayer::S3(layer)
45            }
46        }
47    }
48}
49
50#[derive(Clone)]
51pub enum TenantStorageLayer {
52    /// Storage layer backed by S3
53    S3(s3::S3StorageLayer),
54}
55
56impl TenantStorageLayer {
57    /// Creates the tenant S3 bucket
58    pub async fn create_bucket(&self) -> anyhow::Result<()> {
59        match self {
60            TenantStorageLayer::S3(layer) => layer.create_bucket().await,
61        }
62    }
63
64    /// Deletes the tenant S3 bucket
65    pub async fn delete_bucket(&self) -> anyhow::Result<()> {
66        match self {
67            TenantStorageLayer::S3(layer) => layer.delete_bucket().await,
68        }
69    }
70
71    /// Create a presigned file upload URL
72    pub async fn create_presigned(
73        &self,
74        key: &str,
75        size: i64,
76    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
77        match self {
78            TenantStorageLayer::S3(layer) => layer.create_presigned(key, size).await,
79        }
80    }
81
82    /// Create a presigned file download URL
83    pub async fn create_presigned_download(
84        &self,
85        key: &str,
86        expires_in: Duration,
87    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
88        match self {
89            TenantStorageLayer::S3(layer) => layer.create_presigned_download(key, expires_in).await,
90        }
91    }
92
93    /// Uploads a file to the S3 bucket for the tenant
94    pub async fn upload_file(
95        &self,
96        key: &str,
97        content_type: String,
98        body: Bytes,
99    ) -> anyhow::Result<()> {
100        match self {
101            TenantStorageLayer::S3(layer) => layer.upload_file(key, content_type, body).await,
102        }
103    }
104
105    /// Add the SNS notification to a bucket
106    pub async fn add_bucket_notifications(&self, sns_arn: &str) -> anyhow::Result<()> {
107        match self {
108            TenantStorageLayer::S3(layer) => layer.add_bucket_notifications(sns_arn).await,
109        }
110    }
111
112    /// Applies CORS rules for a bucket
113    pub async fn add_bucket_cors(&self, origins: Vec<String>) -> anyhow::Result<()> {
114        match self {
115            TenantStorageLayer::S3(layer) => layer.add_bucket_cors(origins).await,
116        }
117    }
118
119    /// Deletes the S3 file
120    pub async fn delete_file(&self, key: &str) -> anyhow::Result<()> {
121        match self {
122            TenantStorageLayer::S3(layer) => layer.delete_file(key).await,
123        }
124    }
125
126    /// Gets a byte stream for a file from S3
127    pub async fn get_file(&self, key: &str) -> anyhow::Result<FileStream> {
128        match self {
129            TenantStorageLayer::S3(layer) => layer.get_file(key).await,
130        }
131    }
132}
133
134/// Internal trait defining required async implementations for a storage backend
135pub(crate) trait StorageLayer {
136    /// Creates the tenant S3 bucket
137    async fn create_bucket(&self) -> anyhow::Result<()>;
138
139    /// Deletes the tenant S3 bucket
140    async fn delete_bucket(&self) -> anyhow::Result<()>;
141
142    /// Create a presigned file upload URL
143    async fn create_presigned(
144        &self,
145        key: &str,
146        size: i64,
147    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)>;
148
149    /// Create a presigned file download URL
150    async fn create_presigned_download(
151        &self,
152        key: &str,
153        expires_in: Duration,
154    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)>;
155
156    /// Uploads a file to the S3 bucket for the tenant
157    async fn upload_file(&self, key: &str, content_type: String, body: Bytes)
158    -> anyhow::Result<()>;
159
160    /// Add the SNS notification to a bucket
161    async fn add_bucket_notifications(&self, sns_arn: &str) -> anyhow::Result<()>;
162
163    /// Applies CORS rules for a bucket
164    async fn add_bucket_cors(&self, origins: Vec<String>) -> anyhow::Result<()>;
165
166    /// Deletes the S3 file
167    async fn delete_file(&self, key: &str) -> anyhow::Result<()>;
168
169    /// Gets a byte stream for a file from S3
170    async fn get_file(&self, key: &str) -> anyhow::Result<FileStream>;
171}
172
173/// Stream of bytes from a file
174pub struct FileStream {
175    pub stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>,
176}
177
178impl Stream for FileStream {
179    type Item = std::io::Result<Bytes>;
180
181    fn poll_next(
182        mut self: std::pin::Pin<&mut Self>,
183        cx: &mut std::task::Context<'_>,
184    ) -> std::task::Poll<Option<Self::Item>> {
185        // Pin projection to the underlying stream
186        let stream = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.stream) };
187        stream.poll_next(cx)
188    }
189}
190
191impl FileStream {
192    /// Collect the stream to completion as a single [Bytes] buffer
193    pub async fn collect_bytes(mut self) -> anyhow::Result<Bytes> {
194        let mut output = SegmentedBuf::new();
195
196        while let Some(result) = self.next().await {
197            let chunk = result?;
198            output.push(chunk);
199        }
200
201        Ok(output.copy_to_bytes(output.remaining()))
202    }
203}