1use std::{collections::HashMap, fmt, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{services::S3, Operator};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19 aws::{
20 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
21 AwsCredentialProvider,
22 },
23 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
24 StaticCredentialProvider,
25};
26use snafu::location;
27use tokio::sync::RwLock;
28use url::Url;
29
30use crate::object_store::{
31 ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsProvider,
32 DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
33};
34use lance_core::error::{Error, Result};
35
36#[derive(Default, Debug)]
37pub struct AwsStoreProvider;
38
39impl AwsStoreProvider {
40 async fn build_amazon_s3_store(
41 &self,
42 base_path: &mut Url,
43 params: &ObjectStoreParams,
44 storage_options: &StorageOptions,
45 is_s3_express: bool,
46 ) -> Result<Arc<dyn OSObjectStore>> {
47 let max_retries = storage_options.client_max_retries();
48 let retry_timeout = storage_options.client_retry_timeout();
49 let retry_config = RetryConfig {
50 backoff: Default::default(),
51 max_retries,
52 retry_timeout: Duration::from_secs(retry_timeout),
53 };
54
55 let mut s3_storage_options = storage_options.as_s3_options();
56 let region = resolve_s3_region(base_path, &s3_storage_options).await?;
57 let (aws_creds, region) = build_aws_credential(
58 params.s3_credentials_refresh_offset,
59 params.aws_credentials.clone(),
60 Some(&s3_storage_options),
61 region,
62 params.storage_options_provider.clone(),
63 storage_options.expires_at_millis(),
64 )
65 .await?;
66
67 if is_s3_express {
69 s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
70 }
71
72 base_path.set_scheme("s3").unwrap();
74 base_path.set_query(None);
75
76 let mut builder = AmazonS3Builder::new();
78 for (key, value) in s3_storage_options {
79 builder = builder.with_config(key, value);
80 }
81 builder = builder
82 .with_url(base_path.as_ref())
83 .with_credentials(aws_creds)
84 .with_retry(retry_config)
85 .with_region(region);
86
87 Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
88 }
89
90 async fn build_opendal_s3_store(
91 &self,
92 base_path: &Url,
93 storage_options: &StorageOptions,
94 ) -> Result<Arc<dyn OSObjectStore>> {
95 let bucket = base_path
96 .host_str()
97 .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name", location!()))?
98 .to_string();
99
100 let prefix = base_path.path().trim_start_matches('/').to_string();
101
102 let mut config_map: HashMap<String, String> = storage_options.0.clone();
105
106 config_map.insert("bucket".to_string(), bucket);
108
109 if !prefix.is_empty() {
110 config_map.insert("root".to_string(), "/".to_string());
111 }
112
113 let operator = Operator::from_iter::<S3>(config_map)
114 .map_err(|e| {
115 Error::invalid_input(
116 format!("Failed to create S3 operator: {:?}", e),
117 location!(),
118 )
119 })?
120 .finish();
121
122 Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
123 }
124}
125
126#[async_trait::async_trait]
127impl ObjectStoreProvider for AwsStoreProvider {
128 async fn new_store(
129 &self,
130 mut base_path: Url,
131 params: &ObjectStoreParams,
132 ) -> Result<ObjectStore> {
133 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
134 let mut storage_options =
135 StorageOptions(params.storage_options.clone().unwrap_or_default());
136 storage_options.with_env_s3();
137 let download_retry_count = storage_options.download_retry_count();
138
139 let use_opendal = storage_options
140 .0
141 .get("use_opendal")
142 .map(|v| v == "true")
143 .unwrap_or(false);
144
145 let is_s3_express = check_s3_express(&base_path, &storage_options);
147
148 let use_constant_size_upload_parts = storage_options
149 .0
150 .get("aws_endpoint")
151 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
152 .unwrap_or(false);
153
154 let inner = if use_opendal {
155 self.build_opendal_s3_store(&base_path, &storage_options)
157 .await?
158 } else {
159 self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
161 .await?
162 };
163
164 Ok(ObjectStore {
165 inner,
166 scheme: String::from(base_path.scheme()),
167 block_size,
168 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
169 use_constant_size_upload_parts,
170 list_is_lexically_ordered: !is_s3_express,
171 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
172 download_retry_count,
173 })
174 }
175}
176
177fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
179 storage_options
180 .0
181 .get("s3_express")
182 .map(|v| v == "true")
183 .unwrap_or(false)
184 || url.authority().ends_with("--x-s3")
185}
186
187async fn resolve_s3_region(
195 url: &Url,
196 storage_options: &HashMap<AmazonS3ConfigKey, String>,
197) -> Result<Option<String>> {
198 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
199 Ok(Some(region.clone()))
200 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
201 let bucket = url.host_str().ok_or_else(|| {
204 Error::invalid_input(
205 format!("Could not parse bucket from url: {}", url),
206 location!(),
207 )
208 })?;
209
210 let mut client_options = ClientOptions::default();
211 for (key, value) in storage_options {
212 if let AmazonS3ConfigKey::Client(client_key) = key {
213 client_options = client_options.with_config(*client_key, value.clone());
214 }
215 }
216
217 let bucket_region =
218 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
219 Ok(Some(bucket_region))
220 } else {
221 Ok(None)
222 }
223}
224
225pub async fn build_aws_credential(
245 credentials_refresh_offset: Duration,
246 credentials: Option<AwsCredentialProvider>,
247 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
248 region: Option<String>,
249 storage_options_provider: Option<Arc<dyn StorageOptionsProvider>>,
250 expires_at_millis: Option<u64>,
251) -> Result<(AwsCredentialProvider, String)> {
252 use aws_config::meta::region::RegionProviderChain;
254 const DEFAULT_REGION: &str = "us-west-2";
255
256 let region = if let Some(region) = region {
257 region
258 } else {
259 RegionProviderChain::default_provider()
260 .or_else(DEFAULT_REGION)
261 .region()
262 .await
263 .map(|r| r.as_ref().to_string())
264 .unwrap_or(DEFAULT_REGION.to_string())
265 };
266
267 let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
268 if let Some(storage_options_provider) = storage_options_provider {
269 let creds = build_aws_credential_with_storage_options_provider(
270 storage_options_provider,
271 credentials_refresh_offset,
272 credentials,
273 storage_options_credentials,
274 expires_at_millis,
275 )
276 .await?;
277 Ok((creds, region))
278 } else if let Some(creds) = credentials {
279 Ok((creds, region))
280 } else if let Some(creds) = storage_options_credentials {
281 Ok((Arc::new(creds), region))
282 } else {
283 let credentials_provider = DefaultCredentialsChain::builder().build().await;
284
285 Ok((
286 Arc::new(AwsCredentialAdapter::new(
287 Arc::new(credentials_provider),
288 credentials_refresh_offset,
289 )),
290 region,
291 ))
292 }
293}
294
295async fn build_aws_credential_with_storage_options_provider(
296 storage_options_provider: Arc<dyn StorageOptionsProvider>,
297 credentials_refresh_offset: Duration,
298 credentials: Option<AwsCredentialProvider>,
299 storage_options_credentials: Option<StaticCredentialProvider<ObjectStoreAwsCredential>>,
300 expires_at_millis: Option<u64>,
301) -> Result<AwsCredentialProvider> {
302 match (expires_at_millis, credentials, storage_options_credentials) {
303 (Some(expires_at), Some(cred), _) => {
305 Ok(Arc::new(
306 DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
307 storage_options_provider,
308 credentials_refresh_offset,
309 cred.get_credential().await?,
310 expires_at,
311 ),
312 ))
313 }
314 (Some(expires_at), None, Some(cred)) => {
316 Ok(Arc::new(
317 DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
318 storage_options_provider,
319 credentials_refresh_offset,
320 cred.get_credential().await?,
321 expires_at,
322 ),
323 ))
324 }
325 (None, None, Some(_)) => Err(Error::IO {
327 source: Box::new(std::io::Error::other(
328 "expires_at_millis is required when using storage_options_provider with storage_options",
329 )),
330 location: location!(),
331 }),
332 (None, Some(_), _) => Err(Error::IO {
334 source: Box::new(std::io::Error::other(
335 "expires_at_millis is required when using storage_options_provider with credentials",
336 )),
337 location: location!(),
338 }),
339 (_, None, None) => Ok(Arc::new(DynamicStorageOptionsCredentialProvider::new(
341 storage_options_provider,
342 credentials_refresh_offset,
343 ))),
344 }
345}
346
347fn extract_static_s3_credentials(
348 options: &HashMap<AmazonS3ConfigKey, String>,
349) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
350 let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
351 let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
352 let token = options.get(&AmazonS3ConfigKey::Token).cloned();
353 match (key_id, secret_key, token) {
354 (Some(key_id), Some(secret_key), token) => {
355 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
356 key_id,
357 secret_key,
358 token,
359 }))
360 }
361 _ => None,
362 }
363}
364
365#[derive(Debug)]
367pub struct AwsCredentialAdapter {
368 pub inner: Arc<dyn ProvideCredentials>,
369
370 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
372
373 credentials_refresh_offset: Duration,
375}
376
377impl AwsCredentialAdapter {
378 pub fn new(
379 provider: Arc<dyn ProvideCredentials>,
380 credentials_refresh_offset: Duration,
381 ) -> Self {
382 Self {
383 inner: provider,
384 cache: Arc::new(RwLock::new(HashMap::new())),
385 credentials_refresh_offset,
386 }
387 }
388}
389
390const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
391
392fn to_system_time(time: std::time::SystemTime) -> SystemTime {
394 let duration_since_epoch = time
395 .duration_since(std::time::UNIX_EPOCH)
396 .expect("time should be after UNIX_EPOCH");
397 UNIX_EPOCH + duration_since_epoch
398}
399
400#[async_trait::async_trait]
401impl CredentialProvider for AwsCredentialAdapter {
402 type Credential = ObjectStoreAwsCredential;
403
404 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
405 let cached_creds = {
406 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
407 let expired = cache_value
408 .clone()
409 .map(|cred| {
410 cred.expiry()
411 .map(|exp| {
412 to_system_time(exp)
413 .checked_sub(self.credentials_refresh_offset)
414 .expect("this time should always be valid")
415 < SystemTime::now()
416 })
417 .unwrap_or(false)
419 })
420 .unwrap_or(true); if expired {
422 None
423 } else {
424 cache_value.clone()
425 }
426 };
427
428 if let Some(creds) = cached_creds {
429 Ok(Arc::new(Self::Credential {
430 key_id: creds.access_key_id().to_string(),
431 secret_key: creds.secret_access_key().to_string(),
432 token: creds.session_token().map(|s| s.to_string()),
433 }))
434 } else {
435 let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
436 |e| Error::Internal {
437 message: format!("Failed to get AWS credentials: {:?}", e),
438 location: location!(),
439 },
440 )?);
441
442 self.cache
443 .write()
444 .await
445 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
446
447 Ok(Arc::new(Self::Credential {
448 key_id: refreshed_creds.access_key_id().to_string(),
449 secret_key: refreshed_creds.secret_access_key().to_string(),
450 token: refreshed_creds.session_token().map(|s| s.to_string()),
451 }))
452 }
453 }
454}
455
456impl StorageOptions {
457 pub fn with_env_s3(&mut self) {
459 for (os_key, os_value) in std::env::vars_os() {
460 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
461 if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
462 if !self.0.contains_key(config_key.as_ref()) {
463 self.0
464 .insert(config_key.as_ref().to_string(), value.to_string());
465 }
466 }
467 }
468 }
469 }
470
471 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
473 self.0
474 .iter()
475 .filter_map(|(key, value)| {
476 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
477 Some((s3_key, value.clone()))
478 })
479 .collect()
480 }
481}
482
483impl ObjectStoreParams {
484 pub fn with_aws_credentials(
486 aws_credentials: Option<AwsCredentialProvider>,
487 region: Option<String>,
488 ) -> Self {
489 Self {
490 aws_credentials,
491 storage_options: region
492 .map(|region| [("region".into(), region)].iter().cloned().collect()),
493 ..Default::default()
494 }
495 }
496}
497
498pub struct DynamicStorageOptionsCredentialProvider {
512 provider: Arc<dyn StorageOptionsProvider>,
513 cache: Arc<RwLock<Option<CachedCredential>>>,
514 refresh_offset: Duration,
515}
516
517impl fmt::Debug for DynamicStorageOptionsCredentialProvider {
518 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
519 f.debug_struct("DynamicStorageOptionsCredentialProvider")
520 .field("provider", &self.provider)
521 .field("refresh_offset", &self.refresh_offset)
522 .finish()
523 }
524}
525
526#[derive(Debug, Clone)]
527struct CachedCredential {
528 credential: Arc<ObjectStoreAwsCredential>,
529 expires_at_millis: Option<u64>,
530}
531
532impl DynamicStorageOptionsCredentialProvider {
533 pub fn new(provider: Arc<dyn StorageOptionsProvider>, refresh_offset: Duration) -> Self {
539 Self {
540 provider,
541 cache: Arc::new(RwLock::new(None)),
542 refresh_offset,
543 }
544 }
545
546 pub fn new_with_initial_credential(
554 provider: Arc<dyn StorageOptionsProvider>,
555 refresh_offset: Duration,
556 credential: Arc<ObjectStoreAwsCredential>,
557 expires_at_millis: u64,
558 ) -> Self {
559 Self {
560 provider,
561 cache: Arc::new(RwLock::new(Some(CachedCredential {
562 credential,
563 expires_at_millis: Some(expires_at_millis),
564 }))),
565 refresh_offset,
566 }
567 }
568
569 fn needs_refresh(&self, cached: &Option<CachedCredential>) -> bool {
570 match cached {
571 None => true,
572 Some(cached_cred) => {
573 if let Some(expires_at_millis) = cached_cred.expires_at_millis {
574 let now_ms = SystemTime::now()
575 .duration_since(UNIX_EPOCH)
576 .unwrap_or(Duration::from_secs(0))
577 .as_millis() as u64;
578
579 let refresh_offset_millis = self.refresh_offset.as_millis() as u64;
581 now_ms + refresh_offset_millis >= expires_at_millis
582 } else {
583 false
585 }
586 }
587 }
588 }
589
590 async fn do_get_credential(&self) -> ObjectStoreResult<Option<Arc<ObjectStoreAwsCredential>>> {
591 {
593 let cached = self.cache.read().await;
594 if !self.needs_refresh(&cached) {
595 if let Some(cached_cred) = &*cached {
596 return Ok(Some(cached_cred.credential.clone()));
597 }
598 }
599 }
600
601 let Ok(mut cache) = self.cache.try_write() else {
603 return Ok(None);
604 };
605
606 if !self.needs_refresh(&cache) {
609 if let Some(cached_cred) = &*cache {
610 return Ok(Some(cached_cred.credential.clone()));
611 }
612 }
613
614 let storage_options_map = self
615 .provider
616 .fetch_storage_options()
617 .await
618 .map_err(|e| object_store::Error::Generic {
619 store: "DynamicStorageOptionsCredentialProvider",
620 source: Box::new(e),
621 })?
622 .ok_or_else(|| object_store::Error::Generic {
623 store: "DynamicStorageOptionsCredentialProvider",
624 source: "No storage options available".into(),
625 })?;
626
627 let storage_options = StorageOptions(storage_options_map);
628 let expires_at_millis = storage_options.expires_at_millis();
629 let s3_options = storage_options.as_s3_options();
630 let static_creds = extract_static_s3_credentials(&s3_options).ok_or_else(|| {
631 object_store::Error::Generic {
632 store: "DynamicStorageOptionsCredentialProvider",
633 source: "Missing required credentials in storage options".into(),
634 }
635 })?;
636
637 let credential =
638 static_creds
639 .get_credential()
640 .await
641 .map_err(|e| object_store::Error::Generic {
642 store: "DynamicStorageOptionsCredentialProvider",
643 source: Box::new(e),
644 })?;
645
646 *cache = Some(CachedCredential {
647 credential: credential.clone(),
648 expires_at_millis,
649 });
650
651 Ok(Some(credential))
652 }
653}
654
655#[async_trait::async_trait]
656impl CredentialProvider for DynamicStorageOptionsCredentialProvider {
657 type Credential = ObjectStoreAwsCredential;
658
659 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
660 loop {
662 match self.do_get_credential().await? {
663 Some(cred) => return Ok(cred),
664 None => {
665 tokio::time::sleep(Duration::from_millis(10)).await;
667 continue;
668 }
669 }
670 }
671 }
672}
673
674#[cfg(test)]
675mod tests {
676 use crate::object_store::ObjectStoreRegistry;
677 use mock_instant::thread_local::MockClock;
678 use object_store::path::Path;
679 use std::sync::atomic::{AtomicBool, Ordering};
680
681 use super::*;
682
683 #[derive(Debug, Default)]
684 struct MockAwsCredentialsProvider {
685 called: AtomicBool,
686 }
687
688 #[async_trait::async_trait]
689 impl CredentialProvider for MockAwsCredentialsProvider {
690 type Credential = ObjectStoreAwsCredential;
691
692 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
693 self.called.store(true, Ordering::Relaxed);
694 Ok(Arc::new(Self::Credential {
695 key_id: "".to_string(),
696 secret_key: "".to_string(),
697 token: None,
698 }))
699 }
700 }
701
702 #[tokio::test]
703 async fn test_injected_aws_creds_option_is_used() {
704 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
705 let registry = Arc::new(ObjectStoreRegistry::default());
706
707 let params = ObjectStoreParams {
708 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
709 ..ObjectStoreParams::default()
710 };
711
712 assert!(!mock_provider.called.load(Ordering::Relaxed));
714
715 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
716 .await
717 .unwrap();
718
719 let _ = store
721 .open(&Path::parse("/").unwrap())
722 .await
723 .unwrap()
724 .get_range(0..1)
725 .await;
726
727 assert!(mock_provider.called.load(Ordering::Relaxed));
729 }
730
731 #[test]
732 fn test_s3_path_parsing() {
733 let provider = AwsStoreProvider;
734
735 let cases = [
736 ("s3://bucket/path/to/file", "path/to/file"),
737 ("s3://bucket/测试path/to/file", "测试path/to/file"),
739 ("s3://bucket/path/&to/file", "path/&to/file"),
740 ("s3://bucket/path/=to/file", "path/=to/file"),
741 (
742 "s3+ddb://bucket/path/to/file?ddbTableName=test",
743 "path/to/file",
744 ),
745 ];
746
747 for (uri, expected_path) in cases {
748 let url = Url::parse(uri).unwrap();
749 let path = provider.extract_path(&url).unwrap();
750 let expected_path = Path::from(expected_path);
751 assert_eq!(path, expected_path)
752 }
753 }
754
755 #[test]
756 fn test_is_s3_express() {
757 let cases = [
758 (
759 "s3://bucket/path/to/file",
760 HashMap::from([("s3_express".to_string(), "true".to_string())]),
761 true,
762 ),
763 (
764 "s3://bucket/path/to/file",
765 HashMap::from([("s3_express".to_string(), "false".to_string())]),
766 false,
767 ),
768 ("s3://bucket/path/to/file", HashMap::from([]), false),
769 (
770 "s3://bucket--x-s3/path/to/file",
771 HashMap::from([("s3_express".to_string(), "true".to_string())]),
772 true,
773 ),
774 (
775 "s3://bucket--x-s3/path/to/file",
776 HashMap::from([("s3_express".to_string(), "false".to_string())]),
777 true, ),
779 ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
780 ];
781
782 for (uri, storage_map, expected) in cases {
783 let url = Url::parse(uri).unwrap();
784 let storage_options = StorageOptions(storage_map);
785 let is_s3_express = check_s3_express(&url, &storage_options);
786 assert_eq!(is_s3_express, expected);
787 }
788 }
789
790 #[tokio::test]
791 async fn test_use_opendal_flag() {
792 let provider = AwsStoreProvider;
793 let url = Url::parse("s3://test-bucket/path").unwrap();
794 let params_with_flag = ObjectStoreParams {
795 storage_options: Some(HashMap::from([
796 ("use_opendal".to_string(), "true".to_string()),
797 ("region".to_string(), "us-west-2".to_string()),
798 ])),
799 ..Default::default()
800 };
801
802 let store = provider
803 .new_store(url.clone(), ¶ms_with_flag)
804 .await
805 .unwrap();
806 assert_eq!(store.scheme, "s3");
807 }
808
809 #[derive(Debug)]
810 struct MockStorageOptionsProvider {
811 call_count: Arc<RwLock<usize>>,
812 expires_in_millis: Option<u64>,
813 }
814
815 impl MockStorageOptionsProvider {
816 fn new(expires_in_millis: Option<u64>) -> Self {
817 Self {
818 call_count: Arc::new(RwLock::new(0)),
819 expires_in_millis,
820 }
821 }
822
823 async fn get_call_count(&self) -> usize {
824 *self.call_count.read().await
825 }
826 }
827
828 #[async_trait::async_trait]
829 impl StorageOptionsProvider for MockStorageOptionsProvider {
830 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
831 let count = {
832 let mut c = self.call_count.write().await;
833 *c += 1;
834 *c
835 };
836
837 let mut options = HashMap::from([
838 ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
839 (
840 "aws_secret_access_key".to_string(),
841 format!("SECRET_{}", count),
842 ),
843 ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
844 ]);
845
846 if let Some(expires_in) = self.expires_in_millis {
847 let now_ms = SystemTime::now()
848 .duration_since(UNIX_EPOCH)
849 .unwrap()
850 .as_millis() as u64;
851 let expires_at = now_ms + expires_in;
852 options.insert("expires_at_millis".to_string(), expires_at.to_string());
853 }
854
855 Ok(Some(options))
856 }
857
858 fn provider_id(&self) -> String {
859 let ptr = Arc::as_ptr(&self.call_count) as usize;
860 format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
861 }
862 }
863
864 #[tokio::test]
865 async fn test_dynamic_credential_provider_with_initial_cache() {
866 MockClock::set_system_time(Duration::from_secs(100_000));
867
868 let now_ms = MockClock::system_time().as_millis() as u64;
869
870 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
872 600_000, )));
874
875 let expires_at = now_ms + 600_000; let initial_cred = Arc::new(ObjectStoreAwsCredential {
878 key_id: "AKID_CACHED".to_string(),
879 secret_key: "SECRET_CACHED".to_string(),
880 token: Some("TOKEN_CACHED".to_string()),
881 });
882
883 let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
884 mock.clone(),
885 Duration::from_secs(300), initial_cred,
887 expires_at,
888 );
889
890 let cred = provider.get_credential().await.unwrap();
892 assert_eq!(cred.key_id, "AKID_CACHED");
893 assert_eq!(cred.secret_key, "SECRET_CACHED");
894 assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
895
896 assert_eq!(mock.get_call_count().await, 0);
898 }
899
900 #[tokio::test]
901 async fn test_dynamic_credential_provider_with_expired_cache() {
902 MockClock::set_system_time(Duration::from_secs(100_000));
903
904 let now_ms = MockClock::system_time().as_millis() as u64;
905
906 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
908 600_000, )));
910
911 let expired_time = now_ms - 1_000; let initial_cred = Arc::new(ObjectStoreAwsCredential {
914 key_id: "AKID_EXPIRED".to_string(),
915 secret_key: "SECRET_EXPIRED".to_string(),
916 token: None,
917 });
918
919 let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
920 mock.clone(),
921 Duration::from_secs(300), initial_cred,
923 expired_time,
924 );
925
926 let cred = provider.get_credential().await.unwrap();
928 assert_eq!(cred.key_id, "AKID_1");
929 assert_eq!(cred.secret_key, "SECRET_1");
930 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
931
932 assert_eq!(mock.get_call_count().await, 1);
934 }
935
936 #[tokio::test]
937 async fn test_dynamic_credential_provider_refresh_lead_time() {
938 MockClock::set_system_time(Duration::from_secs(100_000));
939
940 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
942 240_000, )));
944
945 let provider = DynamicStorageOptionsCredentialProvider::new(
948 mock.clone(),
949 Duration::from_secs(300), );
951
952 let cred = provider.get_credential().await.unwrap();
956 assert_eq!(cred.key_id, "AKID_1");
957 assert_eq!(mock.get_call_count().await, 1);
958
959 let cred = provider.get_credential().await.unwrap();
963 assert_eq!(cred.key_id, "AKID_2");
964 assert_eq!(mock.get_call_count().await, 2);
965 }
966
967 #[tokio::test]
968 async fn test_dynamic_credential_provider_no_initial_cache() {
969 MockClock::set_system_time(Duration::from_secs(100_000));
970
971 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
973 600_000, )));
975
976 let provider = DynamicStorageOptionsCredentialProvider::new(
978 mock.clone(),
979 Duration::from_secs(300), );
981
982 let cred = provider.get_credential().await.unwrap();
984 assert_eq!(cred.key_id, "AKID_1");
985 assert_eq!(cred.secret_key, "SECRET_1");
986 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
987 assert_eq!(mock.get_call_count().await, 1);
988
989 let cred = provider.get_credential().await.unwrap();
991 assert_eq!(cred.key_id, "AKID_1");
992 assert_eq!(mock.get_call_count().await, 1); MockClock::set_system_time(Duration::from_secs(100_000 + 360));
996 let cred = provider.get_credential().await.unwrap();
997 assert_eq!(cred.key_id, "AKID_2");
998 assert_eq!(cred.secret_key, "SECRET_2");
999 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
1000 assert_eq!(mock.get_call_count().await, 2);
1001
1002 MockClock::set_system_time(Duration::from_secs(100_000 + 660));
1004 let cred = provider.get_credential().await.unwrap();
1005 assert_eq!(cred.key_id, "AKID_3");
1006 assert_eq!(cred.secret_key, "SECRET_3");
1007 assert_eq!(mock.get_call_count().await, 3);
1008 }
1009
1010 #[tokio::test]
1011 async fn test_dynamic_credential_provider_with_initial_credential() {
1012 MockClock::set_system_time(Duration::from_secs(100_000));
1013
1014 let now_ms = MockClock::system_time().as_millis() as u64;
1015
1016 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1018 600_000, )));
1020
1021 let expires_at = now_ms + 600_000; let initial_cred = Arc::new(ObjectStoreAwsCredential {
1024 key_id: "AKID_INITIAL".to_string(),
1025 secret_key: "SECRET_INITIAL".to_string(),
1026 token: Some("TOKEN_INITIAL".to_string()),
1027 });
1028
1029 let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
1031 mock.clone(),
1032 Duration::from_secs(300), initial_cred,
1034 expires_at,
1035 );
1036
1037 let cred = provider.get_credential().await.unwrap();
1039 assert_eq!(cred.key_id, "AKID_INITIAL");
1040 assert_eq!(cred.secret_key, "SECRET_INITIAL");
1041 assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
1042
1043 assert_eq!(mock.get_call_count().await, 0);
1045
1046 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1049 let cred = provider.get_credential().await.unwrap();
1050 assert_eq!(cred.key_id, "AKID_1");
1051 assert_eq!(cred.secret_key, "SECRET_1");
1052 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1053
1054 assert_eq!(mock.get_call_count().await, 1);
1056
1057 MockClock::set_system_time(Duration::from_secs(100_000 + 660));
1059 let cred = provider.get_credential().await.unwrap();
1060 assert_eq!(cred.key_id, "AKID_2");
1061 assert_eq!(cred.secret_key, "SECRET_2");
1062 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
1063
1064 assert_eq!(mock.get_call_count().await, 2);
1066
1067 MockClock::set_system_time(Duration::from_secs(100_000 + 960));
1069 let cred = provider.get_credential().await.unwrap();
1070 assert_eq!(cred.key_id, "AKID_3");
1071 assert_eq!(cred.secret_key, "SECRET_3");
1072 assert_eq!(cred.token, Some("TOKEN_3".to_string()));
1073
1074 assert_eq!(mock.get_call_count().await, 3);
1076 }
1077
1078 #[tokio::test]
1079 async fn test_dynamic_credential_provider_concurrent_access() {
1080 let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
1082
1083 let provider = Arc::new(DynamicStorageOptionsCredentialProvider::new(
1084 mock.clone(),
1085 Duration::from_secs(300),
1086 ));
1087
1088 let mut handles = vec![];
1090 for i in 0..10 {
1091 let provider = provider.clone();
1092 let handle = tokio::spawn(async move {
1093 let cred = provider.get_credential().await.unwrap();
1094 assert_eq!(cred.key_id, "AKID_1");
1096 assert_eq!(cred.secret_key, "SECRET_1");
1097 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1098 i });
1100 handles.push(handle);
1101 }
1102
1103 let results: Vec<_> = futures::future::join_all(handles)
1105 .await
1106 .into_iter()
1107 .map(|r| r.unwrap())
1108 .collect();
1109
1110 assert_eq!(results.len(), 10);
1112 for i in 0..10 {
1113 assert!(results.contains(&i));
1114 }
1115
1116 let call_count = mock.get_call_count().await;
1119 assert_eq!(
1120 call_count, 1,
1121 "Provider should be called exactly once despite concurrent access"
1122 );
1123 }
1124
1125 #[tokio::test]
1126 async fn test_dynamic_credential_provider_concurrent_refresh() {
1127 MockClock::set_system_time(Duration::from_secs(100_000));
1128
1129 let now_ms = MockClock::system_time().as_millis() as u64;
1130
1131 let expires_at = now_ms - 1_000_000;
1133
1134 let initial_cred = Arc::new(ObjectStoreAwsCredential {
1135 key_id: "AKID_OLD".to_string(),
1136 secret_key: "SECRET_OLD".to_string(),
1137 token: Some("TOKEN_OLD".to_string()),
1138 });
1139
1140 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1142 3_600_000, )));
1144
1145 let provider = Arc::new(
1146 DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
1147 mock.clone(),
1148 Duration::from_secs(300),
1149 initial_cred,
1150 expires_at,
1151 ),
1152 );
1153
1154 let mut handles = vec![];
1157 for i in 0..20 {
1158 let provider = provider.clone();
1159 let handle = tokio::spawn(async move {
1160 let cred = provider.get_credential().await.unwrap();
1161 assert_eq!(cred.key_id, "AKID_1");
1163 assert_eq!(cred.secret_key, "SECRET_1");
1164 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1165 i
1166 });
1167 handles.push(handle);
1168 }
1169
1170 let results: Vec<_> = futures::future::join_all(handles)
1172 .await
1173 .into_iter()
1174 .map(|r| r.unwrap())
1175 .collect();
1176
1177 assert_eq!(results.len(), 20);
1179
1180 let call_count = mock.get_call_count().await;
1183 assert!(
1184 call_count >= 1,
1185 "Provider should be called at least once, was called {} times",
1186 call_count
1187 );
1188
1189 assert!(
1191 call_count < 10,
1192 "Provider should not be called too many times due to lock contention, was called {} times",
1193 call_count
1194 );
1195 }
1196}