lance_io/object_store/
providers.rs1use std::{
5 collections::HashMap,
6 sync::{Arc, RwLock, Weak},
7};
8
9use object_store::path::Path;
10use snafu::location;
11use url::Url;
12
13use super::{tracing::ObjectStoreTracingExt, ObjectStore, ObjectStoreParams};
14use lance_core::error::{Error, LanceOptionExt, Result};
15
16#[cfg(feature = "aws")]
17pub mod aws;
18#[cfg(feature = "azure")]
19pub mod azure;
20#[cfg(feature = "gcp")]
21pub mod gcp;
22pub mod local;
23pub mod memory;
24#[cfg(feature = "oss")]
25pub mod oss;
26
27#[async_trait::async_trait]
28pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
29 async fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
30
31 fn extract_path(&self, url: &Url) -> Path {
39 Path::from(url.path())
40 }
41}
42
43#[derive(Debug)]
63pub struct ObjectStoreRegistry {
64 providers: RwLock<HashMap<String, Arc<dyn ObjectStoreProvider>>>,
65 active_stores: RwLock<HashMap<(String, ObjectStoreParams), Weak<ObjectStore>>>,
69}
70
71fn cache_url(url: &Url) -> String {
80 if ["file", "file-object-store", "memory"].contains(&url.scheme()) {
81 format!("{}://", url.scheme())
85 } else {
86 let mut url = url.clone();
88 url.set_path("");
89 url.to_string()
90 }
91}
92
93impl ObjectStoreRegistry {
94 pub fn empty() -> Self {
99 Self {
100 providers: RwLock::new(HashMap::new()),
101 active_stores: RwLock::new(HashMap::new()),
102 }
103 }
104
105 pub fn get_provider(&self, scheme: &str) -> Option<Arc<dyn ObjectStoreProvider>> {
107 self.providers
108 .read()
109 .expect("ObjectStoreRegistry lock poisoned")
110 .get(scheme)
111 .cloned()
112 }
113
114 pub fn active_stores(&self) -> Vec<Arc<ObjectStore>> {
119 let mut found_inactive = false;
120 let output = self
121 .active_stores
122 .read()
123 .expect("ObjectStoreRegistry lock poisoned")
124 .values()
125 .filter_map(|weak| match weak.upgrade() {
126 Some(store) => Some(store),
127 None => {
128 found_inactive = true;
129 None
130 }
131 })
132 .collect();
133
134 if found_inactive {
135 let mut cache_lock = self
137 .active_stores
138 .write()
139 .expect("ObjectStoreRegistry lock poisoned");
140 cache_lock.retain(|_, weak| weak.upgrade().is_some());
141 }
142 output
143 }
144
145 pub async fn get_store(
151 &self,
152 base_path: Url,
153 params: &ObjectStoreParams,
154 ) -> Result<Arc<ObjectStore>> {
155 let cache_path = cache_url(&base_path);
156 let cache_key = (cache_path, params.clone());
157
158 {
160 let maybe_store = self
161 .active_stores
162 .read()
163 .ok()
164 .expect_ok()?
165 .get(&cache_key)
166 .cloned();
167 if let Some(store) = maybe_store {
168 if let Some(store) = store.upgrade() {
169 return Ok(store);
170 } else {
171 let mut cache_lock = self
173 .active_stores
174 .write()
175 .expect("ObjectStoreRegistry lock poisoned");
176 if let Some(store) = cache_lock.get(&cache_key) {
177 if store.upgrade().is_none() {
178 cache_lock.remove(&cache_key);
180 }
181 }
182 }
183 }
184 }
185
186 let scheme = base_path.scheme();
187 let Some(provider) = self.get_provider(scheme) else {
188 let mut message = format!("No object store provider found for scheme: '{}'", scheme);
189 if let Ok(providers) = self.providers.read() {
190 let valid_schemes = providers.keys().cloned().collect::<Vec<_>>().join(", ");
191 message.push_str(&format!("\nValid schemes: {}", valid_schemes));
192 }
193
194 return Err(Error::invalid_input(message, location!()));
195 };
196 let mut store = provider.new_store(base_path, params).await?;
197
198 store.inner = store.inner.traced();
199
200 if let Some(wrapper) = ¶ms.object_store_wrapper {
201 store.inner = wrapper.wrap(store.inner);
202 }
203
204 let store = Arc::new(store);
205
206 {
207 let mut cache_lock = self.active_stores.write().ok().expect_ok()?;
209 cache_lock.insert(cache_key, Arc::downgrade(&store));
210 }
211
212 Ok(store)
213 }
214}
215
216impl Default for ObjectStoreRegistry {
217 fn default() -> Self {
218 let mut providers: HashMap<String, Arc<dyn ObjectStoreProvider>> = HashMap::new();
219
220 providers.insert("memory".into(), Arc::new(memory::MemoryStoreProvider));
221 providers.insert("file".into(), Arc::new(local::FileStoreProvider));
222 providers.insert(
228 "file-object-store".into(),
229 Arc::new(local::FileStoreProvider),
230 );
231
232 #[cfg(feature = "aws")]
233 {
234 let aws = Arc::new(aws::AwsStoreProvider);
235 providers.insert("s3".into(), aws.clone());
236 providers.insert("s3+ddb".into(), aws);
237 }
238 #[cfg(feature = "azure")]
239 providers.insert("az".into(), Arc::new(azure::AzureBlobStoreProvider));
240 #[cfg(feature = "gcp")]
241 providers.insert("gs".into(), Arc::new(gcp::GcsStoreProvider));
242 #[cfg(feature = "oss")]
243 providers.insert("oss".into(), Arc::new(oss::OssStoreProvider));
244 Self {
245 providers: RwLock::new(providers),
246 active_stores: RwLock::new(HashMap::new()),
247 }
248 }
249}
250
251impl ObjectStoreRegistry {
252 pub fn insert(&self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
255 self.providers
256 .write()
257 .expect("ObjectStoreRegistry lock poisoned")
258 .insert(scheme.into(), provider);
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_cache_url() {
268 let cases = [
269 ("s3://bucket/path?param=value", "s3://bucket?param=value"),
270 ("file:///path/to/file", "file://"),
271 ("file-object-store:///path/to/file", "file-object-store://"),
272 ("memory:///", "memory://"),
273 (
274 "http://example.com/path?param=value",
275 "http://example.com/?param=value",
276 ),
277 ];
278
279 for (url, expected_cache_url) in cases {
280 let url = Url::parse(url).unwrap();
281 let cache_url = cache_url(&url);
282 assert_eq!(cache_url, expected_cache_url);
283 }
284 }
285}