docbox_core/storage/
mod.rs

1use aws_config::SdkConfig;
2use aws_sdk_s3::presigning::PresignedRequest;
3use bytes::{Buf, Bytes};
4use bytes_utils::SegmentedBuf;
5use chrono::{DateTime, Utc};
6use docbox_database::models::tenant::Tenant;
7use futures::{Stream, StreamExt};
8use s3::{S3StorageLayer, S3StorageLayerFactory};
9use serde::{Deserialize, Serialize};
10use std::{pin::Pin, time::Duration};
11
12use crate::storage::s3::S3StorageLayerFactoryConfig;
13
14pub mod s3;
15
16#[derive(Debug, Clone, Deserialize, Serialize)]
17#[serde(tag = "provider", rename_all = "snake_case")]
18pub enum StorageLayerFactoryConfig {
19    S3(S3StorageLayerFactoryConfig),
20}
21
22impl StorageLayerFactoryConfig {
23    pub fn from_env() -> anyhow::Result<Self> {
24        S3StorageLayerFactoryConfig::from_env().map(Self::S3)
25    }
26}
27
28#[derive(Clone)]
29pub enum StorageLayerFactory {
30    S3(S3StorageLayerFactory),
31}
32
33impl StorageLayerFactory {
34    pub fn from_config(aws_config: &SdkConfig, config: StorageLayerFactoryConfig) -> Self {
35        match config {
36            StorageLayerFactoryConfig::S3(config) => {
37                Self::S3(S3StorageLayerFactory::from_config(aws_config, config))
38            }
39        }
40    }
41
42    pub fn create_storage_layer(&self, tenant: &Tenant) -> TenantStorageLayer {
43        match self {
44            StorageLayerFactory::S3(s3) => {
45                let bucket_name = tenant.s3_name.clone();
46                let layer = s3.create_storage_layer(bucket_name);
47                TenantStorageLayer::S3(layer)
48            }
49        }
50    }
51}
52
53#[derive(Clone)]
54pub enum TenantStorageLayer {
55    /// Storage layer backed by S3
56    S3(S3StorageLayer),
57}
58
59impl TenantStorageLayer {
60    /// Creates the tenant S3 bucket
61    pub async fn create_bucket(&self) -> anyhow::Result<()> {
62        match self {
63            TenantStorageLayer::S3(layer) => layer.create_bucket().await,
64        }
65    }
66
67    /// Deletes the tenant S3 bucket
68    pub async fn delete_bucket(&self) -> anyhow::Result<()> {
69        match self {
70            TenantStorageLayer::S3(layer) => layer.delete_bucket().await,
71        }
72    }
73
74    /// Create a presigned file upload URL
75    pub async fn create_presigned(
76        &self,
77        key: &str,
78        size: i64,
79    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
80        match self {
81            TenantStorageLayer::S3(layer) => layer.create_presigned(key, size).await,
82        }
83    }
84
85    /// Create a presigned file download URL
86    pub async fn create_presigned_download(
87        &self,
88        key: &str,
89        expires_in: Duration,
90    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)> {
91        match self {
92            TenantStorageLayer::S3(layer) => layer.create_presigned_download(key, expires_in).await,
93        }
94    }
95
96    /// Uploads a file to the S3 bucket for the tenant
97    pub async fn upload_file(
98        &self,
99        key: &str,
100        content_type: String,
101        body: Bytes,
102    ) -> anyhow::Result<()> {
103        match self {
104            TenantStorageLayer::S3(layer) => layer.upload_file(key, content_type, body).await,
105        }
106    }
107
108    /// Add the SNS notification to a bucket
109    pub async fn add_bucket_notifications(&self, sns_arn: &str) -> anyhow::Result<()> {
110        match self {
111            TenantStorageLayer::S3(layer) => layer.add_bucket_notifications(sns_arn).await,
112        }
113    }
114
115    /// Applies CORS rules for a bucket
116    pub async fn add_bucket_cors(&self, origins: Vec<String>) -> anyhow::Result<()> {
117        match self {
118            TenantStorageLayer::S3(layer) => layer.add_bucket_cors(origins).await,
119        }
120    }
121
122    /// Deletes the S3 file
123    pub async fn delete_file(&self, key: &str) -> anyhow::Result<()> {
124        match self {
125            TenantStorageLayer::S3(layer) => layer.delete_file(key).await,
126        }
127    }
128
129    /// Gets a byte stream for a file from S3
130    pub async fn get_file(&self, key: &str) -> anyhow::Result<FileStream> {
131        match self {
132            TenantStorageLayer::S3(layer) => layer.get_file(key).await,
133        }
134    }
135}
136
137pub(crate) trait StorageLayer {
138    /// Creates the tenant S3 bucket
139    async fn create_bucket(&self) -> anyhow::Result<()>;
140
141    /// Deletes the tenant S3 bucket
142    async fn delete_bucket(&self) -> anyhow::Result<()>;
143
144    /// Create a presigned file upload URL
145    async fn create_presigned(
146        &self,
147        key: &str,
148        size: i64,
149    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)>;
150
151    /// Create a presigned file download URL
152    async fn create_presigned_download(
153        &self,
154        key: &str,
155        expires_in: Duration,
156    ) -> anyhow::Result<(PresignedRequest, DateTime<Utc>)>;
157
158    /// Uploads a file to the S3 bucket for the tenant
159    async fn upload_file(&self, key: &str, content_type: String, body: Bytes)
160    -> anyhow::Result<()>;
161
162    /// Add the SNS notification to a bucket
163    async fn add_bucket_notifications(&self, sns_arn: &str) -> anyhow::Result<()>;
164
165    /// Applies CORS rules for a bucket
166    async fn add_bucket_cors(&self, origins: Vec<String>) -> anyhow::Result<()>;
167
168    /// Deletes the S3 file
169    async fn delete_file(&self, key: &str) -> anyhow::Result<()>;
170
171    /// Gets a byte stream for a file from S3
172    async fn get_file(&self, key: &str) -> anyhow::Result<FileStream>;
173}
174
175/// Stream of bytes from a file
176pub struct FileStream {
177    pub stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>,
178}
179
180impl Stream for FileStream {
181    type Item = std::io::Result<Bytes>;
182
183    fn poll_next(
184        mut self: std::pin::Pin<&mut Self>,
185        cx: &mut std::task::Context<'_>,
186    ) -> std::task::Poll<Option<Self::Item>> {
187        // Pin projection to the underlying stream
188        let stream = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.stream) };
189        stream.poll_next(cx)
190    }
191}
192
193impl FileStream {
194    /// Collect the stream to completion as a single [Bytes] buffer
195    pub async fn collect_bytes(mut self) -> anyhow::Result<Bytes> {
196        let mut output = SegmentedBuf::new();
197
198        while let Some(result) = self.next().await {
199            let chunk = result?;
200            output.push(chunk);
201        }
202
203        Ok(output.copy_to_bytes(output.remaining()))
204    }
205}