1use std::{collections::HashMap, fmt, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{services::S3, Operator};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19 aws::{
20 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
21 AwsCredentialProvider,
22 },
23 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
24 StaticCredentialProvider,
25};
26use snafu::location;
27use tokio::sync::RwLock;
28use url::Url;
29
30use crate::object_store::{
31 ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsAccessor,
32 StorageOptionsProvider, DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM,
33 DEFAULT_MAX_IOP_SIZE,
34};
35use lance_core::error::{Error, Result};
36
37#[derive(Default, Debug)]
38pub struct AwsStoreProvider;
39
40impl AwsStoreProvider {
41 async fn build_amazon_s3_store(
42 &self,
43 base_path: &mut Url,
44 params: &ObjectStoreParams,
45 storage_options: &StorageOptions,
46 is_s3_express: bool,
47 ) -> Result<Arc<dyn OSObjectStore>> {
48 let max_retries = storage_options.client_max_retries();
49 let retry_timeout = storage_options.client_retry_timeout();
50 let retry_config = RetryConfig {
51 backoff: Default::default(),
52 max_retries,
53 retry_timeout: Duration::from_secs(retry_timeout),
54 };
55
56 let mut s3_storage_options = storage_options.as_s3_options();
57 let region = resolve_s3_region(base_path, &s3_storage_options).await?;
58
59 let accessor = params.get_accessor();
61
62 let (aws_creds, region) = build_aws_credential(
63 params.s3_credentials_refresh_offset,
64 params.aws_credentials.clone(),
65 Some(&s3_storage_options),
66 region,
67 accessor,
68 )
69 .await?;
70
71 if is_s3_express {
73 s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
74 }
75
76 base_path.set_scheme("s3").unwrap();
78 base_path.set_query(None);
79
80 let mut builder = AmazonS3Builder::new();
82 for (key, value) in s3_storage_options {
83 builder = builder.with_config(key, value);
84 }
85 builder = builder
86 .with_url(base_path.as_ref())
87 .with_credentials(aws_creds)
88 .with_retry(retry_config)
89 .with_region(region);
90
91 Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
92 }
93
94 async fn build_opendal_s3_store(
95 &self,
96 base_path: &Url,
97 storage_options: &StorageOptions,
98 ) -> Result<Arc<dyn OSObjectStore>> {
99 let bucket = base_path
100 .host_str()
101 .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name", location!()))?
102 .to_string();
103
104 let prefix = base_path.path().trim_start_matches('/').to_string();
105
106 let mut config_map: HashMap<String, String> = storage_options.0.clone();
109
110 config_map.insert("bucket".to_string(), bucket);
112
113 if !prefix.is_empty() {
114 config_map.insert("root".to_string(), "/".to_string());
115 }
116
117 let operator = Operator::from_iter::<S3>(config_map)
118 .map_err(|e| {
119 Error::invalid_input(
120 format!("Failed to create S3 operator: {:?}", e),
121 location!(),
122 )
123 })?
124 .finish();
125
126 Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
127 }
128}
129
130#[async_trait::async_trait]
131impl ObjectStoreProvider for AwsStoreProvider {
132 async fn new_store(
133 &self,
134 mut base_path: Url,
135 params: &ObjectStoreParams,
136 ) -> Result<ObjectStore> {
137 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
138 let mut storage_options =
139 StorageOptions(params.storage_options().cloned().unwrap_or_default());
140 storage_options.with_env_s3();
141 let download_retry_count = storage_options.download_retry_count();
142
143 let use_opendal = storage_options
144 .0
145 .get("use_opendal")
146 .map(|v| v == "true")
147 .unwrap_or(false);
148
149 let is_s3_express = check_s3_express(&base_path, &storage_options);
151
152 let use_constant_size_upload_parts = storage_options
153 .0
154 .get("aws_endpoint")
155 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
156 .unwrap_or(false);
157
158 let inner = if use_opendal {
159 self.build_opendal_s3_store(&base_path, &storage_options)
161 .await?
162 } else {
163 self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
165 .await?
166 };
167
168 Ok(ObjectStore {
169 inner,
170 scheme: String::from(base_path.scheme()),
171 block_size,
172 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
173 use_constant_size_upload_parts,
174 list_is_lexically_ordered: !is_s3_express,
175 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
176 download_retry_count,
177 io_tracker: Default::default(),
178 store_prefix: self
179 .calculate_object_store_prefix(&base_path, params.storage_options())?,
180 })
181 }
182}
183
184fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
186 storage_options
187 .0
188 .get("s3_express")
189 .map(|v| v == "true")
190 .unwrap_or(false)
191 || url.authority().ends_with("--x-s3")
192}
193
194async fn resolve_s3_region(
202 url: &Url,
203 storage_options: &HashMap<AmazonS3ConfigKey, String>,
204) -> Result<Option<String>> {
205 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
206 Ok(Some(region.clone()))
207 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
208 let bucket = url.host_str().ok_or_else(|| {
211 Error::invalid_input(
212 format!("Could not parse bucket from url: {}", url),
213 location!(),
214 )
215 })?;
216
217 let mut client_options = ClientOptions::default();
218 for (key, value) in storage_options {
219 if let AmazonS3ConfigKey::Client(client_key) = key {
220 client_options = client_options.with_config(*client_key, value.clone());
221 }
222 }
223
224 let bucket_region =
225 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
226 Ok(Some(bucket_region))
227 } else {
228 Ok(None)
229 }
230}
231
232pub async fn build_aws_credential(
249 credentials_refresh_offset: Duration,
250 credentials: Option<AwsCredentialProvider>,
251 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
252 region: Option<String>,
253 storage_options_accessor: Option<Arc<StorageOptionsAccessor>>,
254) -> Result<(AwsCredentialProvider, String)> {
255 use aws_config::meta::region::RegionProviderChain;
256 const DEFAULT_REGION: &str = "us-west-2";
257
258 let region = if let Some(region) = region {
259 region
260 } else {
261 RegionProviderChain::default_provider()
262 .or_else(DEFAULT_REGION)
263 .region()
264 .await
265 .map(|r| r.as_ref().to_string())
266 .unwrap_or(DEFAULT_REGION.to_string())
267 };
268
269 let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
270
271 if let Some(accessor) = storage_options_accessor {
273 if accessor.has_provider() {
274 if let Some(creds) = credentials {
276 return Ok((creds, region));
277 }
278 return Ok((
280 Arc::new(DynamicStorageOptionsCredentialProvider::new(accessor)),
281 region,
282 ));
283 }
284 }
285
286 if let Some(creds) = credentials {
288 Ok((creds, region))
289 } else if let Some(creds) = storage_options_credentials {
290 Ok((Arc::new(creds), region))
291 } else {
292 let credentials_provider = DefaultCredentialsChain::builder().build().await;
293
294 Ok((
295 Arc::new(AwsCredentialAdapter::new(
296 Arc::new(credentials_provider),
297 credentials_refresh_offset,
298 )),
299 region,
300 ))
301 }
302}
303
304fn extract_static_s3_credentials(
305 options: &HashMap<AmazonS3ConfigKey, String>,
306) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
307 let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
308 let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
309 let token = options.get(&AmazonS3ConfigKey::Token).cloned();
310 match (key_id, secret_key, token) {
311 (Some(key_id), Some(secret_key), token) => {
312 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
313 key_id,
314 secret_key,
315 token,
316 }))
317 }
318 _ => None,
319 }
320}
321
322#[derive(Debug)]
324pub struct AwsCredentialAdapter {
325 pub inner: Arc<dyn ProvideCredentials>,
326
327 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
329
330 credentials_refresh_offset: Duration,
332}
333
334impl AwsCredentialAdapter {
335 pub fn new(
336 provider: Arc<dyn ProvideCredentials>,
337 credentials_refresh_offset: Duration,
338 ) -> Self {
339 Self {
340 inner: provider,
341 cache: Arc::new(RwLock::new(HashMap::new())),
342 credentials_refresh_offset,
343 }
344 }
345}
346
347const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
348
349fn to_system_time(time: std::time::SystemTime) -> SystemTime {
351 let duration_since_epoch = time
352 .duration_since(std::time::UNIX_EPOCH)
353 .expect("time should be after UNIX_EPOCH");
354 UNIX_EPOCH + duration_since_epoch
355}
356
357#[async_trait::async_trait]
358impl CredentialProvider for AwsCredentialAdapter {
359 type Credential = ObjectStoreAwsCredential;
360
361 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
362 let cached_creds = {
363 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
364 let expired = cache_value
365 .clone()
366 .map(|cred| {
367 cred.expiry()
368 .map(|exp| {
369 to_system_time(exp)
370 .checked_sub(self.credentials_refresh_offset)
371 .expect("this time should always be valid")
372 < SystemTime::now()
373 })
374 .unwrap_or(false)
376 })
377 .unwrap_or(true); if expired {
379 None
380 } else {
381 cache_value.clone()
382 }
383 };
384
385 if let Some(creds) = cached_creds {
386 Ok(Arc::new(Self::Credential {
387 key_id: creds.access_key_id().to_string(),
388 secret_key: creds.secret_access_key().to_string(),
389 token: creds.session_token().map(|s| s.to_string()),
390 }))
391 } else {
392 let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
393 |e| Error::Internal {
394 message: format!("Failed to get AWS credentials: {:?}", e),
395 location: location!(),
396 },
397 )?);
398
399 self.cache
400 .write()
401 .await
402 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
403
404 Ok(Arc::new(Self::Credential {
405 key_id: refreshed_creds.access_key_id().to_string(),
406 secret_key: refreshed_creds.secret_access_key().to_string(),
407 token: refreshed_creds.session_token().map(|s| s.to_string()),
408 }))
409 }
410 }
411}
412
413impl StorageOptions {
414 pub fn with_env_s3(&mut self) {
416 for (os_key, os_value) in std::env::vars_os() {
417 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
418 if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
419 if !self.0.contains_key(config_key.as_ref()) {
420 self.0
421 .insert(config_key.as_ref().to_string(), value.to_string());
422 }
423 }
424 }
425 }
426 }
427
428 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
430 self.0
431 .iter()
432 .filter_map(|(key, value)| {
433 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
434 Some((s3_key, value.clone()))
435 })
436 .collect()
437 }
438}
439
440impl ObjectStoreParams {
441 pub fn with_aws_credentials(
443 aws_credentials: Option<AwsCredentialProvider>,
444 region: Option<String>,
445 ) -> Self {
446 let storage_options_accessor = region.map(|region| {
447 let opts: HashMap<String, String> =
448 [("region".into(), region)].iter().cloned().collect();
449 Arc::new(StorageOptionsAccessor::with_static_options(opts))
450 });
451 Self {
452 aws_credentials,
453 storage_options_accessor,
454 ..Default::default()
455 }
456 }
457}
458
459pub struct DynamicStorageOptionsCredentialProvider {
473 accessor: Arc<StorageOptionsAccessor>,
474}
475
476impl fmt::Debug for DynamicStorageOptionsCredentialProvider {
477 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478 f.debug_struct("DynamicStorageOptionsCredentialProvider")
479 .field("accessor", &self.accessor)
480 .finish()
481 }
482}
483
484impl DynamicStorageOptionsCredentialProvider {
485 pub fn new(accessor: Arc<StorageOptionsAccessor>) -> Self {
487 Self { accessor }
488 }
489
490 pub fn from_provider(provider: Arc<dyn StorageOptionsProvider>) -> Self {
499 Self {
500 accessor: Arc::new(StorageOptionsAccessor::with_provider(provider)),
501 }
502 }
503
504 pub fn from_provider_with_initial(
514 provider: Arc<dyn StorageOptionsProvider>,
515 initial_options: HashMap<String, String>,
516 ) -> Self {
517 Self {
518 accessor: Arc::new(StorageOptionsAccessor::with_initial_and_provider(
519 initial_options,
520 provider,
521 )),
522 }
523 }
524}
525
526#[async_trait::async_trait]
527impl CredentialProvider for DynamicStorageOptionsCredentialProvider {
528 type Credential = ObjectStoreAwsCredential;
529
530 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
531 let storage_options = self.accessor.get_storage_options().await.map_err(|e| {
532 object_store::Error::Generic {
533 store: "DynamicStorageOptionsCredentialProvider",
534 source: Box::new(e),
535 }
536 })?;
537
538 let s3_options = storage_options.as_s3_options();
539 let static_creds = extract_static_s3_credentials(&s3_options).ok_or_else(|| {
540 object_store::Error::Generic {
541 store: "DynamicStorageOptionsCredentialProvider",
542 source: "Missing required credentials in storage options".into(),
543 }
544 })?;
545
546 static_creds
547 .get_credential()
548 .await
549 .map_err(|e| object_store::Error::Generic {
550 store: "DynamicStorageOptionsCredentialProvider",
551 source: Box::new(e),
552 })
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use crate::object_store::ObjectStoreRegistry;
559 use mock_instant::thread_local::MockClock;
560 use object_store::path::Path;
561 use std::sync::atomic::{AtomicBool, Ordering};
562
563 use super::*;
564
565 #[derive(Debug, Default)]
566 struct MockAwsCredentialsProvider {
567 called: AtomicBool,
568 }
569
570 #[async_trait::async_trait]
571 impl CredentialProvider for MockAwsCredentialsProvider {
572 type Credential = ObjectStoreAwsCredential;
573
574 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
575 self.called.store(true, Ordering::Relaxed);
576 Ok(Arc::new(Self::Credential {
577 key_id: "".to_string(),
578 secret_key: "".to_string(),
579 token: None,
580 }))
581 }
582 }
583
584 #[tokio::test]
585 async fn test_injected_aws_creds_option_is_used() {
586 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
587 let registry = Arc::new(ObjectStoreRegistry::default());
588
589 let params = ObjectStoreParams {
590 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
591 ..ObjectStoreParams::default()
592 };
593
594 assert!(!mock_provider.called.load(Ordering::Relaxed));
596
597 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
598 .await
599 .unwrap();
600
601 let _ = store
603 .open(&Path::parse("/").unwrap())
604 .await
605 .unwrap()
606 .get_range(0..1)
607 .await;
608
609 assert!(mock_provider.called.load(Ordering::Relaxed));
611 }
612
613 #[test]
614 fn test_s3_path_parsing() {
615 let provider = AwsStoreProvider;
616
617 let cases = [
618 ("s3://bucket/path/to/file", "path/to/file"),
619 ("s3://bucket/测试path/to/file", "测试path/to/file"),
621 ("s3://bucket/path/&to/file", "path/&to/file"),
622 ("s3://bucket/path/=to/file", "path/=to/file"),
623 (
624 "s3+ddb://bucket/path/to/file?ddbTableName=test",
625 "path/to/file",
626 ),
627 ];
628
629 for (uri, expected_path) in cases {
630 let url = Url::parse(uri).unwrap();
631 let path = provider.extract_path(&url).unwrap();
632 let expected_path = Path::from(expected_path);
633 assert_eq!(path, expected_path)
634 }
635 }
636
637 #[test]
638 fn test_is_s3_express() {
639 let cases = [
640 (
641 "s3://bucket/path/to/file",
642 HashMap::from([("s3_express".to_string(), "true".to_string())]),
643 true,
644 ),
645 (
646 "s3://bucket/path/to/file",
647 HashMap::from([("s3_express".to_string(), "false".to_string())]),
648 false,
649 ),
650 ("s3://bucket/path/to/file", HashMap::from([]), false),
651 (
652 "s3://bucket--x-s3/path/to/file",
653 HashMap::from([("s3_express".to_string(), "true".to_string())]),
654 true,
655 ),
656 (
657 "s3://bucket--x-s3/path/to/file",
658 HashMap::from([("s3_express".to_string(), "false".to_string())]),
659 true, ),
661 ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
662 ];
663
664 for (uri, storage_map, expected) in cases {
665 let url = Url::parse(uri).unwrap();
666 let storage_options = StorageOptions(storage_map);
667 let is_s3_express = check_s3_express(&url, &storage_options);
668 assert_eq!(is_s3_express, expected);
669 }
670 }
671
672 #[tokio::test]
673 async fn test_use_opendal_flag() {
674 use crate::object_store::StorageOptionsAccessor;
675 let provider = AwsStoreProvider;
676 let url = Url::parse("s3://test-bucket/path").unwrap();
677 let params_with_flag = ObjectStoreParams {
678 storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
679 HashMap::from([
680 ("use_opendal".to_string(), "true".to_string()),
681 ("region".to_string(), "us-west-2".to_string()),
682 ]),
683 ))),
684 ..Default::default()
685 };
686
687 let store = provider
688 .new_store(url.clone(), ¶ms_with_flag)
689 .await
690 .unwrap();
691 assert_eq!(store.scheme, "s3");
692 }
693
694 #[derive(Debug)]
695 struct MockStorageOptionsProvider {
696 call_count: Arc<RwLock<usize>>,
697 expires_in_millis: Option<u64>,
698 }
699
700 impl MockStorageOptionsProvider {
701 fn new(expires_in_millis: Option<u64>) -> Self {
702 Self {
703 call_count: Arc::new(RwLock::new(0)),
704 expires_in_millis,
705 }
706 }
707
708 async fn get_call_count(&self) -> usize {
709 *self.call_count.read().await
710 }
711 }
712
713 #[async_trait::async_trait]
714 impl StorageOptionsProvider for MockStorageOptionsProvider {
715 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
716 let count = {
717 let mut c = self.call_count.write().await;
718 *c += 1;
719 *c
720 };
721
722 let mut options = HashMap::from([
723 ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
724 (
725 "aws_secret_access_key".to_string(),
726 format!("SECRET_{}", count),
727 ),
728 ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
729 ]);
730
731 if let Some(expires_in) = self.expires_in_millis {
732 let now_ms = SystemTime::now()
733 .duration_since(UNIX_EPOCH)
734 .unwrap()
735 .as_millis() as u64;
736 let expires_at = now_ms + expires_in;
737 options.insert("expires_at_millis".to_string(), expires_at.to_string());
738 }
739
740 Ok(Some(options))
741 }
742
743 fn provider_id(&self) -> String {
744 let ptr = Arc::as_ptr(&self.call_count) as usize;
745 format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
746 }
747 }
748
749 #[tokio::test]
750 async fn test_dynamic_credential_provider_with_initial_cache() {
751 MockClock::set_system_time(Duration::from_secs(100_000));
752
753 let now_ms = MockClock::system_time().as_millis() as u64;
754
755 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
757 600_000, )));
759
760 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
763 ("aws_access_key_id".to_string(), "AKID_CACHED".to_string()),
764 (
765 "aws_secret_access_key".to_string(),
766 "SECRET_CACHED".to_string(),
767 ),
768 ("aws_session_token".to_string(), "TOKEN_CACHED".to_string()),
769 ("expires_at_millis".to_string(), expires_at.to_string()),
770 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
772
773 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
774 mock.clone(),
775 initial_options,
776 );
777
778 let cred = provider.get_credential().await.unwrap();
780 assert_eq!(cred.key_id, "AKID_CACHED");
781 assert_eq!(cred.secret_key, "SECRET_CACHED");
782 assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
783
784 assert_eq!(mock.get_call_count().await, 0);
786 }
787
788 #[tokio::test]
789 async fn test_dynamic_credential_provider_with_expired_cache() {
790 MockClock::set_system_time(Duration::from_secs(100_000));
791
792 let now_ms = MockClock::system_time().as_millis() as u64;
793
794 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
796 600_000, )));
798
799 let expired_time = now_ms - 1_000; let initial_options = HashMap::from([
802 ("aws_access_key_id".to_string(), "AKID_EXPIRED".to_string()),
803 (
804 "aws_secret_access_key".to_string(),
805 "SECRET_EXPIRED".to_string(),
806 ),
807 ("expires_at_millis".to_string(), expired_time.to_string()),
808 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
810
811 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
812 mock.clone(),
813 initial_options,
814 );
815
816 let cred = provider.get_credential().await.unwrap();
818 assert_eq!(cred.key_id, "AKID_1");
819 assert_eq!(cred.secret_key, "SECRET_1");
820 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
821
822 assert_eq!(mock.get_call_count().await, 1);
824 }
825
826 #[tokio::test]
827 async fn test_dynamic_credential_provider_refresh_lead_time() {
828 MockClock::set_system_time(Duration::from_secs(100_000));
829
830 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
832 30_000, )));
834
835 let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
838
839 let cred = provider.get_credential().await.unwrap();
843 assert_eq!(cred.key_id, "AKID_1");
844 assert_eq!(mock.get_call_count().await, 1);
845
846 let cred = provider.get_credential().await.unwrap();
850 assert_eq!(cred.key_id, "AKID_2");
851 assert_eq!(mock.get_call_count().await, 2);
852 }
853
854 #[tokio::test]
855 async fn test_dynamic_credential_provider_no_initial_cache() {
856 MockClock::set_system_time(Duration::from_secs(100_000));
857
858 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
860 120_000, )));
862
863 let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
865
866 let cred = provider.get_credential().await.unwrap();
868 assert_eq!(cred.key_id, "AKID_1");
869 assert_eq!(cred.secret_key, "SECRET_1");
870 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
871 assert_eq!(mock.get_call_count().await, 1);
872
873 let cred = provider.get_credential().await.unwrap();
875 assert_eq!(cred.key_id, "AKID_1");
876 assert_eq!(mock.get_call_count().await, 1); MockClock::set_system_time(Duration::from_secs(100_000 + 90));
881 let cred = provider.get_credential().await.unwrap();
882 assert_eq!(cred.key_id, "AKID_2");
883 assert_eq!(cred.secret_key, "SECRET_2");
884 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
885 assert_eq!(mock.get_call_count().await, 2);
886
887 MockClock::set_system_time(Duration::from_secs(100_000 + 210));
889 let cred = provider.get_credential().await.unwrap();
890 assert_eq!(cred.key_id, "AKID_3");
891 assert_eq!(cred.secret_key, "SECRET_3");
892 assert_eq!(mock.get_call_count().await, 3);
893 }
894
895 #[tokio::test]
896 async fn test_dynamic_credential_provider_with_initial_options() {
897 MockClock::set_system_time(Duration::from_secs(100_000));
898
899 let now_ms = MockClock::system_time().as_millis() as u64;
900
901 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
903 600_000, )));
905
906 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
909 ("aws_access_key_id".to_string(), "AKID_INITIAL".to_string()),
910 (
911 "aws_secret_access_key".to_string(),
912 "SECRET_INITIAL".to_string(),
913 ),
914 ("aws_session_token".to_string(), "TOKEN_INITIAL".to_string()),
915 ("expires_at_millis".to_string(), expires_at.to_string()),
916 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
918
919 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
921 mock.clone(),
922 initial_options,
923 );
924
925 let cred = provider.get_credential().await.unwrap();
927 assert_eq!(cred.key_id, "AKID_INITIAL");
928 assert_eq!(cred.secret_key, "SECRET_INITIAL");
929 assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
930
931 assert_eq!(mock.get_call_count().await, 0);
933
934 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
937 let cred = provider.get_credential().await.unwrap();
938 assert_eq!(cred.key_id, "AKID_1");
939 assert_eq!(cred.secret_key, "SECRET_1");
940 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
941
942 assert_eq!(mock.get_call_count().await, 1);
944
945 MockClock::set_system_time(Duration::from_secs(100_000 + 660));
947 let cred = provider.get_credential().await.unwrap();
948 assert_eq!(cred.key_id, "AKID_2");
949 assert_eq!(cred.secret_key, "SECRET_2");
950 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
951
952 assert_eq!(mock.get_call_count().await, 2);
954
955 MockClock::set_system_time(Duration::from_secs(100_000 + 960));
957 let cred = provider.get_credential().await.unwrap();
958 assert_eq!(cred.key_id, "AKID_3");
959 assert_eq!(cred.secret_key, "SECRET_3");
960 assert_eq!(cred.token, Some("TOKEN_3".to_string()));
961
962 assert_eq!(mock.get_call_count().await, 3);
964 }
965
966 #[tokio::test]
967 async fn test_dynamic_credential_provider_concurrent_access() {
968 let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
970
971 let provider = Arc::new(DynamicStorageOptionsCredentialProvider::from_provider(
972 mock.clone(),
973 ));
974
975 let mut handles = vec![];
977 for i in 0..10 {
978 let provider = provider.clone();
979 let handle = tokio::spawn(async move {
980 let cred = provider.get_credential().await.unwrap();
981 assert_eq!(cred.key_id, "AKID_1");
983 assert_eq!(cred.secret_key, "SECRET_1");
984 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
985 i });
987 handles.push(handle);
988 }
989
990 let results: Vec<_> = futures::future::join_all(handles)
992 .await
993 .into_iter()
994 .map(|r| r.unwrap())
995 .collect();
996
997 assert_eq!(results.len(), 10);
999 for i in 0..10 {
1000 assert!(results.contains(&i));
1001 }
1002
1003 let call_count = mock.get_call_count().await;
1006 assert_eq!(
1007 call_count, 1,
1008 "Provider should be called exactly once despite concurrent access"
1009 );
1010 }
1011
1012 #[tokio::test]
1013 async fn test_dynamic_credential_provider_concurrent_refresh() {
1014 MockClock::set_system_time(Duration::from_secs(100_000));
1015
1016 let now_ms = MockClock::system_time().as_millis() as u64;
1017
1018 let expires_at = now_ms - 1_000_000;
1020 let initial_options = HashMap::from([
1021 ("aws_access_key_id".to_string(), "AKID_OLD".to_string()),
1022 (
1023 "aws_secret_access_key".to_string(),
1024 "SECRET_OLD".to_string(),
1025 ),
1026 ("aws_session_token".to_string(), "TOKEN_OLD".to_string()),
1027 ("expires_at_millis".to_string(), expires_at.to_string()),
1028 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
1030
1031 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1033 3_600_000, )));
1035
1036 let provider = Arc::new(
1037 DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
1038 mock.clone(),
1039 initial_options,
1040 ),
1041 );
1042
1043 let mut handles = vec![];
1046 for i in 0..20 {
1047 let provider = provider.clone();
1048 let handle = tokio::spawn(async move {
1049 let cred = provider.get_credential().await.unwrap();
1050 assert_eq!(cred.key_id, "AKID_1");
1052 assert_eq!(cred.secret_key, "SECRET_1");
1053 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1054 i
1055 });
1056 handles.push(handle);
1057 }
1058
1059 let results: Vec<_> = futures::future::join_all(handles)
1061 .await
1062 .into_iter()
1063 .map(|r| r.unwrap())
1064 .collect();
1065
1066 assert_eq!(results.len(), 20);
1068
1069 let call_count = mock.get_call_count().await;
1072 assert!(
1073 call_count >= 1,
1074 "Provider should be called at least once, was called {} times",
1075 call_count
1076 );
1077
1078 assert!(
1080 call_count < 10,
1081 "Provider should not be called too many times due to lock contention, was called {} times",
1082 call_count
1083 );
1084 }
1085
1086 #[tokio::test]
1087 async fn test_explicit_aws_credentials_takes_precedence_over_accessor() {
1088 let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1090
1091 let accessor = Arc::new(StorageOptionsAccessor::with_provider(
1093 mock_storage_provider.clone(),
1094 ));
1095
1096 let explicit_cred_provider = Arc::new(MockAwsCredentialsProvider::default());
1098
1099 let (result, _region) = build_aws_credential(
1102 Duration::from_secs(300),
1103 Some(explicit_cred_provider.clone() as AwsCredentialProvider),
1104 None, Some("us-west-2".to_string()),
1106 Some(accessor),
1107 )
1108 .await
1109 .unwrap();
1110
1111 let cred = result.get_credential().await.unwrap();
1113
1114 assert!(explicit_cred_provider.called.load(Ordering::Relaxed));
1116
1117 assert_eq!(
1119 mock_storage_provider.get_call_count().await,
1120 0,
1121 "Storage options provider should not be called when explicit aws_credentials is provided"
1122 );
1123
1124 assert_eq!(cred.key_id, "");
1126 assert_eq!(cred.secret_key, "");
1127 }
1128
1129 #[tokio::test]
1130 async fn test_accessor_used_when_no_explicit_aws_credentials() {
1131 MockClock::set_system_time(Duration::from_secs(100_000));
1132
1133 let now_ms = MockClock::system_time().as_millis() as u64;
1134
1135 let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1137
1138 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
1141 (
1142 "aws_access_key_id".to_string(),
1143 "AKID_FROM_ACCESSOR".to_string(),
1144 ),
1145 (
1146 "aws_secret_access_key".to_string(),
1147 "SECRET_FROM_ACCESSOR".to_string(),
1148 ),
1149 (
1150 "aws_session_token".to_string(),
1151 "TOKEN_FROM_ACCESSOR".to_string(),
1152 ),
1153 ("expires_at_millis".to_string(), expires_at.to_string()),
1154 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
1156
1157 let accessor = Arc::new(StorageOptionsAccessor::with_initial_and_provider(
1159 initial_options,
1160 mock_storage_provider.clone(),
1161 ));
1162
1163 let (result, _region) = build_aws_credential(
1165 Duration::from_secs(300),
1166 None, None, Some("us-west-2".to_string()),
1169 Some(accessor),
1170 )
1171 .await
1172 .unwrap();
1173
1174 let cred = result.get_credential().await.unwrap();
1176 assert_eq!(cred.key_id, "AKID_FROM_ACCESSOR");
1177 assert_eq!(cred.secret_key, "SECRET_FROM_ACCESSOR");
1178
1179 assert_eq!(mock_storage_provider.get_call_count().await, 0);
1181
1182 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1184
1185 let cred = result.get_credential().await.unwrap();
1187 assert_eq!(cred.key_id, "AKID_1");
1188 assert_eq!(cred.secret_key, "SECRET_1");
1189
1190 assert_eq!(mock_storage_provider.get_call_count().await, 1);
1192 }
1193}