lance_io/
object_store.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extend [object_store::ObjectStore] functionalities
5
6use std::collections::HashMap;
7use std::ops::Range;
8use std::path::PathBuf;
9use std::str::FromStr;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12
13use async_trait::async_trait;
14use aws_config::default_provider::credentials::DefaultCredentialsChain;
15use aws_credential_types::provider::ProvideCredentials;
16use bytes::Bytes;
17use chrono::{DateTime, Utc};
18use deepsize::DeepSizeOf;
19use futures::{future, stream::BoxStream, StreamExt, TryStreamExt};
20use lance_core::utils::parse::str_is_truthy;
21use lance_core::utils::tokio::get_num_compute_intensive_cpus;
22use object_store::aws::{
23    AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential, AwsCredentialProvider,
24};
25use object_store::azure::MicrosoftAzureBuilder;
26use object_store::gcp::{GcpCredential, GoogleCloudStorageBuilder};
27use object_store::{
28    aws::AmazonS3Builder, azure::AzureConfigKey, gcp::GoogleConfigKey, local::LocalFileSystem,
29    memory::InMemory, CredentialProvider, Error as ObjectStoreError, Result as ObjectStoreResult,
30};
31use object_store::{path::Path, ObjectMeta, ObjectStore as OSObjectStore};
32use object_store::{ClientOptions, DynObjectStore, RetryConfig, StaticCredentialProvider};
33use shellexpand::tilde;
34use snafu::location;
35use tokio::io::AsyncWriteExt;
36use tokio::sync::RwLock;
37use url::Url;
38
39use super::local::LocalObjectReader;
40mod tracing;
41use self::tracing::ObjectStoreTracingExt;
42use crate::{object_reader::CloudObjectReader, object_writer::ObjectWriter, traits::Reader};
43use lance_core::{Error, Result};
44
45// Local disks tend to do fine with a few threads
46// Note: the number of threads here also impacts the number of files
47// we need to read in some situations.  So keeping this at 8 keeps the
48// RAM on our scanner down.
49pub const DEFAULT_LOCAL_IO_PARALLELISM: usize = 8;
50// Cloud disks often need many many threads to saturate the network
51pub const DEFAULT_CLOUD_IO_PARALLELISM: usize = 64;
52
53pub const DEFAULT_DOWNLOAD_RETRY_COUNT: usize = 3;
54
55#[async_trait]
56pub trait ObjectStoreExt {
57    /// Returns true if the file exists.
58    async fn exists(&self, path: &Path) -> Result<bool>;
59
60    /// Read all files (start from base directory) recursively
61    ///
62    /// unmodified_since can be specified to only return files that have not been modified since the given time.
63    async fn read_dir_all<'a>(
64        &'a self,
65        dir_path: impl Into<&Path> + Send,
66        unmodified_since: Option<DateTime<Utc>>,
67    ) -> Result<BoxStream<'a, Result<ObjectMeta>>>;
68}
69
70#[async_trait]
71impl<O: OSObjectStore + ?Sized> ObjectStoreExt for O {
72    async fn read_dir_all<'a>(
73        &'a self,
74        dir_path: impl Into<&Path> + Send,
75        unmodified_since: Option<DateTime<Utc>>,
76    ) -> Result<BoxStream<'a, Result<ObjectMeta>>> {
77        let mut output = self.list(Some(dir_path.into()));
78        if let Some(unmodified_since_val) = unmodified_since {
79            output = output
80                .try_filter(move |file| future::ready(file.last_modified < unmodified_since_val))
81                .boxed();
82        }
83        Ok(output.map_err(|e| e.into()).boxed())
84    }
85
86    async fn exists(&self, path: &Path) -> Result<bool> {
87        match self.head(path).await {
88            Ok(_) => Ok(true),
89            Err(object_store::Error::NotFound { path: _, source: _ }) => Ok(false),
90            Err(e) => Err(e.into()),
91        }
92    }
93}
94
95/// Wraps [ObjectStore](object_store::ObjectStore)
96#[derive(Debug, Clone)]
97pub struct ObjectStore {
98    // Inner object store
99    pub inner: Arc<dyn OSObjectStore>,
100    scheme: String,
101    block_size: usize,
102    /// Whether to use constant size upload parts for multipart uploads. This
103    /// is only necessary for Cloudflare R2.
104    pub use_constant_size_upload_parts: bool,
105    /// Whether we can assume that the list of files is lexically ordered. This
106    /// is true for object stores, but not for local filesystems.
107    pub list_is_lexically_ordered: bool,
108    io_parallelism: usize,
109    /// Number of times to retry a failed download
110    download_retry_count: usize,
111}
112
113impl DeepSizeOf for ObjectStore {
114    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
115        // We aren't counting `inner` here which is problematic but an ObjectStore
116        // shouldn't be too big.  The only exception might be the write cache but, if
117        // the writer cache has data, it means we're using it somewhere else that isn't
118        // a cache and so that doesn't really count.
119        self.scheme.deep_size_of_children(context) + self.block_size.deep_size_of_children(context)
120    }
121}
122
123impl std::fmt::Display for ObjectStore {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        write!(f, "ObjectStore({})", self.scheme)
126    }
127}
128
129pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
130    fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
131}
132
133#[derive(Default, Debug)]
134pub struct ObjectStoreRegistry {
135    providers: HashMap<String, Arc<dyn ObjectStoreProvider>>,
136}
137
138impl ObjectStoreRegistry {
139    pub fn insert(&mut self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
140        self.providers.insert(scheme.into(), provider);
141    }
142}
143
144const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
145
146/// Adapt an AWS SDK cred into object_store credentials
147#[derive(Debug)]
148pub struct AwsCredentialAdapter {
149    pub inner: Arc<dyn ProvideCredentials>,
150
151    // RefCell can't be shared across threads, so we use HashMap
152    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
153
154    // The amount of time before expiry to refresh credentials
155    credentials_refresh_offset: Duration,
156}
157
158impl AwsCredentialAdapter {
159    pub fn new(
160        provider: Arc<dyn ProvideCredentials>,
161        credentials_refresh_offset: Duration,
162    ) -> Self {
163        Self {
164            inner: provider,
165            cache: Arc::new(RwLock::new(HashMap::new())),
166            credentials_refresh_offset,
167        }
168    }
169}
170
171#[async_trait]
172impl CredentialProvider for AwsCredentialAdapter {
173    type Credential = ObjectStoreAwsCredential;
174
175    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
176        let cached_creds = {
177            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
178            let expired = cache_value
179                .clone()
180                .map(|cred| {
181                    cred.expiry()
182                        .map(|exp| {
183                            exp.checked_sub(self.credentials_refresh_offset)
184                                .expect("this time should always be valid")
185                                < SystemTime::now()
186                        })
187                        // no expiry is never expire
188                        .unwrap_or(false)
189                })
190                .unwrap_or(true); // no cred is the same as expired;
191            if expired {
192                None
193            } else {
194                cache_value.clone()
195            }
196        };
197
198        if let Some(creds) = cached_creds {
199            Ok(Arc::new(Self::Credential {
200                key_id: creds.access_key_id().to_string(),
201                secret_key: creds.secret_access_key().to_string(),
202                token: creds.session_token().map(|s| s.to_string()),
203            }))
204        } else {
205            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
206                |e| Error::Internal {
207                    message: format!("Failed to get AWS credentials: {}", e),
208                    location: location!(),
209                },
210            )?);
211
212            self.cache
213                .write()
214                .await
215                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
216
217            Ok(Arc::new(Self::Credential {
218                key_id: refreshed_creds.access_key_id().to_string(),
219                secret_key: refreshed_creds.secret_access_key().to_string(),
220                token: refreshed_creds.session_token().map(|s| s.to_string()),
221            }))
222        }
223    }
224}
225
226/// Figure out the S3 region of the bucket.
227///
228/// This resolves in order of precedence:
229/// 1. The region provided in the storage options
230/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
231///
232/// It can return None if no region is provided and the endpoint is set.
233async fn resolve_s3_region(
234    url: &Url,
235    storage_options: &HashMap<AmazonS3ConfigKey, String>,
236) -> Result<Option<String>> {
237    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
238        Ok(Some(region.clone()))
239    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
240        // If no endpoint is set, we can assume this is AWS S3 and the region
241        // can be resolved from the bucket.
242        let bucket = url.host_str().ok_or_else(|| {
243            Error::invalid_input(
244                format!("Could not parse bucket from url: {}", url),
245                location!(),
246            )
247        })?;
248
249        let mut client_options = ClientOptions::default();
250        for (key, value) in storage_options {
251            if let AmazonS3ConfigKey::Client(client_key) = key {
252                client_options = client_options.with_config(*client_key, value.clone());
253            }
254        }
255
256        let bucket_region =
257            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
258        Ok(Some(bucket_region))
259    } else {
260        Ok(None)
261    }
262}
263
264/// Build AWS credentials
265///
266/// This resolves credentials from the following sources in order:
267/// 1. An explicit `credentials` provider
268/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`,
269///    `aws_secret_access_key`, `aws_session_token`)
270/// 3. The default credential provider chain from AWS SDK.
271///
272/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
273pub async fn build_aws_credential(
274    credentials_refresh_offset: Duration,
275    credentials: Option<AwsCredentialProvider>,
276    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
277    region: Option<String>,
278) -> Result<(AwsCredentialProvider, String)> {
279    // TODO: make this return no credential provider not using AWS
280    use aws_config::meta::region::RegionProviderChain;
281    const DEFAULT_REGION: &str = "us-west-2";
282
283    let region = if let Some(region) = region {
284        region
285    } else {
286        RegionProviderChain::default_provider()
287            .or_else(DEFAULT_REGION)
288            .region()
289            .await
290            .map(|r| r.as_ref().to_string())
291            .unwrap_or(DEFAULT_REGION.to_string())
292    };
293
294    if let Some(creds) = credentials {
295        Ok((creds, region))
296    } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
297        Ok((Arc::new(creds), region))
298    } else {
299        let credentials_provider = DefaultCredentialsChain::builder().build().await;
300
301        Ok((
302            Arc::new(AwsCredentialAdapter::new(
303                Arc::new(credentials_provider),
304                credentials_refresh_offset,
305            )),
306            region,
307        ))
308    }
309}
310
311fn extract_static_s3_credentials(
312    options: &HashMap<AmazonS3ConfigKey, String>,
313) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
314    let key_id = options
315        .get(&AmazonS3ConfigKey::AccessKeyId)
316        .map(|s| s.to_string());
317    let secret_key = options
318        .get(&AmazonS3ConfigKey::SecretAccessKey)
319        .map(|s| s.to_string());
320    let token = options
321        .get(&AmazonS3ConfigKey::Token)
322        .map(|s| s.to_string());
323    match (key_id, secret_key, token) {
324        (Some(key_id), Some(secret_key), token) => {
325            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
326                key_id,
327                secret_key,
328                token,
329            }))
330        }
331        _ => None,
332    }
333}
334
335pub trait WrappingObjectStore: std::fmt::Debug + Send + Sync {
336    fn wrap(&self, original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore>;
337}
338
339/// Parameters to create an [ObjectStore]
340///
341#[derive(Debug, Clone)]
342pub struct ObjectStoreParams {
343    pub block_size: Option<usize>,
344    pub object_store: Option<(Arc<DynObjectStore>, Url)>,
345    pub s3_credentials_refresh_offset: Duration,
346    pub aws_credentials: Option<AwsCredentialProvider>,
347    pub object_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
348    pub storage_options: Option<HashMap<String, String>>,
349    /// Use constant size upload parts for multipart uploads. Only necessary
350    /// for Cloudflare R2, which doesn't support variable size parts. When this
351    /// is false, max upload size is 2.5TB. When this is true, the max size is
352    /// 50GB.
353    pub use_constant_size_upload_parts: bool,
354    pub list_is_lexically_ordered: Option<bool>,
355}
356
357impl Default for ObjectStoreParams {
358    fn default() -> Self {
359        Self {
360            object_store: None,
361            block_size: None,
362            s3_credentials_refresh_offset: Duration::from_secs(60),
363            aws_credentials: None,
364            object_store_wrapper: None,
365            storage_options: None,
366            use_constant_size_upload_parts: false,
367            list_is_lexically_ordered: None,
368        }
369    }
370}
371
372impl ObjectStoreParams {
373    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
374    pub fn with_aws_credentials(
375        aws_credentials: Option<AwsCredentialProvider>,
376        region: Option<String>,
377    ) -> Self {
378        Self {
379            aws_credentials,
380            storage_options: region
381                .map(|region| [("region".into(), region)].iter().cloned().collect()),
382            ..Default::default()
383        }
384    }
385}
386
387impl ObjectStore {
388    /// Parse from a string URI.
389    ///
390    /// Returns the ObjectStore instance and the absolute path to the object.
391    pub async fn from_uri(uri: &str) -> Result<(Self, Path)> {
392        let registry = Arc::new(ObjectStoreRegistry::default());
393
394        Self::from_uri_and_params(registry, uri, &ObjectStoreParams::default()).await
395    }
396
397    /// Parse from a string URI.
398    ///
399    /// Returns the ObjectStore instance and the absolute path to the object.
400    pub async fn from_uri_and_params(
401        registry: Arc<ObjectStoreRegistry>,
402        uri: &str,
403        params: &ObjectStoreParams,
404    ) -> Result<(Self, Path)> {
405        if let Some((store, path)) = params.object_store.as_ref() {
406            let mut inner = store.clone();
407            if let Some(wrapper) = params.object_store_wrapper.as_ref() {
408                inner = wrapper.wrap(inner);
409            }
410            let store = Self {
411                inner,
412                scheme: path.scheme().to_string(),
413                block_size: params.block_size.unwrap_or(64 * 1024),
414                use_constant_size_upload_parts: params.use_constant_size_upload_parts,
415                list_is_lexically_ordered: params.list_is_lexically_ordered.unwrap_or_default(),
416                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
417                download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
418            };
419            let path = Path::from(path.path());
420            return Ok((store, path));
421        }
422        let (object_store, path) = match Url::parse(uri) {
423            Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
424                // On Windows, the drive is parsed as a scheme
425                Self::from_path(uri)
426            }
427            Ok(url) => {
428                let store = Self::new_from_url(registry, url.clone(), params.clone()).await?;
429                Ok((store, Path::from(url.path())))
430            }
431            Err(_) => Self::from_path(uri),
432        }?;
433
434        Ok((
435            Self {
436                inner: params
437                    .object_store_wrapper
438                    .as_ref()
439                    .map(|w| w.wrap(object_store.inner.clone()))
440                    .unwrap_or(object_store.inner),
441                ..object_store
442            },
443            path,
444        ))
445    }
446
447    pub fn from_path_with_scheme(str_path: &str, scheme: &str) -> Result<(Self, Path)> {
448        let expanded = tilde(str_path).to_string();
449
450        let mut expanded_path = path_abs::PathAbs::new(expanded)
451            .unwrap()
452            .as_path()
453            .to_path_buf();
454        // path_abs::PathAbs::new(".") returns an empty string.
455        if let Some(s) = expanded_path.as_path().to_str() {
456            if s.is_empty() {
457                expanded_path = std::env::current_dir()?;
458            }
459        }
460        Ok((
461            Self {
462                inner: Arc::new(LocalFileSystem::new()).traced(),
463                scheme: String::from(scheme),
464                block_size: 4 * 1024, // 4KB block size
465                use_constant_size_upload_parts: false,
466                list_is_lexically_ordered: false,
467                io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM,
468                download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
469            },
470            Path::from_absolute_path(expanded_path.as_path())?,
471        ))
472    }
473
474    pub fn from_path(str_path: &str) -> Result<(Self, Path)> {
475        Self::from_path_with_scheme(str_path, "file")
476    }
477
478    async fn new_from_url(
479        registry: Arc<ObjectStoreRegistry>,
480        url: Url,
481        params: ObjectStoreParams,
482    ) -> Result<Self> {
483        configure_store(registry, url.as_str(), params).await
484    }
485
486    /// Local object store.
487    pub fn local() -> Self {
488        Self {
489            inner: Arc::new(LocalFileSystem::new()).traced(),
490            scheme: String::from("file"),
491            block_size: 4 * 1024, // 4KB block size
492            use_constant_size_upload_parts: false,
493            list_is_lexically_ordered: false,
494            io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM,
495            download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
496        }
497    }
498
499    /// Create a in-memory object store directly for testing.
500    pub fn memory() -> Self {
501        Self {
502            inner: Arc::new(InMemory::new()).traced(),
503            scheme: String::from("memory"),
504            block_size: 4 * 1024,
505            use_constant_size_upload_parts: false,
506            list_is_lexically_ordered: true,
507            io_parallelism: get_num_compute_intensive_cpus(),
508            download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
509        }
510    }
511
512    /// Returns true if the object store pointed to a local file system.
513    pub fn is_local(&self) -> bool {
514        self.scheme == "file"
515    }
516
517    pub fn is_cloud(&self) -> bool {
518        self.scheme != "file" && self.scheme != "memory"
519    }
520
521    pub fn block_size(&self) -> usize {
522        self.block_size
523    }
524
525    pub fn set_block_size(&mut self, new_size: usize) {
526        self.block_size = new_size;
527    }
528
529    pub fn set_io_parallelism(&mut self, io_parallelism: usize) {
530        self.io_parallelism = io_parallelism;
531    }
532
533    pub fn io_parallelism(&self) -> usize {
534        std::env::var("LANCE_IO_THREADS")
535            .map(|val| val.parse::<usize>().unwrap())
536            .unwrap_or(self.io_parallelism)
537    }
538
539    /// Open a file for path.
540    ///
541    /// Parameters
542    /// - ``path``: Absolute path to the file.
543    pub async fn open(&self, path: &Path) -> Result<Box<dyn Reader>> {
544        match self.scheme.as_str() {
545            "file" => LocalObjectReader::open(path, self.block_size, None).await,
546            _ => Ok(Box::new(CloudObjectReader::new(
547                self.inner.clone(),
548                path.clone(),
549                self.block_size,
550                None,
551                self.download_retry_count,
552            )?)),
553        }
554    }
555
556    /// Open a reader for a file with known size.
557    ///
558    /// This size may either have been retrieved from a list operation or
559    /// cached metadata. By passing in the known size, we can skip a HEAD / metadata
560    /// call.
561    pub async fn open_with_size(&self, path: &Path, known_size: usize) -> Result<Box<dyn Reader>> {
562        match self.scheme.as_str() {
563            "file" => LocalObjectReader::open(path, self.block_size, Some(known_size)).await,
564            _ => Ok(Box::new(CloudObjectReader::new(
565                self.inner.clone(),
566                path.clone(),
567                self.block_size,
568                Some(known_size),
569                self.download_retry_count,
570            )?)),
571        }
572    }
573
574    /// Create an [ObjectWriter] from local [std::path::Path]
575    pub async fn create_local_writer(path: &std::path::Path) -> Result<ObjectWriter> {
576        let object_store = Self::local();
577        let os_path = Path::from(path.to_str().unwrap());
578        object_store.create(&os_path).await
579    }
580
581    /// Open an [Reader] from local [std::path::Path]
582    pub async fn open_local(path: &std::path::Path) -> Result<Box<dyn Reader>> {
583        let object_store = Self::local();
584        let os_path = Path::from(path.to_str().unwrap());
585        object_store.open(&os_path).await
586    }
587
588    /// Create a new file.
589    pub async fn create(&self, path: &Path) -> Result<ObjectWriter> {
590        ObjectWriter::new(self, path).await
591    }
592
593    /// A helper function to create a file and write content to it.
594    pub async fn put(&self, path: &Path, content: &[u8]) -> Result<()> {
595        let mut writer = self.create(path).await?;
596        writer.write_all(content).await?;
597        writer.shutdown().await
598    }
599
600    pub async fn delete(&self, path: &Path) -> Result<()> {
601        self.inner.delete(path).await?;
602        Ok(())
603    }
604
605    pub async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
606        Ok(self.inner.copy(from, to).await?)
607    }
608
609    /// Read a directory (start from base directory) and returns all sub-paths in the directory.
610    pub async fn read_dir(&self, dir_path: impl Into<Path>) -> Result<Vec<String>> {
611        let path = dir_path.into();
612        let path = Path::parse(&path)?;
613        let output = self.inner.list_with_delimiter(Some(&path)).await?;
614        Ok(output
615            .common_prefixes
616            .iter()
617            .chain(output.objects.iter().map(|o| &o.location))
618            .map(|s| s.filename().unwrap().to_string())
619            .collect())
620    }
621
622    /// Read all files (start from base directory) recursively
623    ///
624    /// unmodified_since can be specified to only return files that have not been modified since the given time.
625    pub async fn read_dir_all(
626        &self,
627        dir_path: impl Into<&Path> + Send,
628        unmodified_since: Option<DateTime<Utc>>,
629    ) -> Result<BoxStream<Result<ObjectMeta>>> {
630        self.inner.read_dir_all(dir_path, unmodified_since).await
631    }
632
633    /// Remove a directory recursively.
634    pub async fn remove_dir_all(&self, dir_path: impl Into<Path>) -> Result<()> {
635        let path = dir_path.into();
636        let path = Path::parse(&path)?;
637
638        if self.is_local() {
639            // Local file system needs to delete directories as well.
640            return super::local::remove_dir_all(&path);
641        }
642        let sub_entries = self
643            .inner
644            .list(Some(&path))
645            .map(|m| m.map(|meta| meta.location))
646            .boxed();
647        self.inner
648            .delete_stream(sub_entries)
649            .try_collect::<Vec<_>>()
650            .await?;
651        Ok(())
652    }
653
654    pub fn remove_stream<'a>(
655        &'a self,
656        locations: BoxStream<'a, Result<Path>>,
657    ) -> BoxStream<'a, Result<Path>> {
658        self.inner
659            .delete_stream(locations.err_into::<ObjectStoreError>().boxed())
660            .err_into::<Error>()
661            .boxed()
662    }
663
664    /// Check a file exists.
665    pub async fn exists(&self, path: &Path) -> Result<bool> {
666        match self.inner.head(path).await {
667            Ok(_) => Ok(true),
668            Err(object_store::Error::NotFound { path: _, source: _ }) => Ok(false),
669            Err(e) => Err(e.into()),
670        }
671    }
672
673    /// Get file size.
674    pub async fn size(&self, path: &Path) -> Result<usize> {
675        Ok(self.inner.head(path).await?.size)
676    }
677
678    /// Convenience function to open a reader and read all the bytes
679    pub async fn read_one_all(&self, path: &Path) -> Result<Bytes> {
680        let reader = self.open(path).await?;
681        Ok(reader.get_all().await?)
682    }
683
684    /// Convenience function open a reader and make a single request
685    ///
686    /// If you will be making multiple requests to the path it is more efficient to call [`Self::open`]
687    /// and then call [`Reader::get_range`] multiple times.
688    pub async fn read_one_range(&self, path: &Path, range: Range<usize>) -> Result<Bytes> {
689        let reader = self.open(path).await?;
690        Ok(reader.get_range(range).await?)
691    }
692}
693
694/// Options that can be set for multiple object stores
695#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy)]
696pub enum LanceConfigKey {
697    /// Number of times to retry a download that fails
698    DownloadRetryCount,
699}
700
701impl FromStr for LanceConfigKey {
702    type Err = Error;
703
704    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
705        match s.to_ascii_lowercase().as_str() {
706            "download_retry_count" => Ok(Self::DownloadRetryCount),
707            _ => Err(Error::InvalidInput {
708                source: format!("Invalid LanceConfigKey: {}", s).into(),
709                location: location!(),
710            }),
711        }
712    }
713}
714
715#[derive(Clone, Debug, Default)]
716pub struct StorageOptions(pub HashMap<String, String>);
717
718impl StorageOptions {
719    /// Create a new instance of [`StorageOptions`]
720    pub fn new(options: HashMap<String, String>) -> Self {
721        let mut options = options;
722        if let Ok(value) = std::env::var("AZURE_STORAGE_ALLOW_HTTP") {
723            options.insert("allow_http".into(), value);
724        }
725        if let Ok(value) = std::env::var("AZURE_STORAGE_USE_HTTP") {
726            options.insert("allow_http".into(), value);
727        }
728        if let Ok(value) = std::env::var("AWS_ALLOW_HTTP") {
729            options.insert("allow_http".into(), value);
730        }
731        if let Ok(value) = std::env::var("OBJECT_STORE_CLIENT_MAX_RETRIES") {
732            options.insert("client_max_retries".into(), value);
733        }
734        if let Ok(value) = std::env::var("OBJECT_STORE_CLIENT_RETRY_TIMEOUT") {
735            options.insert("client_retry_timeout".into(), value);
736        }
737        Self(options)
738    }
739
740    /// Add values from the environment to storage options
741    pub fn with_env_azure(&mut self) {
742        for (os_key, os_value) in std::env::vars_os() {
743            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
744                if let Ok(config_key) = AzureConfigKey::from_str(&key.to_ascii_lowercase()) {
745                    if !self.0.contains_key(config_key.as_ref()) {
746                        self.0
747                            .insert(config_key.as_ref().to_string(), value.to_string());
748                    }
749                }
750            }
751        }
752    }
753
754    /// Add values from the environment to storage options
755    pub fn with_env_gcs(&mut self) {
756        for (os_key, os_value) in std::env::vars_os() {
757            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
758                if let Ok(config_key) = GoogleConfigKey::from_str(&key.to_ascii_lowercase()) {
759                    if !self.0.contains_key(config_key.as_ref()) {
760                        self.0
761                            .insert(config_key.as_ref().to_string(), value.to_string());
762                    }
763                }
764            }
765        }
766    }
767
768    /// Add values from the environment to storage options
769    pub fn with_env_s3(&mut self) {
770        for (os_key, os_value) in std::env::vars_os() {
771            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
772                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
773                    if !self.0.contains_key(config_key.as_ref()) {
774                        self.0
775                            .insert(config_key.as_ref().to_string(), value.to_string());
776                    }
777                }
778            }
779        }
780    }
781
782    /// Denotes if unsecure connections via http are allowed
783    pub fn allow_http(&self) -> bool {
784        self.0.iter().any(|(key, value)| {
785            key.to_ascii_lowercase().contains("allow_http") & str_is_truthy(value)
786        })
787    }
788
789    /// Number of times to retry a download that fails
790    pub fn download_retry_count(&self) -> usize {
791        self.0
792            .iter()
793            .find(|(key, _)| key.eq_ignore_ascii_case("download_retry_count"))
794            .map(|(_, value)| value.parse::<usize>().unwrap_or(3))
795            .unwrap_or(3)
796    }
797
798    /// Max retry times to set in RetryConfig for object store client
799    pub fn client_max_retries(&self) -> usize {
800        self.0
801            .iter()
802            .find(|(key, _)| key.eq_ignore_ascii_case("client_max_retries"))
803            .and_then(|(_, value)| value.parse::<usize>().ok())
804            .unwrap_or(10)
805    }
806
807    /// Seconds of timeout to set in RetryConfig for object store client
808    pub fn client_retry_timeout(&self) -> u64 {
809        self.0
810            .iter()
811            .find(|(key, _)| key.eq_ignore_ascii_case("client_retry_timeout"))
812            .and_then(|(_, value)| value.parse::<u64>().ok())
813            .unwrap_or(180)
814    }
815
816    /// Subset of options relevant for azure storage
817    pub fn as_azure_options(&self) -> HashMap<AzureConfigKey, String> {
818        self.0
819            .iter()
820            .filter_map(|(key, value)| {
821                let az_key = AzureConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
822                Some((az_key, value.clone()))
823            })
824            .collect()
825    }
826
827    /// Subset of options relevant for s3 storage
828    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
829        self.0
830            .iter()
831            .filter_map(|(key, value)| {
832                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
833                Some((s3_key, value.clone()))
834            })
835            .collect()
836    }
837
838    /// Subset of options relevant for gcs storage
839    pub fn as_gcs_options(&self) -> HashMap<GoogleConfigKey, String> {
840        self.0
841            .iter()
842            .filter_map(|(key, value)| {
843                let gcs_key = GoogleConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
844                Some((gcs_key, value.clone()))
845            })
846            .collect()
847    }
848
849    pub fn get(&self, key: &str) -> Option<&String> {
850        self.0.get(key)
851    }
852}
853
854impl From<HashMap<String, String>> for StorageOptions {
855    fn from(value: HashMap<String, String>) -> Self {
856        Self::new(value)
857    }
858}
859
860async fn configure_store(
861    registry: Arc<ObjectStoreRegistry>,
862    url: &str,
863    options: ObjectStoreParams,
864) -> Result<ObjectStore> {
865    let mut storage_options = StorageOptions(options.storage_options.clone().unwrap_or_default());
866    let download_retry_count = storage_options.download_retry_count();
867    let mut url = ensure_table_uri(url)?;
868    // Block size: On local file systems, we use 4KB block size. On cloud
869    // object stores, we use 64KB block size. This is generally the largest
870    // block size where we don't see a latency penalty.
871    let file_block_size = options.block_size.unwrap_or(4 * 1024);
872    let cloud_block_size = options.block_size.unwrap_or(64 * 1024);
873    let max_retries = storage_options.client_max_retries();
874    let retry_timeout = storage_options.client_retry_timeout();
875    let retry_config = RetryConfig {
876        backoff: Default::default(),
877        max_retries,
878        retry_timeout: Duration::from_secs(retry_timeout),
879    };
880    match url.scheme() {
881        "s3" | "s3+ddb" => {
882            storage_options.with_env_s3();
883
884            // if url.scheme() == "s3+ddb" && options.commit_handler.is_some() {
885            //     return Err(Error::InvalidInput {
886            //         source: "`s3+ddb://` scheme and custom commit handler are mutually exclusive"
887            //             .into(),
888            //         location: location!(),
889            //     });
890            // }
891
892            let mut storage_options = storage_options.as_s3_options();
893            let region = resolve_s3_region(&url, &storage_options).await?;
894            let (aws_creds, region) = build_aws_credential(
895                options.s3_credentials_refresh_offset,
896                options.aws_credentials.clone(),
897                Some(&storage_options),
898                region,
899            )
900            .await?;
901
902            // This will be default in next version of object store.
903            // https://github.com/apache/arrow-rs/pull/7181
904            storage_options
905                .entry(AmazonS3ConfigKey::ConditionalPut)
906                .or_insert_with(|| "etag".to_string());
907
908            // Cloudflare does not support varying part sizes.
909            let use_constant_size_upload_parts = storage_options
910                .get(&AmazonS3ConfigKey::Endpoint)
911                .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
912                .unwrap_or(false);
913
914            // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
915            url.set_scheme("s3").map_err(|()| Error::Internal {
916                message: "could not set scheme".into(),
917                location: location!(),
918            })?;
919
920            url.set_query(None);
921
922            // we can't use parse_url_opts here because we need to manually set the credentials provider
923            let mut builder = AmazonS3Builder::new();
924            for (key, value) in storage_options {
925                builder = builder.with_config(key, value);
926            }
927            builder = builder
928                .with_url(url.as_ref())
929                .with_credentials(aws_creds)
930                .with_retry(retry_config)
931                .with_region(region);
932            let store = builder.build()?;
933
934            Ok(ObjectStore {
935                inner: Arc::new(store).traced(),
936                scheme: String::from(url.scheme()),
937                block_size: cloud_block_size,
938                use_constant_size_upload_parts,
939                list_is_lexically_ordered: true,
940                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
941                download_retry_count,
942            })
943        }
944        "gs" => {
945            storage_options.with_env_gcs();
946            let mut builder = GoogleCloudStorageBuilder::new()
947                .with_url(url.as_ref())
948                .with_retry(retry_config);
949            for (key, value) in storage_options.as_gcs_options() {
950                builder = builder.with_config(key, value);
951            }
952            let token_key = "google_storage_token";
953            if let Some(storage_token) = storage_options.get(token_key) {
954                let credential = GcpCredential {
955                    bearer: storage_token.to_string(),
956                };
957                let credential_provider = Arc::new(StaticCredentialProvider::new(credential)) as _;
958                builder = builder.with_credentials(credential_provider);
959            }
960            let store = builder.build()?;
961            let store = Arc::new(store).traced();
962
963            Ok(ObjectStore {
964                inner: store,
965                scheme: String::from("gs"),
966                block_size: cloud_block_size,
967                use_constant_size_upload_parts: false,
968                list_is_lexically_ordered: true,
969                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
970                download_retry_count,
971            })
972        }
973        "az" => {
974            storage_options.with_env_azure();
975            let mut builder = MicrosoftAzureBuilder::new()
976                .with_url(url.as_ref())
977                .with_retry(retry_config);
978            for (key, value) in storage_options.as_azure_options() {
979                builder = builder.with_config(key, value);
980            }
981            let store = builder.build()?;
982            let store = Arc::new(store).traced();
983
984            Ok(ObjectStore {
985                inner: store,
986                scheme: String::from("az"),
987                block_size: cloud_block_size,
988                use_constant_size_upload_parts: false,
989                list_is_lexically_ordered: true,
990                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
991                download_retry_count,
992            })
993        }
994        // we have a bypass logic to use `tokio::fs` directly to lower overhead
995        // however this makes testing harder as we can't use the same code path
996        // "file-object-store" forces local file system dataset to use the same
997        // code path as cloud object stores
998        "file" => {
999            let mut object_store = ObjectStore::from_path(url.path())?.0;
1000            object_store.set_block_size(file_block_size);
1001            Ok(object_store)
1002        }
1003        "file-object-store" => {
1004            let mut object_store =
1005                ObjectStore::from_path_with_scheme(url.path(), "file-object-store")?.0;
1006            object_store.set_block_size(file_block_size);
1007            Ok(object_store)
1008        }
1009        "memory" => Ok(ObjectStore {
1010            inner: Arc::new(InMemory::new()).traced(),
1011            scheme: String::from("memory"),
1012            block_size: file_block_size,
1013            use_constant_size_upload_parts: false,
1014            list_is_lexically_ordered: true,
1015            io_parallelism: get_num_compute_intensive_cpus(),
1016            download_retry_count,
1017        }),
1018        unknown_scheme => {
1019            if let Some(provider) = registry.providers.get(unknown_scheme) {
1020                provider.new_store(url, &options)
1021            } else {
1022                let err = lance_core::Error::from(object_store::Error::NotSupported {
1023                    source: format!("Unsupported URI scheme: {} in url {}", unknown_scheme, url)
1024                        .into(),
1025                });
1026                Err(err)
1027            }
1028        }
1029    }
1030}
1031
1032impl ObjectStore {
1033    #[allow(clippy::too_many_arguments)]
1034    pub fn new(
1035        store: Arc<DynObjectStore>,
1036        location: Url,
1037        block_size: Option<usize>,
1038        wrapper: Option<Arc<dyn WrappingObjectStore>>,
1039        use_constant_size_upload_parts: bool,
1040        list_is_lexically_ordered: bool,
1041        io_parallelism: usize,
1042        download_retry_count: usize,
1043    ) -> Self {
1044        let scheme = location.scheme();
1045        let block_size = block_size.unwrap_or_else(|| infer_block_size(scheme));
1046
1047        let store = match wrapper {
1048            Some(wrapper) => wrapper.wrap(store),
1049            None => store,
1050        };
1051
1052        Self {
1053            inner: store,
1054            scheme: scheme.into(),
1055            block_size,
1056            use_constant_size_upload_parts,
1057            list_is_lexically_ordered,
1058            io_parallelism,
1059            download_retry_count,
1060        }
1061    }
1062}
1063
1064fn infer_block_size(scheme: &str) -> usize {
1065    // Block size: On local file systems, we use 4KB block size. On cloud
1066    // object stores, we use 64KB block size. This is generally the largest
1067    // block size where we don't see a latency penalty.
1068    match scheme {
1069        "file" => 4 * 1024,
1070        _ => 64 * 1024,
1071    }
1072}
1073
1074/// Attempt to create a Url from given table location.
1075///
1076/// The location could be:
1077///  * A valid URL, which will be parsed and returned
1078///  * A path to a directory, which will be created and then converted to a URL.
1079///
1080/// If it is a local path, it will be created if it doesn't exist.
1081///
1082/// Extra slashes will be removed from the end path as well.
1083///
1084/// Will return an error if the location is not valid. For example,
1085pub fn ensure_table_uri(table_uri: impl AsRef<str>) -> Result<Url> {
1086    let table_uri = table_uri.as_ref();
1087
1088    enum UriType {
1089        LocalPath(PathBuf),
1090        Url(Url),
1091    }
1092    let uri_type: UriType = if let Ok(url) = Url::parse(table_uri) {
1093        if url.scheme() == "file" {
1094            UriType::LocalPath(url.to_file_path().map_err(|err| {
1095                let msg = format!("Invalid table location: {}\nError: {:?}", table_uri, err);
1096                Error::InvalidTableLocation { message: msg }
1097            })?)
1098        // NOTE this check is required to support absolute windows paths which may properly parse as url
1099        } else {
1100            UriType::Url(url)
1101        }
1102    } else {
1103        UriType::LocalPath(PathBuf::from(table_uri))
1104    };
1105
1106    // If it is a local path, we need to create it if it does not exist.
1107    let mut url = match uri_type {
1108        UriType::LocalPath(path) => {
1109            let path = std::fs::canonicalize(path).map_err(|err| Error::DatasetNotFound {
1110                path: table_uri.to_string(),
1111                source: Box::new(err),
1112                location: location!(),
1113            })?;
1114            Url::from_directory_path(path).map_err(|_| {
1115                let msg = format!(
1116                    "Could not construct a URL from canonicalized path: {}.\n\
1117                  Something must be very wrong with the table path.",
1118                    table_uri
1119                );
1120                Error::InvalidTableLocation { message: msg }
1121            })?
1122        }
1123        UriType::Url(url) => url,
1124    };
1125
1126    let trimmed_path = url.path().trim_end_matches('/').to_owned();
1127    url.set_path(&trimmed_path);
1128    Ok(url)
1129}
1130
1131lazy_static::lazy_static! {
1132  static ref KNOWN_SCHEMES: Vec<&'static str> =
1133      Vec::from([
1134        "s3",
1135        "s3+ddb",
1136        "gs",
1137        "az",
1138        "file",
1139        "file-object-store",
1140        "memory"
1141      ]);
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146    use super::*;
1147    use parquet::data_type::AsBytes;
1148    use rstest::rstest;
1149    use std::env::set_current_dir;
1150    use std::fs::{create_dir_all, write};
1151    use std::path::Path as StdPath;
1152    use std::sync::atomic::{AtomicBool, Ordering};
1153
1154    /// Write test content to file.
1155    fn write_to_file(path_str: &str, contents: &str) -> std::io::Result<()> {
1156        let expanded = tilde(path_str).to_string();
1157        let path = StdPath::new(&expanded);
1158        std::fs::create_dir_all(path.parent().unwrap())?;
1159        write(path, contents)
1160    }
1161
1162    async fn read_from_store(store: ObjectStore, path: &Path) -> Result<String> {
1163        let test_file_store = store.open(path).await.unwrap();
1164        let size = test_file_store.size().await.unwrap();
1165        let bytes = test_file_store.get_range(0..size).await.unwrap();
1166        let contents = String::from_utf8(bytes.to_vec()).unwrap();
1167        Ok(contents)
1168    }
1169
1170    #[tokio::test]
1171    async fn test_absolute_paths() {
1172        let tmp_dir = tempfile::tempdir().unwrap();
1173        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1174        write_to_file(
1175            &format!("{tmp_path}/bar/foo.lance/test_file"),
1176            "TEST_CONTENT",
1177        )
1178        .unwrap();
1179
1180        // test a few variations of the same path
1181        for uri in &[
1182            format!("{tmp_path}/bar/foo.lance"),
1183            format!("{tmp_path}/./bar/foo.lance"),
1184            format!("{tmp_path}/bar/foo.lance/../foo.lance"),
1185        ] {
1186            let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1187            let contents = read_from_store(store, &path.child("test_file"))
1188                .await
1189                .unwrap();
1190            assert_eq!(contents, "TEST_CONTENT");
1191        }
1192    }
1193
1194    #[tokio::test]
1195    async fn test_cloud_paths() {
1196        let uri = "s3://bucket/foo.lance";
1197        let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1198        assert_eq!(store.scheme, "s3");
1199        assert_eq!(path.to_string(), "foo.lance");
1200
1201        let (store, path) = ObjectStore::from_uri("s3+ddb://bucket/foo.lance")
1202            .await
1203            .unwrap();
1204        assert_eq!(store.scheme, "s3");
1205        assert_eq!(path.to_string(), "foo.lance");
1206
1207        let (store, path) = ObjectStore::from_uri("gs://bucket/foo.lance")
1208            .await
1209            .unwrap();
1210        assert_eq!(store.scheme, "gs");
1211        assert_eq!(path.to_string(), "foo.lance");
1212    }
1213
1214    async fn test_block_size_used_test_helper(
1215        uri: &str,
1216        storage_options: Option<HashMap<String, String>>,
1217        default_expected_block_size: usize,
1218    ) {
1219        // Test the default
1220        let registry = Arc::new(ObjectStoreRegistry::default());
1221        let params = ObjectStoreParams {
1222            storage_options: storage_options.clone(),
1223            ..ObjectStoreParams::default()
1224        };
1225        let (store, _) = ObjectStore::from_uri_and_params(registry, uri, &params)
1226            .await
1227            .unwrap();
1228        assert_eq!(store.block_size, default_expected_block_size);
1229
1230        // Ensure param is used
1231        let registry = Arc::new(ObjectStoreRegistry::default());
1232        let params = ObjectStoreParams {
1233            block_size: Some(1024),
1234            storage_options: storage_options.clone(),
1235            ..ObjectStoreParams::default()
1236        };
1237        let (store, _) = ObjectStore::from_uri_and_params(registry, uri, &params)
1238            .await
1239            .unwrap();
1240        assert_eq!(store.block_size, 1024);
1241    }
1242
1243    #[rstest]
1244    #[case("s3://bucket/foo.lance", None)]
1245    #[case("gs://bucket/foo.lance", None)]
1246    #[case("az://account/bucket/foo.lance",
1247      Some(HashMap::from([
1248            (String::from("account_name"), String::from("account")),
1249            (String::from("container_name"), String::from("container"))
1250           ])))]
1251    #[tokio::test]
1252    async fn test_block_size_used_cloud(
1253        #[case] uri: &str,
1254        #[case] storage_options: Option<HashMap<String, String>>,
1255    ) {
1256        test_block_size_used_test_helper(uri, storage_options, 64 * 1024).await;
1257    }
1258
1259    #[rstest]
1260    #[case("file")]
1261    #[case("file-object-store")]
1262    #[case("memory:///bucket/foo.lance")]
1263    #[tokio::test]
1264    async fn test_block_size_used_file(#[case] prefix: &str) {
1265        let tmp_dir = tempfile::tempdir().unwrap();
1266        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1267        let path = format!("{tmp_path}/bar/foo.lance/test_file");
1268        write_to_file(&path, "URL").unwrap();
1269        let uri = format!("{prefix}:///{path}");
1270        test_block_size_used_test_helper(&uri, None, 4 * 1024).await;
1271    }
1272
1273    #[tokio::test]
1274    async fn test_relative_paths() {
1275        let tmp_dir = tempfile::tempdir().unwrap();
1276        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1277        write_to_file(
1278            &format!("{tmp_path}/bar/foo.lance/test_file"),
1279            "RELATIVE_URL",
1280        )
1281        .unwrap();
1282
1283        set_current_dir(StdPath::new(&tmp_path)).expect("Error changing current dir");
1284        let (store, path) = ObjectStore::from_uri("./bar/foo.lance").await.unwrap();
1285
1286        let contents = read_from_store(store, &path.child("test_file"))
1287            .await
1288            .unwrap();
1289        assert_eq!(contents, "RELATIVE_URL");
1290    }
1291
1292    #[tokio::test]
1293    async fn test_tilde_expansion() {
1294        let uri = "~/foo.lance";
1295        write_to_file(&format!("{uri}/test_file"), "TILDE").unwrap();
1296        let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1297        let contents = read_from_store(store, &path.child("test_file"))
1298            .await
1299            .unwrap();
1300        assert_eq!(contents, "TILDE");
1301    }
1302
1303    #[tokio::test]
1304    async fn test_read_directory() {
1305        let tmp_dir = tempfile::tempdir().unwrap();
1306        let path = tmp_dir.path();
1307        create_dir_all(path.join("foo").join("bar")).unwrap();
1308        create_dir_all(path.join("foo").join("zoo")).unwrap();
1309        create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
1310        write_to_file(
1311            path.join("foo").join("test_file").to_str().unwrap(),
1312            "read_dir",
1313        )
1314        .unwrap();
1315        let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
1316
1317        let sub_dirs = store.read_dir(base.child("foo")).await.unwrap();
1318        assert_eq!(sub_dirs, vec!["bar", "zoo", "test_file"]);
1319    }
1320
1321    #[tokio::test]
1322    async fn test_delete_directory() {
1323        let tmp_dir = tempfile::tempdir().unwrap();
1324        let path = tmp_dir.path();
1325        create_dir_all(path.join("foo").join("bar")).unwrap();
1326        create_dir_all(path.join("foo").join("zoo")).unwrap();
1327        create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
1328        write_to_file(
1329            path.join("foo")
1330                .join("bar")
1331                .join("test_file")
1332                .to_str()
1333                .unwrap(),
1334            "delete",
1335        )
1336        .unwrap();
1337        write_to_file(path.join("foo").join("top").to_str().unwrap(), "delete_top").unwrap();
1338        let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
1339        store.remove_dir_all(base.child("foo")).await.unwrap();
1340
1341        assert!(!path.join("foo").exists());
1342    }
1343
1344    #[derive(Debug)]
1345    struct TestWrapper {
1346        called: AtomicBool,
1347
1348        return_value: Arc<dyn OSObjectStore>,
1349    }
1350
1351    impl WrappingObjectStore for TestWrapper {
1352        fn wrap(&self, _original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore> {
1353            self.called.store(true, Ordering::Relaxed);
1354
1355            // return a mocked value so we can check if the final store is the one we expect
1356            self.return_value.clone()
1357        }
1358    }
1359
1360    impl TestWrapper {
1361        fn called(&self) -> bool {
1362            self.called.load(Ordering::Relaxed)
1363        }
1364    }
1365
1366    #[tokio::test]
1367    async fn test_wrapping_object_store_option_is_used() {
1368        // Make a store for the inner store first
1369        let mock_inner_store: Arc<dyn OSObjectStore> = Arc::new(InMemory::new());
1370        let registry = Arc::new(ObjectStoreRegistry::default());
1371
1372        assert_eq!(Arc::strong_count(&mock_inner_store), 1);
1373
1374        let wrapper = Arc::new(TestWrapper {
1375            called: AtomicBool::new(false),
1376            return_value: mock_inner_store.clone(),
1377        });
1378
1379        let params = ObjectStoreParams {
1380            object_store_wrapper: Some(wrapper.clone()),
1381            ..ObjectStoreParams::default()
1382        };
1383
1384        // not called yet
1385        assert!(!wrapper.called());
1386
1387        let _ = ObjectStore::from_uri_and_params(registry, "memory:///", &params)
1388            .await
1389            .unwrap();
1390
1391        // called after construction
1392        assert!(wrapper.called());
1393
1394        // hard to compare two trait pointers as the point to vtables
1395        // using the ref count as a proxy to make sure that the store is correctly kept
1396        assert_eq!(Arc::strong_count(&mock_inner_store), 2);
1397    }
1398
1399    #[derive(Debug, Default)]
1400    struct MockAwsCredentialsProvider {
1401        called: AtomicBool,
1402    }
1403
1404    #[async_trait]
1405    impl CredentialProvider for MockAwsCredentialsProvider {
1406        type Credential = ObjectStoreAwsCredential;
1407
1408        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
1409            self.called.store(true, Ordering::Relaxed);
1410            Ok(Arc::new(Self::Credential {
1411                key_id: "".to_string(),
1412                secret_key: "".to_string(),
1413                token: None,
1414            }))
1415        }
1416    }
1417
1418    #[tokio::test]
1419    async fn test_injected_aws_creds_option_is_used() {
1420        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
1421        let registry = Arc::new(ObjectStoreRegistry::default());
1422
1423        let params = ObjectStoreParams {
1424            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
1425            ..ObjectStoreParams::default()
1426        };
1427
1428        // Not called yet
1429        assert!(!mock_provider.called.load(Ordering::Relaxed));
1430
1431        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
1432            .await
1433            .unwrap();
1434
1435        // fails, but we don't care
1436        let _ = store
1437            .open(&Path::parse("/").unwrap())
1438            .await
1439            .unwrap()
1440            .get_range(0..1)
1441            .await;
1442
1443        // Not called yet
1444        assert!(mock_provider.called.load(Ordering::Relaxed));
1445    }
1446
1447    #[tokio::test]
1448    async fn test_local_paths() {
1449        let temp_dir = tempfile::tempdir().unwrap();
1450
1451        let file_path = temp_dir.path().join("test_file");
1452        let mut writer = ObjectStore::create_local_writer(file_path.as_path())
1453            .await
1454            .unwrap();
1455        writer.write_all(b"LOCAL").await.unwrap();
1456        writer.shutdown().await.unwrap();
1457
1458        let reader = ObjectStore::open_local(file_path.as_path()).await.unwrap();
1459        let buf = reader.get_range(0..5).await.unwrap();
1460        assert_eq!(buf.as_bytes(), b"LOCAL");
1461    }
1462
1463    #[tokio::test]
1464    async fn test_read_one() {
1465        let temp_dir = tempfile::tempdir().unwrap();
1466
1467        let file_path = temp_dir.path().join("test_file");
1468        let mut writer = ObjectStore::create_local_writer(file_path.as_path())
1469            .await
1470            .unwrap();
1471        writer.write_all(b"LOCAL").await.unwrap();
1472        writer.shutdown().await.unwrap();
1473
1474        let file_path_os = object_store::path::Path::parse(file_path.to_str().unwrap()).unwrap();
1475        let obj_store = ObjectStore::local();
1476        let buf = obj_store.read_one_all(&file_path_os).await.unwrap();
1477        assert_eq!(buf.as_bytes(), b"LOCAL");
1478
1479        let buf = obj_store.read_one_range(&file_path_os, 0..5).await.unwrap();
1480        assert_eq!(buf.as_bytes(), b"LOCAL");
1481    }
1482
1483    #[tokio::test]
1484    #[cfg(windows)]
1485    async fn test_windows_paths() {
1486        use std::path::Component;
1487        use std::path::Prefix;
1488        use std::path::Prefix::*;
1489
1490        fn get_path_prefix(path: &StdPath) -> Prefix {
1491            match path.components().next().unwrap() {
1492                Component::Prefix(prefix_component) => prefix_component.kind(),
1493                _ => panic!(),
1494            }
1495        }
1496
1497        fn get_drive_letter(prefix: Prefix) -> String {
1498            match prefix {
1499                Disk(bytes) => String::from_utf8(vec![bytes]).unwrap(),
1500                _ => panic!(),
1501            }
1502        }
1503
1504        let tmp_dir = tempfile::tempdir().unwrap();
1505        let tmp_path = tmp_dir.path();
1506        let prefix = get_path_prefix(tmp_path);
1507        let drive_letter = get_drive_letter(prefix);
1508
1509        write_to_file(
1510            &(format!("{drive_letter}:/test_folder/test.lance") + "/test_file"),
1511            "WINDOWS",
1512        )
1513        .unwrap();
1514
1515        for uri in &[
1516            format!("{drive_letter}:/test_folder/test.lance"),
1517            format!("{drive_letter}:\\test_folder\\test.lance"),
1518        ] {
1519            let (store, base) = ObjectStore::from_uri(uri).await.unwrap();
1520            let contents = read_from_store(store, &base.child("test_file"))
1521                .await
1522                .unwrap();
1523            assert_eq!(contents, "WINDOWS");
1524        }
1525    }
1526}