lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, fmt, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{services::S3, Operator};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19    aws::{
20        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
21        AwsCredentialProvider,
22    },
23    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
24    StaticCredentialProvider,
25};
26use snafu::location;
27use tokio::sync::RwLock;
28use url::Url;
29
30use crate::object_store::{
31    ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsProvider,
32    DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
33};
34use lance_core::error::{Error, Result};
35
36#[derive(Default, Debug)]
37pub struct AwsStoreProvider;
38
39impl AwsStoreProvider {
40    async fn build_amazon_s3_store(
41        &self,
42        base_path: &mut Url,
43        params: &ObjectStoreParams,
44        storage_options: &StorageOptions,
45        is_s3_express: bool,
46    ) -> Result<Arc<dyn OSObjectStore>> {
47        let max_retries = storage_options.client_max_retries();
48        let retry_timeout = storage_options.client_retry_timeout();
49        let retry_config = RetryConfig {
50            backoff: Default::default(),
51            max_retries,
52            retry_timeout: Duration::from_secs(retry_timeout),
53        };
54
55        let mut s3_storage_options = storage_options.as_s3_options();
56        let region = resolve_s3_region(base_path, &s3_storage_options).await?;
57        let (aws_creds, region) = build_aws_credential(
58            params.s3_credentials_refresh_offset,
59            params.aws_credentials.clone(),
60            Some(&s3_storage_options),
61            region,
62            params.storage_options_provider.clone(),
63            storage_options.expires_at_millis(),
64        )
65        .await?;
66
67        // Set S3Express flag if detected
68        if is_s3_express {
69            s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
70        }
71
72        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
73        base_path.set_scheme("s3").unwrap();
74        base_path.set_query(None);
75
76        // we can't use parse_url_opts here because we need to manually set the credentials provider
77        let mut builder = AmazonS3Builder::new();
78        for (key, value) in s3_storage_options {
79            builder = builder.with_config(key, value);
80        }
81        builder = builder
82            .with_url(base_path.as_ref())
83            .with_credentials(aws_creds)
84            .with_retry(retry_config)
85            .with_region(region);
86
87        Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
88    }
89
90    async fn build_opendal_s3_store(
91        &self,
92        base_path: &Url,
93        storage_options: &StorageOptions,
94    ) -> Result<Arc<dyn OSObjectStore>> {
95        let bucket = base_path
96            .host_str()
97            .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name", location!()))?
98            .to_string();
99
100        let prefix = base_path.path().trim_start_matches('/').to_string();
101
102        // Start with all storage options as the config map
103        // OpenDAL will handle environment variables through its default credentials chain
104        let mut config_map: HashMap<String, String> = storage_options.0.clone();
105
106        // Set required OpenDAL configuration
107        config_map.insert("bucket".to_string(), bucket);
108
109        if !prefix.is_empty() {
110            config_map.insert("root".to_string(), "/".to_string());
111        }
112
113        let operator = Operator::from_iter::<S3>(config_map)
114            .map_err(|e| {
115                Error::invalid_input(
116                    format!("Failed to create S3 operator: {:?}", e),
117                    location!(),
118                )
119            })?
120            .finish();
121
122        Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
123    }
124}
125
126#[async_trait::async_trait]
127impl ObjectStoreProvider for AwsStoreProvider {
128    async fn new_store(
129        &self,
130        mut base_path: Url,
131        params: &ObjectStoreParams,
132    ) -> Result<ObjectStore> {
133        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
134        let mut storage_options =
135            StorageOptions(params.storage_options.clone().unwrap_or_default());
136        storage_options.with_env_s3();
137        let download_retry_count = storage_options.download_retry_count();
138
139        let use_opendal = storage_options
140            .0
141            .get("use_opendal")
142            .map(|v| v == "true")
143            .unwrap_or(false);
144
145        // Determine S3 Express and constant size upload parts before building the store
146        let is_s3_express = check_s3_express(&base_path, &storage_options);
147
148        let use_constant_size_upload_parts = storage_options
149            .0
150            .get("aws_endpoint")
151            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
152            .unwrap_or(false);
153
154        let inner = if use_opendal {
155            // Use OpenDAL implementation
156            self.build_opendal_s3_store(&base_path, &storage_options)
157                .await?
158        } else {
159            // Use default Amazon S3 implementation
160            self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
161                .await?
162        };
163
164        Ok(ObjectStore {
165            inner,
166            scheme: String::from(base_path.scheme()),
167            block_size,
168            max_iop_size: *DEFAULT_MAX_IOP_SIZE,
169            use_constant_size_upload_parts,
170            list_is_lexically_ordered: !is_s3_express,
171            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
172            download_retry_count,
173        })
174    }
175}
176
177/// Check if the storage is S3 Express
178fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
179    storage_options
180        .0
181        .get("s3_express")
182        .map(|v| v == "true")
183        .unwrap_or(false)
184        || url.authority().ends_with("--x-s3")
185}
186
187/// Figure out the S3 region of the bucket.
188///
189/// This resolves in order of precedence:
190/// 1. The region provided in the storage options
191/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
192///
193/// It can return None if no region is provided and the endpoint is set.
194async fn resolve_s3_region(
195    url: &Url,
196    storage_options: &HashMap<AmazonS3ConfigKey, String>,
197) -> Result<Option<String>> {
198    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
199        Ok(Some(region.clone()))
200    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
201        // If no endpoint is set, we can assume this is AWS S3 and the region
202        // can be resolved from the bucket.
203        let bucket = url.host_str().ok_or_else(|| {
204            Error::invalid_input(
205                format!("Could not parse bucket from url: {}", url),
206                location!(),
207            )
208        })?;
209
210        let mut client_options = ClientOptions::default();
211        for (key, value) in storage_options {
212            if let AmazonS3ConfigKey::Client(client_key) = key {
213                client_options = client_options.with_config(*client_key, value.clone());
214            }
215        }
216
217        let bucket_region =
218            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
219        Ok(Some(bucket_region))
220    } else {
221        Ok(None)
222    }
223}
224
225/// Build AWS credentials
226///
227/// This resolves credentials from the following sources in order:
228/// 1. An explicit `storage_options_provider`
229/// 2. An explicit `credentials` provider
230/// 3. Explicit credentials in storage_options (as in `aws_access_key_id`,
231///    `aws_secret_access_key`, `aws_session_token`)
232/// 4. The default credential provider chain from AWS SDK.
233///
234/// # Initial Credentials with Storage Options Provider
235///
236/// When `storage_options_provider` is provided along with `storage_options` and
237/// `expires_at_millis`, these serve as **initial values** to avoid redundant calls to
238/// fetch new storage options. The provider will use these initial credentials until they
239/// expire (based on `expires_at_millis`), then automatically fetch fresh credentials from
240/// the provider. Once the initial credentials expire, the passed-in values are no longer
241/// used - all future credentials come from the provider's `fetch_storage_options()` method.
242///
243/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
244pub async fn build_aws_credential(
245    credentials_refresh_offset: Duration,
246    credentials: Option<AwsCredentialProvider>,
247    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
248    region: Option<String>,
249    storage_options_provider: Option<Arc<dyn StorageOptionsProvider>>,
250    expires_at_millis: Option<u64>,
251) -> Result<(AwsCredentialProvider, String)> {
252    // TODO: make this return no credential provider not using AWS
253    use aws_config::meta::region::RegionProviderChain;
254    const DEFAULT_REGION: &str = "us-west-2";
255
256    let region = if let Some(region) = region {
257        region
258    } else {
259        RegionProviderChain::default_provider()
260            .or_else(DEFAULT_REGION)
261            .region()
262            .await
263            .map(|r| r.as_ref().to_string())
264            .unwrap_or(DEFAULT_REGION.to_string())
265    };
266
267    let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
268    if let Some(storage_options_provider) = storage_options_provider {
269        let creds = build_aws_credential_with_storage_options_provider(
270            storage_options_provider,
271            credentials_refresh_offset,
272            credentials,
273            storage_options_credentials,
274            expires_at_millis,
275        )
276        .await?;
277        Ok((creds, region))
278    } else if let Some(creds) = credentials {
279        Ok((creds, region))
280    } else if let Some(creds) = storage_options_credentials {
281        Ok((Arc::new(creds), region))
282    } else {
283        let credentials_provider = DefaultCredentialsChain::builder().build().await;
284
285        Ok((
286            Arc::new(AwsCredentialAdapter::new(
287                Arc::new(credentials_provider),
288                credentials_refresh_offset,
289            )),
290            region,
291        ))
292    }
293}
294
295async fn build_aws_credential_with_storage_options_provider(
296    storage_options_provider: Arc<dyn StorageOptionsProvider>,
297    credentials_refresh_offset: Duration,
298    credentials: Option<AwsCredentialProvider>,
299    storage_options_credentials: Option<StaticCredentialProvider<ObjectStoreAwsCredential>>,
300    expires_at_millis: Option<u64>,
301) -> Result<AwsCredentialProvider> {
302    match (expires_at_millis, credentials, storage_options_credentials) {
303        // Case 1: provider + credentials + expiration time
304        (Some(expires_at), Some(cred), _) => {
305            Ok(Arc::new(
306                DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
307                    storage_options_provider,
308                    credentials_refresh_offset,
309                    cred.get_credential().await?,
310                    expires_at,
311                ),
312            ))
313        }
314        // Case 2: provider + storage_options (with valid credentials) + expiration time
315        (Some(expires_at), None, Some(cred)) => {
316            Ok(Arc::new(
317                DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
318                    storage_options_provider,
319                    credentials_refresh_offset,
320                    cred.get_credential().await?,
321                    expires_at,
322                ),
323            ))
324        }
325        // Case 3: provider + storage_options without expiration - FAIL
326        (None, None, Some(_)) => Err(Error::IO {
327            source: Box::new(std::io::Error::other(
328                "expires_at_millis is required when using storage_options_provider with storage_options",
329            )),
330            location: location!(),
331        }),
332        // Case 4: provider + credentials without expiration - FAIL
333        (None, Some(_), _) => Err(Error::IO {
334            source: Box::new(std::io::Error::other(
335                "expires_at_millis is required when using storage_options_provider with credentials",
336            )),
337            location: location!(),
338        }),
339        // Case 5: provider without credentials/storage_options, or with expiration but no creds/opts
340        (_, None, None) => Ok(Arc::new(DynamicStorageOptionsCredentialProvider::new(
341            storage_options_provider,
342            credentials_refresh_offset,
343        ))),
344    }
345}
346
347fn extract_static_s3_credentials(
348    options: &HashMap<AmazonS3ConfigKey, String>,
349) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
350    let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
351    let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
352    let token = options.get(&AmazonS3ConfigKey::Token).cloned();
353    match (key_id, secret_key, token) {
354        (Some(key_id), Some(secret_key), token) => {
355            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
356                key_id,
357                secret_key,
358                token,
359            }))
360        }
361        _ => None,
362    }
363}
364
365/// Adapt an AWS SDK cred into object_store credentials
366#[derive(Debug)]
367pub struct AwsCredentialAdapter {
368    pub inner: Arc<dyn ProvideCredentials>,
369
370    // RefCell can't be shared across threads, so we use HashMap
371    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
372
373    // The amount of time before expiry to refresh credentials
374    credentials_refresh_offset: Duration,
375}
376
377impl AwsCredentialAdapter {
378    pub fn new(
379        provider: Arc<dyn ProvideCredentials>,
380        credentials_refresh_offset: Duration,
381    ) -> Self {
382        Self {
383            inner: provider,
384            cache: Arc::new(RwLock::new(HashMap::new())),
385            credentials_refresh_offset,
386        }
387    }
388}
389
390const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
391
392/// Convert std::time::SystemTime from AWS SDK to our mockable SystemTime
393fn to_system_time(time: std::time::SystemTime) -> SystemTime {
394    let duration_since_epoch = time
395        .duration_since(std::time::UNIX_EPOCH)
396        .expect("time should be after UNIX_EPOCH");
397    UNIX_EPOCH + duration_since_epoch
398}
399
400#[async_trait::async_trait]
401impl CredentialProvider for AwsCredentialAdapter {
402    type Credential = ObjectStoreAwsCredential;
403
404    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
405        let cached_creds = {
406            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
407            let expired = cache_value
408                .clone()
409                .map(|cred| {
410                    cred.expiry()
411                        .map(|exp| {
412                            to_system_time(exp)
413                                .checked_sub(self.credentials_refresh_offset)
414                                .expect("this time should always be valid")
415                                < SystemTime::now()
416                        })
417                        // no expiry is never expire
418                        .unwrap_or(false)
419                })
420                .unwrap_or(true); // no cred is the same as expired;
421            if expired {
422                None
423            } else {
424                cache_value.clone()
425            }
426        };
427
428        if let Some(creds) = cached_creds {
429            Ok(Arc::new(Self::Credential {
430                key_id: creds.access_key_id().to_string(),
431                secret_key: creds.secret_access_key().to_string(),
432                token: creds.session_token().map(|s| s.to_string()),
433            }))
434        } else {
435            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
436                |e| Error::Internal {
437                    message: format!("Failed to get AWS credentials: {:?}", e),
438                    location: location!(),
439                },
440            )?);
441
442            self.cache
443                .write()
444                .await
445                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
446
447            Ok(Arc::new(Self::Credential {
448                key_id: refreshed_creds.access_key_id().to_string(),
449                secret_key: refreshed_creds.secret_access_key().to_string(),
450                token: refreshed_creds.session_token().map(|s| s.to_string()),
451            }))
452        }
453    }
454}
455
456impl StorageOptions {
457    /// Add values from the environment to storage options
458    pub fn with_env_s3(&mut self) {
459        for (os_key, os_value) in std::env::vars_os() {
460            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
461                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
462                    if !self.0.contains_key(config_key.as_ref()) {
463                        self.0
464                            .insert(config_key.as_ref().to_string(), value.to_string());
465                    }
466                }
467            }
468        }
469    }
470
471    /// Subset of options relevant for s3 storage
472    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
473        self.0
474            .iter()
475            .filter_map(|(key, value)| {
476                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
477                Some((s3_key, value.clone()))
478            })
479            .collect()
480    }
481}
482
483impl ObjectStoreParams {
484    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
485    pub fn with_aws_credentials(
486        aws_credentials: Option<AwsCredentialProvider>,
487        region: Option<String>,
488    ) -> Self {
489        Self {
490            aws_credentials,
491            storage_options: region
492                .map(|region| [("region".into(), region)].iter().cloned().collect()),
493            ..Default::default()
494        }
495    }
496}
497
498/// AWS Credential Provider that uses StorageOptionsProvider
499///
500/// This adapter converts our generic StorageOptionsProvider trait into
501/// AWS-specific credentials that can be used with S3. It caches credentials
502/// and automatically refreshes them before they expire.
503///
504/// # Future Work
505///
506/// TODO: Support AWS/GCP/Azure together in a unified credential provider.
507/// Currently this is AWS-specific. Needs investigation of how GCP and Azure credential
508/// refresh mechanisms work and whether they can be unified with AWS's approach.
509///
510/// See: <https://github.com/lancedb/lance/pull/4905#discussion_r2474605265>
511pub struct DynamicStorageOptionsCredentialProvider {
512    provider: Arc<dyn StorageOptionsProvider>,
513    cache: Arc<RwLock<Option<CachedCredential>>>,
514    refresh_offset: Duration,
515}
516
517impl fmt::Debug for DynamicStorageOptionsCredentialProvider {
518    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
519        f.debug_struct("DynamicStorageOptionsCredentialProvider")
520            .field("provider", &self.provider)
521            .field("refresh_offset", &self.refresh_offset)
522            .finish()
523    }
524}
525
526#[derive(Debug, Clone)]
527struct CachedCredential {
528    credential: Arc<ObjectStoreAwsCredential>,
529    expires_at_millis: Option<u64>,
530}
531
532impl DynamicStorageOptionsCredentialProvider {
533    /// Create a new credential provider without initial credentials
534    ///
535    /// # Arguments
536    /// * `provider` - The storage options provider
537    /// * `refresh_offset` - Duration before expiry to refresh credentials
538    pub fn new(provider: Arc<dyn StorageOptionsProvider>, refresh_offset: Duration) -> Self {
539        Self {
540            provider,
541            cache: Arc::new(RwLock::new(None)),
542            refresh_offset,
543        }
544    }
545
546    /// Create a new credential provider with initial credentials from an explicit credential
547    ///
548    /// # Arguments
549    /// * `provider` - The storage options provider
550    /// * `refresh_offset` - Duration before expiry to refresh credentials
551    /// * `credential` - Initial credential to cache
552    /// * `expires_at_millis` - Expiration time in milliseconds since epoch (required for refresh)
553    pub fn new_with_initial_credential(
554        provider: Arc<dyn StorageOptionsProvider>,
555        refresh_offset: Duration,
556        credential: Arc<ObjectStoreAwsCredential>,
557        expires_at_millis: u64,
558    ) -> Self {
559        Self {
560            provider,
561            cache: Arc::new(RwLock::new(Some(CachedCredential {
562                credential,
563                expires_at_millis: Some(expires_at_millis),
564            }))),
565            refresh_offset,
566        }
567    }
568
569    fn needs_refresh(&self, cached: &Option<CachedCredential>) -> bool {
570        match cached {
571            None => true,
572            Some(cached_cred) => {
573                if let Some(expires_at_millis) = cached_cred.expires_at_millis {
574                    let now_ms = SystemTime::now()
575                        .duration_since(UNIX_EPOCH)
576                        .unwrap_or(Duration::from_secs(0))
577                        .as_millis() as u64;
578
579                    // Refresh if we're within the refresh offset of expiration
580                    let refresh_offset_millis = self.refresh_offset.as_millis() as u64;
581                    now_ms + refresh_offset_millis >= expires_at_millis
582                } else {
583                    // No expiration means credentials never expire
584                    false
585                }
586            }
587        }
588    }
589
590    async fn do_get_credential(&self) -> ObjectStoreResult<Option<Arc<ObjectStoreAwsCredential>>> {
591        // Check if we have valid cached credentials with read lock
592        {
593            let cached = self.cache.read().await;
594            if !self.needs_refresh(&cached) {
595                if let Some(cached_cred) = &*cached {
596                    return Ok(Some(cached_cred.credential.clone()));
597                }
598            }
599        }
600
601        // Try to acquire write lock - if it fails, return None and let caller retry
602        let Ok(mut cache) = self.cache.try_write() else {
603            return Ok(None);
604        };
605
606        // Double-check if credentials are still stale after acquiring write lock
607        // (another thread might have refreshed them)
608        if !self.needs_refresh(&cache) {
609            if let Some(cached_cred) = &*cache {
610                return Ok(Some(cached_cred.credential.clone()));
611            }
612        }
613
614        let storage_options_map = self
615            .provider
616            .fetch_storage_options()
617            .await
618            .map_err(|e| object_store::Error::Generic {
619                store: "DynamicStorageOptionsCredentialProvider",
620                source: Box::new(e),
621            })?
622            .ok_or_else(|| object_store::Error::Generic {
623                store: "DynamicStorageOptionsCredentialProvider",
624                source: "No storage options available".into(),
625            })?;
626
627        let storage_options = StorageOptions(storage_options_map);
628        let expires_at_millis = storage_options.expires_at_millis();
629        let s3_options = storage_options.as_s3_options();
630        let static_creds = extract_static_s3_credentials(&s3_options).ok_or_else(|| {
631            object_store::Error::Generic {
632                store: "DynamicStorageOptionsCredentialProvider",
633                source: "Missing required credentials in storage options".into(),
634            }
635        })?;
636
637        let credential =
638            static_creds
639                .get_credential()
640                .await
641                .map_err(|e| object_store::Error::Generic {
642                    store: "DynamicStorageOptionsCredentialProvider",
643                    source: Box::new(e),
644                })?;
645
646        *cache = Some(CachedCredential {
647            credential: credential.clone(),
648            expires_at_millis,
649        });
650
651        Ok(Some(credential))
652    }
653}
654
655#[async_trait::async_trait]
656impl CredentialProvider for DynamicStorageOptionsCredentialProvider {
657    type Credential = ObjectStoreAwsCredential;
658
659    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
660        // Retry loop - if do_get_credential returns None (lock busy), retry from the beginning
661        loop {
662            match self.do_get_credential().await? {
663                Some(cred) => return Ok(cred),
664                None => {
665                    // Lock was busy, wait 10ms before retrying
666                    tokio::time::sleep(Duration::from_millis(10)).await;
667                    continue;
668                }
669            }
670        }
671    }
672}
673
674#[cfg(test)]
675mod tests {
676    use crate::object_store::ObjectStoreRegistry;
677    use mock_instant::thread_local::MockClock;
678    use object_store::path::Path;
679    use std::sync::atomic::{AtomicBool, Ordering};
680
681    use super::*;
682
683    #[derive(Debug, Default)]
684    struct MockAwsCredentialsProvider {
685        called: AtomicBool,
686    }
687
688    #[async_trait::async_trait]
689    impl CredentialProvider for MockAwsCredentialsProvider {
690        type Credential = ObjectStoreAwsCredential;
691
692        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
693            self.called.store(true, Ordering::Relaxed);
694            Ok(Arc::new(Self::Credential {
695                key_id: "".to_string(),
696                secret_key: "".to_string(),
697                token: None,
698            }))
699        }
700    }
701
702    #[tokio::test]
703    async fn test_injected_aws_creds_option_is_used() {
704        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
705        let registry = Arc::new(ObjectStoreRegistry::default());
706
707        let params = ObjectStoreParams {
708            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
709            ..ObjectStoreParams::default()
710        };
711
712        // Not called yet
713        assert!(!mock_provider.called.load(Ordering::Relaxed));
714
715        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
716            .await
717            .unwrap();
718
719        // fails, but we don't care
720        let _ = store
721            .open(&Path::parse("/").unwrap())
722            .await
723            .unwrap()
724            .get_range(0..1)
725            .await;
726
727        // Not called yet
728        assert!(mock_provider.called.load(Ordering::Relaxed));
729    }
730
731    #[test]
732    fn test_s3_path_parsing() {
733        let provider = AwsStoreProvider;
734
735        let cases = [
736            ("s3://bucket/path/to/file", "path/to/file"),
737            // for non ASCII string tests
738            ("s3://bucket/测试path/to/file", "测试path/to/file"),
739            ("s3://bucket/path/&to/file", "path/&to/file"),
740            ("s3://bucket/path/=to/file", "path/=to/file"),
741            (
742                "s3+ddb://bucket/path/to/file?ddbTableName=test",
743                "path/to/file",
744            ),
745        ];
746
747        for (uri, expected_path) in cases {
748            let url = Url::parse(uri).unwrap();
749            let path = provider.extract_path(&url).unwrap();
750            let expected_path = Path::from(expected_path);
751            assert_eq!(path, expected_path)
752        }
753    }
754
755    #[test]
756    fn test_is_s3_express() {
757        let cases = [
758            (
759                "s3://bucket/path/to/file",
760                HashMap::from([("s3_express".to_string(), "true".to_string())]),
761                true,
762            ),
763            (
764                "s3://bucket/path/to/file",
765                HashMap::from([("s3_express".to_string(), "false".to_string())]),
766                false,
767            ),
768            ("s3://bucket/path/to/file", HashMap::from([]), false),
769            (
770                "s3://bucket--x-s3/path/to/file",
771                HashMap::from([("s3_express".to_string(), "true".to_string())]),
772                true,
773            ),
774            (
775                "s3://bucket--x-s3/path/to/file",
776                HashMap::from([("s3_express".to_string(), "false".to_string())]),
777                true, // URL takes precedence
778            ),
779            ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
780        ];
781
782        for (uri, storage_map, expected) in cases {
783            let url = Url::parse(uri).unwrap();
784            let storage_options = StorageOptions(storage_map);
785            let is_s3_express = check_s3_express(&url, &storage_options);
786            assert_eq!(is_s3_express, expected);
787        }
788    }
789
790    #[tokio::test]
791    async fn test_use_opendal_flag() {
792        let provider = AwsStoreProvider;
793        let url = Url::parse("s3://test-bucket/path").unwrap();
794        let params_with_flag = ObjectStoreParams {
795            storage_options: Some(HashMap::from([
796                ("use_opendal".to_string(), "true".to_string()),
797                ("region".to_string(), "us-west-2".to_string()),
798            ])),
799            ..Default::default()
800        };
801
802        let store = provider
803            .new_store(url.clone(), &params_with_flag)
804            .await
805            .unwrap();
806        assert_eq!(store.scheme, "s3");
807    }
808
809    #[derive(Debug)]
810    struct MockStorageOptionsProvider {
811        call_count: Arc<RwLock<usize>>,
812        expires_in_millis: Option<u64>,
813    }
814
815    impl MockStorageOptionsProvider {
816        fn new(expires_in_millis: Option<u64>) -> Self {
817            Self {
818                call_count: Arc::new(RwLock::new(0)),
819                expires_in_millis,
820            }
821        }
822
823        async fn get_call_count(&self) -> usize {
824            *self.call_count.read().await
825        }
826    }
827
828    #[async_trait::async_trait]
829    impl StorageOptionsProvider for MockStorageOptionsProvider {
830        async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
831            let count = {
832                let mut c = self.call_count.write().await;
833                *c += 1;
834                *c
835            };
836
837            let mut options = HashMap::from([
838                ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
839                (
840                    "aws_secret_access_key".to_string(),
841                    format!("SECRET_{}", count),
842                ),
843                ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
844            ]);
845
846            if let Some(expires_in) = self.expires_in_millis {
847                let now_ms = SystemTime::now()
848                    .duration_since(UNIX_EPOCH)
849                    .unwrap()
850                    .as_millis() as u64;
851                let expires_at = now_ms + expires_in;
852                options.insert("expires_at_millis".to_string(), expires_at.to_string());
853            }
854
855            Ok(Some(options))
856        }
857
858        fn provider_id(&self) -> String {
859            let ptr = Arc::as_ptr(&self.call_count) as usize;
860            format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
861        }
862    }
863
864    #[tokio::test]
865    async fn test_dynamic_credential_provider_with_initial_cache() {
866        MockClock::set_system_time(Duration::from_secs(100_000));
867
868        let now_ms = MockClock::system_time().as_millis() as u64;
869
870        // Create a mock provider that returns credentials expiring in 10 minutes
871        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
872            600_000, // Expires in 10 minutes
873        )));
874
875        // Create credential provider with initial cached credentials that expire in 10 minutes
876        let expires_at = now_ms + 600_000; // 10 minutes from now
877        let initial_cred = Arc::new(ObjectStoreAwsCredential {
878            key_id: "AKID_CACHED".to_string(),
879            secret_key: "SECRET_CACHED".to_string(),
880            token: Some("TOKEN_CACHED".to_string()),
881        });
882
883        let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
884            mock.clone(),
885            Duration::from_secs(300), // 5 minute refresh offset
886            initial_cred,
887            expires_at,
888        );
889
890        // First call should use cached credentials (not expired yet)
891        let cred = provider.get_credential().await.unwrap();
892        assert_eq!(cred.key_id, "AKID_CACHED");
893        assert_eq!(cred.secret_key, "SECRET_CACHED");
894        assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
895
896        // Should not have called the provider yet
897        assert_eq!(mock.get_call_count().await, 0);
898    }
899
900    #[tokio::test]
901    async fn test_dynamic_credential_provider_with_expired_cache() {
902        MockClock::set_system_time(Duration::from_secs(100_000));
903
904        let now_ms = MockClock::system_time().as_millis() as u64;
905
906        // Create a mock provider that returns credentials expiring in 10 minutes
907        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
908            600_000, // Expires in 10 minutes
909        )));
910
911        // Create credential provider with initial cached credentials that expired 1 second ago
912        let expired_time = now_ms - 1_000; // 1 second ago
913        let initial_cred = Arc::new(ObjectStoreAwsCredential {
914            key_id: "AKID_EXPIRED".to_string(),
915            secret_key: "SECRET_EXPIRED".to_string(),
916            token: None,
917        });
918
919        let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
920            mock.clone(),
921            Duration::from_secs(300), // 5 minute refresh offset
922            initial_cred,
923            expired_time,
924        );
925
926        // First call should fetch new credentials because cached ones are expired
927        let cred = provider.get_credential().await.unwrap();
928        assert_eq!(cred.key_id, "AKID_1");
929        assert_eq!(cred.secret_key, "SECRET_1");
930        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
931
932        // Should have called the provider once
933        assert_eq!(mock.get_call_count().await, 1);
934    }
935
936    #[tokio::test]
937    async fn test_dynamic_credential_provider_refresh_lead_time() {
938        MockClock::set_system_time(Duration::from_secs(100_000));
939
940        // Create a mock provider that returns credentials expiring in 4 minutes
941        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
942            240_000, // Expires in 4 minutes
943        )));
944
945        // Create credential provider with 5 minute refresh offset
946        // This means credentials should be refreshed when they have less than 5 minutes left
947        let provider = DynamicStorageOptionsCredentialProvider::new(
948            mock.clone(),
949            Duration::from_secs(300), // 5 minute refresh offset
950        );
951
952        // First call should fetch credentials from provider (no initial cache)
953        // Credentials expire in 4 minutes, which is less than our 5 minute refresh offset,
954        // so they should be considered "needs refresh" immediately
955        let cred = provider.get_credential().await.unwrap();
956        assert_eq!(cred.key_id, "AKID_1");
957        assert_eq!(mock.get_call_count().await, 1);
958
959        // Second call should trigger refresh because credentials expire in 4 minutes
960        // but our refresh lead time is 5 minutes (now + 5min > expires_at)
961        // The mock will return new credentials (AKID_2) with the same expiration
962        let cred = provider.get_credential().await.unwrap();
963        assert_eq!(cred.key_id, "AKID_2");
964        assert_eq!(mock.get_call_count().await, 2);
965    }
966
967    #[tokio::test]
968    async fn test_dynamic_credential_provider_no_initial_cache() {
969        MockClock::set_system_time(Duration::from_secs(100_000));
970
971        // Create a mock provider that returns credentials expiring in 10 minutes
972        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
973            600_000, // Expires in 10 minutes
974        )));
975
976        // Create credential provider without initial cache
977        let provider = DynamicStorageOptionsCredentialProvider::new(
978            mock.clone(),
979            Duration::from_secs(300), // 5 minute refresh offset
980        );
981
982        // First call should fetch from provider (call count = 1)
983        let cred = provider.get_credential().await.unwrap();
984        assert_eq!(cred.key_id, "AKID_1");
985        assert_eq!(cred.secret_key, "SECRET_1");
986        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
987        assert_eq!(mock.get_call_count().await, 1);
988
989        // Second call should use cached credentials (not expired yet)
990        let cred = provider.get_credential().await.unwrap();
991        assert_eq!(cred.key_id, "AKID_1");
992        assert_eq!(mock.get_call_count().await, 1); // Still 1, didn't fetch again
993
994        // Advance time to 6 minutes - should trigger refresh (within 5 min refresh offset)
995        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
996        let cred = provider.get_credential().await.unwrap();
997        assert_eq!(cred.key_id, "AKID_2");
998        assert_eq!(cred.secret_key, "SECRET_2");
999        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
1000        assert_eq!(mock.get_call_count().await, 2);
1001
1002        // Advance time to 11 minutes total - should trigger another refresh
1003        MockClock::set_system_time(Duration::from_secs(100_000 + 660));
1004        let cred = provider.get_credential().await.unwrap();
1005        assert_eq!(cred.key_id, "AKID_3");
1006        assert_eq!(cred.secret_key, "SECRET_3");
1007        assert_eq!(mock.get_call_count().await, 3);
1008    }
1009
1010    #[tokio::test]
1011    async fn test_dynamic_credential_provider_with_initial_credential() {
1012        MockClock::set_system_time(Duration::from_secs(100_000));
1013
1014        let now_ms = MockClock::system_time().as_millis() as u64;
1015
1016        // Create a mock provider that returns credentials expiring in 10 minutes
1017        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1018            600_000, // Expires in 10 minutes
1019        )));
1020
1021        // Create an initial credential with expiration in 10 minutes
1022        let expires_at = now_ms + 600_000; // 10 minutes from now
1023        let initial_cred = Arc::new(ObjectStoreAwsCredential {
1024            key_id: "AKID_INITIAL".to_string(),
1025            secret_key: "SECRET_INITIAL".to_string(),
1026            token: Some("TOKEN_INITIAL".to_string()),
1027        });
1028
1029        // Create credential provider with initial credential and expiration
1030        let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
1031            mock.clone(),
1032            Duration::from_secs(300), // 5 minute refresh offset
1033            initial_cred,
1034            expires_at,
1035        );
1036
1037        // First call should use the initial credential (not expired yet)
1038        let cred = provider.get_credential().await.unwrap();
1039        assert_eq!(cred.key_id, "AKID_INITIAL");
1040        assert_eq!(cred.secret_key, "SECRET_INITIAL");
1041        assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
1042
1043        // Should not have called the provider yet
1044        assert_eq!(mock.get_call_count().await, 0);
1045
1046        // Advance time to 6 minutes - this should trigger a refresh
1047        // (5 minute refresh offset means we refresh 5 minutes before expiration)
1048        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1049        let cred = provider.get_credential().await.unwrap();
1050        assert_eq!(cred.key_id, "AKID_1");
1051        assert_eq!(cred.secret_key, "SECRET_1");
1052        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1053
1054        // Should have called the provider once
1055        assert_eq!(mock.get_call_count().await, 1);
1056
1057        // Advance time to 11 minutes total - this should trigger another refresh
1058        MockClock::set_system_time(Duration::from_secs(100_000 + 660));
1059        let cred = provider.get_credential().await.unwrap();
1060        assert_eq!(cred.key_id, "AKID_2");
1061        assert_eq!(cred.secret_key, "SECRET_2");
1062        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
1063
1064        // Should have called the provider twice
1065        assert_eq!(mock.get_call_count().await, 2);
1066
1067        // Advance time to 16 minutes total - this should trigger yet another refresh
1068        MockClock::set_system_time(Duration::from_secs(100_000 + 960));
1069        let cred = provider.get_credential().await.unwrap();
1070        assert_eq!(cred.key_id, "AKID_3");
1071        assert_eq!(cred.secret_key, "SECRET_3");
1072        assert_eq!(cred.token, Some("TOKEN_3".to_string()));
1073
1074        // Should have called the provider three times
1075        assert_eq!(mock.get_call_count().await, 3);
1076    }
1077
1078    #[tokio::test]
1079    async fn test_dynamic_credential_provider_concurrent_access() {
1080        // Create a mock provider with far future expiration
1081        let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
1082
1083        let provider = Arc::new(DynamicStorageOptionsCredentialProvider::new(
1084            mock.clone(),
1085            Duration::from_secs(300),
1086        ));
1087
1088        // Spawn 10 concurrent tasks that all try to get credentials at the same time
1089        let mut handles = vec![];
1090        for i in 0..10 {
1091            let provider = provider.clone();
1092            let handle = tokio::spawn(async move {
1093                let cred = provider.get_credential().await.unwrap();
1094                // Verify we got the correct credentials (should all be AKID_1 from first fetch)
1095                assert_eq!(cred.key_id, "AKID_1");
1096                assert_eq!(cred.secret_key, "SECRET_1");
1097                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1098                i // Return task number for verification
1099            });
1100            handles.push(handle);
1101        }
1102
1103        // Wait for all tasks to complete
1104        let results: Vec<_> = futures::future::join_all(handles)
1105            .await
1106            .into_iter()
1107            .map(|r| r.unwrap())
1108            .collect();
1109
1110        // Verify all 10 tasks completed successfully
1111        assert_eq!(results.len(), 10);
1112        for i in 0..10 {
1113            assert!(results.contains(&i));
1114        }
1115
1116        // The provider should have been called exactly once (first request triggers fetch,
1117        // subsequent requests use cache)
1118        let call_count = mock.get_call_count().await;
1119        assert_eq!(
1120            call_count, 1,
1121            "Provider should be called exactly once despite concurrent access"
1122        );
1123    }
1124
1125    #[tokio::test]
1126    async fn test_dynamic_credential_provider_concurrent_refresh() {
1127        MockClock::set_system_time(Duration::from_secs(100_000));
1128
1129        let now_ms = MockClock::system_time().as_millis() as u64;
1130
1131        // Create initial credentials that expired in the past (1000 seconds ago)
1132        let expires_at = now_ms - 1_000_000;
1133
1134        let initial_cred = Arc::new(ObjectStoreAwsCredential {
1135            key_id: "AKID_OLD".to_string(),
1136            secret_key: "SECRET_OLD".to_string(),
1137            token: Some("TOKEN_OLD".to_string()),
1138        });
1139
1140        // Mock will return credentials expiring in 1 hour
1141        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1142            3_600_000, // Expires in 1 hour
1143        )));
1144
1145        let provider = Arc::new(
1146            DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
1147                mock.clone(),
1148                Duration::from_secs(300),
1149                initial_cred,
1150                expires_at,
1151            ),
1152        );
1153
1154        // Spawn 20 concurrent tasks that all try to get credentials at the same time
1155        // Since the initial credential is expired, they'll all try to refresh
1156        let mut handles = vec![];
1157        for i in 0..20 {
1158            let provider = provider.clone();
1159            let handle = tokio::spawn(async move {
1160                let cred = provider.get_credential().await.unwrap();
1161                // All should get the new credentials (AKID_1 from first fetch)
1162                assert_eq!(cred.key_id, "AKID_1");
1163                assert_eq!(cred.secret_key, "SECRET_1");
1164                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1165                i
1166            });
1167            handles.push(handle);
1168        }
1169
1170        // Wait for all tasks to complete
1171        let results: Vec<_> = futures::future::join_all(handles)
1172            .await
1173            .into_iter()
1174            .map(|r| r.unwrap())
1175            .collect();
1176
1177        // Verify all 20 tasks completed successfully
1178        assert_eq!(results.len(), 20);
1179
1180        // The provider should have been called at least once, but possibly more times
1181        // due to the try_write mechanism and race conditions
1182        let call_count = mock.get_call_count().await;
1183        assert!(
1184            call_count >= 1,
1185            "Provider should be called at least once, was called {} times",
1186            call_count
1187        );
1188
1189        // It shouldn't be called 20 times though - the lock should prevent most concurrent fetches
1190        assert!(
1191            call_count < 10,
1192            "Provider should not be called too many times due to lock contention, was called {} times",
1193            call_count
1194        );
1195    }
1196}