lance_io/object_store/
providers.rs1use std::{
5 collections::HashMap,
6 sync::{Arc, RwLock, Weak},
7};
8
9use object_store::path::Path;
10use snafu::location;
11use url::Url;
12
13use super::{tracing::ObjectStoreTracingExt, ObjectStore, ObjectStoreParams};
14use lance_core::error::{Error, LanceOptionExt, Result};
15
16#[cfg(feature = "aws")]
17pub mod aws;
18#[cfg(feature = "azure")]
19pub mod azure;
20#[cfg(feature = "gcp")]
21pub mod gcp;
22pub mod local;
23pub mod memory;
24#[cfg(feature = "oss")]
25pub mod oss;
26
27#[async_trait::async_trait]
28pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
29 async fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
30
31 fn extract_path(&self, url: &Url) -> Result<Path> {
39 Path::parse(url.path()).map_err(|_| {
40 Error::invalid_input(format!("Invalid path in URL: {}", url.path()), location!())
41 })
42 }
43
44 fn cache_url(&self, url: &Url) -> String {
50 if ["file", "file-object-store", "memory"].contains(&url.scheme()) {
51 format!("{}://", url.scheme())
55 } else {
56 let mut url = url.clone();
58 url.set_path("");
59 url.to_string()
60 }
61 }
62}
63
64#[derive(Debug)]
84pub struct ObjectStoreRegistry {
85 providers: RwLock<HashMap<String, Arc<dyn ObjectStoreProvider>>>,
86 active_stores: RwLock<HashMap<(String, ObjectStoreParams), Weak<ObjectStore>>>,
90}
91
92impl ObjectStoreRegistry {
93 pub fn empty() -> Self {
98 Self {
99 providers: RwLock::new(HashMap::new()),
100 active_stores: RwLock::new(HashMap::new()),
101 }
102 }
103
104 pub fn get_provider(&self, scheme: &str) -> Option<Arc<dyn ObjectStoreProvider>> {
106 self.providers
107 .read()
108 .expect("ObjectStoreRegistry lock poisoned")
109 .get(scheme)
110 .cloned()
111 }
112
113 pub fn active_stores(&self) -> Vec<Arc<ObjectStore>> {
118 let mut found_inactive = false;
119 let output = self
120 .active_stores
121 .read()
122 .expect("ObjectStoreRegistry lock poisoned")
123 .values()
124 .filter_map(|weak| match weak.upgrade() {
125 Some(store) => Some(store),
126 None => {
127 found_inactive = true;
128 None
129 }
130 })
131 .collect();
132
133 if found_inactive {
134 let mut cache_lock = self
136 .active_stores
137 .write()
138 .expect("ObjectStoreRegistry lock poisoned");
139 cache_lock.retain(|_, weak| weak.upgrade().is_some());
140 }
141 output
142 }
143
144 pub async fn get_store(
150 &self,
151 base_path: Url,
152 params: &ObjectStoreParams,
153 ) -> Result<Arc<ObjectStore>> {
154 let scheme = base_path.scheme();
155 let Some(provider) = self.get_provider(scheme) else {
156 let mut message = format!("No object store provider found for scheme: '{}'", scheme);
157 if let Ok(providers) = self.providers.read() {
158 let valid_schemes = providers.keys().cloned().collect::<Vec<_>>().join(", ");
159 message.push_str(&format!("\nValid schemes: {}", valid_schemes));
160 }
161 return Err(Error::invalid_input(message, location!()));
162 };
163
164 let cache_path = provider.cache_url(&base_path);
165 let cache_key = (cache_path, params.clone());
166
167 {
169 let maybe_store = self
170 .active_stores
171 .read()
172 .ok()
173 .expect_ok()?
174 .get(&cache_key)
175 .cloned();
176 if let Some(store) = maybe_store {
177 if let Some(store) = store.upgrade() {
178 return Ok(store);
179 } else {
180 let mut cache_lock = self
182 .active_stores
183 .write()
184 .expect("ObjectStoreRegistry lock poisoned");
185 if let Some(store) = cache_lock.get(&cache_key) {
186 if store.upgrade().is_none() {
187 cache_lock.remove(&cache_key);
189 }
190 }
191 }
192 }
193 }
194
195 let mut store = provider.new_store(base_path, params).await?;
196
197 store.inner = store.inner.traced();
198
199 if let Some(wrapper) = ¶ms.object_store_wrapper {
200 store.inner = wrapper.wrap(store.inner, params.storage_options.as_ref());
201 }
202
203 let store = Arc::new(store);
204
205 {
206 let mut cache_lock = self.active_stores.write().ok().expect_ok()?;
208 cache_lock.insert(cache_key, Arc::downgrade(&store));
209 }
210
211 Ok(store)
212 }
213}
214
215impl Default for ObjectStoreRegistry {
216 fn default() -> Self {
217 let mut providers: HashMap<String, Arc<dyn ObjectStoreProvider>> = HashMap::new();
218
219 providers.insert("memory".into(), Arc::new(memory::MemoryStoreProvider));
220 providers.insert("file".into(), Arc::new(local::FileStoreProvider));
221 providers.insert(
227 "file-object-store".into(),
228 Arc::new(local::FileStoreProvider),
229 );
230
231 #[cfg(feature = "aws")]
232 {
233 let aws = Arc::new(aws::AwsStoreProvider);
234 providers.insert("s3".into(), aws.clone());
235 providers.insert("s3+ddb".into(), aws);
236 }
237 #[cfg(feature = "azure")]
238 providers.insert("az".into(), Arc::new(azure::AzureBlobStoreProvider));
239 #[cfg(feature = "gcp")]
240 providers.insert("gs".into(), Arc::new(gcp::GcsStoreProvider));
241 #[cfg(feature = "oss")]
242 providers.insert("oss".into(), Arc::new(oss::OssStoreProvider));
243 Self {
244 providers: RwLock::new(providers),
245 active_stores: RwLock::new(HashMap::new()),
246 }
247 }
248}
249
250impl ObjectStoreRegistry {
251 pub fn insert(&self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
254 self.providers
255 .write()
256 .expect("ObjectStoreRegistry lock poisoned")
257 .insert(scheme.into(), provider);
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_cache_url() {
267 #[derive(Debug)]
269 struct DummyProvider;
270
271 #[async_trait::async_trait]
272 impl ObjectStoreProvider for DummyProvider {
273 async fn new_store(
274 &self,
275 _base_path: Url,
276 _params: &ObjectStoreParams,
277 ) -> Result<ObjectStore> {
278 unreachable!("This test doesn't create stores")
279 }
280 }
281
282 let provider = DummyProvider;
283 let cases = [
284 ("s3://bucket/path?param=value", "s3://bucket?param=value"),
285 ("file:///path/to/file", "file://"),
286 ("file-object-store:///path/to/file", "file-object-store://"),
287 ("memory:///", "memory://"),
288 (
289 "http://example.com/path?param=value",
290 "http://example.com/?param=value",
291 ),
292 ];
293
294 for (url, expected_cache_url) in cases {
295 let url = Url::parse(url).unwrap();
296 let cache_url = provider.cache_url(&url);
297 assert_eq!(cache_url, expected_cache_url);
298 }
299 }
300}