1use std::{
5 collections::HashMap,
6 sync::{
7 Arc, RwLock, Weak,
8 atomic::{AtomicU64, Ordering},
9 },
10};
11
12use object_store::path::Path;
13use url::Url;
14
15use crate::object_store::WrappingObjectStore;
16use crate::object_store::uri_to_url;
17
18use super::{ObjectStore, ObjectStoreParams, tracing::ObjectStoreTracingExt};
19use lance_core::error::{Error, LanceOptionExt, Result};
20
21#[cfg(feature = "aws")]
22pub mod aws;
23#[cfg(feature = "azure")]
24pub mod azure;
25#[cfg(feature = "gcp")]
26pub mod gcp;
27#[cfg(feature = "huggingface")]
28pub mod huggingface;
29pub mod local;
30pub mod memory;
31#[cfg(feature = "oss")]
32pub mod oss;
33#[cfg(feature = "tencent")]
34pub mod tencent;
35
36#[async_trait::async_trait]
37pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
38 async fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
39
40 fn extract_path(&self, url: &Url) -> Result<Path> {
48 Path::parse(url.path())
49 .map_err(|_| Error::invalid_input(format!("Invalid path in URL: {}", url.path())))
50 }
51
52 fn calculate_object_store_prefix(
65 &self,
66 url: &Url,
67 _storage_options: Option<&HashMap<String, String>>,
68 ) -> Result<String> {
69 Ok(format!("{}${}", url.scheme(), url.authority()))
70 }
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct ObjectStoreRegistryStats {
76 pub hits: u64,
78 pub misses: u64,
80 pub active_stores: usize,
82}
83
84#[derive(Debug)]
104pub struct ObjectStoreRegistry {
105 providers: RwLock<HashMap<String, Arc<dyn ObjectStoreProvider>>>,
106 active_stores: RwLock<HashMap<(String, ObjectStoreParams), Weak<ObjectStore>>>,
110 hits: AtomicU64,
112 misses: AtomicU64,
113}
114
115impl ObjectStoreRegistry {
116 pub fn empty() -> Self {
121 Self {
122 providers: RwLock::new(HashMap::new()),
123 active_stores: RwLock::new(HashMap::new()),
124 hits: AtomicU64::new(0),
125 misses: AtomicU64::new(0),
126 }
127 }
128
129 pub fn get_provider(&self, scheme: &str) -> Option<Arc<dyn ObjectStoreProvider>> {
131 self.providers
132 .read()
133 .expect("ObjectStoreRegistry lock poisoned")
134 .get(scheme)
135 .cloned()
136 }
137
138 pub fn active_stores(&self) -> Vec<Arc<ObjectStore>> {
143 let mut found_inactive = false;
144 let output = self
145 .active_stores
146 .read()
147 .expect("ObjectStoreRegistry lock poisoned")
148 .values()
149 .filter_map(|weak| match weak.upgrade() {
150 Some(store) => Some(store),
151 None => {
152 found_inactive = true;
153 None
154 }
155 })
156 .collect();
157
158 if found_inactive {
159 let mut cache_lock = self
161 .active_stores
162 .write()
163 .expect("ObjectStoreRegistry lock poisoned");
164 cache_lock.retain(|_, weak| weak.upgrade().is_some());
165 }
166 output
167 }
168
169 pub fn stats(&self) -> ObjectStoreRegistryStats {
175 let active_stores = self
176 .active_stores
177 .read()
178 .map(|s| s.values().filter(|w| w.strong_count() > 0).count())
179 .unwrap_or(0);
180 ObjectStoreRegistryStats {
181 hits: self.hits.load(Ordering::Relaxed),
182 misses: self.misses.load(Ordering::Relaxed),
183 active_stores,
184 }
185 }
186
187 fn scheme_not_found_error(&self, scheme: &str) -> Error {
188 let mut message = format!("No object store provider found for scheme: '{}'", scheme);
189 if let Ok(providers) = self.providers.read() {
190 let valid_schemes = providers.keys().cloned().collect::<Vec<_>>().join(", ");
191 message.push_str(&format!("\nValid schemes: {}", valid_schemes));
192 }
193 Error::invalid_input(message)
194 }
195
196 pub async fn get_store(
202 &self,
203 base_path: Url,
204 params: &ObjectStoreParams,
205 ) -> Result<Arc<ObjectStore>> {
206 let scheme = base_path.scheme();
207 let Some(provider) = self.get_provider(scheme) else {
208 return Err(self.scheme_not_found_error(scheme));
209 };
210
211 let cache_path =
212 provider.calculate_object_store_prefix(&base_path, params.storage_options())?;
213 let cache_key = (cache_path.clone(), params.clone());
214
215 {
217 let maybe_store = self
218 .active_stores
219 .read()
220 .ok()
221 .expect_ok()?
222 .get(&cache_key)
223 .cloned();
224 if let Some(store) = maybe_store {
225 if let Some(store) = store.upgrade() {
226 self.hits.fetch_add(1, Ordering::Relaxed);
227 return Ok(store);
228 } else {
229 let mut cache_lock = self
231 .active_stores
232 .write()
233 .expect("ObjectStoreRegistry lock poisoned");
234 if let Some(store) = cache_lock.get(&cache_key)
235 && store.upgrade().is_none()
236 {
237 cache_lock.remove(&cache_key);
239 }
240 }
241 }
242 }
243
244 self.misses.fetch_add(1, Ordering::Relaxed);
245
246 let mut store = provider.new_store(base_path, params).await?;
247
248 store.inner = store.inner.traced();
249
250 if let Some(wrapper) = ¶ms.object_store_wrapper {
251 store.inner = wrapper.wrap(&cache_path, store.inner);
252 }
253
254 store.inner = store.io_tracker.wrap("", store.inner);
256
257 let store = Arc::new(store);
258
259 {
260 let mut cache_lock = self.active_stores.write().ok().expect_ok()?;
262 cache_lock.insert(cache_key, Arc::downgrade(&store));
263 }
264
265 Ok(store)
266 }
267
268 pub fn calculate_object_store_prefix(
271 &self,
272 uri: &str,
273 storage_options: Option<&HashMap<String, String>>,
274 ) -> Result<String> {
275 let url = uri_to_url(uri)?;
276 match self.get_provider(url.scheme()) {
277 None => {
278 if url.scheme() == "file" || url.scheme().len() == 1 {
279 Ok("file".to_string())
280 } else {
281 Err(self.scheme_not_found_error(url.scheme()))
282 }
283 }
284 Some(provider) => provider.calculate_object_store_prefix(&url, storage_options),
285 }
286 }
287}
288
289impl Default for ObjectStoreRegistry {
290 fn default() -> Self {
291 let mut providers: HashMap<String, Arc<dyn ObjectStoreProvider>> = HashMap::new();
292
293 providers.insert("memory".into(), Arc::new(memory::MemoryStoreProvider));
294 providers.insert("file".into(), Arc::new(local::FileStoreProvider));
295 providers.insert(
301 "file-object-store".into(),
302 Arc::new(local::FileStoreProvider),
303 );
304
305 #[cfg(feature = "aws")]
306 {
307 let aws = Arc::new(aws::AwsStoreProvider);
308 providers.insert("s3".into(), aws.clone());
309 providers.insert("s3+ddb".into(), aws);
310 }
311 #[cfg(feature = "azure")]
312 {
313 let azure = Arc::new(azure::AzureBlobStoreProvider);
314 providers.insert("az".into(), azure.clone());
315 providers.insert("abfss".into(), azure);
316 }
317 #[cfg(feature = "gcp")]
318 providers.insert("gs".into(), Arc::new(gcp::GcsStoreProvider));
319 #[cfg(feature = "oss")]
320 providers.insert("oss".into(), Arc::new(oss::OssStoreProvider));
321 #[cfg(feature = "tencent")]
322 providers.insert("cos".into(), Arc::new(tencent::TencentStoreProvider));
323 #[cfg(feature = "huggingface")]
324 providers.insert("hf".into(), Arc::new(huggingface::HuggingfaceStoreProvider));
325 Self {
326 providers: RwLock::new(providers),
327 active_stores: RwLock::new(HashMap::new()),
328 hits: AtomicU64::new(0),
329 misses: AtomicU64::new(0),
330 }
331 }
332}
333
334impl ObjectStoreRegistry {
335 pub fn insert(&self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
338 self.providers
339 .write()
340 .expect("ObjectStoreRegistry lock poisoned")
341 .insert(scheme.into(), provider);
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use std::collections::HashMap;
348
349 use super::*;
350
351 #[derive(Debug)]
352 struct DummyProvider;
353
354 #[async_trait::async_trait]
355 impl ObjectStoreProvider for DummyProvider {
356 async fn new_store(
357 &self,
358 _base_path: Url,
359 _params: &ObjectStoreParams,
360 ) -> Result<ObjectStore> {
361 unreachable!("This test doesn't create stores")
362 }
363 }
364
365 #[test]
366 fn test_calculate_object_store_prefix() {
367 let provider = DummyProvider;
368 let url = Url::parse("dummy://blah/path").unwrap();
369 assert_eq!(
370 "dummy$blah",
371 provider.calculate_object_store_prefix(&url, None).unwrap()
372 );
373 }
374
375 #[test]
376 fn test_calculate_object_store_scheme_not_found() {
377 let registry = ObjectStoreRegistry::empty();
378 registry.insert("dummy", Arc::new(DummyProvider));
379 let s = "Invalid user input: No object store provider found for scheme: 'dummy2'\nValid schemes: dummy";
380 let result = registry
381 .calculate_object_store_prefix("dummy2://mybucket/my/long/path", None)
382 .expect_err("expected error")
383 .to_string();
384 assert_eq!(s, &result[..s.len()]);
385 }
386
387 #[test]
389 fn test_calculate_object_store_prefix_for_local() {
390 let registry = ObjectStoreRegistry::empty();
391 assert_eq!(
392 "file",
393 registry
394 .calculate_object_store_prefix("/tmp/foobar", None)
395 .unwrap()
396 );
397 }
398
399 #[test]
401 fn test_calculate_object_store_prefix_for_local_windows_path() {
402 let registry = ObjectStoreRegistry::empty();
403 assert_eq!(
404 "file",
405 registry
406 .calculate_object_store_prefix("c://dos/path", None)
407 .unwrap()
408 );
409 }
410
411 #[test]
413 fn test_calculate_object_store_prefix_for_dummy_path() {
414 let registry = ObjectStoreRegistry::empty();
415 registry.insert("dummy", Arc::new(DummyProvider));
416 assert_eq!(
417 "dummy$mybucket",
418 registry
419 .calculate_object_store_prefix("dummy://mybucket/my/long/path", None)
420 .unwrap()
421 );
422 }
423
424 #[tokio::test]
425 async fn test_stats_hit_miss_tracking() {
426 use crate::object_store::StorageOptionsAccessor;
427 let registry = ObjectStoreRegistry::default();
428 let url = Url::parse("memory://test").unwrap();
429
430 let params1 = ObjectStoreParams::default();
431 let params2 = ObjectStoreParams {
432 storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
433 HashMap::from([("k".into(), "v".into())]),
434 ))),
435 ..Default::default()
436 };
437
438 let cases: &[(&ObjectStoreParams, (u64, u64, usize))] = &[
440 (¶ms1, (0, 1, 1)), (¶ms1, (1, 1, 1)), (¶ms2, (1, 2, 2)), ];
444
445 let mut stores = vec![]; for (params, (hits, misses, active)) in cases {
447 stores.push(registry.get_store(url.clone(), params).await.unwrap());
448 let s = registry.stats();
449 assert_eq!(
450 (s.hits, s.misses, s.active_stores),
451 (*hits, *misses, *active)
452 );
453 }
454
455 assert!(Arc::ptr_eq(&stores[0], &stores[1]));
457 }
458}