Skip to main content

lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, fmt, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{Operator, services::S3};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
20    StaticCredentialProvider,
21    aws::{
22        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
23        AwsCredentialProvider,
24    },
25};
26use tokio::sync::RwLock;
27use url::Url;
28
29use crate::object_store::{
30    DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE, ObjectStore,
31    ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsAccessor,
32    StorageOptionsProvider,
33};
34use lance_core::error::{Error, Result};
35
36#[derive(Default, Debug)]
37pub struct AwsStoreProvider;
38
39impl AwsStoreProvider {
40    async fn build_amazon_s3_store(
41        &self,
42        base_path: &mut Url,
43        params: &ObjectStoreParams,
44        storage_options: &StorageOptions,
45        is_s3_express: bool,
46    ) -> Result<Arc<dyn OSObjectStore>> {
47        let max_retries = storage_options.client_max_retries();
48        let retry_timeout = storage_options.client_retry_timeout();
49        let retry_config = RetryConfig {
50            backoff: Default::default(),
51            max_retries,
52            retry_timeout: Duration::from_secs(retry_timeout),
53        };
54
55        let mut s3_storage_options = storage_options.as_s3_options();
56        let region = resolve_s3_region(base_path, &s3_storage_options).await?;
57
58        // Get accessor from params
59        let accessor = params.get_accessor();
60
61        let (aws_creds, region) = build_aws_credential(
62            params.s3_credentials_refresh_offset,
63            params.aws_credentials.clone(),
64            Some(&s3_storage_options),
65            region,
66            accessor,
67        )
68        .await?;
69
70        // Set S3Express flag if detected
71        if is_s3_express {
72            s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
73        }
74
75        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
76        base_path.set_scheme("s3").unwrap();
77        base_path.set_query(None);
78
79        // we can't use parse_url_opts here because we need to manually set the credentials provider
80        let mut builder =
81            AmazonS3Builder::new().with_client_options(storage_options.client_options()?);
82        for (key, value) in s3_storage_options {
83            builder = builder.with_config(key, value);
84        }
85        builder = builder
86            .with_url(base_path.as_ref())
87            .with_credentials(aws_creds)
88            .with_retry(retry_config)
89            .with_region(region);
90
91        Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
92    }
93
94    async fn build_opendal_s3_store(
95        &self,
96        base_path: &Url,
97        storage_options: &StorageOptions,
98    ) -> Result<Arc<dyn OSObjectStore>> {
99        let bucket = base_path
100            .host_str()
101            .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name"))?
102            .to_string();
103
104        let prefix = base_path.path().trim_start_matches('/').to_string();
105
106        // Start with all storage options as the config map
107        // OpenDAL will handle environment variables through its default credentials chain
108        let mut config_map: HashMap<String, String> = storage_options.0.clone();
109
110        // Set required OpenDAL configuration
111        config_map.insert("bucket".to_string(), bucket);
112
113        if !prefix.is_empty() {
114            config_map.insert("root".to_string(), "/".to_string());
115        }
116
117        let operator = Operator::from_iter::<S3>(config_map)
118            .map_err(|e| Error::invalid_input(format!("Failed to create S3 operator: {:?}", e)))?
119            .finish();
120
121        Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
122    }
123}
124
125#[async_trait::async_trait]
126impl ObjectStoreProvider for AwsStoreProvider {
127    async fn new_store(
128        &self,
129        mut base_path: Url,
130        params: &ObjectStoreParams,
131    ) -> Result<ObjectStore> {
132        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
133        let mut storage_options =
134            StorageOptions(params.storage_options().cloned().unwrap_or_default());
135        storage_options.with_env_s3();
136        let download_retry_count = storage_options.download_retry_count();
137
138        let use_opendal = storage_options
139            .0
140            .get("use_opendal")
141            .map(|v| v == "true")
142            .unwrap_or(false);
143
144        // Determine S3 Express and constant size upload parts before building the store
145        let is_s3_express = check_s3_express(&base_path, &storage_options);
146
147        let use_constant_size_upload_parts = storage_options
148            .0
149            .get("aws_endpoint")
150            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
151            .unwrap_or(false);
152
153        let inner = if use_opendal {
154            // Use OpenDAL implementation
155            self.build_opendal_s3_store(&base_path, &storage_options)
156                .await?
157        } else {
158            // Use default Amazon S3 implementation
159            self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
160                .await?
161        };
162
163        Ok(ObjectStore {
164            inner,
165            scheme: String::from(base_path.scheme()),
166            block_size,
167            max_iop_size: *DEFAULT_MAX_IOP_SIZE,
168            use_constant_size_upload_parts,
169            list_is_lexically_ordered: !is_s3_express,
170            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
171            download_retry_count,
172            io_tracker: Default::default(),
173            store_prefix: self
174                .calculate_object_store_prefix(&base_path, params.storage_options())?,
175        })
176    }
177}
178
179/// Check if the storage is S3 Express
180fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
181    storage_options
182        .0
183        .get("s3_express")
184        .map(|v| v == "true")
185        .unwrap_or(false)
186        || url.authority().ends_with("--x-s3")
187}
188
189/// Figure out the S3 region of the bucket.
190///
191/// This resolves in order of precedence:
192/// 1. The region provided in the storage options
193/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
194///
195/// It can return None if no region is provided and the endpoint is set.
196async fn resolve_s3_region(
197    url: &Url,
198    storage_options: &HashMap<AmazonS3ConfigKey, String>,
199) -> Result<Option<String>> {
200    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
201        Ok(Some(region.clone()))
202    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
203        // If no endpoint is set, we can assume this is AWS S3 and the region
204        // can be resolved from the bucket.
205        let bucket = url.host_str().ok_or_else(|| {
206            Error::invalid_input(format!("Could not parse bucket from url: {}", url))
207        })?;
208
209        let mut client_options = ClientOptions::default();
210        for (key, value) in storage_options {
211            if let AmazonS3ConfigKey::Client(client_key) = key {
212                client_options = client_options.with_config(*client_key, value.clone());
213            }
214        }
215
216        let bucket_region =
217            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
218        Ok(Some(bucket_region))
219    } else {
220        Ok(None)
221    }
222}
223
224/// Build AWS credentials
225///
226/// This resolves credentials from the following sources in order:
227/// 1. An explicit `storage_options_accessor` with a provider
228/// 2. An explicit `credentials` provider
229/// 3. Explicit credentials in storage_options (as in `aws_access_key_id`,
230///    `aws_secret_access_key`, `aws_session_token`)
231/// 4. The default credential provider chain from AWS SDK.
232///
233/// # Storage Options Accessor
234///
235/// When `storage_options_accessor` is provided and has a dynamic provider,
236/// credentials are fetched and cached by the accessor with automatic refresh
237/// before expiration.
238///
239/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
240pub async fn build_aws_credential(
241    credentials_refresh_offset: Duration,
242    credentials: Option<AwsCredentialProvider>,
243    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
244    region: Option<String>,
245    storage_options_accessor: Option<Arc<StorageOptionsAccessor>>,
246) -> Result<(AwsCredentialProvider, String)> {
247    use aws_config::meta::region::RegionProviderChain;
248    const DEFAULT_REGION: &str = "us-west-2";
249
250    let region = if let Some(region) = region {
251        region
252    } else {
253        RegionProviderChain::default_provider()
254            .or_else(DEFAULT_REGION)
255            .region()
256            .await
257            .map(|r| r.as_ref().to_string())
258            .unwrap_or(DEFAULT_REGION.to_string())
259    };
260
261    let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
262
263    // If accessor has a provider, check whether it vends credentials.
264    // If it does, use DynamicStorageOptionsCredentialProvider for ongoing
265    // refresh. If not, fall through to the default credentials chain.
266    if let Some(accessor) = storage_options_accessor
267        && accessor.has_provider()
268    {
269        // Explicit aws_credentials takes precedence
270        if let Some(creds) = credentials {
271            return Ok((creds, region));
272        }
273
274        // Check if the accessor's storage options contain credentials
275        let opts = accessor.get_storage_options().await?;
276        let s3_options = opts.as_s3_options();
277        if extract_static_s3_credentials(&s3_options).is_some() {
278            return Ok((
279                Arc::new(DynamicStorageOptionsCredentialProvider::new(accessor)),
280                region,
281            ));
282        }
283
284        log::debug!(
285            "Storage options from provider do not contain explicit AWS credentials, \
286             falling back to default AWS credentials chain."
287        );
288    }
289
290    // Fall back to existing logic for static credentials
291    if let Some(creds) = credentials {
292        Ok((creds, region))
293    } else if let Some(creds) = storage_options_credentials {
294        Ok((Arc::new(creds), region))
295    } else {
296        let credentials_provider = DefaultCredentialsChain::builder().build().await;
297
298        Ok((
299            Arc::new(AwsCredentialAdapter::new(
300                Arc::new(credentials_provider),
301                credentials_refresh_offset,
302            )),
303            region,
304        ))
305    }
306}
307
308fn extract_static_s3_credentials(
309    options: &HashMap<AmazonS3ConfigKey, String>,
310) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
311    let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
312    let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
313    let token = options.get(&AmazonS3ConfigKey::Token).cloned();
314    match (key_id, secret_key, token) {
315        (Some(key_id), Some(secret_key), token) => {
316            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
317                key_id,
318                secret_key,
319                token,
320            }))
321        }
322        _ => None,
323    }
324}
325
326/// Adapt an AWS SDK cred into object_store credentials
327#[derive(Debug)]
328pub struct AwsCredentialAdapter {
329    pub inner: Arc<dyn ProvideCredentials>,
330
331    // RefCell can't be shared across threads, so we use HashMap
332    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
333
334    // The amount of time before expiry to refresh credentials
335    credentials_refresh_offset: Duration,
336}
337
338impl AwsCredentialAdapter {
339    pub fn new(
340        provider: Arc<dyn ProvideCredentials>,
341        credentials_refresh_offset: Duration,
342    ) -> Self {
343        Self {
344            inner: provider,
345            cache: Arc::new(RwLock::new(HashMap::new())),
346            credentials_refresh_offset,
347        }
348    }
349}
350
351const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
352
353/// Convert std::time::SystemTime from AWS SDK to our mockable SystemTime
354fn to_system_time(time: std::time::SystemTime) -> SystemTime {
355    let duration_since_epoch = time
356        .duration_since(std::time::UNIX_EPOCH)
357        .expect("time should be after UNIX_EPOCH");
358    UNIX_EPOCH + duration_since_epoch
359}
360
361#[async_trait::async_trait]
362impl CredentialProvider for AwsCredentialAdapter {
363    type Credential = ObjectStoreAwsCredential;
364
365    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
366        let cached_creds = {
367            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
368            let expired = cache_value
369                .clone()
370                .map(|cred| {
371                    cred.expiry()
372                        .map(|exp| {
373                            to_system_time(exp)
374                                .checked_sub(self.credentials_refresh_offset)
375                                .expect("this time should always be valid")
376                                < SystemTime::now()
377                        })
378                        // no expiry is never expire
379                        .unwrap_or(false)
380                })
381                .unwrap_or(true); // no cred is the same as expired;
382            if expired { None } else { cache_value.clone() }
383        };
384
385        if let Some(creds) = cached_creds {
386            Ok(Arc::new(Self::Credential {
387                key_id: creds.access_key_id().to_string(),
388                secret_key: creds.secret_access_key().to_string(),
389                token: creds.session_token().map(|s| s.to_string()),
390            }))
391        } else {
392            let refreshed_creds =
393                Arc::new(self.inner.provide_credentials().await.map_err(|e| {
394                    Error::internal(format!("Failed to get AWS credentials: {:?}", e))
395                })?);
396
397            self.cache
398                .write()
399                .await
400                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
401
402            Ok(Arc::new(Self::Credential {
403                key_id: refreshed_creds.access_key_id().to_string(),
404                secret_key: refreshed_creds.secret_access_key().to_string(),
405                token: refreshed_creds.session_token().map(|s| s.to_string()),
406            }))
407        }
408    }
409}
410
411impl StorageOptions {
412    /// Add values from the environment to storage options
413    pub fn with_env_s3(&mut self) {
414        for (os_key, os_value) in std::env::vars_os() {
415            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str())
416                && let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase())
417                && !self.0.contains_key(config_key.as_ref())
418            {
419                self.0
420                    .insert(config_key.as_ref().to_string(), value.to_string());
421            }
422        }
423    }
424
425    /// Subset of options relevant for s3 storage
426    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
427        self.0
428            .iter()
429            .filter_map(|(key, value)| {
430                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
431                Some((s3_key, value.clone()))
432            })
433            .collect()
434    }
435}
436
437impl ObjectStoreParams {
438    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
439    pub fn with_aws_credentials(
440        aws_credentials: Option<AwsCredentialProvider>,
441        region: Option<String>,
442    ) -> Self {
443        let storage_options_accessor = region.map(|region| {
444            let opts: HashMap<String, String> =
445                [("region".into(), region)].iter().cloned().collect();
446            Arc::new(StorageOptionsAccessor::with_static_options(opts))
447        });
448        Self {
449            aws_credentials,
450            storage_options_accessor,
451            ..Default::default()
452        }
453    }
454}
455
456/// AWS Credential Provider that delegates to StorageOptionsAccessor
457///
458/// This adapter converts storage options from a [`StorageOptionsAccessor`] into
459/// AWS-specific credentials that can be used with S3. All caching and refresh logic
460/// is handled by the accessor.
461///
462/// # Future Work
463///
464/// TODO: Support AWS/GCP/Azure together in a unified credential provider.
465/// Currently this is AWS-specific. Needs investigation of how GCP and Azure credential
466/// refresh mechanisms work and whether they can be unified with AWS's approach.
467///
468/// See: <https://github.com/lance-format/lance/pull/4905#discussion_r2474605265>
469pub struct DynamicStorageOptionsCredentialProvider {
470    accessor: Arc<StorageOptionsAccessor>,
471}
472
473impl fmt::Debug for DynamicStorageOptionsCredentialProvider {
474    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
475        f.debug_struct("DynamicStorageOptionsCredentialProvider")
476            .field("accessor", &self.accessor)
477            .finish()
478    }
479}
480
481impl DynamicStorageOptionsCredentialProvider {
482    /// Create a new credential provider from a storage options accessor
483    pub fn new(accessor: Arc<StorageOptionsAccessor>) -> Self {
484        Self { accessor }
485    }
486
487    /// Create a new credential provider from a storage options provider
488    ///
489    /// This is a convenience constructor for backward compatibility.
490    /// The refresh offset will be extracted from storage options using
491    /// the `refresh_offset_millis` key, defaulting to 60 seconds.
492    ///
493    /// # Arguments
494    /// * `provider` - The storage options provider
495    pub fn from_provider(provider: Arc<dyn StorageOptionsProvider>) -> Self {
496        Self {
497            accessor: Arc::new(StorageOptionsAccessor::with_provider(provider)),
498        }
499    }
500
501    /// Create a new credential provider with initial credentials
502    ///
503    /// This is a convenience constructor for backward compatibility.
504    /// The refresh offset will be extracted from initial_options using
505    /// the `refresh_offset_millis` key, defaulting to 60 seconds.
506    ///
507    /// # Arguments
508    /// * `provider` - The storage options provider
509    /// * `initial_options` - Initial storage options to cache
510    pub fn from_provider_with_initial(
511        provider: Arc<dyn StorageOptionsProvider>,
512        initial_options: HashMap<String, String>,
513    ) -> Self {
514        Self {
515            accessor: Arc::new(StorageOptionsAccessor::with_initial_and_provider(
516                initial_options,
517                provider,
518            )),
519        }
520    }
521}
522
523#[async_trait::async_trait]
524impl CredentialProvider for DynamicStorageOptionsCredentialProvider {
525    type Credential = ObjectStoreAwsCredential;
526
527    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
528        let storage_options = self.accessor.get_storage_options().await.map_err(|e| {
529            object_store::Error::Generic {
530                store: "DynamicStorageOptionsCredentialProvider",
531                source: Box::new(e),
532            }
533        })?;
534
535        let s3_options = storage_options.as_s3_options();
536        let static_creds = extract_static_s3_credentials(&s3_options).ok_or_else(|| {
537            object_store::Error::Generic {
538                store: "DynamicStorageOptionsCredentialProvider",
539                source: "Missing required credentials in storage options".into(),
540            }
541        })?;
542
543        static_creds
544            .get_credential()
545            .await
546            .map_err(|e| object_store::Error::Generic {
547                store: "DynamicStorageOptionsCredentialProvider",
548                source: Box::new(e),
549            })
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    use crate::object_store::ObjectStoreRegistry;
556    use mock_instant::thread_local::MockClock;
557    use object_store::path::Path;
558    use std::sync::atomic::{AtomicBool, Ordering};
559
560    use super::*;
561
562    #[derive(Debug, Default)]
563    struct MockAwsCredentialsProvider {
564        called: AtomicBool,
565    }
566
567    #[async_trait::async_trait]
568    impl CredentialProvider for MockAwsCredentialsProvider {
569        type Credential = ObjectStoreAwsCredential;
570
571        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
572            self.called.store(true, Ordering::Relaxed);
573            Ok(Arc::new(Self::Credential {
574                key_id: "".to_string(),
575                secret_key: "".to_string(),
576                token: None,
577            }))
578        }
579    }
580
581    #[tokio::test]
582    async fn test_injected_aws_creds_option_is_used() {
583        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
584        let registry = Arc::new(ObjectStoreRegistry::default());
585
586        let params = ObjectStoreParams {
587            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
588            ..ObjectStoreParams::default()
589        };
590
591        // Not called yet
592        assert!(!mock_provider.called.load(Ordering::Relaxed));
593
594        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
595            .await
596            .unwrap();
597
598        // fails, but we don't care
599        let _ = store
600            .open(&Path::parse("/").unwrap())
601            .await
602            .unwrap()
603            .get_range(0..1)
604            .await;
605
606        // Not called yet
607        assert!(mock_provider.called.load(Ordering::Relaxed));
608    }
609
610    #[test]
611    fn test_s3_path_parsing() {
612        let provider = AwsStoreProvider;
613
614        let cases = [
615            ("s3://bucket/path/to/file", "path/to/file"),
616            // for non ASCII string tests
617            ("s3://bucket/测试path/to/file", "测试path/to/file"),
618            ("s3://bucket/path/&to/file", "path/&to/file"),
619            ("s3://bucket/path/=to/file", "path/=to/file"),
620            (
621                "s3+ddb://bucket/path/to/file?ddbTableName=test",
622                "path/to/file",
623            ),
624        ];
625
626        for (uri, expected_path) in cases {
627            let url = Url::parse(uri).unwrap();
628            let path = provider.extract_path(&url).unwrap();
629            let expected_path = Path::from(expected_path);
630            assert_eq!(path, expected_path)
631        }
632    }
633
634    #[test]
635    fn test_is_s3_express() {
636        let cases = [
637            (
638                "s3://bucket/path/to/file",
639                HashMap::from([("s3_express".to_string(), "true".to_string())]),
640                true,
641            ),
642            (
643                "s3://bucket/path/to/file",
644                HashMap::from([("s3_express".to_string(), "false".to_string())]),
645                false,
646            ),
647            ("s3://bucket/path/to/file", HashMap::from([]), false),
648            (
649                "s3://bucket--x-s3/path/to/file",
650                HashMap::from([("s3_express".to_string(), "true".to_string())]),
651                true,
652            ),
653            (
654                "s3://bucket--x-s3/path/to/file",
655                HashMap::from([("s3_express".to_string(), "false".to_string())]),
656                true, // URL takes precedence
657            ),
658            ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
659        ];
660
661        for (uri, storage_map, expected) in cases {
662            let url = Url::parse(uri).unwrap();
663            let storage_options = StorageOptions(storage_map);
664            let is_s3_express = check_s3_express(&url, &storage_options);
665            assert_eq!(is_s3_express, expected);
666        }
667    }
668
669    #[tokio::test]
670    async fn test_use_opendal_flag() {
671        use crate::object_store::StorageOptionsAccessor;
672        let provider = AwsStoreProvider;
673        let url = Url::parse("s3://test-bucket/path").unwrap();
674        let params_with_flag = ObjectStoreParams {
675            storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
676                HashMap::from([
677                    ("use_opendal".to_string(), "true".to_string()),
678                    ("region".to_string(), "us-west-2".to_string()),
679                ]),
680            ))),
681            ..Default::default()
682        };
683
684        let store = provider
685            .new_store(url.clone(), &params_with_flag)
686            .await
687            .unwrap();
688        assert_eq!(store.scheme, "s3");
689    }
690
691    #[derive(Debug)]
692    struct MockStorageOptionsProvider {
693        call_count: Arc<RwLock<usize>>,
694        expires_in_millis: Option<u64>,
695    }
696
697    impl MockStorageOptionsProvider {
698        fn new(expires_in_millis: Option<u64>) -> Self {
699            Self {
700                call_count: Arc::new(RwLock::new(0)),
701                expires_in_millis,
702            }
703        }
704
705        async fn get_call_count(&self) -> usize {
706            *self.call_count.read().await
707        }
708    }
709
710    #[async_trait::async_trait]
711    impl StorageOptionsProvider for MockStorageOptionsProvider {
712        async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
713            let count = {
714                let mut c = self.call_count.write().await;
715                *c += 1;
716                *c
717            };
718
719            let mut options = HashMap::from([
720                ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
721                (
722                    "aws_secret_access_key".to_string(),
723                    format!("SECRET_{}", count),
724                ),
725                ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
726            ]);
727
728            if let Some(expires_in) = self.expires_in_millis {
729                let now_ms = SystemTime::now()
730                    .duration_since(UNIX_EPOCH)
731                    .unwrap()
732                    .as_millis() as u64;
733                let expires_at = now_ms + expires_in;
734                options.insert("expires_at_millis".to_string(), expires_at.to_string());
735            }
736
737            Ok(Some(options))
738        }
739
740        fn provider_id(&self) -> String {
741            let ptr = Arc::as_ptr(&self.call_count) as usize;
742            format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
743        }
744    }
745
746    #[tokio::test]
747    async fn test_dynamic_credential_provider_with_initial_cache() {
748        MockClock::set_system_time(Duration::from_secs(100_000));
749
750        let now_ms = MockClock::system_time().as_millis() as u64;
751
752        // Create a mock provider that returns credentials expiring in 10 minutes
753        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
754            600_000, // Expires in 10 minutes
755        )));
756
757        // Create initial options with cached credentials that expire in 10 minutes
758        let expires_at = now_ms + 600_000; // 10 minutes from now
759        let initial_options = HashMap::from([
760            ("aws_access_key_id".to_string(), "AKID_CACHED".to_string()),
761            (
762                "aws_secret_access_key".to_string(),
763                "SECRET_CACHED".to_string(),
764            ),
765            ("aws_session_token".to_string(), "TOKEN_CACHED".to_string()),
766            ("expires_at_millis".to_string(), expires_at.to_string()),
767            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
768        ]);
769
770        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
771            mock.clone(),
772            initial_options,
773        );
774
775        // First call should use cached credentials (not expired yet)
776        let cred = provider.get_credential().await.unwrap();
777        assert_eq!(cred.key_id, "AKID_CACHED");
778        assert_eq!(cred.secret_key, "SECRET_CACHED");
779        assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
780
781        // Should not have called the provider yet
782        assert_eq!(mock.get_call_count().await, 0);
783    }
784
785    #[tokio::test]
786    async fn test_dynamic_credential_provider_with_expired_cache() {
787        MockClock::set_system_time(Duration::from_secs(100_000));
788
789        let now_ms = MockClock::system_time().as_millis() as u64;
790
791        // Create a mock provider that returns credentials expiring in 10 minutes
792        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
793            600_000, // Expires in 10 minutes
794        )));
795
796        // Create initial options with credentials that expired 1 second ago
797        let expired_time = now_ms - 1_000; // 1 second ago
798        let initial_options = HashMap::from([
799            ("aws_access_key_id".to_string(), "AKID_EXPIRED".to_string()),
800            (
801                "aws_secret_access_key".to_string(),
802                "SECRET_EXPIRED".to_string(),
803            ),
804            ("expires_at_millis".to_string(), expired_time.to_string()),
805            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
806        ]);
807
808        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
809            mock.clone(),
810            initial_options,
811        );
812
813        // First call should fetch new credentials because cached ones are expired
814        let cred = provider.get_credential().await.unwrap();
815        assert_eq!(cred.key_id, "AKID_1");
816        assert_eq!(cred.secret_key, "SECRET_1");
817        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
818
819        // Should have called the provider once
820        assert_eq!(mock.get_call_count().await, 1);
821    }
822
823    #[tokio::test]
824    async fn test_dynamic_credential_provider_refresh_lead_time() {
825        MockClock::set_system_time(Duration::from_secs(100_000));
826
827        // Create a mock provider that returns credentials expiring in 30 seconds
828        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
829            30_000, // Expires in 30 seconds
830        )));
831
832        // Create credential provider with default 60 second refresh offset
833        // This means credentials should be refreshed when they have less than 60 seconds left
834        let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
835
836        // First call should fetch credentials from provider (no initial cache)
837        // Credentials expire in 30 seconds, which is less than our 60 second refresh offset,
838        // so they should be considered "needs refresh" immediately
839        let cred = provider.get_credential().await.unwrap();
840        assert_eq!(cred.key_id, "AKID_1");
841        assert_eq!(mock.get_call_count().await, 1);
842
843        // Second call should trigger refresh because credentials expire in 30 seconds
844        // but our refresh lead time is 60 seconds (now + 60sec > expires_at)
845        // The mock will return new credentials (AKID_2) with the same expiration
846        let cred = provider.get_credential().await.unwrap();
847        assert_eq!(cred.key_id, "AKID_2");
848        assert_eq!(mock.get_call_count().await, 2);
849    }
850
851    #[tokio::test]
852    async fn test_dynamic_credential_provider_no_initial_cache() {
853        MockClock::set_system_time(Duration::from_secs(100_000));
854
855        // Create a mock provider that returns credentials expiring in 2 minutes
856        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
857            120_000, // Expires in 2 minutes
858        )));
859
860        // Create credential provider without initial cache, using default 60 second refresh offset
861        let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
862
863        // First call should fetch from provider (call count = 1)
864        let cred = provider.get_credential().await.unwrap();
865        assert_eq!(cred.key_id, "AKID_1");
866        assert_eq!(cred.secret_key, "SECRET_1");
867        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
868        assert_eq!(mock.get_call_count().await, 1);
869
870        // Second call should use cached credentials (not expired yet, still > 60 seconds remaining)
871        let cred = provider.get_credential().await.unwrap();
872        assert_eq!(cred.key_id, "AKID_1");
873        assert_eq!(mock.get_call_count().await, 1); // Still 1, didn't fetch again
874
875        // Advance time to 90 seconds - should trigger refresh (within 60 sec refresh offset)
876        // At this point, credentials expire in 30 seconds (< 60 sec offset)
877        MockClock::set_system_time(Duration::from_secs(100_000 + 90));
878        let cred = provider.get_credential().await.unwrap();
879        assert_eq!(cred.key_id, "AKID_2");
880        assert_eq!(cred.secret_key, "SECRET_2");
881        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
882        assert_eq!(mock.get_call_count().await, 2);
883
884        // Advance time to 210 seconds total (90 + 120) - should trigger another refresh
885        MockClock::set_system_time(Duration::from_secs(100_000 + 210));
886        let cred = provider.get_credential().await.unwrap();
887        assert_eq!(cred.key_id, "AKID_3");
888        assert_eq!(cred.secret_key, "SECRET_3");
889        assert_eq!(mock.get_call_count().await, 3);
890    }
891
892    #[tokio::test]
893    async fn test_dynamic_credential_provider_with_initial_options() {
894        MockClock::set_system_time(Duration::from_secs(100_000));
895
896        let now_ms = MockClock::system_time().as_millis() as u64;
897
898        // Create a mock provider that returns credentials expiring in 10 minutes
899        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
900            600_000, // Expires in 10 minutes
901        )));
902
903        // Create initial options with expiration in 10 minutes
904        let expires_at = now_ms + 600_000; // 10 minutes from now
905        let initial_options = HashMap::from([
906            ("aws_access_key_id".to_string(), "AKID_INITIAL".to_string()),
907            (
908                "aws_secret_access_key".to_string(),
909                "SECRET_INITIAL".to_string(),
910            ),
911            ("aws_session_token".to_string(), "TOKEN_INITIAL".to_string()),
912            ("expires_at_millis".to_string(), expires_at.to_string()),
913            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
914        ]);
915
916        // Create credential provider with initial options
917        let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
918            mock.clone(),
919            initial_options,
920        );
921
922        // First call should use the initial credential (not expired yet)
923        let cred = provider.get_credential().await.unwrap();
924        assert_eq!(cred.key_id, "AKID_INITIAL");
925        assert_eq!(cred.secret_key, "SECRET_INITIAL");
926        assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
927
928        // Should not have called the provider yet
929        assert_eq!(mock.get_call_count().await, 0);
930
931        // Advance time to 6 minutes - this should trigger a refresh
932        // (5 minute refresh offset means we refresh 5 minutes before expiration)
933        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
934        let cred = provider.get_credential().await.unwrap();
935        assert_eq!(cred.key_id, "AKID_1");
936        assert_eq!(cred.secret_key, "SECRET_1");
937        assert_eq!(cred.token, Some("TOKEN_1".to_string()));
938
939        // Should have called the provider once
940        assert_eq!(mock.get_call_count().await, 1);
941
942        // Advance time to 11 minutes total - this should trigger another refresh
943        MockClock::set_system_time(Duration::from_secs(100_000 + 660));
944        let cred = provider.get_credential().await.unwrap();
945        assert_eq!(cred.key_id, "AKID_2");
946        assert_eq!(cred.secret_key, "SECRET_2");
947        assert_eq!(cred.token, Some("TOKEN_2".to_string()));
948
949        // Should have called the provider twice
950        assert_eq!(mock.get_call_count().await, 2);
951
952        // Advance time to 16 minutes total - this should trigger yet another refresh
953        MockClock::set_system_time(Duration::from_secs(100_000 + 960));
954        let cred = provider.get_credential().await.unwrap();
955        assert_eq!(cred.key_id, "AKID_3");
956        assert_eq!(cred.secret_key, "SECRET_3");
957        assert_eq!(cred.token, Some("TOKEN_3".to_string()));
958
959        // Should have called the provider three times
960        assert_eq!(mock.get_call_count().await, 3);
961    }
962
963    #[tokio::test]
964    async fn test_dynamic_credential_provider_concurrent_access() {
965        // Create a mock provider with far future expiration
966        let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
967
968        let provider = Arc::new(DynamicStorageOptionsCredentialProvider::from_provider(
969            mock.clone(),
970        ));
971
972        // Spawn 10 concurrent tasks that all try to get credentials at the same time
973        let mut handles = vec![];
974        for i in 0..10 {
975            let provider = provider.clone();
976            let handle = tokio::spawn(async move {
977                let cred = provider.get_credential().await.unwrap();
978                // Verify we got the correct credentials (should all be AKID_1 from first fetch)
979                assert_eq!(cred.key_id, "AKID_1");
980                assert_eq!(cred.secret_key, "SECRET_1");
981                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
982                i // Return task number for verification
983            });
984            handles.push(handle);
985        }
986
987        // Wait for all tasks to complete
988        let results: Vec<_> = futures::future::join_all(handles)
989            .await
990            .into_iter()
991            .map(|r| r.unwrap())
992            .collect();
993
994        // Verify all 10 tasks completed successfully
995        assert_eq!(results.len(), 10);
996        for i in 0..10 {
997            assert!(results.contains(&i));
998        }
999
1000        // The provider should have been called exactly once (first request triggers fetch,
1001        // subsequent requests use cache)
1002        let call_count = mock.get_call_count().await;
1003        assert_eq!(
1004            call_count, 1,
1005            "Provider should be called exactly once despite concurrent access"
1006        );
1007    }
1008
1009    #[tokio::test]
1010    async fn test_dynamic_credential_provider_concurrent_refresh() {
1011        MockClock::set_system_time(Duration::from_secs(100_000));
1012
1013        let now_ms = MockClock::system_time().as_millis() as u64;
1014
1015        // Create initial options with credentials that expired in the past (1000 seconds ago)
1016        let expires_at = now_ms - 1_000_000;
1017        let initial_options = HashMap::from([
1018            ("aws_access_key_id".to_string(), "AKID_OLD".to_string()),
1019            (
1020                "aws_secret_access_key".to_string(),
1021                "SECRET_OLD".to_string(),
1022            ),
1023            ("aws_session_token".to_string(), "TOKEN_OLD".to_string()),
1024            ("expires_at_millis".to_string(), expires_at.to_string()),
1025            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
1026        ]);
1027
1028        // Mock will return credentials expiring in 1 hour
1029        let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1030            3_600_000, // Expires in 1 hour
1031        )));
1032
1033        let provider = Arc::new(
1034            DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
1035                mock.clone(),
1036                initial_options,
1037            ),
1038        );
1039
1040        // Spawn 20 concurrent tasks that all try to get credentials at the same time
1041        // Since the initial credential is expired, they'll all try to refresh
1042        let mut handles = vec![];
1043        for i in 0..20 {
1044            let provider = provider.clone();
1045            let handle = tokio::spawn(async move {
1046                let cred = provider.get_credential().await.unwrap();
1047                // All should get the new credentials (AKID_1 from first fetch)
1048                assert_eq!(cred.key_id, "AKID_1");
1049                assert_eq!(cred.secret_key, "SECRET_1");
1050                assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1051                i
1052            });
1053            handles.push(handle);
1054        }
1055
1056        // Wait for all tasks to complete
1057        let results: Vec<_> = futures::future::join_all(handles)
1058            .await
1059            .into_iter()
1060            .map(|r| r.unwrap())
1061            .collect();
1062
1063        // Verify all 20 tasks completed successfully
1064        assert_eq!(results.len(), 20);
1065
1066        // The provider should have been called at least once, but possibly more times
1067        // due to the try_write mechanism and race conditions
1068        let call_count = mock.get_call_count().await;
1069        assert!(
1070            call_count >= 1,
1071            "Provider should be called at least once, was called {} times",
1072            call_count
1073        );
1074
1075        // It shouldn't be called 20 times though - the lock should prevent most concurrent fetches
1076        assert!(
1077            call_count < 10,
1078            "Provider should not be called too many times due to lock contention, was called {} times",
1079            call_count
1080        );
1081    }
1082
1083    #[tokio::test]
1084    async fn test_explicit_aws_credentials_takes_precedence_over_accessor() {
1085        // Create a mock storage options provider that should NOT be called
1086        let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1087
1088        // Create an accessor with the mock provider
1089        let accessor = Arc::new(StorageOptionsAccessor::with_provider(
1090            mock_storage_provider.clone(),
1091        ));
1092
1093        // Create an explicit AWS credentials provider
1094        let explicit_cred_provider = Arc::new(MockAwsCredentialsProvider::default());
1095
1096        // Build credentials with both aws_credentials AND accessor
1097        // The explicit aws_credentials should take precedence
1098        let (result, _region) = build_aws_credential(
1099            Duration::from_secs(300),
1100            Some(explicit_cred_provider.clone() as AwsCredentialProvider),
1101            None, // no storage_options
1102            Some("us-west-2".to_string()),
1103            Some(accessor),
1104        )
1105        .await
1106        .unwrap();
1107
1108        // Get credential from the result
1109        let cred = result.get_credential().await.unwrap();
1110
1111        // The explicit provider should have been called (it returns empty strings)
1112        assert!(explicit_cred_provider.called.load(Ordering::Relaxed));
1113
1114        // The storage options provider should NOT have been called
1115        assert_eq!(
1116            mock_storage_provider.get_call_count().await,
1117            0,
1118            "Storage options provider should not be called when explicit aws_credentials is provided"
1119        );
1120
1121        // Verify we got credentials from the explicit provider (empty strings)
1122        assert_eq!(cred.key_id, "");
1123        assert_eq!(cred.secret_key, "");
1124    }
1125
1126    #[tokio::test]
1127    async fn test_accessor_used_when_no_explicit_aws_credentials() {
1128        MockClock::set_system_time(Duration::from_secs(100_000));
1129
1130        let now_ms = MockClock::system_time().as_millis() as u64;
1131
1132        // Create a mock storage options provider
1133        let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1134
1135        // Create initial options
1136        let expires_at = now_ms + 600_000; // 10 minutes from now
1137        let initial_options = HashMap::from([
1138            (
1139                "aws_access_key_id".to_string(),
1140                "AKID_FROM_ACCESSOR".to_string(),
1141            ),
1142            (
1143                "aws_secret_access_key".to_string(),
1144                "SECRET_FROM_ACCESSOR".to_string(),
1145            ),
1146            (
1147                "aws_session_token".to_string(),
1148                "TOKEN_FROM_ACCESSOR".to_string(),
1149            ),
1150            ("expires_at_millis".to_string(), expires_at.to_string()),
1151            ("refresh_offset_millis".to_string(), "300000".to_string()), // 5 minute refresh offset
1152        ]);
1153
1154        // Create an accessor with initial options and provider
1155        let accessor = Arc::new(StorageOptionsAccessor::with_initial_and_provider(
1156            initial_options,
1157            mock_storage_provider.clone(),
1158        ));
1159
1160        // Build credentials with accessor but NO explicit aws_credentials
1161        let (result, _region) = build_aws_credential(
1162            Duration::from_secs(300),
1163            None, // no explicit aws_credentials
1164            None, // no storage_options
1165            Some("us-west-2".to_string()),
1166            Some(accessor),
1167        )
1168        .await
1169        .unwrap();
1170
1171        // Get credential - should use the initial accessor credentials
1172        let cred = result.get_credential().await.unwrap();
1173        assert_eq!(cred.key_id, "AKID_FROM_ACCESSOR");
1174        assert_eq!(cred.secret_key, "SECRET_FROM_ACCESSOR");
1175
1176        // Storage options provider should NOT have been called yet (using cached initial creds)
1177        assert_eq!(mock_storage_provider.get_call_count().await, 0);
1178
1179        // Advance time to trigger refresh (past the 5 minute refresh offset)
1180        MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1181
1182        // Get credential again - should now fetch from provider
1183        let cred = result.get_credential().await.unwrap();
1184        assert_eq!(cred.key_id, "AKID_1");
1185        assert_eq!(cred.secret_key, "SECRET_1");
1186
1187        // Storage options provider should have been called once
1188        assert_eq!(mock_storage_provider.get_call_count().await, 1);
1189    }
1190}