Skip to main content

lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{Operator, services::S3};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
20    StaticCredentialProvider,
21    aws::{
22        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
23        AwsCredentialProvider,
24    },
25};
26use tokio::sync::RwLock;
27use url::Url;
28
29use crate::object_store::{
30    DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE, ObjectStore,
31    ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsAccessor,
32    dynamic_credentials::{NamespaceCredentialsProvider, build_dynamic_credential_provider},
33    throttle::{AimdThrottleConfig, AimdThrottledStore},
34};
35use lance_core::error::{Error, Result};
36
37#[derive(Default, Debug)]
38pub struct AwsStoreProvider;
39
40impl AwsStoreProvider {
41    async fn build_amazon_s3_store(
42        &self,
43        base_path: &mut Url,
44        params: &ObjectStoreParams,
45        storage_options: &StorageOptions,
46        is_s3_express: bool,
47    ) -> Result<Arc<dyn OSObjectStore>> {
48        // Use a low retry count since the AIMD throttle layer handles
49        // throttle recovery with its own retry loop.
50        let retry_config = RetryConfig {
51            backoff: Default::default(),
52            max_retries: storage_options.client_max_retries(),
53            retry_timeout: Duration::from_secs(storage_options.client_retry_timeout()),
54        };
55
56        let mut s3_storage_options = storage_options.as_s3_options();
57        let region = resolve_s3_region(base_path, &s3_storage_options).await?;
58
59        // Get accessor from params
60        let accessor = params.get_accessor();
61
62        let (aws_creds, region) = build_aws_credential(
63            params.s3_credentials_refresh_offset,
64            params.aws_credentials.clone(),
65            Some(&s3_storage_options),
66            region,
67            accessor,
68        )
69        .await?;
70
71        // Set S3Express flag if detected
72        if is_s3_express {
73            s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
74        }
75
76        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
77        base_path.set_scheme("s3").unwrap();
78        base_path.set_query(None);
79
80        // we can't use parse_url_opts here because we need to manually set the credentials provider
81        let mut builder =
82            AmazonS3Builder::new().with_client_options(storage_options.client_options()?);
83        for (key, value) in s3_storage_options {
84            builder = builder.with_config(key, value);
85        }
86        builder = builder
87            .with_url(base_path.as_ref())
88            .with_credentials(aws_creds)
89            .with_retry(retry_config)
90            .with_region(region);
91
92        Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
93    }
94
95    async fn build_opendal_s3_store(
96        &self,
97        base_path: &Url,
98        storage_options: &StorageOptions,
99    ) -> Result<Arc<dyn OSObjectStore>> {
100        let bucket = base_path
101            .host_str()
102            .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name"))?
103            .to_string();
104
105        let prefix = base_path.path().trim_start_matches('/').to_string();
106
107        // Start with all storage options as the config map
108        // OpenDAL will handle environment variables through its default credentials chain
109        let mut config_map: HashMap<String, String> = storage_options.0.clone();
110
111        // Set required OpenDAL configuration
112        config_map.insert("bucket".to_string(), bucket);
113
114        if !prefix.is_empty() {
115            config_map.insert("root".to_string(), "/".to_string());
116        }
117
118        let operator = Operator::from_iter::<S3>(config_map)
119            .map_err(|e| Error::invalid_input(format!("Failed to create S3 operator: {:?}", e)))?
120            .finish();
121
122        Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
123    }
124}
125
126#[async_trait::async_trait]
127impl ObjectStoreProvider for AwsStoreProvider {
128    async fn new_store(
129        &self,
130        mut base_path: Url,
131        params: &ObjectStoreParams,
132    ) -> Result<ObjectStore> {
133        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
134        let mut storage_options =
135            StorageOptions::new(params.storage_options().cloned().unwrap_or_default());
136        storage_options.with_env_s3();
137        let download_retry_count = storage_options.download_retry_count();
138
139        let use_opendal = storage_options
140            .0
141            .get("use_opendal")
142            .map(|v| v == "true")
143            .unwrap_or(false);
144
145        // Determine S3 Express and constant size upload parts before building the store
146        let is_s3_express = check_s3_express(&base_path, &storage_options);
147
148        let use_constant_size_upload_parts = storage_options
149            .0
150            .get("aws_endpoint")
151            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
152            .unwrap_or(false);
153
154        let inner = if use_opendal {
155            // Use OpenDAL implementation
156            self.build_opendal_s3_store(&base_path, &storage_options)
157                .await?
158        } else {
159            // Use default Amazon S3 implementation
160            self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
161                .await?
162        };
163        let throttle_config = AimdThrottleConfig::from_storage_options(params.storage_options())?;
164        let inner = if throttle_config.is_disabled() {
165            inner
166        } else if storage_options.client_max_retries() == 0 {
167            log::warn!(
168                "AIMD throttle disabled: the current implementation relies on the object store \
169                 client surfacing retry errors, which requires client_max_retries > 0. \
170                 No throttle or retry layer will be applied."
171            );
172            inner
173        } else {
174            Arc::new(AimdThrottledStore::new(inner, throttle_config)?) as Arc<dyn OSObjectStore>
175        };
176
177        Ok(ObjectStore {
178            inner,
179            scheme: String::from(base_path.scheme()),
180            block_size,
181            max_iop_size: *DEFAULT_MAX_IOP_SIZE,
182            use_constant_size_upload_parts,
183            list_is_lexically_ordered: !is_s3_express,
184            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
185            download_retry_count,
186            io_tracker: Default::default(),
187            store_prefix: self
188                .calculate_object_store_prefix(&base_path, params.storage_options())?,
189        })
190    }
191}
192
193/// Check if the storage is S3 Express
194fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
195    storage_options
196        .0
197        .get("s3_express")
198        .map(|v| v == "true")
199        .unwrap_or(false)
200        || url.authority().ends_with("--x-s3")
201}
202
203/// Figure out the S3 region of the bucket.
204///
205/// This resolves in order of precedence:
206/// 1. The region provided in the storage options
207/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
208///
209/// It can return None if no region is provided and the endpoint is set.
210async fn resolve_s3_region(
211    url: &Url,
212    storage_options: &HashMap<AmazonS3ConfigKey, String>,
213) -> Result<Option<String>> {
214    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
215        Ok(Some(region.clone()))
216    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
217        // If no endpoint is set, we can assume this is AWS S3 and the region
218        // can be resolved from the bucket.
219        let bucket = url.host_str().ok_or_else(|| {
220            Error::invalid_input(format!("Could not parse bucket from url: {}", url))
221        })?;
222
223        let mut client_options = ClientOptions::default();
224        for (key, value) in storage_options {
225            if let AmazonS3ConfigKey::Client(client_key) = key {
226                client_options = client_options.with_config(*client_key, value.clone());
227            }
228        }
229
230        let bucket_region =
231            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
232        Ok(Some(bucket_region))
233    } else {
234        Ok(None)
235    }
236}
237
238/// Build AWS credentials
239///
240/// This resolves credentials from the following sources in order:
241/// 1. An explicit `storage_options_accessor` with a provider
242/// 2. An explicit `credentials` provider
243/// 3. Explicit credentials in storage_options (as in `aws_access_key_id`,
244///    `aws_secret_access_key`, `aws_session_token`)
245/// 4. The default credential provider chain from AWS SDK.
246///
247/// # Storage Options Accessor
248///
249/// When `storage_options_accessor` is provided and has a dynamic provider,
250/// credentials are fetched and cached by the accessor with automatic refresh
251/// before expiration.
252///
253/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
254pub async fn build_aws_credential(
255    credentials_refresh_offset: Duration,
256    credentials: Option<AwsCredentialProvider>,
257    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
258    region: Option<String>,
259    storage_options_accessor: Option<Arc<StorageOptionsAccessor>>,
260) -> Result<(AwsCredentialProvider, String)> {
261    use aws_config::meta::region::RegionProviderChain;
262    const DEFAULT_REGION: &str = "us-west-2";
263
264    let region = if let Some(region) = region {
265        region
266    } else {
267        RegionProviderChain::default_provider()
268            .or_else(DEFAULT_REGION)
269            .region()
270            .await
271            .map(|r| r.as_ref().to_string())
272            .unwrap_or(DEFAULT_REGION.to_string())
273    };
274
275    let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
276
277    // Explicit aws_credentials takes precedence over dynamic credentials.
278    if credentials.is_none()
279        && let Some(dynamic_creds) = build_dynamic_credential_provider::<ObjectStoreAwsCredential>(
280            storage_options_accessor.clone(),
281        )
282        .await?
283    {
284        return Ok((dynamic_creds, region));
285    }
286
287    if storage_options_accessor
288        .as_ref()
289        .is_some_and(|a| a.has_provider())
290    {
291        log::debug!(
292            "Storage options from provider do not contain explicit AWS credentials, \
293             falling back to default AWS credentials chain."
294        );
295    }
296
297    // Fall back to existing logic for static credentials
298    if let Some(creds) = credentials {
299        Ok((creds, region))
300    } else if let Some(creds) = storage_options_credentials {
301        Ok((Arc::new(creds), region))
302    } else {
303        let credentials_provider = DefaultCredentialsChain::builder().build().await;
304
305        Ok((
306            Arc::new(AwsCredentialAdapter::new(
307                Arc::new(credentials_provider),
308                credentials_refresh_offset,
309            )),
310            region,
311        ))
312    }
313}
314
315fn extract_static_s3_credentials(
316    options: &HashMap<AmazonS3ConfigKey, String>,
317) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
318    let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
319    let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
320    let token = options.get(&AmazonS3ConfigKey::Token).cloned();
321    match (key_id, secret_key, token) {
322        (Some(key_id), Some(secret_key), token) => {
323            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
324                key_id,
325                secret_key,
326                token,
327            }))
328        }
329        _ => None,
330    }
331}
332
333/// Adapt an AWS SDK cred into object_store credentials
334#[derive(Debug)]
335pub struct AwsCredentialAdapter {
336    pub inner: Arc<dyn ProvideCredentials>,
337
338    // RefCell can't be shared across threads, so we use HashMap
339    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
340
341    // The amount of time before expiry to refresh credentials
342    credentials_refresh_offset: Duration,
343}
344
345impl AwsCredentialAdapter {
346    pub fn new(
347        provider: Arc<dyn ProvideCredentials>,
348        credentials_refresh_offset: Duration,
349    ) -> Self {
350        Self {
351            inner: provider,
352            cache: Arc::new(RwLock::new(HashMap::new())),
353            credentials_refresh_offset,
354        }
355    }
356}
357
358const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
359
360/// Convert std::time::SystemTime from AWS SDK to our mockable SystemTime
361fn to_system_time(time: std::time::SystemTime) -> SystemTime {
362    let duration_since_epoch = time
363        .duration_since(std::time::UNIX_EPOCH)
364        .expect("time should be after UNIX_EPOCH");
365    UNIX_EPOCH + duration_since_epoch
366}
367
368#[async_trait::async_trait]
369impl CredentialProvider for AwsCredentialAdapter {
370    type Credential = ObjectStoreAwsCredential;
371
372    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
373        let cached_creds = {
374            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
375            let expired = cache_value
376                .clone()
377                .map(|cred| {
378                    cred.expiry()
379                        .map(|exp| {
380                            to_system_time(exp)
381                                .checked_sub(self.credentials_refresh_offset)
382                                .expect("this time should always be valid")
383                                < SystemTime::now()
384                        })
385                        // no expiry is never expire
386                        .unwrap_or(false)
387                })
388                .unwrap_or(true); // no cred is the same as expired;
389            if expired { None } else { cache_value.clone() }
390        };
391
392        if let Some(creds) = cached_creds {
393            Ok(Arc::new(Self::Credential {
394                key_id: creds.access_key_id().to_string(),
395                secret_key: creds.secret_access_key().to_string(),
396                token: creds.session_token().map(|s| s.to_string()),
397            }))
398        } else {
399            let refreshed_creds =
400                Arc::new(self.inner.provide_credentials().await.map_err(|e| {
401                    Error::internal(format!("Failed to get AWS credentials: {:?}", e))
402                })?);
403
404            self.cache
405                .write()
406                .await
407                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
408
409            Ok(Arc::new(Self::Credential {
410                key_id: refreshed_creds.access_key_id().to_string(),
411                secret_key: refreshed_creds.secret_access_key().to_string(),
412                token: refreshed_creds.session_token().map(|s| s.to_string()),
413            }))
414        }
415    }
416}
417
418impl StorageOptions {
419    /// Add values from the environment to storage options
420    pub fn with_env_s3(&mut self) {
421        for (os_key, os_value) in std::env::vars_os() {
422            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str())
423                && let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase())
424                && !self.0.contains_key(config_key.as_ref())
425            {
426                self.0
427                    .insert(config_key.as_ref().to_string(), value.to_string());
428            }
429        }
430    }
431
432    /// Subset of options relevant for s3 storage
433    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
434        self.0
435            .iter()
436            .filter_map(|(key, value)| {
437                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
438                Some((s3_key, value.clone()))
439            })
440            .collect()
441    }
442}
443
444impl ObjectStoreParams {
445    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
446    pub fn with_aws_credentials(
447        aws_credentials: Option<AwsCredentialProvider>,
448        region: Option<String>,
449    ) -> Self {
450        let storage_options_accessor = region.map(|region| {
451            let opts: HashMap<String, String> =
452                [("region".into(), region)].iter().cloned().collect();
453            Arc::new(StorageOptionsAccessor::with_static_options(opts))
454        });
455        Self {
456            aws_credentials,
457            storage_options_accessor,
458            ..Default::default()
459        }
460    }
461}
462
463pub type DynamicStorageOptionsCredentialProvider =
464    NamespaceCredentialsProvider<ObjectStoreAwsCredential>;
465
466#[cfg(test)]
467mod tests {
468    use crate::object_store::ObjectStoreRegistry;
469    use crate::object_store::StorageOptionsProvider;
470    use mock_instant::thread_local::MockClock;
471    use object_store::path::Path;
472    use std::sync::atomic::{AtomicBool, Ordering};
473
474    use super::*;
475
476    #[derive(Debug, Default)]
477    struct MockAwsCredentialsProvider {
478        called: AtomicBool,
479    }
480
481    #[async_trait::async_trait]
482    impl CredentialProvider for MockAwsCredentialsProvider {
483        type Credential = ObjectStoreAwsCredential;
484
485        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
486            self.called.store(true, Ordering::Relaxed);
487            Ok(Arc::new(Self::Credential {
488                key_id: "".to_string(),
489                secret_key: "".to_string(),
490                token: None,
491            }))
492        }
493    }
494
495    #[tokio::test]
496    async fn test_injected_aws_creds_option_is_used() {
497        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
498        let registry = Arc::new(ObjectStoreRegistry::default());
499
500        let params = ObjectStoreParams {
501            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
502            ..ObjectStoreParams::default()
503        };
504
505        // Not called yet
506        assert!(!mock_provider.called.load(Ordering::Relaxed));
507
508        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
509            .await
510            .unwrap();
511
512        // fails, but we don't care
513        let _ = store
514            .open(&Path::parse("/").unwrap())
515            .await
516            .unwrap()
517            .get_range(0..1)
518            .await;
519
520        // Not called yet
521        assert!(mock_provider.called.load(Ordering::Relaxed));
522    }
523
524    #[test]
525    fn test_s3_path_parsing() {
526        let provider = AwsStoreProvider;
527
528        let cases = [
529            ("s3://bucket/path/to/file", "path/to/file"),
530            // for non ASCII string tests
531            ("s3://bucket/测试path/to/file", "测试path/to/file"),
532            ("s3://bucket/path/&to/file", "path/&to/file"),
533            ("s3://bucket/path/=to/file", "path/=to/file"),
534            (
535                "s3+ddb://bucket/path/to/file?ddbTableName=test",
536                "path/to/file",
537            ),
538        ];
539
540        for (uri, expected_path) in cases {
541            let url = Url::parse(uri).unwrap();
542            let path = provider.extract_path(&url).unwrap();
543            let expected_path = Path::from(expected_path);
544            assert_eq!(path, expected_path)
545        }
546    }
547
548    #[test]
549    fn test_is_s3_express() {
550        let cases = [
551            (
552                "s3://bucket/path/to/file",
553                HashMap::from([("s3_express".to_string(), "true".to_string())]),
554                true,
555            ),
556            (
557                "s3://bucket/path/to/file",
558                HashMap::from([("s3_express".to_string(), "false".to_string())]),
559                false,
560            ),
561            ("s3://bucket/path/to/file", HashMap::from([]), false),
562            (
563                "s3://bucket--x-s3/path/to/file",
564                HashMap::from([("s3_express".to_string(), "true".to_string())]),
565                true,
566            ),
567            (
568                "s3://bucket--x-s3/path/to/file",
569                HashMap::from([("s3_express".to_string(), "false".to_string())]),
570                true, // URL takes precedence
571            ),
572            ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
573        ];
574
575        for (uri, storage_map, expected) in cases {
576            let url = Url::parse(uri).unwrap();
577            let storage_options = StorageOptions(storage_map);
578            let is_s3_express = check_s3_express(&url, &storage_options);
579            assert_eq!(is_s3_express, expected);
580        }
581    }
582
583    #[tokio::test]
584    async fn test_use_opendal_flag() {
585        use crate::object_store::StorageOptionsAccessor;
586        let provider = AwsStoreProvider;
587        let url = Url::parse("s3://test-bucket/path").unwrap();
588        let params_with_flag = ObjectStoreParams {
589            storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
590                HashMap::from([
591                    ("use_opendal".to_string(), "true".to_string()),
592                    ("region".to_string(), "us-west-2".to_string()),
593                ]),
594            ))),
595            ..Default::default()
596        };
597
598        let store = provider
599            .new_store(url.clone(), &params_with_flag)
600            .await
601            .unwrap();
602        assert_eq!(store.scheme, "s3");
603    }
604
605    #[derive(Debug)]
606    struct MockStorageOptionsProvider {
607        call_count: Arc<RwLock<usize>>,
608        expires_in_millis: Option<u64>,
609    }
610
611    impl MockStorageOptionsProvider {
612        fn new(expires_in_millis: Option<u64>) -> Self {
613            Self {
614                call_count: Arc::new(RwLock::new(0)),
615                expires_in_millis,
616            }
617        }
618
619        async fn get_call_count(&self) -> usize {
620            *self.call_count.read().await
621        }
622    }
623
624    #[async_trait::async_trait]
625    impl StorageOptionsProvider for MockStorageOptionsProvider {
626        async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
627            let count = {
628                let mut c = self.call_count.write().await;
629                *c += 1;
630                *c
631            };
632
633            let mut options = HashMap::from([
634                ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
635                (
636                    "aws_secret_access_key".to_string(),
637                    format!("SECRET_{}", count),
638                ),
639                ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
640            ]);
641
642            if let Some(expires_in) = self.expires_in_millis {
643                let now_ms = SystemTime::now()
644                    .duration_since(UNIX_EPOCH)
645                    .unwrap()
646                    .as_millis() as u64;
647                let expires_at = now_ms + expires_in;
648                options.insert("expires_at_millis".to_string(), expires_at.to_string());
649            }
650
651            Ok(Some(options))
652        }
653
654        fn provider_id(&self) -> String {
655            let ptr = Arc::as_ptr(&self.call_count) as usize;
656            format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
657        }
658    }
659
660    #[tokio::test]
661    async fn test_dynamic_credential_provider_with_initial_cache() {
662        MockClock::set_system_time(Duration::from_secs(100_000));
663
664        let now_ms = MockClock::system_time().as_millis() as u64;
665
666        // Create a mock provider that returns credentials expiring in 10 minutes
667        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
668            600_000, // Expires in 10 minutes
669        )));
670
671        // Create initial options with cached credentials that expire in 10 minutes
672        let expires_at = now_ms + 600_000; // 10 minutes from now
673        let initial_options = HashMap::from([
674            ("aws_access_key_id".to_string(), "AKID_CACHED".to_string()),
675            (
676                "aws_secret_access_key".to_string(),
677                "SECRET_CACHED".to_string(),
678            ),
679            ("aws_session_token".to_string(), "TOKEN_CACHED".to_string()),
680            ("expires_at_millis".to_string(), expires_at.to_string()),
681            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
682        ]);
683
684        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
685            mock.clone(),
686            initial_options,
687        );
688
689        // First call should use cached credentials (not expired yet)
690        let cred = provider.get_credential().await.unwrap();
691        assert_eq!(cred.key_id, "AKID_CACHED");
692        assert_eq!(cred.secret_key, "SECRET_CACHED");
693        assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
694
695        // Should not have called the provider yet
696        assert_eq!(mock.get_call_count().await, 0);
697    }
698
699    #[tokio::test]
700    async fn test_dynamic_credential_provider_with_expired_cache() {
701        MockClock::set_system_time(Duration::from_secs(100_000));
702
703        let now_ms = MockClock::system_time().as_millis() as u64;
704
705        // Create a mock provider that returns credentials expiring in 10 minutes
706        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
707            600_000, // Expires in 10 minutes
708        )));
709
710        // Create initial options with credentials that expired 1 second ago
711        let expired_time = now_ms - 1_000; // 1 second ago
712        let initial_options = HashMap::from([
713            ("aws_access_key_id".to_string(), "AKID_EXPIRED".to_string()),
714            (
715                "aws_secret_access_key".to_string(),
716                "SECRET_EXPIRED".to_string(),
717            ),
718            ("expires_at_millis".to_string(), expired_time.to_string()),
719            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
720        ]);
721
722        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
723            mock.clone(),
724            initial_options,
725        );
726
727        // First call should fetch new credentials because cached ones are expired
728        let cred = provider.get_credential().await.unwrap();
729        assert_eq!(cred.key_id, "AKID_1");
730        assert_eq!(cred.secret_key, "SECRET_1");
731        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
732
733        // Should have called the provider once
734        assert_eq!(mock.get_call_count().await, 1);
735    }
736
737    #[tokio::test]
738    async fn test_dynamic_credential_provider_refresh_lead_time() {
739        MockClock::set_system_time(Duration::from_secs(100_000));
740
741        // Create a mock provider that returns credentials expiring in 30 seconds
742        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
743            30_000, // Expires in 30 seconds
744        )));
745
746        // Create credential provider with default 60 second refresh offset
747        // This means credentials should be refreshed when they have less than 60 seconds left
748        let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
749
750        // First call should fetch credentials from provider (no initial cache)
751        // Credentials expire in 30 seconds, which is less than our 60 second refresh offset,
752        // so they should be considered "needs refresh" immediately
753        let cred = provider.get_credential().await.unwrap();
754        assert_eq!(cred.key_id, "AKID_1");
755        assert_eq!(mock.get_call_count().await, 1);
756
757        // Second call should trigger refresh because credentials expire in 30 seconds
758        // but our refresh lead time is 60 seconds (now + 60sec > expires_at)
759        // The mock will return new credentials (AKID_2) with the same expiration
760        let cred = provider.get_credential().await.unwrap();
761        assert_eq!(cred.key_id, "AKID_2");
762        assert_eq!(mock.get_call_count().await, 2);
763    }
764
765    #[tokio::test]
766    async fn test_dynamic_credential_provider_no_initial_cache() {
767        MockClock::set_system_time(Duration::from_secs(100_000));
768
769        // Create a mock provider that returns credentials expiring in 2 minutes
770        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
771            120_000, // Expires in 2 minutes
772        )));
773
774        // Create credential provider without initial cache, using default 60 second refresh offset
775        let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
776
777        // First call should fetch from provider (call count = 1)
778        let cred = provider.get_credential().await.unwrap();
779        assert_eq!(cred.key_id, "AKID_1");
780        assert_eq!(cred.secret_key, "SECRET_1");
781        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
782        assert_eq!(mock.get_call_count().await, 1);
783
784        // Second call should use cached credentials (not expired yet, still > 60 seconds remaining)
785        let cred = provider.get_credential().await.unwrap();
786        assert_eq!(cred.key_id, "AKID_1");
787        assert_eq!(mock.get_call_count().await, 1); // Still 1, didn't fetch again
788
789        // Advance time to 90 seconds - should trigger refresh (within 60 sec refresh offset)
790        // At this point, credentials expire in 30 seconds (< 60 sec offset)
791        MockClock::set_system_time(Duration::from_secs(100_000 + 90));
792        let cred = provider.get_credential().await.unwrap();
793        assert_eq!(cred.key_id, "AKID_2");
794        assert_eq!(cred.secret_key, "SECRET_2");
795        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
796        assert_eq!(mock.get_call_count().await, 2);
797
798        // Advance time to 210 seconds total (90 + 120) - should trigger another refresh
799        MockClock::set_system_time(Duration::from_secs(100_000 + 210));
800        let cred = provider.get_credential().await.unwrap();
801        assert_eq!(cred.key_id, "AKID_3");
802        assert_eq!(cred.secret_key, "SECRET_3");
803        assert_eq!(mock.get_call_count().await, 3);
804    }
805
806    #[tokio::test]
807    async fn test_dynamic_credential_provider_with_initial_options() {
808        MockClock::set_system_time(Duration::from_secs(100_000));
809
810        let now_ms = MockClock::system_time().as_millis() as u64;
811
812        // Create a mock provider that returns credentials expiring in 10 minutes
813        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
814            600_000, // Expires in 10 minutes
815        )));
816
817        // Create initial options with expiration in 10 minutes
818        let expires_at = now_ms + 600_000; // 10 minutes from now
819        let initial_options = HashMap::from([
820            ("aws_access_key_id".to_string(), "AKID_INITIAL".to_string()),
821            (
822                "aws_secret_access_key".to_string(),
823                "SECRET_INITIAL".to_string(),
824            ),
825            ("aws_session_token".to_string(), "TOKEN_INITIAL".to_string()),
826            ("expires_at_millis".to_string(), expires_at.to_string()),
827            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
828        ]);
829
830        // Create credential provider with initial options
831        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
832            mock.clone(),
833            initial_options,
834        );
835
836        // First call should use the initial credential (not expired yet)
837        let cred = provider.get_credential().await.unwrap();
838        assert_eq!(cred.key_id, "AKID_INITIAL");
839        assert_eq!(cred.secret_key, "SECRET_INITIAL");
840        assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
841
842        // Should not have called the provider yet
843        assert_eq!(mock.get_call_count().await, 0);
844
845        // Advance time to 6 minutes - this should trigger a refresh
846        // (5 minute refresh offset means we refresh 5 minutes before expiration)
847        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
848        let cred = provider.get_credential().await.unwrap();
849        assert_eq!(cred.key_id, "AKID_1");
850        assert_eq!(cred.secret_key, "SECRET_1");
851        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
852
853        // Should have called the provider once
854        assert_eq!(mock.get_call_count().await, 1);
855
856        // Advance time to 11 minutes total - this should trigger another refresh
857        MockClock::set_system_time(Duration::from_secs(100_000 + 660));
858        let cred = provider.get_credential().await.unwrap();
859        assert_eq!(cred.key_id, "AKID_2");
860        assert_eq!(cred.secret_key, "SECRET_2");
861        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
862
863        // Should have called the provider twice
864        assert_eq!(mock.get_call_count().await, 2);
865
866        // Advance time to 16 minutes total - this should trigger yet another refresh
867        MockClock::set_system_time(Duration::from_secs(100_000 + 960));
868        let cred = provider.get_credential().await.unwrap();
869        assert_eq!(cred.key_id, "AKID_3");
870        assert_eq!(cred.secret_key, "SECRET_3");
871        assert_eq!(cred.token, Some("TOKEN_3".to_string()));
872
873        // Should have called the provider three times
874        assert_eq!(mock.get_call_count().await, 3);
875    }
876
877    #[tokio::test]
878    async fn test_dynamic_credential_provider_concurrent_access() {
879        // Create a mock provider with far future expiration
880        let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
881
882        let provider = Arc::new(DynamicStorageOptionsCredentialProvider::from_provider(
883            mock.clone(),
884        ));
885
886        // Spawn 10 concurrent tasks that all try to get credentials at the same time
887        let mut handles = vec![];
888        for i in 0..10 {
889            let provider = provider.clone();
890            let handle = tokio::spawn(async move {
891                let cred = provider.get_credential().await.unwrap();
892                // Verify we got the correct credentials (should all be AKID_1 from first fetch)
893                assert_eq!(cred.key_id, "AKID_1");
894                assert_eq!(cred.secret_key, "SECRET_1");
895                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
896                i // Return task number for verification
897            });
898            handles.push(handle);
899        }
900
901        // Wait for all tasks to complete
902        let results: Vec<_> = futures::future::join_all(handles)
903            .await
904            .into_iter()
905            .map(|r| r.unwrap())
906            .collect();
907
908        // Verify all 10 tasks completed successfully
909        assert_eq!(results.len(), 10);
910        for i in 0..10 {
911            assert!(results.contains(&i));
912        }
913
914        // The provider should have been called exactly once (first request triggers fetch,
915        // subsequent requests use cache)
916        let call_count = mock.get_call_count().await;
917        assert_eq!(
918            call_count, 1,
919            "Provider should be called exactly once despite concurrent access"
920        );
921    }
922
923    #[tokio::test]
924    async fn test_dynamic_credential_provider_concurrent_refresh() {
925        MockClock::set_system_time(Duration::from_secs(100_000));
926
927        let now_ms = MockClock::system_time().as_millis() as u64;
928
929        // Create initial options with credentials that expired in the past (1000 seconds ago)
930        let expires_at = now_ms - 1_000_000;
931        let initial_options = HashMap::from([
932            ("aws_access_key_id".to_string(), "AKID_OLD".to_string()),
933            (
934                "aws_secret_access_key".to_string(),
935                "SECRET_OLD".to_string(),
936            ),
937            ("aws_session_token".to_string(), "TOKEN_OLD".to_string()),
938            ("expires_at_millis".to_string(), expires_at.to_string()),
939            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
940        ]);
941
942        // Mock will return credentials expiring in 1 hour
943        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
944            3_600_000, // Expires in 1 hour
945        )));
946
947        let provider = Arc::new(
948            DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
949                mock.clone(),
950                initial_options,
951            ),
952        );
953
954        // Spawn 20 concurrent tasks that all try to get credentials at the same time
955        // Since the initial credential is expired, they'll all try to refresh
956        let mut handles = vec![];
957        for i in 0..20 {
958            let provider = provider.clone();
959            let handle = tokio::spawn(async move {
960                let cred = provider.get_credential().await.unwrap();
961                // All should get the new credentials (AKID_1 from first fetch)
962                assert_eq!(cred.key_id, "AKID_1");
963                assert_eq!(cred.secret_key, "SECRET_1");
964                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
965                i
966            });
967            handles.push(handle);
968        }
969
970        // Wait for all tasks to complete
971        let results: Vec<_> = futures::future::join_all(handles)
972            .await
973            .into_iter()
974            .map(|r| r.unwrap())
975            .collect();
976
977        // Verify all 20 tasks completed successfully
978        assert_eq!(results.len(), 20);
979
980        // The provider should have been called at least once, but possibly more times
981        // due to the try_write mechanism and race conditions
982        let call_count = mock.get_call_count().await;
983        assert!(
984            call_count >= 1,
985            "Provider should be called at least once, was called {} times",
986            call_count
987        );
988
989        // It shouldn't be called 20 times though - the lock should prevent most concurrent fetches
990        assert!(
991            call_count < 10,
992            "Provider should not be called too many times due to lock contention, was called {} times",
993            call_count
994        );
995    }
996
997    #[tokio::test]
998    async fn test_explicit_aws_credentials_takes_precedence_over_accessor() {
999        // Create a mock storage options provider that should NOT be called
1000        let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1001
1002        // Create an accessor with the mock provider
1003        let accessor = Arc::new(StorageOptionsAccessor::with_provider(
1004            mock_storage_provider.clone(),
1005        ));
1006
1007        // Create an explicit AWS credentials provider
1008        let explicit_cred_provider = Arc::new(MockAwsCredentialsProvider::default());
1009
1010        // Build credentials with both aws_credentials AND accessor
1011        // The explicit aws_credentials should take precedence
1012        let (result, _region) = build_aws_credential(
1013            Duration::from_secs(300),
1014            Some(explicit_cred_provider.clone() as AwsCredentialProvider),
1015            None, // no storage_options
1016            Some("us-west-2".to_string()),
1017            Some(accessor),
1018        )
1019        .await
1020        .unwrap();
1021
1022        // Get credential from the result
1023        let cred = result.get_credential().await.unwrap();
1024
1025        // The explicit provider should have been called (it returns empty strings)
1026        assert!(explicit_cred_provider.called.load(Ordering::Relaxed));
1027
1028        // The storage options provider should NOT have been called
1029        assert_eq!(
1030            mock_storage_provider.get_call_count().await,
1031            0,
1032            "Storage options provider should not be called when explicit aws_credentials is provided"
1033        );
1034
1035        // Verify we got credentials from the explicit provider (empty strings)
1036        assert_eq!(cred.key_id, "");
1037        assert_eq!(cred.secret_key, "");
1038    }
1039
1040    #[tokio::test]
1041    async fn test_accessor_used_when_no_explicit_aws_credentials() {
1042        MockClock::set_system_time(Duration::from_secs(100_000));
1043
1044        let now_ms = MockClock::system_time().as_millis() as u64;
1045
1046        // Create a mock storage options provider
1047        let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1048
1049        // Create initial options
1050        let expires_at = now_ms + 600_000; // 10 minutes from now
1051        let initial_options = HashMap::from([
1052            (
1053                "aws_access_key_id".to_string(),
1054                "AKID_FROM_ACCESSOR".to_string(),
1055            ),
1056            (
1057                "aws_secret_access_key".to_string(),
1058                "SECRET_FROM_ACCESSOR".to_string(),
1059            ),
1060            (
1061                "aws_session_token".to_string(),
1062                "TOKEN_FROM_ACCESSOR".to_string(),
1063            ),
1064            ("expires_at_millis".to_string(), expires_at.to_string()),
1065            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
1066        ]);
1067
1068        // Create an accessor with initial options and provider
1069        let accessor = Arc::new(StorageOptionsAccessor::with_initial_and_provider(
1070            initial_options,
1071            mock_storage_provider.clone(),
1072        ));
1073
1074        // Build credentials with accessor but NO explicit aws_credentials
1075        let (result, _region) = build_aws_credential(
1076            Duration::from_secs(300),
1077            None, // no explicit aws_credentials
1078            None, // no storage_options
1079            Some("us-west-2".to_string()),
1080            Some(accessor),
1081        )
1082        .await
1083        .unwrap();
1084
1085        // Get credential - should use the initial accessor credentials
1086        let cred = result.get_credential().await.unwrap();
1087        assert_eq!(cred.key_id, "AKID_FROM_ACCESSOR");
1088        assert_eq!(cred.secret_key, "SECRET_FROM_ACCESSOR");
1089
1090        // Storage options provider should NOT have been called yet (using cached initial creds)
1091        assert_eq!(mock_storage_provider.get_call_count().await, 0);
1092
1093        // Advance time to trigger refresh (past the 5 minute refresh offset)
1094        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1095
1096        // Get credential again - should now fetch from provider
1097        let cred = result.get_credential().await.unwrap();
1098        assert_eq!(cred.key_id, "AKID_1");
1099        assert_eq!(cred.secret_key, "SECRET_1");
1100
1101        // Storage options provider should have been called once
1102        assert_eq!(mock_storage_provider.get_call_count().await, 1);
1103    }
1104}