use std::collections::HashMap;
use std::error::Error;
use std::sync::Arc;
use once_cell::sync::OnceCell;
use crate::local_file_storage::LocalFileStorageFactory;
use crate::ram_storage::RamStorageFactory;
use crate::{S3CompatibleObjectStorageFactory, Storage, StorageResolverError};
pub fn quickwit_storage_uri_resolver() -> &'static StorageUriResolver {
static STORAGE_URI_RESOLVER: OnceCell<StorageUriResolver> = OnceCell::new();
STORAGE_URI_RESOLVER.get_or_init(|| {
StorageUriResolver::builder()
.register(RamStorageFactory::default())
.register(LocalFileStorageFactory::default())
.register(S3CompatibleObjectStorageFactory::default())
.build()
})
}
#[cfg_attr(any(test, feature = "testsuite"), mockall::automock)]
pub trait StorageFactory: Send + Sync + 'static {
fn protocol(&self) -> String;
fn resolve(&self, uri: &str) -> crate::StorageResult<Arc<dyn Storage>>;
}
#[derive(Clone)]
pub struct StorageUriResolver {
per_protocol_resolver: Arc<HashMap<String, Arc<dyn StorageFactory>>>,
}
#[derive(Default)]
pub struct StorageUriResolverBuilder {
per_protocol_resolver: HashMap<String, Arc<dyn StorageFactory>>,
}
impl StorageUriResolverBuilder {
pub fn register<S: StorageFactory>(mut self, resolver: S) -> Self {
self.per_protocol_resolver
.insert(resolver.protocol(), Arc::new(resolver));
self
}
pub fn build(self) -> StorageUriResolver {
StorageUriResolver {
per_protocol_resolver: Arc::new(self.per_protocol_resolver),
}
}
}
impl StorageUriResolver {
pub fn builder() -> StorageUriResolverBuilder {
StorageUriResolverBuilder::default()
}
#[doc(hidden)]
pub fn for_test() -> Self {
StorageUriResolver::builder()
.register(RamStorageFactory::default())
.register(LocalFileStorageFactory::default())
.register(S3CompatibleObjectStorageFactory::default())
.build()
}
pub fn resolve(&self, uri: &str) -> Result<Arc<dyn Storage>, StorageResolverError> {
let protocol = uri
.split("://")
.next()
.ok_or_else(|| StorageResolverError::InvalidUri {
message: format!("Protocol not found in storage uri: {}", uri),
})?;
let resolver = self.per_protocol_resolver.get(protocol).ok_or_else(|| {
StorageResolverError::ProtocolUnsupported {
protocol: protocol.to_string(),
}
})?;
let storage = resolver.resolve(uri).map_err(|storage_error| {
StorageResolverError::FailedToOpenStorage {
kind: storage_error.kind(),
message: storage_error
.source()
.map(|err| format!("{}", err))
.unwrap_or_else(String::new),
}
})?;
Ok(storage)
}
}
#[cfg(test)]
mod tests {
use std::path::Path;
use super::*;
use crate::RamStorage;
#[tokio::test]
async fn test_storage_resolver_simple() -> anyhow::Result<()> {
let mut first = MockStorageFactory::new();
first.expect_protocol().returning(|| "first".to_string());
let mut second = MockStorageFactory::new();
second.expect_protocol().returning(|| "second".to_string());
second.expect_resolve().returning(|_uri| {
Ok(Arc::new(
RamStorage::builder()
.put("hello", b"hello_content_second")
.build(),
))
});
let storage_resolver = StorageUriResolver::builder()
.register(first)
.register(second)
.build();
let resolved = storage_resolver.resolve("second://")?;
let data = resolved.get_all(Path::new("hello")).await?;
assert_eq!(&data[..], b"hello_content_second");
Ok(())
}
#[tokio::test]
async fn test_storage_resolver_override() -> anyhow::Result<()> {
let mut first = MockStorageFactory::new();
first.expect_protocol().returning(|| "protocol".to_string());
let mut second = MockStorageFactory::new();
second
.expect_protocol()
.returning(|| "protocol".to_string());
second.expect_resolve().returning(|uri| {
assert_eq!(uri, "protocol://mystorage");
Ok(Arc::new(
RamStorage::builder()
.put("hello", b"hello_content_second")
.build(),
))
});
let storage_resolver = StorageUriResolver::builder()
.register(first)
.register(second)
.build();
let resolved = storage_resolver.resolve("protocol://mystorage")?;
let data = resolved.get_all(Path::new("hello")).await?;
assert_eq!(&data[..], b"hello_content_second");
Ok(())
}
#[test]
fn test_storage_resolver_unsupported_protocol() {
let storage_resolver = StorageUriResolver::for_test();
assert!(matches!(
storage_resolver.resolve("protocol://hello"),
Err(crate::StorageResolverError::ProtocolUnsupported { protocol }) if protocol == "protocol"
));
}
}