1use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{Operator, services::S3};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
20 StaticCredentialProvider,
21 aws::{
22 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
23 AwsCredentialProvider,
24 },
25};
26use tokio::sync::RwLock;
27use url::Url;
28
29use crate::object_store::{
30 DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE, ObjectStore,
31 ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsAccessor,
32 dynamic_credentials::{NamespaceCredentialsProvider, build_dynamic_credential_provider},
33 throttle::{AimdThrottleConfig, AimdThrottledStore},
34};
35use lance_core::error::{Error, Result};
36
37#[derive(Default, Debug)]
38pub struct AwsStoreProvider;
39
40impl AwsStoreProvider {
41 async fn build_amazon_s3_store(
42 &self,
43 base_path: &mut Url,
44 params: &ObjectStoreParams,
45 storage_options: &StorageOptions,
46 is_s3_express: bool,
47 ) -> Result<Arc<dyn OSObjectStore>> {
48 let retry_config = RetryConfig {
51 backoff: Default::default(),
52 max_retries: storage_options.client_max_retries(),
53 retry_timeout: Duration::from_secs(storage_options.client_retry_timeout()),
54 };
55
56 let mut s3_storage_options = storage_options.as_s3_options();
57 let region = resolve_s3_region(base_path, &s3_storage_options).await?;
58
59 let accessor = params.get_accessor();
61
62 let (aws_creds, region) = build_aws_credential(
63 params.s3_credentials_refresh_offset,
64 params.aws_credentials.clone(),
65 Some(&s3_storage_options),
66 region,
67 accessor,
68 )
69 .await?;
70
71 if is_s3_express {
73 s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
74 }
75
76 base_path.set_scheme("s3").unwrap();
78 base_path.set_query(None);
79
80 let mut builder =
82 AmazonS3Builder::new().with_client_options(storage_options.client_options()?);
83 for (key, value) in s3_storage_options {
84 builder = builder.with_config(key, value);
85 }
86 builder = builder
87 .with_url(base_path.as_ref())
88 .with_credentials(aws_creds)
89 .with_retry(retry_config)
90 .with_region(region);
91
92 Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
93 }
94
95 async fn build_opendal_s3_store(
96 &self,
97 base_path: &Url,
98 storage_options: &StorageOptions,
99 ) -> Result<Arc<dyn OSObjectStore>> {
100 let bucket = base_path
101 .host_str()
102 .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name"))?
103 .to_string();
104
105 let prefix = base_path.path().trim_start_matches('/').to_string();
106
107 let mut config_map: HashMap<String, String> = storage_options.0.clone();
110
111 config_map.insert("bucket".to_string(), bucket);
113
114 if !prefix.is_empty() {
115 config_map.insert("root".to_string(), "/".to_string());
116 }
117
118 let operator = Operator::from_iter::<S3>(config_map)
119 .map_err(|e| Error::invalid_input(format!("Failed to create S3 operator: {:?}", e)))?
120 .finish();
121
122 Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
123 }
124}
125
126#[async_trait::async_trait]
127impl ObjectStoreProvider for AwsStoreProvider {
128 async fn new_store(
129 &self,
130 mut base_path: Url,
131 params: &ObjectStoreParams,
132 ) -> Result<ObjectStore> {
133 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
134 let mut storage_options =
135 StorageOptions::new(params.storage_options().cloned().unwrap_or_default());
136 storage_options.with_env_s3();
137 let download_retry_count = storage_options.download_retry_count();
138
139 let use_opendal = storage_options
140 .0
141 .get("use_opendal")
142 .map(|v| v == "true")
143 .unwrap_or(false);
144
145 let is_s3_express = check_s3_express(&base_path, &storage_options);
147
148 let use_constant_size_upload_parts = storage_options
149 .0
150 .get("aws_endpoint")
151 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
152 .unwrap_or(false);
153
154 let inner = if use_opendal {
155 self.build_opendal_s3_store(&base_path, &storage_options)
157 .await?
158 } else {
159 self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
161 .await?
162 };
163 let throttle_config = AimdThrottleConfig::from_storage_options(params.storage_options())?;
164 let inner = if throttle_config.is_disabled() {
165 inner
166 } else if storage_options.client_max_retries() == 0 {
167 log::warn!(
168 "AIMD throttle disabled: the current implementation relies on the object store \
169 client surfacing retry errors, which requires client_max_retries > 0. \
170 No throttle or retry layer will be applied."
171 );
172 inner
173 } else {
174 Arc::new(AimdThrottledStore::new(inner, throttle_config)?) as Arc<dyn OSObjectStore>
175 };
176
177 Ok(ObjectStore {
178 inner,
179 scheme: String::from(base_path.scheme()),
180 block_size,
181 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
182 use_constant_size_upload_parts,
183 list_is_lexically_ordered: !is_s3_express,
184 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
185 download_retry_count,
186 io_tracker: Default::default(),
187 store_prefix: self
188 .calculate_object_store_prefix(&base_path, params.storage_options())?,
189 })
190 }
191}
192
193fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
195 storage_options
196 .0
197 .get("s3_express")
198 .map(|v| v == "true")
199 .unwrap_or(false)
200 || url.authority().ends_with("--x-s3")
201}
202
203async fn resolve_s3_region(
211 url: &Url,
212 storage_options: &HashMap<AmazonS3ConfigKey, String>,
213) -> Result<Option<String>> {
214 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
215 Ok(Some(region.clone()))
216 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
217 let bucket = url.host_str().ok_or_else(|| {
220 Error::invalid_input(format!("Could not parse bucket from url: {}", url))
221 })?;
222
223 let mut client_options = ClientOptions::default();
224 for (key, value) in storage_options {
225 if let AmazonS3ConfigKey::Client(client_key) = key {
226 client_options = client_options.with_config(*client_key, value.clone());
227 }
228 }
229
230 let bucket_region =
231 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
232 Ok(Some(bucket_region))
233 } else {
234 Ok(None)
235 }
236}
237
238pub async fn build_aws_credential(
255 credentials_refresh_offset: Duration,
256 credentials: Option<AwsCredentialProvider>,
257 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
258 region: Option<String>,
259 storage_options_accessor: Option<Arc<StorageOptionsAccessor>>,
260) -> Result<(AwsCredentialProvider, String)> {
261 use aws_config::meta::region::RegionProviderChain;
262 const DEFAULT_REGION: &str = "us-west-2";
263
264 let region = if let Some(region) = region {
265 region
266 } else {
267 RegionProviderChain::default_provider()
268 .or_else(DEFAULT_REGION)
269 .region()
270 .await
271 .map(|r| r.as_ref().to_string())
272 .unwrap_or(DEFAULT_REGION.to_string())
273 };
274
275 let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
276
277 if credentials.is_none()
279 && let Some(dynamic_creds) = build_dynamic_credential_provider::<ObjectStoreAwsCredential>(
280 storage_options_accessor.clone(),
281 )
282 .await?
283 {
284 return Ok((dynamic_creds, region));
285 }
286
287 if storage_options_accessor
288 .as_ref()
289 .is_some_and(|a| a.has_provider())
290 {
291 log::debug!(
292 "Storage options from provider do not contain explicit AWS credentials, \
293 falling back to default AWS credentials chain."
294 );
295 }
296
297 if let Some(creds) = credentials {
299 Ok((creds, region))
300 } else if let Some(creds) = storage_options_credentials {
301 Ok((Arc::new(creds), region))
302 } else {
303 let credentials_provider = DefaultCredentialsChain::builder().build().await;
304
305 Ok((
306 Arc::new(AwsCredentialAdapter::new(
307 Arc::new(credentials_provider),
308 credentials_refresh_offset,
309 )),
310 region,
311 ))
312 }
313}
314
315fn extract_static_s3_credentials(
316 options: &HashMap<AmazonS3ConfigKey, String>,
317) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
318 let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
319 let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
320 let token = options.get(&AmazonS3ConfigKey::Token).cloned();
321 match (key_id, secret_key, token) {
322 (Some(key_id), Some(secret_key), token) => {
323 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
324 key_id,
325 secret_key,
326 token,
327 }))
328 }
329 _ => None,
330 }
331}
332
333#[derive(Debug)]
335pub struct AwsCredentialAdapter {
336 pub inner: Arc<dyn ProvideCredentials>,
337
338 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
340
341 credentials_refresh_offset: Duration,
343}
344
345impl AwsCredentialAdapter {
346 pub fn new(
347 provider: Arc<dyn ProvideCredentials>,
348 credentials_refresh_offset: Duration,
349 ) -> Self {
350 Self {
351 inner: provider,
352 cache: Arc::new(RwLock::new(HashMap::new())),
353 credentials_refresh_offset,
354 }
355 }
356}
357
358const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
359
360fn to_system_time(time: std::time::SystemTime) -> SystemTime {
362 let duration_since_epoch = time
363 .duration_since(std::time::UNIX_EPOCH)
364 .expect("time should be after UNIX_EPOCH");
365 UNIX_EPOCH + duration_since_epoch
366}
367
368#[async_trait::async_trait]
369impl CredentialProvider for AwsCredentialAdapter {
370 type Credential = ObjectStoreAwsCredential;
371
372 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
373 let cached_creds = {
374 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
375 let expired = cache_value
376 .clone()
377 .map(|cred| {
378 cred.expiry()
379 .map(|exp| {
380 to_system_time(exp)
381 .checked_sub(self.credentials_refresh_offset)
382 .expect("this time should always be valid")
383 < SystemTime::now()
384 })
385 .unwrap_or(false)
387 })
388 .unwrap_or(true); if expired { None } else { cache_value.clone() }
390 };
391
392 if let Some(creds) = cached_creds {
393 Ok(Arc::new(Self::Credential {
394 key_id: creds.access_key_id().to_string(),
395 secret_key: creds.secret_access_key().to_string(),
396 token: creds.session_token().map(|s| s.to_string()),
397 }))
398 } else {
399 let refreshed_creds =
400 Arc::new(self.inner.provide_credentials().await.map_err(|e| {
401 Error::internal(format!("Failed to get AWS credentials: {:?}", e))
402 })?);
403
404 self.cache
405 .write()
406 .await
407 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
408
409 Ok(Arc::new(Self::Credential {
410 key_id: refreshed_creds.access_key_id().to_string(),
411 secret_key: refreshed_creds.secret_access_key().to_string(),
412 token: refreshed_creds.session_token().map(|s| s.to_string()),
413 }))
414 }
415 }
416}
417
418impl StorageOptions {
419 pub fn with_env_s3(&mut self) {
421 for (os_key, os_value) in std::env::vars_os() {
422 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str())
423 && let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase())
424 && !self.0.contains_key(config_key.as_ref())
425 {
426 self.0
427 .insert(config_key.as_ref().to_string(), value.to_string());
428 }
429 }
430 }
431
432 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
434 self.0
435 .iter()
436 .filter_map(|(key, value)| {
437 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
438 Some((s3_key, value.clone()))
439 })
440 .collect()
441 }
442}
443
444impl ObjectStoreParams {
445 pub fn with_aws_credentials(
447 aws_credentials: Option<AwsCredentialProvider>,
448 region: Option<String>,
449 ) -> Self {
450 let storage_options_accessor = region.map(|region| {
451 let opts: HashMap<String, String> =
452 [("region".into(), region)].iter().cloned().collect();
453 Arc::new(StorageOptionsAccessor::with_static_options(opts))
454 });
455 Self {
456 aws_credentials,
457 storage_options_accessor,
458 ..Default::default()
459 }
460 }
461}
462
463pub type DynamicStorageOptionsCredentialProvider =
464 NamespaceCredentialsProvider<ObjectStoreAwsCredential>;
465
466#[cfg(test)]
467mod tests {
468 use crate::object_store::ObjectStoreRegistry;
469 use crate::object_store::StorageOptionsProvider;
470 use mock_instant::thread_local::MockClock;
471 use object_store::path::Path;
472 use std::sync::atomic::{AtomicBool, Ordering};
473
474 use super::*;
475
476 #[derive(Debug, Default)]
477 struct MockAwsCredentialsProvider {
478 called: AtomicBool,
479 }
480
481 #[async_trait::async_trait]
482 impl CredentialProvider for MockAwsCredentialsProvider {
483 type Credential = ObjectStoreAwsCredential;
484
485 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
486 self.called.store(true, Ordering::Relaxed);
487 Ok(Arc::new(Self::Credential {
488 key_id: "".to_string(),
489 secret_key: "".to_string(),
490 token: None,
491 }))
492 }
493 }
494
495 #[tokio::test]
496 async fn test_injected_aws_creds_option_is_used() {
497 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
498 let registry = Arc::new(ObjectStoreRegistry::default());
499
500 let params = ObjectStoreParams {
501 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
502 ..ObjectStoreParams::default()
503 };
504
505 assert!(!mock_provider.called.load(Ordering::Relaxed));
507
508 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
509 .await
510 .unwrap();
511
512 let _ = store
514 .open(&Path::parse("/").unwrap())
515 .await
516 .unwrap()
517 .get_range(0..1)
518 .await;
519
520 assert!(mock_provider.called.load(Ordering::Relaxed));
522 }
523
524 #[test]
525 fn test_s3_path_parsing() {
526 let provider = AwsStoreProvider;
527
528 let cases = [
529 ("s3://bucket/path/to/file", "path/to/file"),
530 ("s3://bucket/测试path/to/file", "测试path/to/file"),
532 ("s3://bucket/path/&to/file", "path/&to/file"),
533 ("s3://bucket/path/=to/file", "path/=to/file"),
534 (
535 "s3+ddb://bucket/path/to/file?ddbTableName=test",
536 "path/to/file",
537 ),
538 ];
539
540 for (uri, expected_path) in cases {
541 let url = Url::parse(uri).unwrap();
542 let path = provider.extract_path(&url).unwrap();
543 let expected_path = Path::from(expected_path);
544 assert_eq!(path, expected_path)
545 }
546 }
547
548 #[test]
549 fn test_is_s3_express() {
550 let cases = [
551 (
552 "s3://bucket/path/to/file",
553 HashMap::from([("s3_express".to_string(), "true".to_string())]),
554 true,
555 ),
556 (
557 "s3://bucket/path/to/file",
558 HashMap::from([("s3_express".to_string(), "false".to_string())]),
559 false,
560 ),
561 ("s3://bucket/path/to/file", HashMap::from([]), false),
562 (
563 "s3://bucket--x-s3/path/to/file",
564 HashMap::from([("s3_express".to_string(), "true".to_string())]),
565 true,
566 ),
567 (
568 "s3://bucket--x-s3/path/to/file",
569 HashMap::from([("s3_express".to_string(), "false".to_string())]),
570 true, ),
572 ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
573 ];
574
575 for (uri, storage_map, expected) in cases {
576 let url = Url::parse(uri).unwrap();
577 let storage_options = StorageOptions(storage_map);
578 let is_s3_express = check_s3_express(&url, &storage_options);
579 assert_eq!(is_s3_express, expected);
580 }
581 }
582
583 #[tokio::test]
584 async fn test_use_opendal_flag() {
585 use crate::object_store::StorageOptionsAccessor;
586 let provider = AwsStoreProvider;
587 let url = Url::parse("s3://test-bucket/path").unwrap();
588 let params_with_flag = ObjectStoreParams {
589 storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
590 HashMap::from([
591 ("use_opendal".to_string(), "true".to_string()),
592 ("region".to_string(), "us-west-2".to_string()),
593 ]),
594 ))),
595 ..Default::default()
596 };
597
598 let store = provider
599 .new_store(url.clone(), ¶ms_with_flag)
600 .await
601 .unwrap();
602 assert_eq!(store.scheme, "s3");
603 }
604
605 #[derive(Debug)]
606 struct MockStorageOptionsProvider {
607 call_count: Arc<RwLock<usize>>,
608 expires_in_millis: Option<u64>,
609 }
610
611 impl MockStorageOptionsProvider {
612 fn new(expires_in_millis: Option<u64>) -> Self {
613 Self {
614 call_count: Arc::new(RwLock::new(0)),
615 expires_in_millis,
616 }
617 }
618
619 async fn get_call_count(&self) -> usize {
620 *self.call_count.read().await
621 }
622 }
623
624 #[async_trait::async_trait]
625 impl StorageOptionsProvider for MockStorageOptionsProvider {
626 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
627 let count = {
628 let mut c = self.call_count.write().await;
629 *c += 1;
630 *c
631 };
632
633 let mut options = HashMap::from([
634 ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
635 (
636 "aws_secret_access_key".to_string(),
637 format!("SECRET_{}", count),
638 ),
639 ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
640 ]);
641
642 if let Some(expires_in) = self.expires_in_millis {
643 let now_ms = SystemTime::now()
644 .duration_since(UNIX_EPOCH)
645 .unwrap()
646 .as_millis() as u64;
647 let expires_at = now_ms + expires_in;
648 options.insert("expires_at_millis".to_string(), expires_at.to_string());
649 }
650
651 Ok(Some(options))
652 }
653
654 fn provider_id(&self) -> String {
655 let ptr = Arc::as_ptr(&self.call_count) as usize;
656 format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
657 }
658 }
659
660 #[tokio::test]
661 async fn test_dynamic_credential_provider_with_initial_cache() {
662 MockClock::set_system_time(Duration::from_secs(100_000));
663
664 let now_ms = MockClock::system_time().as_millis() as u64;
665
666 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
668 600_000, )));
670
671 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
674 ("aws_access_key_id".to_string(), "AKID_CACHED".to_string()),
675 (
676 "aws_secret_access_key".to_string(),
677 "SECRET_CACHED".to_string(),
678 ),
679 ("aws_session_token".to_string(), "TOKEN_CACHED".to_string()),
680 ("expires_at_millis".to_string(), expires_at.to_string()),
681 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
683
684 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
685 mock.clone(),
686 initial_options,
687 );
688
689 let cred = provider.get_credential().await.unwrap();
691 assert_eq!(cred.key_id, "AKID_CACHED");
692 assert_eq!(cred.secret_key, "SECRET_CACHED");
693 assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
694
695 assert_eq!(mock.get_call_count().await, 0);
697 }
698
699 #[tokio::test]
700 async fn test_dynamic_credential_provider_with_expired_cache() {
701 MockClock::set_system_time(Duration::from_secs(100_000));
702
703 let now_ms = MockClock::system_time().as_millis() as u64;
704
705 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
707 600_000, )));
709
710 let expired_time = now_ms - 1_000; let initial_options = HashMap::from([
713 ("aws_access_key_id".to_string(), "AKID_EXPIRED".to_string()),
714 (
715 "aws_secret_access_key".to_string(),
716 "SECRET_EXPIRED".to_string(),
717 ),
718 ("expires_at_millis".to_string(), expired_time.to_string()),
719 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
721
722 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
723 mock.clone(),
724 initial_options,
725 );
726
727 let cred = provider.get_credential().await.unwrap();
729 assert_eq!(cred.key_id, "AKID_1");
730 assert_eq!(cred.secret_key, "SECRET_1");
731 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
732
733 assert_eq!(mock.get_call_count().await, 1);
735 }
736
737 #[tokio::test]
738 async fn test_dynamic_credential_provider_refresh_lead_time() {
739 MockClock::set_system_time(Duration::from_secs(100_000));
740
741 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
743 30_000, )));
745
746 let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
749
750 let cred = provider.get_credential().await.unwrap();
754 assert_eq!(cred.key_id, "AKID_1");
755 assert_eq!(mock.get_call_count().await, 1);
756
757 let cred = provider.get_credential().await.unwrap();
761 assert_eq!(cred.key_id, "AKID_2");
762 assert_eq!(mock.get_call_count().await, 2);
763 }
764
765 #[tokio::test]
766 async fn test_dynamic_credential_provider_no_initial_cache() {
767 MockClock::set_system_time(Duration::from_secs(100_000));
768
769 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
771 120_000, )));
773
774 let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
776
777 let cred = provider.get_credential().await.unwrap();
779 assert_eq!(cred.key_id, "AKID_1");
780 assert_eq!(cred.secret_key, "SECRET_1");
781 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
782 assert_eq!(mock.get_call_count().await, 1);
783
784 let cred = provider.get_credential().await.unwrap();
786 assert_eq!(cred.key_id, "AKID_1");
787 assert_eq!(mock.get_call_count().await, 1); MockClock::set_system_time(Duration::from_secs(100_000 + 90));
792 let cred = provider.get_credential().await.unwrap();
793 assert_eq!(cred.key_id, "AKID_2");
794 assert_eq!(cred.secret_key, "SECRET_2");
795 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
796 assert_eq!(mock.get_call_count().await, 2);
797
798 MockClock::set_system_time(Duration::from_secs(100_000 + 210));
800 let cred = provider.get_credential().await.unwrap();
801 assert_eq!(cred.key_id, "AKID_3");
802 assert_eq!(cred.secret_key, "SECRET_3");
803 assert_eq!(mock.get_call_count().await, 3);
804 }
805
806 #[tokio::test]
807 async fn test_dynamic_credential_provider_with_initial_options() {
808 MockClock::set_system_time(Duration::from_secs(100_000));
809
810 let now_ms = MockClock::system_time().as_millis() as u64;
811
812 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
814 600_000, )));
816
817 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
820 ("aws_access_key_id".to_string(), "AKID_INITIAL".to_string()),
821 (
822 "aws_secret_access_key".to_string(),
823 "SECRET_INITIAL".to_string(),
824 ),
825 ("aws_session_token".to_string(), "TOKEN_INITIAL".to_string()),
826 ("expires_at_millis".to_string(), expires_at.to_string()),
827 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
829
830 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
832 mock.clone(),
833 initial_options,
834 );
835
836 let cred = provider.get_credential().await.unwrap();
838 assert_eq!(cred.key_id, "AKID_INITIAL");
839 assert_eq!(cred.secret_key, "SECRET_INITIAL");
840 assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
841
842 assert_eq!(mock.get_call_count().await, 0);
844
845 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
848 let cred = provider.get_credential().await.unwrap();
849 assert_eq!(cred.key_id, "AKID_1");
850 assert_eq!(cred.secret_key, "SECRET_1");
851 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
852
853 assert_eq!(mock.get_call_count().await, 1);
855
856 MockClock::set_system_time(Duration::from_secs(100_000 + 660));
858 let cred = provider.get_credential().await.unwrap();
859 assert_eq!(cred.key_id, "AKID_2");
860 assert_eq!(cred.secret_key, "SECRET_2");
861 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
862
863 assert_eq!(mock.get_call_count().await, 2);
865
866 MockClock::set_system_time(Duration::from_secs(100_000 + 960));
868 let cred = provider.get_credential().await.unwrap();
869 assert_eq!(cred.key_id, "AKID_3");
870 assert_eq!(cred.secret_key, "SECRET_3");
871 assert_eq!(cred.token, Some("TOKEN_3".to_string()));
872
873 assert_eq!(mock.get_call_count().await, 3);
875 }
876
877 #[tokio::test]
878 async fn test_dynamic_credential_provider_concurrent_access() {
879 let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
881
882 let provider = Arc::new(DynamicStorageOptionsCredentialProvider::from_provider(
883 mock.clone(),
884 ));
885
886 let mut handles = vec![];
888 for i in 0..10 {
889 let provider = provider.clone();
890 let handle = tokio::spawn(async move {
891 let cred = provider.get_credential().await.unwrap();
892 assert_eq!(cred.key_id, "AKID_1");
894 assert_eq!(cred.secret_key, "SECRET_1");
895 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
896 i });
898 handles.push(handle);
899 }
900
901 let results: Vec<_> = futures::future::join_all(handles)
903 .await
904 .into_iter()
905 .map(|r| r.unwrap())
906 .collect();
907
908 assert_eq!(results.len(), 10);
910 for i in 0..10 {
911 assert!(results.contains(&i));
912 }
913
914 let call_count = mock.get_call_count().await;
917 assert_eq!(
918 call_count, 1,
919 "Provider should be called exactly once despite concurrent access"
920 );
921 }
922
923 #[tokio::test]
924 async fn test_dynamic_credential_provider_concurrent_refresh() {
925 MockClock::set_system_time(Duration::from_secs(100_000));
926
927 let now_ms = MockClock::system_time().as_millis() as u64;
928
929 let expires_at = now_ms - 1_000_000;
931 let initial_options = HashMap::from([
932 ("aws_access_key_id".to_string(), "AKID_OLD".to_string()),
933 (
934 "aws_secret_access_key".to_string(),
935 "SECRET_OLD".to_string(),
936 ),
937 ("aws_session_token".to_string(), "TOKEN_OLD".to_string()),
938 ("expires_at_millis".to_string(), expires_at.to_string()),
939 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
941
942 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
944 3_600_000, )));
946
947 let provider = Arc::new(
948 DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
949 mock.clone(),
950 initial_options,
951 ),
952 );
953
954 let mut handles = vec![];
957 for i in 0..20 {
958 let provider = provider.clone();
959 let handle = tokio::spawn(async move {
960 let cred = provider.get_credential().await.unwrap();
961 assert_eq!(cred.key_id, "AKID_1");
963 assert_eq!(cred.secret_key, "SECRET_1");
964 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
965 i
966 });
967 handles.push(handle);
968 }
969
970 let results: Vec<_> = futures::future::join_all(handles)
972 .await
973 .into_iter()
974 .map(|r| r.unwrap())
975 .collect();
976
977 assert_eq!(results.len(), 20);
979
980 let call_count = mock.get_call_count().await;
983 assert!(
984 call_count >= 1,
985 "Provider should be called at least once, was called {} times",
986 call_count
987 );
988
989 assert!(
991 call_count < 10,
992 "Provider should not be called too many times due to lock contention, was called {} times",
993 call_count
994 );
995 }
996
997 #[tokio::test]
998 async fn test_explicit_aws_credentials_takes_precedence_over_accessor() {
999 let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1001
1002 let accessor = Arc::new(StorageOptionsAccessor::with_provider(
1004 mock_storage_provider.clone(),
1005 ));
1006
1007 let explicit_cred_provider = Arc::new(MockAwsCredentialsProvider::default());
1009
1010 let (result, _region) = build_aws_credential(
1013 Duration::from_secs(300),
1014 Some(explicit_cred_provider.clone() as AwsCredentialProvider),
1015 None, Some("us-west-2".to_string()),
1017 Some(accessor),
1018 )
1019 .await
1020 .unwrap();
1021
1022 let cred = result.get_credential().await.unwrap();
1024
1025 assert!(explicit_cred_provider.called.load(Ordering::Relaxed));
1027
1028 assert_eq!(
1030 mock_storage_provider.get_call_count().await,
1031 0,
1032 "Storage options provider should not be called when explicit aws_credentials is provided"
1033 );
1034
1035 assert_eq!(cred.key_id, "");
1037 assert_eq!(cred.secret_key, "");
1038 }
1039
1040 #[tokio::test]
1041 async fn test_accessor_used_when_no_explicit_aws_credentials() {
1042 MockClock::set_system_time(Duration::from_secs(100_000));
1043
1044 let now_ms = MockClock::system_time().as_millis() as u64;
1045
1046 let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1048
1049 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
1052 (
1053 "aws_access_key_id".to_string(),
1054 "AKID_FROM_ACCESSOR".to_string(),
1055 ),
1056 (
1057 "aws_secret_access_key".to_string(),
1058 "SECRET_FROM_ACCESSOR".to_string(),
1059 ),
1060 (
1061 "aws_session_token".to_string(),
1062 "TOKEN_FROM_ACCESSOR".to_string(),
1063 ),
1064 ("expires_at_millis".to_string(), expires_at.to_string()),
1065 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
1067
1068 let accessor = Arc::new(StorageOptionsAccessor::with_initial_and_provider(
1070 initial_options,
1071 mock_storage_provider.clone(),
1072 ));
1073
1074 let (result, _region) = build_aws_credential(
1076 Duration::from_secs(300),
1077 None, None, Some("us-west-2".to_string()),
1080 Some(accessor),
1081 )
1082 .await
1083 .unwrap();
1084
1085 let cred = result.get_credential().await.unwrap();
1087 assert_eq!(cred.key_id, "AKID_FROM_ACCESSOR");
1088 assert_eq!(cred.secret_key, "SECRET_FROM_ACCESSOR");
1089
1090 assert_eq!(mock_storage_provider.get_call_count().await, 0);
1092
1093 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1095
1096 let cred = result.get_credential().await.unwrap();
1098 assert_eq!(cred.key_id, "AKID_1");
1099 assert_eq!(cred.secret_key, "SECRET_1");
1100
1101 assert_eq!(mock_storage_provider.get_call_count().await, 1);
1103 }
1104}