Skip to main content

lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, fmt, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{services::S3, Operator};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19    aws::{
20        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
21        AwsCredentialProvider,
22    },
23    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
24    StaticCredentialProvider,
25};
26use snafu::location;
27use tokio::sync::RwLock;
28use url::Url;
29
30use crate::object_store::{
31    ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsAccessor,
32    StorageOptionsProvider, DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM,
33    DEFAULT_MAX_IOP_SIZE,
34};
35use lance_core::error::{Error, Result};
36
37#[derive(Default, Debug)]
38pub struct AwsStoreProvider;
39
40impl AwsStoreProvider {
41    async fn build_amazon_s3_store(
42        &self,
43        base_path: &mut Url,
44        params: &ObjectStoreParams,
45        storage_options: &StorageOptions,
46        is_s3_express: bool,
47    ) -> Result<Arc<dyn OSObjectStore>> {
48        let max_retries = storage_options.client_max_retries();
49        let retry_timeout = storage_options.client_retry_timeout();
50        let retry_config = RetryConfig {
51            backoff: Default::default(),
52            max_retries,
53            retry_timeout: Duration::from_secs(retry_timeout),
54        };
55
56        let mut s3_storage_options = storage_options.as_s3_options();
57        let region = resolve_s3_region(base_path, &s3_storage_options).await?;
58
59        // Get accessor from params
60        let accessor = params.get_accessor();
61
62        let (aws_creds, region) = build_aws_credential(
63            params.s3_credentials_refresh_offset,
64            params.aws_credentials.clone(),
65            Some(&s3_storage_options),
66            region,
67            accessor,
68        )
69        .await?;
70
71        // Set S3Express flag if detected
72        if is_s3_express {
73            s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
74        }
75
76        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
77        base_path.set_scheme("s3").unwrap();
78        base_path.set_query(None);
79
80        // we can't use parse_url_opts here because we need to manually set the credentials provider
81        let mut builder = AmazonS3Builder::new();
82        for (key, value) in s3_storage_options {
83            builder = builder.with_config(key, value);
84        }
85        builder = builder
86            .with_url(base_path.as_ref())
87            .with_credentials(aws_creds)
88            .with_retry(retry_config)
89            .with_region(region);
90
91        Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
92    }
93
94    async fn build_opendal_s3_store(
95        &self,
96        base_path: &Url,
97        storage_options: &StorageOptions,
98    ) -> Result<Arc<dyn OSObjectStore>> {
99        let bucket = base_path
100            .host_str()
101            .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name", location!()))?
102            .to_string();
103
104        let prefix = base_path.path().trim_start_matches('/').to_string();
105
106        // Start with all storage options as the config map
107        // OpenDAL will handle environment variables through its default credentials chain
108        let mut config_map: HashMap<String, String> = storage_options.0.clone();
109
110        // Set required OpenDAL configuration
111        config_map.insert("bucket".to_string(), bucket);
112
113        if !prefix.is_empty() {
114            config_map.insert("root".to_string(), "/".to_string());
115        }
116
117        let operator = Operator::from_iter::<S3>(config_map)
118            .map_err(|e| {
119                Error::invalid_input(
120                    format!("Failed to create S3 operator: {:?}", e),
121                    location!(),
122                )
123            })?
124            .finish();
125
126        Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
127    }
128}
129
130#[async_trait::async_trait]
131impl ObjectStoreProvider for AwsStoreProvider {
132    async fn new_store(
133        &self,
134        mut base_path: Url,
135        params: &ObjectStoreParams,
136    ) -> Result<ObjectStore> {
137        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
138        let mut storage_options =
139            StorageOptions(params.storage_options().cloned().unwrap_or_default());
140        storage_options.with_env_s3();
141        let download_retry_count = storage_options.download_retry_count();
142
143        let use_opendal = storage_options
144            .0
145            .get("use_opendal")
146            .map(|v| v == "true")
147            .unwrap_or(false);
148
149        // Determine S3 Express and constant size upload parts before building the store
150        let is_s3_express = check_s3_express(&base_path, &storage_options);
151
152        let use_constant_size_upload_parts = storage_options
153            .0
154            .get("aws_endpoint")
155            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
156            .unwrap_or(false);
157
158        let inner = if use_opendal {
159            // Use OpenDAL implementation
160            self.build_opendal_s3_store(&base_path, &storage_options)
161                .await?
162        } else {
163            // Use default Amazon S3 implementation
164            self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
165                .await?
166        };
167
168        Ok(ObjectStore {
169            inner,
170            scheme: String::from(base_path.scheme()),
171            block_size,
172            max_iop_size: *DEFAULT_MAX_IOP_SIZE,
173            use_constant_size_upload_parts,
174            list_is_lexically_ordered: !is_s3_express,
175            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
176            download_retry_count,
177            io_tracker: Default::default(),
178            store_prefix: self
179                .calculate_object_store_prefix(&base_path, params.storage_options())?,
180        })
181    }
182}
183
184/// Check if the storage is S3 Express
185fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
186    storage_options
187        .0
188        .get("s3_express")
189        .map(|v| v == "true")
190        .unwrap_or(false)
191        || url.authority().ends_with("--x-s3")
192}
193
194/// Figure out the S3 region of the bucket.
195///
196/// This resolves in order of precedence:
197/// 1. The region provided in the storage options
198/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
199///
200/// It can return None if no region is provided and the endpoint is set.
201async fn resolve_s3_region(
202    url: &Url,
203    storage_options: &HashMap<AmazonS3ConfigKey, String>,
204) -> Result<Option<String>> {
205    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
206        Ok(Some(region.clone()))
207    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
208        // If no endpoint is set, we can assume this is AWS S3 and the region
209        // can be resolved from the bucket.
210        let bucket = url.host_str().ok_or_else(|| {
211            Error::invalid_input(
212                format!("Could not parse bucket from url: {}", url),
213                location!(),
214            )
215        })?;
216
217        let mut client_options = ClientOptions::default();
218        for (key, value) in storage_options {
219            if let AmazonS3ConfigKey::Client(client_key) = key {
220                client_options = client_options.with_config(*client_key, value.clone());
221            }
222        }
223
224        let bucket_region =
225            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
226        Ok(Some(bucket_region))
227    } else {
228        Ok(None)
229    }
230}
231
232/// Build AWS credentials
233///
234/// This resolves credentials from the following sources in order:
235/// 1. An explicit `storage_options_accessor` with a provider
236/// 2. An explicit `credentials` provider
237/// 3. Explicit credentials in storage_options (as in `aws_access_key_id`,
238///    `aws_secret_access_key`, `aws_session_token`)
239/// 4. The default credential provider chain from AWS SDK.
240///
241/// # Storage Options Accessor
242///
243/// When `storage_options_accessor` is provided and has a dynamic provider,
244/// credentials are fetched and cached by the accessor with automatic refresh
245/// before expiration.
246///
247/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
248pub async fn build_aws_credential(
249    credentials_refresh_offset: Duration,
250    credentials: Option<AwsCredentialProvider>,
251    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
252    region: Option<String>,
253    storage_options_accessor: Option<Arc<StorageOptionsAccessor>>,
254) -> Result<(AwsCredentialProvider, String)> {
255    use aws_config::meta::region::RegionProviderChain;
256    const DEFAULT_REGION: &str = "us-west-2";
257
258    let region = if let Some(region) = region {
259        region
260    } else {
261        RegionProviderChain::default_provider()
262            .or_else(DEFAULT_REGION)
263            .region()
264            .await
265            .map(|r| r.as_ref().to_string())
266            .unwrap_or(DEFAULT_REGION.to_string())
267    };
268
269    let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
270
271    // If accessor has a provider, use DynamicStorageOptionsCredentialProvider
272    if let Some(accessor) = storage_options_accessor {
273        if accessor.has_provider() {
274            // Explicit aws_credentials takes precedence
275            if let Some(creds) = credentials {
276                return Ok((creds, region));
277            }
278            // Use accessor for dynamic credential refresh
279            return Ok((
280                Arc::new(DynamicStorageOptionsCredentialProvider::new(accessor)),
281                region,
282            ));
283        }
284    }
285
286    // Fall back to existing logic for static credentials
287    if let Some(creds) = credentials {
288        Ok((creds, region))
289    } else if let Some(creds) = storage_options_credentials {
290        Ok((Arc::new(creds), region))
291    } else {
292        let credentials_provider = DefaultCredentialsChain::builder().build().await;
293
294        Ok((
295            Arc::new(AwsCredentialAdapter::new(
296                Arc::new(credentials_provider),
297                credentials_refresh_offset,
298            )),
299            region,
300        ))
301    }
302}
303
304fn extract_static_s3_credentials(
305    options: &HashMap<AmazonS3ConfigKey, String>,
306) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
307    let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
308    let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
309    let token = options.get(&AmazonS3ConfigKey::Token).cloned();
310    match (key_id, secret_key, token) {
311        (Some(key_id), Some(secret_key), token) => {
312            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
313                key_id,
314                secret_key,
315                token,
316            }))
317        }
318        _ => None,
319    }
320}
321
322/// Adapt an AWS SDK cred into object_store credentials
323#[derive(Debug)]
324pub struct AwsCredentialAdapter {
325    pub inner: Arc<dyn ProvideCredentials>,
326
327    // RefCell can't be shared across threads, so we use HashMap
328    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
329
330    // The amount of time before expiry to refresh credentials
331    credentials_refresh_offset: Duration,
332}
333
334impl AwsCredentialAdapter {
335    pub fn new(
336        provider: Arc<dyn ProvideCredentials>,
337        credentials_refresh_offset: Duration,
338    ) -> Self {
339        Self {
340            inner: provider,
341            cache: Arc::new(RwLock::new(HashMap::new())),
342            credentials_refresh_offset,
343        }
344    }
345}
346
347const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
348
349/// Convert std::time::SystemTime from AWS SDK to our mockable SystemTime
350fn to_system_time(time: std::time::SystemTime) -> SystemTime {
351    let duration_since_epoch = time
352        .duration_since(std::time::UNIX_EPOCH)
353        .expect("time should be after UNIX_EPOCH");
354    UNIX_EPOCH + duration_since_epoch
355}
356
357#[async_trait::async_trait]
358impl CredentialProvider for AwsCredentialAdapter {
359    type Credential = ObjectStoreAwsCredential;
360
361    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
362        let cached_creds = {
363            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
364            let expired = cache_value
365                .clone()
366                .map(|cred| {
367                    cred.expiry()
368                        .map(|exp| {
369                            to_system_time(exp)
370                                .checked_sub(self.credentials_refresh_offset)
371                                .expect("this time should always be valid")
372                                < SystemTime::now()
373                        })
374                        // no expiry is never expire
375                        .unwrap_or(false)
376                })
377                .unwrap_or(true); // no cred is the same as expired;
378            if expired {
379                None
380            } else {
381                cache_value.clone()
382            }
383        };
384
385        if let Some(creds) = cached_creds {
386            Ok(Arc::new(Self::Credential {
387                key_id: creds.access_key_id().to_string(),
388                secret_key: creds.secret_access_key().to_string(),
389                token: creds.session_token().map(|s| s.to_string()),
390            }))
391        } else {
392            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
393                |e| Error::Internal {
394                    message: format!("Failed to get AWS credentials: {:?}", e),
395                    location: location!(),
396                },
397            )?);
398
399            self.cache
400                .write()
401                .await
402                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
403
404            Ok(Arc::new(Self::Credential {
405                key_id: refreshed_creds.access_key_id().to_string(),
406                secret_key: refreshed_creds.secret_access_key().to_string(),
407                token: refreshed_creds.session_token().map(|s| s.to_string()),
408            }))
409        }
410    }
411}
412
413impl StorageOptions {
414    /// Add values from the environment to storage options
415    pub fn with_env_s3(&mut self) {
416        for (os_key, os_value) in std::env::vars_os() {
417            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
418                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
419                    if !self.0.contains_key(config_key.as_ref()) {
420                        self.0
421                            .insert(config_key.as_ref().to_string(), value.to_string());
422                    }
423                }
424            }
425        }
426    }
427
428    /// Subset of options relevant for s3 storage
429    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
430        self.0
431            .iter()
432            .filter_map(|(key, value)| {
433                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
434                Some((s3_key, value.clone()))
435            })
436            .collect()
437    }
438}
439
440impl ObjectStoreParams {
441    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
442    pub fn with_aws_credentials(
443        aws_credentials: Option<AwsCredentialProvider>,
444        region: Option<String>,
445    ) -> Self {
446        let storage_options_accessor = region.map(|region| {
447            let opts: HashMap<String, String> =
448                [("region".into(), region)].iter().cloned().collect();
449            Arc::new(StorageOptionsAccessor::with_static_options(opts))
450        });
451        Self {
452            aws_credentials,
453            storage_options_accessor,
454            ..Default::default()
455        }
456    }
457}
458
459/// AWS Credential Provider that delegates to StorageOptionsAccessor
460///
461/// This adapter converts storage options from a [`StorageOptionsAccessor`] into
462/// AWS-specific credentials that can be used with S3. All caching and refresh logic
463/// is handled by the accessor.
464///
465/// # Future Work
466///
467/// TODO: Support AWS/GCP/Azure together in a unified credential provider.
468/// Currently this is AWS-specific. Needs investigation of how GCP and Azure credential
469/// refresh mechanisms work and whether they can be unified with AWS's approach.
470///
471/// See: <https://github.com/lance-format/lance/pull/4905#discussion_r2474605265>
472pub struct DynamicStorageOptionsCredentialProvider {
473    accessor: Arc<StorageOptionsAccessor>,
474}
475
476impl fmt::Debug for DynamicStorageOptionsCredentialProvider {
477    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478        f.debug_struct("DynamicStorageOptionsCredentialProvider")
479            .field("accessor", &self.accessor)
480            .finish()
481    }
482}
483
484impl DynamicStorageOptionsCredentialProvider {
485    /// Create a new credential provider from a storage options accessor
486    pub fn new(accessor: Arc<StorageOptionsAccessor>) -> Self {
487        Self { accessor }
488    }
489
490    /// Create a new credential provider from a storage options provider
491    ///
492    /// This is a convenience constructor for backward compatibility.
493    /// The refresh offset will be extracted from storage options using
494    /// the `refresh_offset_millis` key, defaulting to 60 seconds.
495    ///
496    /// # Arguments
497    /// * `provider` - The storage options provider
498    pub fn from_provider(provider: Arc<dyn StorageOptionsProvider>) -> Self {
499        Self {
500            accessor: Arc::new(StorageOptionsAccessor::with_provider(provider)),
501        }
502    }
503
504    /// Create a new credential provider with initial credentials
505    ///
506    /// This is a convenience constructor for backward compatibility.
507    /// The refresh offset will be extracted from initial_options using
508    /// the `refresh_offset_millis` key, defaulting to 60 seconds.
509    ///
510    /// # Arguments
511    /// * `provider` - The storage options provider
512    /// * `initial_options` - Initial storage options to cache
513    pub fn from_provider_with_initial(
514        provider: Arc<dyn StorageOptionsProvider>,
515        initial_options: HashMap<String, String>,
516    ) -> Self {
517        Self {
518            accessor: Arc::new(StorageOptionsAccessor::with_initial_and_provider(
519                initial_options,
520                provider,
521            )),
522        }
523    }
524}
525
526#[async_trait::async_trait]
527impl CredentialProvider for DynamicStorageOptionsCredentialProvider {
528    type Credential = ObjectStoreAwsCredential;
529
530    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
531        let storage_options = self.accessor.get_storage_options().await.map_err(|e| {
532            object_store::Error::Generic {
533                store: "DynamicStorageOptionsCredentialProvider",
534                source: Box::new(e),
535            }
536        })?;
537
538        let s3_options = storage_options.as_s3_options();
539        let static_creds = extract_static_s3_credentials(&s3_options).ok_or_else(|| {
540            object_store::Error::Generic {
541                store: "DynamicStorageOptionsCredentialProvider",
542                source: "Missing required credentials in storage options".into(),
543            }
544        })?;
545
546        static_creds
547            .get_credential()
548            .await
549            .map_err(|e| object_store::Error::Generic {
550                store: "DynamicStorageOptionsCredentialProvider",
551                source: Box::new(e),
552            })
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use crate::object_store::ObjectStoreRegistry;
559    use mock_instant::thread_local::MockClock;
560    use object_store::path::Path;
561    use std::sync::atomic::{AtomicBool, Ordering};
562
563    use super::*;
564
565    #[derive(Debug, Default)]
566    struct MockAwsCredentialsProvider {
567        called: AtomicBool,
568    }
569
570    #[async_trait::async_trait]
571    impl CredentialProvider for MockAwsCredentialsProvider {
572        type Credential = ObjectStoreAwsCredential;
573
574        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
575            self.called.store(true, Ordering::Relaxed);
576            Ok(Arc::new(Self::Credential {
577                key_id: "".to_string(),
578                secret_key: "".to_string(),
579                token: None,
580            }))
581        }
582    }
583
584    #[tokio::test]
585    async fn test_injected_aws_creds_option_is_used() {
586        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
587        let registry = Arc::new(ObjectStoreRegistry::default());
588
589        let params = ObjectStoreParams {
590            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
591            ..ObjectStoreParams::default()
592        };
593
594        // Not called yet
595        assert!(!mock_provider.called.load(Ordering::Relaxed));
596
597        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
598            .await
599            .unwrap();
600
601        // fails, but we don't care
602        let _ = store
603            .open(&Path::parse("/").unwrap())
604            .await
605            .unwrap()
606            .get_range(0..1)
607            .await;
608
609        // Not called yet
610        assert!(mock_provider.called.load(Ordering::Relaxed));
611    }
612
613    #[test]
614    fn test_s3_path_parsing() {
615        let provider = AwsStoreProvider;
616
617        let cases = [
618            ("s3://bucket/path/to/file", "path/to/file"),
619            // for non ASCII string tests
620            ("s3://bucket/测试path/to/file", "测试path/to/file"),
621            ("s3://bucket/path/&to/file", "path/&to/file"),
622            ("s3://bucket/path/=to/file", "path/=to/file"),
623            (
624                "s3+ddb://bucket/path/to/file?ddbTableName=test",
625                "path/to/file",
626            ),
627        ];
628
629        for (uri, expected_path) in cases {
630            let url = Url::parse(uri).unwrap();
631            let path = provider.extract_path(&url).unwrap();
632            let expected_path = Path::from(expected_path);
633            assert_eq!(path, expected_path)
634        }
635    }
636
637    #[test]
638    fn test_is_s3_express() {
639        let cases = [
640            (
641                "s3://bucket/path/to/file",
642                HashMap::from([("s3_express".to_string(), "true".to_string())]),
643                true,
644            ),
645            (
646                "s3://bucket/path/to/file",
647                HashMap::from([("s3_express".to_string(), "false".to_string())]),
648                false,
649            ),
650            ("s3://bucket/path/to/file", HashMap::from([]), false),
651            (
652                "s3://bucket--x-s3/path/to/file",
653                HashMap::from([("s3_express".to_string(), "true".to_string())]),
654                true,
655            ),
656            (
657                "s3://bucket--x-s3/path/to/file",
658                HashMap::from([("s3_express".to_string(), "false".to_string())]),
659                true, // URL takes precedence
660            ),
661            ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
662        ];
663
664        for (uri, storage_map, expected) in cases {
665            let url = Url::parse(uri).unwrap();
666            let storage_options = StorageOptions(storage_map);
667            let is_s3_express = check_s3_express(&url, &storage_options);
668            assert_eq!(is_s3_express, expected);
669        }
670    }
671
672    #[tokio::test]
673    async fn test_use_opendal_flag() {
674        use crate::object_store::StorageOptionsAccessor;
675        let provider = AwsStoreProvider;
676        let url = Url::parse("s3://test-bucket/path").unwrap();
677        let params_with_flag = ObjectStoreParams {
678            storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
679                HashMap::from([
680                    ("use_opendal".to_string(), "true".to_string()),
681                    ("region".to_string(), "us-west-2".to_string()),
682                ]),
683            ))),
684            ..Default::default()
685        };
686
687        let store = provider
688            .new_store(url.clone(), &params_with_flag)
689            .await
690            .unwrap();
691        assert_eq!(store.scheme, "s3");
692    }
693
694    #[derive(Debug)]
695    struct MockStorageOptionsProvider {
696        call_count: Arc<RwLock<usize>>,
697        expires_in_millis: Option<u64>,
698    }
699
700    impl MockStorageOptionsProvider {
701        fn new(expires_in_millis: Option<u64>) -> Self {
702            Self {
703                call_count: Arc::new(RwLock::new(0)),
704                expires_in_millis,
705            }
706        }
707
708        async fn get_call_count(&self) -> usize {
709            *self.call_count.read().await
710        }
711    }
712
713    #[async_trait::async_trait]
714    impl StorageOptionsProvider for MockStorageOptionsProvider {
715        async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
716            let count = {
717                let mut c = self.call_count.write().await;
718                *c += 1;
719                *c
720            };
721
722            let mut options = HashMap::from([
723                ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
724                (
725                    "aws_secret_access_key".to_string(),
726                    format!("SECRET_{}", count),
727                ),
728                ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
729            ]);
730
731            if let Some(expires_in) = self.expires_in_millis {
732                let now_ms = SystemTime::now()
733                    .duration_since(UNIX_EPOCH)
734                    .unwrap()
735                    .as_millis() as u64;
736                let expires_at = now_ms + expires_in;
737                options.insert("expires_at_millis".to_string(), expires_at.to_string());
738            }
739
740            Ok(Some(options))
741        }
742
743        fn provider_id(&self) -> String {
744            let ptr = Arc::as_ptr(&self.call_count) as usize;
745            format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
746        }
747    }
748
749    #[tokio::test]
750    async fn test_dynamic_credential_provider_with_initial_cache() {
751        MockClock::set_system_time(Duration::from_secs(100_000));
752
753        let now_ms = MockClock::system_time().as_millis() as u64;
754
755        // Create a mock provider that returns credentials expiring in 10 minutes
756        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
757            600_000, // Expires in 10 minutes
758        )));
759
760        // Create initial options with cached credentials that expire in 10 minutes
761        let expires_at = now_ms + 600_000; // 10 minutes from now
762        let initial_options = HashMap::from([
763            ("aws_access_key_id".to_string(), "AKID_CACHED".to_string()),
764            (
765                "aws_secret_access_key".to_string(),
766                "SECRET_CACHED".to_string(),
767            ),
768            ("aws_session_token".to_string(), "TOKEN_CACHED".to_string()),
769            ("expires_at_millis".to_string(), expires_at.to_string()),
770            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
771        ]);
772
773        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
774            mock.clone(),
775            initial_options,
776        );
777
778        // First call should use cached credentials (not expired yet)
779        let cred = provider.get_credential().await.unwrap();
780        assert_eq!(cred.key_id, "AKID_CACHED");
781        assert_eq!(cred.secret_key, "SECRET_CACHED");
782        assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
783
784        // Should not have called the provider yet
785        assert_eq!(mock.get_call_count().await, 0);
786    }
787
788    #[tokio::test]
789    async fn test_dynamic_credential_provider_with_expired_cache() {
790        MockClock::set_system_time(Duration::from_secs(100_000));
791
792        let now_ms = MockClock::system_time().as_millis() as u64;
793
794        // Create a mock provider that returns credentials expiring in 10 minutes
795        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
796            600_000, // Expires in 10 minutes
797        )));
798
799        // Create initial options with credentials that expired 1 second ago
800        let expired_time = now_ms - 1_000; // 1 second ago
801        let initial_options = HashMap::from([
802            ("aws_access_key_id".to_string(), "AKID_EXPIRED".to_string()),
803            (
804                "aws_secret_access_key".to_string(),
805                "SECRET_EXPIRED".to_string(),
806            ),
807            ("expires_at_millis".to_string(), expired_time.to_string()),
808            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
809        ]);
810
811        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
812            mock.clone(),
813            initial_options,
814        );
815
816        // First call should fetch new credentials because cached ones are expired
817        let cred = provider.get_credential().await.unwrap();
818        assert_eq!(cred.key_id, "AKID_1");
819        assert_eq!(cred.secret_key, "SECRET_1");
820        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
821
822        // Should have called the provider once
823        assert_eq!(mock.get_call_count().await, 1);
824    }
825
826    #[tokio::test]
827    async fn test_dynamic_credential_provider_refresh_lead_time() {
828        MockClock::set_system_time(Duration::from_secs(100_000));
829
830        // Create a mock provider that returns credentials expiring in 30 seconds
831        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
832            30_000, // Expires in 30 seconds
833        )));
834
835        // Create credential provider with default 60 second refresh offset
836        // This means credentials should be refreshed when they have less than 60 seconds left
837        let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
838
839        // First call should fetch credentials from provider (no initial cache)
840        // Credentials expire in 30 seconds, which is less than our 60 second refresh offset,
841        // so they should be considered "needs refresh" immediately
842        let cred = provider.get_credential().await.unwrap();
843        assert_eq!(cred.key_id, "AKID_1");
844        assert_eq!(mock.get_call_count().await, 1);
845
846        // Second call should trigger refresh because credentials expire in 30 seconds
847        // but our refresh lead time is 60 seconds (now + 60sec > expires_at)
848        // The mock will return new credentials (AKID_2) with the same expiration
849        let cred = provider.get_credential().await.unwrap();
850        assert_eq!(cred.key_id, "AKID_2");
851        assert_eq!(mock.get_call_count().await, 2);
852    }
853
854    #[tokio::test]
855    async fn test_dynamic_credential_provider_no_initial_cache() {
856        MockClock::set_system_time(Duration::from_secs(100_000));
857
858        // Create a mock provider that returns credentials expiring in 2 minutes
859        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
860            120_000, // Expires in 2 minutes
861        )));
862
863        // Create credential provider without initial cache, using default 60 second refresh offset
864        let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
865
866        // First call should fetch from provider (call count = 1)
867        let cred = provider.get_credential().await.unwrap();
868        assert_eq!(cred.key_id, "AKID_1");
869        assert_eq!(cred.secret_key, "SECRET_1");
870        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
871        assert_eq!(mock.get_call_count().await, 1);
872
873        // Second call should use cached credentials (not expired yet, still > 60 seconds remaining)
874        let cred = provider.get_credential().await.unwrap();
875        assert_eq!(cred.key_id, "AKID_1");
876        assert_eq!(mock.get_call_count().await, 1); // Still 1, didn't fetch again
877
878        // Advance time to 90 seconds - should trigger refresh (within 60 sec refresh offset)
879        // At this point, credentials expire in 30 seconds (< 60 sec offset)
880        MockClock::set_system_time(Duration::from_secs(100_000 + 90));
881        let cred = provider.get_credential().await.unwrap();
882        assert_eq!(cred.key_id, "AKID_2");
883        assert_eq!(cred.secret_key, "SECRET_2");
884        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
885        assert_eq!(mock.get_call_count().await, 2);
886
887        // Advance time to 210 seconds total (90 + 120) - should trigger another refresh
888        MockClock::set_system_time(Duration::from_secs(100_000 + 210));
889        let cred = provider.get_credential().await.unwrap();
890        assert_eq!(cred.key_id, "AKID_3");
891        assert_eq!(cred.secret_key, "SECRET_3");
892        assert_eq!(mock.get_call_count().await, 3);
893    }
894
895    #[tokio::test]
896    async fn test_dynamic_credential_provider_with_initial_options() {
897        MockClock::set_system_time(Duration::from_secs(100_000));
898
899        let now_ms = MockClock::system_time().as_millis() as u64;
900
901        // Create a mock provider that returns credentials expiring in 10 minutes
902        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
903            600_000, // Expires in 10 minutes
904        )));
905
906        // Create initial options with expiration in 10 minutes
907        let expires_at = now_ms + 600_000; // 10 minutes from now
908        let initial_options = HashMap::from([
909            ("aws_access_key_id".to_string(), "AKID_INITIAL".to_string()),
910            (
911                "aws_secret_access_key".to_string(),
912                "SECRET_INITIAL".to_string(),
913            ),
914            ("aws_session_token".to_string(), "TOKEN_INITIAL".to_string()),
915            ("expires_at_millis".to_string(), expires_at.to_string()),
916            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
917        ]);
918
919        // Create credential provider with initial options
920        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
921            mock.clone(),
922            initial_options,
923        );
924
925        // First call should use the initial credential (not expired yet)
926        let cred = provider.get_credential().await.unwrap();
927        assert_eq!(cred.key_id, "AKID_INITIAL");
928        assert_eq!(cred.secret_key, "SECRET_INITIAL");
929        assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
930
931        // Should not have called the provider yet
932        assert_eq!(mock.get_call_count().await, 0);
933
934        // Advance time to 6 minutes - this should trigger a refresh
935        // (5 minute refresh offset means we refresh 5 minutes before expiration)
936        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
937        let cred = provider.get_credential().await.unwrap();
938        assert_eq!(cred.key_id, "AKID_1");
939        assert_eq!(cred.secret_key, "SECRET_1");
940        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
941
942        // Should have called the provider once
943        assert_eq!(mock.get_call_count().await, 1);
944
945        // Advance time to 11 minutes total - this should trigger another refresh
946        MockClock::set_system_time(Duration::from_secs(100_000 + 660));
947        let cred = provider.get_credential().await.unwrap();
948        assert_eq!(cred.key_id, "AKID_2");
949        assert_eq!(cred.secret_key, "SECRET_2");
950        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
951
952        // Should have called the provider twice
953        assert_eq!(mock.get_call_count().await, 2);
954
955        // Advance time to 16 minutes total - this should trigger yet another refresh
956        MockClock::set_system_time(Duration::from_secs(100_000 + 960));
957        let cred = provider.get_credential().await.unwrap();
958        assert_eq!(cred.key_id, "AKID_3");
959        assert_eq!(cred.secret_key, "SECRET_3");
960        assert_eq!(cred.token, Some("TOKEN_3".to_string()));
961
962        // Should have called the provider three times
963        assert_eq!(mock.get_call_count().await, 3);
964    }
965
966    #[tokio::test]
967    async fn test_dynamic_credential_provider_concurrent_access() {
968        // Create a mock provider with far future expiration
969        let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
970
971        let provider = Arc::new(DynamicStorageOptionsCredentialProvider::from_provider(
972            mock.clone(),
973        ));
974
975        // Spawn 10 concurrent tasks that all try to get credentials at the same time
976        let mut handles = vec![];
977        for i in 0..10 {
978            let provider = provider.clone();
979            let handle = tokio::spawn(async move {
980                let cred = provider.get_credential().await.unwrap();
981                // Verify we got the correct credentials (should all be AKID_1 from first fetch)
982                assert_eq!(cred.key_id, "AKID_1");
983                assert_eq!(cred.secret_key, "SECRET_1");
984                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
985                i // Return task number for verification
986            });
987            handles.push(handle);
988        }
989
990        // Wait for all tasks to complete
991        let results: Vec<_> = futures::future::join_all(handles)
992            .await
993            .into_iter()
994            .map(|r| r.unwrap())
995            .collect();
996
997        // Verify all 10 tasks completed successfully
998        assert_eq!(results.len(), 10);
999        for i in 0..10 {
1000            assert!(results.contains(&i));
1001        }
1002
1003        // The provider should have been called exactly once (first request triggers fetch,
1004        // subsequent requests use cache)
1005        let call_count = mock.get_call_count().await;
1006        assert_eq!(
1007            call_count, 1,
1008            "Provider should be called exactly once despite concurrent access"
1009        );
1010    }
1011
1012    #[tokio::test]
1013    async fn test_dynamic_credential_provider_concurrent_refresh() {
1014        MockClock::set_system_time(Duration::from_secs(100_000));
1015
1016        let now_ms = MockClock::system_time().as_millis() as u64;
1017
1018        // Create initial options with credentials that expired in the past (1000 seconds ago)
1019        let expires_at = now_ms - 1_000_000;
1020        let initial_options = HashMap::from([
1021            ("aws_access_key_id".to_string(), "AKID_OLD".to_string()),
1022            (
1023                "aws_secret_access_key".to_string(),
1024                "SECRET_OLD".to_string(),
1025            ),
1026            ("aws_session_token".to_string(), "TOKEN_OLD".to_string()),
1027            ("expires_at_millis".to_string(), expires_at.to_string()),
1028            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
1029        ]);
1030
1031        // Mock will return credentials expiring in 1 hour
1032        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1033            3_600_000, // Expires in 1 hour
1034        )));
1035
1036        let provider = Arc::new(
1037            DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
1038                mock.clone(),
1039                initial_options,
1040            ),
1041        );
1042
1043        // Spawn 20 concurrent tasks that all try to get credentials at the same time
1044        // Since the initial credential is expired, they'll all try to refresh
1045        let mut handles = vec![];
1046        for i in 0..20 {
1047            let provider = provider.clone();
1048            let handle = tokio::spawn(async move {
1049                let cred = provider.get_credential().await.unwrap();
1050                // All should get the new credentials (AKID_1 from first fetch)
1051                assert_eq!(cred.key_id, "AKID_1");
1052                assert_eq!(cred.secret_key, "SECRET_1");
1053                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1054                i
1055            });
1056            handles.push(handle);
1057        }
1058
1059        // Wait for all tasks to complete
1060        let results: Vec<_> = futures::future::join_all(handles)
1061            .await
1062            .into_iter()
1063            .map(|r| r.unwrap())
1064            .collect();
1065
1066        // Verify all 20 tasks completed successfully
1067        assert_eq!(results.len(), 20);
1068
1069        // The provider should have been called at least once, but possibly more times
1070        // due to the try_write mechanism and race conditions
1071        let call_count = mock.get_call_count().await;
1072        assert!(
1073            call_count >= 1,
1074            "Provider should be called at least once, was called {} times",
1075            call_count
1076        );
1077
1078        // It shouldn't be called 20 times though - the lock should prevent most concurrent fetches
1079        assert!(
1080            call_count < 10,
1081            "Provider should not be called too many times due to lock contention, was called {} times",
1082            call_count
1083        );
1084    }
1085
1086    #[tokio::test]
1087    async fn test_explicit_aws_credentials_takes_precedence_over_accessor() {
1088        // Create a mock storage options provider that should NOT be called
1089        let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1090
1091        // Create an accessor with the mock provider
1092        let accessor = Arc::new(StorageOptionsAccessor::with_provider(
1093            mock_storage_provider.clone(),
1094        ));
1095
1096        // Create an explicit AWS credentials provider
1097        let explicit_cred_provider = Arc::new(MockAwsCredentialsProvider::default());
1098
1099        // Build credentials with both aws_credentials AND accessor
1100        // The explicit aws_credentials should take precedence
1101        let (result, _region) = build_aws_credential(
1102            Duration::from_secs(300),
1103            Some(explicit_cred_provider.clone() as AwsCredentialProvider),
1104            None, // no storage_options
1105            Some("us-west-2".to_string()),
1106            Some(accessor),
1107        )
1108        .await
1109        .unwrap();
1110
1111        // Get credential from the result
1112        let cred = result.get_credential().await.unwrap();
1113
1114        // The explicit provider should have been called (it returns empty strings)
1115        assert!(explicit_cred_provider.called.load(Ordering::Relaxed));
1116
1117        // The storage options provider should NOT have been called
1118        assert_eq!(
1119            mock_storage_provider.get_call_count().await,
1120            0,
1121            "Storage options provider should not be called when explicit aws_credentials is provided"
1122        );
1123
1124        // Verify we got credentials from the explicit provider (empty strings)
1125        assert_eq!(cred.key_id, "");
1126        assert_eq!(cred.secret_key, "");
1127    }
1128
1129    #[tokio::test]
1130    async fn test_accessor_used_when_no_explicit_aws_credentials() {
1131        MockClock::set_system_time(Duration::from_secs(100_000));
1132
1133        let now_ms = MockClock::system_time().as_millis() as u64;
1134
1135        // Create a mock storage options provider
1136        let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1137
1138        // Create initial options
1139        let expires_at = now_ms + 600_000; // 10 minutes from now
1140        let initial_options = HashMap::from([
1141            (
1142                "aws_access_key_id".to_string(),
1143                "AKID_FROM_ACCESSOR".to_string(),
1144            ),
1145            (
1146                "aws_secret_access_key".to_string(),
1147                "SECRET_FROM_ACCESSOR".to_string(),
1148            ),
1149            (
1150                "aws_session_token".to_string(),
1151                "TOKEN_FROM_ACCESSOR".to_string(),
1152            ),
1153            ("expires_at_millis".to_string(), expires_at.to_string()),
1154            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
1155        ]);
1156
1157        // Create an accessor with initial options and provider
1158        let accessor = Arc::new(StorageOptionsAccessor::with_initial_and_provider(
1159            initial_options,
1160            mock_storage_provider.clone(),
1161        ));
1162
1163        // Build credentials with accessor but NO explicit aws_credentials
1164        let (result, _region) = build_aws_credential(
1165            Duration::from_secs(300),
1166            None, // no explicit aws_credentials
1167            None, // no storage_options
1168            Some("us-west-2".to_string()),
1169            Some(accessor),
1170        )
1171        .await
1172        .unwrap();
1173
1174        // Get credential - should use the initial accessor credentials
1175        let cred = result.get_credential().await.unwrap();
1176        assert_eq!(cred.key_id, "AKID_FROM_ACCESSOR");
1177        assert_eq!(cred.secret_key, "SECRET_FROM_ACCESSOR");
1178
1179        // Storage options provider should NOT have been called yet (using cached initial creds)
1180        assert_eq!(mock_storage_provider.get_call_count().await, 0);
1181
1182        // Advance time to trigger refresh (past the 5 minute refresh offset)
1183        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1184
1185        // Get credential again - should now fetch from provider
1186        let cred = result.get_credential().await.unwrap();
1187        assert_eq!(cred.key_id, "AKID_1");
1188        assert_eq!(cred.secret_key, "SECRET_1");
1189
1190        // Storage options provider should have been called once
1191        assert_eq!(mock_storage_provider.get_call_count().await, 1);
1192    }
1193}