use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use async_trait::async_trait;
pub use async_trait::async_trait as async_trait_reexport;
pub type ByteStream =
std::pin::Pin<Box<dyn futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> + Send>>;
pub const CAP_EXCEEDED_KIND: std::io::ErrorKind = std::io::ErrorKind::Other;
pub const CAP_EXCEEDED_MARKER: &str = "umbral-storage-cap-exceeded";
pub fn cap_stream(body: ByteStream, max: u64) -> ByteStream {
use futures_util::StreamExt;
let mut seen: u64 = 0;
let mut tripped = false;
let capped = body.flat_map(move |item| {
if tripped {
return futures_util::stream::iter(Vec::new());
}
match item {
Ok(chunk) => {
seen = seen.saturating_add(chunk.len() as u64);
if seen > max {
tripped = true;
let err = std::io::Error::new(CAP_EXCEEDED_KIND, CAP_EXCEEDED_MARKER);
futures_util::stream::iter(vec![Err(err)])
} else {
futures_util::stream::iter(vec![Ok(chunk)])
}
}
Err(e) => {
tripped = true;
futures_util::stream::iter(vec![Err(e)])
}
}
});
Box::pin(capped)
}
pub fn is_cap_exceeded(e: &std::io::Error) -> bool {
e.kind() == CAP_EXCEEDED_KIND && e.to_string().contains(CAP_EXCEEDED_MARKER)
}
#[async_trait]
pub trait Storage: Send + Sync {
async fn store(
&self,
filename: &str,
content_type: &str,
bytes: &[u8],
) -> Result<StoredFile, StorageError>;
async fn retrieve(&self, key: &str) -> Result<Vec<u8>, StorageError>;
async fn store_stream(
&self,
filename: &str,
content_type: &str,
body: ByteStream,
) -> Result<StoredFile, StorageError> {
let mut bytes: Vec<u8> = Vec::new();
let mut body = body;
while let Some(chunk) = futures_util::StreamExt::next(&mut body).await {
let chunk = chunk.map_err(StorageError::Io)?;
bytes.extend_from_slice(&chunk);
}
self.store(filename, content_type, &bytes).await
}
async fn retrieve_stream(&self, key: &str) -> Result<ByteStream, StorageError> {
let bytes = self.retrieve(key).await?;
let chunk: Result<bytes::Bytes, std::io::Error> = Ok(bytes::Bytes::from(bytes));
Ok(Box::pin(futures_util::stream::once(async move { chunk })))
}
async fn put(
&self,
key: &str,
content_type: &str,
bytes: &[u8],
) -> Result<StoredFile, StorageError> {
let _ = (key, content_type, bytes);
Err(StorageError::Unsupported(
"this Storage backend does not implement put(); override it to write at an exact key"
.to_string(),
))
}
async fn put_stream(
&self,
key: &str,
content_type: &str,
body: ByteStream,
) -> Result<StoredFile, StorageError> {
let mut bytes: Vec<u8> = Vec::new();
let mut body = body;
while let Some(chunk) = futures_util::StreamExt::next(&mut body).await {
let chunk = chunk.map_err(StorageError::Io)?;
bytes.extend_from_slice(&chunk);
}
self.put(key, content_type, &bytes).await
}
async fn exists(&self, key: &str) -> Result<bool, StorageError> {
Ok(self.retrieve(key).await.is_ok())
}
async fn delete(&self, key: &str) -> Result<(), StorageError>;
fn url(&self, key: &str) -> String;
}
#[derive(Debug, Clone)]
pub struct StoredFile {
pub key: String,
pub url: String,
pub size: u64,
}
#[derive(Debug)]
pub enum StorageError {
NoBackend,
NotFound,
TooLarge {
limit: u64,
actual: u64,
},
Io(std::io::Error),
Backend(String),
Unsupported(String),
}
impl std::fmt::Display for StorageError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StorageError::NoBackend => write!(
f,
"storage: no backend registered; add StoragePlugin or call set_storage"
),
StorageError::NotFound => write!(f, "storage: object not found"),
StorageError::TooLarge { limit, actual } => write!(
f,
"storage: object {actual}B exceeds configured cap of {limit}B"
),
StorageError::Io(e) => write!(f, "storage: io: {e}"),
StorageError::Backend(s) => write!(f, "storage: backend: {s}"),
StorageError::Unsupported(s) => write!(f, "storage: unsupported: {s}"),
}
}
}
impl std::error::Error for StorageError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
StorageError::Io(e) => Some(e),
_ => None,
}
}
}
impl From<std::io::Error> for StorageError {
fn from(e: std::io::Error) -> Self {
StorageError::Io(e)
}
}
pub const DEFAULT: &str = "default";
pub const STATICFILES: &str = "staticfiles";
static STORAGES: OnceLock<Mutex<HashMap<&'static str, Arc<dyn Storage>>>> = OnceLock::new();
fn registry() -> &'static Mutex<HashMap<&'static str, Arc<dyn Storage>>> {
STORAGES.get_or_init(|| Mutex::new(HashMap::new()))
}
pub fn set_storage_named(name: &'static str, s: Arc<dyn Storage>) -> bool {
let mut map = registry()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if map.contains_key(name) {
tracing::warn!(
name,
"umbral::storage::set_storage_named called more than once for the same name; \
keeping the first-registered backend and ignoring the new one"
);
false
} else {
map.insert(name, s);
true
}
}
pub fn storage_named(name: &str) -> Arc<dyn Storage> {
try_storage_named(name).unwrap_or_else(|_| {
panic!(
"no Storage backend registered under `{name}`; add the owning plugin \
(StoragePlugin for `default`) or call umbral::storage::set_storage_named"
)
})
}
pub fn try_storage_named(name: &str) -> Result<Arc<dyn Storage>, StorageError> {
storage_opt_named(name).ok_or(StorageError::NoBackend)
}
pub fn storage_opt_named(name: &str) -> Option<Arc<dyn Storage>> {
let map = registry()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
map.get(name).cloned()
}
pub fn set_storage(s: Arc<dyn Storage>) -> bool {
set_storage_named(DEFAULT, s)
}
pub fn storage() -> Arc<dyn Storage> {
try_storage().expect(
"no Storage backend registered; add StoragePlugin or call umbral::storage::set_storage",
)
}
pub fn try_storage() -> Result<Arc<dyn Storage>, StorageError> {
try_storage_named(DEFAULT)
}
pub fn storage_opt() -> Option<Arc<dyn Storage>> {
storage_opt_named(DEFAULT)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap as Map;
use std::sync::Mutex as StdMutex;
struct MemNoPut {
objects: StdMutex<Map<String, Vec<u8>>>,
}
impl MemNoPut {
fn new() -> Self {
Self {
objects: StdMutex::new(Map::new()),
}
}
}
#[async_trait]
impl Storage for MemNoPut {
async fn store(
&self,
filename: &str,
_content_type: &str,
bytes: &[u8],
) -> Result<StoredFile, StorageError> {
let key = format!("k-{filename}");
self.objects
.lock()
.unwrap()
.insert(key.clone(), bytes.to_vec());
Ok(StoredFile {
url: self.url(&key),
key,
size: bytes.len() as u64,
})
}
async fn retrieve(&self, key: &str) -> Result<Vec<u8>, StorageError> {
self.objects
.lock()
.unwrap()
.get(key)
.cloned()
.ok_or(StorageError::NotFound)
}
async fn delete(&self, key: &str) -> Result<(), StorageError> {
self.objects.lock().unwrap().remove(key);
Ok(())
}
fn url(&self, key: &str) -> String {
format!("/mem/{key}")
}
}
struct MemWithPut {
objects: StdMutex<Map<String, Vec<u8>>>,
}
impl MemWithPut {
fn new() -> Self {
Self {
objects: StdMutex::new(Map::new()),
}
}
}
#[async_trait]
impl Storage for MemWithPut {
async fn store(
&self,
filename: &str,
ct: &str,
bytes: &[u8],
) -> Result<StoredFile, StorageError> {
self.put(&format!("k-{filename}"), ct, bytes).await
}
async fn retrieve(&self, key: &str) -> Result<Vec<u8>, StorageError> {
self.objects
.lock()
.unwrap()
.get(key)
.cloned()
.ok_or(StorageError::NotFound)
}
async fn put(
&self,
key: &str,
_ct: &str,
bytes: &[u8],
) -> Result<StoredFile, StorageError> {
self.objects
.lock()
.unwrap()
.insert(key.to_string(), bytes.to_vec());
Ok(StoredFile {
url: self.url(key),
key: key.to_string(),
size: bytes.len() as u64,
})
}
async fn delete(&self, key: &str) -> Result<(), StorageError> {
self.objects.lock().unwrap().remove(key);
Ok(())
}
fn url(&self, key: &str) -> String {
format!("/mem/{key}")
}
}
#[tokio::test]
async fn put_default_returns_unsupported() {
let s = MemNoPut::new();
let err = s.put("css/app.css", "text/css", b"x").await.unwrap_err();
match err {
StorageError::Unsupported(msg) => {
assert!(msg.contains("does not implement put"), "msg = {msg}");
}
other => panic!("expected Unsupported, got {other:?}"),
}
}
#[tokio::test]
async fn put_override_writes_at_exact_key() {
let s = MemWithPut::new();
let stored = s
.put("css/app.css", "text/css", b"body{}")
.await
.unwrap();
assert_eq!(stored.key, "css/app.css");
assert_eq!(stored.size, 6);
assert_eq!(s.retrieve("css/app.css").await.unwrap(), b"body{}");
}
#[tokio::test]
async fn exists_default_true_after_store_false_when_missing() {
let s = MemNoPut::new();
let stored = s.store("a.txt", "text/plain", b"hi").await.unwrap();
assert!(s.exists(&stored.key).await.unwrap());
assert!(!s.exists("nope").await.unwrap());
}
#[tokio::test]
async fn put_stream_default_delegates_to_put() {
let s = MemWithPut::new();
let body: ByteStream = Box::pin(futures_util::stream::once(async {
Ok(bytes::Bytes::from_static(b"streamed"))
}));
let stored = s.put_stream("js/app.js", "text/javascript", body).await.unwrap();
assert_eq!(stored.key, "js/app.js");
assert_eq!(s.retrieve("js/app.js").await.unwrap(), b"streamed");
}
}