lance_io/object_store/
providers.rs1use std::{
5 collections::HashMap,
6 sync::{Arc, RwLock, Weak},
7};
8
9use object_store::path::Path;
10use snafu::location;
11use url::Url;
12
13use super::{tracing::ObjectStoreTracingExt, ObjectStore, ObjectStoreParams};
14use lance_core::error::{Error, LanceOptionExt, Result};
15
16#[cfg(feature = "aws")]
17pub mod aws;
18#[cfg(feature = "azure")]
19pub mod azure;
20#[cfg(feature = "gcp")]
21pub mod gcp;
22pub mod local;
23pub mod memory;
24
25#[async_trait::async_trait]
26pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
27 async fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
28
29 fn extract_path(&self, url: &Url) -> Path {
37 Path::from(url.path())
38 }
39}
40
41#[derive(Debug)]
61pub struct ObjectStoreRegistry {
62 providers: RwLock<HashMap<String, Arc<dyn ObjectStoreProvider>>>,
63 active_stores: RwLock<HashMap<(String, ObjectStoreParams), Weak<ObjectStore>>>,
67}
68
69fn cache_url(url: &Url) -> String {
78 if ["file", "file-object-store", "memory"].contains(&url.scheme()) {
79 format!("{}://", url.scheme())
83 } else {
84 let mut url = url.clone();
86 url.set_path("");
87 url.to_string()
88 }
89}
90
91impl ObjectStoreRegistry {
92 pub fn empty() -> Self {
97 Self {
98 providers: RwLock::new(HashMap::new()),
99 active_stores: RwLock::new(HashMap::new()),
100 }
101 }
102
103 pub fn get_provider(&self, scheme: &str) -> Option<Arc<dyn ObjectStoreProvider>> {
105 self.providers
106 .read()
107 .expect("ObjectStoreRegistry lock poisoned")
108 .get(scheme)
109 .cloned()
110 }
111
112 pub fn active_stores(&self) -> Vec<Arc<ObjectStore>> {
117 let mut found_inactive = false;
118 let output = self
119 .active_stores
120 .read()
121 .expect("ObjectStoreRegistry lock poisoned")
122 .values()
123 .filter_map(|weak| match weak.upgrade() {
124 Some(store) => Some(store),
125 None => {
126 found_inactive = true;
127 None
128 }
129 })
130 .collect();
131
132 if found_inactive {
133 let mut cache_lock = self
135 .active_stores
136 .write()
137 .expect("ObjectStoreRegistry lock poisoned");
138 cache_lock.retain(|_, weak| weak.upgrade().is_some());
139 }
140 output
141 }
142
143 pub async fn get_store(
149 &self,
150 base_path: Url,
151 params: &ObjectStoreParams,
152 ) -> Result<Arc<ObjectStore>> {
153 let cache_path = cache_url(&base_path);
154 let cache_key = (cache_path, params.clone());
155
156 {
158 let maybe_store = self
159 .active_stores
160 .read()
161 .ok()
162 .expect_ok()?
163 .get(&cache_key)
164 .cloned();
165 if let Some(store) = maybe_store {
166 if let Some(store) = store.upgrade() {
167 return Ok(store);
168 } else {
169 let mut cache_lock = self
171 .active_stores
172 .write()
173 .expect("ObjectStoreRegistry lock poisoned");
174 if let Some(store) = cache_lock.get(&cache_key) {
175 if store.upgrade().is_none() {
176 cache_lock.remove(&cache_key);
178 }
179 }
180 }
181 }
182 }
183
184 let scheme = base_path.scheme();
185 let Some(provider) = self.get_provider(scheme) else {
186 let mut message = format!("No object store provider found for scheme: '{}'", scheme);
187 if let Ok(providers) = self.providers.read() {
188 let valid_schemes = providers.keys().cloned().collect::<Vec<_>>().join(", ");
189 message.push_str(&format!("\nValid schemes: {}", valid_schemes));
190 }
191
192 return Err(Error::invalid_input(message, location!()));
193 };
194 let mut store = provider.new_store(base_path, params).await?;
195
196 store.inner = store.inner.traced();
197
198 if let Some(wrapper) = ¶ms.object_store_wrapper {
199 store.inner = wrapper.wrap(store.inner);
200 }
201
202 let store = Arc::new(store);
203
204 {
205 let mut cache_lock = self.active_stores.write().ok().expect_ok()?;
207 cache_lock.insert(cache_key, Arc::downgrade(&store));
208 }
209
210 Ok(store)
211 }
212}
213
214impl Default for ObjectStoreRegistry {
215 fn default() -> Self {
216 let mut providers: HashMap<String, Arc<dyn ObjectStoreProvider>> = HashMap::new();
217
218 providers.insert("memory".into(), Arc::new(memory::MemoryStoreProvider));
219 providers.insert("file".into(), Arc::new(local::FileStoreProvider));
220 providers.insert(
226 "file-object-store".into(),
227 Arc::new(local::FileStoreProvider),
228 );
229
230 #[cfg(feature = "aws")]
231 {
232 let aws = Arc::new(aws::AwsStoreProvider);
233 providers.insert("s3".into(), aws.clone());
234 providers.insert("s3+ddb".into(), aws);
235 }
236 #[cfg(feature = "azure")]
237 providers.insert("az".into(), Arc::new(azure::AzureBlobStoreProvider));
238 #[cfg(feature = "gcp")]
239 providers.insert("gs".into(), Arc::new(gcp::GcsStoreProvider));
240 Self {
241 providers: RwLock::new(providers),
242 active_stores: RwLock::new(HashMap::new()),
243 }
244 }
245}
246
247impl ObjectStoreRegistry {
248 pub fn insert(&self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
251 self.providers
252 .write()
253 .expect("ObjectStoreRegistry lock poisoned")
254 .insert(scheme.into(), provider);
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_cache_url() {
264 let cases = [
265 ("s3://bucket/path?param=value", "s3://bucket?param=value"),
266 ("file:///path/to/file", "file://"),
267 ("file-object-store:///path/to/file", "file-object-store://"),
268 ("memory:///", "memory://"),
269 (
270 "http://example.com/path?param=value",
271 "http://example.com/?param=value",
272 ),
273 ];
274
275 for (url, expected_cache_url) in cases {
276 let url = Url::parse(url).unwrap();
277 let cache_url = cache_url(&url);
278 assert_eq!(cache_url, expected_cache_url);
279 }
280 }
281}