1use std::{
5 collections::HashMap,
6 sync::{
7 Arc, RwLock, Weak,
8 atomic::{AtomicU64, Ordering},
9 },
10};
11
12use object_store::path::Path;
13use url::Url;
14
15use crate::object_store::WrappingObjectStore;
16use crate::object_store::uri_to_url;
17
18use super::{ObjectStore, ObjectStoreParams, tracing::ObjectStoreTracingExt};
19use lance_core::error::{Error, LanceOptionExt, Result};
20
21#[cfg(feature = "aws")]
22pub mod aws;
23#[cfg(feature = "azure")]
24pub mod azure;
25#[cfg(feature = "gcp")]
26pub mod gcp;
27#[cfg(feature = "huggingface")]
28pub mod huggingface;
29pub mod local;
30pub mod memory;
31#[cfg(feature = "oss")]
32pub mod oss;
33#[cfg(feature = "tencent")]
34pub mod tencent;
35
36#[async_trait::async_trait]
37pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
38 async fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
39
40 fn extract_path(&self, url: &Url) -> Result<Path> {
48 Path::parse(url.path())
49 .map_err(|_| Error::invalid_input(format!("Invalid path in URL: {}", url.path())))
50 }
51
52 fn calculate_object_store_prefix(
65 &self,
66 url: &Url,
67 _storage_options: Option<&HashMap<String, String>>,
68 ) -> Result<String> {
69 Ok(format!("{}${}", url.scheme(), url.authority()))
70 }
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct ObjectStoreRegistryStats {
76 pub hits: u64,
78 pub misses: u64,
80 pub active_stores: usize,
82}
83
84#[derive(Debug)]
105pub struct ObjectStoreRegistry {
106 providers: RwLock<HashMap<String, Arc<dyn ObjectStoreProvider>>>,
107 active_stores: RwLock<HashMap<(String, ObjectStoreParams), Weak<ObjectStore>>>,
111 hits: AtomicU64,
113 misses: AtomicU64,
114}
115
116impl ObjectStoreRegistry {
117 pub fn empty() -> Self {
122 Self {
123 providers: RwLock::new(HashMap::new()),
124 active_stores: RwLock::new(HashMap::new()),
125 hits: AtomicU64::new(0),
126 misses: AtomicU64::new(0),
127 }
128 }
129
130 pub fn get_provider(&self, scheme: &str) -> Option<Arc<dyn ObjectStoreProvider>> {
132 self.providers
133 .read()
134 .expect("ObjectStoreRegistry lock poisoned")
135 .get(scheme)
136 .cloned()
137 }
138
139 pub fn active_stores(&self) -> Vec<Arc<ObjectStore>> {
144 let mut found_inactive = false;
145 let output = self
146 .active_stores
147 .read()
148 .expect("ObjectStoreRegistry lock poisoned")
149 .values()
150 .filter_map(|weak| match weak.upgrade() {
151 Some(store) => Some(store),
152 None => {
153 found_inactive = true;
154 None
155 }
156 })
157 .collect();
158
159 if found_inactive {
160 let mut cache_lock = self
162 .active_stores
163 .write()
164 .expect("ObjectStoreRegistry lock poisoned");
165 cache_lock.retain(|_, weak| weak.upgrade().is_some());
166 }
167 output
168 }
169
170 pub fn stats(&self) -> ObjectStoreRegistryStats {
176 let active_stores = self
177 .active_stores
178 .read()
179 .map(|s| s.values().filter(|w| w.strong_count() > 0).count())
180 .unwrap_or(0);
181 ObjectStoreRegistryStats {
182 hits: self.hits.load(Ordering::Relaxed),
183 misses: self.misses.load(Ordering::Relaxed),
184 active_stores,
185 }
186 }
187
188 fn scheme_not_found_error(&self, scheme: &str) -> Error {
189 let mut message = format!("No object store provider found for scheme: '{}'", scheme);
190 if let Ok(providers) = self.providers.read() {
191 let valid_schemes = providers.keys().cloned().collect::<Vec<_>>().join(", ");
192 message.push_str(&format!("\nValid schemes: {}", valid_schemes));
193 }
194 Error::invalid_input(message)
195 }
196
197 pub async fn get_store(
203 &self,
204 base_path: Url,
205 params: &ObjectStoreParams,
206 ) -> Result<Arc<ObjectStore>> {
207 let scheme = base_path.scheme();
208 let Some(provider) = self.get_provider(scheme) else {
209 return Err(self.scheme_not_found_error(scheme));
210 };
211
212 let cache_path =
213 provider.calculate_object_store_prefix(&base_path, params.storage_options())?;
214 let cache_key = (cache_path.clone(), params.clone());
215
216 {
218 let maybe_store = self
219 .active_stores
220 .read()
221 .ok()
222 .expect_ok()?
223 .get(&cache_key)
224 .cloned();
225 if let Some(store) = maybe_store {
226 if let Some(store) = store.upgrade() {
227 self.hits.fetch_add(1, Ordering::Relaxed);
228 return Ok(store);
229 } else {
230 let mut cache_lock = self
232 .active_stores
233 .write()
234 .expect("ObjectStoreRegistry lock poisoned");
235 if let Some(store) = cache_lock.get(&cache_key)
236 && store.upgrade().is_none()
237 {
238 cache_lock.remove(&cache_key);
240 }
241 }
242 }
243 }
244
245 self.misses.fetch_add(1, Ordering::Relaxed);
246
247 let mut store = provider.new_store(base_path, params).await?;
248
249 store.inner = store.inner.traced();
250
251 if let Some(wrapper) = ¶ms.object_store_wrapper {
252 store.inner = wrapper.wrap(&cache_path, store.inner);
253 }
254
255 store.inner = store.io_tracker.wrap("", store.inner);
257
258 let store = Arc::new(store);
259
260 {
261 let mut cache_lock = self.active_stores.write().ok().expect_ok()?;
263 cache_lock.insert(cache_key, Arc::downgrade(&store));
264 }
265
266 Ok(store)
267 }
268
269 pub fn calculate_object_store_prefix(
272 &self,
273 uri: &str,
274 storage_options: Option<&HashMap<String, String>>,
275 ) -> Result<String> {
276 let url = uri_to_url(uri)?;
277 match self.get_provider(url.scheme()) {
278 None => {
279 if url.scheme() == "file" || url.scheme().len() == 1 {
280 Ok("file".to_string())
281 } else {
282 Err(self.scheme_not_found_error(url.scheme()))
283 }
284 }
285 Some(provider) => provider.calculate_object_store_prefix(&url, storage_options),
286 }
287 }
288}
289
290impl Default for ObjectStoreRegistry {
291 fn default() -> Self {
292 let mut providers: HashMap<String, Arc<dyn ObjectStoreProvider>> = HashMap::new();
293
294 providers.insert("memory".into(), Arc::new(memory::MemoryStoreProvider));
295 providers.insert("file".into(), Arc::new(local::FileStoreProvider));
296 providers.insert(
302 "file-object-store".into(),
303 Arc::new(local::FileStoreProvider),
304 );
305 #[cfg(target_os = "linux")]
306 providers.insert("file+uring".into(), Arc::new(local::FileStoreProvider));
307
308 #[cfg(feature = "aws")]
309 {
310 let aws = Arc::new(aws::AwsStoreProvider);
311 providers.insert("s3".into(), aws.clone());
312 providers.insert("s3+ddb".into(), aws);
313 }
314 #[cfg(feature = "azure")]
315 {
316 let azure = Arc::new(azure::AzureBlobStoreProvider);
317 providers.insert("az".into(), azure.clone());
318 providers.insert("abfss".into(), azure);
319 }
320 #[cfg(feature = "gcp")]
321 providers.insert("gs".into(), Arc::new(gcp::GcsStoreProvider));
322 #[cfg(feature = "oss")]
323 providers.insert("oss".into(), Arc::new(oss::OssStoreProvider));
324 #[cfg(feature = "tencent")]
325 providers.insert("cos".into(), Arc::new(tencent::TencentStoreProvider));
326 #[cfg(feature = "huggingface")]
327 providers.insert("hf".into(), Arc::new(huggingface::HuggingfaceStoreProvider));
328 Self {
329 providers: RwLock::new(providers),
330 active_stores: RwLock::new(HashMap::new()),
331 hits: AtomicU64::new(0),
332 misses: AtomicU64::new(0),
333 }
334 }
335}
336
337impl ObjectStoreRegistry {
338 pub fn insert(&self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
341 self.providers
342 .write()
343 .expect("ObjectStoreRegistry lock poisoned")
344 .insert(scheme.into(), provider);
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use std::collections::HashMap;
351
352 use super::*;
353
354 #[derive(Debug)]
355 struct DummyProvider;
356
357 #[async_trait::async_trait]
358 impl ObjectStoreProvider for DummyProvider {
359 async fn new_store(
360 &self,
361 _base_path: Url,
362 _params: &ObjectStoreParams,
363 ) -> Result<ObjectStore> {
364 unreachable!("This test doesn't create stores")
365 }
366 }
367
368 #[test]
369 fn test_calculate_object_store_prefix() {
370 let provider = DummyProvider;
371 let url = Url::parse("dummy://blah/path").unwrap();
372 assert_eq!(
373 "dummy$blah",
374 provider.calculate_object_store_prefix(&url, None).unwrap()
375 );
376 }
377
378 #[test]
379 fn test_calculate_object_store_scheme_not_found() {
380 let registry = ObjectStoreRegistry::empty();
381 registry.insert("dummy", Arc::new(DummyProvider));
382 let s = "Invalid user input: No object store provider found for scheme: 'dummy2'\nValid schemes: dummy";
383 let result = registry
384 .calculate_object_store_prefix("dummy2://mybucket/my/long/path", None)
385 .expect_err("expected error")
386 .to_string();
387 assert_eq!(s, &result[..s.len()]);
388 }
389
390 #[test]
392 fn test_calculate_object_store_prefix_for_local() {
393 let registry = ObjectStoreRegistry::empty();
394 assert_eq!(
395 "file",
396 registry
397 .calculate_object_store_prefix("/tmp/foobar", None)
398 .unwrap()
399 );
400 }
401
402 #[test]
404 fn test_calculate_object_store_prefix_for_local_windows_path() {
405 let registry = ObjectStoreRegistry::empty();
406 assert_eq!(
407 "file",
408 registry
409 .calculate_object_store_prefix("c://dos/path", None)
410 .unwrap()
411 );
412 }
413
414 #[test]
416 fn test_calculate_object_store_prefix_for_dummy_path() {
417 let registry = ObjectStoreRegistry::empty();
418 registry.insert("dummy", Arc::new(DummyProvider));
419 assert_eq!(
420 "dummy$mybucket",
421 registry
422 .calculate_object_store_prefix("dummy://mybucket/my/long/path", None)
423 .unwrap()
424 );
425 }
426
427 #[tokio::test]
428 async fn test_stats_hit_miss_tracking() {
429 use crate::object_store::StorageOptionsAccessor;
430 let registry = ObjectStoreRegistry::default();
431 let url = Url::parse("memory://test").unwrap();
432
433 let params1 = ObjectStoreParams::default();
434 let params2 = ObjectStoreParams {
435 storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
436 HashMap::from([("k".into(), "v".into())]),
437 ))),
438 ..Default::default()
439 };
440
441 let cases: &[(&ObjectStoreParams, (u64, u64, usize))] = &[
443 (¶ms1, (0, 1, 1)), (¶ms1, (1, 1, 1)), (¶ms2, (1, 2, 2)), ];
447
448 let mut stores = vec![]; for (params, (hits, misses, active)) in cases {
450 stores.push(registry.get_store(url.clone(), params).await.unwrap());
451 let s = registry.stats();
452 assert_eq!(
453 (s.hits, s.misses, s.active_stores),
454 (*hits, *misses, *active)
455 );
456 }
457
458 assert!(Arc::ptr_eq(&stores[0], &stores[1]));
460 }
461}