1use std::{collections::HashMap, fmt, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{Operator, services::S3};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
20 StaticCredentialProvider,
21 aws::{
22 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
23 AwsCredentialProvider,
24 },
25};
26use tokio::sync::RwLock;
27use url::Url;
28
29use crate::object_store::{
30 DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE, ObjectStore,
31 ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsAccessor,
32 StorageOptionsProvider,
33};
34use lance_core::error::{Error, Result};
35
36#[derive(Default, Debug)]
37pub struct AwsStoreProvider;
38
39impl AwsStoreProvider {
40 async fn build_amazon_s3_store(
41 &self,
42 base_path: &mut Url,
43 params: &ObjectStoreParams,
44 storage_options: &StorageOptions,
45 is_s3_express: bool,
46 ) -> Result<Arc<dyn OSObjectStore>> {
47 let max_retries = storage_options.client_max_retries();
48 let retry_timeout = storage_options.client_retry_timeout();
49 let retry_config = RetryConfig {
50 backoff: Default::default(),
51 max_retries,
52 retry_timeout: Duration::from_secs(retry_timeout),
53 };
54
55 let mut s3_storage_options = storage_options.as_s3_options();
56 let region = resolve_s3_region(base_path, &s3_storage_options).await?;
57
58 let accessor = params.get_accessor();
60
61 let (aws_creds, region) = build_aws_credential(
62 params.s3_credentials_refresh_offset,
63 params.aws_credentials.clone(),
64 Some(&s3_storage_options),
65 region,
66 accessor,
67 )
68 .await?;
69
70 if is_s3_express {
72 s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
73 }
74
75 base_path.set_scheme("s3").unwrap();
77 base_path.set_query(None);
78
79 let mut builder =
81 AmazonS3Builder::new().with_client_options(storage_options.client_options()?);
82 for (key, value) in s3_storage_options {
83 builder = builder.with_config(key, value);
84 }
85 builder = builder
86 .with_url(base_path.as_ref())
87 .with_credentials(aws_creds)
88 .with_retry(retry_config)
89 .with_region(region);
90
91 Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
92 }
93
94 async fn build_opendal_s3_store(
95 &self,
96 base_path: &Url,
97 storage_options: &StorageOptions,
98 ) -> Result<Arc<dyn OSObjectStore>> {
99 let bucket = base_path
100 .host_str()
101 .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name"))?
102 .to_string();
103
104 let prefix = base_path.path().trim_start_matches('/').to_string();
105
106 let mut config_map: HashMap<String, String> = storage_options.0.clone();
109
110 config_map.insert("bucket".to_string(), bucket);
112
113 if !prefix.is_empty() {
114 config_map.insert("root".to_string(), "/".to_string());
115 }
116
117 let operator = Operator::from_iter::<S3>(config_map)
118 .map_err(|e| Error::invalid_input(format!("Failed to create S3 operator: {:?}", e)))?
119 .finish();
120
121 Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
122 }
123}
124
125#[async_trait::async_trait]
126impl ObjectStoreProvider for AwsStoreProvider {
127 async fn new_store(
128 &self,
129 mut base_path: Url,
130 params: &ObjectStoreParams,
131 ) -> Result<ObjectStore> {
132 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
133 let mut storage_options =
134 StorageOptions(params.storage_options().cloned().unwrap_or_default());
135 storage_options.with_env_s3();
136 let download_retry_count = storage_options.download_retry_count();
137
138 let use_opendal = storage_options
139 .0
140 .get("use_opendal")
141 .map(|v| v == "true")
142 .unwrap_or(false);
143
144 let is_s3_express = check_s3_express(&base_path, &storage_options);
146
147 let use_constant_size_upload_parts = storage_options
148 .0
149 .get("aws_endpoint")
150 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
151 .unwrap_or(false);
152
153 let inner = if use_opendal {
154 self.build_opendal_s3_store(&base_path, &storage_options)
156 .await?
157 } else {
158 self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
160 .await?
161 };
162
163 Ok(ObjectStore {
164 inner,
165 scheme: String::from(base_path.scheme()),
166 block_size,
167 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
168 use_constant_size_upload_parts,
169 list_is_lexically_ordered: !is_s3_express,
170 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
171 download_retry_count,
172 io_tracker: Default::default(),
173 store_prefix: self
174 .calculate_object_store_prefix(&base_path, params.storage_options())?,
175 })
176 }
177}
178
179fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
181 storage_options
182 .0
183 .get("s3_express")
184 .map(|v| v == "true")
185 .unwrap_or(false)
186 || url.authority().ends_with("--x-s3")
187}
188
189async fn resolve_s3_region(
197 url: &Url,
198 storage_options: &HashMap<AmazonS3ConfigKey, String>,
199) -> Result<Option<String>> {
200 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
201 Ok(Some(region.clone()))
202 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
203 let bucket = url.host_str().ok_or_else(|| {
206 Error::invalid_input(format!("Could not parse bucket from url: {}", url))
207 })?;
208
209 let mut client_options = ClientOptions::default();
210 for (key, value) in storage_options {
211 if let AmazonS3ConfigKey::Client(client_key) = key {
212 client_options = client_options.with_config(*client_key, value.clone());
213 }
214 }
215
216 let bucket_region =
217 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
218 Ok(Some(bucket_region))
219 } else {
220 Ok(None)
221 }
222}
223
224pub async fn build_aws_credential(
241 credentials_refresh_offset: Duration,
242 credentials: Option<AwsCredentialProvider>,
243 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
244 region: Option<String>,
245 storage_options_accessor: Option<Arc<StorageOptionsAccessor>>,
246) -> Result<(AwsCredentialProvider, String)> {
247 use aws_config::meta::region::RegionProviderChain;
248 const DEFAULT_REGION: &str = "us-west-2";
249
250 let region = if let Some(region) = region {
251 region
252 } else {
253 RegionProviderChain::default_provider()
254 .or_else(DEFAULT_REGION)
255 .region()
256 .await
257 .map(|r| r.as_ref().to_string())
258 .unwrap_or(DEFAULT_REGION.to_string())
259 };
260
261 let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
262
263 if let Some(accessor) = storage_options_accessor
265 && accessor.has_provider()
266 {
267 if let Some(creds) = credentials {
269 return Ok((creds, region));
270 }
271 return Ok((
273 Arc::new(DynamicStorageOptionsCredentialProvider::new(accessor)),
274 region,
275 ));
276 }
277
278 if let Some(creds) = credentials {
280 Ok((creds, region))
281 } else if let Some(creds) = storage_options_credentials {
282 Ok((Arc::new(creds), region))
283 } else {
284 let credentials_provider = DefaultCredentialsChain::builder().build().await;
285
286 Ok((
287 Arc::new(AwsCredentialAdapter::new(
288 Arc::new(credentials_provider),
289 credentials_refresh_offset,
290 )),
291 region,
292 ))
293 }
294}
295
296fn extract_static_s3_credentials(
297 options: &HashMap<AmazonS3ConfigKey, String>,
298) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
299 let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
300 let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
301 let token = options.get(&AmazonS3ConfigKey::Token).cloned();
302 match (key_id, secret_key, token) {
303 (Some(key_id), Some(secret_key), token) => {
304 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
305 key_id,
306 secret_key,
307 token,
308 }))
309 }
310 _ => None,
311 }
312}
313
314#[derive(Debug)]
316pub struct AwsCredentialAdapter {
317 pub inner: Arc<dyn ProvideCredentials>,
318
319 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
321
322 credentials_refresh_offset: Duration,
324}
325
326impl AwsCredentialAdapter {
327 pub fn new(
328 provider: Arc<dyn ProvideCredentials>,
329 credentials_refresh_offset: Duration,
330 ) -> Self {
331 Self {
332 inner: provider,
333 cache: Arc::new(RwLock::new(HashMap::new())),
334 credentials_refresh_offset,
335 }
336 }
337}
338
339const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
340
341fn to_system_time(time: std::time::SystemTime) -> SystemTime {
343 let duration_since_epoch = time
344 .duration_since(std::time::UNIX_EPOCH)
345 .expect("time should be after UNIX_EPOCH");
346 UNIX_EPOCH + duration_since_epoch
347}
348
349#[async_trait::async_trait]
350impl CredentialProvider for AwsCredentialAdapter {
351 type Credential = ObjectStoreAwsCredential;
352
353 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
354 let cached_creds = {
355 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
356 let expired = cache_value
357 .clone()
358 .map(|cred| {
359 cred.expiry()
360 .map(|exp| {
361 to_system_time(exp)
362 .checked_sub(self.credentials_refresh_offset)
363 .expect("this time should always be valid")
364 < SystemTime::now()
365 })
366 .unwrap_or(false)
368 })
369 .unwrap_or(true); if expired { None } else { cache_value.clone() }
371 };
372
373 if let Some(creds) = cached_creds {
374 Ok(Arc::new(Self::Credential {
375 key_id: creds.access_key_id().to_string(),
376 secret_key: creds.secret_access_key().to_string(),
377 token: creds.session_token().map(|s| s.to_string()),
378 }))
379 } else {
380 let refreshed_creds =
381 Arc::new(self.inner.provide_credentials().await.map_err(|e| {
382 Error::internal(format!("Failed to get AWS credentials: {:?}", e))
383 })?);
384
385 self.cache
386 .write()
387 .await
388 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
389
390 Ok(Arc::new(Self::Credential {
391 key_id: refreshed_creds.access_key_id().to_string(),
392 secret_key: refreshed_creds.secret_access_key().to_string(),
393 token: refreshed_creds.session_token().map(|s| s.to_string()),
394 }))
395 }
396 }
397}
398
399impl StorageOptions {
400 pub fn with_env_s3(&mut self) {
402 for (os_key, os_value) in std::env::vars_os() {
403 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str())
404 && let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase())
405 && !self.0.contains_key(config_key.as_ref())
406 {
407 self.0
408 .insert(config_key.as_ref().to_string(), value.to_string());
409 }
410 }
411 }
412
413 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
415 self.0
416 .iter()
417 .filter_map(|(key, value)| {
418 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
419 Some((s3_key, value.clone()))
420 })
421 .collect()
422 }
423}
424
425impl ObjectStoreParams {
426 pub fn with_aws_credentials(
428 aws_credentials: Option<AwsCredentialProvider>,
429 region: Option<String>,
430 ) -> Self {
431 let storage_options_accessor = region.map(|region| {
432 let opts: HashMap<String, String> =
433 [("region".into(), region)].iter().cloned().collect();
434 Arc::new(StorageOptionsAccessor::with_static_options(opts))
435 });
436 Self {
437 aws_credentials,
438 storage_options_accessor,
439 ..Default::default()
440 }
441 }
442}
443
444pub struct DynamicStorageOptionsCredentialProvider {
458 accessor: Arc<StorageOptionsAccessor>,
459}
460
461impl fmt::Debug for DynamicStorageOptionsCredentialProvider {
462 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
463 f.debug_struct("DynamicStorageOptionsCredentialProvider")
464 .field("accessor", &self.accessor)
465 .finish()
466 }
467}
468
469impl DynamicStorageOptionsCredentialProvider {
470 pub fn new(accessor: Arc<StorageOptionsAccessor>) -> Self {
472 Self { accessor }
473 }
474
475 pub fn from_provider(provider: Arc<dyn StorageOptionsProvider>) -> Self {
484 Self {
485 accessor: Arc::new(StorageOptionsAccessor::with_provider(provider)),
486 }
487 }
488
489 pub fn from_provider_with_initial(
499 provider: Arc<dyn StorageOptionsProvider>,
500 initial_options: HashMap<String, String>,
501 ) -> Self {
502 Self {
503 accessor: Arc::new(StorageOptionsAccessor::with_initial_and_provider(
504 initial_options,
505 provider,
506 )),
507 }
508 }
509}
510
511#[async_trait::async_trait]
512impl CredentialProvider for DynamicStorageOptionsCredentialProvider {
513 type Credential = ObjectStoreAwsCredential;
514
515 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
516 let storage_options = self.accessor.get_storage_options().await.map_err(|e| {
517 object_store::Error::Generic {
518 store: "DynamicStorageOptionsCredentialProvider",
519 source: Box::new(e),
520 }
521 })?;
522
523 let s3_options = storage_options.as_s3_options();
524 let static_creds = extract_static_s3_credentials(&s3_options).ok_or_else(|| {
525 object_store::Error::Generic {
526 store: "DynamicStorageOptionsCredentialProvider",
527 source: "Missing required credentials in storage options".into(),
528 }
529 })?;
530
531 static_creds
532 .get_credential()
533 .await
534 .map_err(|e| object_store::Error::Generic {
535 store: "DynamicStorageOptionsCredentialProvider",
536 source: Box::new(e),
537 })
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use crate::object_store::ObjectStoreRegistry;
544 use mock_instant::thread_local::MockClock;
545 use object_store::path::Path;
546 use std::sync::atomic::{AtomicBool, Ordering};
547
548 use super::*;
549
550 #[derive(Debug, Default)]
551 struct MockAwsCredentialsProvider {
552 called: AtomicBool,
553 }
554
555 #[async_trait::async_trait]
556 impl CredentialProvider for MockAwsCredentialsProvider {
557 type Credential = ObjectStoreAwsCredential;
558
559 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
560 self.called.store(true, Ordering::Relaxed);
561 Ok(Arc::new(Self::Credential {
562 key_id: "".to_string(),
563 secret_key: "".to_string(),
564 token: None,
565 }))
566 }
567 }
568
569 #[tokio::test]
570 async fn test_injected_aws_creds_option_is_used() {
571 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
572 let registry = Arc::new(ObjectStoreRegistry::default());
573
574 let params = ObjectStoreParams {
575 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
576 ..ObjectStoreParams::default()
577 };
578
579 assert!(!mock_provider.called.load(Ordering::Relaxed));
581
582 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
583 .await
584 .unwrap();
585
586 let _ = store
588 .open(&Path::parse("/").unwrap())
589 .await
590 .unwrap()
591 .get_range(0..1)
592 .await;
593
594 assert!(mock_provider.called.load(Ordering::Relaxed));
596 }
597
598 #[test]
599 fn test_s3_path_parsing() {
600 let provider = AwsStoreProvider;
601
602 let cases = [
603 ("s3://bucket/path/to/file", "path/to/file"),
604 ("s3://bucket/测试path/to/file", "测试path/to/file"),
606 ("s3://bucket/path/&to/file", "path/&to/file"),
607 ("s3://bucket/path/=to/file", "path/=to/file"),
608 (
609 "s3+ddb://bucket/path/to/file?ddbTableName=test",
610 "path/to/file",
611 ),
612 ];
613
614 for (uri, expected_path) in cases {
615 let url = Url::parse(uri).unwrap();
616 let path = provider.extract_path(&url).unwrap();
617 let expected_path = Path::from(expected_path);
618 assert_eq!(path, expected_path)
619 }
620 }
621
622 #[test]
623 fn test_is_s3_express() {
624 let cases = [
625 (
626 "s3://bucket/path/to/file",
627 HashMap::from([("s3_express".to_string(), "true".to_string())]),
628 true,
629 ),
630 (
631 "s3://bucket/path/to/file",
632 HashMap::from([("s3_express".to_string(), "false".to_string())]),
633 false,
634 ),
635 ("s3://bucket/path/to/file", HashMap::from([]), false),
636 (
637 "s3://bucket--x-s3/path/to/file",
638 HashMap::from([("s3_express".to_string(), "true".to_string())]),
639 true,
640 ),
641 (
642 "s3://bucket--x-s3/path/to/file",
643 HashMap::from([("s3_express".to_string(), "false".to_string())]),
644 true, ),
646 ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
647 ];
648
649 for (uri, storage_map, expected) in cases {
650 let url = Url::parse(uri).unwrap();
651 let storage_options = StorageOptions(storage_map);
652 let is_s3_express = check_s3_express(&url, &storage_options);
653 assert_eq!(is_s3_express, expected);
654 }
655 }
656
657 #[tokio::test]
658 async fn test_use_opendal_flag() {
659 use crate::object_store::StorageOptionsAccessor;
660 let provider = AwsStoreProvider;
661 let url = Url::parse("s3://test-bucket/path").unwrap();
662 let params_with_flag = ObjectStoreParams {
663 storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
664 HashMap::from([
665 ("use_opendal".to_string(), "true".to_string()),
666 ("region".to_string(), "us-west-2".to_string()),
667 ]),
668 ))),
669 ..Default::default()
670 };
671
672 let store = provider
673 .new_store(url.clone(), ¶ms_with_flag)
674 .await
675 .unwrap();
676 assert_eq!(store.scheme, "s3");
677 }
678
679 #[derive(Debug)]
680 struct MockStorageOptionsProvider {
681 call_count: Arc<RwLock<usize>>,
682 expires_in_millis: Option<u64>,
683 }
684
685 impl MockStorageOptionsProvider {
686 fn new(expires_in_millis: Option<u64>) -> Self {
687 Self {
688 call_count: Arc::new(RwLock::new(0)),
689 expires_in_millis,
690 }
691 }
692
693 async fn get_call_count(&self) -> usize {
694 *self.call_count.read().await
695 }
696 }
697
698 #[async_trait::async_trait]
699 impl StorageOptionsProvider for MockStorageOptionsProvider {
700 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
701 let count = {
702 let mut c = self.call_count.write().await;
703 *c += 1;
704 *c
705 };
706
707 let mut options = HashMap::from([
708 ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
709 (
710 "aws_secret_access_key".to_string(),
711 format!("SECRET_{}", count),
712 ),
713 ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
714 ]);
715
716 if let Some(expires_in) = self.expires_in_millis {
717 let now_ms = SystemTime::now()
718 .duration_since(UNIX_EPOCH)
719 .unwrap()
720 .as_millis() as u64;
721 let expires_at = now_ms + expires_in;
722 options.insert("expires_at_millis".to_string(), expires_at.to_string());
723 }
724
725 Ok(Some(options))
726 }
727
728 fn provider_id(&self) -> String {
729 let ptr = Arc::as_ptr(&self.call_count) as usize;
730 format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
731 }
732 }
733
734 #[tokio::test]
735 async fn test_dynamic_credential_provider_with_initial_cache() {
736 MockClock::set_system_time(Duration::from_secs(100_000));
737
738 let now_ms = MockClock::system_time().as_millis() as u64;
739
740 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
742 600_000, )));
744
745 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
748 ("aws_access_key_id".to_string(), "AKID_CACHED".to_string()),
749 (
750 "aws_secret_access_key".to_string(),
751 "SECRET_CACHED".to_string(),
752 ),
753 ("aws_session_token".to_string(), "TOKEN_CACHED".to_string()),
754 ("expires_at_millis".to_string(), expires_at.to_string()),
755 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
757
758 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
759 mock.clone(),
760 initial_options,
761 );
762
763 let cred = provider.get_credential().await.unwrap();
765 assert_eq!(cred.key_id, "AKID_CACHED");
766 assert_eq!(cred.secret_key, "SECRET_CACHED");
767 assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
768
769 assert_eq!(mock.get_call_count().await, 0);
771 }
772
773 #[tokio::test]
774 async fn test_dynamic_credential_provider_with_expired_cache() {
775 MockClock::set_system_time(Duration::from_secs(100_000));
776
777 let now_ms = MockClock::system_time().as_millis() as u64;
778
779 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
781 600_000, )));
783
784 let expired_time = now_ms - 1_000; let initial_options = HashMap::from([
787 ("aws_access_key_id".to_string(), "AKID_EXPIRED".to_string()),
788 (
789 "aws_secret_access_key".to_string(),
790 "SECRET_EXPIRED".to_string(),
791 ),
792 ("expires_at_millis".to_string(), expired_time.to_string()),
793 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
795
796 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
797 mock.clone(),
798 initial_options,
799 );
800
801 let cred = provider.get_credential().await.unwrap();
803 assert_eq!(cred.key_id, "AKID_1");
804 assert_eq!(cred.secret_key, "SECRET_1");
805 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
806
807 assert_eq!(mock.get_call_count().await, 1);
809 }
810
811 #[tokio::test]
812 async fn test_dynamic_credential_provider_refresh_lead_time() {
813 MockClock::set_system_time(Duration::from_secs(100_000));
814
815 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
817 30_000, )));
819
820 let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
823
824 let cred = provider.get_credential().await.unwrap();
828 assert_eq!(cred.key_id, "AKID_1");
829 assert_eq!(mock.get_call_count().await, 1);
830
831 let cred = provider.get_credential().await.unwrap();
835 assert_eq!(cred.key_id, "AKID_2");
836 assert_eq!(mock.get_call_count().await, 2);
837 }
838
839 #[tokio::test]
840 async fn test_dynamic_credential_provider_no_initial_cache() {
841 MockClock::set_system_time(Duration::from_secs(100_000));
842
843 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
845 120_000, )));
847
848 let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
850
851 let cred = provider.get_credential().await.unwrap();
853 assert_eq!(cred.key_id, "AKID_1");
854 assert_eq!(cred.secret_key, "SECRET_1");
855 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
856 assert_eq!(mock.get_call_count().await, 1);
857
858 let cred = provider.get_credential().await.unwrap();
860 assert_eq!(cred.key_id, "AKID_1");
861 assert_eq!(mock.get_call_count().await, 1); MockClock::set_system_time(Duration::from_secs(100_000 + 90));
866 let cred = provider.get_credential().await.unwrap();
867 assert_eq!(cred.key_id, "AKID_2");
868 assert_eq!(cred.secret_key, "SECRET_2");
869 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
870 assert_eq!(mock.get_call_count().await, 2);
871
872 MockClock::set_system_time(Duration::from_secs(100_000 + 210));
874 let cred = provider.get_credential().await.unwrap();
875 assert_eq!(cred.key_id, "AKID_3");
876 assert_eq!(cred.secret_key, "SECRET_3");
877 assert_eq!(mock.get_call_count().await, 3);
878 }
879
880 #[tokio::test]
881 async fn test_dynamic_credential_provider_with_initial_options() {
882 MockClock::set_system_time(Duration::from_secs(100_000));
883
884 let now_ms = MockClock::system_time().as_millis() as u64;
885
886 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
888 600_000, )));
890
891 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
894 ("aws_access_key_id".to_string(), "AKID_INITIAL".to_string()),
895 (
896 "aws_secret_access_key".to_string(),
897 "SECRET_INITIAL".to_string(),
898 ),
899 ("aws_session_token".to_string(), "TOKEN_INITIAL".to_string()),
900 ("expires_at_millis".to_string(), expires_at.to_string()),
901 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
903
904 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
906 mock.clone(),
907 initial_options,
908 );
909
910 let cred = provider.get_credential().await.unwrap();
912 assert_eq!(cred.key_id, "AKID_INITIAL");
913 assert_eq!(cred.secret_key, "SECRET_INITIAL");
914 assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
915
916 assert_eq!(mock.get_call_count().await, 0);
918
919 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
922 let cred = provider.get_credential().await.unwrap();
923 assert_eq!(cred.key_id, "AKID_1");
924 assert_eq!(cred.secret_key, "SECRET_1");
925 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
926
927 assert_eq!(mock.get_call_count().await, 1);
929
930 MockClock::set_system_time(Duration::from_secs(100_000 + 660));
932 let cred = provider.get_credential().await.unwrap();
933 assert_eq!(cred.key_id, "AKID_2");
934 assert_eq!(cred.secret_key, "SECRET_2");
935 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
936
937 assert_eq!(mock.get_call_count().await, 2);
939
940 MockClock::set_system_time(Duration::from_secs(100_000 + 960));
942 let cred = provider.get_credential().await.unwrap();
943 assert_eq!(cred.key_id, "AKID_3");
944 assert_eq!(cred.secret_key, "SECRET_3");
945 assert_eq!(cred.token, Some("TOKEN_3".to_string()));
946
947 assert_eq!(mock.get_call_count().await, 3);
949 }
950
951 #[tokio::test]
952 async fn test_dynamic_credential_provider_concurrent_access() {
953 let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
955
956 let provider = Arc::new(DynamicStorageOptionsCredentialProvider::from_provider(
957 mock.clone(),
958 ));
959
960 let mut handles = vec![];
962 for i in 0..10 {
963 let provider = provider.clone();
964 let handle = tokio::spawn(async move {
965 let cred = provider.get_credential().await.unwrap();
966 assert_eq!(cred.key_id, "AKID_1");
968 assert_eq!(cred.secret_key, "SECRET_1");
969 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
970 i });
972 handles.push(handle);
973 }
974
975 let results: Vec<_> = futures::future::join_all(handles)
977 .await
978 .into_iter()
979 .map(|r| r.unwrap())
980 .collect();
981
982 assert_eq!(results.len(), 10);
984 for i in 0..10 {
985 assert!(results.contains(&i));
986 }
987
988 let call_count = mock.get_call_count().await;
991 assert_eq!(
992 call_count, 1,
993 "Provider should be called exactly once despite concurrent access"
994 );
995 }
996
997 #[tokio::test]
998 async fn test_dynamic_credential_provider_concurrent_refresh() {
999 MockClock::set_system_time(Duration::from_secs(100_000));
1000
1001 let now_ms = MockClock::system_time().as_millis() as u64;
1002
1003 let expires_at = now_ms - 1_000_000;
1005 let initial_options = HashMap::from([
1006 ("aws_access_key_id".to_string(), "AKID_OLD".to_string()),
1007 (
1008 "aws_secret_access_key".to_string(),
1009 "SECRET_OLD".to_string(),
1010 ),
1011 ("aws_session_token".to_string(), "TOKEN_OLD".to_string()),
1012 ("expires_at_millis".to_string(), expires_at.to_string()),
1013 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
1015
1016 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1018 3_600_000, )));
1020
1021 let provider = Arc::new(
1022 DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
1023 mock.clone(),
1024 initial_options,
1025 ),
1026 );
1027
1028 let mut handles = vec![];
1031 for i in 0..20 {
1032 let provider = provider.clone();
1033 let handle = tokio::spawn(async move {
1034 let cred = provider.get_credential().await.unwrap();
1035 assert_eq!(cred.key_id, "AKID_1");
1037 assert_eq!(cred.secret_key, "SECRET_1");
1038 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1039 i
1040 });
1041 handles.push(handle);
1042 }
1043
1044 let results: Vec<_> = futures::future::join_all(handles)
1046 .await
1047 .into_iter()
1048 .map(|r| r.unwrap())
1049 .collect();
1050
1051 assert_eq!(results.len(), 20);
1053
1054 let call_count = mock.get_call_count().await;
1057 assert!(
1058 call_count >= 1,
1059 "Provider should be called at least once, was called {} times",
1060 call_count
1061 );
1062
1063 assert!(
1065 call_count < 10,
1066 "Provider should not be called too many times due to lock contention, was called {} times",
1067 call_count
1068 );
1069 }
1070
1071 #[tokio::test]
1072 async fn test_explicit_aws_credentials_takes_precedence_over_accessor() {
1073 let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1075
1076 let accessor = Arc::new(StorageOptionsAccessor::with_provider(
1078 mock_storage_provider.clone(),
1079 ));
1080
1081 let explicit_cred_provider = Arc::new(MockAwsCredentialsProvider::default());
1083
1084 let (result, _region) = build_aws_credential(
1087 Duration::from_secs(300),
1088 Some(explicit_cred_provider.clone() as AwsCredentialProvider),
1089 None, Some("us-west-2".to_string()),
1091 Some(accessor),
1092 )
1093 .await
1094 .unwrap();
1095
1096 let cred = result.get_credential().await.unwrap();
1098
1099 assert!(explicit_cred_provider.called.load(Ordering::Relaxed));
1101
1102 assert_eq!(
1104 mock_storage_provider.get_call_count().await,
1105 0,
1106 "Storage options provider should not be called when explicit aws_credentials is provided"
1107 );
1108
1109 assert_eq!(cred.key_id, "");
1111 assert_eq!(cred.secret_key, "");
1112 }
1113
1114 #[tokio::test]
1115 async fn test_accessor_used_when_no_explicit_aws_credentials() {
1116 MockClock::set_system_time(Duration::from_secs(100_000));
1117
1118 let now_ms = MockClock::system_time().as_millis() as u64;
1119
1120 let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1122
1123 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
1126 (
1127 "aws_access_key_id".to_string(),
1128 "AKID_FROM_ACCESSOR".to_string(),
1129 ),
1130 (
1131 "aws_secret_access_key".to_string(),
1132 "SECRET_FROM_ACCESSOR".to_string(),
1133 ),
1134 (
1135 "aws_session_token".to_string(),
1136 "TOKEN_FROM_ACCESSOR".to_string(),
1137 ),
1138 ("expires_at_millis".to_string(), expires_at.to_string()),
1139 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
1141
1142 let accessor = Arc::new(StorageOptionsAccessor::with_initial_and_provider(
1144 initial_options,
1145 mock_storage_provider.clone(),
1146 ));
1147
1148 let (result, _region) = build_aws_credential(
1150 Duration::from_secs(300),
1151 None, None, Some("us-west-2".to_string()),
1154 Some(accessor),
1155 )
1156 .await
1157 .unwrap();
1158
1159 let cred = result.get_credential().await.unwrap();
1161 assert_eq!(cred.key_id, "AKID_FROM_ACCESSOR");
1162 assert_eq!(cred.secret_key, "SECRET_FROM_ACCESSOR");
1163
1164 assert_eq!(mock_storage_provider.get_call_count().await, 0);
1166
1167 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1169
1170 let cred = result.get_credential().await.unwrap();
1172 assert_eq!(cred.key_id, "AKID_1");
1173 assert_eq!(cred.secret_key, "SECRET_1");
1174
1175 assert_eq!(mock_storage_provider.get_call_count().await, 1);
1177 }
1178}