lance_io/
object_store.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extend [object_store::ObjectStore] functionalities
5
6use std::collections::HashMap;
7use std::ops::Range;
8use std::path::PathBuf;
9use std::str::FromStr;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12
13use async_trait::async_trait;
14use aws_config::default_provider::credentials::DefaultCredentialsChain;
15use aws_credential_types::provider::ProvideCredentials;
16use bytes::Bytes;
17use chrono::{DateTime, Utc};
18use deepsize::DeepSizeOf;
19use futures::{future, stream::BoxStream, StreamExt, TryStreamExt};
20use lance_core::utils::parse::str_is_truthy;
21use lance_core::utils::tokio::get_num_compute_intensive_cpus;
22use object_store::aws::{
23    AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential, AwsCredentialProvider,
24};
25use object_store::gcp::GoogleCloudStorageBuilder;
26use object_store::{
27    aws::AmazonS3Builder, azure::AzureConfigKey, gcp::GoogleConfigKey, local::LocalFileSystem,
28    memory::InMemory, CredentialProvider, Error as ObjectStoreError, Result as ObjectStoreResult,
29};
30use object_store::{
31    parse_url_opts, ClientOptions, DynObjectStore, RetryConfig, StaticCredentialProvider,
32};
33use object_store::{path::Path, ObjectMeta, ObjectStore as OSObjectStore};
34use shellexpand::tilde;
35use snafu::location;
36use tokio::io::AsyncWriteExt;
37use tokio::sync::RwLock;
38use url::Url;
39
40use super::local::LocalObjectReader;
41mod tracing;
42use self::tracing::ObjectStoreTracingExt;
43use crate::{object_reader::CloudObjectReader, object_writer::ObjectWriter, traits::Reader};
44use lance_core::{Error, Result};
45
46// Local disks tend to do fine with a few threads
47// Note: the number of threads here also impacts the number of files
48// we need to read in some situations.  So keeping this at 8 keeps the
49// RAM on our scanner down.
50pub const DEFAULT_LOCAL_IO_PARALLELISM: usize = 8;
51// Cloud disks often need many many threads to saturate the network
52pub const DEFAULT_CLOUD_IO_PARALLELISM: usize = 64;
53
54pub const DEFAULT_DOWNLOAD_RETRY_COUNT: usize = 3;
55
56#[async_trait]
57pub trait ObjectStoreExt {
58    /// Returns true if the file exists.
59    async fn exists(&self, path: &Path) -> Result<bool>;
60
61    /// Read all files (start from base directory) recursively
62    ///
63    /// unmodified_since can be specified to only return files that have not been modified since the given time.
64    async fn read_dir_all<'a>(
65        &'a self,
66        dir_path: impl Into<&Path> + Send,
67        unmodified_since: Option<DateTime<Utc>>,
68    ) -> Result<BoxStream<'a, Result<ObjectMeta>>>;
69}
70
71#[async_trait]
72impl<O: OSObjectStore + ?Sized> ObjectStoreExt for O {
73    async fn read_dir_all<'a>(
74        &'a self,
75        dir_path: impl Into<&Path> + Send,
76        unmodified_since: Option<DateTime<Utc>>,
77    ) -> Result<BoxStream<'a, Result<ObjectMeta>>> {
78        let mut output = self.list(Some(dir_path.into()));
79        if let Some(unmodified_since_val) = unmodified_since {
80            output = output
81                .try_filter(move |file| future::ready(file.last_modified < unmodified_since_val))
82                .boxed();
83        }
84        Ok(output.map_err(|e| e.into()).boxed())
85    }
86
87    async fn exists(&self, path: &Path) -> Result<bool> {
88        match self.head(path).await {
89            Ok(_) => Ok(true),
90            Err(object_store::Error::NotFound { path: _, source: _ }) => Ok(false),
91            Err(e) => Err(e.into()),
92        }
93    }
94}
95
96/// Wraps [ObjectStore](object_store::ObjectStore)
97#[derive(Debug, Clone)]
98pub struct ObjectStore {
99    // Inner object store
100    pub inner: Arc<dyn OSObjectStore>,
101    scheme: String,
102    block_size: usize,
103    /// Whether to use constant size upload parts for multipart uploads. This
104    /// is only necessary for Cloudflare R2.
105    pub use_constant_size_upload_parts: bool,
106    /// Whether we can assume that the list of files is lexically ordered. This
107    /// is true for object stores, but not for local filesystems.
108    pub list_is_lexically_ordered: bool,
109    io_parallelism: usize,
110    /// Number of times to retry a failed download
111    download_retry_count: usize,
112}
113
114impl DeepSizeOf for ObjectStore {
115    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
116        // We aren't counting `inner` here which is problematic but an ObjectStore
117        // shouldn't be too big.  The only exception might be the write cache but, if
118        // the writer cache has data, it means we're using it somewhere else that isn't
119        // a cache and so that doesn't really count.
120        self.scheme.deep_size_of_children(context) + self.block_size.deep_size_of_children(context)
121    }
122}
123
124impl std::fmt::Display for ObjectStore {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(f, "ObjectStore({})", self.scheme)
127    }
128}
129
130pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
131    fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
132}
133
134#[derive(Default, Debug)]
135pub struct ObjectStoreRegistry {
136    providers: HashMap<String, Arc<dyn ObjectStoreProvider>>,
137}
138
139impl ObjectStoreRegistry {
140    pub fn insert(&mut self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
141        self.providers.insert(scheme.into(), provider);
142    }
143}
144
145const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
146
147/// Adapt an AWS SDK cred into object_store credentials
148#[derive(Debug)]
149pub struct AwsCredentialAdapter {
150    pub inner: Arc<dyn ProvideCredentials>,
151
152    // RefCell can't be shared across threads, so we use HashMap
153    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
154
155    // The amount of time before expiry to refresh credentials
156    credentials_refresh_offset: Duration,
157}
158
159impl AwsCredentialAdapter {
160    pub fn new(
161        provider: Arc<dyn ProvideCredentials>,
162        credentials_refresh_offset: Duration,
163    ) -> Self {
164        Self {
165            inner: provider,
166            cache: Arc::new(RwLock::new(HashMap::new())),
167            credentials_refresh_offset,
168        }
169    }
170}
171
172#[async_trait]
173impl CredentialProvider for AwsCredentialAdapter {
174    type Credential = ObjectStoreAwsCredential;
175
176    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
177        let cached_creds = {
178            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
179            let expired = cache_value
180                .clone()
181                .map(|cred| {
182                    cred.expiry()
183                        .map(|exp| {
184                            exp.checked_sub(self.credentials_refresh_offset)
185                                .expect("this time should always be valid")
186                                < SystemTime::now()
187                        })
188                        // no expiry is never expire
189                        .unwrap_or(false)
190                })
191                .unwrap_or(true); // no cred is the same as expired;
192            if expired {
193                None
194            } else {
195                cache_value.clone()
196            }
197        };
198
199        if let Some(creds) = cached_creds {
200            Ok(Arc::new(Self::Credential {
201                key_id: creds.access_key_id().to_string(),
202                secret_key: creds.secret_access_key().to_string(),
203                token: creds.session_token().map(|s| s.to_string()),
204            }))
205        } else {
206            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
207                |e| Error::Internal {
208                    message: format!("Failed to get AWS credentials: {}", e),
209                    location: location!(),
210                },
211            )?);
212
213            self.cache
214                .write()
215                .await
216                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
217
218            Ok(Arc::new(Self::Credential {
219                key_id: refreshed_creds.access_key_id().to_string(),
220                secret_key: refreshed_creds.secret_access_key().to_string(),
221                token: refreshed_creds.session_token().map(|s| s.to_string()),
222            }))
223        }
224    }
225}
226
227/// Figure out the S3 region of the bucket.
228///
229/// This resolves in order of precedence:
230/// 1. The region provided in the storage options
231/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
232///
233/// It can return None if no region is provided and the endpoint is set.
234async fn resolve_s3_region(
235    url: &Url,
236    storage_options: &HashMap<AmazonS3ConfigKey, String>,
237) -> Result<Option<String>> {
238    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
239        Ok(Some(region.clone()))
240    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
241        // If no endpoint is set, we can assume this is AWS S3 and the region
242        // can be resolved from the bucket.
243        let bucket = url.host_str().ok_or_else(|| {
244            Error::invalid_input(
245                format!("Could not parse bucket from url: {}", url),
246                location!(),
247            )
248        })?;
249
250        let mut client_options = ClientOptions::default();
251        for (key, value) in storage_options {
252            if let AmazonS3ConfigKey::Client(client_key) = key {
253                client_options = client_options.with_config(*client_key, value.clone());
254            }
255        }
256
257        let bucket_region =
258            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
259        Ok(Some(bucket_region))
260    } else {
261        Ok(None)
262    }
263}
264
265/// Build AWS credentials
266///
267/// This resolves credentials from the following sources in order:
268/// 1. An explicit `credentials` provider
269/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`,
270///    `aws_secret_access_key`, `aws_session_token`)
271/// 3. The default credential provider chain from AWS SDK.
272///
273/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
274pub async fn build_aws_credential(
275    credentials_refresh_offset: Duration,
276    credentials: Option<AwsCredentialProvider>,
277    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
278    region: Option<String>,
279) -> Result<(AwsCredentialProvider, String)> {
280    // TODO: make this return no credential provider not using AWS
281    use aws_config::meta::region::RegionProviderChain;
282    const DEFAULT_REGION: &str = "us-west-2";
283
284    let region = if let Some(region) = region {
285        region
286    } else {
287        RegionProviderChain::default_provider()
288            .or_else(DEFAULT_REGION)
289            .region()
290            .await
291            .map(|r| r.as_ref().to_string())
292            .unwrap_or(DEFAULT_REGION.to_string())
293    };
294
295    if let Some(creds) = credentials {
296        Ok((creds, region))
297    } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
298        Ok((Arc::new(creds), region))
299    } else {
300        let credentials_provider = DefaultCredentialsChain::builder().build().await;
301
302        Ok((
303            Arc::new(AwsCredentialAdapter::new(
304                Arc::new(credentials_provider),
305                credentials_refresh_offset,
306            )),
307            region,
308        ))
309    }
310}
311
312fn extract_static_s3_credentials(
313    options: &HashMap<AmazonS3ConfigKey, String>,
314) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
315    let key_id = options
316        .get(&AmazonS3ConfigKey::AccessKeyId)
317        .map(|s| s.to_string());
318    let secret_key = options
319        .get(&AmazonS3ConfigKey::SecretAccessKey)
320        .map(|s| s.to_string());
321    let token = options
322        .get(&AmazonS3ConfigKey::Token)
323        .map(|s| s.to_string());
324    match (key_id, secret_key, token) {
325        (Some(key_id), Some(secret_key), token) => {
326            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
327                key_id,
328                secret_key,
329                token,
330            }))
331        }
332        _ => None,
333    }
334}
335
336pub trait WrappingObjectStore: std::fmt::Debug + Send + Sync {
337    fn wrap(&self, original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore>;
338}
339
340/// Parameters to create an [ObjectStore]
341///
342#[derive(Debug, Clone)]
343pub struct ObjectStoreParams {
344    pub block_size: Option<usize>,
345    pub object_store: Option<(Arc<DynObjectStore>, Url)>,
346    pub s3_credentials_refresh_offset: Duration,
347    pub aws_credentials: Option<AwsCredentialProvider>,
348    pub object_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
349    pub storage_options: Option<HashMap<String, String>>,
350    /// Use constant size upload parts for multipart uploads. Only necessary
351    /// for Cloudflare R2, which doesn't support variable size parts. When this
352    /// is false, max upload size is 2.5TB. When this is true, the max size is
353    /// 50GB.
354    pub use_constant_size_upload_parts: bool,
355    pub list_is_lexically_ordered: Option<bool>,
356}
357
358impl Default for ObjectStoreParams {
359    fn default() -> Self {
360        Self {
361            object_store: None,
362            block_size: None,
363            s3_credentials_refresh_offset: Duration::from_secs(60),
364            aws_credentials: None,
365            object_store_wrapper: None,
366            storage_options: None,
367            use_constant_size_upload_parts: false,
368            list_is_lexically_ordered: None,
369        }
370    }
371}
372
373impl ObjectStoreParams {
374    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
375    pub fn with_aws_credentials(
376        aws_credentials: Option<AwsCredentialProvider>,
377        region: Option<String>,
378    ) -> Self {
379        Self {
380            aws_credentials,
381            storage_options: region
382                .map(|region| [("region".into(), region)].iter().cloned().collect()),
383            ..Default::default()
384        }
385    }
386}
387
388impl ObjectStore {
389    /// Parse from a string URI.
390    ///
391    /// Returns the ObjectStore instance and the absolute path to the object.
392    pub async fn from_uri(uri: &str) -> Result<(Self, Path)> {
393        let registry = Arc::new(ObjectStoreRegistry::default());
394
395        Self::from_uri_and_params(registry, uri, &ObjectStoreParams::default()).await
396    }
397
398    /// Parse from a string URI.
399    ///
400    /// Returns the ObjectStore instance and the absolute path to the object.
401    pub async fn from_uri_and_params(
402        registry: Arc<ObjectStoreRegistry>,
403        uri: &str,
404        params: &ObjectStoreParams,
405    ) -> Result<(Self, Path)> {
406        if let Some((store, path)) = params.object_store.as_ref() {
407            let mut inner = store.clone();
408            if let Some(wrapper) = params.object_store_wrapper.as_ref() {
409                inner = wrapper.wrap(inner);
410            }
411            let store = Self {
412                inner,
413                scheme: path.scheme().to_string(),
414                block_size: params.block_size.unwrap_or(64 * 1024),
415                use_constant_size_upload_parts: params.use_constant_size_upload_parts,
416                list_is_lexically_ordered: params.list_is_lexically_ordered.unwrap_or_default(),
417                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
418                download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
419            };
420            let path = Path::from(path.path());
421            return Ok((store, path));
422        }
423        let (object_store, path) = match Url::parse(uri) {
424            Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
425                // On Windows, the drive is parsed as a scheme
426                Self::from_path(uri)
427            }
428            Ok(url) => {
429                let store = Self::new_from_url(registry, url.clone(), params.clone()).await?;
430                Ok((store, Path::from(url.path())))
431            }
432            Err(_) => Self::from_path(uri),
433        }?;
434
435        Ok((
436            Self {
437                inner: params
438                    .object_store_wrapper
439                    .as_ref()
440                    .map(|w| w.wrap(object_store.inner.clone()))
441                    .unwrap_or(object_store.inner),
442                ..object_store
443            },
444            path,
445        ))
446    }
447
448    pub fn from_path_with_scheme(str_path: &str, scheme: &str) -> Result<(Self, Path)> {
449        let expanded = tilde(str_path).to_string();
450
451        let mut expanded_path = path_abs::PathAbs::new(expanded)
452            .unwrap()
453            .as_path()
454            .to_path_buf();
455        // path_abs::PathAbs::new(".") returns an empty string.
456        if let Some(s) = expanded_path.as_path().to_str() {
457            if s.is_empty() {
458                expanded_path = std::env::current_dir()?;
459            }
460        }
461        Ok((
462            Self {
463                inner: Arc::new(LocalFileSystem::new()).traced(),
464                scheme: String::from(scheme),
465                block_size: 4 * 1024, // 4KB block size
466                use_constant_size_upload_parts: false,
467                list_is_lexically_ordered: false,
468                io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM,
469                download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
470            },
471            Path::from_absolute_path(expanded_path.as_path())?,
472        ))
473    }
474
475    pub fn from_path(str_path: &str) -> Result<(Self, Path)> {
476        Self::from_path_with_scheme(str_path, "file")
477    }
478
479    async fn new_from_url(
480        registry: Arc<ObjectStoreRegistry>,
481        url: Url,
482        params: ObjectStoreParams,
483    ) -> Result<Self> {
484        configure_store(registry, url.as_str(), params).await
485    }
486
487    /// Local object store.
488    pub fn local() -> Self {
489        Self {
490            inner: Arc::new(LocalFileSystem::new()).traced(),
491            scheme: String::from("file"),
492            block_size: 4 * 1024, // 4KB block size
493            use_constant_size_upload_parts: false,
494            list_is_lexically_ordered: false,
495            io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM,
496            download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
497        }
498    }
499
500    /// Create a in-memory object store directly for testing.
501    pub fn memory() -> Self {
502        Self {
503            inner: Arc::new(InMemory::new()).traced(),
504            scheme: String::from("memory"),
505            block_size: 4 * 1024,
506            use_constant_size_upload_parts: false,
507            list_is_lexically_ordered: true,
508            io_parallelism: get_num_compute_intensive_cpus(),
509            download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
510        }
511    }
512
513    /// Returns true if the object store pointed to a local file system.
514    pub fn is_local(&self) -> bool {
515        self.scheme == "file"
516    }
517
518    pub fn is_cloud(&self) -> bool {
519        self.scheme != "file" && self.scheme != "memory"
520    }
521
522    pub fn block_size(&self) -> usize {
523        self.block_size
524    }
525
526    pub fn set_block_size(&mut self, new_size: usize) {
527        self.block_size = new_size;
528    }
529
530    pub fn set_io_parallelism(&mut self, io_parallelism: usize) {
531        self.io_parallelism = io_parallelism;
532    }
533
534    pub fn io_parallelism(&self) -> usize {
535        std::env::var("LANCE_IO_THREADS")
536            .map(|val| val.parse::<usize>().unwrap())
537            .unwrap_or(self.io_parallelism)
538    }
539
540    /// Open a file for path.
541    ///
542    /// Parameters
543    /// - ``path``: Absolute path to the file.
544    pub async fn open(&self, path: &Path) -> Result<Box<dyn Reader>> {
545        match self.scheme.as_str() {
546            "file" => LocalObjectReader::open(path, self.block_size, None).await,
547            _ => Ok(Box::new(CloudObjectReader::new(
548                self.inner.clone(),
549                path.clone(),
550                self.block_size,
551                None,
552                self.download_retry_count,
553            )?)),
554        }
555    }
556
557    /// Open a reader for a file with known size.
558    ///
559    /// This size may either have been retrieved from a list operation or
560    /// cached metadata. By passing in the known size, we can skip a HEAD / metadata
561    /// call.
562    pub async fn open_with_size(&self, path: &Path, known_size: usize) -> Result<Box<dyn Reader>> {
563        match self.scheme.as_str() {
564            "file" => LocalObjectReader::open(path, self.block_size, Some(known_size)).await,
565            _ => Ok(Box::new(CloudObjectReader::new(
566                self.inner.clone(),
567                path.clone(),
568                self.block_size,
569                Some(known_size),
570                self.download_retry_count,
571            )?)),
572        }
573    }
574
575    /// Create an [ObjectWriter] from local [std::path::Path]
576    pub async fn create_local_writer(path: &std::path::Path) -> Result<ObjectWriter> {
577        let object_store = Self::local();
578        let os_path = Path::from(path.to_str().unwrap());
579        object_store.create(&os_path).await
580    }
581
582    /// Open an [Reader] from local [std::path::Path]
583    pub async fn open_local(path: &std::path::Path) -> Result<Box<dyn Reader>> {
584        let object_store = Self::local();
585        let os_path = Path::from(path.to_str().unwrap());
586        object_store.open(&os_path).await
587    }
588
589    /// Create a new file.
590    pub async fn create(&self, path: &Path) -> Result<ObjectWriter> {
591        ObjectWriter::new(self, path).await
592    }
593
594    /// A helper function to create a file and write content to it.
595    pub async fn put(&self, path: &Path, content: &[u8]) -> Result<()> {
596        let mut writer = self.create(path).await?;
597        writer.write_all(content).await?;
598        writer.shutdown().await
599    }
600
601    pub async fn delete(&self, path: &Path) -> Result<()> {
602        self.inner.delete(path).await?;
603        Ok(())
604    }
605
606    pub async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
607        Ok(self.inner.copy(from, to).await?)
608    }
609
610    /// Read a directory (start from base directory) and returns all sub-paths in the directory.
611    pub async fn read_dir(&self, dir_path: impl Into<Path>) -> Result<Vec<String>> {
612        let path = dir_path.into();
613        let path = Path::parse(&path)?;
614        let output = self.inner.list_with_delimiter(Some(&path)).await?;
615        Ok(output
616            .common_prefixes
617            .iter()
618            .chain(output.objects.iter().map(|o| &o.location))
619            .map(|s| s.filename().unwrap().to_string())
620            .collect())
621    }
622
623    /// Read all files (start from base directory) recursively
624    ///
625    /// unmodified_since can be specified to only return files that have not been modified since the given time.
626    pub async fn read_dir_all(
627        &self,
628        dir_path: impl Into<&Path> + Send,
629        unmodified_since: Option<DateTime<Utc>>,
630    ) -> Result<BoxStream<Result<ObjectMeta>>> {
631        self.inner.read_dir_all(dir_path, unmodified_since).await
632    }
633
634    /// Remove a directory recursively.
635    pub async fn remove_dir_all(&self, dir_path: impl Into<Path>) -> Result<()> {
636        let path = dir_path.into();
637        let path = Path::parse(&path)?;
638
639        if self.is_local() {
640            // Local file system needs to delete directories as well.
641            return super::local::remove_dir_all(&path);
642        }
643        let sub_entries = self
644            .inner
645            .list(Some(&path))
646            .map(|m| m.map(|meta| meta.location))
647            .boxed();
648        self.inner
649            .delete_stream(sub_entries)
650            .try_collect::<Vec<_>>()
651            .await?;
652        Ok(())
653    }
654
655    pub fn remove_stream<'a>(
656        &'a self,
657        locations: BoxStream<'a, Result<Path>>,
658    ) -> BoxStream<'a, Result<Path>> {
659        self.inner
660            .delete_stream(locations.err_into::<ObjectStoreError>().boxed())
661            .err_into::<Error>()
662            .boxed()
663    }
664
665    /// Check a file exists.
666    pub async fn exists(&self, path: &Path) -> Result<bool> {
667        match self.inner.head(path).await {
668            Ok(_) => Ok(true),
669            Err(object_store::Error::NotFound { path: _, source: _ }) => Ok(false),
670            Err(e) => Err(e.into()),
671        }
672    }
673
674    /// Get file size.
675    pub async fn size(&self, path: &Path) -> Result<usize> {
676        Ok(self.inner.head(path).await?.size)
677    }
678
679    /// Convenience function to open a reader and read all the bytes
680    pub async fn read_one_all(&self, path: &Path) -> Result<Bytes> {
681        let reader = self.open(path).await?;
682        Ok(reader.get_all().await?)
683    }
684
685    /// Convenience function open a reader and make a single request
686    ///
687    /// If you will be making multiple requests to the path it is more efficient to call [`Self::open`]
688    /// and then call [`Reader::get_range`] multiple times.
689    pub async fn read_one_range(&self, path: &Path, range: Range<usize>) -> Result<Bytes> {
690        let reader = self.open(path).await?;
691        Ok(reader.get_range(range).await?)
692    }
693}
694
695/// Options that can be set for multiple object stores
696#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy)]
697pub enum LanceConfigKey {
698    /// Number of times to retry a download that fails
699    DownloadRetryCount,
700}
701
702impl FromStr for LanceConfigKey {
703    type Err = Error;
704
705    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
706        match s.to_ascii_lowercase().as_str() {
707            "download_retry_count" => Ok(Self::DownloadRetryCount),
708            _ => Err(Error::InvalidInput {
709                source: format!("Invalid LanceConfigKey: {}", s).into(),
710                location: location!(),
711            }),
712        }
713    }
714}
715
716#[derive(Clone, Debug, Default)]
717pub struct StorageOptions(pub HashMap<String, String>);
718
719impl StorageOptions {
720    /// Create a new instance of [`StorageOptions`]
721    pub fn new(options: HashMap<String, String>) -> Self {
722        let mut options = options;
723        if let Ok(value) = std::env::var("AZURE_STORAGE_ALLOW_HTTP") {
724            options.insert("allow_http".into(), value);
725        }
726        if let Ok(value) = std::env::var("AZURE_STORAGE_USE_HTTP") {
727            options.insert("allow_http".into(), value);
728        }
729        if let Ok(value) = std::env::var("AWS_ALLOW_HTTP") {
730            options.insert("allow_http".into(), value);
731        }
732        Self(options)
733    }
734
735    /// Add values from the environment to storage options
736    pub fn with_env_azure(&mut self) {
737        for (os_key, os_value) in std::env::vars_os() {
738            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
739                if let Ok(config_key) = AzureConfigKey::from_str(&key.to_ascii_lowercase()) {
740                    if !self.0.contains_key(config_key.as_ref()) {
741                        self.0
742                            .insert(config_key.as_ref().to_string(), value.to_string());
743                    }
744                }
745            }
746        }
747    }
748
749    /// Add values from the environment to storage options
750    pub fn with_env_gcs(&mut self) {
751        for (os_key, os_value) in std::env::vars_os() {
752            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
753                if let Ok(config_key) = GoogleConfigKey::from_str(&key.to_ascii_lowercase()) {
754                    if !self.0.contains_key(config_key.as_ref()) {
755                        self.0
756                            .insert(config_key.as_ref().to_string(), value.to_string());
757                    }
758                }
759            }
760        }
761    }
762
763    /// Add values from the environment to storage options
764    pub fn with_env_s3(&mut self) {
765        for (os_key, os_value) in std::env::vars_os() {
766            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
767                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
768                    if !self.0.contains_key(config_key.as_ref()) {
769                        self.0
770                            .insert(config_key.as_ref().to_string(), value.to_string());
771                    }
772                }
773            }
774        }
775    }
776
777    /// Denotes if unsecure connections via http are allowed
778    pub fn allow_http(&self) -> bool {
779        self.0.iter().any(|(key, value)| {
780            key.to_ascii_lowercase().contains("allow_http") & str_is_truthy(value)
781        })
782    }
783
784    /// Number of times to retry a download that fails
785    pub fn download_retry_count(&self) -> usize {
786        self.0
787            .iter()
788            .find(|(key, _)| key.eq_ignore_ascii_case("download_retry_count"))
789            .map(|(_, value)| value.parse::<usize>().unwrap_or(3))
790            .unwrap_or(3)
791    }
792
793    /// Max retry times to set in RetryConfig for s3 client
794    pub fn client_max_retries(&self) -> usize {
795        self.0
796            .iter()
797            .find(|(key, _)| key.eq_ignore_ascii_case("client_max_retries"))
798            .and_then(|(_, value)| value.parse::<usize>().ok())
799            .unwrap_or(10)
800    }
801
802    /// Seconds of timeout to set in RetryConfig for s3 client
803    pub fn client_retry_timeout(&self) -> u64 {
804        self.0
805            .iter()
806            .find(|(key, _)| key.eq_ignore_ascii_case("client_retry_timeout"))
807            .and_then(|(_, value)| value.parse::<u64>().ok())
808            .unwrap_or(180)
809    }
810
811    /// Subset of options relevant for azure storage
812    pub fn as_azure_options(&self) -> HashMap<AzureConfigKey, String> {
813        self.0
814            .iter()
815            .filter_map(|(key, value)| {
816                let az_key = AzureConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
817                Some((az_key, value.clone()))
818            })
819            .collect()
820    }
821
822    /// Subset of options relevant for s3 storage
823    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
824        self.0
825            .iter()
826            .filter_map(|(key, value)| {
827                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
828                Some((s3_key, value.clone()))
829            })
830            .collect()
831    }
832
833    /// Subset of options relevant for gcs storage
834    pub fn as_gcs_options(&self) -> HashMap<GoogleConfigKey, String> {
835        self.0
836            .iter()
837            .filter_map(|(key, value)| {
838                let gcs_key = GoogleConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
839                Some((gcs_key, value.clone()))
840            })
841            .collect()
842    }
843}
844
845impl From<HashMap<String, String>> for StorageOptions {
846    fn from(value: HashMap<String, String>) -> Self {
847        Self::new(value)
848    }
849}
850
851async fn configure_store(
852    registry: Arc<ObjectStoreRegistry>,
853    url: &str,
854    options: ObjectStoreParams,
855) -> Result<ObjectStore> {
856    let mut storage_options = StorageOptions(options.storage_options.clone().unwrap_or_default());
857    let download_retry_count = storage_options.download_retry_count();
858    let mut url = ensure_table_uri(url)?;
859    // Block size: On local file systems, we use 4KB block size. On cloud
860    // object stores, we use 64KB block size. This is generally the largest
861    // block size where we don't see a latency penalty.
862    let file_block_size = options.block_size.unwrap_or(4 * 1024);
863    let cloud_block_size = options.block_size.unwrap_or(64 * 1024);
864    match url.scheme() {
865        "s3" | "s3+ddb" => {
866            storage_options.with_env_s3();
867
868            // if url.scheme() == "s3+ddb" && options.commit_handler.is_some() {
869            //     return Err(Error::InvalidInput {
870            //         source: "`s3+ddb://` scheme and custom commit handler are mutually exclusive"
871            //             .into(),
872            //         location: location!(),
873            //     });
874            // }
875
876            let max_retries = storage_options.client_max_retries();
877            let retry_timeout = storage_options.client_retry_timeout();
878            let retry_config = RetryConfig {
879                backoff: Default::default(),
880                max_retries,
881                retry_timeout: Duration::from_secs(retry_timeout),
882            };
883            let storage_options = storage_options.as_s3_options();
884            let region = resolve_s3_region(&url, &storage_options).await?;
885            let (aws_creds, region) = build_aws_credential(
886                options.s3_credentials_refresh_offset,
887                options.aws_credentials.clone(),
888                Some(&storage_options),
889                region,
890            )
891            .await?;
892
893            // Cloudflare does not support varying part sizes.
894            let use_constant_size_upload_parts = storage_options
895                .get(&AmazonS3ConfigKey::Endpoint)
896                .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
897                .unwrap_or(false);
898
899            // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
900            url.set_scheme("s3").map_err(|()| Error::Internal {
901                message: "could not set scheme".into(),
902                location: location!(),
903            })?;
904
905            url.set_query(None);
906
907            // we can't use parse_url_opts here because we need to manually set the credentials provider
908            let mut builder = AmazonS3Builder::new();
909            for (key, value) in storage_options {
910                builder = builder.with_config(key, value);
911            }
912            builder = builder
913                .with_url(url.as_ref())
914                .with_credentials(aws_creds)
915                .with_retry(retry_config)
916                .with_region(region);
917            let store = builder.build()?;
918
919            Ok(ObjectStore {
920                inner: Arc::new(store).traced(),
921                scheme: String::from(url.scheme()),
922                block_size: cloud_block_size,
923                use_constant_size_upload_parts,
924                list_is_lexically_ordered: true,
925                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
926                download_retry_count,
927            })
928        }
929        "gs" => {
930            storage_options.with_env_gcs();
931            let mut builder = GoogleCloudStorageBuilder::new().with_url(url.as_ref());
932            for (key, value) in storage_options.as_gcs_options() {
933                builder = builder.with_config(key, value);
934            }
935            let store = builder.build()?;
936            let store = Arc::new(store).traced();
937
938            Ok(ObjectStore {
939                inner: store,
940                scheme: String::from("gs"),
941                block_size: cloud_block_size,
942                use_constant_size_upload_parts: false,
943                list_is_lexically_ordered: true,
944                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
945                download_retry_count,
946            })
947        }
948        "az" => {
949            storage_options.with_env_azure();
950            let (store, _) = parse_url_opts(&url, storage_options.as_azure_options())?;
951            let store = Arc::new(store).traced();
952
953            Ok(ObjectStore {
954                inner: store,
955                scheme: String::from("az"),
956                block_size: cloud_block_size,
957                use_constant_size_upload_parts: false,
958                list_is_lexically_ordered: true,
959                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
960                download_retry_count,
961            })
962        }
963        // we have a bypass logic to use `tokio::fs` directly to lower overhead
964        // however this makes testing harder as we can't use the same code path
965        // "file-object-store" forces local file system dataset to use the same
966        // code path as cloud object stores
967        "file" => {
968            let mut object_store = ObjectStore::from_path(url.path())?.0;
969            object_store.set_block_size(file_block_size);
970            Ok(object_store)
971        }
972        "file-object-store" => {
973            let mut object_store =
974                ObjectStore::from_path_with_scheme(url.path(), "file-object-store")?.0;
975            object_store.set_block_size(file_block_size);
976            Ok(object_store)
977        }
978        "memory" => Ok(ObjectStore {
979            inner: Arc::new(InMemory::new()).traced(),
980            scheme: String::from("memory"),
981            block_size: file_block_size,
982            use_constant_size_upload_parts: false,
983            list_is_lexically_ordered: true,
984            io_parallelism: get_num_compute_intensive_cpus(),
985            download_retry_count,
986        }),
987        unknown_scheme => {
988            if let Some(provider) = registry.providers.get(unknown_scheme) {
989                provider.new_store(url, &options)
990            } else {
991                let err = lance_core::Error::from(object_store::Error::NotSupported {
992                    source: format!("Unsupported URI scheme: {} in url {}", unknown_scheme, url)
993                        .into(),
994                });
995                Err(err)
996            }
997        }
998    }
999}
1000
1001impl ObjectStore {
1002    #[allow(clippy::too_many_arguments)]
1003    pub fn new(
1004        store: Arc<DynObjectStore>,
1005        location: Url,
1006        block_size: Option<usize>,
1007        wrapper: Option<Arc<dyn WrappingObjectStore>>,
1008        use_constant_size_upload_parts: bool,
1009        list_is_lexically_ordered: bool,
1010        io_parallelism: usize,
1011        download_retry_count: usize,
1012    ) -> Self {
1013        let scheme = location.scheme();
1014        let block_size = block_size.unwrap_or_else(|| infer_block_size(scheme));
1015
1016        let store = match wrapper {
1017            Some(wrapper) => wrapper.wrap(store),
1018            None => store,
1019        };
1020
1021        Self {
1022            inner: store,
1023            scheme: scheme.into(),
1024            block_size,
1025            use_constant_size_upload_parts,
1026            list_is_lexically_ordered,
1027            io_parallelism,
1028            download_retry_count,
1029        }
1030    }
1031}
1032
1033fn infer_block_size(scheme: &str) -> usize {
1034    // Block size: On local file systems, we use 4KB block size. On cloud
1035    // object stores, we use 64KB block size. This is generally the largest
1036    // block size where we don't see a latency penalty.
1037    match scheme {
1038        "file" => 4 * 1024,
1039        _ => 64 * 1024,
1040    }
1041}
1042
1043/// Attempt to create a Url from given table location.
1044///
1045/// The location could be:
1046///  * A valid URL, which will be parsed and returned
1047///  * A path to a directory, which will be created and then converted to a URL.
1048///
1049/// If it is a local path, it will be created if it doesn't exist.
1050///
1051/// Extra slashes will be removed from the end path as well.
1052///
1053/// Will return an error if the location is not valid. For example,
1054pub fn ensure_table_uri(table_uri: impl AsRef<str>) -> Result<Url> {
1055    let table_uri = table_uri.as_ref();
1056
1057    enum UriType {
1058        LocalPath(PathBuf),
1059        Url(Url),
1060    }
1061    let uri_type: UriType = if let Ok(url) = Url::parse(table_uri) {
1062        if url.scheme() == "file" {
1063            UriType::LocalPath(url.to_file_path().map_err(|err| {
1064                let msg = format!("Invalid table location: {}\nError: {:?}", table_uri, err);
1065                Error::InvalidTableLocation { message: msg }
1066            })?)
1067        // NOTE this check is required to support absolute windows paths which may properly parse as url
1068        } else {
1069            UriType::Url(url)
1070        }
1071    } else {
1072        UriType::LocalPath(PathBuf::from(table_uri))
1073    };
1074
1075    // If it is a local path, we need to create it if it does not exist.
1076    let mut url = match uri_type {
1077        UriType::LocalPath(path) => {
1078            let path = std::fs::canonicalize(path).map_err(|err| Error::DatasetNotFound {
1079                path: table_uri.to_string(),
1080                source: Box::new(err),
1081                location: location!(),
1082            })?;
1083            Url::from_directory_path(path).map_err(|_| {
1084                let msg = format!(
1085                    "Could not construct a URL from canonicalized path: {}.\n\
1086                  Something must be very wrong with the table path.",
1087                    table_uri
1088                );
1089                Error::InvalidTableLocation { message: msg }
1090            })?
1091        }
1092        UriType::Url(url) => url,
1093    };
1094
1095    let trimmed_path = url.path().trim_end_matches('/').to_owned();
1096    url.set_path(&trimmed_path);
1097    Ok(url)
1098}
1099
1100lazy_static::lazy_static! {
1101  static ref KNOWN_SCHEMES: Vec<&'static str> =
1102      Vec::from([
1103        "s3",
1104        "s3+ddb",
1105        "gs",
1106        "az",
1107        "file",
1108        "file-object-store",
1109        "memory"
1110      ]);
1111}
1112
1113#[cfg(test)]
1114mod tests {
1115    use super::*;
1116    use parquet::data_type::AsBytes;
1117    use rstest::rstest;
1118    use std::env::set_current_dir;
1119    use std::fs::{create_dir_all, write};
1120    use std::path::Path as StdPath;
1121    use std::sync::atomic::{AtomicBool, Ordering};
1122
1123    /// Write test content to file.
1124    fn write_to_file(path_str: &str, contents: &str) -> std::io::Result<()> {
1125        let expanded = tilde(path_str).to_string();
1126        let path = StdPath::new(&expanded);
1127        std::fs::create_dir_all(path.parent().unwrap())?;
1128        write(path, contents)
1129    }
1130
1131    async fn read_from_store(store: ObjectStore, path: &Path) -> Result<String> {
1132        let test_file_store = store.open(path).await.unwrap();
1133        let size = test_file_store.size().await.unwrap();
1134        let bytes = test_file_store.get_range(0..size).await.unwrap();
1135        let contents = String::from_utf8(bytes.to_vec()).unwrap();
1136        Ok(contents)
1137    }
1138
1139    #[tokio::test]
1140    async fn test_absolute_paths() {
1141        let tmp_dir = tempfile::tempdir().unwrap();
1142        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1143        write_to_file(
1144            &format!("{tmp_path}/bar/foo.lance/test_file"),
1145            "TEST_CONTENT",
1146        )
1147        .unwrap();
1148
1149        // test a few variations of the same path
1150        for uri in &[
1151            format!("{tmp_path}/bar/foo.lance"),
1152            format!("{tmp_path}/./bar/foo.lance"),
1153            format!("{tmp_path}/bar/foo.lance/../foo.lance"),
1154        ] {
1155            let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1156            let contents = read_from_store(store, &path.child("test_file"))
1157                .await
1158                .unwrap();
1159            assert_eq!(contents, "TEST_CONTENT");
1160        }
1161    }
1162
1163    #[tokio::test]
1164    async fn test_cloud_paths() {
1165        let uri = "s3://bucket/foo.lance";
1166        let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1167        assert_eq!(store.scheme, "s3");
1168        assert_eq!(path.to_string(), "foo.lance");
1169
1170        let (store, path) = ObjectStore::from_uri("s3+ddb://bucket/foo.lance")
1171            .await
1172            .unwrap();
1173        assert_eq!(store.scheme, "s3");
1174        assert_eq!(path.to_string(), "foo.lance");
1175
1176        let (store, path) = ObjectStore::from_uri("gs://bucket/foo.lance")
1177            .await
1178            .unwrap();
1179        assert_eq!(store.scheme, "gs");
1180        assert_eq!(path.to_string(), "foo.lance");
1181    }
1182
1183    async fn test_block_size_used_test_helper(
1184        uri: &str,
1185        storage_options: Option<HashMap<String, String>>,
1186        default_expected_block_size: usize,
1187    ) {
1188        // Test the default
1189        let registry = Arc::new(ObjectStoreRegistry::default());
1190        let params = ObjectStoreParams {
1191            storage_options: storage_options.clone(),
1192            ..ObjectStoreParams::default()
1193        };
1194        let (store, _) = ObjectStore::from_uri_and_params(registry, uri, &params)
1195            .await
1196            .unwrap();
1197        assert_eq!(store.block_size, default_expected_block_size);
1198
1199        // Ensure param is used
1200        let registry = Arc::new(ObjectStoreRegistry::default());
1201        let params = ObjectStoreParams {
1202            block_size: Some(1024),
1203            storage_options: storage_options.clone(),
1204            ..ObjectStoreParams::default()
1205        };
1206        let (store, _) = ObjectStore::from_uri_and_params(registry, uri, &params)
1207            .await
1208            .unwrap();
1209        assert_eq!(store.block_size, 1024);
1210    }
1211
1212    #[rstest]
1213    #[case("s3://bucket/foo.lance", None)]
1214    #[case("gs://bucket/foo.lance", None)]
1215    #[case("az://account/bucket/foo.lance",
1216      Some(HashMap::from([
1217            (String::from("account_name"), String::from("account")),
1218            (String::from("container_name"), String::from("container"))
1219           ])))]
1220    #[tokio::test]
1221    async fn test_block_size_used_cloud(
1222        #[case] uri: &str,
1223        #[case] storage_options: Option<HashMap<String, String>>,
1224    ) {
1225        test_block_size_used_test_helper(uri, storage_options, 64 * 1024).await;
1226    }
1227
1228    #[rstest]
1229    #[case("file")]
1230    #[case("file-object-store")]
1231    #[case("memory:///bucket/foo.lance")]
1232    #[tokio::test]
1233    async fn test_block_size_used_file(#[case] prefix: &str) {
1234        let tmp_dir = tempfile::tempdir().unwrap();
1235        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1236        let path = format!("{tmp_path}/bar/foo.lance/test_file");
1237        write_to_file(&path, "URL").unwrap();
1238        let uri = format!("{prefix}:///{path}");
1239        test_block_size_used_test_helper(&uri, None, 4 * 1024).await;
1240    }
1241
1242    #[tokio::test]
1243    async fn test_relative_paths() {
1244        let tmp_dir = tempfile::tempdir().unwrap();
1245        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1246        write_to_file(
1247            &format!("{tmp_path}/bar/foo.lance/test_file"),
1248            "RELATIVE_URL",
1249        )
1250        .unwrap();
1251
1252        set_current_dir(StdPath::new(&tmp_path)).expect("Error changing current dir");
1253        let (store, path) = ObjectStore::from_uri("./bar/foo.lance").await.unwrap();
1254
1255        let contents = read_from_store(store, &path.child("test_file"))
1256            .await
1257            .unwrap();
1258        assert_eq!(contents, "RELATIVE_URL");
1259    }
1260
1261    #[tokio::test]
1262    async fn test_tilde_expansion() {
1263        let uri = "~/foo.lance";
1264        write_to_file(&format!("{uri}/test_file"), "TILDE").unwrap();
1265        let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1266        let contents = read_from_store(store, &path.child("test_file"))
1267            .await
1268            .unwrap();
1269        assert_eq!(contents, "TILDE");
1270    }
1271
1272    #[tokio::test]
1273    async fn test_read_directory() {
1274        let tmp_dir = tempfile::tempdir().unwrap();
1275        let path = tmp_dir.path();
1276        create_dir_all(path.join("foo").join("bar")).unwrap();
1277        create_dir_all(path.join("foo").join("zoo")).unwrap();
1278        create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
1279        write_to_file(
1280            path.join("foo").join("test_file").to_str().unwrap(),
1281            "read_dir",
1282        )
1283        .unwrap();
1284        let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
1285
1286        let sub_dirs = store.read_dir(base.child("foo")).await.unwrap();
1287        assert_eq!(sub_dirs, vec!["bar", "zoo", "test_file"]);
1288    }
1289
1290    #[tokio::test]
1291    async fn test_delete_directory() {
1292        let tmp_dir = tempfile::tempdir().unwrap();
1293        let path = tmp_dir.path();
1294        create_dir_all(path.join("foo").join("bar")).unwrap();
1295        create_dir_all(path.join("foo").join("zoo")).unwrap();
1296        create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
1297        write_to_file(
1298            path.join("foo")
1299                .join("bar")
1300                .join("test_file")
1301                .to_str()
1302                .unwrap(),
1303            "delete",
1304        )
1305        .unwrap();
1306        write_to_file(path.join("foo").join("top").to_str().unwrap(), "delete_top").unwrap();
1307        let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
1308        store.remove_dir_all(base.child("foo")).await.unwrap();
1309
1310        assert!(!path.join("foo").exists());
1311    }
1312
1313    #[derive(Debug)]
1314    struct TestWrapper {
1315        called: AtomicBool,
1316
1317        return_value: Arc<dyn OSObjectStore>,
1318    }
1319
1320    impl WrappingObjectStore for TestWrapper {
1321        fn wrap(&self, _original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore> {
1322            self.called.store(true, Ordering::Relaxed);
1323
1324            // return a mocked value so we can check if the final store is the one we expect
1325            self.return_value.clone()
1326        }
1327    }
1328
1329    impl TestWrapper {
1330        fn called(&self) -> bool {
1331            self.called.load(Ordering::Relaxed)
1332        }
1333    }
1334
1335    #[tokio::test]
1336    async fn test_wrapping_object_store_option_is_used() {
1337        // Make a store for the inner store first
1338        let mock_inner_store: Arc<dyn OSObjectStore> = Arc::new(InMemory::new());
1339        let registry = Arc::new(ObjectStoreRegistry::default());
1340
1341        assert_eq!(Arc::strong_count(&mock_inner_store), 1);
1342
1343        let wrapper = Arc::new(TestWrapper {
1344            called: AtomicBool::new(false),
1345            return_value: mock_inner_store.clone(),
1346        });
1347
1348        let params = ObjectStoreParams {
1349            object_store_wrapper: Some(wrapper.clone()),
1350            ..ObjectStoreParams::default()
1351        };
1352
1353        // not called yet
1354        assert!(!wrapper.called());
1355
1356        let _ = ObjectStore::from_uri_and_params(registry, "memory:///", &params)
1357            .await
1358            .unwrap();
1359
1360        // called after construction
1361        assert!(wrapper.called());
1362
1363        // hard to compare two trait pointers as the point to vtables
1364        // using the ref count as a proxy to make sure that the store is correctly kept
1365        assert_eq!(Arc::strong_count(&mock_inner_store), 2);
1366    }
1367
1368    #[derive(Debug, Default)]
1369    struct MockAwsCredentialsProvider {
1370        called: AtomicBool,
1371    }
1372
1373    #[async_trait]
1374    impl CredentialProvider for MockAwsCredentialsProvider {
1375        type Credential = ObjectStoreAwsCredential;
1376
1377        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
1378            self.called.store(true, Ordering::Relaxed);
1379            Ok(Arc::new(Self::Credential {
1380                key_id: "".to_string(),
1381                secret_key: "".to_string(),
1382                token: None,
1383            }))
1384        }
1385    }
1386
1387    #[tokio::test]
1388    async fn test_injected_aws_creds_option_is_used() {
1389        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
1390        let registry = Arc::new(ObjectStoreRegistry::default());
1391
1392        let params = ObjectStoreParams {
1393            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
1394            ..ObjectStoreParams::default()
1395        };
1396
1397        // Not called yet
1398        assert!(!mock_provider.called.load(Ordering::Relaxed));
1399
1400        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
1401            .await
1402            .unwrap();
1403
1404        // fails, but we don't care
1405        let _ = store
1406            .open(&Path::parse("/").unwrap())
1407            .await
1408            .unwrap()
1409            .get_range(0..1)
1410            .await;
1411
1412        // Not called yet
1413        assert!(mock_provider.called.load(Ordering::Relaxed));
1414    }
1415
1416    #[tokio::test]
1417    async fn test_local_paths() {
1418        let temp_dir = tempfile::tempdir().unwrap();
1419
1420        let file_path = temp_dir.path().join("test_file");
1421        let mut writer = ObjectStore::create_local_writer(file_path.as_path())
1422            .await
1423            .unwrap();
1424        writer.write_all(b"LOCAL").await.unwrap();
1425        writer.shutdown().await.unwrap();
1426
1427        let reader = ObjectStore::open_local(file_path.as_path()).await.unwrap();
1428        let buf = reader.get_range(0..5).await.unwrap();
1429        assert_eq!(buf.as_bytes(), b"LOCAL");
1430    }
1431
1432    #[tokio::test]
1433    async fn test_read_one() {
1434        let temp_dir = tempfile::tempdir().unwrap();
1435
1436        let file_path = temp_dir.path().join("test_file");
1437        let mut writer = ObjectStore::create_local_writer(file_path.as_path())
1438            .await
1439            .unwrap();
1440        writer.write_all(b"LOCAL").await.unwrap();
1441        writer.shutdown().await.unwrap();
1442
1443        let file_path_os = object_store::path::Path::parse(file_path.to_str().unwrap()).unwrap();
1444        let obj_store = ObjectStore::local();
1445        let buf = obj_store.read_one_all(&file_path_os).await.unwrap();
1446        assert_eq!(buf.as_bytes(), b"LOCAL");
1447
1448        let buf = obj_store.read_one_range(&file_path_os, 0..5).await.unwrap();
1449        assert_eq!(buf.as_bytes(), b"LOCAL");
1450    }
1451
1452    #[tokio::test]
1453    #[cfg(windows)]
1454    async fn test_windows_paths() {
1455        use std::path::Component;
1456        use std::path::Prefix;
1457        use std::path::Prefix::*;
1458
1459        fn get_path_prefix(path: &StdPath) -> Prefix {
1460            match path.components().next().unwrap() {
1461                Component::Prefix(prefix_component) => prefix_component.kind(),
1462                _ => panic!(),
1463            }
1464        }
1465
1466        fn get_drive_letter(prefix: Prefix) -> String {
1467            match prefix {
1468                Disk(bytes) => String::from_utf8(vec![bytes]).unwrap(),
1469                _ => panic!(),
1470            }
1471        }
1472
1473        let tmp_dir = tempfile::tempdir().unwrap();
1474        let tmp_path = tmp_dir.path();
1475        let prefix = get_path_prefix(tmp_path);
1476        let drive_letter = get_drive_letter(prefix);
1477
1478        write_to_file(
1479            &(format!("{drive_letter}:/test_folder/test.lance") + "/test_file"),
1480            "WINDOWS",
1481        )
1482        .unwrap();
1483
1484        for uri in &[
1485            format!("{drive_letter}:/test_folder/test.lance"),
1486            format!("{drive_letter}:\\test_folder\\test.lance"),
1487        ] {
1488            let (store, base) = ObjectStore::from_uri(uri).await.unwrap();
1489            let contents = read_from_store(store, &base.child("test_file"))
1490                .await
1491                .unwrap();
1492            assert_eq!(contents, "WINDOWS");
1493        }
1494    }
1495}