1use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{Operator, services::S3};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
20 StaticCredentialProvider,
21 aws::{
22 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
23 AwsCredentialProvider,
24 },
25};
26use tokio::sync::RwLock;
27use url::Url;
28
29use crate::object_store::{
30 DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE, ObjectStore,
31 ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsAccessor,
32 dynamic_credentials::{NamespaceCredentialsProvider, build_dynamic_credential_provider},
33 throttle::{AimdThrottleConfig, AimdThrottledStore},
34};
35use lance_core::error::{Error, Result};
36
37#[derive(Default, Debug)]
38pub struct AwsStoreProvider;
39
40impl AwsStoreProvider {
41 async fn build_amazon_s3_store(
42 &self,
43 base_path: &mut Url,
44 params: &ObjectStoreParams,
45 storage_options: &StorageOptions,
46 is_s3_express: bool,
47 ) -> Result<Arc<dyn OSObjectStore>> {
48 let retry_config = RetryConfig {
51 backoff: Default::default(),
52 max_retries: storage_options.client_max_retries(),
53 retry_timeout: Duration::from_secs(storage_options.client_retry_timeout()),
54 };
55
56 let mut s3_storage_options = storage_options.as_s3_options();
57 let region = resolve_s3_region(base_path, &s3_storage_options).await?;
58
59 let accessor = params.get_accessor();
61
62 let (aws_creds, region) = build_aws_credential(
63 params.s3_credentials_refresh_offset,
64 params.aws_credentials.clone(),
65 Some(&s3_storage_options),
66 region,
67 accessor,
68 )
69 .await?;
70
71 if is_s3_express {
73 s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
74 }
75
76 base_path.set_scheme("s3").unwrap();
78 base_path.set_query(None);
79
80 let mut builder =
82 AmazonS3Builder::new().with_client_options(storage_options.client_options()?);
83 for (key, value) in s3_storage_options {
84 builder = builder.with_config(key, value);
85 }
86 builder = builder
87 .with_url(base_path.as_ref())
88 .with_credentials(aws_creds)
89 .with_retry(retry_config)
90 .with_region(region);
91
92 Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
93 }
94
95 async fn build_opendal_s3_store(
96 &self,
97 base_path: &Url,
98 storage_options: &StorageOptions,
99 ) -> Result<Arc<dyn OSObjectStore>> {
100 let bucket = base_path
101 .host_str()
102 .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name"))?
103 .to_string();
104
105 let prefix = base_path.path().trim_start_matches('/').to_string();
106
107 let mut config_map: HashMap<String, String> = storage_options.0.clone();
110
111 config_map.insert("bucket".to_string(), bucket);
113
114 if !prefix.is_empty() {
115 config_map.insert("root".to_string(), "/".to_string());
116 }
117
118 let operator = Operator::from_iter::<S3>(config_map)
119 .map_err(|e| Error::invalid_input(format!("Failed to create S3 operator: {:?}", e)))?
120 .finish();
121
122 Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
123 }
124}
125
126#[async_trait::async_trait]
127impl ObjectStoreProvider for AwsStoreProvider {
128 async fn new_store(
129 &self,
130 mut base_path: Url,
131 params: &ObjectStoreParams,
132 ) -> Result<ObjectStore> {
133 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
134 let mut storage_options =
135 StorageOptions::new(params.storage_options().cloned().unwrap_or_default());
136 storage_options.with_env_s3();
137 let download_retry_count = storage_options.download_retry_count();
138
139 let use_opendal = storage_options
140 .0
141 .get("use_opendal")
142 .map(|v| v == "true")
143 .unwrap_or(false);
144
145 let is_s3_express = check_s3_express(&base_path, &storage_options);
147
148 let use_constant_size_upload_parts = storage_options
149 .0
150 .get("aws_endpoint")
151 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
152 .unwrap_or(false);
153
154 let inner = if use_opendal {
155 self.build_opendal_s3_store(&base_path, &storage_options)
157 .await?
158 } else {
159 self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
161 .await?
162 };
163 let throttle_config = AimdThrottleConfig::from_storage_options(params.storage_options())?;
164 let inner = if throttle_config.is_disabled() {
165 inner
166 } else {
167 Arc::new(AimdThrottledStore::new(inner, throttle_config)?) as Arc<dyn OSObjectStore>
168 };
169
170 Ok(ObjectStore {
171 inner,
172 scheme: String::from(base_path.scheme()),
173 block_size,
174 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
175 use_constant_size_upload_parts,
176 list_is_lexically_ordered: !is_s3_express,
177 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
178 download_retry_count,
179 io_tracker: Default::default(),
180 store_prefix: self
181 .calculate_object_store_prefix(&base_path, params.storage_options())?,
182 })
183 }
184}
185
186fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
188 storage_options
189 .0
190 .get("s3_express")
191 .map(|v| v == "true")
192 .unwrap_or(false)
193 || url.authority().ends_with("--x-s3")
194}
195
196async fn resolve_s3_region(
204 url: &Url,
205 storage_options: &HashMap<AmazonS3ConfigKey, String>,
206) -> Result<Option<String>> {
207 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
208 Ok(Some(region.clone()))
209 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
210 let bucket = url.host_str().ok_or_else(|| {
213 Error::invalid_input(format!("Could not parse bucket from url: {}", url))
214 })?;
215
216 let mut client_options = ClientOptions::default();
217 for (key, value) in storage_options {
218 if let AmazonS3ConfigKey::Client(client_key) = key {
219 client_options = client_options.with_config(*client_key, value.clone());
220 }
221 }
222
223 let bucket_region =
224 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
225 Ok(Some(bucket_region))
226 } else {
227 Ok(None)
228 }
229}
230
231pub async fn build_aws_credential(
248 credentials_refresh_offset: Duration,
249 credentials: Option<AwsCredentialProvider>,
250 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
251 region: Option<String>,
252 storage_options_accessor: Option<Arc<StorageOptionsAccessor>>,
253) -> Result<(AwsCredentialProvider, String)> {
254 use aws_config::meta::region::RegionProviderChain;
255 const DEFAULT_REGION: &str = "us-west-2";
256
257 let region = if let Some(region) = region {
258 region
259 } else {
260 RegionProviderChain::default_provider()
261 .or_else(DEFAULT_REGION)
262 .region()
263 .await
264 .map(|r| r.as_ref().to_string())
265 .unwrap_or(DEFAULT_REGION.to_string())
266 };
267
268 let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
269
270 if credentials.is_none()
272 && let Some(dynamic_creds) = build_dynamic_credential_provider::<ObjectStoreAwsCredential>(
273 storage_options_accessor.clone(),
274 )
275 .await?
276 {
277 return Ok((dynamic_creds, region));
278 }
279
280 if storage_options_accessor
281 .as_ref()
282 .is_some_and(|a| a.has_provider())
283 {
284 log::debug!(
285 "Storage options from provider do not contain explicit AWS credentials, \
286 falling back to default AWS credentials chain."
287 );
288 }
289
290 if let Some(creds) = credentials {
292 Ok((creds, region))
293 } else if let Some(creds) = storage_options_credentials {
294 Ok((Arc::new(creds), region))
295 } else {
296 let credentials_provider = DefaultCredentialsChain::builder().build().await;
297
298 Ok((
299 Arc::new(AwsCredentialAdapter::new(
300 Arc::new(credentials_provider),
301 credentials_refresh_offset,
302 )),
303 region,
304 ))
305 }
306}
307
308fn extract_static_s3_credentials(
309 options: &HashMap<AmazonS3ConfigKey, String>,
310) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
311 let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
312 let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
313 let token = options.get(&AmazonS3ConfigKey::Token).cloned();
314 match (key_id, secret_key, token) {
315 (Some(key_id), Some(secret_key), token) => {
316 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
317 key_id,
318 secret_key,
319 token,
320 }))
321 }
322 _ => None,
323 }
324}
325
326#[derive(Debug)]
328pub struct AwsCredentialAdapter {
329 pub inner: Arc<dyn ProvideCredentials>,
330
331 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
333
334 credentials_refresh_offset: Duration,
336}
337
338impl AwsCredentialAdapter {
339 pub fn new(
340 provider: Arc<dyn ProvideCredentials>,
341 credentials_refresh_offset: Duration,
342 ) -> Self {
343 Self {
344 inner: provider,
345 cache: Arc::new(RwLock::new(HashMap::new())),
346 credentials_refresh_offset,
347 }
348 }
349}
350
351const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
352
353fn to_system_time(time: std::time::SystemTime) -> SystemTime {
355 let duration_since_epoch = time
356 .duration_since(std::time::UNIX_EPOCH)
357 .expect("time should be after UNIX_EPOCH");
358 UNIX_EPOCH + duration_since_epoch
359}
360
361#[async_trait::async_trait]
362impl CredentialProvider for AwsCredentialAdapter {
363 type Credential = ObjectStoreAwsCredential;
364
365 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
366 let cached_creds = {
367 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
368 let expired = cache_value
369 .clone()
370 .map(|cred| {
371 cred.expiry()
372 .map(|exp| {
373 to_system_time(exp)
374 .checked_sub(self.credentials_refresh_offset)
375 .expect("this time should always be valid")
376 < SystemTime::now()
377 })
378 .unwrap_or(false)
380 })
381 .unwrap_or(true); if expired { None } else { cache_value.clone() }
383 };
384
385 if let Some(creds) = cached_creds {
386 Ok(Arc::new(Self::Credential {
387 key_id: creds.access_key_id().to_string(),
388 secret_key: creds.secret_access_key().to_string(),
389 token: creds.session_token().map(|s| s.to_string()),
390 }))
391 } else {
392 let refreshed_creds =
393 Arc::new(self.inner.provide_credentials().await.map_err(|e| {
394 Error::internal(format!("Failed to get AWS credentials: {:?}", e))
395 })?);
396
397 self.cache
398 .write()
399 .await
400 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
401
402 Ok(Arc::new(Self::Credential {
403 key_id: refreshed_creds.access_key_id().to_string(),
404 secret_key: refreshed_creds.secret_access_key().to_string(),
405 token: refreshed_creds.session_token().map(|s| s.to_string()),
406 }))
407 }
408 }
409}
410
411impl StorageOptions {
412 pub fn with_env_s3(&mut self) {
414 for (os_key, os_value) in std::env::vars_os() {
415 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str())
416 && let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase())
417 && !self.0.contains_key(config_key.as_ref())
418 {
419 self.0
420 .insert(config_key.as_ref().to_string(), value.to_string());
421 }
422 }
423 }
424
425 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
427 self.0
428 .iter()
429 .filter_map(|(key, value)| {
430 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
431 Some((s3_key, value.clone()))
432 })
433 .collect()
434 }
435}
436
437impl ObjectStoreParams {
438 pub fn with_aws_credentials(
440 aws_credentials: Option<AwsCredentialProvider>,
441 region: Option<String>,
442 ) -> Self {
443 let storage_options_accessor = region.map(|region| {
444 let opts: HashMap<String, String> =
445 [("region".into(), region)].iter().cloned().collect();
446 Arc::new(StorageOptionsAccessor::with_static_options(opts))
447 });
448 Self {
449 aws_credentials,
450 storage_options_accessor,
451 ..Default::default()
452 }
453 }
454}
455
456pub type DynamicStorageOptionsCredentialProvider =
457 NamespaceCredentialsProvider<ObjectStoreAwsCredential>;
458
459#[cfg(test)]
460mod tests {
461 use crate::object_store::ObjectStoreRegistry;
462 use crate::object_store::StorageOptionsProvider;
463 use mock_instant::thread_local::MockClock;
464 use object_store::path::Path;
465 use std::sync::atomic::{AtomicBool, Ordering};
466
467 use super::*;
468
469 #[derive(Debug, Default)]
470 struct MockAwsCredentialsProvider {
471 called: AtomicBool,
472 }
473
474 #[async_trait::async_trait]
475 impl CredentialProvider for MockAwsCredentialsProvider {
476 type Credential = ObjectStoreAwsCredential;
477
478 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
479 self.called.store(true, Ordering::Relaxed);
480 Ok(Arc::new(Self::Credential {
481 key_id: "".to_string(),
482 secret_key: "".to_string(),
483 token: None,
484 }))
485 }
486 }
487
488 #[tokio::test]
489 async fn test_injected_aws_creds_option_is_used() {
490 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
491 let registry = Arc::new(ObjectStoreRegistry::default());
492
493 let params = ObjectStoreParams {
494 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
495 ..ObjectStoreParams::default()
496 };
497
498 assert!(!mock_provider.called.load(Ordering::Relaxed));
500
501 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
502 .await
503 .unwrap();
504
505 let _ = store
507 .open(&Path::parse("/").unwrap())
508 .await
509 .unwrap()
510 .get_range(0..1)
511 .await;
512
513 assert!(mock_provider.called.load(Ordering::Relaxed));
515 }
516
517 #[test]
518 fn test_s3_path_parsing() {
519 let provider = AwsStoreProvider;
520
521 let cases = [
522 ("s3://bucket/path/to/file", "path/to/file"),
523 ("s3://bucket/测试path/to/file", "测试path/to/file"),
525 ("s3://bucket/path/&to/file", "path/&to/file"),
526 ("s3://bucket/path/=to/file", "path/=to/file"),
527 (
528 "s3+ddb://bucket/path/to/file?ddbTableName=test",
529 "path/to/file",
530 ),
531 ];
532
533 for (uri, expected_path) in cases {
534 let url = Url::parse(uri).unwrap();
535 let path = provider.extract_path(&url).unwrap();
536 let expected_path = Path::from(expected_path);
537 assert_eq!(path, expected_path)
538 }
539 }
540
541 #[test]
542 fn test_is_s3_express() {
543 let cases = [
544 (
545 "s3://bucket/path/to/file",
546 HashMap::from([("s3_express".to_string(), "true".to_string())]),
547 true,
548 ),
549 (
550 "s3://bucket/path/to/file",
551 HashMap::from([("s3_express".to_string(), "false".to_string())]),
552 false,
553 ),
554 ("s3://bucket/path/to/file", HashMap::from([]), false),
555 (
556 "s3://bucket--x-s3/path/to/file",
557 HashMap::from([("s3_express".to_string(), "true".to_string())]),
558 true,
559 ),
560 (
561 "s3://bucket--x-s3/path/to/file",
562 HashMap::from([("s3_express".to_string(), "false".to_string())]),
563 true, ),
565 ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
566 ];
567
568 for (uri, storage_map, expected) in cases {
569 let url = Url::parse(uri).unwrap();
570 let storage_options = StorageOptions(storage_map);
571 let is_s3_express = check_s3_express(&url, &storage_options);
572 assert_eq!(is_s3_express, expected);
573 }
574 }
575
576 #[tokio::test]
577 async fn test_use_opendal_flag() {
578 use crate::object_store::StorageOptionsAccessor;
579 let provider = AwsStoreProvider;
580 let url = Url::parse("s3://test-bucket/path").unwrap();
581 let params_with_flag = ObjectStoreParams {
582 storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
583 HashMap::from([
584 ("use_opendal".to_string(), "true".to_string()),
585 ("region".to_string(), "us-west-2".to_string()),
586 ]),
587 ))),
588 ..Default::default()
589 };
590
591 let store = provider
592 .new_store(url.clone(), ¶ms_with_flag)
593 .await
594 .unwrap();
595 assert_eq!(store.scheme, "s3");
596 }
597
598 #[derive(Debug)]
599 struct MockStorageOptionsProvider {
600 call_count: Arc<RwLock<usize>>,
601 expires_in_millis: Option<u64>,
602 }
603
604 impl MockStorageOptionsProvider {
605 fn new(expires_in_millis: Option<u64>) -> Self {
606 Self {
607 call_count: Arc::new(RwLock::new(0)),
608 expires_in_millis,
609 }
610 }
611
612 async fn get_call_count(&self) -> usize {
613 *self.call_count.read().await
614 }
615 }
616
617 #[async_trait::async_trait]
618 impl StorageOptionsProvider for MockStorageOptionsProvider {
619 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
620 let count = {
621 let mut c = self.call_count.write().await;
622 *c += 1;
623 *c
624 };
625
626 let mut options = HashMap::from([
627 ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
628 (
629 "aws_secret_access_key".to_string(),
630 format!("SECRET_{}", count),
631 ),
632 ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
633 ]);
634
635 if let Some(expires_in) = self.expires_in_millis {
636 let now_ms = SystemTime::now()
637 .duration_since(UNIX_EPOCH)
638 .unwrap()
639 .as_millis() as u64;
640 let expires_at = now_ms + expires_in;
641 options.insert("expires_at_millis".to_string(), expires_at.to_string());
642 }
643
644 Ok(Some(options))
645 }
646
647 fn provider_id(&self) -> String {
648 let ptr = Arc::as_ptr(&self.call_count) as usize;
649 format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
650 }
651 }
652
653 #[tokio::test]
654 async fn test_dynamic_credential_provider_with_initial_cache() {
655 MockClock::set_system_time(Duration::from_secs(100_000));
656
657 let now_ms = MockClock::system_time().as_millis() as u64;
658
659 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
661 600_000, )));
663
664 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
667 ("aws_access_key_id".to_string(), "AKID_CACHED".to_string()),
668 (
669 "aws_secret_access_key".to_string(),
670 "SECRET_CACHED".to_string(),
671 ),
672 ("aws_session_token".to_string(), "TOKEN_CACHED".to_string()),
673 ("expires_at_millis".to_string(), expires_at.to_string()),
674 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
676
677 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
678 mock.clone(),
679 initial_options,
680 );
681
682 let cred = provider.get_credential().await.unwrap();
684 assert_eq!(cred.key_id, "AKID_CACHED");
685 assert_eq!(cred.secret_key, "SECRET_CACHED");
686 assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
687
688 assert_eq!(mock.get_call_count().await, 0);
690 }
691
692 #[tokio::test]
693 async fn test_dynamic_credential_provider_with_expired_cache() {
694 MockClock::set_system_time(Duration::from_secs(100_000));
695
696 let now_ms = MockClock::system_time().as_millis() as u64;
697
698 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
700 600_000, )));
702
703 let expired_time = now_ms - 1_000; let initial_options = HashMap::from([
706 ("aws_access_key_id".to_string(), "AKID_EXPIRED".to_string()),
707 (
708 "aws_secret_access_key".to_string(),
709 "SECRET_EXPIRED".to_string(),
710 ),
711 ("expires_at_millis".to_string(), expired_time.to_string()),
712 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
714
715 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
716 mock.clone(),
717 initial_options,
718 );
719
720 let cred = provider.get_credential().await.unwrap();
722 assert_eq!(cred.key_id, "AKID_1");
723 assert_eq!(cred.secret_key, "SECRET_1");
724 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
725
726 assert_eq!(mock.get_call_count().await, 1);
728 }
729
730 #[tokio::test]
731 async fn test_dynamic_credential_provider_refresh_lead_time() {
732 MockClock::set_system_time(Duration::from_secs(100_000));
733
734 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
736 30_000, )));
738
739 let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
742
743 let cred = provider.get_credential().await.unwrap();
747 assert_eq!(cred.key_id, "AKID_1");
748 assert_eq!(mock.get_call_count().await, 1);
749
750 let cred = provider.get_credential().await.unwrap();
754 assert_eq!(cred.key_id, "AKID_2");
755 assert_eq!(mock.get_call_count().await, 2);
756 }
757
758 #[tokio::test]
759 async fn test_dynamic_credential_provider_no_initial_cache() {
760 MockClock::set_system_time(Duration::from_secs(100_000));
761
762 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
764 120_000, )));
766
767 let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
769
770 let cred = provider.get_credential().await.unwrap();
772 assert_eq!(cred.key_id, "AKID_1");
773 assert_eq!(cred.secret_key, "SECRET_1");
774 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
775 assert_eq!(mock.get_call_count().await, 1);
776
777 let cred = provider.get_credential().await.unwrap();
779 assert_eq!(cred.key_id, "AKID_1");
780 assert_eq!(mock.get_call_count().await, 1); MockClock::set_system_time(Duration::from_secs(100_000 + 90));
785 let cred = provider.get_credential().await.unwrap();
786 assert_eq!(cred.key_id, "AKID_2");
787 assert_eq!(cred.secret_key, "SECRET_2");
788 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
789 assert_eq!(mock.get_call_count().await, 2);
790
791 MockClock::set_system_time(Duration::from_secs(100_000 + 210));
793 let cred = provider.get_credential().await.unwrap();
794 assert_eq!(cred.key_id, "AKID_3");
795 assert_eq!(cred.secret_key, "SECRET_3");
796 assert_eq!(mock.get_call_count().await, 3);
797 }
798
799 #[tokio::test]
800 async fn test_dynamic_credential_provider_with_initial_options() {
801 MockClock::set_system_time(Duration::from_secs(100_000));
802
803 let now_ms = MockClock::system_time().as_millis() as u64;
804
805 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
807 600_000, )));
809
810 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
813 ("aws_access_key_id".to_string(), "AKID_INITIAL".to_string()),
814 (
815 "aws_secret_access_key".to_string(),
816 "SECRET_INITIAL".to_string(),
817 ),
818 ("aws_session_token".to_string(), "TOKEN_INITIAL".to_string()),
819 ("expires_at_millis".to_string(), expires_at.to_string()),
820 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
822
823 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
825 mock.clone(),
826 initial_options,
827 );
828
829 let cred = provider.get_credential().await.unwrap();
831 assert_eq!(cred.key_id, "AKID_INITIAL");
832 assert_eq!(cred.secret_key, "SECRET_INITIAL");
833 assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
834
835 assert_eq!(mock.get_call_count().await, 0);
837
838 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
841 let cred = provider.get_credential().await.unwrap();
842 assert_eq!(cred.key_id, "AKID_1");
843 assert_eq!(cred.secret_key, "SECRET_1");
844 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
845
846 assert_eq!(mock.get_call_count().await, 1);
848
849 MockClock::set_system_time(Duration::from_secs(100_000 + 660));
851 let cred = provider.get_credential().await.unwrap();
852 assert_eq!(cred.key_id, "AKID_2");
853 assert_eq!(cred.secret_key, "SECRET_2");
854 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
855
856 assert_eq!(mock.get_call_count().await, 2);
858
859 MockClock::set_system_time(Duration::from_secs(100_000 + 960));
861 let cred = provider.get_credential().await.unwrap();
862 assert_eq!(cred.key_id, "AKID_3");
863 assert_eq!(cred.secret_key, "SECRET_3");
864 assert_eq!(cred.token, Some("TOKEN_3".to_string()));
865
866 assert_eq!(mock.get_call_count().await, 3);
868 }
869
870 #[tokio::test]
871 async fn test_dynamic_credential_provider_concurrent_access() {
872 let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
874
875 let provider = Arc::new(DynamicStorageOptionsCredentialProvider::from_provider(
876 mock.clone(),
877 ));
878
879 let mut handles = vec![];
881 for i in 0..10 {
882 let provider = provider.clone();
883 let handle = tokio::spawn(async move {
884 let cred = provider.get_credential().await.unwrap();
885 assert_eq!(cred.key_id, "AKID_1");
887 assert_eq!(cred.secret_key, "SECRET_1");
888 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
889 i });
891 handles.push(handle);
892 }
893
894 let results: Vec<_> = futures::future::join_all(handles)
896 .await
897 .into_iter()
898 .map(|r| r.unwrap())
899 .collect();
900
901 assert_eq!(results.len(), 10);
903 for i in 0..10 {
904 assert!(results.contains(&i));
905 }
906
907 let call_count = mock.get_call_count().await;
910 assert_eq!(
911 call_count, 1,
912 "Provider should be called exactly once despite concurrent access"
913 );
914 }
915
916 #[tokio::test]
917 async fn test_dynamic_credential_provider_concurrent_refresh() {
918 MockClock::set_system_time(Duration::from_secs(100_000));
919
920 let now_ms = MockClock::system_time().as_millis() as u64;
921
922 let expires_at = now_ms - 1_000_000;
924 let initial_options = HashMap::from([
925 ("aws_access_key_id".to_string(), "AKID_OLD".to_string()),
926 (
927 "aws_secret_access_key".to_string(),
928 "SECRET_OLD".to_string(),
929 ),
930 ("aws_session_token".to_string(), "TOKEN_OLD".to_string()),
931 ("expires_at_millis".to_string(), expires_at.to_string()),
932 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
934
935 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
937 3_600_000, )));
939
940 let provider = Arc::new(
941 DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
942 mock.clone(),
943 initial_options,
944 ),
945 );
946
947 let mut handles = vec![];
950 for i in 0..20 {
951 let provider = provider.clone();
952 let handle = tokio::spawn(async move {
953 let cred = provider.get_credential().await.unwrap();
954 assert_eq!(cred.key_id, "AKID_1");
956 assert_eq!(cred.secret_key, "SECRET_1");
957 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
958 i
959 });
960 handles.push(handle);
961 }
962
963 let results: Vec<_> = futures::future::join_all(handles)
965 .await
966 .into_iter()
967 .map(|r| r.unwrap())
968 .collect();
969
970 assert_eq!(results.len(), 20);
972
973 let call_count = mock.get_call_count().await;
976 assert!(
977 call_count >= 1,
978 "Provider should be called at least once, was called {} times",
979 call_count
980 );
981
982 assert!(
984 call_count < 10,
985 "Provider should not be called too many times due to lock contention, was called {} times",
986 call_count
987 );
988 }
989
990 #[tokio::test]
991 async fn test_explicit_aws_credentials_takes_precedence_over_accessor() {
992 let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
994
995 let accessor = Arc::new(StorageOptionsAccessor::with_provider(
997 mock_storage_provider.clone(),
998 ));
999
1000 let explicit_cred_provider = Arc::new(MockAwsCredentialsProvider::default());
1002
1003 let (result, _region) = build_aws_credential(
1006 Duration::from_secs(300),
1007 Some(explicit_cred_provider.clone() as AwsCredentialProvider),
1008 None, Some("us-west-2".to_string()),
1010 Some(accessor),
1011 )
1012 .await
1013 .unwrap();
1014
1015 let cred = result.get_credential().await.unwrap();
1017
1018 assert!(explicit_cred_provider.called.load(Ordering::Relaxed));
1020
1021 assert_eq!(
1023 mock_storage_provider.get_call_count().await,
1024 0,
1025 "Storage options provider should not be called when explicit aws_credentials is provided"
1026 );
1027
1028 assert_eq!(cred.key_id, "");
1030 assert_eq!(cred.secret_key, "");
1031 }
1032
1033 #[tokio::test]
1034 async fn test_accessor_used_when_no_explicit_aws_credentials() {
1035 MockClock::set_system_time(Duration::from_secs(100_000));
1036
1037 let now_ms = MockClock::system_time().as_millis() as u64;
1038
1039 let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1041
1042 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
1045 (
1046 "aws_access_key_id".to_string(),
1047 "AKID_FROM_ACCESSOR".to_string(),
1048 ),
1049 (
1050 "aws_secret_access_key".to_string(),
1051 "SECRET_FROM_ACCESSOR".to_string(),
1052 ),
1053 (
1054 "aws_session_token".to_string(),
1055 "TOKEN_FROM_ACCESSOR".to_string(),
1056 ),
1057 ("expires_at_millis".to_string(), expires_at.to_string()),
1058 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
1060
1061 let accessor = Arc::new(StorageOptionsAccessor::with_initial_and_provider(
1063 initial_options,
1064 mock_storage_provider.clone(),
1065 ));
1066
1067 let (result, _region) = build_aws_credential(
1069 Duration::from_secs(300),
1070 None, None, Some("us-west-2".to_string()),
1073 Some(accessor),
1074 )
1075 .await
1076 .unwrap();
1077
1078 let cred = result.get_credential().await.unwrap();
1080 assert_eq!(cred.key_id, "AKID_FROM_ACCESSOR");
1081 assert_eq!(cred.secret_key, "SECRET_FROM_ACCESSOR");
1082
1083 assert_eq!(mock_storage_provider.get_call_count().await, 0);
1085
1086 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1088
1089 let cred = result.get_credential().await.unwrap();
1091 assert_eq!(cred.key_id, "AKID_1");
1092 assert_eq!(cred.secret_key, "SECRET_1");
1093
1094 assert_eq!(mock_storage_provider.get_call_count().await, 1);
1096 }
1097}