use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use super::builder::{build_object_store, DynObjectStore};
use super::config::CloudConfig;
use super::error::StorageError;
use super::url::{Scheme, StorageUrl};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct DriverKey {
scheme: Scheme,
root: String,
}
impl DriverKey {
fn from_url(url: &StorageUrl) -> Self {
let scheme = url.scheme();
let root = match scheme {
Scheme::File | Scheme::Memory => String::new(),
_ => url.path().split('/').next().unwrap_or_default().to_string(),
};
Self { scheme, root }
}
}
#[derive(Clone, Default)]
pub struct StorageRegistry {
inner: Arc<Mutex<HashMap<DriverKey, DynObjectStore>>>,
default_cloud: Option<Arc<CloudConfig>>,
}
impl StorageRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn with_default_cloud(default_cloud: Option<CloudConfig>) -> Self {
Self {
inner: Arc::default(),
default_cloud: default_cloud.map(Arc::new),
}
}
pub fn default_cloud(&self) -> Option<&CloudConfig> {
self.default_cloud.as_deref()
}
pub fn driver_for(
&self,
url: &StorageUrl,
config: Option<&CloudConfig>,
) -> Result<DynObjectStore, StorageError> {
let key = DriverKey::from_url(url);
let mut guard = self.inner.lock().expect("storage registry mutex poisoned");
if let Some(existing) = guard.get(&key) {
return Ok(Arc::clone(existing));
}
let effective = config.or(self.default_cloud.as_deref());
let driver = build_object_store(url, effective)?;
guard.insert(key, Arc::clone(&driver));
Ok(driver)
}
pub fn evict(&self, url: &StorageUrl) {
let key = DriverKey::from_url(url);
let mut guard = self.inner.lock().expect("storage registry mutex poisoned");
guard.remove(&key);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn caches_drivers_per_root() {
let r = StorageRegistry::new();
let a = StorageUrl::memory("benchmarks/2026.parquet");
let b = StorageUrl::memory("benchmarks/2027.parquet");
let d1 = r.driver_for(&a, None).unwrap();
let d2 = r.driver_for(&b, None).unwrap();
assert!(Arc::ptr_eq(&d1, &d2));
}
#[test]
fn evict_clears_cache() {
let r = StorageRegistry::new();
let u = StorageUrl::memory("snapshots/x");
let d1 = r.driver_for(&u, None).unwrap();
r.evict(&u);
let d2 = r.driver_for(&u, None).unwrap();
assert!(!Arc::ptr_eq(&d1, &d2));
}
#[test]
fn default_cloud_is_exposed() {
use super::super::config::{CloudConfig, R2Config};
let cfg = CloudConfig::R2(R2Config {
account_id: Some("acct".into()),
..Default::default()
});
let r = StorageRegistry::with_default_cloud(Some(cfg));
assert!(matches!(r.default_cloud(), Some(CloudConfig::R2(_))));
let empty = StorageRegistry::new();
assert!(empty.default_cloud().is_none());
}
#[cfg(feature = "storage-r2")]
#[test]
fn driver_for_falls_back_to_default_cloud() {
use super::super::config::{CloudConfig, R2Config};
let cfg = CloudConfig::R2(R2Config {
account_id: Some("abc123".into()),
access_key_id: Some("k".into()),
secret_access_key: Some("s".into()),
..Default::default()
});
let r = StorageRegistry::with_default_cloud(Some(cfg));
let url = StorageUrl::parse("r2://archives/x").unwrap();
let store = r
.driver_for(&url, None)
.expect("r2 driver builds via default cloud");
let _ = format!("{store}");
}
}