1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3
4use aws_config::SdkConfig;
14use aws_sdk_s3::presigning::PresignedRequest;
15use bytes::{Buf, Bytes};
16use bytes_utils::SegmentedBuf;
17use chrono::{DateTime, Utc};
18use futures::{Stream, StreamExt};
19use serde::{Deserialize, Serialize};
20use std::{fmt::Debug, pin::Pin, time::Duration};
21use thiserror::Error;
22
23pub mod s3;
24
25#[derive(Debug, Clone, Deserialize, Serialize)]
27#[serde(tag = "provider", rename_all = "snake_case")]
28pub enum StorageLayerFactoryConfig {
29 S3(s3::S3StorageLayerFactoryConfig),
31}
32
33impl Default for StorageLayerFactoryConfig {
34 fn default() -> Self {
35 Self::S3(Default::default())
36 }
37}
38
39#[derive(Debug, Error)]
42pub enum StorageLayerFactoryConfigError {
43 #[error(transparent)]
45 S3(#[from] s3::S3StorageLayerFactoryConfigError),
46}
47
48impl StorageLayerFactoryConfig {
49 pub fn from_env() -> Result<Self, StorageLayerFactoryConfigError> {
51 s3::S3StorageLayerFactoryConfig::from_env()
52 .map(Self::S3)
53 .map_err(StorageLayerFactoryConfigError::S3)
54 }
55}
56
57#[derive(Clone)]
60pub enum StorageLayerFactory {
61 S3(s3::S3StorageLayerFactory),
63}
64
65#[derive(Debug, Error)]
67pub enum StorageLayerError {
68 #[error(transparent)]
70 S3(Box<s3::S3StorageError>),
71
72 #[error("failed to collect file contents")]
74 CollectBytes,
75}
76
77impl From<s3::S3StorageError> for StorageLayerError {
78 fn from(value: s3::S3StorageError) -> Self {
79 Self::S3(Box::new(value))
80 }
81}
82
83#[derive(Debug, Clone, PartialEq, Eq)]
85pub struct StorageLayerOptions {
86 pub bucket_name: String,
88}
89
90impl StorageLayerFactory {
91 pub fn from_config(aws_config: &SdkConfig, config: StorageLayerFactoryConfig) -> Self {
93 match config {
94 StorageLayerFactoryConfig::S3(config) => {
95 Self::S3(s3::S3StorageLayerFactory::from_config(aws_config, config))
96 }
97 }
98 }
99
100 #[cfg(debug_assertions)]
102 pub fn create_test_layer(&self) -> StorageLayer {
103 self.create_layer(StorageLayerOptions {
104 bucket_name: "test".to_string(),
105 })
106 }
107
108 pub fn create_layer(&self, options: StorageLayerOptions) -> StorageLayer {
110 match self {
111 StorageLayerFactory::S3(s3) => {
112 let layer = s3.create_storage_layer(options.bucket_name);
113 StorageLayer::S3(layer)
114 }
115 }
116 }
117}
118
119#[derive(Clone)]
121pub enum StorageLayer {
122 S3(s3::S3StorageLayer),
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum CreateBucketOutcome {
129 New,
131 Existing,
133}
134
135impl StorageLayer {
136 pub fn bucket_name(&self) -> String {
138 match self {
139 StorageLayer::S3(layer) => layer.bucket_name(),
140 }
141 }
142
143 #[tracing::instrument(skip(self))]
148 pub async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError> {
149 match self {
150 StorageLayer::S3(layer) => layer.create_bucket().await,
151 }
152 }
153
154 #[tracing::instrument(skip(self))]
156 pub async fn bucket_exists(&self) -> Result<bool, StorageLayerError> {
157 match self {
158 StorageLayer::S3(layer) => layer.bucket_exists().await,
159 }
160 }
161
162 #[tracing::instrument(skip(self))]
167 pub async fn delete_bucket(&self) -> Result<(), StorageLayerError> {
168 match self {
169 StorageLayer::S3(layer) => layer.delete_bucket().await,
170 }
171 }
172
173 #[tracing::instrument(skip(self))]
175 pub async fn create_presigned(
176 &self,
177 key: &str,
178 size: i64,
179 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
180 match self {
181 StorageLayer::S3(layer) => layer.create_presigned(key, size).await,
182 }
183 }
184
185 #[tracing::instrument(skip(self))]
190 pub async fn create_presigned_download(
191 &self,
192 key: &str,
193 expires_in: Duration,
194 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError> {
195 match self {
196 StorageLayer::S3(layer) => layer.create_presigned_download(key, expires_in).await,
197 }
198 }
199
200 #[tracing::instrument(skip(self, body), fields(body_length = body.len()))]
202 pub async fn upload_file(
203 &self,
204 key: &str,
205 content_type: String,
206 body: Bytes,
207 ) -> Result<(), StorageLayerError> {
208 match self {
209 StorageLayer::S3(layer) => layer.upload_file(key, content_type, body).await,
210 }
211 }
212
213 #[tracing::instrument(skip(self))]
215 pub async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError> {
216 match self {
217 StorageLayer::S3(layer) => layer.add_bucket_notifications(sns_arn).await,
218 }
219 }
220
221 #[tracing::instrument(skip(self))]
223 pub async fn set_bucket_cors_origins(
224 &self,
225 origins: Vec<String>,
226 ) -> Result<(), StorageLayerError> {
227 match self {
228 StorageLayer::S3(layer) => layer.set_bucket_cors_origins(origins).await,
229 }
230 }
231
232 #[tracing::instrument(skip(self))]
237 pub async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError> {
238 match self {
239 StorageLayer::S3(layer) => layer.delete_file(key).await,
240 }
241 }
242
243 #[tracing::instrument(skip(self))]
245 pub async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError> {
246 match self {
247 StorageLayer::S3(layer) => layer.get_file(key).await,
248 }
249 }
250}
251
252pub(crate) trait StorageLayerImpl {
254 fn bucket_name(&self) -> String;
255
256 async fn create_bucket(&self) -> Result<CreateBucketOutcome, StorageLayerError>;
257
258 async fn bucket_exists(&self) -> Result<bool, StorageLayerError>;
259
260 async fn delete_bucket(&self) -> Result<(), StorageLayerError>;
261
262 async fn create_presigned(
263 &self,
264 key: &str,
265 size: i64,
266 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
267
268 async fn create_presigned_download(
269 &self,
270 key: &str,
271 expires_in: Duration,
272 ) -> Result<(PresignedRequest, DateTime<Utc>), StorageLayerError>;
273
274 async fn upload_file(
275 &self,
276 key: &str,
277 content_type: String,
278 body: Bytes,
279 ) -> Result<(), StorageLayerError>;
280
281 async fn add_bucket_notifications(&self, sns_arn: &str) -> Result<(), StorageLayerError>;
282
283 async fn set_bucket_cors_origins(&self, origins: Vec<String>) -> Result<(), StorageLayerError>;
284
285 async fn delete_file(&self, key: &str) -> Result<(), StorageLayerError>;
286
287 async fn get_file(&self, key: &str) -> Result<FileStream, StorageLayerError>;
288}
289
290pub struct FileStream {
292 pub stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>,
294}
295
296impl Debug for FileStream {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 f.debug_struct("FileStream").finish()
299 }
300}
301
302impl Stream for FileStream {
303 type Item = std::io::Result<Bytes>;
304
305 fn poll_next(
306 mut self: std::pin::Pin<&mut Self>,
307 cx: &mut std::task::Context<'_>,
308 ) -> std::task::Poll<Option<Self::Item>> {
309 self.stream.as_mut().poll_next(cx)
310 }
311}
312
313impl FileStream {
314 pub async fn collect_bytes(mut self) -> Result<Bytes, StorageLayerError> {
316 let mut output = SegmentedBuf::new();
317
318 while let Some(result) = self.next().await {
319 let chunk = result.map_err(|error| {
320 tracing::error!(?error, "failed to collect file stream bytes");
321 StorageLayerError::CollectBytes
322 })?;
323
324 output.push(chunk);
325 }
326
327 Ok(output.copy_to_bytes(output.remaining()))
328 }
329}