lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, fmt, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{services::S3, Operator};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19    aws::{
20        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
21        AwsCredentialProvider,
22    },
23    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
24    StaticCredentialProvider,
25};
26use snafu::location;
27use tokio::sync::RwLock;
28use url::Url;
29
30use crate::object_store::{
31    ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsProvider,
32    DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
33};
34use lance_core::error::{Error, Result};
35
36#[derive(Default, Debug)]
37pub struct AwsStoreProvider;
38
39impl AwsStoreProvider {
40    async fn build_amazon_s3_store(
41        &self,
42        base_path: &mut Url,
43        params: &ObjectStoreParams,
44        storage_options: &StorageOptions,
45        is_s3_express: bool,
46    ) -> Result<Arc<dyn OSObjectStore>> {
47        let max_retries = storage_options.client_max_retries();
48        let retry_timeout = storage_options.client_retry_timeout();
49        let retry_config = RetryConfig {
50            backoff: Default::default(),
51            max_retries,
52            retry_timeout: Duration::from_secs(retry_timeout),
53        };
54
55        let mut s3_storage_options = storage_options.as_s3_options();
56        let region = resolve_s3_region(base_path, &s3_storage_options).await?;
57        let (aws_creds, region) = build_aws_credential(
58            params.s3_credentials_refresh_offset,
59            params.aws_credentials.clone(),
60            Some(&s3_storage_options),
61            region,
62            params.storage_options_provider.clone(),
63            storage_options.expires_at_millis(),
64        )
65        .await?;
66
67        // Set S3Express flag if detected
68        if is_s3_express {
69            s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
70        }
71
72        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
73        base_path.set_scheme("s3").unwrap();
74        base_path.set_query(None);
75
76        // we can't use parse_url_opts here because we need to manually set the credentials provider
77        let mut builder = AmazonS3Builder::new();
78        for (key, value) in s3_storage_options {
79            builder = builder.with_config(key, value);
80        }
81        builder = builder
82            .with_url(base_path.as_ref())
83            .with_credentials(aws_creds)
84            .with_retry(retry_config)
85            .with_region(region);
86
87        Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
88    }
89
90    async fn build_opendal_s3_store(
91        &self,
92        base_path: &Url,
93        storage_options: &StorageOptions,
94    ) -> Result<Arc<dyn OSObjectStore>> {
95        let bucket = base_path
96            .host_str()
97            .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name", location!()))?
98            .to_string();
99
100        let prefix = base_path.path().trim_start_matches('/').to_string();
101
102        // Start with all storage options as the config map
103        // OpenDAL will handle environment variables through its default credentials chain
104        let mut config_map: HashMap<String, String> = storage_options.0.clone();
105
106        // Set required OpenDAL configuration
107        config_map.insert("bucket".to_string(), bucket);
108
109        if !prefix.is_empty() {
110            config_map.insert("root".to_string(), "/".to_string());
111        }
112
113        let operator = Operator::from_iter::<S3>(config_map)
114            .map_err(|e| {
115                Error::invalid_input(
116                    format!("Failed to create S3 operator: {:?}", e),
117                    location!(),
118                )
119            })?
120            .finish();
121
122        Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
123    }
124}
125
126#[async_trait::async_trait]
127impl ObjectStoreProvider for AwsStoreProvider {
128    async fn new_store(
129        &self,
130        mut base_path: Url,
131        params: &ObjectStoreParams,
132    ) -> Result<ObjectStore> {
133        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
134        let mut storage_options =
135            StorageOptions(params.storage_options.clone().unwrap_or_default());
136        storage_options.with_env_s3();
137        let download_retry_count = storage_options.download_retry_count();
138
139        let use_opendal = storage_options
140            .0
141            .get("use_opendal")
142            .map(|v| v == "true")
143            .unwrap_or(false);
144
145        // Determine S3 Express and constant size upload parts before building the store
146        let is_s3_express = check_s3_express(&base_path, &storage_options);
147
148        let use_constant_size_upload_parts = storage_options
149            .0
150            .get("aws_endpoint")
151            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
152            .unwrap_or(false);
153
154        let inner = if use_opendal {
155            // Use OpenDAL implementation
156            self.build_opendal_s3_store(&base_path, &storage_options)
157                .await?
158        } else {
159            // Use default Amazon S3 implementation
160            self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
161                .await?
162        };
163
164        Ok(ObjectStore {
165            inner,
166            scheme: String::from(base_path.scheme()),
167            block_size,
168            max_iop_size: *DEFAULT_MAX_IOP_SIZE,
169            use_constant_size_upload_parts,
170            list_is_lexically_ordered: !is_s3_express,
171            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
172            download_retry_count,
173            io_tracker: Default::default(),
174        })
175    }
176}
177
178/// Check if the storage is S3 Express
179fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
180    storage_options
181        .0
182        .get("s3_express")
183        .map(|v| v == "true")
184        .unwrap_or(false)
185        || url.authority().ends_with("--x-s3")
186}
187
188/// Figure out the S3 region of the bucket.
189///
190/// This resolves in order of precedence:
191/// 1. The region provided in the storage options
192/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
193///
194/// It can return None if no region is provided and the endpoint is set.
195async fn resolve_s3_region(
196    url: &Url,
197    storage_options: &HashMap<AmazonS3ConfigKey, String>,
198) -> Result<Option<String>> {
199    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
200        Ok(Some(region.clone()))
201    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
202        // If no endpoint is set, we can assume this is AWS S3 and the region
203        // can be resolved from the bucket.
204        let bucket = url.host_str().ok_or_else(|| {
205            Error::invalid_input(
206                format!("Could not parse bucket from url: {}", url),
207                location!(),
208            )
209        })?;
210
211        let mut client_options = ClientOptions::default();
212        for (key, value) in storage_options {
213            if let AmazonS3ConfigKey::Client(client_key) = key {
214                client_options = client_options.with_config(*client_key, value.clone());
215            }
216        }
217
218        let bucket_region =
219            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
220        Ok(Some(bucket_region))
221    } else {
222        Ok(None)
223    }
224}
225
226/// Build AWS credentials
227///
228/// This resolves credentials from the following sources in order:
229/// 1. An explicit `storage_options_provider`
230/// 2. An explicit `credentials` provider
231/// 3. Explicit credentials in storage_options (as in `aws_access_key_id`,
232///    `aws_secret_access_key`, `aws_session_token`)
233/// 4. The default credential provider chain from AWS SDK.
234///
235/// # Initial Credentials with Storage Options Provider
236///
237/// When `storage_options_provider` is provided along with `storage_options` and
238/// `expires_at_millis`, these serve as **initial values** to avoid redundant calls to
239/// fetch new storage options. The provider will use these initial credentials until they
240/// expire (based on `expires_at_millis`), then automatically fetch fresh credentials from
241/// the provider. Once the initial credentials expire, the passed-in values are no longer
242/// used - all future credentials come from the provider's `fetch_storage_options()` method.
243///
244/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
245pub async fn build_aws_credential(
246    credentials_refresh_offset: Duration,
247    credentials: Option<AwsCredentialProvider>,
248    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
249    region: Option<String>,
250    storage_options_provider: Option<Arc<dyn StorageOptionsProvider>>,
251    expires_at_millis: Option<u64>,
252) -> Result<(AwsCredentialProvider, String)> {
253    // TODO: make this return no credential provider not using AWS
254    use aws_config::meta::region::RegionProviderChain;
255    const DEFAULT_REGION: &str = "us-west-2";
256
257    let region = if let Some(region) = region {
258        region
259    } else {
260        RegionProviderChain::default_provider()
261            .or_else(DEFAULT_REGION)
262            .region()
263            .await
264            .map(|r| r.as_ref().to_string())
265            .unwrap_or(DEFAULT_REGION.to_string())
266    };
267
268    let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
269    if let Some(storage_options_provider) = storage_options_provider {
270        let creds = build_aws_credential_with_storage_options_provider(
271            storage_options_provider,
272            credentials_refresh_offset,
273            credentials,
274            storage_options_credentials,
275            expires_at_millis,
276        )
277        .await?;
278        Ok((creds, region))
279    } else if let Some(creds) = credentials {
280        Ok((creds, region))
281    } else if let Some(creds) = storage_options_credentials {
282        Ok((Arc::new(creds), region))
283    } else {
284        let credentials_provider = DefaultCredentialsChain::builder().build().await;
285
286        Ok((
287            Arc::new(AwsCredentialAdapter::new(
288                Arc::new(credentials_provider),
289                credentials_refresh_offset,
290            )),
291            region,
292        ))
293    }
294}
295
296async fn build_aws_credential_with_storage_options_provider(
297    storage_options_provider: Arc<dyn StorageOptionsProvider>,
298    credentials_refresh_offset: Duration,
299    credentials: Option<AwsCredentialProvider>,
300    storage_options_credentials: Option<StaticCredentialProvider<ObjectStoreAwsCredential>>,
301    expires_at_millis: Option<u64>,
302) -> Result<AwsCredentialProvider> {
303    match (expires_at_millis, credentials, storage_options_credentials) {
304        // Case 1: provider + credentials + expiration time
305        (Some(expires_at), Some(cred), _) => {
306            Ok(Arc::new(
307                DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
308                    storage_options_provider,
309                    credentials_refresh_offset,
310                    cred.get_credential().await?,
311                    expires_at,
312                ),
313            ))
314        }
315        // Case 2: provider + storage_options (with valid credentials) + expiration time
316        (Some(expires_at), None, Some(cred)) => {
317            Ok(Arc::new(
318                DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
319                    storage_options_provider,
320                    credentials_refresh_offset,
321                    cred.get_credential().await?,
322                    expires_at,
323                ),
324            ))
325        }
326        // Case 3: provider + storage_options without expiration - FAIL
327        (None, None, Some(_)) => Err(Error::IO {
328            source: Box::new(std::io::Error::other(
329                "expires_at_millis is required when using storage_options_provider with storage_options",
330            )),
331            location: location!(),
332        }),
333        // Case 4: provider + credentials without expiration - FAIL
334        (None, Some(_), _) => Err(Error::IO {
335            source: Box::new(std::io::Error::other(
336                "expires_at_millis is required when using storage_options_provider with credentials",
337            )),
338            location: location!(),
339        }),
340        // Case 5: provider without credentials/storage_options, or with expiration but no creds/opts
341        (_, None, None) => Ok(Arc::new(DynamicStorageOptionsCredentialProvider::new(
342            storage_options_provider,
343            credentials_refresh_offset,
344        ))),
345    }
346}
347
348fn extract_static_s3_credentials(
349    options: &HashMap<AmazonS3ConfigKey, String>,
350) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
351    let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
352    let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
353    let token = options.get(&AmazonS3ConfigKey::Token).cloned();
354    match (key_id, secret_key, token) {
355        (Some(key_id), Some(secret_key), token) => {
356            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
357                key_id,
358                secret_key,
359                token,
360            }))
361        }
362        _ => None,
363    }
364}
365
366/// Adapt an AWS SDK cred into object_store credentials
367#[derive(Debug)]
368pub struct AwsCredentialAdapter {
369    pub inner: Arc<dyn ProvideCredentials>,
370
371    // RefCell can't be shared across threads, so we use HashMap
372    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
373
374    // The amount of time before expiry to refresh credentials
375    credentials_refresh_offset: Duration,
376}
377
378impl AwsCredentialAdapter {
379    pub fn new(
380        provider: Arc<dyn ProvideCredentials>,
381        credentials_refresh_offset: Duration,
382    ) -> Self {
383        Self {
384            inner: provider,
385            cache: Arc::new(RwLock::new(HashMap::new())),
386            credentials_refresh_offset,
387        }
388    }
389}
390
391const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
392
393/// Convert std::time::SystemTime from AWS SDK to our mockable SystemTime
394fn to_system_time(time: std::time::SystemTime) -> SystemTime {
395    let duration_since_epoch = time
396        .duration_since(std::time::UNIX_EPOCH)
397        .expect("time should be after UNIX_EPOCH");
398    UNIX_EPOCH + duration_since_epoch
399}
400
401#[async_trait::async_trait]
402impl CredentialProvider for AwsCredentialAdapter {
403    type Credential = ObjectStoreAwsCredential;
404
405    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
406        let cached_creds = {
407            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
408            let expired = cache_value
409                .clone()
410                .map(|cred| {
411                    cred.expiry()
412                        .map(|exp| {
413                            to_system_time(exp)
414                                .checked_sub(self.credentials_refresh_offset)
415                                .expect("this time should always be valid")
416                                < SystemTime::now()
417                        })
418                        // no expiry is never expire
419                        .unwrap_or(false)
420                })
421                .unwrap_or(true); // no cred is the same as expired;
422            if expired {
423                None
424            } else {
425                cache_value.clone()
426            }
427        };
428
429        if let Some(creds) = cached_creds {
430            Ok(Arc::new(Self::Credential {
431                key_id: creds.access_key_id().to_string(),
432                secret_key: creds.secret_access_key().to_string(),
433                token: creds.session_token().map(|s| s.to_string()),
434            }))
435        } else {
436            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
437                |e| Error::Internal {
438                    message: format!("Failed to get AWS credentials: {:?}", e),
439                    location: location!(),
440                },
441            )?);
442
443            self.cache
444                .write()
445                .await
446                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
447
448            Ok(Arc::new(Self::Credential {
449                key_id: refreshed_creds.access_key_id().to_string(),
450                secret_key: refreshed_creds.secret_access_key().to_string(),
451                token: refreshed_creds.session_token().map(|s| s.to_string()),
452            }))
453        }
454    }
455}
456
457impl StorageOptions {
458    /// Add values from the environment to storage options
459    pub fn with_env_s3(&mut self) {
460        for (os_key, os_value) in std::env::vars_os() {
461            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
462                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
463                    if !self.0.contains_key(config_key.as_ref()) {
464                        self.0
465                            .insert(config_key.as_ref().to_string(), value.to_string());
466                    }
467                }
468            }
469        }
470    }
471
472    /// Subset of options relevant for s3 storage
473    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
474        self.0
475            .iter()
476            .filter_map(|(key, value)| {
477                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
478                Some((s3_key, value.clone()))
479            })
480            .collect()
481    }
482}
483
484impl ObjectStoreParams {
485    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
486    pub fn with_aws_credentials(
487        aws_credentials: Option<AwsCredentialProvider>,
488        region: Option<String>,
489    ) -> Self {
490        Self {
491            aws_credentials,
492            storage_options: region
493                .map(|region| [("region".into(), region)].iter().cloned().collect()),
494            ..Default::default()
495        }
496    }
497}
498
499/// AWS Credential Provider that uses StorageOptionsProvider
500///
501/// This adapter converts our generic StorageOptionsProvider trait into
502/// AWS-specific credentials that can be used with S3. It caches credentials
503/// and automatically refreshes them before they expire.
504///
505/// # Future Work
506///
507/// TODO: Support AWS/GCP/Azure together in a unified credential provider.
508/// Currently this is AWS-specific. Needs investigation of how GCP and Azure credential
509/// refresh mechanisms work and whether they can be unified with AWS's approach.
510///
511/// See: <https://github.com/lance-format/lance/pull/4905#discussion_r2474605265>
512pub struct DynamicStorageOptionsCredentialProvider {
513    provider: Arc<dyn StorageOptionsProvider>,
514    cache: Arc<RwLock<Option<CachedCredential>>>,
515    refresh_offset: Duration,
516}
517
518impl fmt::Debug for DynamicStorageOptionsCredentialProvider {
519    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
520        f.debug_struct("DynamicStorageOptionsCredentialProvider")
521            .field("provider", &self.provider)
522            .field("refresh_offset", &self.refresh_offset)
523            .finish()
524    }
525}
526
527#[derive(Debug, Clone)]
528struct CachedCredential {
529    credential: Arc<ObjectStoreAwsCredential>,
530    expires_at_millis: Option<u64>,
531}
532
533impl DynamicStorageOptionsCredentialProvider {
534    /// Create a new credential provider without initial credentials
535    ///
536    /// # Arguments
537    /// * `provider` - The storage options provider
538    /// * `refresh_offset` - Duration before expiry to refresh credentials
539    pub fn new(provider: Arc<dyn StorageOptionsProvider>, refresh_offset: Duration) -> Self {
540        Self {
541            provider,
542            cache: Arc::new(RwLock::new(None)),
543            refresh_offset,
544        }
545    }
546
547    /// Create a new credential provider with initial credentials from an explicit credential
548    ///
549    /// # Arguments
550    /// * `provider` - The storage options provider
551    /// * `refresh_offset` - Duration before expiry to refresh credentials
552    /// * `credential` - Initial credential to cache
553    /// * `expires_at_millis` - Expiration time in milliseconds since epoch (required for refresh)
554    pub fn new_with_initial_credential(
555        provider: Arc<dyn StorageOptionsProvider>,
556        refresh_offset: Duration,
557        credential: Arc<ObjectStoreAwsCredential>,
558        expires_at_millis: u64,
559    ) -> Self {
560        Self {
561            provider,
562            cache: Arc::new(RwLock::new(Some(CachedCredential {
563                credential,
564                expires_at_millis: Some(expires_at_millis),
565            }))),
566            refresh_offset,
567        }
568    }
569
570    fn needs_refresh(&self, cached: &Option<CachedCredential>) -> bool {
571        match cached {
572            None => true,
573            Some(cached_cred) => {
574                if let Some(expires_at_millis) = cached_cred.expires_at_millis {
575                    let now_ms = SystemTime::now()
576                        .duration_since(UNIX_EPOCH)
577                        .unwrap_or(Duration::from_secs(0))
578                        .as_millis() as u64;
579
580                    // Refresh if we're within the refresh offset of expiration
581                    let refresh_offset_millis = self.refresh_offset.as_millis() as u64;
582                    now_ms + refresh_offset_millis >= expires_at_millis
583                } else {
584                    // No expiration means credentials never expire
585                    false
586                }
587            }
588        }
589    }
590
591    async fn do_get_credential(&self) -> ObjectStoreResult<Option<Arc<ObjectStoreAwsCredential>>> {
592        // Check if we have valid cached credentials with read lock
593        {
594            let cached = self.cache.read().await;
595            if !self.needs_refresh(&cached) {
596                if let Some(cached_cred) = &*cached {
597                    return Ok(Some(cached_cred.credential.clone()));
598                }
599            }
600        }
601
602        // Try to acquire write lock - if it fails, return None and let caller retry
603        let Ok(mut cache) = self.cache.try_write() else {
604            return Ok(None);
605        };
606
607        // Double-check if credentials are still stale after acquiring write lock
608        // (another thread might have refreshed them)
609        if !self.needs_refresh(&cache) {
610            if let Some(cached_cred) = &*cache {
611                return Ok(Some(cached_cred.credential.clone()));
612            }
613        }
614
615        log::debug!(
616            "Refreshing S3 credentials from storage options provider: {}",
617            self.provider.provider_id()
618        );
619
620        let storage_options_map = self
621            .provider
622            .fetch_storage_options()
623            .await
624            .map_err(|e| object_store::Error::Generic {
625                store: "DynamicStorageOptionsCredentialProvider",
626                source: Box::new(e),
627            })?
628            .ok_or_else(|| object_store::Error::Generic {
629                store: "DynamicStorageOptionsCredentialProvider",
630                source: "No storage options available".into(),
631            })?;
632
633        let storage_options = StorageOptions(storage_options_map);
634        let expires_at_millis = storage_options.expires_at_millis();
635        let s3_options = storage_options.as_s3_options();
636        let static_creds = extract_static_s3_credentials(&s3_options).ok_or_else(|| {
637            object_store::Error::Generic {
638                store: "DynamicStorageOptionsCredentialProvider",
639                source: "Missing required credentials in storage options".into(),
640            }
641        })?;
642
643        let credential =
644            static_creds
645                .get_credential()
646                .await
647                .map_err(|e| object_store::Error::Generic {
648                    store: "DynamicStorageOptionsCredentialProvider",
649                    source: Box::new(e),
650                })?;
651
652        if let Some(expires_at) = expires_at_millis {
653            let now_ms = SystemTime::now()
654                .duration_since(UNIX_EPOCH)
655                .unwrap_or(Duration::from_secs(0))
656                .as_millis() as u64;
657            let expires_in_secs = (expires_at.saturating_sub(now_ms)) / 1000;
658            log::debug!(
659                "Successfully refreshed S3 credentials from provider: {}, credentials expire in {} seconds",
660                self.provider.provider_id(),
661                expires_in_secs
662            );
663        } else {
664            log::debug!(
665                "Successfully refreshed S3 credentials from provider: {} (no expiration)",
666                self.provider.provider_id()
667            );
668        }
669
670        *cache = Some(CachedCredential {
671            credential: credential.clone(),
672            expires_at_millis,
673        });
674
675        Ok(Some(credential))
676    }
677}
678
679#[async_trait::async_trait]
680impl CredentialProvider for DynamicStorageOptionsCredentialProvider {
681    type Credential = ObjectStoreAwsCredential;
682
683    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
684        // Retry loop - if do_get_credential returns None (lock busy), retry from the beginning
685        loop {
686            match self.do_get_credential().await? {
687                Some(cred) => return Ok(cred),
688                None => {
689                    // Lock was busy, wait 10ms before retrying
690                    tokio::time::sleep(Duration::from_millis(10)).await;
691                    continue;
692                }
693            }
694        }
695    }
696}
697
698#[cfg(test)]
699mod tests {
700    use crate::object_store::ObjectStoreRegistry;
701    use mock_instant::thread_local::MockClock;
702    use object_store::path::Path;
703    use std::sync::atomic::{AtomicBool, Ordering};
704
705    use super::*;
706
707    #[derive(Debug, Default)]
708    struct MockAwsCredentialsProvider {
709        called: AtomicBool,
710    }
711
712    #[async_trait::async_trait]
713    impl CredentialProvider for MockAwsCredentialsProvider {
714        type Credential = ObjectStoreAwsCredential;
715
716        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
717            self.called.store(true, Ordering::Relaxed);
718            Ok(Arc::new(Self::Credential {
719                key_id: "".to_string(),
720                secret_key: "".to_string(),
721                token: None,
722            }))
723        }
724    }
725
726    #[tokio::test]
727    async fn test_injected_aws_creds_option_is_used() {
728        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
729        let registry = Arc::new(ObjectStoreRegistry::default());
730
731        let params = ObjectStoreParams {
732            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
733            ..ObjectStoreParams::default()
734        };
735
736        // Not called yet
737        assert!(!mock_provider.called.load(Ordering::Relaxed));
738
739        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
740            .await
741            .unwrap();
742
743        // fails, but we don't care
744        let _ = store
745            .open(&Path::parse("/").unwrap())
746            .await
747            .unwrap()
748            .get_range(0..1)
749            .await;
750
751        // Not called yet
752        assert!(mock_provider.called.load(Ordering::Relaxed));
753    }
754
755    #[test]
756    fn test_s3_path_parsing() {
757        let provider = AwsStoreProvider;
758
759        let cases = [
760            ("s3://bucket/path/to/file", "path/to/file"),
761            // for non ASCII string tests
762            ("s3://bucket/测试path/to/file", "测试path/to/file"),
763            ("s3://bucket/path/&to/file", "path/&to/file"),
764            ("s3://bucket/path/=to/file", "path/=to/file"),
765            (
766                "s3+ddb://bucket/path/to/file?ddbTableName=test",
767                "path/to/file",
768            ),
769        ];
770
771        for (uri, expected_path) in cases {
772            let url = Url::parse(uri).unwrap();
773            let path = provider.extract_path(&url).unwrap();
774            let expected_path = Path::from(expected_path);
775            assert_eq!(path, expected_path)
776        }
777    }
778
779    #[test]
780    fn test_is_s3_express() {
781        let cases = [
782            (
783                "s3://bucket/path/to/file",
784                HashMap::from([("s3_express".to_string(), "true".to_string())]),
785                true,
786            ),
787            (
788                "s3://bucket/path/to/file",
789                HashMap::from([("s3_express".to_string(), "false".to_string())]),
790                false,
791            ),
792            ("s3://bucket/path/to/file", HashMap::from([]), false),
793            (
794                "s3://bucket--x-s3/path/to/file",
795                HashMap::from([("s3_express".to_string(), "true".to_string())]),
796                true,
797            ),
798            (
799                "s3://bucket--x-s3/path/to/file",
800                HashMap::from([("s3_express".to_string(), "false".to_string())]),
801                true, // URL takes precedence
802            ),
803            ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
804        ];
805
806        for (uri, storage_map, expected) in cases {
807            let url = Url::parse(uri).unwrap();
808            let storage_options = StorageOptions(storage_map);
809            let is_s3_express = check_s3_express(&url, &storage_options);
810            assert_eq!(is_s3_express, expected);
811        }
812    }
813
814    #[tokio::test]
815    async fn test_use_opendal_flag() {
816        let provider = AwsStoreProvider;
817        let url = Url::parse("s3://test-bucket/path").unwrap();
818        let params_with_flag = ObjectStoreParams {
819            storage_options: Some(HashMap::from([
820                ("use_opendal".to_string(), "true".to_string()),
821                ("region".to_string(), "us-west-2".to_string()),
822            ])),
823            ..Default::default()
824        };
825
826        let store = provider
827            .new_store(url.clone(), &params_with_flag)
828            .await
829            .unwrap();
830        assert_eq!(store.scheme, "s3");
831    }
832
833    #[derive(Debug)]
834    struct MockStorageOptionsProvider {
835        call_count: Arc<RwLock<usize>>,
836        expires_in_millis: Option<u64>,
837    }
838
839    impl MockStorageOptionsProvider {
840        fn new(expires_in_millis: Option<u64>) -> Self {
841            Self {
842                call_count: Arc::new(RwLock::new(0)),
843                expires_in_millis,
844            }
845        }
846
847        async fn get_call_count(&self) -> usize {
848            *self.call_count.read().await
849        }
850    }
851
852    #[async_trait::async_trait]
853    impl StorageOptionsProvider for MockStorageOptionsProvider {
854        async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
855            let count = {
856                let mut c = self.call_count.write().await;
857                *c += 1;
858                *c
859            };
860
861            let mut options = HashMap::from([
862                ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
863                (
864                    "aws_secret_access_key".to_string(),
865                    format!("SECRET_{}", count),
866                ),
867                ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
868            ]);
869
870            if let Some(expires_in) = self.expires_in_millis {
871                let now_ms = SystemTime::now()
872                    .duration_since(UNIX_EPOCH)
873                    .unwrap()
874                    .as_millis() as u64;
875                let expires_at = now_ms + expires_in;
876                options.insert("expires_at_millis".to_string(), expires_at.to_string());
877            }
878
879            Ok(Some(options))
880        }
881
882        fn provider_id(&self) -> String {
883            let ptr = Arc::as_ptr(&self.call_count) as usize;
884            format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
885        }
886    }
887
888    #[tokio::test]
889    async fn test_dynamic_credential_provider_with_initial_cache() {
890        MockClock::set_system_time(Duration::from_secs(100_000));
891
892        let now_ms = MockClock::system_time().as_millis() as u64;
893
894        // Create a mock provider that returns credentials expiring in 10 minutes
895        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
896            600_000, // Expires in 10 minutes
897        )));
898
899        // Create credential provider with initial cached credentials that expire in 10 minutes
900        let expires_at = now_ms + 600_000; // 10 minutes from now
901        let initial_cred = Arc::new(ObjectStoreAwsCredential {
902            key_id: "AKID_CACHED".to_string(),
903            secret_key: "SECRET_CACHED".to_string(),
904            token: Some("TOKEN_CACHED".to_string()),
905        });
906
907        let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
908            mock.clone(),
909            Duration::from_secs(300), // 5 minute refresh offset
910            initial_cred,
911            expires_at,
912        );
913
914        // First call should use cached credentials (not expired yet)
915        let cred = provider.get_credential().await.unwrap();
916        assert_eq!(cred.key_id, "AKID_CACHED");
917        assert_eq!(cred.secret_key, "SECRET_CACHED");
918        assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
919
920        // Should not have called the provider yet
921        assert_eq!(mock.get_call_count().await, 0);
922    }
923
924    #[tokio::test]
925    async fn test_dynamic_credential_provider_with_expired_cache() {
926        MockClock::set_system_time(Duration::from_secs(100_000));
927
928        let now_ms = MockClock::system_time().as_millis() as u64;
929
930        // Create a mock provider that returns credentials expiring in 10 minutes
931        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
932            600_000, // Expires in 10 minutes
933        )));
934
935        // Create credential provider with initial cached credentials that expired 1 second ago
936        let expired_time = now_ms - 1_000; // 1 second ago
937        let initial_cred = Arc::new(ObjectStoreAwsCredential {
938            key_id: "AKID_EXPIRED".to_string(),
939            secret_key: "SECRET_EXPIRED".to_string(),
940            token: None,
941        });
942
943        let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
944            mock.clone(),
945            Duration::from_secs(300), // 5 minute refresh offset
946            initial_cred,
947            expired_time,
948        );
949
950        // First call should fetch new credentials because cached ones are expired
951        let cred = provider.get_credential().await.unwrap();
952        assert_eq!(cred.key_id, "AKID_1");
953        assert_eq!(cred.secret_key, "SECRET_1");
954        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
955
956        // Should have called the provider once
957        assert_eq!(mock.get_call_count().await, 1);
958    }
959
960    #[tokio::test]
961    async fn test_dynamic_credential_provider_refresh_lead_time() {
962        MockClock::set_system_time(Duration::from_secs(100_000));
963
964        // Create a mock provider that returns credentials expiring in 4 minutes
965        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
966            240_000, // Expires in 4 minutes
967        )));
968
969        // Create credential provider with 5 minute refresh offset
970        // This means credentials should be refreshed when they have less than 5 minutes left
971        let provider = DynamicStorageOptionsCredentialProvider::new(
972            mock.clone(),
973            Duration::from_secs(300), // 5 minute refresh offset
974        );
975
976        // First call should fetch credentials from provider (no initial cache)
977        // Credentials expire in 4 minutes, which is less than our 5 minute refresh offset,
978        // so they should be considered "needs refresh" immediately
979        let cred = provider.get_credential().await.unwrap();
980        assert_eq!(cred.key_id, "AKID_1");
981        assert_eq!(mock.get_call_count().await, 1);
982
983        // Second call should trigger refresh because credentials expire in 4 minutes
984        // but our refresh lead time is 5 minutes (now + 5min > expires_at)
985        // The mock will return new credentials (AKID_2) with the same expiration
986        let cred = provider.get_credential().await.unwrap();
987        assert_eq!(cred.key_id, "AKID_2");
988        assert_eq!(mock.get_call_count().await, 2);
989    }
990
991    #[tokio::test]
992    async fn test_dynamic_credential_provider_no_initial_cache() {
993        MockClock::set_system_time(Duration::from_secs(100_000));
994
995        // Create a mock provider that returns credentials expiring in 10 minutes
996        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
997            600_000, // Expires in 10 minutes
998        )));
999
1000        // Create credential provider without initial cache
1001        let provider = DynamicStorageOptionsCredentialProvider::new(
1002            mock.clone(),
1003            Duration::from_secs(300), // 5 minute refresh offset
1004        );
1005
1006        // First call should fetch from provider (call count = 1)
1007        let cred = provider.get_credential().await.unwrap();
1008        assert_eq!(cred.key_id, "AKID_1");
1009        assert_eq!(cred.secret_key, "SECRET_1");
1010        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1011        assert_eq!(mock.get_call_count().await, 1);
1012
1013        // Second call should use cached credentials (not expired yet)
1014        let cred = provider.get_credential().await.unwrap();
1015        assert_eq!(cred.key_id, "AKID_1");
1016        assert_eq!(mock.get_call_count().await, 1); // Still 1, didn't fetch again
1017
1018        // Advance time to 6 minutes - should trigger refresh (within 5 min refresh offset)
1019        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1020        let cred = provider.get_credential().await.unwrap();
1021        assert_eq!(cred.key_id, "AKID_2");
1022        assert_eq!(cred.secret_key, "SECRET_2");
1023        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
1024        assert_eq!(mock.get_call_count().await, 2);
1025
1026        // Advance time to 11 minutes total - should trigger another refresh
1027        MockClock::set_system_time(Duration::from_secs(100_000 + 660));
1028        let cred = provider.get_credential().await.unwrap();
1029        assert_eq!(cred.key_id, "AKID_3");
1030        assert_eq!(cred.secret_key, "SECRET_3");
1031        assert_eq!(mock.get_call_count().await, 3);
1032    }
1033
1034    #[tokio::test]
1035    async fn test_dynamic_credential_provider_with_initial_credential() {
1036        MockClock::set_system_time(Duration::from_secs(100_000));
1037
1038        let now_ms = MockClock::system_time().as_millis() as u64;
1039
1040        // Create a mock provider that returns credentials expiring in 10 minutes
1041        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1042            600_000, // Expires in 10 minutes
1043        )));
1044
1045        // Create an initial credential with expiration in 10 minutes
1046        let expires_at = now_ms + 600_000; // 10 minutes from now
1047        let initial_cred = Arc::new(ObjectStoreAwsCredential {
1048            key_id: "AKID_INITIAL".to_string(),
1049            secret_key: "SECRET_INITIAL".to_string(),
1050            token: Some("TOKEN_INITIAL".to_string()),
1051        });
1052
1053        // Create credential provider with initial credential and expiration
1054        let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
1055            mock.clone(),
1056            Duration::from_secs(300), // 5 minute refresh offset
1057            initial_cred,
1058            expires_at,
1059        );
1060
1061        // First call should use the initial credential (not expired yet)
1062        let cred = provider.get_credential().await.unwrap();
1063        assert_eq!(cred.key_id, "AKID_INITIAL");
1064        assert_eq!(cred.secret_key, "SECRET_INITIAL");
1065        assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
1066
1067        // Should not have called the provider yet
1068        assert_eq!(mock.get_call_count().await, 0);
1069
1070        // Advance time to 6 minutes - this should trigger a refresh
1071        // (5 minute refresh offset means we refresh 5 minutes before expiration)
1072        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1073        let cred = provider.get_credential().await.unwrap();
1074        assert_eq!(cred.key_id, "AKID_1");
1075        assert_eq!(cred.secret_key, "SECRET_1");
1076        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1077
1078        // Should have called the provider once
1079        assert_eq!(mock.get_call_count().await, 1);
1080
1081        // Advance time to 11 minutes total - this should trigger another refresh
1082        MockClock::set_system_time(Duration::from_secs(100_000 + 660));
1083        let cred = provider.get_credential().await.unwrap();
1084        assert_eq!(cred.key_id, "AKID_2");
1085        assert_eq!(cred.secret_key, "SECRET_2");
1086        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
1087
1088        // Should have called the provider twice
1089        assert_eq!(mock.get_call_count().await, 2);
1090
1091        // Advance time to 16 minutes total - this should trigger yet another refresh
1092        MockClock::set_system_time(Duration::from_secs(100_000 + 960));
1093        let cred = provider.get_credential().await.unwrap();
1094        assert_eq!(cred.key_id, "AKID_3");
1095        assert_eq!(cred.secret_key, "SECRET_3");
1096        assert_eq!(cred.token, Some("TOKEN_3".to_string()));
1097
1098        // Should have called the provider three times
1099        assert_eq!(mock.get_call_count().await, 3);
1100    }
1101
1102    #[tokio::test]
1103    async fn test_dynamic_credential_provider_concurrent_access() {
1104        // Create a mock provider with far future expiration
1105        let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
1106
1107        let provider = Arc::new(DynamicStorageOptionsCredentialProvider::new(
1108            mock.clone(),
1109            Duration::from_secs(300),
1110        ));
1111
1112        // Spawn 10 concurrent tasks that all try to get credentials at the same time
1113        let mut handles = vec![];
1114        for i in 0..10 {
1115            let provider = provider.clone();
1116            let handle = tokio::spawn(async move {
1117                let cred = provider.get_credential().await.unwrap();
1118                // Verify we got the correct credentials (should all be AKID_1 from first fetch)
1119                assert_eq!(cred.key_id, "AKID_1");
1120                assert_eq!(cred.secret_key, "SECRET_1");
1121                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1122                i // Return task number for verification
1123            });
1124            handles.push(handle);
1125        }
1126
1127        // Wait for all tasks to complete
1128        let results: Vec<_> = futures::future::join_all(handles)
1129            .await
1130            .into_iter()
1131            .map(|r| r.unwrap())
1132            .collect();
1133
1134        // Verify all 10 tasks completed successfully
1135        assert_eq!(results.len(), 10);
1136        for i in 0..10 {
1137            assert!(results.contains(&i));
1138        }
1139
1140        // The provider should have been called exactly once (first request triggers fetch,
1141        // subsequent requests use cache)
1142        let call_count = mock.get_call_count().await;
1143        assert_eq!(
1144            call_count, 1,
1145            "Provider should be called exactly once despite concurrent access"
1146        );
1147    }
1148
1149    #[tokio::test]
1150    async fn test_dynamic_credential_provider_concurrent_refresh() {
1151        MockClock::set_system_time(Duration::from_secs(100_000));
1152
1153        let now_ms = MockClock::system_time().as_millis() as u64;
1154
1155        // Create initial credentials that expired in the past (1000 seconds ago)
1156        let expires_at = now_ms - 1_000_000;
1157
1158        let initial_cred = Arc::new(ObjectStoreAwsCredential {
1159            key_id: "AKID_OLD".to_string(),
1160            secret_key: "SECRET_OLD".to_string(),
1161            token: Some("TOKEN_OLD".to_string()),
1162        });
1163
1164        // Mock will return credentials expiring in 1 hour
1165        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1166            3_600_000, // Expires in 1 hour
1167        )));
1168
1169        let provider = Arc::new(
1170            DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
1171                mock.clone(),
1172                Duration::from_secs(300),
1173                initial_cred,
1174                expires_at,
1175            ),
1176        );
1177
1178        // Spawn 20 concurrent tasks that all try to get credentials at the same time
1179        // Since the initial credential is expired, they'll all try to refresh
1180        let mut handles = vec![];
1181        for i in 0..20 {
1182            let provider = provider.clone();
1183            let handle = tokio::spawn(async move {
1184                let cred = provider.get_credential().await.unwrap();
1185                // All should get the new credentials (AKID_1 from first fetch)
1186                assert_eq!(cred.key_id, "AKID_1");
1187                assert_eq!(cred.secret_key, "SECRET_1");
1188                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1189                i
1190            });
1191            handles.push(handle);
1192        }
1193
1194        // Wait for all tasks to complete
1195        let results: Vec<_> = futures::future::join_all(handles)
1196            .await
1197            .into_iter()
1198            .map(|r| r.unwrap())
1199            .collect();
1200
1201        // Verify all 20 tasks completed successfully
1202        assert_eq!(results.len(), 20);
1203
1204        // The provider should have been called at least once, but possibly more times
1205        // due to the try_write mechanism and race conditions
1206        let call_count = mock.get_call_count().await;
1207        assert!(
1208            call_count >= 1,
1209            "Provider should be called at least once, was called {} times",
1210            call_count
1211        );
1212
1213        // It shouldn't be called 20 times though - the lock should prevent most concurrent fetches
1214        assert!(
1215            call_count < 10,
1216            "Provider should not be called too many times due to lock contention, was called {} times",
1217            call_count
1218        );
1219    }
1220}