Skip to main content

lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{Operator, services::S3};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
20    StaticCredentialProvider,
21    aws::{
22        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
23        AwsCredentialProvider,
24    },
25};
26use tokio::sync::RwLock;
27use url::Url;
28
29use crate::object_store::{
30    DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE, ObjectStore,
31    ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsAccessor,
32    dynamic_credentials::{NamespaceCredentialsProvider, build_dynamic_credential_provider},
33    throttle::{AimdThrottleConfig, AimdThrottledStore},
34};
35use lance_core::error::{Error, Result};
36
37#[derive(Default, Debug)]
38pub struct AwsStoreProvider;
39
40impl AwsStoreProvider {
41    async fn build_amazon_s3_store(
42        &self,
43        base_path: &mut Url,
44        params: &ObjectStoreParams,
45        storage_options: &StorageOptions,
46        is_s3_express: bool,
47    ) -> Result<Arc<dyn OSObjectStore>> {
48        // Use a low retry count since the AIMD throttle layer handles
49        // throttle recovery with its own retry loop.
50        let retry_config = RetryConfig {
51            backoff: Default::default(),
52            max_retries: storage_options.client_max_retries(),
53            retry_timeout: Duration::from_secs(storage_options.client_retry_timeout()),
54        };
55
56        let mut s3_storage_options = storage_options.as_s3_options();
57        let region = resolve_s3_region(base_path, &s3_storage_options).await?;
58
59        // Get accessor from params
60        let accessor = params.get_accessor();
61
62        let (aws_creds, region) = build_aws_credential(
63            params.s3_credentials_refresh_offset,
64            params.aws_credentials.clone(),
65            Some(&s3_storage_options),
66            region,
67            accessor,
68        )
69        .await?;
70
71        // Set S3Express flag if detected
72        if is_s3_express {
73            s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
74        }
75
76        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
77        base_path.set_scheme("s3").unwrap();
78        base_path.set_query(None);
79
80        // we can't use parse_url_opts here because we need to manually set the credentials provider
81        let mut builder =
82            AmazonS3Builder::new().with_client_options(storage_options.client_options()?);
83        for (key, value) in s3_storage_options {
84            builder = builder.with_config(key, value);
85        }
86        builder = builder
87            .with_url(base_path.as_ref())
88            .with_credentials(aws_creds)
89            .with_retry(retry_config)
90            .with_region(region);
91
92        Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
93    }
94
95    async fn build_opendal_s3_store(
96        &self,
97        base_path: &Url,
98        storage_options: &StorageOptions,
99    ) -> Result<Arc<dyn OSObjectStore>> {
100        let bucket = base_path
101            .host_str()
102            .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name"))?
103            .to_string();
104
105        let prefix = base_path.path().trim_start_matches('/').to_string();
106
107        // Start with all storage options as the config map
108        // OpenDAL will handle environment variables through its default credentials chain
109        let mut config_map: HashMap<String, String> = storage_options.0.clone();
110
111        // Set required OpenDAL configuration
112        config_map.insert("bucket".to_string(), bucket);
113
114        if !prefix.is_empty() {
115            config_map.insert("root".to_string(), "/".to_string());
116        }
117
118        let operator = Operator::from_iter::<S3>(config_map)
119            .map_err(|e| Error::invalid_input(format!("Failed to create S3 operator: {:?}", e)))?
120            .finish();
121
122        Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
123    }
124}
125
126#[async_trait::async_trait]
127impl ObjectStoreProvider for AwsStoreProvider {
128    async fn new_store(
129        &self,
130        mut base_path: Url,
131        params: &ObjectStoreParams,
132    ) -> Result<ObjectStore> {
133        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
134        let mut storage_options =
135            StorageOptions::new(params.storage_options().cloned().unwrap_or_default());
136        storage_options.with_env_s3();
137        let download_retry_count = storage_options.download_retry_count();
138
139        let use_opendal = storage_options
140            .0
141            .get("use_opendal")
142            .map(|v| v == "true")
143            .unwrap_or(false);
144
145        // Determine S3 Express and constant size upload parts before building the store
146        let is_s3_express = check_s3_express(&base_path, &storage_options);
147
148        let use_constant_size_upload_parts = storage_options
149            .0
150            .get("aws_endpoint")
151            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
152            .unwrap_or(false);
153
154        let inner = if use_opendal {
155            // Use OpenDAL implementation
156            self.build_opendal_s3_store(&base_path, &storage_options)
157                .await?
158        } else {
159            // Use default Amazon S3 implementation
160            self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
161                .await?
162        };
163        let throttle_config = AimdThrottleConfig::from_storage_options(params.storage_options())?;
164        let inner = if throttle_config.is_disabled() {
165            inner
166        } else {
167            Arc::new(AimdThrottledStore::new(inner, throttle_config)?) as Arc<dyn OSObjectStore>
168        };
169
170        Ok(ObjectStore {
171            inner,
172            scheme: String::from(base_path.scheme()),
173            block_size,
174            max_iop_size: *DEFAULT_MAX_IOP_SIZE,
175            use_constant_size_upload_parts,
176            list_is_lexically_ordered: !is_s3_express,
177            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
178            download_retry_count,
179            io_tracker: Default::default(),
180            store_prefix: self
181                .calculate_object_store_prefix(&base_path, params.storage_options())?,
182        })
183    }
184}
185
186/// Check if the storage is S3 Express
187fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
188    storage_options
189        .0
190        .get("s3_express")
191        .map(|v| v == "true")
192        .unwrap_or(false)
193        || url.authority().ends_with("--x-s3")
194}
195
196/// Figure out the S3 region of the bucket.
197///
198/// This resolves in order of precedence:
199/// 1. The region provided in the storage options
200/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
201///
202/// It can return None if no region is provided and the endpoint is set.
203async fn resolve_s3_region(
204    url: &Url,
205    storage_options: &HashMap<AmazonS3ConfigKey, String>,
206) -> Result<Option<String>> {
207    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
208        Ok(Some(region.clone()))
209    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
210        // If no endpoint is set, we can assume this is AWS S3 and the region
211        // can be resolved from the bucket.
212        let bucket = url.host_str().ok_or_else(|| {
213            Error::invalid_input(format!("Could not parse bucket from url: {}", url))
214        })?;
215
216        let mut client_options = ClientOptions::default();
217        for (key, value) in storage_options {
218            if let AmazonS3ConfigKey::Client(client_key) = key {
219                client_options = client_options.with_config(*client_key, value.clone());
220            }
221        }
222
223        let bucket_region =
224            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
225        Ok(Some(bucket_region))
226    } else {
227        Ok(None)
228    }
229}
230
231/// Build AWS credentials
232///
233/// This resolves credentials from the following sources in order:
234/// 1. An explicit `storage_options_accessor` with a provider
235/// 2. An explicit `credentials` provider
236/// 3. Explicit credentials in storage_options (as in `aws_access_key_id`,
237///    `aws_secret_access_key`, `aws_session_token`)
238/// 4. The default credential provider chain from AWS SDK.
239///
240/// # Storage Options Accessor
241///
242/// When `storage_options_accessor` is provided and has a dynamic provider,
243/// credentials are fetched and cached by the accessor with automatic refresh
244/// before expiration.
245///
246/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
247pub async fn build_aws_credential(
248    credentials_refresh_offset: Duration,
249    credentials: Option<AwsCredentialProvider>,
250    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
251    region: Option<String>,
252    storage_options_accessor: Option<Arc<StorageOptionsAccessor>>,
253) -> Result<(AwsCredentialProvider, String)> {
254    use aws_config::meta::region::RegionProviderChain;
255    const DEFAULT_REGION: &str = "us-west-2";
256
257    let region = if let Some(region) = region {
258        region
259    } else {
260        RegionProviderChain::default_provider()
261            .or_else(DEFAULT_REGION)
262            .region()
263            .await
264            .map(|r| r.as_ref().to_string())
265            .unwrap_or(DEFAULT_REGION.to_string())
266    };
267
268    let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
269
270    // Explicit aws_credentials takes precedence over dynamic credentials.
271    if credentials.is_none()
272        && let Some(dynamic_creds) = build_dynamic_credential_provider::<ObjectStoreAwsCredential>(
273            storage_options_accessor.clone(),
274        )
275        .await?
276    {
277        return Ok((dynamic_creds, region));
278    }
279
280    if storage_options_accessor
281        .as_ref()
282        .is_some_and(|a| a.has_provider())
283    {
284        log::debug!(
285            "Storage options from provider do not contain explicit AWS credentials, \
286             falling back to default AWS credentials chain."
287        );
288    }
289
290    // Fall back to existing logic for static credentials
291    if let Some(creds) = credentials {
292        Ok((creds, region))
293    } else if let Some(creds) = storage_options_credentials {
294        Ok((Arc::new(creds), region))
295    } else {
296        let credentials_provider = DefaultCredentialsChain::builder().build().await;
297
298        Ok((
299            Arc::new(AwsCredentialAdapter::new(
300                Arc::new(credentials_provider),
301                credentials_refresh_offset,
302            )),
303            region,
304        ))
305    }
306}
307
308fn extract_static_s3_credentials(
309    options: &HashMap<AmazonS3ConfigKey, String>,
310) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
311    let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
312    let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
313    let token = options.get(&AmazonS3ConfigKey::Token).cloned();
314    match (key_id, secret_key, token) {
315        (Some(key_id), Some(secret_key), token) => {
316            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
317                key_id,
318                secret_key,
319                token,
320            }))
321        }
322        _ => None,
323    }
324}
325
326/// Adapt an AWS SDK cred into object_store credentials
327#[derive(Debug)]
328pub struct AwsCredentialAdapter {
329    pub inner: Arc<dyn ProvideCredentials>,
330
331    // RefCell can't be shared across threads, so we use HashMap
332    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
333
334    // The amount of time before expiry to refresh credentials
335    credentials_refresh_offset: Duration,
336}
337
338impl AwsCredentialAdapter {
339    pub fn new(
340        provider: Arc<dyn ProvideCredentials>,
341        credentials_refresh_offset: Duration,
342    ) -> Self {
343        Self {
344            inner: provider,
345            cache: Arc::new(RwLock::new(HashMap::new())),
346            credentials_refresh_offset,
347        }
348    }
349}
350
351const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
352
353/// Convert std::time::SystemTime from AWS SDK to our mockable SystemTime
354fn to_system_time(time: std::time::SystemTime) -> SystemTime {
355    let duration_since_epoch = time
356        .duration_since(std::time::UNIX_EPOCH)
357        .expect("time should be after UNIX_EPOCH");
358    UNIX_EPOCH + duration_since_epoch
359}
360
361#[async_trait::async_trait]
362impl CredentialProvider for AwsCredentialAdapter {
363    type Credential = ObjectStoreAwsCredential;
364
365    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
366        let cached_creds = {
367            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
368            let expired = cache_value
369                .clone()
370                .map(|cred| {
371                    cred.expiry()
372                        .map(|exp| {
373                            to_system_time(exp)
374                                .checked_sub(self.credentials_refresh_offset)
375                                .expect("this time should always be valid")
376                                < SystemTime::now()
377                        })
378                        // no expiry is never expire
379                        .unwrap_or(false)
380                })
381                .unwrap_or(true); // no cred is the same as expired;
382            if expired { None } else { cache_value.clone() }
383        };
384
385        if let Some(creds) = cached_creds {
386            Ok(Arc::new(Self::Credential {
387                key_id: creds.access_key_id().to_string(),
388                secret_key: creds.secret_access_key().to_string(),
389                token: creds.session_token().map(|s| s.to_string()),
390            }))
391        } else {
392            let refreshed_creds =
393                Arc::new(self.inner.provide_credentials().await.map_err(|e| {
394                    Error::internal(format!("Failed to get AWS credentials: {:?}", e))
395                })?);
396
397            self.cache
398                .write()
399                .await
400                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
401
402            Ok(Arc::new(Self::Credential {
403                key_id: refreshed_creds.access_key_id().to_string(),
404                secret_key: refreshed_creds.secret_access_key().to_string(),
405                token: refreshed_creds.session_token().map(|s| s.to_string()),
406            }))
407        }
408    }
409}
410
411impl StorageOptions {
412    /// Add values from the environment to storage options
413    pub fn with_env_s3(&mut self) {
414        for (os_key, os_value) in std::env::vars_os() {
415            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str())
416                && let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase())
417                && !self.0.contains_key(config_key.as_ref())
418            {
419                self.0
420                    .insert(config_key.as_ref().to_string(), value.to_string());
421            }
422        }
423    }
424
425    /// Subset of options relevant for s3 storage
426    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
427        self.0
428            .iter()
429            .filter_map(|(key, value)| {
430                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
431                Some((s3_key, value.clone()))
432            })
433            .collect()
434    }
435}
436
437impl ObjectStoreParams {
438    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
439    pub fn with_aws_credentials(
440        aws_credentials: Option<AwsCredentialProvider>,
441        region: Option<String>,
442    ) -> Self {
443        let storage_options_accessor = region.map(|region| {
444            let opts: HashMap<String, String> =
445                [("region".into(), region)].iter().cloned().collect();
446            Arc::new(StorageOptionsAccessor::with_static_options(opts))
447        });
448        Self {
449            aws_credentials,
450            storage_options_accessor,
451            ..Default::default()
452        }
453    }
454}
455
456pub type DynamicStorageOptionsCredentialProvider =
457    NamespaceCredentialsProvider<ObjectStoreAwsCredential>;
458
459#[cfg(test)]
460mod tests {
461    use crate::object_store::ObjectStoreRegistry;
462    use crate::object_store::StorageOptionsProvider;
463    use mock_instant::thread_local::MockClock;
464    use object_store::path::Path;
465    use std::sync::atomic::{AtomicBool, Ordering};
466
467    use super::*;
468
469    #[derive(Debug, Default)]
470    struct MockAwsCredentialsProvider {
471        called: AtomicBool,
472    }
473
474    #[async_trait::async_trait]
475    impl CredentialProvider for MockAwsCredentialsProvider {
476        type Credential = ObjectStoreAwsCredential;
477
478        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
479            self.called.store(true, Ordering::Relaxed);
480            Ok(Arc::new(Self::Credential {
481                key_id: "".to_string(),
482                secret_key: "".to_string(),
483                token: None,
484            }))
485        }
486    }
487
488    #[tokio::test]
489    async fn test_injected_aws_creds_option_is_used() {
490        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
491        let registry = Arc::new(ObjectStoreRegistry::default());
492
493        let params = ObjectStoreParams {
494            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
495            ..ObjectStoreParams::default()
496        };
497
498        // Not called yet
499        assert!(!mock_provider.called.load(Ordering::Relaxed));
500
501        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
502            .await
503            .unwrap();
504
505        // fails, but we don't care
506        let _ = store
507            .open(&Path::parse("/").unwrap())
508            .await
509            .unwrap()
510            .get_range(0..1)
511            .await;
512
513        // Not called yet
514        assert!(mock_provider.called.load(Ordering::Relaxed));
515    }
516
517    #[test]
518    fn test_s3_path_parsing() {
519        let provider = AwsStoreProvider;
520
521        let cases = [
522            ("s3://bucket/path/to/file", "path/to/file"),
523            // for non ASCII string tests
524            ("s3://bucket/测试path/to/file", "测试path/to/file"),
525            ("s3://bucket/path/&to/file", "path/&to/file"),
526            ("s3://bucket/path/=to/file", "path/=to/file"),
527            (
528                "s3+ddb://bucket/path/to/file?ddbTableName=test",
529                "path/to/file",
530            ),
531        ];
532
533        for (uri, expected_path) in cases {
534            let url = Url::parse(uri).unwrap();
535            let path = provider.extract_path(&url).unwrap();
536            let expected_path = Path::from(expected_path);
537            assert_eq!(path, expected_path)
538        }
539    }
540
541    #[test]
542    fn test_is_s3_express() {
543        let cases = [
544            (
545                "s3://bucket/path/to/file",
546                HashMap::from([("s3_express".to_string(), "true".to_string())]),
547                true,
548            ),
549            (
550                "s3://bucket/path/to/file",
551                HashMap::from([("s3_express".to_string(), "false".to_string())]),
552                false,
553            ),
554            ("s3://bucket/path/to/file", HashMap::from([]), false),
555            (
556                "s3://bucket--x-s3/path/to/file",
557                HashMap::from([("s3_express".to_string(), "true".to_string())]),
558                true,
559            ),
560            (
561                "s3://bucket--x-s3/path/to/file",
562                HashMap::from([("s3_express".to_string(), "false".to_string())]),
563                true, // URL takes precedence
564            ),
565            ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
566        ];
567
568        for (uri, storage_map, expected) in cases {
569            let url = Url::parse(uri).unwrap();
570            let storage_options = StorageOptions(storage_map);
571            let is_s3_express = check_s3_express(&url, &storage_options);
572            assert_eq!(is_s3_express, expected);
573        }
574    }
575
576    #[tokio::test]
577    async fn test_use_opendal_flag() {
578        use crate::object_store::StorageOptionsAccessor;
579        let provider = AwsStoreProvider;
580        let url = Url::parse("s3://test-bucket/path").unwrap();
581        let params_with_flag = ObjectStoreParams {
582            storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
583                HashMap::from([
584                    ("use_opendal".to_string(), "true".to_string()),
585                    ("region".to_string(), "us-west-2".to_string()),
586                ]),
587            ))),
588            ..Default::default()
589        };
590
591        let store = provider
592            .new_store(url.clone(), &params_with_flag)
593            .await
594            .unwrap();
595        assert_eq!(store.scheme, "s3");
596    }
597
598    #[derive(Debug)]
599    struct MockStorageOptionsProvider {
600        call_count: Arc<RwLock<usize>>,
601        expires_in_millis: Option<u64>,
602    }
603
604    impl MockStorageOptionsProvider {
605        fn new(expires_in_millis: Option<u64>) -> Self {
606            Self {
607                call_count: Arc::new(RwLock::new(0)),
608                expires_in_millis,
609            }
610        }
611
612        async fn get_call_count(&self) -> usize {
613            *self.call_count.read().await
614        }
615    }
616
617    #[async_trait::async_trait]
618    impl StorageOptionsProvider for MockStorageOptionsProvider {
619        async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
620            let count = {
621                let mut c = self.call_count.write().await;
622                *c += 1;
623                *c
624            };
625
626            let mut options = HashMap::from([
627                ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
628                (
629                    "aws_secret_access_key".to_string(),
630                    format!("SECRET_{}", count),
631                ),
632                ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
633            ]);
634
635            if let Some(expires_in) = self.expires_in_millis {
636                let now_ms = SystemTime::now()
637                    .duration_since(UNIX_EPOCH)
638                    .unwrap()
639                    .as_millis() as u64;
640                let expires_at = now_ms + expires_in;
641                options.insert("expires_at_millis".to_string(), expires_at.to_string());
642            }
643
644            Ok(Some(options))
645        }
646
647        fn provider_id(&self) -> String {
648            let ptr = Arc::as_ptr(&self.call_count) as usize;
649            format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
650        }
651    }
652
653    #[tokio::test]
654    async fn test_dynamic_credential_provider_with_initial_cache() {
655        MockClock::set_system_time(Duration::from_secs(100_000));
656
657        let now_ms = MockClock::system_time().as_millis() as u64;
658
659        // Create a mock provider that returns credentials expiring in 10 minutes
660        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
661            600_000, // Expires in 10 minutes
662        )));
663
664        // Create initial options with cached credentials that expire in 10 minutes
665        let expires_at = now_ms + 600_000; // 10 minutes from now
666        let initial_options = HashMap::from([
667            ("aws_access_key_id".to_string(), "AKID_CACHED".to_string()),
668            (
669                "aws_secret_access_key".to_string(),
670                "SECRET_CACHED".to_string(),
671            ),
672            ("aws_session_token".to_string(), "TOKEN_CACHED".to_string()),
673            ("expires_at_millis".to_string(), expires_at.to_string()),
674            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
675        ]);
676
677        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
678            mock.clone(),
679            initial_options,
680        );
681
682        // First call should use cached credentials (not expired yet)
683        let cred = provider.get_credential().await.unwrap();
684        assert_eq!(cred.key_id, "AKID_CACHED");
685        assert_eq!(cred.secret_key, "SECRET_CACHED");
686        assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
687
688        // Should not have called the provider yet
689        assert_eq!(mock.get_call_count().await, 0);
690    }
691
692    #[tokio::test]
693    async fn test_dynamic_credential_provider_with_expired_cache() {
694        MockClock::set_system_time(Duration::from_secs(100_000));
695
696        let now_ms = MockClock::system_time().as_millis() as u64;
697
698        // Create a mock provider that returns credentials expiring in 10 minutes
699        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
700            600_000, // Expires in 10 minutes
701        )));
702
703        // Create initial options with credentials that expired 1 second ago
704        let expired_time = now_ms - 1_000; // 1 second ago
705        let initial_options = HashMap::from([
706            ("aws_access_key_id".to_string(), "AKID_EXPIRED".to_string()),
707            (
708                "aws_secret_access_key".to_string(),
709                "SECRET_EXPIRED".to_string(),
710            ),
711            ("expires_at_millis".to_string(), expired_time.to_string()),
712            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
713        ]);
714
715        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
716            mock.clone(),
717            initial_options,
718        );
719
720        // First call should fetch new credentials because cached ones are expired
721        let cred = provider.get_credential().await.unwrap();
722        assert_eq!(cred.key_id, "AKID_1");
723        assert_eq!(cred.secret_key, "SECRET_1");
724        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
725
726        // Should have called the provider once
727        assert_eq!(mock.get_call_count().await, 1);
728    }
729
730    #[tokio::test]
731    async fn test_dynamic_credential_provider_refresh_lead_time() {
732        MockClock::set_system_time(Duration::from_secs(100_000));
733
734        // Create a mock provider that returns credentials expiring in 30 seconds
735        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
736            30_000, // Expires in 30 seconds
737        )));
738
739        // Create credential provider with default 60 second refresh offset
740        // This means credentials should be refreshed when they have less than 60 seconds left
741        let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
742
743        // First call should fetch credentials from provider (no initial cache)
744        // Credentials expire in 30 seconds, which is less than our 60 second refresh offset,
745        // so they should be considered "needs refresh" immediately
746        let cred = provider.get_credential().await.unwrap();
747        assert_eq!(cred.key_id, "AKID_1");
748        assert_eq!(mock.get_call_count().await, 1);
749
750        // Second call should trigger refresh because credentials expire in 30 seconds
751        // but our refresh lead time is 60 seconds (now + 60sec > expires_at)
752        // The mock will return new credentials (AKID_2) with the same expiration
753        let cred = provider.get_credential().await.unwrap();
754        assert_eq!(cred.key_id, "AKID_2");
755        assert_eq!(mock.get_call_count().await, 2);
756    }
757
758    #[tokio::test]
759    async fn test_dynamic_credential_provider_no_initial_cache() {
760        MockClock::set_system_time(Duration::from_secs(100_000));
761
762        // Create a mock provider that returns credentials expiring in 2 minutes
763        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
764            120_000, // Expires in 2 minutes
765        )));
766
767        // Create credential provider without initial cache, using default 60 second refresh offset
768        let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
769
770        // First call should fetch from provider (call count = 1)
771        let cred = provider.get_credential().await.unwrap();
772        assert_eq!(cred.key_id, "AKID_1");
773        assert_eq!(cred.secret_key, "SECRET_1");
774        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
775        assert_eq!(mock.get_call_count().await, 1);
776
777        // Second call should use cached credentials (not expired yet, still > 60 seconds remaining)
778        let cred = provider.get_credential().await.unwrap();
779        assert_eq!(cred.key_id, "AKID_1");
780        assert_eq!(mock.get_call_count().await, 1); // Still 1, didn't fetch again
781
782        // Advance time to 90 seconds - should trigger refresh (within 60 sec refresh offset)
783        // At this point, credentials expire in 30 seconds (< 60 sec offset)
784        MockClock::set_system_time(Duration::from_secs(100_000 + 90));
785        let cred = provider.get_credential().await.unwrap();
786        assert_eq!(cred.key_id, "AKID_2");
787        assert_eq!(cred.secret_key, "SECRET_2");
788        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
789        assert_eq!(mock.get_call_count().await, 2);
790
791        // Advance time to 210 seconds total (90 + 120) - should trigger another refresh
792        MockClock::set_system_time(Duration::from_secs(100_000 + 210));
793        let cred = provider.get_credential().await.unwrap();
794        assert_eq!(cred.key_id, "AKID_3");
795        assert_eq!(cred.secret_key, "SECRET_3");
796        assert_eq!(mock.get_call_count().await, 3);
797    }
798
799    #[tokio::test]
800    async fn test_dynamic_credential_provider_with_initial_options() {
801        MockClock::set_system_time(Duration::from_secs(100_000));
802
803        let now_ms = MockClock::system_time().as_millis() as u64;
804
805        // Create a mock provider that returns credentials expiring in 10 minutes
806        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
807            600_000, // Expires in 10 minutes
808        )));
809
810        // Create initial options with expiration in 10 minutes
811        let expires_at = now_ms + 600_000; // 10 minutes from now
812        let initial_options = HashMap::from([
813            ("aws_access_key_id".to_string(), "AKID_INITIAL".to_string()),
814            (
815                "aws_secret_access_key".to_string(),
816                "SECRET_INITIAL".to_string(),
817            ),
818            ("aws_session_token".to_string(), "TOKEN_INITIAL".to_string()),
819            ("expires_at_millis".to_string(), expires_at.to_string()),
820            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
821        ]);
822
823        // Create credential provider with initial options
824        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
825            mock.clone(),
826            initial_options,
827        );
828
829        // First call should use the initial credential (not expired yet)
830        let cred = provider.get_credential().await.unwrap();
831        assert_eq!(cred.key_id, "AKID_INITIAL");
832        assert_eq!(cred.secret_key, "SECRET_INITIAL");
833        assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
834
835        // Should not have called the provider yet
836        assert_eq!(mock.get_call_count().await, 0);
837
838        // Advance time to 6 minutes - this should trigger a refresh
839        // (5 minute refresh offset means we refresh 5 minutes before expiration)
840        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
841        let cred = provider.get_credential().await.unwrap();
842        assert_eq!(cred.key_id, "AKID_1");
843        assert_eq!(cred.secret_key, "SECRET_1");
844        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
845
846        // Should have called the provider once
847        assert_eq!(mock.get_call_count().await, 1);
848
849        // Advance time to 11 minutes total - this should trigger another refresh
850        MockClock::set_system_time(Duration::from_secs(100_000 + 660));
851        let cred = provider.get_credential().await.unwrap();
852        assert_eq!(cred.key_id, "AKID_2");
853        assert_eq!(cred.secret_key, "SECRET_2");
854        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
855
856        // Should have called the provider twice
857        assert_eq!(mock.get_call_count().await, 2);
858
859        // Advance time to 16 minutes total - this should trigger yet another refresh
860        MockClock::set_system_time(Duration::from_secs(100_000 + 960));
861        let cred = provider.get_credential().await.unwrap();
862        assert_eq!(cred.key_id, "AKID_3");
863        assert_eq!(cred.secret_key, "SECRET_3");
864        assert_eq!(cred.token, Some("TOKEN_3".to_string()));
865
866        // Should have called the provider three times
867        assert_eq!(mock.get_call_count().await, 3);
868    }
869
870    #[tokio::test]
871    async fn test_dynamic_credential_provider_concurrent_access() {
872        // Create a mock provider with far future expiration
873        let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
874
875        let provider = Arc::new(DynamicStorageOptionsCredentialProvider::from_provider(
876            mock.clone(),
877        ));
878
879        // Spawn 10 concurrent tasks that all try to get credentials at the same time
880        let mut handles = vec![];
881        for i in 0..10 {
882            let provider = provider.clone();
883            let handle = tokio::spawn(async move {
884                let cred = provider.get_credential().await.unwrap();
885                // Verify we got the correct credentials (should all be AKID_1 from first fetch)
886                assert_eq!(cred.key_id, "AKID_1");
887                assert_eq!(cred.secret_key, "SECRET_1");
888                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
889                i // Return task number for verification
890            });
891            handles.push(handle);
892        }
893
894        // Wait for all tasks to complete
895        let results: Vec<_> = futures::future::join_all(handles)
896            .await
897            .into_iter()
898            .map(|r| r.unwrap())
899            .collect();
900
901        // Verify all 10 tasks completed successfully
902        assert_eq!(results.len(), 10);
903        for i in 0..10 {
904            assert!(results.contains(&i));
905        }
906
907        // The provider should have been called exactly once (first request triggers fetch,
908        // subsequent requests use cache)
909        let call_count = mock.get_call_count().await;
910        assert_eq!(
911            call_count, 1,
912            "Provider should be called exactly once despite concurrent access"
913        );
914    }
915
916    #[tokio::test]
917    async fn test_dynamic_credential_provider_concurrent_refresh() {
918        MockClock::set_system_time(Duration::from_secs(100_000));
919
920        let now_ms = MockClock::system_time().as_millis() as u64;
921
922        // Create initial options with credentials that expired in the past (1000 seconds ago)
923        let expires_at = now_ms - 1_000_000;
924        let initial_options = HashMap::from([
925            ("aws_access_key_id".to_string(), "AKID_OLD".to_string()),
926            (
927                "aws_secret_access_key".to_string(),
928                "SECRET_OLD".to_string(),
929            ),
930            ("aws_session_token".to_string(), "TOKEN_OLD".to_string()),
931            ("expires_at_millis".to_string(), expires_at.to_string()),
932            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
933        ]);
934
935        // Mock will return credentials expiring in 1 hour
936        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
937            3_600_000, // Expires in 1 hour
938        )));
939
940        let provider = Arc::new(
941            DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
942                mock.clone(),
943                initial_options,
944            ),
945        );
946
947        // Spawn 20 concurrent tasks that all try to get credentials at the same time
948        // Since the initial credential is expired, they'll all try to refresh
949        let mut handles = vec![];
950        for i in 0..20 {
951            let provider = provider.clone();
952            let handle = tokio::spawn(async move {
953                let cred = provider.get_credential().await.unwrap();
954                // All should get the new credentials (AKID_1 from first fetch)
955                assert_eq!(cred.key_id, "AKID_1");
956                assert_eq!(cred.secret_key, "SECRET_1");
957                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
958                i
959            });
960            handles.push(handle);
961        }
962
963        // Wait for all tasks to complete
964        let results: Vec<_> = futures::future::join_all(handles)
965            .await
966            .into_iter()
967            .map(|r| r.unwrap())
968            .collect();
969
970        // Verify all 20 tasks completed successfully
971        assert_eq!(results.len(), 20);
972
973        // The provider should have been called at least once, but possibly more times
974        // due to the try_write mechanism and race conditions
975        let call_count = mock.get_call_count().await;
976        assert!(
977            call_count >= 1,
978            "Provider should be called at least once, was called {} times",
979            call_count
980        );
981
982        // It shouldn't be called 20 times though - the lock should prevent most concurrent fetches
983        assert!(
984            call_count < 10,
985            "Provider should not be called too many times due to lock contention, was called {} times",
986            call_count
987        );
988    }
989
990    #[tokio::test]
991    async fn test_explicit_aws_credentials_takes_precedence_over_accessor() {
992        // Create a mock storage options provider that should NOT be called
993        let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
994
995        // Create an accessor with the mock provider
996        let accessor = Arc::new(StorageOptionsAccessor::with_provider(
997            mock_storage_provider.clone(),
998        ));
999
1000        // Create an explicit AWS credentials provider
1001        let explicit_cred_provider = Arc::new(MockAwsCredentialsProvider::default());
1002
1003        // Build credentials with both aws_credentials AND accessor
1004        // The explicit aws_credentials should take precedence
1005        let (result, _region) = build_aws_credential(
1006            Duration::from_secs(300),
1007            Some(explicit_cred_provider.clone() as AwsCredentialProvider),
1008            None, // no storage_options
1009            Some("us-west-2".to_string()),
1010            Some(accessor),
1011        )
1012        .await
1013        .unwrap();
1014
1015        // Get credential from the result
1016        let cred = result.get_credential().await.unwrap();
1017
1018        // The explicit provider should have been called (it returns empty strings)
1019        assert!(explicit_cred_provider.called.load(Ordering::Relaxed));
1020
1021        // The storage options provider should NOT have been called
1022        assert_eq!(
1023            mock_storage_provider.get_call_count().await,
1024            0,
1025            "Storage options provider should not be called when explicit aws_credentials is provided"
1026        );
1027
1028        // Verify we got credentials from the explicit provider (empty strings)
1029        assert_eq!(cred.key_id, "");
1030        assert_eq!(cred.secret_key, "");
1031    }
1032
1033    #[tokio::test]
1034    async fn test_accessor_used_when_no_explicit_aws_credentials() {
1035        MockClock::set_system_time(Duration::from_secs(100_000));
1036
1037        let now_ms = MockClock::system_time().as_millis() as u64;
1038
1039        // Create a mock storage options provider
1040        let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1041
1042        // Create initial options
1043        let expires_at = now_ms + 600_000; // 10 minutes from now
1044        let initial_options = HashMap::from([
1045            (
1046                "aws_access_key_id".to_string(),
1047                "AKID_FROM_ACCESSOR".to_string(),
1048            ),
1049            (
1050                "aws_secret_access_key".to_string(),
1051                "SECRET_FROM_ACCESSOR".to_string(),
1052            ),
1053            (
1054                "aws_session_token".to_string(),
1055                "TOKEN_FROM_ACCESSOR".to_string(),
1056            ),
1057            ("expires_at_millis".to_string(), expires_at.to_string()),
1058            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
1059        ]);
1060
1061        // Create an accessor with initial options and provider
1062        let accessor = Arc::new(StorageOptionsAccessor::with_initial_and_provider(
1063            initial_options,
1064            mock_storage_provider.clone(),
1065        ));
1066
1067        // Build credentials with accessor but NO explicit aws_credentials
1068        let (result, _region) = build_aws_credential(
1069            Duration::from_secs(300),
1070            None, // no explicit aws_credentials
1071            None, // no storage_options
1072            Some("us-west-2".to_string()),
1073            Some(accessor),
1074        )
1075        .await
1076        .unwrap();
1077
1078        // Get credential - should use the initial accessor credentials
1079        let cred = result.get_credential().await.unwrap();
1080        assert_eq!(cred.key_id, "AKID_FROM_ACCESSOR");
1081        assert_eq!(cred.secret_key, "SECRET_FROM_ACCESSOR");
1082
1083        // Storage options provider should NOT have been called yet (using cached initial creds)
1084        assert_eq!(mock_storage_provider.get_call_count().await, 0);
1085
1086        // Advance time to trigger refresh (past the 5 minute refresh offset)
1087        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1088
1089        // Get credential again - should now fetch from provider
1090        let cred = result.get_credential().await.unwrap();
1091        assert_eq!(cred.key_id, "AKID_1");
1092        assert_eq!(cred.secret_key, "SECRET_1");
1093
1094        // Storage options provider should have been called once
1095        assert_eq!(mock_storage_provider.get_call_count().await, 1);
1096    }
1097}