use std::{
collections::HashMap,
sync::{
Arc, RwLock, Weak,
atomic::{AtomicU64, Ordering},
},
};
use object_store::path::Path;
use url::Url;
use crate::object_store::WrappingObjectStore;
use crate::object_store::uri_to_url;
use super::{ObjectStore, ObjectStoreParams, tracing::ObjectStoreTracingExt};
use lance_core::error::{Error, LanceOptionExt, Result};
#[cfg(feature = "aws")]
pub mod aws;
#[cfg(feature = "azure")]
pub mod azure;
#[cfg(feature = "gcp")]
pub mod gcp;
#[cfg(feature = "huggingface")]
pub mod huggingface;
pub mod local;
pub mod memory;
#[cfg(feature = "oss")]
pub mod oss;
#[cfg(feature = "tencent")]
pub mod tencent;
#[async_trait::async_trait]
pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
async fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
fn extract_path(&self, url: &Url) -> Result<Path> {
Path::parse(url.path())
.map_err(|_| Error::invalid_input(format!("Invalid path in URL: {}", url.path())))
}
fn calculate_object_store_prefix(
&self,
url: &Url,
_storage_options: Option<&HashMap<String, String>>,
) -> Result<String> {
Ok(format!("{}${}", url.scheme(), url.authority()))
}
}
#[derive(Debug, Clone, Default)]
pub struct ObjectStoreRegistryStats {
pub hits: u64,
pub misses: u64,
pub active_stores: usize,
}
#[derive(Debug)]
pub struct ObjectStoreRegistry {
providers: RwLock<HashMap<String, Arc<dyn ObjectStoreProvider>>>,
active_stores: RwLock<HashMap<(String, ObjectStoreParams), Weak<ObjectStore>>>,
hits: AtomicU64,
misses: AtomicU64,
}
impl ObjectStoreRegistry {
pub fn empty() -> Self {
Self {
providers: RwLock::new(HashMap::new()),
active_stores: RwLock::new(HashMap::new()),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
pub fn get_provider(&self, scheme: &str) -> Option<Arc<dyn ObjectStoreProvider>> {
self.providers
.read()
.expect("ObjectStoreRegistry lock poisoned")
.get(scheme)
.cloned()
}
pub fn active_stores(&self) -> Vec<Arc<ObjectStore>> {
let mut found_inactive = false;
let output = self
.active_stores
.read()
.expect("ObjectStoreRegistry lock poisoned")
.values()
.filter_map(|weak| match weak.upgrade() {
Some(store) => Some(store),
None => {
found_inactive = true;
None
}
})
.collect();
if found_inactive {
let mut cache_lock = self
.active_stores
.write()
.expect("ObjectStoreRegistry lock poisoned");
cache_lock.retain(|_, weak| weak.upgrade().is_some());
}
output
}
pub fn stats(&self) -> ObjectStoreRegistryStats {
let active_stores = self
.active_stores
.read()
.map(|s| s.values().filter(|w| w.strong_count() > 0).count())
.unwrap_or(0);
ObjectStoreRegistryStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
active_stores,
}
}
fn scheme_not_found_error(&self, scheme: &str) -> Error {
let mut message = format!("No object store provider found for scheme: '{}'", scheme);
if let Ok(providers) = self.providers.read() {
let valid_schemes = providers.keys().cloned().collect::<Vec<_>>().join(", ");
message.push_str(&format!("\nValid schemes: {}", valid_schemes));
}
Error::invalid_input(message)
}
pub async fn get_store(
&self,
base_path: Url,
params: &ObjectStoreParams,
) -> Result<Arc<ObjectStore>> {
let scheme = base_path.scheme();
let Some(provider) = self.get_provider(scheme) else {
return Err(self.scheme_not_found_error(scheme));
};
let cache_path =
provider.calculate_object_store_prefix(&base_path, params.storage_options())?;
let cache_key = (cache_path.clone(), params.clone());
{
let maybe_store = self
.active_stores
.read()
.ok()
.expect_ok()?
.get(&cache_key)
.cloned();
if let Some(store) = maybe_store {
if let Some(store) = store.upgrade() {
self.hits.fetch_add(1, Ordering::Relaxed);
return Ok(store);
} else {
let mut cache_lock = self
.active_stores
.write()
.expect("ObjectStoreRegistry lock poisoned");
if let Some(store) = cache_lock.get(&cache_key)
&& store.upgrade().is_none()
{
cache_lock.remove(&cache_key);
}
}
}
}
self.misses.fetch_add(1, Ordering::Relaxed);
let mut store = provider.new_store(base_path, params).await?;
store.inner = store.inner.traced();
if let Some(wrapper) = ¶ms.object_store_wrapper {
store.inner = wrapper.wrap(&cache_path, store.inner);
}
store.inner = store.io_tracker.wrap("", store.inner);
let store = Arc::new(store);
{
let mut cache_lock = self.active_stores.write().ok().expect_ok()?;
cache_lock.insert(cache_key, Arc::downgrade(&store));
}
Ok(store)
}
pub fn calculate_object_store_prefix(
&self,
uri: &str,
storage_options: Option<&HashMap<String, String>>,
) -> Result<String> {
let url = uri_to_url(uri)?;
match self.get_provider(url.scheme()) {
None => {
if url.scheme() == "file" || url.scheme().len() == 1 {
Ok("file".to_string())
} else {
Err(self.scheme_not_found_error(url.scheme()))
}
}
Some(provider) => provider.calculate_object_store_prefix(&url, storage_options),
}
}
}
impl Default for ObjectStoreRegistry {
fn default() -> Self {
let mut providers: HashMap<String, Arc<dyn ObjectStoreProvider>> = HashMap::new();
providers.insert("memory".into(), Arc::new(memory::MemoryStoreProvider));
providers.insert("file".into(), Arc::new(local::FileStoreProvider));
providers.insert(
"file-object-store".into(),
Arc::new(local::FileStoreProvider),
);
#[cfg(feature = "aws")]
{
let aws = Arc::new(aws::AwsStoreProvider);
providers.insert("s3".into(), aws.clone());
providers.insert("s3+ddb".into(), aws);
}
#[cfg(feature = "azure")]
{
let azure = Arc::new(azure::AzureBlobStoreProvider);
providers.insert("az".into(), azure.clone());
providers.insert("abfss".into(), azure);
}
#[cfg(feature = "gcp")]
providers.insert("gs".into(), Arc::new(gcp::GcsStoreProvider));
#[cfg(feature = "oss")]
providers.insert("oss".into(), Arc::new(oss::OssStoreProvider));
#[cfg(feature = "tencent")]
providers.insert("cos".into(), Arc::new(tencent::TencentStoreProvider));
#[cfg(feature = "huggingface")]
providers.insert("hf".into(), Arc::new(huggingface::HuggingfaceStoreProvider));
Self {
providers: RwLock::new(providers),
active_stores: RwLock::new(HashMap::new()),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
}
impl ObjectStoreRegistry {
pub fn insert(&self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
self.providers
.write()
.expect("ObjectStoreRegistry lock poisoned")
.insert(scheme.into(), provider);
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
#[derive(Debug)]
struct DummyProvider;
#[async_trait::async_trait]
impl ObjectStoreProvider for DummyProvider {
async fn new_store(
&self,
_base_path: Url,
_params: &ObjectStoreParams,
) -> Result<ObjectStore> {
unreachable!("This test doesn't create stores")
}
}
#[test]
fn test_calculate_object_store_prefix() {
let provider = DummyProvider;
let url = Url::parse("dummy://blah/path").unwrap();
assert_eq!(
"dummy$blah",
provider.calculate_object_store_prefix(&url, None).unwrap()
);
}
#[test]
fn test_calculate_object_store_scheme_not_found() {
let registry = ObjectStoreRegistry::empty();
registry.insert("dummy", Arc::new(DummyProvider));
let s = "Invalid user input: No object store provider found for scheme: 'dummy2'\nValid schemes: dummy";
let result = registry
.calculate_object_store_prefix("dummy2://mybucket/my/long/path", None)
.expect_err("expected error")
.to_string();
assert_eq!(s, &result[..s.len()]);
}
#[test]
fn test_calculate_object_store_prefix_for_local() {
let registry = ObjectStoreRegistry::empty();
assert_eq!(
"file",
registry
.calculate_object_store_prefix("/tmp/foobar", None)
.unwrap()
);
}
#[test]
fn test_calculate_object_store_prefix_for_local_windows_path() {
let registry = ObjectStoreRegistry::empty();
assert_eq!(
"file",
registry
.calculate_object_store_prefix("c://dos/path", None)
.unwrap()
);
}
#[test]
fn test_calculate_object_store_prefix_for_dummy_path() {
let registry = ObjectStoreRegistry::empty();
registry.insert("dummy", Arc::new(DummyProvider));
assert_eq!(
"dummy$mybucket",
registry
.calculate_object_store_prefix("dummy://mybucket/my/long/path", None)
.unwrap()
);
}
#[tokio::test]
async fn test_stats_hit_miss_tracking() {
use crate::object_store::StorageOptionsAccessor;
let registry = ObjectStoreRegistry::default();
let url = Url::parse("memory://test").unwrap();
let params1 = ObjectStoreParams::default();
let params2 = ObjectStoreParams {
storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
HashMap::from([("k".into(), "v".into())]),
))),
..Default::default()
};
let cases: &[(&ObjectStoreParams, (u64, u64, usize))] = &[
(¶ms1, (0, 1, 1)), (¶ms1, (1, 1, 1)), (¶ms2, (1, 2, 2)), ];
let mut stores = vec![]; for (params, (hits, misses, active)) in cases {
stores.push(registry.get_store(url.clone(), params).await.unwrap());
let s = registry.stats();
assert_eq!(
(s.hits, s.misses, s.active_stores),
(*hits, *misses, *active)
);
}
assert!(Arc::ptr_eq(&stores[0], &stores[1]));
}
}