lance_io/
object_store.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extend [object_store::ObjectStore] functionalities
5
6use std::collections::HashMap;
7use std::ops::Range;
8use std::path::PathBuf;
9use std::str::FromStr;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12
13use async_trait::async_trait;
14use aws_config::default_provider::credentials::DefaultCredentialsChain;
15use aws_credential_types::provider::ProvideCredentials;
16use bytes::Bytes;
17use chrono::{DateTime, Utc};
18use deepsize::DeepSizeOf;
19use futures::{future, stream::BoxStream, StreamExt, TryStreamExt};
20use lance_core::utils::parse::str_is_truthy;
21use lance_core::utils::tokio::get_num_compute_intensive_cpus;
22use object_store::aws::{
23    AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential, AwsCredentialProvider,
24};
25use object_store::azure::MicrosoftAzureBuilder;
26use object_store::gcp::{GcpCredential, GoogleCloudStorageBuilder};
27use object_store::{
28    aws::AmazonS3Builder, azure::AzureConfigKey, gcp::GoogleConfigKey, local::LocalFileSystem,
29    memory::InMemory, CredentialProvider, Error as ObjectStoreError, Result as ObjectStoreResult,
30};
31use object_store::{path::Path, ObjectMeta, ObjectStore as OSObjectStore};
32use object_store::{ClientOptions, DynObjectStore, RetryConfig, StaticCredentialProvider};
33use shellexpand::tilde;
34use snafu::location;
35use tokio::io::AsyncWriteExt;
36use tokio::sync::RwLock;
37use url::Url;
38
39use super::local::LocalObjectReader;
40mod tracing;
41use self::tracing::ObjectStoreTracingExt;
42use crate::{object_reader::CloudObjectReader, object_writer::ObjectWriter, traits::Reader};
43use lance_core::{Error, Result};
44
45// Local disks tend to do fine with a few threads
46// Note: the number of threads here also impacts the number of files
47// we need to read in some situations.  So keeping this at 8 keeps the
48// RAM on our scanner down.
49pub const DEFAULT_LOCAL_IO_PARALLELISM: usize = 8;
50// Cloud disks often need many many threads to saturate the network
51pub const DEFAULT_CLOUD_IO_PARALLELISM: usize = 64;
52
53pub const DEFAULT_DOWNLOAD_RETRY_COUNT: usize = 3;
54
55#[async_trait]
56pub trait ObjectStoreExt {
57    /// Returns true if the file exists.
58    async fn exists(&self, path: &Path) -> Result<bool>;
59
60    /// Read all files (start from base directory) recursively
61    ///
62    /// unmodified_since can be specified to only return files that have not been modified since the given time.
63    async fn read_dir_all<'a>(
64        &'a self,
65        dir_path: impl Into<&Path> + Send,
66        unmodified_since: Option<DateTime<Utc>>,
67    ) -> Result<BoxStream<'a, Result<ObjectMeta>>>;
68}
69
70#[async_trait]
71impl<O: OSObjectStore + ?Sized> ObjectStoreExt for O {
72    async fn read_dir_all<'a>(
73        &'a self,
74        dir_path: impl Into<&Path> + Send,
75        unmodified_since: Option<DateTime<Utc>>,
76    ) -> Result<BoxStream<'a, Result<ObjectMeta>>> {
77        let mut output = self.list(Some(dir_path.into()));
78        if let Some(unmodified_since_val) = unmodified_since {
79            output = output
80                .try_filter(move |file| future::ready(file.last_modified < unmodified_since_val))
81                .boxed();
82        }
83        Ok(output.map_err(|e| e.into()).boxed())
84    }
85
86    async fn exists(&self, path: &Path) -> Result<bool> {
87        match self.head(path).await {
88            Ok(_) => Ok(true),
89            Err(object_store::Error::NotFound { path: _, source: _ }) => Ok(false),
90            Err(e) => Err(e.into()),
91        }
92    }
93}
94
95/// Wraps [ObjectStore](object_store::ObjectStore)
96#[derive(Debug, Clone)]
97pub struct ObjectStore {
98    // Inner object store
99    pub inner: Arc<dyn OSObjectStore>,
100    scheme: String,
101    block_size: usize,
102    /// Whether to use constant size upload parts for multipart uploads. This
103    /// is only necessary for Cloudflare R2.
104    pub use_constant_size_upload_parts: bool,
105    /// Whether we can assume that the list of files is lexically ordered. This
106    /// is true for object stores, but not for local filesystems.
107    pub list_is_lexically_ordered: bool,
108    io_parallelism: usize,
109    /// Number of times to retry a failed download
110    download_retry_count: usize,
111}
112
113impl DeepSizeOf for ObjectStore {
114    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
115        // We aren't counting `inner` here which is problematic but an ObjectStore
116        // shouldn't be too big.  The only exception might be the write cache but, if
117        // the writer cache has data, it means we're using it somewhere else that isn't
118        // a cache and so that doesn't really count.
119        self.scheme.deep_size_of_children(context) + self.block_size.deep_size_of_children(context)
120    }
121}
122
123impl std::fmt::Display for ObjectStore {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        write!(f, "ObjectStore({})", self.scheme)
126    }
127}
128
129pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
130    fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result<ObjectStore>;
131}
132
133#[derive(Default, Debug)]
134pub struct ObjectStoreRegistry {
135    providers: HashMap<String, Arc<dyn ObjectStoreProvider>>,
136}
137
138impl ObjectStoreRegistry {
139    pub fn insert(&mut self, scheme: &str, provider: Arc<dyn ObjectStoreProvider>) {
140        self.providers.insert(scheme.into(), provider);
141    }
142}
143
144const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
145
146/// Adapt an AWS SDK cred into object_store credentials
147#[derive(Debug)]
148pub struct AwsCredentialAdapter {
149    pub inner: Arc<dyn ProvideCredentials>,
150
151    // RefCell can't be shared across threads, so we use HashMap
152    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
153
154    // The amount of time before expiry to refresh credentials
155    credentials_refresh_offset: Duration,
156}
157
158impl AwsCredentialAdapter {
159    pub fn new(
160        provider: Arc<dyn ProvideCredentials>,
161        credentials_refresh_offset: Duration,
162    ) -> Self {
163        Self {
164            inner: provider,
165            cache: Arc::new(RwLock::new(HashMap::new())),
166            credentials_refresh_offset,
167        }
168    }
169}
170
171#[async_trait]
172impl CredentialProvider for AwsCredentialAdapter {
173    type Credential = ObjectStoreAwsCredential;
174
175    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
176        let cached_creds = {
177            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
178            let expired = cache_value
179                .clone()
180                .map(|cred| {
181                    cred.expiry()
182                        .map(|exp| {
183                            exp.checked_sub(self.credentials_refresh_offset)
184                                .expect("this time should always be valid")
185                                < SystemTime::now()
186                        })
187                        // no expiry is never expire
188                        .unwrap_or(false)
189                })
190                .unwrap_or(true); // no cred is the same as expired;
191            if expired {
192                None
193            } else {
194                cache_value.clone()
195            }
196        };
197
198        if let Some(creds) = cached_creds {
199            Ok(Arc::new(Self::Credential {
200                key_id: creds.access_key_id().to_string(),
201                secret_key: creds.secret_access_key().to_string(),
202                token: creds.session_token().map(|s| s.to_string()),
203            }))
204        } else {
205            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
206                |e| Error::Internal {
207                    message: format!("Failed to get AWS credentials: {}", e),
208                    location: location!(),
209                },
210            )?);
211
212            self.cache
213                .write()
214                .await
215                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
216
217            Ok(Arc::new(Self::Credential {
218                key_id: refreshed_creds.access_key_id().to_string(),
219                secret_key: refreshed_creds.secret_access_key().to_string(),
220                token: refreshed_creds.session_token().map(|s| s.to_string()),
221            }))
222        }
223    }
224}
225
226/// Figure out the S3 region of the bucket.
227///
228/// This resolves in order of precedence:
229/// 1. The region provided in the storage options
230/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
231///
232/// It can return None if no region is provided and the endpoint is set.
233async fn resolve_s3_region(
234    url: &Url,
235    storage_options: &HashMap<AmazonS3ConfigKey, String>,
236) -> Result<Option<String>> {
237    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
238        Ok(Some(region.clone()))
239    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
240        // If no endpoint is set, we can assume this is AWS S3 and the region
241        // can be resolved from the bucket.
242        let bucket = url.host_str().ok_or_else(|| {
243            Error::invalid_input(
244                format!("Could not parse bucket from url: {}", url),
245                location!(),
246            )
247        })?;
248
249        let mut client_options = ClientOptions::default();
250        for (key, value) in storage_options {
251            if let AmazonS3ConfigKey::Client(client_key) = key {
252                client_options = client_options.with_config(*client_key, value.clone());
253            }
254        }
255
256        let bucket_region =
257            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
258        Ok(Some(bucket_region))
259    } else {
260        Ok(None)
261    }
262}
263
264/// Build AWS credentials
265///
266/// This resolves credentials from the following sources in order:
267/// 1. An explicit `credentials` provider
268/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`,
269///    `aws_secret_access_key`, `aws_session_token`)
270/// 3. The default credential provider chain from AWS SDK.
271///
272/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
273pub async fn build_aws_credential(
274    credentials_refresh_offset: Duration,
275    credentials: Option<AwsCredentialProvider>,
276    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
277    region: Option<String>,
278) -> Result<(AwsCredentialProvider, String)> {
279    // TODO: make this return no credential provider not using AWS
280    use aws_config::meta::region::RegionProviderChain;
281    const DEFAULT_REGION: &str = "us-west-2";
282
283    let region = if let Some(region) = region {
284        region
285    } else {
286        RegionProviderChain::default_provider()
287            .or_else(DEFAULT_REGION)
288            .region()
289            .await
290            .map(|r| r.as_ref().to_string())
291            .unwrap_or(DEFAULT_REGION.to_string())
292    };
293
294    if let Some(creds) = credentials {
295        Ok((creds, region))
296    } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
297        Ok((Arc::new(creds), region))
298    } else {
299        let credentials_provider = DefaultCredentialsChain::builder().build().await;
300
301        Ok((
302            Arc::new(AwsCredentialAdapter::new(
303                Arc::new(credentials_provider),
304                credentials_refresh_offset,
305            )),
306            region,
307        ))
308    }
309}
310
311fn extract_static_s3_credentials(
312    options: &HashMap<AmazonS3ConfigKey, String>,
313) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
314    let key_id = options
315        .get(&AmazonS3ConfigKey::AccessKeyId)
316        .map(|s| s.to_string());
317    let secret_key = options
318        .get(&AmazonS3ConfigKey::SecretAccessKey)
319        .map(|s| s.to_string());
320    let token = options
321        .get(&AmazonS3ConfigKey::Token)
322        .map(|s| s.to_string());
323    match (key_id, secret_key, token) {
324        (Some(key_id), Some(secret_key), token) => {
325            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
326                key_id,
327                secret_key,
328                token,
329            }))
330        }
331        _ => None,
332    }
333}
334
335pub trait WrappingObjectStore: std::fmt::Debug + Send + Sync {
336    fn wrap(&self, original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore>;
337}
338
339/// Parameters to create an [ObjectStore]
340///
341#[derive(Debug, Clone)]
342pub struct ObjectStoreParams {
343    pub block_size: Option<usize>,
344    pub object_store: Option<(Arc<DynObjectStore>, Url)>,
345    pub s3_credentials_refresh_offset: Duration,
346    pub aws_credentials: Option<AwsCredentialProvider>,
347    pub object_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
348    pub storage_options: Option<HashMap<String, String>>,
349    /// Use constant size upload parts for multipart uploads. Only necessary
350    /// for Cloudflare R2, which doesn't support variable size parts. When this
351    /// is false, max upload size is 2.5TB. When this is true, the max size is
352    /// 50GB.
353    pub use_constant_size_upload_parts: bool,
354    pub list_is_lexically_ordered: Option<bool>,
355}
356
357impl Default for ObjectStoreParams {
358    fn default() -> Self {
359        Self {
360            object_store: None,
361            block_size: None,
362            s3_credentials_refresh_offset: Duration::from_secs(60),
363            aws_credentials: None,
364            object_store_wrapper: None,
365            storage_options: None,
366            use_constant_size_upload_parts: false,
367            list_is_lexically_ordered: None,
368        }
369    }
370}
371
372impl ObjectStoreParams {
373    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
374    pub fn with_aws_credentials(
375        aws_credentials: Option<AwsCredentialProvider>,
376        region: Option<String>,
377    ) -> Self {
378        Self {
379            aws_credentials,
380            storage_options: region
381                .map(|region| [("region".into(), region)].iter().cloned().collect()),
382            ..Default::default()
383        }
384    }
385}
386
387impl ObjectStore {
388    /// Parse from a string URI.
389    ///
390    /// Returns the ObjectStore instance and the absolute path to the object.
391    pub async fn from_uri(uri: &str) -> Result<(Self, Path)> {
392        let registry = Arc::new(ObjectStoreRegistry::default());
393
394        Self::from_uri_and_params(registry, uri, &ObjectStoreParams::default()).await
395    }
396
397    /// Parse from a string URI.
398    ///
399    /// Returns the ObjectStore instance and the absolute path to the object.
400    pub async fn from_uri_and_params(
401        registry: Arc<ObjectStoreRegistry>,
402        uri: &str,
403        params: &ObjectStoreParams,
404    ) -> Result<(Self, Path)> {
405        if let Some((store, path)) = params.object_store.as_ref() {
406            let mut inner = store.clone();
407            if let Some(wrapper) = params.object_store_wrapper.as_ref() {
408                inner = wrapper.wrap(inner);
409            }
410            let store = Self {
411                inner,
412                scheme: path.scheme().to_string(),
413                block_size: params.block_size.unwrap_or(64 * 1024),
414                use_constant_size_upload_parts: params.use_constant_size_upload_parts,
415                list_is_lexically_ordered: params.list_is_lexically_ordered.unwrap_or_default(),
416                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
417                download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
418            };
419            let path = Path::from(path.path());
420            return Ok((store, path));
421        }
422        let (object_store, path) = match Url::parse(uri) {
423            Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
424                // On Windows, the drive is parsed as a scheme
425                Self::from_path(uri)
426            }
427            Ok(url) => {
428                let store = Self::new_from_url(registry, url.clone(), params.clone()).await?;
429                Ok((store, Path::from(url.path())))
430            }
431            Err(_) => Self::from_path(uri),
432        }?;
433
434        Ok((
435            Self {
436                inner: params
437                    .object_store_wrapper
438                    .as_ref()
439                    .map(|w| w.wrap(object_store.inner.clone()))
440                    .unwrap_or(object_store.inner),
441                ..object_store
442            },
443            path,
444        ))
445    }
446
447    pub fn from_path_with_scheme(str_path: &str, scheme: &str) -> Result<(Self, Path)> {
448        let expanded = tilde(str_path).to_string();
449
450        let mut expanded_path = path_abs::PathAbs::new(expanded)
451            .unwrap()
452            .as_path()
453            .to_path_buf();
454        // path_abs::PathAbs::new(".") returns an empty string.
455        if let Some(s) = expanded_path.as_path().to_str() {
456            if s.is_empty() {
457                expanded_path = std::env::current_dir()?;
458            }
459        }
460        Ok((
461            Self {
462                inner: Arc::new(LocalFileSystem::new()).traced(),
463                scheme: String::from(scheme),
464                block_size: 4 * 1024, // 4KB block size
465                use_constant_size_upload_parts: false,
466                list_is_lexically_ordered: false,
467                io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM,
468                download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
469            },
470            Path::from_absolute_path(expanded_path.as_path())?,
471        ))
472    }
473
474    pub fn from_path(str_path: &str) -> Result<(Self, Path)> {
475        Self::from_path_with_scheme(str_path, "file")
476    }
477
478    async fn new_from_url(
479        registry: Arc<ObjectStoreRegistry>,
480        url: Url,
481        params: ObjectStoreParams,
482    ) -> Result<Self> {
483        configure_store(registry, url.as_str(), params).await
484    }
485
486    /// Local object store.
487    pub fn local() -> Self {
488        Self {
489            inner: Arc::new(LocalFileSystem::new()).traced(),
490            scheme: String::from("file"),
491            block_size: 4 * 1024, // 4KB block size
492            use_constant_size_upload_parts: false,
493            list_is_lexically_ordered: false,
494            io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM,
495            download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
496        }
497    }
498
499    /// Create a in-memory object store directly for testing.
500    pub fn memory() -> Self {
501        Self {
502            inner: Arc::new(InMemory::new()).traced(),
503            scheme: String::from("memory"),
504            block_size: 4 * 1024,
505            use_constant_size_upload_parts: false,
506            list_is_lexically_ordered: true,
507            io_parallelism: get_num_compute_intensive_cpus(),
508            download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT,
509        }
510    }
511
512    /// Returns true if the object store pointed to a local file system.
513    pub fn is_local(&self) -> bool {
514        self.scheme == "file"
515    }
516
517    pub fn is_cloud(&self) -> bool {
518        self.scheme != "file" && self.scheme != "memory"
519    }
520
521    pub fn block_size(&self) -> usize {
522        self.block_size
523    }
524
525    pub fn set_block_size(&mut self, new_size: usize) {
526        self.block_size = new_size;
527    }
528
529    pub fn set_io_parallelism(&mut self, io_parallelism: usize) {
530        self.io_parallelism = io_parallelism;
531    }
532
533    pub fn io_parallelism(&self) -> usize {
534        std::env::var("LANCE_IO_THREADS")
535            .map(|val| val.parse::<usize>().unwrap())
536            .unwrap_or(self.io_parallelism)
537    }
538
539    /// Open a file for path.
540    ///
541    /// Parameters
542    /// - ``path``: Absolute path to the file.
543    pub async fn open(&self, path: &Path) -> Result<Box<dyn Reader>> {
544        match self.scheme.as_str() {
545            "file" => LocalObjectReader::open(path, self.block_size, None).await,
546            _ => Ok(Box::new(CloudObjectReader::new(
547                self.inner.clone(),
548                path.clone(),
549                self.block_size,
550                None,
551                self.download_retry_count,
552            )?)),
553        }
554    }
555
556    /// Open a reader for a file with known size.
557    ///
558    /// This size may either have been retrieved from a list operation or
559    /// cached metadata. By passing in the known size, we can skip a HEAD / metadata
560    /// call.
561    pub async fn open_with_size(&self, path: &Path, known_size: usize) -> Result<Box<dyn Reader>> {
562        match self.scheme.as_str() {
563            "file" => LocalObjectReader::open(path, self.block_size, Some(known_size)).await,
564            _ => Ok(Box::new(CloudObjectReader::new(
565                self.inner.clone(),
566                path.clone(),
567                self.block_size,
568                Some(known_size),
569                self.download_retry_count,
570            )?)),
571        }
572    }
573
574    /// Create an [ObjectWriter] from local [std::path::Path]
575    pub async fn create_local_writer(path: &std::path::Path) -> Result<ObjectWriter> {
576        let object_store = Self::local();
577        let os_path = Path::from(path.to_str().unwrap());
578        object_store.create(&os_path).await
579    }
580
581    /// Open an [Reader] from local [std::path::Path]
582    pub async fn open_local(path: &std::path::Path) -> Result<Box<dyn Reader>> {
583        let object_store = Self::local();
584        let os_path = Path::from(path.to_str().unwrap());
585        object_store.open(&os_path).await
586    }
587
588    /// Create a new file.
589    pub async fn create(&self, path: &Path) -> Result<ObjectWriter> {
590        ObjectWriter::new(self, path).await
591    }
592
593    /// A helper function to create a file and write content to it.
594    pub async fn put(&self, path: &Path, content: &[u8]) -> Result<()> {
595        let mut writer = self.create(path).await?;
596        writer.write_all(content).await?;
597        writer.shutdown().await
598    }
599
600    pub async fn delete(&self, path: &Path) -> Result<()> {
601        self.inner.delete(path).await?;
602        Ok(())
603    }
604
605    pub async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
606        Ok(self.inner.copy(from, to).await?)
607    }
608
609    /// Read a directory (start from base directory) and returns all sub-paths in the directory.
610    pub async fn read_dir(&self, dir_path: impl Into<Path>) -> Result<Vec<String>> {
611        let path = dir_path.into();
612        let path = Path::parse(&path)?;
613        let output = self.inner.list_with_delimiter(Some(&path)).await?;
614        Ok(output
615            .common_prefixes
616            .iter()
617            .chain(output.objects.iter().map(|o| &o.location))
618            .map(|s| s.filename().unwrap().to_string())
619            .collect())
620    }
621
622    /// Read all files (start from base directory) recursively
623    ///
624    /// unmodified_since can be specified to only return files that have not been modified since the given time.
625    pub async fn read_dir_all(
626        &self,
627        dir_path: impl Into<&Path> + Send,
628        unmodified_since: Option<DateTime<Utc>>,
629    ) -> Result<BoxStream<Result<ObjectMeta>>> {
630        self.inner.read_dir_all(dir_path, unmodified_since).await
631    }
632
633    /// Remove a directory recursively.
634    pub async fn remove_dir_all(&self, dir_path: impl Into<Path>) -> Result<()> {
635        let path = dir_path.into();
636        let path = Path::parse(&path)?;
637
638        if self.is_local() {
639            // Local file system needs to delete directories as well.
640            return super::local::remove_dir_all(&path);
641        }
642        let sub_entries = self
643            .inner
644            .list(Some(&path))
645            .map(|m| m.map(|meta| meta.location))
646            .boxed();
647        self.inner
648            .delete_stream(sub_entries)
649            .try_collect::<Vec<_>>()
650            .await?;
651        Ok(())
652    }
653
654    pub fn remove_stream<'a>(
655        &'a self,
656        locations: BoxStream<'a, Result<Path>>,
657    ) -> BoxStream<'a, Result<Path>> {
658        self.inner
659            .delete_stream(locations.err_into::<ObjectStoreError>().boxed())
660            .err_into::<Error>()
661            .boxed()
662    }
663
664    /// Check a file exists.
665    pub async fn exists(&self, path: &Path) -> Result<bool> {
666        match self.inner.head(path).await {
667            Ok(_) => Ok(true),
668            Err(object_store::Error::NotFound { path: _, source: _ }) => Ok(false),
669            Err(e) => Err(e.into()),
670        }
671    }
672
673    /// Get file size.
674    pub async fn size(&self, path: &Path) -> Result<usize> {
675        Ok(self.inner.head(path).await?.size)
676    }
677
678    /// Convenience function to open a reader and read all the bytes
679    pub async fn read_one_all(&self, path: &Path) -> Result<Bytes> {
680        let reader = self.open(path).await?;
681        Ok(reader.get_all().await?)
682    }
683
684    /// Convenience function open a reader and make a single request
685    ///
686    /// If you will be making multiple requests to the path it is more efficient to call [`Self::open`]
687    /// and then call [`Reader::get_range`] multiple times.
688    pub async fn read_one_range(&self, path: &Path, range: Range<usize>) -> Result<Bytes> {
689        let reader = self.open(path).await?;
690        Ok(reader.get_range(range).await?)
691    }
692}
693
694/// Options that can be set for multiple object stores
695#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy)]
696pub enum LanceConfigKey {
697    /// Number of times to retry a download that fails
698    DownloadRetryCount,
699}
700
701impl FromStr for LanceConfigKey {
702    type Err = Error;
703
704    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
705        match s.to_ascii_lowercase().as_str() {
706            "download_retry_count" => Ok(Self::DownloadRetryCount),
707            _ => Err(Error::InvalidInput {
708                source: format!("Invalid LanceConfigKey: {}", s).into(),
709                location: location!(),
710            }),
711        }
712    }
713}
714
715#[derive(Clone, Debug, Default)]
716pub struct StorageOptions(pub HashMap<String, String>);
717
718impl StorageOptions {
719    /// Create a new instance of [`StorageOptions`]
720    pub fn new(options: HashMap<String, String>) -> Self {
721        let mut options = options;
722        if let Ok(value) = std::env::var("AZURE_STORAGE_ALLOW_HTTP") {
723            options.insert("allow_http".into(), value);
724        }
725        if let Ok(value) = std::env::var("AZURE_STORAGE_USE_HTTP") {
726            options.insert("allow_http".into(), value);
727        }
728        if let Ok(value) = std::env::var("AWS_ALLOW_HTTP") {
729            options.insert("allow_http".into(), value);
730        }
731        if let Ok(value) = std::env::var("OBJECT_STORE_CLIENT_MAX_RETRIES") {
732            options.insert("client_max_retries".into(), value);
733        }
734        if let Ok(value) = std::env::var("OBJECT_STORE_CLIENT_RETRY_TIMEOUT") {
735            options.insert("client_retry_timeout".into(), value);
736        }
737        Self(options)
738    }
739
740    /// Add values from the environment to storage options
741    pub fn with_env_azure(&mut self) {
742        for (os_key, os_value) in std::env::vars_os() {
743            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
744                if let Ok(config_key) = AzureConfigKey::from_str(&key.to_ascii_lowercase()) {
745                    if !self.0.contains_key(config_key.as_ref()) {
746                        self.0
747                            .insert(config_key.as_ref().to_string(), value.to_string());
748                    }
749                }
750            }
751        }
752    }
753
754    /// Add values from the environment to storage options
755    pub fn with_env_gcs(&mut self) {
756        for (os_key, os_value) in std::env::vars_os() {
757            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
758                let lowercase_key = key.to_ascii_lowercase();
759                let token_key = "google_storage_token";
760
761                if let Ok(config_key) = GoogleConfigKey::from_str(&lowercase_key) {
762                    if !self.0.contains_key(config_key.as_ref()) {
763                        self.0
764                            .insert(config_key.as_ref().to_string(), value.to_string());
765                    }
766                }
767                // Check for GOOGLE_STORAGE_TOKEN until GoogleConfigKey supports storage token
768                else if lowercase_key == token_key && !self.0.contains_key(token_key) {
769                    self.0.insert(token_key.to_string(), value.to_string());
770                }
771            }
772        }
773    }
774
775    /// Add values from the environment to storage options
776    pub fn with_env_s3(&mut self) {
777        for (os_key, os_value) in std::env::vars_os() {
778            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
779                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
780                    if !self.0.contains_key(config_key.as_ref()) {
781                        self.0
782                            .insert(config_key.as_ref().to_string(), value.to_string());
783                    }
784                }
785            }
786        }
787    }
788
789    /// Denotes if unsecure connections via http are allowed
790    pub fn allow_http(&self) -> bool {
791        self.0.iter().any(|(key, value)| {
792            key.to_ascii_lowercase().contains("allow_http") & str_is_truthy(value)
793        })
794    }
795
796    /// Number of times to retry a download that fails
797    pub fn download_retry_count(&self) -> usize {
798        self.0
799            .iter()
800            .find(|(key, _)| key.eq_ignore_ascii_case("download_retry_count"))
801            .map(|(_, value)| value.parse::<usize>().unwrap_or(3))
802            .unwrap_or(3)
803    }
804
805    /// Max retry times to set in RetryConfig for object store client
806    pub fn client_max_retries(&self) -> usize {
807        self.0
808            .iter()
809            .find(|(key, _)| key.eq_ignore_ascii_case("client_max_retries"))
810            .and_then(|(_, value)| value.parse::<usize>().ok())
811            .unwrap_or(10)
812    }
813
814    /// Seconds of timeout to set in RetryConfig for object store client
815    pub fn client_retry_timeout(&self) -> u64 {
816        self.0
817            .iter()
818            .find(|(key, _)| key.eq_ignore_ascii_case("client_retry_timeout"))
819            .and_then(|(_, value)| value.parse::<u64>().ok())
820            .unwrap_or(180)
821    }
822
823    /// Subset of options relevant for azure storage
824    pub fn as_azure_options(&self) -> HashMap<AzureConfigKey, String> {
825        self.0
826            .iter()
827            .filter_map(|(key, value)| {
828                let az_key = AzureConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
829                Some((az_key, value.clone()))
830            })
831            .collect()
832    }
833
834    /// Subset of options relevant for s3 storage
835    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
836        self.0
837            .iter()
838            .filter_map(|(key, value)| {
839                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
840                Some((s3_key, value.clone()))
841            })
842            .collect()
843    }
844
845    /// Subset of options relevant for gcs storage
846    pub fn as_gcs_options(&self) -> HashMap<GoogleConfigKey, String> {
847        self.0
848            .iter()
849            .filter_map(|(key, value)| {
850                let gcs_key = GoogleConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
851                Some((gcs_key, value.clone()))
852            })
853            .collect()
854    }
855
856    pub fn get(&self, key: &str) -> Option<&String> {
857        self.0.get(key)
858    }
859}
860
861impl From<HashMap<String, String>> for StorageOptions {
862    fn from(value: HashMap<String, String>) -> Self {
863        Self::new(value)
864    }
865}
866
867async fn configure_store(
868    registry: Arc<ObjectStoreRegistry>,
869    url: &str,
870    options: ObjectStoreParams,
871) -> Result<ObjectStore> {
872    let mut storage_options = StorageOptions(options.storage_options.clone().unwrap_or_default());
873    let download_retry_count = storage_options.download_retry_count();
874    let mut url = ensure_table_uri(url)?;
875    // Block size: On local file systems, we use 4KB block size. On cloud
876    // object stores, we use 64KB block size. This is generally the largest
877    // block size where we don't see a latency penalty.
878    let file_block_size = options.block_size.unwrap_or(4 * 1024);
879    let cloud_block_size = options.block_size.unwrap_or(64 * 1024);
880    let max_retries = storage_options.client_max_retries();
881    let retry_timeout = storage_options.client_retry_timeout();
882    let retry_config = RetryConfig {
883        backoff: Default::default(),
884        max_retries,
885        retry_timeout: Duration::from_secs(retry_timeout),
886    };
887    match url.scheme() {
888        "s3" | "s3+ddb" => {
889            storage_options.with_env_s3();
890
891            // if url.scheme() == "s3+ddb" && options.commit_handler.is_some() {
892            //     return Err(Error::InvalidInput {
893            //         source: "`s3+ddb://` scheme and custom commit handler are mutually exclusive"
894            //             .into(),
895            //         location: location!(),
896            //     });
897            // }
898
899            let mut storage_options = storage_options.as_s3_options();
900            let region = resolve_s3_region(&url, &storage_options).await?;
901            let (aws_creds, region) = build_aws_credential(
902                options.s3_credentials_refresh_offset,
903                options.aws_credentials.clone(),
904                Some(&storage_options),
905                region,
906            )
907            .await?;
908
909            // This will be default in next version of object store.
910            // https://github.com/apache/arrow-rs/pull/7181
911            storage_options
912                .entry(AmazonS3ConfigKey::ConditionalPut)
913                .or_insert_with(|| "etag".to_string());
914
915            // Cloudflare does not support varying part sizes.
916            let use_constant_size_upload_parts = storage_options
917                .get(&AmazonS3ConfigKey::Endpoint)
918                .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
919                .unwrap_or(false);
920
921            // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
922            url.set_scheme("s3").map_err(|()| Error::Internal {
923                message: "could not set scheme".into(),
924                location: location!(),
925            })?;
926
927            url.set_query(None);
928
929            // we can't use parse_url_opts here because we need to manually set the credentials provider
930            let mut builder = AmazonS3Builder::new();
931            for (key, value) in storage_options {
932                builder = builder.with_config(key, value);
933            }
934            builder = builder
935                .with_url(url.as_ref())
936                .with_credentials(aws_creds)
937                .with_retry(retry_config)
938                .with_region(region);
939            let store = builder.build()?;
940
941            Ok(ObjectStore {
942                inner: Arc::new(store).traced(),
943                scheme: String::from(url.scheme()),
944                block_size: cloud_block_size,
945                use_constant_size_upload_parts,
946                list_is_lexically_ordered: true,
947                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
948                download_retry_count,
949            })
950        }
951        "gs" => {
952            storage_options.with_env_gcs();
953            let mut builder = GoogleCloudStorageBuilder::new()
954                .with_url(url.as_ref())
955                .with_retry(retry_config);
956            for (key, value) in storage_options.as_gcs_options() {
957                builder = builder.with_config(key, value);
958            }
959            let token_key = "google_storage_token";
960            if let Some(storage_token) = storage_options.get(token_key) {
961                let credential = GcpCredential {
962                    bearer: storage_token.to_string(),
963                };
964                let credential_provider = Arc::new(StaticCredentialProvider::new(credential)) as _;
965                builder = builder.with_credentials(credential_provider);
966            }
967            let store = builder.build()?;
968            let store = Arc::new(store).traced();
969
970            Ok(ObjectStore {
971                inner: store,
972                scheme: String::from("gs"),
973                block_size: cloud_block_size,
974                use_constant_size_upload_parts: false,
975                list_is_lexically_ordered: true,
976                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
977                download_retry_count,
978            })
979        }
980        "az" => {
981            storage_options.with_env_azure();
982            let mut builder = MicrosoftAzureBuilder::new()
983                .with_url(url.as_ref())
984                .with_retry(retry_config);
985            for (key, value) in storage_options.as_azure_options() {
986                builder = builder.with_config(key, value);
987            }
988            let store = builder.build()?;
989            let store = Arc::new(store).traced();
990
991            Ok(ObjectStore {
992                inner: store,
993                scheme: String::from("az"),
994                block_size: cloud_block_size,
995                use_constant_size_upload_parts: false,
996                list_is_lexically_ordered: true,
997                io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
998                download_retry_count,
999            })
1000        }
1001        // we have a bypass logic to use `tokio::fs` directly to lower overhead
1002        // however this makes testing harder as we can't use the same code path
1003        // "file-object-store" forces local file system dataset to use the same
1004        // code path as cloud object stores
1005        "file" => {
1006            let mut object_store = ObjectStore::from_path(url.path())?.0;
1007            object_store.set_block_size(file_block_size);
1008            Ok(object_store)
1009        }
1010        "file-object-store" => {
1011            let mut object_store =
1012                ObjectStore::from_path_with_scheme(url.path(), "file-object-store")?.0;
1013            object_store.set_block_size(file_block_size);
1014            Ok(object_store)
1015        }
1016        "memory" => Ok(ObjectStore {
1017            inner: Arc::new(InMemory::new()).traced(),
1018            scheme: String::from("memory"),
1019            block_size: file_block_size,
1020            use_constant_size_upload_parts: false,
1021            list_is_lexically_ordered: true,
1022            io_parallelism: get_num_compute_intensive_cpus(),
1023            download_retry_count,
1024        }),
1025        unknown_scheme => {
1026            if let Some(provider) = registry.providers.get(unknown_scheme) {
1027                provider.new_store(url, &options)
1028            } else {
1029                let err = lance_core::Error::from(object_store::Error::NotSupported {
1030                    source: format!("Unsupported URI scheme: {} in url {}", unknown_scheme, url)
1031                        .into(),
1032                });
1033                Err(err)
1034            }
1035        }
1036    }
1037}
1038
1039impl ObjectStore {
1040    #[allow(clippy::too_many_arguments)]
1041    pub fn new(
1042        store: Arc<DynObjectStore>,
1043        location: Url,
1044        block_size: Option<usize>,
1045        wrapper: Option<Arc<dyn WrappingObjectStore>>,
1046        use_constant_size_upload_parts: bool,
1047        list_is_lexically_ordered: bool,
1048        io_parallelism: usize,
1049        download_retry_count: usize,
1050    ) -> Self {
1051        let scheme = location.scheme();
1052        let block_size = block_size.unwrap_or_else(|| infer_block_size(scheme));
1053
1054        let store = match wrapper {
1055            Some(wrapper) => wrapper.wrap(store),
1056            None => store,
1057        };
1058
1059        Self {
1060            inner: store,
1061            scheme: scheme.into(),
1062            block_size,
1063            use_constant_size_upload_parts,
1064            list_is_lexically_ordered,
1065            io_parallelism,
1066            download_retry_count,
1067        }
1068    }
1069}
1070
1071fn infer_block_size(scheme: &str) -> usize {
1072    // Block size: On local file systems, we use 4KB block size. On cloud
1073    // object stores, we use 64KB block size. This is generally the largest
1074    // block size where we don't see a latency penalty.
1075    match scheme {
1076        "file" => 4 * 1024,
1077        _ => 64 * 1024,
1078    }
1079}
1080
1081/// Attempt to create a Url from given table location.
1082///
1083/// The location could be:
1084///  * A valid URL, which will be parsed and returned
1085///  * A path to a directory, which will be created and then converted to a URL.
1086///
1087/// If it is a local path, it will be created if it doesn't exist.
1088///
1089/// Extra slashes will be removed from the end path as well.
1090///
1091/// Will return an error if the location is not valid. For example,
1092pub fn ensure_table_uri(table_uri: impl AsRef<str>) -> Result<Url> {
1093    let table_uri = table_uri.as_ref();
1094
1095    enum UriType {
1096        LocalPath(PathBuf),
1097        Url(Url),
1098    }
1099    let uri_type: UriType = if let Ok(url) = Url::parse(table_uri) {
1100        if url.scheme() == "file" {
1101            UriType::LocalPath(url.to_file_path().map_err(|err| {
1102                let msg = format!("Invalid table location: {}\nError: {:?}", table_uri, err);
1103                Error::InvalidTableLocation { message: msg }
1104            })?)
1105        // NOTE this check is required to support absolute windows paths which may properly parse as url
1106        } else {
1107            UriType::Url(url)
1108        }
1109    } else {
1110        UriType::LocalPath(PathBuf::from(table_uri))
1111    };
1112
1113    // If it is a local path, we need to create it if it does not exist.
1114    let mut url = match uri_type {
1115        UriType::LocalPath(path) => {
1116            let path = std::fs::canonicalize(path).map_err(|err| Error::DatasetNotFound {
1117                path: table_uri.to_string(),
1118                source: Box::new(err),
1119                location: location!(),
1120            })?;
1121            Url::from_directory_path(path).map_err(|_| {
1122                let msg = format!(
1123                    "Could not construct a URL from canonicalized path: {}.\n\
1124                  Something must be very wrong with the table path.",
1125                    table_uri
1126                );
1127                Error::InvalidTableLocation { message: msg }
1128            })?
1129        }
1130        UriType::Url(url) => url,
1131    };
1132
1133    let trimmed_path = url.path().trim_end_matches('/').to_owned();
1134    url.set_path(&trimmed_path);
1135    Ok(url)
1136}
1137
1138lazy_static::lazy_static! {
1139  static ref KNOWN_SCHEMES: Vec<&'static str> =
1140      Vec::from([
1141        "s3",
1142        "s3+ddb",
1143        "gs",
1144        "az",
1145        "file",
1146        "file-object-store",
1147        "memory"
1148      ]);
1149}
1150
1151#[cfg(test)]
1152mod tests {
1153    use super::*;
1154    use parquet::data_type::AsBytes;
1155    use rstest::rstest;
1156    use std::env::set_current_dir;
1157    use std::fs::{create_dir_all, write};
1158    use std::path::Path as StdPath;
1159    use std::sync::atomic::{AtomicBool, Ordering};
1160
1161    /// Write test content to file.
1162    fn write_to_file(path_str: &str, contents: &str) -> std::io::Result<()> {
1163        let expanded = tilde(path_str).to_string();
1164        let path = StdPath::new(&expanded);
1165        std::fs::create_dir_all(path.parent().unwrap())?;
1166        write(path, contents)
1167    }
1168
1169    async fn read_from_store(store: ObjectStore, path: &Path) -> Result<String> {
1170        let test_file_store = store.open(path).await.unwrap();
1171        let size = test_file_store.size().await.unwrap();
1172        let bytes = test_file_store.get_range(0..size).await.unwrap();
1173        let contents = String::from_utf8(bytes.to_vec()).unwrap();
1174        Ok(contents)
1175    }
1176
1177    #[tokio::test]
1178    async fn test_absolute_paths() {
1179        let tmp_dir = tempfile::tempdir().unwrap();
1180        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1181        write_to_file(
1182            &format!("{tmp_path}/bar/foo.lance/test_file"),
1183            "TEST_CONTENT",
1184        )
1185        .unwrap();
1186
1187        // test a few variations of the same path
1188        for uri in &[
1189            format!("{tmp_path}/bar/foo.lance"),
1190            format!("{tmp_path}/./bar/foo.lance"),
1191            format!("{tmp_path}/bar/foo.lance/../foo.lance"),
1192        ] {
1193            let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1194            let contents = read_from_store(store, &path.child("test_file"))
1195                .await
1196                .unwrap();
1197            assert_eq!(contents, "TEST_CONTENT");
1198        }
1199    }
1200
1201    #[tokio::test]
1202    async fn test_cloud_paths() {
1203        let uri = "s3://bucket/foo.lance";
1204        let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1205        assert_eq!(store.scheme, "s3");
1206        assert_eq!(path.to_string(), "foo.lance");
1207
1208        let (store, path) = ObjectStore::from_uri("s3+ddb://bucket/foo.lance")
1209            .await
1210            .unwrap();
1211        assert_eq!(store.scheme, "s3");
1212        assert_eq!(path.to_string(), "foo.lance");
1213
1214        let (store, path) = ObjectStore::from_uri("gs://bucket/foo.lance")
1215            .await
1216            .unwrap();
1217        assert_eq!(store.scheme, "gs");
1218        assert_eq!(path.to_string(), "foo.lance");
1219    }
1220
1221    async fn test_block_size_used_test_helper(
1222        uri: &str,
1223        storage_options: Option<HashMap<String, String>>,
1224        default_expected_block_size: usize,
1225    ) {
1226        // Test the default
1227        let registry = Arc::new(ObjectStoreRegistry::default());
1228        let params = ObjectStoreParams {
1229            storage_options: storage_options.clone(),
1230            ..ObjectStoreParams::default()
1231        };
1232        let (store, _) = ObjectStore::from_uri_and_params(registry, uri, &params)
1233            .await
1234            .unwrap();
1235        assert_eq!(store.block_size, default_expected_block_size);
1236
1237        // Ensure param is used
1238        let registry = Arc::new(ObjectStoreRegistry::default());
1239        let params = ObjectStoreParams {
1240            block_size: Some(1024),
1241            storage_options: storage_options.clone(),
1242            ..ObjectStoreParams::default()
1243        };
1244        let (store, _) = ObjectStore::from_uri_and_params(registry, uri, &params)
1245            .await
1246            .unwrap();
1247        assert_eq!(store.block_size, 1024);
1248    }
1249
1250    #[rstest]
1251    #[case("s3://bucket/foo.lance", None)]
1252    #[case("gs://bucket/foo.lance", None)]
1253    #[case("az://account/bucket/foo.lance",
1254      Some(HashMap::from([
1255            (String::from("account_name"), String::from("account")),
1256            (String::from("container_name"), String::from("container"))
1257           ])))]
1258    #[tokio::test]
1259    async fn test_block_size_used_cloud(
1260        #[case] uri: &str,
1261        #[case] storage_options: Option<HashMap<String, String>>,
1262    ) {
1263        test_block_size_used_test_helper(uri, storage_options, 64 * 1024).await;
1264    }
1265
1266    #[rstest]
1267    #[case("file")]
1268    #[case("file-object-store")]
1269    #[case("memory:///bucket/foo.lance")]
1270    #[tokio::test]
1271    async fn test_block_size_used_file(#[case] prefix: &str) {
1272        let tmp_dir = tempfile::tempdir().unwrap();
1273        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1274        let path = format!("{tmp_path}/bar/foo.lance/test_file");
1275        write_to_file(&path, "URL").unwrap();
1276        let uri = format!("{prefix}:///{path}");
1277        test_block_size_used_test_helper(&uri, None, 4 * 1024).await;
1278    }
1279
1280    #[tokio::test]
1281    async fn test_relative_paths() {
1282        let tmp_dir = tempfile::tempdir().unwrap();
1283        let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
1284        write_to_file(
1285            &format!("{tmp_path}/bar/foo.lance/test_file"),
1286            "RELATIVE_URL",
1287        )
1288        .unwrap();
1289
1290        set_current_dir(StdPath::new(&tmp_path)).expect("Error changing current dir");
1291        let (store, path) = ObjectStore::from_uri("./bar/foo.lance").await.unwrap();
1292
1293        let contents = read_from_store(store, &path.child("test_file"))
1294            .await
1295            .unwrap();
1296        assert_eq!(contents, "RELATIVE_URL");
1297    }
1298
1299    #[tokio::test]
1300    async fn test_tilde_expansion() {
1301        let uri = "~/foo.lance";
1302        write_to_file(&format!("{uri}/test_file"), "TILDE").unwrap();
1303        let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
1304        let contents = read_from_store(store, &path.child("test_file"))
1305            .await
1306            .unwrap();
1307        assert_eq!(contents, "TILDE");
1308    }
1309
1310    #[tokio::test]
1311    async fn test_read_directory() {
1312        let tmp_dir = tempfile::tempdir().unwrap();
1313        let path = tmp_dir.path();
1314        create_dir_all(path.join("foo").join("bar")).unwrap();
1315        create_dir_all(path.join("foo").join("zoo")).unwrap();
1316        create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
1317        write_to_file(
1318            path.join("foo").join("test_file").to_str().unwrap(),
1319            "read_dir",
1320        )
1321        .unwrap();
1322        let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
1323
1324        let sub_dirs = store.read_dir(base.child("foo")).await.unwrap();
1325        assert_eq!(sub_dirs, vec!["bar", "zoo", "test_file"]);
1326    }
1327
1328    #[tokio::test]
1329    async fn test_delete_directory() {
1330        let tmp_dir = tempfile::tempdir().unwrap();
1331        let path = tmp_dir.path();
1332        create_dir_all(path.join("foo").join("bar")).unwrap();
1333        create_dir_all(path.join("foo").join("zoo")).unwrap();
1334        create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
1335        write_to_file(
1336            path.join("foo")
1337                .join("bar")
1338                .join("test_file")
1339                .to_str()
1340                .unwrap(),
1341            "delete",
1342        )
1343        .unwrap();
1344        write_to_file(path.join("foo").join("top").to_str().unwrap(), "delete_top").unwrap();
1345        let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
1346        store.remove_dir_all(base.child("foo")).await.unwrap();
1347
1348        assert!(!path.join("foo").exists());
1349    }
1350
1351    #[derive(Debug)]
1352    struct TestWrapper {
1353        called: AtomicBool,
1354
1355        return_value: Arc<dyn OSObjectStore>,
1356    }
1357
1358    impl WrappingObjectStore for TestWrapper {
1359        fn wrap(&self, _original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore> {
1360            self.called.store(true, Ordering::Relaxed);
1361
1362            // return a mocked value so we can check if the final store is the one we expect
1363            self.return_value.clone()
1364        }
1365    }
1366
1367    impl TestWrapper {
1368        fn called(&self) -> bool {
1369            self.called.load(Ordering::Relaxed)
1370        }
1371    }
1372
1373    #[tokio::test]
1374    async fn test_wrapping_object_store_option_is_used() {
1375        // Make a store for the inner store first
1376        let mock_inner_store: Arc<dyn OSObjectStore> = Arc::new(InMemory::new());
1377        let registry = Arc::new(ObjectStoreRegistry::default());
1378
1379        assert_eq!(Arc::strong_count(&mock_inner_store), 1);
1380
1381        let wrapper = Arc::new(TestWrapper {
1382            called: AtomicBool::new(false),
1383            return_value: mock_inner_store.clone(),
1384        });
1385
1386        let params = ObjectStoreParams {
1387            object_store_wrapper: Some(wrapper.clone()),
1388            ..ObjectStoreParams::default()
1389        };
1390
1391        // not called yet
1392        assert!(!wrapper.called());
1393
1394        let _ = ObjectStore::from_uri_and_params(registry, "memory:///", &params)
1395            .await
1396            .unwrap();
1397
1398        // called after construction
1399        assert!(wrapper.called());
1400
1401        // hard to compare two trait pointers as the point to vtables
1402        // using the ref count as a proxy to make sure that the store is correctly kept
1403        assert_eq!(Arc::strong_count(&mock_inner_store), 2);
1404    }
1405
1406    #[derive(Debug, Default)]
1407    struct MockAwsCredentialsProvider {
1408        called: AtomicBool,
1409    }
1410
1411    #[async_trait]
1412    impl CredentialProvider for MockAwsCredentialsProvider {
1413        type Credential = ObjectStoreAwsCredential;
1414
1415        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
1416            self.called.store(true, Ordering::Relaxed);
1417            Ok(Arc::new(Self::Credential {
1418                key_id: "".to_string(),
1419                secret_key: "".to_string(),
1420                token: None,
1421            }))
1422        }
1423    }
1424
1425    #[tokio::test]
1426    async fn test_injected_aws_creds_option_is_used() {
1427        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
1428        let registry = Arc::new(ObjectStoreRegistry::default());
1429
1430        let params = ObjectStoreParams {
1431            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
1432            ..ObjectStoreParams::default()
1433        };
1434
1435        // Not called yet
1436        assert!(!mock_provider.called.load(Ordering::Relaxed));
1437
1438        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
1439            .await
1440            .unwrap();
1441
1442        // fails, but we don't care
1443        let _ = store
1444            .open(&Path::parse("/").unwrap())
1445            .await
1446            .unwrap()
1447            .get_range(0..1)
1448            .await;
1449
1450        // Not called yet
1451        assert!(mock_provider.called.load(Ordering::Relaxed));
1452    }
1453
1454    #[tokio::test]
1455    async fn test_local_paths() {
1456        let temp_dir = tempfile::tempdir().unwrap();
1457
1458        let file_path = temp_dir.path().join("test_file");
1459        let mut writer = ObjectStore::create_local_writer(file_path.as_path())
1460            .await
1461            .unwrap();
1462        writer.write_all(b"LOCAL").await.unwrap();
1463        writer.shutdown().await.unwrap();
1464
1465        let reader = ObjectStore::open_local(file_path.as_path()).await.unwrap();
1466        let buf = reader.get_range(0..5).await.unwrap();
1467        assert_eq!(buf.as_bytes(), b"LOCAL");
1468    }
1469
1470    #[tokio::test]
1471    async fn test_read_one() {
1472        let temp_dir = tempfile::tempdir().unwrap();
1473
1474        let file_path = temp_dir.path().join("test_file");
1475        let mut writer = ObjectStore::create_local_writer(file_path.as_path())
1476            .await
1477            .unwrap();
1478        writer.write_all(b"LOCAL").await.unwrap();
1479        writer.shutdown().await.unwrap();
1480
1481        let file_path_os = object_store::path::Path::parse(file_path.to_str().unwrap()).unwrap();
1482        let obj_store = ObjectStore::local();
1483        let buf = obj_store.read_one_all(&file_path_os).await.unwrap();
1484        assert_eq!(buf.as_bytes(), b"LOCAL");
1485
1486        let buf = obj_store.read_one_range(&file_path_os, 0..5).await.unwrap();
1487        assert_eq!(buf.as_bytes(), b"LOCAL");
1488    }
1489
1490    #[tokio::test]
1491    #[cfg(windows)]
1492    async fn test_windows_paths() {
1493        use std::path::Component;
1494        use std::path::Prefix;
1495        use std::path::Prefix::*;
1496
1497        fn get_path_prefix(path: &StdPath) -> Prefix {
1498            match path.components().next().unwrap() {
1499                Component::Prefix(prefix_component) => prefix_component.kind(),
1500                _ => panic!(),
1501            }
1502        }
1503
1504        fn get_drive_letter(prefix: Prefix) -> String {
1505            match prefix {
1506                Disk(bytes) => String::from_utf8(vec![bytes]).unwrap(),
1507                _ => panic!(),
1508            }
1509        }
1510
1511        let tmp_dir = tempfile::tempdir().unwrap();
1512        let tmp_path = tmp_dir.path();
1513        let prefix = get_path_prefix(tmp_path);
1514        let drive_letter = get_drive_letter(prefix);
1515
1516        write_to_file(
1517            &(format!("{drive_letter}:/test_folder/test.lance") + "/test_file"),
1518            "WINDOWS",
1519        )
1520        .unwrap();
1521
1522        for uri in &[
1523            format!("{drive_letter}:/test_folder/test.lance"),
1524            format!("{drive_letter}:\\test_folder\\test.lance"),
1525        ] {
1526            let (store, base) = ObjectStore::from_uri(uri).await.unwrap();
1527            let contents = read_from_store(store, &base.child("test_file"))
1528                .await
1529                .unwrap();
1530            assert_eq!(contents, "WINDOWS");
1531        }
1532    }
1533}