lance_io/object_store/
providers.rs1use std::{
5 collections::HashMap,
6 sync::{Arc, RwLock, Weak},
7};
8
9use object_store::path::Path;
10use snafu::location;
11use url::Url;
12
13use super::{tracing::ObjectStoreTracingExt, ObjectStore, ObjectStoreParams};
14use lance_core::error::{Error, LanceOptionExt, Result};
15
16#[cfg(feature = "aws")]
17pub mod aws;
18#[cfg(feature = "azure")]
19pub mod azure;
20#[cfg(feature = "gcp")]
21pub mod gcp;
22pub mod local;
23pub mod memory;
24#[cfg(feature = "oss")]
25pub mod oss;
26
27#[async_trait::async_trait]
28pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
29 async fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
30
31 fn extract_path(&self, url: &Url) -> Result<Path> {
39 Path::parse(url.path()).map_err(|_| {
40 Error::invalid_input(format!("Invalid path in URL: {}", url.path()), location!())
41 })
42 }
43}
44
45#[derive(Debug)]
65pub struct ObjectStoreRegistry {
66 providers: RwLock<HashMap<String, Arc<dyn ObjectStoreProvider>>>,
67 active_stores: RwLock<HashMap<(String, ObjectStoreParams), Weak<ObjectStore>>>,
71}
72
73fn cache_url(url: &Url) -> String {
82 if ["file", "file-object-store", "memory"].contains(&url.scheme()) {
83 format!("{}://", url.scheme())
87 } else {
88 let mut url = url.clone();
90 url.set_path("");
91 url.to_string()
92 }
93}
94
95impl ObjectStoreRegistry {
96 pub fn empty() -> Self {
101 Self {
102 providers: RwLock::new(HashMap::new()),
103 active_stores: RwLock::new(HashMap::new()),
104 }
105 }
106
107 pub fn get_provider(&self, scheme: &str) -> Option<Arc<dyn ObjectStoreProvider>> {
109 self.providers
110 .read()
111 .expect("ObjectStoreRegistry lock poisoned")
112 .get(scheme)
113 .cloned()
114 }
115
116 pub fn active_stores(&self) -> Vec<Arc<ObjectStore>> {
121 let mut found_inactive = false;
122 let output = self
123 .active_stores
124 .read()
125 .expect("ObjectStoreRegistry lock poisoned")
126 .values()
127 .filter_map(|weak| match weak.upgrade() {
128 Some(store) => Some(store),
129 None => {
130 found_inactive = true;
131 None
132 }
133 })
134 .collect();
135
136 if found_inactive {
137 let mut cache_lock = self
139 .active_stores
140 .write()
141 .expect("ObjectStoreRegistry lock poisoned");
142 cache_lock.retain(|_, weak| weak.upgrade().is_some());
143 }
144 output
145 }
146
147 pub async fn get_store(
153 &self,
154 base_path: Url,
155 params: &ObjectStoreParams,
156 ) -> Result<Arc<ObjectStore>> {
157 let cache_path = cache_url(&base_path);
158 let cache_key = (cache_path, params.clone());
159
160 {
162 let maybe_store = self
163 .active_stores
164 .read()
165 .ok()
166 .expect_ok()?
167 .get(&cache_key)
168 .cloned();
169 if let Some(store) = maybe_store {
170 if let Some(store) = store.upgrade() {
171 return Ok(store);
172 } else {
173 let mut cache_lock = self
175 .active_stores
176 .write()
177 .expect("ObjectStoreRegistry lock poisoned");
178 if let Some(store) = cache_lock.get(&cache_key) {
179 if store.upgrade().is_none() {
180 cache_lock.remove(&cache_key);
182 }
183 }
184 }
185 }
186 }
187
188 let scheme = base_path.scheme();
189 let Some(provider) = self.get_provider(scheme) else {
190 let mut message = format!("No object store provider found for scheme: '{}'", scheme);
191 if let Ok(providers) = self.providers.read() {
192 let valid_schemes = providers.keys().cloned().collect::<Vec<_>>().join(", ");
193 message.push_str(&format!("\nValid schemes: {}", valid_schemes));
194 }
195
196 return Err(Error::invalid_input(message, location!()));
197 };
198 let mut store = provider.new_store(base_path, params).await?;
199
200 store.inner = store.inner.traced();
201
202 if let Some(wrapper) = ¶ms.object_store_wrapper {
203 store.inner = wrapper.wrap(store.inner);
204 }
205
206 let store = Arc::new(store);
207
208 {
209 let mut cache_lock = self.active_stores.write().ok().expect_ok()?;
211 cache_lock.insert(cache_key, Arc::downgrade(&store));
212 }
213
214 Ok(store)
215 }
216}
217
218impl Default for ObjectStoreRegistry {
219 fn default() -> Self {
220 let mut providers: HashMap<String, Arc<dyn ObjectStoreProvider>> = HashMap::new();
221
222 providers.insert("memory".into(), Arc::new(memory::MemoryStoreProvider));
223 providers.insert("file".into(), Arc::new(local::FileStoreProvider));
224 providers.insert(
230 "file-object-store".into(),
231 Arc::new(local::FileStoreProvider),
232 );
233
234 #[cfg(feature = "aws")]
235 {
236 let aws = Arc::new(aws::AwsStoreProvider);
237 providers.insert("s3".into(), aws.clone());
238 providers.insert("s3+ddb".into(), aws);
239 }
240 #[cfg(feature = "azure")]
241 providers.insert("az".into(), Arc::new(azure::AzureBlobStoreProvider));
242 #[cfg(feature = "gcp")]
243 providers.insert("gs".into(), Arc::new(gcp::GcsStoreProvider));
244 #[cfg(feature = "oss")]
245 providers.insert("oss".into(), Arc::new(oss::OssStoreProvider));
246 Self {
247 providers: RwLock::new(providers),
248 active_stores: RwLock::new(HashMap::new()),
249 }
250 }
251}
252
253impl ObjectStoreRegistry {
254 pub fn insert(&self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
257 self.providers
258 .write()
259 .expect("ObjectStoreRegistry lock poisoned")
260 .insert(scheme.into(), provider);
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_cache_url() {
270 let cases = [
271 ("s3://bucket/path?param=value", "s3://bucket?param=value"),
272 ("file:///path/to/file", "file://"),
273 ("file-object-store:///path/to/file", "file-object-store://"),
274 ("memory:///", "memory://"),
275 (
276 "http://example.com/path?param=value",
277 "http://example.com/?param=value",
278 ),
279 ];
280
281 for (url, expected_cache_url) in cases {
282 let url = Url::parse(url).unwrap();
283 let cache_url = cache_url(&url);
284 assert_eq!(cache_url, expected_cache_url);
285 }
286 }
287}