1use std::{collections::HashMap, fmt, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{Operator, services::S3};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
20 StaticCredentialProvider,
21 aws::{
22 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
23 AwsCredentialProvider,
24 },
25};
26use tokio::sync::RwLock;
27use url::Url;
28
29use crate::object_store::{
30 DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE, ObjectStore,
31 ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsAccessor,
32 StorageOptionsProvider,
33};
34use lance_core::error::{Error, Result};
35
36#[derive(Default, Debug)]
37pub struct AwsStoreProvider;
38
39impl AwsStoreProvider {
40 async fn build_amazon_s3_store(
41 &self,
42 base_path: &mut Url,
43 params: &ObjectStoreParams,
44 storage_options: &StorageOptions,
45 is_s3_express: bool,
46 ) -> Result<Arc<dyn OSObjectStore>> {
47 let max_retries = storage_options.client_max_retries();
48 let retry_timeout = storage_options.client_retry_timeout();
49 let retry_config = RetryConfig {
50 backoff: Default::default(),
51 max_retries,
52 retry_timeout: Duration::from_secs(retry_timeout),
53 };
54
55 let mut s3_storage_options = storage_options.as_s3_options();
56 let region = resolve_s3_region(base_path, &s3_storage_options).await?;
57
58 let accessor = params.get_accessor();
60
61 let (aws_creds, region) = build_aws_credential(
62 params.s3_credentials_refresh_offset,
63 params.aws_credentials.clone(),
64 Some(&s3_storage_options),
65 region,
66 accessor,
67 )
68 .await?;
69
70 if is_s3_express {
72 s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
73 }
74
75 base_path.set_scheme("s3").unwrap();
77 base_path.set_query(None);
78
79 let mut builder =
81 AmazonS3Builder::new().with_client_options(storage_options.client_options()?);
82 for (key, value) in s3_storage_options {
83 builder = builder.with_config(key, value);
84 }
85 builder = builder
86 .with_url(base_path.as_ref())
87 .with_credentials(aws_creds)
88 .with_retry(retry_config)
89 .with_region(region);
90
91 Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
92 }
93
94 async fn build_opendal_s3_store(
95 &self,
96 base_path: &Url,
97 storage_options: &StorageOptions,
98 ) -> Result<Arc<dyn OSObjectStore>> {
99 let bucket = base_path
100 .host_str()
101 .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name"))?
102 .to_string();
103
104 let prefix = base_path.path().trim_start_matches('/').to_string();
105
106 let mut config_map: HashMap<String, String> = storage_options.0.clone();
109
110 config_map.insert("bucket".to_string(), bucket);
112
113 if !prefix.is_empty() {
114 config_map.insert("root".to_string(), "/".to_string());
115 }
116
117 let operator = Operator::from_iter::<S3>(config_map)
118 .map_err(|e| Error::invalid_input(format!("Failed to create S3 operator: {:?}", e)))?
119 .finish();
120
121 Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
122 }
123}
124
125#[async_trait::async_trait]
126impl ObjectStoreProvider for AwsStoreProvider {
127 async fn new_store(
128 &self,
129 mut base_path: Url,
130 params: &ObjectStoreParams,
131 ) -> Result<ObjectStore> {
132 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
133 let mut storage_options =
134 StorageOptions(params.storage_options().cloned().unwrap_or_default());
135 storage_options.with_env_s3();
136 let download_retry_count = storage_options.download_retry_count();
137
138 let use_opendal = storage_options
139 .0
140 .get("use_opendal")
141 .map(|v| v == "true")
142 .unwrap_or(false);
143
144 let is_s3_express = check_s3_express(&base_path, &storage_options);
146
147 let use_constant_size_upload_parts = storage_options
148 .0
149 .get("aws_endpoint")
150 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
151 .unwrap_or(false);
152
153 let inner = if use_opendal {
154 self.build_opendal_s3_store(&base_path, &storage_options)
156 .await?
157 } else {
158 self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
160 .await?
161 };
162
163 Ok(ObjectStore {
164 inner,
165 scheme: String::from(base_path.scheme()),
166 block_size,
167 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
168 use_constant_size_upload_parts,
169 list_is_lexically_ordered: !is_s3_express,
170 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
171 download_retry_count,
172 io_tracker: Default::default(),
173 store_prefix: self
174 .calculate_object_store_prefix(&base_path, params.storage_options())?,
175 })
176 }
177}
178
179fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
181 storage_options
182 .0
183 .get("s3_express")
184 .map(|v| v == "true")
185 .unwrap_or(false)
186 || url.authority().ends_with("--x-s3")
187}
188
189async fn resolve_s3_region(
197 url: &Url,
198 storage_options: &HashMap<AmazonS3ConfigKey, String>,
199) -> Result<Option<String>> {
200 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
201 Ok(Some(region.clone()))
202 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
203 let bucket = url.host_str().ok_or_else(|| {
206 Error::invalid_input(format!("Could not parse bucket from url: {}", url))
207 })?;
208
209 let mut client_options = ClientOptions::default();
210 for (key, value) in storage_options {
211 if let AmazonS3ConfigKey::Client(client_key) = key {
212 client_options = client_options.with_config(*client_key, value.clone());
213 }
214 }
215
216 let bucket_region =
217 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
218 Ok(Some(bucket_region))
219 } else {
220 Ok(None)
221 }
222}
223
224pub async fn build_aws_credential(
241 credentials_refresh_offset: Duration,
242 credentials: Option<AwsCredentialProvider>,
243 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
244 region: Option<String>,
245 storage_options_accessor: Option<Arc<StorageOptionsAccessor>>,
246) -> Result<(AwsCredentialProvider, String)> {
247 use aws_config::meta::region::RegionProviderChain;
248 const DEFAULT_REGION: &str = "us-west-2";
249
250 let region = if let Some(region) = region {
251 region
252 } else {
253 RegionProviderChain::default_provider()
254 .or_else(DEFAULT_REGION)
255 .region()
256 .await
257 .map(|r| r.as_ref().to_string())
258 .unwrap_or(DEFAULT_REGION.to_string())
259 };
260
261 let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
262
263 if let Some(accessor) = storage_options_accessor
267 && accessor.has_provider()
268 {
269 if let Some(creds) = credentials {
271 return Ok((creds, region));
272 }
273
274 let opts = accessor.get_storage_options().await?;
276 let s3_options = opts.as_s3_options();
277 if extract_static_s3_credentials(&s3_options).is_some() {
278 return Ok((
279 Arc::new(DynamicStorageOptionsCredentialProvider::new(accessor)),
280 region,
281 ));
282 }
283
284 log::debug!(
285 "Storage options from provider do not contain explicit AWS credentials, \
286 falling back to default AWS credentials chain."
287 );
288 }
289
290 if let Some(creds) = credentials {
292 Ok((creds, region))
293 } else if let Some(creds) = storage_options_credentials {
294 Ok((Arc::new(creds), region))
295 } else {
296 let credentials_provider = DefaultCredentialsChain::builder().build().await;
297
298 Ok((
299 Arc::new(AwsCredentialAdapter::new(
300 Arc::new(credentials_provider),
301 credentials_refresh_offset,
302 )),
303 region,
304 ))
305 }
306}
307
308fn extract_static_s3_credentials(
309 options: &HashMap<AmazonS3ConfigKey, String>,
310) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
311 let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
312 let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
313 let token = options.get(&AmazonS3ConfigKey::Token).cloned();
314 match (key_id, secret_key, token) {
315 (Some(key_id), Some(secret_key), token) => {
316 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
317 key_id,
318 secret_key,
319 token,
320 }))
321 }
322 _ => None,
323 }
324}
325
326#[derive(Debug)]
328pub struct AwsCredentialAdapter {
329 pub inner: Arc<dyn ProvideCredentials>,
330
331 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
333
334 credentials_refresh_offset: Duration,
336}
337
338impl AwsCredentialAdapter {
339 pub fn new(
340 provider: Arc<dyn ProvideCredentials>,
341 credentials_refresh_offset: Duration,
342 ) -> Self {
343 Self {
344 inner: provider,
345 cache: Arc::new(RwLock::new(HashMap::new())),
346 credentials_refresh_offset,
347 }
348 }
349}
350
351const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
352
353fn to_system_time(time: std::time::SystemTime) -> SystemTime {
355 let duration_since_epoch = time
356 .duration_since(std::time::UNIX_EPOCH)
357 .expect("time should be after UNIX_EPOCH");
358 UNIX_EPOCH + duration_since_epoch
359}
360
361#[async_trait::async_trait]
362impl CredentialProvider for AwsCredentialAdapter {
363 type Credential = ObjectStoreAwsCredential;
364
365 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
366 let cached_creds = {
367 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
368 let expired = cache_value
369 .clone()
370 .map(|cred| {
371 cred.expiry()
372 .map(|exp| {
373 to_system_time(exp)
374 .checked_sub(self.credentials_refresh_offset)
375 .expect("this time should always be valid")
376 < SystemTime::now()
377 })
378 .unwrap_or(false)
380 })
381 .unwrap_or(true); if expired { None } else { cache_value.clone() }
383 };
384
385 if let Some(creds) = cached_creds {
386 Ok(Arc::new(Self::Credential {
387 key_id: creds.access_key_id().to_string(),
388 secret_key: creds.secret_access_key().to_string(),
389 token: creds.session_token().map(|s| s.to_string()),
390 }))
391 } else {
392 let refreshed_creds =
393 Arc::new(self.inner.provide_credentials().await.map_err(|e| {
394 Error::internal(format!("Failed to get AWS credentials: {:?}", e))
395 })?);
396
397 self.cache
398 .write()
399 .await
400 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
401
402 Ok(Arc::new(Self::Credential {
403 key_id: refreshed_creds.access_key_id().to_string(),
404 secret_key: refreshed_creds.secret_access_key().to_string(),
405 token: refreshed_creds.session_token().map(|s| s.to_string()),
406 }))
407 }
408 }
409}
410
411impl StorageOptions {
412 pub fn with_env_s3(&mut self) {
414 for (os_key, os_value) in std::env::vars_os() {
415 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str())
416 && let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase())
417 && !self.0.contains_key(config_key.as_ref())
418 {
419 self.0
420 .insert(config_key.as_ref().to_string(), value.to_string());
421 }
422 }
423 }
424
425 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
427 self.0
428 .iter()
429 .filter_map(|(key, value)| {
430 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
431 Some((s3_key, value.clone()))
432 })
433 .collect()
434 }
435}
436
437impl ObjectStoreParams {
438 pub fn with_aws_credentials(
440 aws_credentials: Option<AwsCredentialProvider>,
441 region: Option<String>,
442 ) -> Self {
443 let storage_options_accessor = region.map(|region| {
444 let opts: HashMap<String, String> =
445 [("region".into(), region)].iter().cloned().collect();
446 Arc::new(StorageOptionsAccessor::with_static_options(opts))
447 });
448 Self {
449 aws_credentials,
450 storage_options_accessor,
451 ..Default::default()
452 }
453 }
454}
455
456pub struct DynamicStorageOptionsCredentialProvider {
470 accessor: Arc<StorageOptionsAccessor>,
471}
472
473impl fmt::Debug for DynamicStorageOptionsCredentialProvider {
474 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
475 f.debug_struct("DynamicStorageOptionsCredentialProvider")
476 .field("accessor", &self.accessor)
477 .finish()
478 }
479}
480
481impl DynamicStorageOptionsCredentialProvider {
482 pub fn new(accessor: Arc<StorageOptionsAccessor>) -> Self {
484 Self { accessor }
485 }
486
487 pub fn from_provider(provider: Arc<dyn StorageOptionsProvider>) -> Self {
496 Self {
497 accessor: Arc::new(StorageOptionsAccessor::with_provider(provider)),
498 }
499 }
500
501 pub fn from_provider_with_initial(
511 provider: Arc<dyn StorageOptionsProvider>,
512 initial_options: HashMap<String, String>,
513 ) -> Self {
514 Self {
515 accessor: Arc::new(StorageOptionsAccessor::with_initial_and_provider(
516 initial_options,
517 provider,
518 )),
519 }
520 }
521}
522
523#[async_trait::async_trait]
524impl CredentialProvider for DynamicStorageOptionsCredentialProvider {
525 type Credential = ObjectStoreAwsCredential;
526
527 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
528 let storage_options = self.accessor.get_storage_options().await.map_err(|e| {
529 object_store::Error::Generic {
530 store: "DynamicStorageOptionsCredentialProvider",
531 source: Box::new(e),
532 }
533 })?;
534
535 let s3_options = storage_options.as_s3_options();
536 let static_creds = extract_static_s3_credentials(&s3_options).ok_or_else(|| {
537 object_store::Error::Generic {
538 store: "DynamicStorageOptionsCredentialProvider",
539 source: "Missing required credentials in storage options".into(),
540 }
541 })?;
542
543 static_creds
544 .get_credential()
545 .await
546 .map_err(|e| object_store::Error::Generic {
547 store: "DynamicStorageOptionsCredentialProvider",
548 source: Box::new(e),
549 })
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 use crate::object_store::ObjectStoreRegistry;
556 use mock_instant::thread_local::MockClock;
557 use object_store::path::Path;
558 use std::sync::atomic::{AtomicBool, Ordering};
559
560 use super::*;
561
562 #[derive(Debug, Default)]
563 struct MockAwsCredentialsProvider {
564 called: AtomicBool,
565 }
566
567 #[async_trait::async_trait]
568 impl CredentialProvider for MockAwsCredentialsProvider {
569 type Credential = ObjectStoreAwsCredential;
570
571 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
572 self.called.store(true, Ordering::Relaxed);
573 Ok(Arc::new(Self::Credential {
574 key_id: "".to_string(),
575 secret_key: "".to_string(),
576 token: None,
577 }))
578 }
579 }
580
581 #[tokio::test]
582 async fn test_injected_aws_creds_option_is_used() {
583 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
584 let registry = Arc::new(ObjectStoreRegistry::default());
585
586 let params = ObjectStoreParams {
587 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
588 ..ObjectStoreParams::default()
589 };
590
591 assert!(!mock_provider.called.load(Ordering::Relaxed));
593
594 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
595 .await
596 .unwrap();
597
598 let _ = store
600 .open(&Path::parse("/").unwrap())
601 .await
602 .unwrap()
603 .get_range(0..1)
604 .await;
605
606 assert!(mock_provider.called.load(Ordering::Relaxed));
608 }
609
610 #[test]
611 fn test_s3_path_parsing() {
612 let provider = AwsStoreProvider;
613
614 let cases = [
615 ("s3://bucket/path/to/file", "path/to/file"),
616 ("s3://bucket/测试path/to/file", "测试path/to/file"),
618 ("s3://bucket/path/&to/file", "path/&to/file"),
619 ("s3://bucket/path/=to/file", "path/=to/file"),
620 (
621 "s3+ddb://bucket/path/to/file?ddbTableName=test",
622 "path/to/file",
623 ),
624 ];
625
626 for (uri, expected_path) in cases {
627 let url = Url::parse(uri).unwrap();
628 let path = provider.extract_path(&url).unwrap();
629 let expected_path = Path::from(expected_path);
630 assert_eq!(path, expected_path)
631 }
632 }
633
634 #[test]
635 fn test_is_s3_express() {
636 let cases = [
637 (
638 "s3://bucket/path/to/file",
639 HashMap::from([("s3_express".to_string(), "true".to_string())]),
640 true,
641 ),
642 (
643 "s3://bucket/path/to/file",
644 HashMap::from([("s3_express".to_string(), "false".to_string())]),
645 false,
646 ),
647 ("s3://bucket/path/to/file", HashMap::from([]), false),
648 (
649 "s3://bucket--x-s3/path/to/file",
650 HashMap::from([("s3_express".to_string(), "true".to_string())]),
651 true,
652 ),
653 (
654 "s3://bucket--x-s3/path/to/file",
655 HashMap::from([("s3_express".to_string(), "false".to_string())]),
656 true, ),
658 ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
659 ];
660
661 for (uri, storage_map, expected) in cases {
662 let url = Url::parse(uri).unwrap();
663 let storage_options = StorageOptions(storage_map);
664 let is_s3_express = check_s3_express(&url, &storage_options);
665 assert_eq!(is_s3_express, expected);
666 }
667 }
668
669 #[tokio::test]
670 async fn test_use_opendal_flag() {
671 use crate::object_store::StorageOptionsAccessor;
672 let provider = AwsStoreProvider;
673 let url = Url::parse("s3://test-bucket/path").unwrap();
674 let params_with_flag = ObjectStoreParams {
675 storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options(
676 HashMap::from([
677 ("use_opendal".to_string(), "true".to_string()),
678 ("region".to_string(), "us-west-2".to_string()),
679 ]),
680 ))),
681 ..Default::default()
682 };
683
684 let store = provider
685 .new_store(url.clone(), ¶ms_with_flag)
686 .await
687 .unwrap();
688 assert_eq!(store.scheme, "s3");
689 }
690
691 #[derive(Debug)]
692 struct MockStorageOptionsProvider {
693 call_count: Arc<RwLock<usize>>,
694 expires_in_millis: Option<u64>,
695 }
696
697 impl MockStorageOptionsProvider {
698 fn new(expires_in_millis: Option<u64>) -> Self {
699 Self {
700 call_count: Arc::new(RwLock::new(0)),
701 expires_in_millis,
702 }
703 }
704
705 async fn get_call_count(&self) -> usize {
706 *self.call_count.read().await
707 }
708 }
709
710 #[async_trait::async_trait]
711 impl StorageOptionsProvider for MockStorageOptionsProvider {
712 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
713 let count = {
714 let mut c = self.call_count.write().await;
715 *c += 1;
716 *c
717 };
718
719 let mut options = HashMap::from([
720 ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
721 (
722 "aws_secret_access_key".to_string(),
723 format!("SECRET_{}", count),
724 ),
725 ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
726 ]);
727
728 if let Some(expires_in) = self.expires_in_millis {
729 let now_ms = SystemTime::now()
730 .duration_since(UNIX_EPOCH)
731 .unwrap()
732 .as_millis() as u64;
733 let expires_at = now_ms + expires_in;
734 options.insert("expires_at_millis".to_string(), expires_at.to_string());
735 }
736
737 Ok(Some(options))
738 }
739
740 fn provider_id(&self) -> String {
741 let ptr = Arc::as_ptr(&self.call_count) as usize;
742 format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
743 }
744 }
745
746 #[tokio::test]
747 async fn test_dynamic_credential_provider_with_initial_cache() {
748 MockClock::set_system_time(Duration::from_secs(100_000));
749
750 let now_ms = MockClock::system_time().as_millis() as u64;
751
752 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
754 600_000, )));
756
757 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
760 ("aws_access_key_id".to_string(), "AKID_CACHED".to_string()),
761 (
762 "aws_secret_access_key".to_string(),
763 "SECRET_CACHED".to_string(),
764 ),
765 ("aws_session_token".to_string(), "TOKEN_CACHED".to_string()),
766 ("expires_at_millis".to_string(), expires_at.to_string()),
767 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
769
770 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
771 mock.clone(),
772 initial_options,
773 );
774
775 let cred = provider.get_credential().await.unwrap();
777 assert_eq!(cred.key_id, "AKID_CACHED");
778 assert_eq!(cred.secret_key, "SECRET_CACHED");
779 assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
780
781 assert_eq!(mock.get_call_count().await, 0);
783 }
784
785 #[tokio::test]
786 async fn test_dynamic_credential_provider_with_expired_cache() {
787 MockClock::set_system_time(Duration::from_secs(100_000));
788
789 let now_ms = MockClock::system_time().as_millis() as u64;
790
791 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
793 600_000, )));
795
796 let expired_time = now_ms - 1_000; let initial_options = HashMap::from([
799 ("aws_access_key_id".to_string(), "AKID_EXPIRED".to_string()),
800 (
801 "aws_secret_access_key".to_string(),
802 "SECRET_EXPIRED".to_string(),
803 ),
804 ("expires_at_millis".to_string(), expired_time.to_string()),
805 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
807
808 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
809 mock.clone(),
810 initial_options,
811 );
812
813 let cred = provider.get_credential().await.unwrap();
815 assert_eq!(cred.key_id, "AKID_1");
816 assert_eq!(cred.secret_key, "SECRET_1");
817 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
818
819 assert_eq!(mock.get_call_count().await, 1);
821 }
822
823 #[tokio::test]
824 async fn test_dynamic_credential_provider_refresh_lead_time() {
825 MockClock::set_system_time(Duration::from_secs(100_000));
826
827 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
829 30_000, )));
831
832 let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
835
836 let cred = provider.get_credential().await.unwrap();
840 assert_eq!(cred.key_id, "AKID_1");
841 assert_eq!(mock.get_call_count().await, 1);
842
843 let cred = provider.get_credential().await.unwrap();
847 assert_eq!(cred.key_id, "AKID_2");
848 assert_eq!(mock.get_call_count().await, 2);
849 }
850
851 #[tokio::test]
852 async fn test_dynamic_credential_provider_no_initial_cache() {
853 MockClock::set_system_time(Duration::from_secs(100_000));
854
855 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
857 120_000, )));
859
860 let provider = DynamicStorageOptionsCredentialProvider::from_provider(mock.clone());
862
863 let cred = provider.get_credential().await.unwrap();
865 assert_eq!(cred.key_id, "AKID_1");
866 assert_eq!(cred.secret_key, "SECRET_1");
867 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
868 assert_eq!(mock.get_call_count().await, 1);
869
870 let cred = provider.get_credential().await.unwrap();
872 assert_eq!(cred.key_id, "AKID_1");
873 assert_eq!(mock.get_call_count().await, 1); MockClock::set_system_time(Duration::from_secs(100_000 + 90));
878 let cred = provider.get_credential().await.unwrap();
879 assert_eq!(cred.key_id, "AKID_2");
880 assert_eq!(cred.secret_key, "SECRET_2");
881 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
882 assert_eq!(mock.get_call_count().await, 2);
883
884 MockClock::set_system_time(Duration::from_secs(100_000 + 210));
886 let cred = provider.get_credential().await.unwrap();
887 assert_eq!(cred.key_id, "AKID_3");
888 assert_eq!(cred.secret_key, "SECRET_3");
889 assert_eq!(mock.get_call_count().await, 3);
890 }
891
892 #[tokio::test]
893 async fn test_dynamic_credential_provider_with_initial_options() {
894 MockClock::set_system_time(Duration::from_secs(100_000));
895
896 let now_ms = MockClock::system_time().as_millis() as u64;
897
898 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
900 600_000, )));
902
903 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
906 ("aws_access_key_id".to_string(), "AKID_INITIAL".to_string()),
907 (
908 "aws_secret_access_key".to_string(),
909 "SECRET_INITIAL".to_string(),
910 ),
911 ("aws_session_token".to_string(), "TOKEN_INITIAL".to_string()),
912 ("expires_at_millis".to_string(), expires_at.to_string()),
913 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
915
916 let provider = DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
918 mock.clone(),
919 initial_options,
920 );
921
922 let cred = provider.get_credential().await.unwrap();
924 assert_eq!(cred.key_id, "AKID_INITIAL");
925 assert_eq!(cred.secret_key, "SECRET_INITIAL");
926 assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
927
928 assert_eq!(mock.get_call_count().await, 0);
930
931 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
934 let cred = provider.get_credential().await.unwrap();
935 assert_eq!(cred.key_id, "AKID_1");
936 assert_eq!(cred.secret_key, "SECRET_1");
937 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
938
939 assert_eq!(mock.get_call_count().await, 1);
941
942 MockClock::set_system_time(Duration::from_secs(100_000 + 660));
944 let cred = provider.get_credential().await.unwrap();
945 assert_eq!(cred.key_id, "AKID_2");
946 assert_eq!(cred.secret_key, "SECRET_2");
947 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
948
949 assert_eq!(mock.get_call_count().await, 2);
951
952 MockClock::set_system_time(Duration::from_secs(100_000 + 960));
954 let cred = provider.get_credential().await.unwrap();
955 assert_eq!(cred.key_id, "AKID_3");
956 assert_eq!(cred.secret_key, "SECRET_3");
957 assert_eq!(cred.token, Some("TOKEN_3".to_string()));
958
959 assert_eq!(mock.get_call_count().await, 3);
961 }
962
963 #[tokio::test]
964 async fn test_dynamic_credential_provider_concurrent_access() {
965 let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
967
968 let provider = Arc::new(DynamicStorageOptionsCredentialProvider::from_provider(
969 mock.clone(),
970 ));
971
972 let mut handles = vec![];
974 for i in 0..10 {
975 let provider = provider.clone();
976 let handle = tokio::spawn(async move {
977 let cred = provider.get_credential().await.unwrap();
978 assert_eq!(cred.key_id, "AKID_1");
980 assert_eq!(cred.secret_key, "SECRET_1");
981 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
982 i });
984 handles.push(handle);
985 }
986
987 let results: Vec<_> = futures::future::join_all(handles)
989 .await
990 .into_iter()
991 .map(|r| r.unwrap())
992 .collect();
993
994 assert_eq!(results.len(), 10);
996 for i in 0..10 {
997 assert!(results.contains(&i));
998 }
999
1000 let call_count = mock.get_call_count().await;
1003 assert_eq!(
1004 call_count, 1,
1005 "Provider should be called exactly once despite concurrent access"
1006 );
1007 }
1008
1009 #[tokio::test]
1010 async fn test_dynamic_credential_provider_concurrent_refresh() {
1011 MockClock::set_system_time(Duration::from_secs(100_000));
1012
1013 let now_ms = MockClock::system_time().as_millis() as u64;
1014
1015 let expires_at = now_ms - 1_000_000;
1017 let initial_options = HashMap::from([
1018 ("aws_access_key_id".to_string(), "AKID_OLD".to_string()),
1019 (
1020 "aws_secret_access_key".to_string(),
1021 "SECRET_OLD".to_string(),
1022 ),
1023 ("aws_session_token".to_string(), "TOKEN_OLD".to_string()),
1024 ("expires_at_millis".to_string(), expires_at.to_string()),
1025 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
1027
1028 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1030 3_600_000, )));
1032
1033 let provider = Arc::new(
1034 DynamicStorageOptionsCredentialProvider::from_provider_with_initial(
1035 mock.clone(),
1036 initial_options,
1037 ),
1038 );
1039
1040 let mut handles = vec![];
1043 for i in 0..20 {
1044 let provider = provider.clone();
1045 let handle = tokio::spawn(async move {
1046 let cred = provider.get_credential().await.unwrap();
1047 assert_eq!(cred.key_id, "AKID_1");
1049 assert_eq!(cred.secret_key, "SECRET_1");
1050 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1051 i
1052 });
1053 handles.push(handle);
1054 }
1055
1056 let results: Vec<_> = futures::future::join_all(handles)
1058 .await
1059 .into_iter()
1060 .map(|r| r.unwrap())
1061 .collect();
1062
1063 assert_eq!(results.len(), 20);
1065
1066 let call_count = mock.get_call_count().await;
1069 assert!(
1070 call_count >= 1,
1071 "Provider should be called at least once, was called {} times",
1072 call_count
1073 );
1074
1075 assert!(
1077 call_count < 10,
1078 "Provider should not be called too many times due to lock contention, was called {} times",
1079 call_count
1080 );
1081 }
1082
1083 #[tokio::test]
1084 async fn test_explicit_aws_credentials_takes_precedence_over_accessor() {
1085 let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1087
1088 let accessor = Arc::new(StorageOptionsAccessor::with_provider(
1090 mock_storage_provider.clone(),
1091 ));
1092
1093 let explicit_cred_provider = Arc::new(MockAwsCredentialsProvider::default());
1095
1096 let (result, _region) = build_aws_credential(
1099 Duration::from_secs(300),
1100 Some(explicit_cred_provider.clone() as AwsCredentialProvider),
1101 None, Some("us-west-2".to_string()),
1103 Some(accessor),
1104 )
1105 .await
1106 .unwrap();
1107
1108 let cred = result.get_credential().await.unwrap();
1110
1111 assert!(explicit_cred_provider.called.load(Ordering::Relaxed));
1113
1114 assert_eq!(
1116 mock_storage_provider.get_call_count().await,
1117 0,
1118 "Storage options provider should not be called when explicit aws_credentials is provided"
1119 );
1120
1121 assert_eq!(cred.key_id, "");
1123 assert_eq!(cred.secret_key, "");
1124 }
1125
1126 #[tokio::test]
1127 async fn test_accessor_used_when_no_explicit_aws_credentials() {
1128 MockClock::set_system_time(Duration::from_secs(100_000));
1129
1130 let now_ms = MockClock::system_time().as_millis() as u64;
1131
1132 let mock_storage_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
1134
1135 let expires_at = now_ms + 600_000; let initial_options = HashMap::from([
1138 (
1139 "aws_access_key_id".to_string(),
1140 "AKID_FROM_ACCESSOR".to_string(),
1141 ),
1142 (
1143 "aws_secret_access_key".to_string(),
1144 "SECRET_FROM_ACCESSOR".to_string(),
1145 ),
1146 (
1147 "aws_session_token".to_string(),
1148 "TOKEN_FROM_ACCESSOR".to_string(),
1149 ),
1150 ("expires_at_millis".to_string(), expires_at.to_string()),
1151 ("refresh_offset_millis".to_string(), "300000".to_string()), ]);
1153
1154 let accessor = Arc::new(StorageOptionsAccessor::with_initial_and_provider(
1156 initial_options,
1157 mock_storage_provider.clone(),
1158 ));
1159
1160 let (result, _region) = build_aws_credential(
1162 Duration::from_secs(300),
1163 None, None, Some("us-west-2".to_string()),
1166 Some(accessor),
1167 )
1168 .await
1169 .unwrap();
1170
1171 let cred = result.get_credential().await.unwrap();
1173 assert_eq!(cred.key_id, "AKID_FROM_ACCESSOR");
1174 assert_eq!(cred.secret_key, "SECRET_FROM_ACCESSOR");
1175
1176 assert_eq!(mock_storage_provider.get_call_count().await, 0);
1178
1179 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1181
1182 let cred = result.get_credential().await.unwrap();
1184 assert_eq!(cred.key_id, "AKID_1");
1185 assert_eq!(cred.secret_key, "SECRET_1");
1186
1187 assert_eq!(mock_storage_provider.get_call_count().await, 1);
1189 }
1190}