lance_io/
object_store.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extend [object_store::ObjectStore] functionalities
5
6use std::collections::HashMap;
7use std::ops::Range;
8use std::path::PathBuf;
9use std::str::FromStr;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12
13use async_trait::async_trait;
14use aws_config::default_provider::credentials::DefaultCredentialsChain;
15use aws_credential_types::provider::ProvideCredentials;
16use bytes::Bytes;
17use chrono::{DateTime, Utc};
18use deepsize::DeepSizeOf;
19use futures::{future, stream::BoxStream, StreamExt, TryStreamExt};
20use lance_core::utils::parse::str_is_truthy;
21use lance_core::utils::tokio::get_num_compute_intensive_cpus;
22use object_store::aws::{
23    AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential, AwsCredentialProvider,
24};
25use object_store::azure::MicrosoftAzureBuilder;
26use object_store::gcp::{GcpCredential, GoogleCloudStorageBuilder};
27use object_store::{
28    aws::AmazonS3Builder, azure::AzureConfigKey, gcp::GoogleConfigKey, local::LocalFileSystem,
29    memory::InMemory, CredentialProvider, Error as ObjectStoreError, Result as ObjectStoreResult,
30};
31use object_store::{path::Path, ObjectMeta, ObjectStore as OSObjectStore};
32use object_store::{ClientOptions, DynObjectStore, RetryConfig, StaticCredentialProvider};
33use shellexpand::tilde;
34use snafu::location;
35use tokio::io::AsyncWriteExt;
36use tokio::sync::RwLock;
37use url::Url;
38
39use super::local::LocalObjectReader;
40mod tracing;
41use self::tracing::ObjectStoreTracingExt;
42use crate::object_writer::WriteResult;
43use crate::{object_reader::CloudObjectReader, object_writer::ObjectWriter, traits::Reader};
44use lance_core::{Error, Result};
45
46// Local disks tend to do fine with a few threads
47// Note: the number of threads here also impacts the number of files
48// we need to read in some situations.  So keeping this at 8 keeps the
49// RAM on our scanner down.
50pub const DEFAULT_LOCAL_IO_PARALLELISM: usize = 8;
51// Cloud disks often need many many threads to saturate the network
52pub const DEFAULT_CLOUD_IO_PARALLELISM: usize = 64;
53
54pub const DEFAULT_DOWNLOAD_RETRY_COUNT: usize = 3;
55
56#[async_trait]
57pub trait ObjectStoreExt {
58    /// Returns true if the file exists.
59    async fn exists(&self, path: &Path) -> Result<bool>;
60
61    /// Read all files (start from base directory) recursively
62    ///
63    /// unmodified_since can be specified to only return files that have not been modified since the given time.
64    async fn read_dir_all<'a>(
65        &'a self,
66        dir_path: impl Into<&Path> + Send,
67        unmodified_since: Option<DateTime<Utc>>,
68    ) -> Result<BoxStream<'a, Result<ObjectMeta>>>;
69}
70
71#[async_trait]
72impl<O: OSObjectStore + ?Sized> ObjectStoreExt for O {
73    async fn read_dir_all<'a>(
74        &'a self,
75        dir_path: impl Into<&Path> + Send,
76        unmodified_since: Option<DateTime<Utc>>,
77    ) -> Result<BoxStream<'a, Result<ObjectMeta>>> {
78        let mut output = self.list(Some(dir_path.into()));
79        if let Some(unmodified_since_val) = unmodified_since {
80            output = output
81                .try_filter(move |file| future::ready(file.last_modified < unmodified_since_val))
82                .boxed();
83        }
84        Ok(output.map_err(|e| e.into()).boxed())
85    }
86
87    async fn exists(&self, path: &Path) -> Result<bool> {
88        match self.head(path).await {
89            Ok(_) => Ok(true),
90            Err(object_store::Error::NotFound { path: _, source: _ }) => Ok(false),
91            Err(e) => Err(e.into()),
92        }
93    }
94}
95
96/// Wraps [ObjectStore](object_store::ObjectStore)
97#[derive(Debug, Clone)]
98pub struct ObjectStore {
99    // Inner object store
100    pub inner: Arc<dyn OSObjectStore>,
101    scheme: String,
102    block_size: usize,
103    /// Whether to use constant size upload parts for multipart uploads. This
104    /// is only necessary for Cloudflare R2.
105    pub use_constant_size_upload_parts: bool,
106    /// Whether we can assume that the list of files is lexically ordered. This
107    /// is true for object stores, but not for local filesystems.
108    pub list_is_lexically_ordered: bool,
109    io_parallelism: usize,
110    /// Number of times to retry a failed download
111    download_retry_count: usize,
112}
113
114impl DeepSizeOf for ObjectStore {
115    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
116        // We aren't counting `inner` here which is problematic but an ObjectStore
117        // shouldn't be too big.  The only exception might be the write cache but, if
118        // the writer cache has data, it means we're using it somewhere else that isn't
119        // a cache and so that doesn't really count.
120        self.scheme.deep_size_of_children(context) + self.block_size.deep_size_of_children(context)
121    }
122}
123
124impl std::fmt::Display for ObjectStore {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(f, "ObjectStore({})", self.scheme)
127    }
128}
129
130pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
131    fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
132}
133
134#[derive(Default, Debug)]
135pub struct ObjectStoreRegistry {
136    providers: HashMap<String, Arc<dyn ObjectStoreProvider>>,
137}
138
139impl ObjectStoreRegistry {
140    pub fn insert(&mut self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
141        self.providers.insert(scheme.into(), provider);
142    }
143}
144
145const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
146
147/// Adapt an AWS SDK cred into object_store credentials
148#[derive(Debug)]
149pub struct AwsCredentialAdapter {
150    pub inner: Arc<dyn ProvideCredentials>,
151
152    // RefCell can't be shared across threads, so we use HashMap
153    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
154
155    // The amount of time before expiry to refresh credentials
156    credentials_refresh_offset: Duration,
157}
158
159impl AwsCredentialAdapter {
160    pub fn new(
161        provider: Arc<dyn ProvideCredentials>,
162        credentials_refresh_offset: Duration,
163    ) -> Self {
164        Self {
165            inner: provider,
166            cache: Arc::new(RwLock::new(HashMap::new())),
167            credentials_refresh_offset,
168        }
169    }
170}
171
172#[async_trait]
173impl CredentialProvider for AwsCredentialAdapter {
174    type Credential = ObjectStoreAwsCredential;
175
176    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
177        let cached_creds = {
178            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
179            let expired = cache_value
180                .clone()
181                .map(|cred| {
182                    cred.expiry()
183                        .map(|exp| {
184                            exp.checked_sub(self.credentials_refresh_offset)
185                                .expect("this time should always be valid")
186                                < SystemTime::now()
187                        })
188                        // no expiry is never expire
189                        .unwrap_or(false)
190                })
191                .unwrap_or(true); // no cred is the same as expired;
192            if expired {
193                None
194            } else {
195                cache_value.clone()
196            }
197        };
198
199        if let Some(creds) = cached_creds {
200            Ok(Arc::new(Self::Credential {
201                key_id: creds.access_key_id().to_string(),
202                secret_key: creds.secret_access_key().to_string(),
203                token: creds.session_token().map(|s| s.to_string()),
204            }))
205        } else {
206            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
207                |e| Error::Internal {
208                    message: format!("Failed to get AWS credentials: {}", e),
209                    location: location!(),
210                },
211            )?);
212
213            self.cache
214                .write()
215                .await
216                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
217
218            Ok(Arc::new(Self::Credential {
219                key_id: refreshed_creds.access_key_id().to_string(),
220                secret_key: refreshed_creds.secret_access_key().to_string(),
221                token: refreshed_creds.session_token().map(|s| s.to_string()),
222            }))
223        }
224    }
225}
226
227/// Figure out the S3 region of the bucket.
228///
229/// This resolves in order of precedence:
230/// 1. The region provided in the storage options
231/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
232///
233/// It can return None if no region is provided and the endpoint is set.
234async fn resolve_s3_region(
235    url: &Url,
236    storage_options: &HashMap<AmazonS3ConfigKey, String>,
237) -> Result<Option<String>> {
238    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
239        Ok(Some(region.clone()))
240    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
241        // If no endpoint is set, we can assume this is AWS S3 and the region
242        // can be resolved from the bucket.
243        let bucket = url.host_str().ok_or_else(|| {
244            Error::invalid_input(
245                format!("Could not parse bucket from url: {}", url),
246                location!(),
247            )
248        })?;
249
250        let mut client_options = ClientOptions::default();
251        for (key, value) in storage_options {
252            if let AmazonS3ConfigKey::Client(client_key) = key {
253                client_options = client_options.with_config(*client_key, value.clone());
254            }
255        }
256
257        let bucket_region =
258            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
259        Ok(Some(bucket_region))
260    } else {
261        Ok(None)
262    }
263}
264
265/// Build AWS credentials
266///
267/// This resolves credentials from the following sources in order:
268/// 1. An explicit `credentials` provider
269/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`,
270///    `aws_secret_access_key`, `aws_session_token`)
271/// 3. The default credential provider chain from AWS SDK.
272///
273/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
274pub async fn build_aws_credential(
275    credentials_refresh_offset: Duration,
276    credentials: Option<AwsCredentialProvider>,
277    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
278    region: Option<String>,
279) -> Result<(AwsCredentialProvider, String)> {
280    // TODO: make this return no credential provider not using AWS
281    use aws_config::meta::region::RegionProviderChain;
282    const DEFAULT_REGION: &str = "us-west-2";
283
284    let region = if let Some(region) = region {
285        region
286    } else {
287        RegionProviderChain::default_provider()
288            .or_else(DEFAULT_REGION)
289            .region()
290            .await
291            .map(|r| r.as_ref().to_string())
292            .unwrap_or(DEFAULT_REGION.to_string())
293    };
294
295    if let Some(creds) = credentials {
296        Ok((creds, region))
297    } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
298        Ok((Arc::new(creds), region))
299    } else {
300        let credentials_provider = DefaultCredentialsChain::builder().build().await;
301
302        Ok((
303            Arc::new(AwsCredentialAdapter::new(
304                Arc::new(credentials_provider),
305                credentials_refresh_offset,
306            )),
307            region,
308        ))
309    }
310}
311
312fn extract_static_s3_credentials(
313    options: &HashMap<AmazonS3ConfigKey, String>,
314) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
315    let key_id = options
316        .get(&AmazonS3ConfigKey::AccessKeyId)
317        .map(|s| s.to_string());
318    let secret_key = options
319        .get(&AmazonS3ConfigKey::SecretAccessKey)
320        .map(|s| s.to_string());
321    let token = options
322        .get(&AmazonS3ConfigKey::Token)
323        .map(|s| s.to_string());
324    match (key_id, secret_key, token) {
325        (Some(key_id), Some(secret_key), token) => {
326            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
327                key_id,
328                secret_key,
329                token,
330            }))
331        }
332        _ => None,
333    }
334}
335
336pub trait WrappingObjectStore: std::fmt::Debug + Send + Sync {
337    fn wrap(&self, original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore>;
338}
339
340/// Parameters to create an [ObjectStore]
341///
342#[derive(Debug, Clone)]
343pub struct ObjectStoreParams {
344    pub block_size: Option<usize>,
345    pub object_store: Option<(Arc<DynObjectStore>, Url)>,
346    pub s3_credentials_refresh_offset: Duration,
347    pub aws_credentials: Option<AwsCredentialProvider>,
348    pub object_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
349    pub storage_options: Option<HashMap<String, String>>,
350    /// Use constant size upload parts for multipart uploads. Only necessary
351    /// for Cloudflare R2, which doesn't support variable size parts. When this
352    /// is false, max upload size is 2.5TB. When this is true, the max size is
353    /// 50GB.
354    pub use_constant_size_upload_parts: bool,
355    pub list_is_lexically_ordered: Option<bool>,
356}
357
358impl Default for ObjectStoreParams {
359    fn default() -> Self {
360        Self {
361            object_store: None,
362            block_size: None,
363            s3_credentials_refresh_offset: Duration::from_secs(60),
364            aws_credentials: None,
365            object_store_wrapper: None,
366            storage_options: None,
367            use_constant_size_upload_parts: false,
368            list_is_lexically_ordered: None,
369        }
370    }
371}
372
373impl ObjectStoreParams {
374    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
375    pub fn with_aws_credentials(
376        aws_credentials: Option<AwsCredentialProvider>,
377        region: Option<String>,
378    ) -> Self {
379        Self {
380            aws_credentials,
381            storage_options: region
382                .map(|region| [("region".into(), region)].iter().cloned().collect()),
383            ..Default::default()
384        }
385    }
386}
387
388impl ObjectStore {
389    /// Parse from a string URI.
390    ///
391    /// Returns the ObjectStore instance and the absolute path to the object.
392    pub async fn from_uri(uri: &str) -> Result<(Self, Path)> {
393        let registry = Arc::new(ObjectStoreRegistry::default());
394
395        Self::from_uri_and_params(registry, uri, &ObjectStoreParams::default()).await
396    }
397
398    /// Parse from a string URI.
399    ///
400    /// Returns the ObjectStore instance and the absolute path to the object.
401    pub async fn from_uri_and_params(
402        registry: Arc<ObjectStoreRegistry>,
403        uri: &str,
404        params: &ObjectStoreParams,
405    ) -> Result<(Self, Path)> {
406        if let Some((store, path)) = params.object_store.as_ref() {
407            let mut inner = store.clone();
408            if let Some(wrapper) = params.object_store_wrapper.as_ref() {
409                inner = wrapper.wrap(inner);
410            }
411            let store = Self {
412                inner,
413                scheme: path.scheme().to_string(),
414                block_size: params.block_size.unwrap_or(64 * 1024),
415                use_constant_size_upload_parts: params.use_constant_size_upload_parts,
416                list_is_lexically_ordered: params.list_is_lexically_ordered.unwrap_or_default(),
417                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
418                download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
419            };
420            let path = Path::from(path.path());
421            return Ok((store, path));
422        }
423        let (object_store, path) = match Url::parse(uri) {
424            Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
425                // On Windows, the drive is parsed as a scheme
426                Self::from_path(uri)
427            }
428            Ok(url) => {
429                let store = Self::new_from_url(registry, url.clone(), params.clone()).await?;
430                Ok((store, Path::from(url.path())))
431            }
432            Err(_) => Self::from_path(uri),
433        }?;
434
435        Ok((
436            Self {
437                inner: params
438                    .object_store_wrapper
439                    .as_ref()
440                    .map(|w| w.wrap(object_store.inner.clone()))
441                    .unwrap_or(object_store.inner),
442                ..object_store
443            },
444            path,
445        ))
446    }
447
448    pub fn from_path_with_scheme(str_path: &str, scheme: &str) -> Result<(Self, Path)> {
449        let expanded = tilde(str_path).to_string();
450
451        let mut expanded_path = path_abs::PathAbs::new(expanded)
452            .unwrap()
453            .as_path()
454            .to_path_buf();
455        // path_abs::PathAbs::new(".") returns an empty string.
456        if let Some(s) = expanded_path.as_path().to_str() {
457            if s.is_empty() {
458                expanded_path = std::env::current_dir()?;
459            }
460        }
461        Ok((
462            Self {
463                inner: Arc::new(LocalFileSystem::new()).traced(),
464                scheme: String::from(scheme),
465                block_size: 4 * 1024, // 4KB block size
466                use_constant_size_upload_parts: false,
467                list_is_lexically_ordered: false,
468                io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM,
469                download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
470            },
471            Path::from_absolute_path(expanded_path.as_path())?,
472        ))
473    }
474
475    pub fn from_path(str_path: &str) -> Result<(Self, Path)> {
476        Self::from_path_with_scheme(str_path, "file")
477    }
478
479    async fn new_from_url(
480        registry: Arc<ObjectStoreRegistry>,
481        url: Url,
482        params: ObjectStoreParams,
483    ) -> Result<Self> {
484        configure_store(registry, url.as_str(), params).await
485    }
486
487    /// Local object store.
488    pub fn local() -> Self {
489        Self {
490            inner: Arc::new(LocalFileSystem::new()).traced(),
491            scheme: String::from("file"),
492            block_size: 4 * 1024, // 4KB block size
493            use_constant_size_upload_parts: false,
494            list_is_lexically_ordered: false,
495            io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM,
496            download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
497        }
498    }
499
500    /// Create a in-memory object store directly for testing.
501    pub fn memory() -> Self {
502        Self {
503            inner: Arc::new(InMemory::new()).traced(),
504            scheme: String::from("memory"),
505            block_size: 4 * 1024,
506            use_constant_size_upload_parts: false,
507            list_is_lexically_ordered: true,
508            io_parallelism: get_num_compute_intensive_cpus(),
509            download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
510        }
511    }
512
513    /// Returns true if the object store pointed to a local file system.
514    pub fn is_local(&self) -> bool {
515        self.scheme == "file"
516    }
517
518    pub fn is_cloud(&self) -> bool {
519        self.scheme != "file" && self.scheme != "memory"
520    }
521
522    pub fn block_size(&self) -> usize {
523        self.block_size
524    }
525
526    pub fn set_block_size(&mut self, new_size: usize) {
527        self.block_size = new_size;
528    }
529
530    pub fn set_io_parallelism(&mut self, io_parallelism: usize) {
531        self.io_parallelism = io_parallelism;
532    }
533
534    pub fn io_parallelism(&self) -> usize {
535        std::env::var("LANCE_IO_THREADS")
536            .map(|val| val.parse::<usize>().unwrap())
537            .unwrap_or(self.io_parallelism)
538    }
539
540    /// Open a file for path.
541    ///
542    /// Parameters
543    /// - ``path``: Absolute path to the file.
544    pub async fn open(&self, path: &Path) -> Result<Box<dyn Reader>> {
545        match self.scheme.as_str() {
546            "file" => LocalObjectReader::open(path, self.block_size, None).await,
547            _ => Ok(Box::new(CloudObjectReader::new(
548                self.inner.clone(),
549                path.clone(),
550                self.block_size,
551                None,
552                self.download_retry_count,
553            )?)),
554        }
555    }
556
557    /// Open a reader for a file with known size.
558    ///
559    /// This size may either have been retrieved from a list operation or
560    /// cached metadata. By passing in the known size, we can skip a HEAD / metadata
561    /// call.
562    pub async fn open_with_size(&self, path: &Path, known_size: usize) -> Result<Box<dyn Reader>> {
563        match self.scheme.as_str() {
564            "file" => LocalObjectReader::open(path, self.block_size, Some(known_size)).await,
565            _ => Ok(Box::new(CloudObjectReader::new(
566                self.inner.clone(),
567                path.clone(),
568                self.block_size,
569                Some(known_size),
570                self.download_retry_count,
571            )?)),
572        }
573    }
574
575    /// Create an [ObjectWriter] from local [std::path::Path]
576    pub async fn create_local_writer(path: &std::path::Path) -> Result<ObjectWriter> {
577        let object_store = Self::local();
578        let os_path = Path::from(path.to_str().unwrap());
579        object_store.create(&os_path).await
580    }
581
582    /// Open an [Reader] from local [std::path::Path]
583    pub async fn open_local(path: &std::path::Path) -> Result<Box<dyn Reader>> {
584        let object_store = Self::local();
585        let os_path = Path::from(path.to_str().unwrap());
586        object_store.open(&os_path).await
587    }
588
589    /// Create a new file.
590    pub async fn create(&self, path: &Path) -> Result<ObjectWriter> {
591        ObjectWriter::new(self, path).await
592    }
593
594    /// A helper function to create a file and write content to it.
595    pub async fn put(&self, path: &Path, content: &[u8]) -> Result<WriteResult> {
596        let mut writer = self.create(path).await?;
597        writer.write_all(content).await?;
598        writer.shutdown().await
599    }
600
601    pub async fn delete(&self, path: &Path) -> Result<()> {
602        self.inner.delete(path).await?;
603        Ok(())
604    }
605
606    pub async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
607        Ok(self.inner.copy(from, to).await?)
608    }
609
610    /// Read a directory (start from base directory) and returns all sub-paths in the directory.
611    pub async fn read_dir(&self, dir_path: impl Into<Path>) -> Result<Vec<String>> {
612        let path = dir_path.into();
613        let path = Path::parse(&path)?;
614        let output = self.inner.list_with_delimiter(Some(&path)).await?;
615        Ok(output
616            .common_prefixes
617            .iter()
618            .chain(output.objects.iter().map(|o| &o.location))
619            .map(|s| s.filename().unwrap().to_string())
620            .collect())
621    }
622
623    /// Read all files (start from base directory) recursively
624    ///
625    /// unmodified_since can be specified to only return files that have not been modified since the given time.
626    pub async fn read_dir_all(
627        &self,
628        dir_path: impl Into<&Path> + Send,
629        unmodified_since: Option<DateTime<Utc>>,
630    ) -> Result<BoxStream<Result<ObjectMeta>>> {
631        self.inner.read_dir_all(dir_path, unmodified_since).await
632    }
633
634    /// Remove a directory recursively.
635    pub async fn remove_dir_all(&self, dir_path: impl Into<Path>) -> Result<()> {
636        let path = dir_path.into();
637        let path = Path::parse(&path)?;
638
639        if self.is_local() {
640            // Local file system needs to delete directories as well.
641            return super::local::remove_dir_all(&path);
642        }
643        let sub_entries = self
644            .inner
645            .list(Some(&path))
646            .map(|m| m.map(|meta| meta.location))
647            .boxed();
648        self.inner
649            .delete_stream(sub_entries)
650            .try_collect::<Vec<_>>()
651            .await?;
652        Ok(())
653    }
654
655    pub fn remove_stream<'a>(
656        &'a self,
657        locations: BoxStream<'a, Result<Path>>,
658    ) -> BoxStream<'a, Result<Path>> {
659        self.inner
660            .delete_stream(locations.err_into::<ObjectStoreError>().boxed())
661            .err_into::<Error>()
662            .boxed()
663    }
664
665    /// Check a file exists.
666    pub async fn exists(&self, path: &Path) -> Result<bool> {
667        match self.inner.head(path).await {
668            Ok(_) => Ok(true),
669            Err(object_store::Error::NotFound { path: _, source: _ }) => Ok(false),
670            Err(e) => Err(e.into()),
671        }
672    }
673
674    /// Get file size.
675    pub async fn size(&self, path: &Path) -> Result<usize> {
676        Ok(self.inner.head(path).await?.size)
677    }
678
679    /// Convenience function to open a reader and read all the bytes
680    pub async fn read_one_all(&self, path: &Path) -> Result<Bytes> {
681        let reader = self.open(path).await?;
682        Ok(reader.get_all().await?)
683    }
684
685    /// Convenience function open a reader and make a single request
686    ///
687    /// If you will be making multiple requests to the path it is more efficient to call [`Self::open`]
688    /// and then call [`Reader::get_range`] multiple times.
689    pub async fn read_one_range(&self, path: &Path, range: Range<usize>) -> Result<Bytes> {
690        let reader = self.open(path).await?;
691        Ok(reader.get_range(range).await?)
692    }
693}
694
695/// Options that can be set for multiple object stores
696#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy)]
697pub enum LanceConfigKey {
698    /// Number of times to retry a download that fails
699    DownloadRetryCount,
700}
701
702impl FromStr for LanceConfigKey {
703    type Err = Error;
704
705    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
706        match s.to_ascii_lowercase().as_str() {
707            "download_retry_count" => Ok(Self::DownloadRetryCount),
708            _ => Err(Error::InvalidInput {
709                source: format!("Invalid LanceConfigKey: {}", s).into(),
710                location: location!(),
711            }),
712        }
713    }
714}
715
716#[derive(Clone, Debug, Default)]
717pub struct StorageOptions(pub HashMap<String, String>);
718
719impl StorageOptions {
720    /// Create a new instance of [`StorageOptions`]
721    pub fn new(options: HashMap<String, String>) -> Self {
722        let mut options = options;
723        if let Ok(value) = std::env::var("AZURE_STORAGE_ALLOW_HTTP") {
724            options.insert("allow_http".into(), value);
725        }
726        if let Ok(value) = std::env::var("AZURE_STORAGE_USE_HTTP") {
727            options.insert("allow_http".into(), value);
728        }
729        if let Ok(value) = std::env::var("AWS_ALLOW_HTTP") {
730            options.insert("allow_http".into(), value);
731        }
732        if let Ok(value) = std::env::var("OBJECT_STORE_CLIENT_MAX_RETRIES") {
733            options.insert("client_max_retries".into(), value);
734        }
735        if let Ok(value) = std::env::var("OBJECT_STORE_CLIENT_RETRY_TIMEOUT") {
736            options.insert("client_retry_timeout".into(), value);
737        }
738        Self(options)
739    }
740
741    /// Add values from the environment to storage options
742    pub fn with_env_azure(&mut self) {
743        for (os_key, os_value) in std::env::vars_os() {
744            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
745                if let Ok(config_key) = AzureConfigKey::from_str(&key.to_ascii_lowercase()) {
746                    if !self.0.contains_key(config_key.as_ref()) {
747                        self.0
748                            .insert(config_key.as_ref().to_string(), value.to_string());
749                    }
750                }
751            }
752        }
753    }
754
755    /// Add values from the environment to storage options
756    pub fn with_env_gcs(&mut self) {
757        for (os_key, os_value) in std::env::vars_os() {
758            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
759                let lowercase_key = key.to_ascii_lowercase();
760                let token_key = "google_storage_token";
761
762                if let Ok(config_key) = GoogleConfigKey::from_str(&lowercase_key) {
763                    if !self.0.contains_key(config_key.as_ref()) {
764                        self.0
765                            .insert(config_key.as_ref().to_string(), value.to_string());
766                    }
767                }
768                // Check for GOOGLE_STORAGE_TOKEN until GoogleConfigKey supports storage token
769                else if lowercase_key == token_key && !self.0.contains_key(token_key) {
770                    self.0.insert(token_key.to_string(), value.to_string());
771                }
772            }
773        }
774    }
775
776    /// Add values from the environment to storage options
777    pub fn with_env_s3(&mut self) {
778        for (os_key, os_value) in std::env::vars_os() {
779            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
780                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
781                    if !self.0.contains_key(config_key.as_ref()) {
782                        self.0
783                            .insert(config_key.as_ref().to_string(), value.to_string());
784                    }
785                }
786            }
787        }
788    }
789
790    /// Denotes if unsecure connections via http are allowed
791    pub fn allow_http(&self) -> bool {
792        self.0.iter().any(|(key, value)| {
793            key.to_ascii_lowercase().contains("allow_http") & str_is_truthy(value)
794        })
795    }
796
797    /// Number of times to retry a download that fails
798    pub fn download_retry_count(&self) -> usize {
799        self.0
800            .iter()
801            .find(|(key, _)| key.eq_ignore_ascii_case("download_retry_count"))
802            .map(|(_, value)| value.parse::<usize>().unwrap_or(3))
803            .unwrap_or(3)
804    }
805
806    /// Max retry times to set in RetryConfig for object store client
807    pub fn client_max_retries(&self) -> usize {
808        self.0
809            .iter()
810            .find(|(key, _)| key.eq_ignore_ascii_case("client_max_retries"))
811            .and_then(|(_, value)| value.parse::<usize>().ok())
812            .unwrap_or(10)
813    }
814
815    /// Seconds of timeout to set in RetryConfig for object store client
816    pub fn client_retry_timeout(&self) -> u64 {
817        self.0
818            .iter()
819            .find(|(key, _)| key.eq_ignore_ascii_case("client_retry_timeout"))
820            .and_then(|(_, value)| value.parse::<u64>().ok())
821            .unwrap_or(180)
822    }
823
824    /// Subset of options relevant for azure storage
825    pub fn as_azure_options(&self) -> HashMap<AzureConfigKey, String> {
826        self.0
827            .iter()
828            .filter_map(|(key, value)| {
829                let az_key = AzureConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
830                Some((az_key, value.clone()))
831            })
832            .collect()
833    }
834
835    /// Subset of options relevant for s3 storage
836    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
837        self.0
838            .iter()
839            .filter_map(|(key, value)| {
840                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
841                Some((s3_key, value.clone()))
842            })
843            .collect()
844    }
845
846    /// Subset of options relevant for gcs storage
847    pub fn as_gcs_options(&self) -> HashMap<GoogleConfigKey, String> {
848        self.0
849            .iter()
850            .filter_map(|(key, value)| {
851                let gcs_key = GoogleConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
852                Some((gcs_key, value.clone()))
853            })
854            .collect()
855    }
856
857    pub fn get(&self, key: &str) -> Option<&String> {
858        self.0.get(key)
859    }
860}
861
862impl From<HashMap<String, String>> for StorageOptions {
863    fn from(value: HashMap<String, String>) -> Self {
864        Self::new(value)
865    }
866}
867
868async fn configure_store(
869    registry: Arc<ObjectStoreRegistry>,
870    url: &str,
871    options: ObjectStoreParams,
872) -> Result<ObjectStore> {
873    let mut storage_options = StorageOptions(options.storage_options.clone().unwrap_or_default());
874    let download_retry_count = storage_options.download_retry_count();
875    let mut url = ensure_table_uri(url)?;
876    // Block size: On local file systems, we use 4KB block size. On cloud
877    // object stores, we use 64KB block size. This is generally the largest
878    // block size where we don't see a latency penalty.
879    let file_block_size = options.block_size.unwrap_or(4 * 1024);
880    let cloud_block_size = options.block_size.unwrap_or(64 * 1024);
881    let max_retries = storage_options.client_max_retries();
882    let retry_timeout = storage_options.client_retry_timeout();
883    let retry_config = RetryConfig {
884        backoff: Default::default(),
885        max_retries,
886        retry_timeout: Duration::from_secs(retry_timeout),
887    };
888    match url.scheme() {
889        "s3" | "s3+ddb" => {
890            storage_options.with_env_s3();
891
892            // if url.scheme() == "s3+ddb" && options.commit_handler.is_some() {
893            //     return Err(Error::InvalidInput {
894            //         source: "`s3+ddb://` scheme and custom commit handler are mutually exclusive"
895            //             .into(),
896            //         location: location!(),
897            //     });
898            // }
899
900            let mut storage_options = storage_options.as_s3_options();
901            let region = resolve_s3_region(&url, &storage_options).await?;
902            let (aws_creds, region) = build_aws_credential(
903                options.s3_credentials_refresh_offset,
904                options.aws_credentials.clone(),
905                Some(&storage_options),
906                region,
907            )
908            .await?;
909
910            // This will be default in next version of object store.
911            // https://github.com/apache/arrow-rs/pull/7181
912            storage_options
913                .entry(AmazonS3ConfigKey::ConditionalPut)
914                .or_insert_with(|| "etag".to_string());
915
916            // Cloudflare does not support varying part sizes.
917            let use_constant_size_upload_parts = storage_options
918                .get(&AmazonS3ConfigKey::Endpoint)
919                .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
920                .unwrap_or(false);
921
922            // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
923            url.set_scheme("s3").map_err(|()| Error::Internal {
924                message: "could not set scheme".into(),
925                location: location!(),
926            })?;
927
928            url.set_query(None);
929
930            // we can't use parse_url_opts here because we need to manually set the credentials provider
931            let mut builder = AmazonS3Builder::new();
932            for (key, value) in storage_options {
933                builder = builder.with_config(key, value);
934            }
935            builder = builder
936                .with_url(url.as_ref())
937                .with_credentials(aws_creds)
938                .with_retry(retry_config)
939                .with_region(region);
940            let store = builder.build()?;
941
942            Ok(ObjectStore {
943                inner: Arc::new(store).traced(),
944                scheme: String::from(url.scheme()),
945                block_size: cloud_block_size,
946                use_constant_size_upload_parts,
947                list_is_lexically_ordered: true,
948                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
949                download_retry_count,
950            })
951        }
952        "gs" => {
953            storage_options.with_env_gcs();
954            let mut builder = GoogleCloudStorageBuilder::new()
955                .with_url(url.as_ref())
956                .with_retry(retry_config);
957            for (key, value) in storage_options.as_gcs_options() {
958                builder = builder.with_config(key, value);
959            }
960            let token_key = "google_storage_token";
961            if let Some(storage_token) = storage_options.get(token_key) {
962                let credential = GcpCredential {
963                    bearer: storage_token.to_string(),
964                };
965                let credential_provider = Arc::new(StaticCredentialProvider::new(credential)) as _;
966                builder = builder.with_credentials(credential_provider);
967            }
968            let store = builder.build()?;
969            let store = Arc::new(store).traced();
970
971            Ok(ObjectStore {
972                inner: store,
973                scheme: String::from("gs"),
974                block_size: cloud_block_size,
975                use_constant_size_upload_parts: false,
976                list_is_lexically_ordered: true,
977                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
978                download_retry_count,
979            })
980        }
981        "az" => {
982            storage_options.with_env_azure();
983            let mut builder = MicrosoftAzureBuilder::new()
984                .with_url(url.as_ref())
985                .with_retry(retry_config);
986            for (key, value) in storage_options.as_azure_options() {
987                builder = builder.with_config(key, value);
988            }
989            let store = builder.build()?;
990            let store = Arc::new(store).traced();
991
992            Ok(ObjectStore {
993                inner: store,
994                scheme: String::from("az"),
995                block_size: cloud_block_size,
996                use_constant_size_upload_parts: false,
997                list_is_lexically_ordered: true,
998                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
999                download_retry_count,
1000            })
1001        }
1002        // we have a bypass logic to use `tokio::fs` directly to lower overhead
1003        // however this makes testing harder as we can't use the same code path
1004        // "file-object-store" forces local file system dataset to use the same
1005        // code path as cloud object stores
1006        "file" => {
1007            let mut object_store = ObjectStore::from_path(url.path())?.0;
1008            object_store.set_block_size(file_block_size);
1009            Ok(object_store)
1010        }
1011        "file-object-store" => {
1012            let mut object_store =
1013                ObjectStore::from_path_with_scheme(url.path(), "file-object-store")?.0;
1014            object_store.set_block_size(file_block_size);
1015            Ok(object_store)
1016        }
1017        "memory" => Ok(ObjectStore {
1018            inner: Arc::new(InMemory::new()).traced(),
1019            scheme: String::from("memory"),
1020            block_size: file_block_size,
1021            use_constant_size_upload_parts: false,
1022            list_is_lexically_ordered: true,
1023            io_parallelism: get_num_compute_intensive_cpus(),
1024            download_retry_count,
1025        }),
1026        unknown_scheme => {
1027            if let Some(provider) = registry.providers.get(unknown_scheme) {
1028                provider.new_store(url, &options)
1029            } else {
1030                let err = lance_core::Error::from(object_store::Error::NotSupported {
1031                    source: format!("Unsupported URI scheme: {} in url {}", unknown_scheme, url)
1032                        .into(),
1033                });
1034                Err(err)
1035            }
1036        }
1037    }
1038}
1039
1040impl ObjectStore {
1041    #[allow(clippy::too_many_arguments)]
1042    pub fn new(
1043        store: Arc<DynObjectStore>,
1044        location: Url,
1045        block_size: Option<usize>,
1046        wrapper: Option<Arc<dyn WrappingObjectStore>>,
1047        use_constant_size_upload_parts: bool,
1048        list_is_lexically_ordered: bool,
1049        io_parallelism: usize,
1050        download_retry_count: usize,
1051    ) -> Self {
1052        let scheme = location.scheme();
1053        let block_size = block_size.unwrap_or_else(|| infer_block_size(scheme));
1054
1055        let store = match wrapper {
1056            Some(wrapper) => wrapper.wrap(store),
1057            None => store,
1058        };
1059
1060        Self {
1061            inner: store,
1062            scheme: scheme.into(),
1063            block_size,
1064            use_constant_size_upload_parts,
1065            list_is_lexically_ordered,
1066            io_parallelism,
1067            download_retry_count,
1068        }
1069    }
1070}
1071
1072fn infer_block_size(scheme: &str) -> usize {
1073    // Block size: On local file systems, we use 4KB block size. On cloud
1074    // object stores, we use 64KB block size. This is generally the largest
1075    // block size where we don't see a latency penalty.
1076    match scheme {
1077        "file" => 4 * 1024,
1078        _ => 64 * 1024,
1079    }
1080}
1081
1082/// Attempt to create a Url from given table location.
1083///
1084/// The location could be:
1085///  * A valid URL, which will be parsed and returned
1086///  * A path to a directory, which will be created and then converted to a URL.
1087///
1088/// If it is a local path, it will be created if it doesn't exist.
1089///
1090/// Extra slashes will be removed from the end path as well.
1091///
1092/// Will return an error if the location is not valid. For example,
1093pub fn ensure_table_uri(table_uri: impl AsRef<str>) -> Result<Url> {
1094    let table_uri = table_uri.as_ref();
1095
1096    enum UriType {
1097        LocalPath(PathBuf),
1098        Url(Url),
1099    }
1100    let uri_type: UriType = if let Ok(url) = Url::parse(table_uri) {
1101        if url.scheme() == "file" {
1102            UriType::LocalPath(url.to_file_path().map_err(|err| {
1103                let msg = format!("Invalid table location: {}\nError: {:?}", table_uri, err);
1104                Error::InvalidTableLocation { message: msg }
1105            })?)
1106        // NOTE this check is required to support absolute windows paths which may properly parse as url
1107        } else {
1108            UriType::Url(url)
1109        }
1110    } else {
1111        UriType::LocalPath(PathBuf::from(table_uri))
1112    };
1113
1114    // If it is a local path, we need to create it if it does not exist.
1115    let mut url = match uri_type {
1116        UriType::LocalPath(path) => {
1117            let path = std::fs::canonicalize(path).map_err(|err| Error::DatasetNotFound {
1118                path: table_uri.to_string(),
1119                source: Box::new(err),
1120                location: location!(),
1121            })?;
1122            Url::from_directory_path(path).map_err(|_| {
1123                let msg = format!(
1124                    "Could not construct a URL from canonicalized path: {}.\n\
1125                  Something must be very wrong with the table path.",
1126                    table_uri
1127                );
1128                Error::InvalidTableLocation { message: msg }
1129            })?
1130        }
1131        UriType::Url(url) => url,
1132    };
1133
1134    let trimmed_path = url.path().trim_end_matches('/').to_owned();
1135    url.set_path(&trimmed_path);
1136    Ok(url)
1137}
1138
1139lazy_static::lazy_static! {
1140  static ref KNOWN_SCHEMES: Vec<&'static str> =
1141      Vec::from([
1142        "s3",
1143        "s3+ddb",
1144        "gs",
1145        "az",
1146        "file",
1147        "file-object-store",
1148        "memory"
1149      ]);
1150}
1151
1152#[cfg(test)]
1153mod tests {
1154    use super::*;
1155    use parquet::data_type::AsBytes;
1156    use rstest::rstest;
1157    use std::env::set_current_dir;
1158    use std::fs::{create_dir_all, write};
1159    use std::path::Path as StdPath;
1160    use std::sync::atomic::{AtomicBool, Ordering};
1161
1162    /// Write test content to file.
1163    fn write_to_file(path_str: &str, contents: &str) -> std::io::Result<()> {
1164        let expanded = tilde(path_str).to_string();
1165        let path = StdPath::new(&expanded);
1166        std::fs::create_dir_all(path.parent().unwrap())?;
1167        write(path, contents)
1168    }
1169
1170    async fn read_from_store(store: ObjectStore, path: &Path) -> Result<String> {
1171        let test_file_store = store.open(path).await.unwrap();
1172        let size = test_file_store.size().await.unwrap();
1173        let bytes = test_file_store.get_range(0..size).await.unwrap();
1174        let contents = String::from_utf8(bytes.to_vec()).unwrap();
1175        Ok(contents)
1176    }
1177
1178    #[tokio::test]
1179    async fn test_absolute_paths() {
1180        let tmp_dir = tempfile::tempdir().unwrap();
1181        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1182        write_to_file(
1183            &format!("{tmp_path}/bar/foo.lance/test_file"),
1184            "TEST_CONTENT",
1185        )
1186        .unwrap();
1187
1188        // test a few variations of the same path
1189        for uri in &[
1190            format!("{tmp_path}/bar/foo.lance"),
1191            format!("{tmp_path}/./bar/foo.lance"),
1192            format!("{tmp_path}/bar/foo.lance/../foo.lance"),
1193        ] {
1194            let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1195            let contents = read_from_store(store, &path.child("test_file"))
1196                .await
1197                .unwrap();
1198            assert_eq!(contents, "TEST_CONTENT");
1199        }
1200    }
1201
1202    #[tokio::test]
1203    async fn test_cloud_paths() {
1204        let uri = "s3://bucket/foo.lance";
1205        let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1206        assert_eq!(store.scheme, "s3");
1207        assert_eq!(path.to_string(), "foo.lance");
1208
1209        let (store, path) = ObjectStore::from_uri("s3+ddb://bucket/foo.lance")
1210            .await
1211            .unwrap();
1212        assert_eq!(store.scheme, "s3");
1213        assert_eq!(path.to_string(), "foo.lance");
1214
1215        let (store, path) = ObjectStore::from_uri("gs://bucket/foo.lance")
1216            .await
1217            .unwrap();
1218        assert_eq!(store.scheme, "gs");
1219        assert_eq!(path.to_string(), "foo.lance");
1220    }
1221
1222    async fn test_block_size_used_test_helper(
1223        uri: &str,
1224        storage_options: Option<HashMap<String, String>>,
1225        default_expected_block_size: usize,
1226    ) {
1227        // Test the default
1228        let registry = Arc::new(ObjectStoreRegistry::default());
1229        let params = ObjectStoreParams {
1230            storage_options: storage_options.clone(),
1231            ..ObjectStoreParams::default()
1232        };
1233        let (store, _) = ObjectStore::from_uri_and_params(registry, uri, &params)
1234            .await
1235            .unwrap();
1236        assert_eq!(store.block_size, default_expected_block_size);
1237
1238        // Ensure param is used
1239        let registry = Arc::new(ObjectStoreRegistry::default());
1240        let params = ObjectStoreParams {
1241            block_size: Some(1024),
1242            storage_options: storage_options.clone(),
1243            ..ObjectStoreParams::default()
1244        };
1245        let (store, _) = ObjectStore::from_uri_and_params(registry, uri, &params)
1246            .await
1247            .unwrap();
1248        assert_eq!(store.block_size, 1024);
1249    }
1250
1251    #[rstest]
1252    #[case("s3://bucket/foo.lance", None)]
1253    #[case("gs://bucket/foo.lance", None)]
1254    #[case("az://account/bucket/foo.lance",
1255      Some(HashMap::from([
1256            (String::from("account_name"), String::from("account")),
1257            (String::from("container_name"), String::from("container"))
1258           ])))]
1259    #[tokio::test]
1260    async fn test_block_size_used_cloud(
1261        #[case] uri: &str,
1262        #[case] storage_options: Option<HashMap<String, String>>,
1263    ) {
1264        test_block_size_used_test_helper(uri, storage_options, 64 * 1024).await;
1265    }
1266
1267    #[rstest]
1268    #[case("file")]
1269    #[case("file-object-store")]
1270    #[case("memory:///bucket/foo.lance")]
1271    #[tokio::test]
1272    async fn test_block_size_used_file(#[case] prefix: &str) {
1273        let tmp_dir = tempfile::tempdir().unwrap();
1274        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1275        let path = format!("{tmp_path}/bar/foo.lance/test_file");
1276        write_to_file(&path, "URL").unwrap();
1277        let uri = format!("{prefix}:///{path}");
1278        test_block_size_used_test_helper(&uri, None, 4 * 1024).await;
1279    }
1280
1281    #[tokio::test]
1282    async fn test_relative_paths() {
1283        let tmp_dir = tempfile::tempdir().unwrap();
1284        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1285        write_to_file(
1286            &format!("{tmp_path}/bar/foo.lance/test_file"),
1287            "RELATIVE_URL",
1288        )
1289        .unwrap();
1290
1291        set_current_dir(StdPath::new(&tmp_path)).expect("Error changing current dir");
1292        let (store, path) = ObjectStore::from_uri("./bar/foo.lance").await.unwrap();
1293
1294        let contents = read_from_store(store, &path.child("test_file"))
1295            .await
1296            .unwrap();
1297        assert_eq!(contents, "RELATIVE_URL");
1298    }
1299
1300    #[tokio::test]
1301    async fn test_tilde_expansion() {
1302        let uri = "~/foo.lance";
1303        write_to_file(&format!("{uri}/test_file"), "TILDE").unwrap();
1304        let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1305        let contents = read_from_store(store, &path.child("test_file"))
1306            .await
1307            .unwrap();
1308        assert_eq!(contents, "TILDE");
1309    }
1310
1311    #[tokio::test]
1312    async fn test_read_directory() {
1313        let tmp_dir = tempfile::tempdir().unwrap();
1314        let path = tmp_dir.path();
1315        create_dir_all(path.join("foo").join("bar")).unwrap();
1316        create_dir_all(path.join("foo").join("zoo")).unwrap();
1317        create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
1318        write_to_file(
1319            path.join("foo").join("test_file").to_str().unwrap(),
1320            "read_dir",
1321        )
1322        .unwrap();
1323        let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
1324
1325        let sub_dirs = store.read_dir(base.child("foo")).await.unwrap();
1326        assert_eq!(sub_dirs, vec!["bar", "zoo", "test_file"]);
1327    }
1328
1329    #[tokio::test]
1330    async fn test_delete_directory() {
1331        let tmp_dir = tempfile::tempdir().unwrap();
1332        let path = tmp_dir.path();
1333        create_dir_all(path.join("foo").join("bar")).unwrap();
1334        create_dir_all(path.join("foo").join("zoo")).unwrap();
1335        create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
1336        write_to_file(
1337            path.join("foo")
1338                .join("bar")
1339                .join("test_file")
1340                .to_str()
1341                .unwrap(),
1342            "delete",
1343        )
1344        .unwrap();
1345        write_to_file(path.join("foo").join("top").to_str().unwrap(), "delete_top").unwrap();
1346        let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
1347        store.remove_dir_all(base.child("foo")).await.unwrap();
1348
1349        assert!(!path.join("foo").exists());
1350    }
1351
1352    #[derive(Debug)]
1353    struct TestWrapper {
1354        called: AtomicBool,
1355
1356        return_value: Arc<dyn OSObjectStore>,
1357    }
1358
1359    impl WrappingObjectStore for TestWrapper {
1360        fn wrap(&self, _original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore> {
1361            self.called.store(true, Ordering::Relaxed);
1362
1363            // return a mocked value so we can check if the final store is the one we expect
1364            self.return_value.clone()
1365        }
1366    }
1367
1368    impl TestWrapper {
1369        fn called(&self) -> bool {
1370            self.called.load(Ordering::Relaxed)
1371        }
1372    }
1373
1374    #[tokio::test]
1375    async fn test_wrapping_object_store_option_is_used() {
1376        // Make a store for the inner store first
1377        let mock_inner_store: Arc<dyn OSObjectStore> = Arc::new(InMemory::new());
1378        let registry = Arc::new(ObjectStoreRegistry::default());
1379
1380        assert_eq!(Arc::strong_count(&mock_inner_store), 1);
1381
1382        let wrapper = Arc::new(TestWrapper {
1383            called: AtomicBool::new(false),
1384            return_value: mock_inner_store.clone(),
1385        });
1386
1387        let params = ObjectStoreParams {
1388            object_store_wrapper: Some(wrapper.clone()),
1389            ..ObjectStoreParams::default()
1390        };
1391
1392        // not called yet
1393        assert!(!wrapper.called());
1394
1395        let _ = ObjectStore::from_uri_and_params(registry, "memory:///", &params)
1396            .await
1397            .unwrap();
1398
1399        // called after construction
1400        assert!(wrapper.called());
1401
1402        // hard to compare two trait pointers as the point to vtables
1403        // using the ref count as a proxy to make sure that the store is correctly kept
1404        assert_eq!(Arc::strong_count(&mock_inner_store), 2);
1405    }
1406
1407    #[derive(Debug, Default)]
1408    struct MockAwsCredentialsProvider {
1409        called: AtomicBool,
1410    }
1411
1412    #[async_trait]
1413    impl CredentialProvider for MockAwsCredentialsProvider {
1414        type Credential = ObjectStoreAwsCredential;
1415
1416        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
1417            self.called.store(true, Ordering::Relaxed);
1418            Ok(Arc::new(Self::Credential {
1419                key_id: "".to_string(),
1420                secret_key: "".to_string(),
1421                token: None,
1422            }))
1423        }
1424    }
1425
1426    #[tokio::test]
1427    async fn test_injected_aws_creds_option_is_used() {
1428        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
1429        let registry = Arc::new(ObjectStoreRegistry::default());
1430
1431        let params = ObjectStoreParams {
1432            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
1433            ..ObjectStoreParams::default()
1434        };
1435
1436        // Not called yet
1437        assert!(!mock_provider.called.load(Ordering::Relaxed));
1438
1439        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
1440            .await
1441            .unwrap();
1442
1443        // fails, but we don't care
1444        let _ = store
1445            .open(&Path::parse("/").unwrap())
1446            .await
1447            .unwrap()
1448            .get_range(0..1)
1449            .await;
1450
1451        // Not called yet
1452        assert!(mock_provider.called.load(Ordering::Relaxed));
1453    }
1454
1455    #[tokio::test]
1456    async fn test_local_paths() {
1457        let temp_dir = tempfile::tempdir().unwrap();
1458
1459        let file_path = temp_dir.path().join("test_file");
1460        let mut writer = ObjectStore::create_local_writer(file_path.as_path())
1461            .await
1462            .unwrap();
1463        writer.write_all(b"LOCAL").await.unwrap();
1464        writer.shutdown().await.unwrap();
1465
1466        let reader = ObjectStore::open_local(file_path.as_path()).await.unwrap();
1467        let buf = reader.get_range(0..5).await.unwrap();
1468        assert_eq!(buf.as_bytes(), b"LOCAL");
1469    }
1470
1471    #[tokio::test]
1472    async fn test_read_one() {
1473        let temp_dir = tempfile::tempdir().unwrap();
1474
1475        let file_path = temp_dir.path().join("test_file");
1476        let mut writer = ObjectStore::create_local_writer(file_path.as_path())
1477            .await
1478            .unwrap();
1479        writer.write_all(b"LOCAL").await.unwrap();
1480        writer.shutdown().await.unwrap();
1481
1482        let file_path_os = object_store::path::Path::parse(file_path.to_str().unwrap()).unwrap();
1483        let obj_store = ObjectStore::local();
1484        let buf = obj_store.read_one_all(&file_path_os).await.unwrap();
1485        assert_eq!(buf.as_bytes(), b"LOCAL");
1486
1487        let buf = obj_store.read_one_range(&file_path_os, 0..5).await.unwrap();
1488        assert_eq!(buf.as_bytes(), b"LOCAL");
1489    }
1490
1491    #[tokio::test]
1492    #[cfg(windows)]
1493    async fn test_windows_paths() {
1494        use std::path::Component;
1495        use std::path::Prefix;
1496        use std::path::Prefix::*;
1497
1498        fn get_path_prefix(path: &StdPath) -> Prefix {
1499            match path.components().next().unwrap() {
1500                Component::Prefix(prefix_component) => prefix_component.kind(),
1501                _ => panic!(),
1502            }
1503        }
1504
1505        fn get_drive_letter(prefix: Prefix) -> String {
1506            match prefix {
1507                Disk(bytes) => String::from_utf8(vec![bytes]).unwrap(),
1508                _ => panic!(),
1509            }
1510        }
1511
1512        let tmp_dir = tempfile::tempdir().unwrap();
1513        let tmp_path = tmp_dir.path();
1514        let prefix = get_path_prefix(tmp_path);
1515        let drive_letter = get_drive_letter(prefix);
1516
1517        write_to_file(
1518            &(format!("{drive_letter}:/test_folder/test.lance") + "/test_file"),
1519            "WINDOWS",
1520        )
1521        .unwrap();
1522
1523        for uri in &[
1524            format!("{drive_letter}:/test_folder/test.lance"),
1525            format!("{drive_letter}:\\test_folder\\test.lance"),
1526        ] {
1527            let (store, base) = ObjectStore::from_uri(uri).await.unwrap();
1528            let contents = read_from_store(store, &base.child("test_file"))
1529                .await
1530                .unwrap();
1531            assert_eq!(contents, "WINDOWS");
1532        }
1533    }
1534}