1use std::{collections::HashMap, fmt, str::FromStr, sync::Arc, time::Duration};
5
6#[cfg(test)]
7use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
8
9#[cfg(not(test))]
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use object_store::ObjectStore as OSObjectStore;
13use object_store_opendal::OpendalStore;
14use opendal::{services::S3, Operator};
15
16use aws_config::default_provider::credentials::DefaultCredentialsChain;
17use aws_credential_types::provider::ProvideCredentials;
18use object_store::{
19 aws::{
20 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
21 AwsCredentialProvider,
22 },
23 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
24 StaticCredentialProvider,
25};
26use snafu::location;
27use tokio::sync::RwLock;
28use url::Url;
29
30use crate::object_store::{
31 ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, StorageOptionsProvider,
32 DEFAULT_CLOUD_BLOCK_SIZE, DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
33};
34use lance_core::error::{Error, Result};
35
36#[derive(Default, Debug)]
37pub struct AwsStoreProvider;
38
39impl AwsStoreProvider {
40 async fn build_amazon_s3_store(
41 &self,
42 base_path: &mut Url,
43 params: &ObjectStoreParams,
44 storage_options: &StorageOptions,
45 is_s3_express: bool,
46 ) -> Result<Arc<dyn OSObjectStore>> {
47 let max_retries = storage_options.client_max_retries();
48 let retry_timeout = storage_options.client_retry_timeout();
49 let retry_config = RetryConfig {
50 backoff: Default::default(),
51 max_retries,
52 retry_timeout: Duration::from_secs(retry_timeout),
53 };
54
55 let mut s3_storage_options = storage_options.as_s3_options();
56 let region = resolve_s3_region(base_path, &s3_storage_options).await?;
57 let (aws_creds, region) = build_aws_credential(
58 params.s3_credentials_refresh_offset,
59 params.aws_credentials.clone(),
60 Some(&s3_storage_options),
61 region,
62 params.storage_options_provider.clone(),
63 storage_options.expires_at_millis(),
64 )
65 .await?;
66
67 if is_s3_express {
69 s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
70 }
71
72 base_path.set_scheme("s3").unwrap();
74 base_path.set_query(None);
75
76 let mut builder = AmazonS3Builder::new();
78 for (key, value) in s3_storage_options {
79 builder = builder.with_config(key, value);
80 }
81 builder = builder
82 .with_url(base_path.as_ref())
83 .with_credentials(aws_creds)
84 .with_retry(retry_config)
85 .with_region(region);
86
87 Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
88 }
89
90 async fn build_opendal_s3_store(
91 &self,
92 base_path: &Url,
93 storage_options: &StorageOptions,
94 ) -> Result<Arc<dyn OSObjectStore>> {
95 let bucket = base_path
96 .host_str()
97 .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name", location!()))?
98 .to_string();
99
100 let prefix = base_path.path().trim_start_matches('/').to_string();
101
102 let mut config_map: HashMap<String, String> = storage_options.0.clone();
105
106 config_map.insert("bucket".to_string(), bucket);
108
109 if !prefix.is_empty() {
110 config_map.insert("root".to_string(), "/".to_string());
111 }
112
113 let operator = Operator::from_iter::<S3>(config_map)
114 .map_err(|e| {
115 Error::invalid_input(
116 format!("Failed to create S3 operator: {:?}", e),
117 location!(),
118 )
119 })?
120 .finish();
121
122 Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
123 }
124}
125
126#[async_trait::async_trait]
127impl ObjectStoreProvider for AwsStoreProvider {
128 async fn new_store(
129 &self,
130 mut base_path: Url,
131 params: &ObjectStoreParams,
132 ) -> Result<ObjectStore> {
133 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
134 let mut storage_options =
135 StorageOptions(params.storage_options.clone().unwrap_or_default());
136 storage_options.with_env_s3();
137 let download_retry_count = storage_options.download_retry_count();
138
139 let use_opendal = storage_options
140 .0
141 .get("use_opendal")
142 .map(|v| v == "true")
143 .unwrap_or(false);
144
145 let is_s3_express = check_s3_express(&base_path, &storage_options);
147
148 let use_constant_size_upload_parts = storage_options
149 .0
150 .get("aws_endpoint")
151 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
152 .unwrap_or(false);
153
154 let inner = if use_opendal {
155 self.build_opendal_s3_store(&base_path, &storage_options)
157 .await?
158 } else {
159 self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
161 .await?
162 };
163
164 Ok(ObjectStore {
165 inner,
166 scheme: String::from(base_path.scheme()),
167 block_size,
168 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
169 use_constant_size_upload_parts,
170 list_is_lexically_ordered: !is_s3_express,
171 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
172 download_retry_count,
173 io_tracker: Default::default(),
174 })
175 }
176}
177
178fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
180 storage_options
181 .0
182 .get("s3_express")
183 .map(|v| v == "true")
184 .unwrap_or(false)
185 || url.authority().ends_with("--x-s3")
186}
187
188async fn resolve_s3_region(
196 url: &Url,
197 storage_options: &HashMap<AmazonS3ConfigKey, String>,
198) -> Result<Option<String>> {
199 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
200 Ok(Some(region.clone()))
201 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
202 let bucket = url.host_str().ok_or_else(|| {
205 Error::invalid_input(
206 format!("Could not parse bucket from url: {}", url),
207 location!(),
208 )
209 })?;
210
211 let mut client_options = ClientOptions::default();
212 for (key, value) in storage_options {
213 if let AmazonS3ConfigKey::Client(client_key) = key {
214 client_options = client_options.with_config(*client_key, value.clone());
215 }
216 }
217
218 let bucket_region =
219 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
220 Ok(Some(bucket_region))
221 } else {
222 Ok(None)
223 }
224}
225
226pub async fn build_aws_credential(
246 credentials_refresh_offset: Duration,
247 credentials: Option<AwsCredentialProvider>,
248 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
249 region: Option<String>,
250 storage_options_provider: Option<Arc<dyn StorageOptionsProvider>>,
251 expires_at_millis: Option<u64>,
252) -> Result<(AwsCredentialProvider, String)> {
253 use aws_config::meta::region::RegionProviderChain;
255 const DEFAULT_REGION: &str = "us-west-2";
256
257 let region = if let Some(region) = region {
258 region
259 } else {
260 RegionProviderChain::default_provider()
261 .or_else(DEFAULT_REGION)
262 .region()
263 .await
264 .map(|r| r.as_ref().to_string())
265 .unwrap_or(DEFAULT_REGION.to_string())
266 };
267
268 let storage_options_credentials = storage_options.and_then(extract_static_s3_credentials);
269 if let Some(storage_options_provider) = storage_options_provider {
270 let creds = build_aws_credential_with_storage_options_provider(
271 storage_options_provider,
272 credentials_refresh_offset,
273 credentials,
274 storage_options_credentials,
275 expires_at_millis,
276 )
277 .await?;
278 Ok((creds, region))
279 } else if let Some(creds) = credentials {
280 Ok((creds, region))
281 } else if let Some(creds) = storage_options_credentials {
282 Ok((Arc::new(creds), region))
283 } else {
284 let credentials_provider = DefaultCredentialsChain::builder().build().await;
285
286 Ok((
287 Arc::new(AwsCredentialAdapter::new(
288 Arc::new(credentials_provider),
289 credentials_refresh_offset,
290 )),
291 region,
292 ))
293 }
294}
295
296async fn build_aws_credential_with_storage_options_provider(
297 storage_options_provider: Arc<dyn StorageOptionsProvider>,
298 credentials_refresh_offset: Duration,
299 credentials: Option<AwsCredentialProvider>,
300 storage_options_credentials: Option<StaticCredentialProvider<ObjectStoreAwsCredential>>,
301 expires_at_millis: Option<u64>,
302) -> Result<AwsCredentialProvider> {
303 match (expires_at_millis, credentials, storage_options_credentials) {
304 (Some(expires_at), Some(cred), _) => {
306 Ok(Arc::new(
307 DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
308 storage_options_provider,
309 credentials_refresh_offset,
310 cred.get_credential().await?,
311 expires_at,
312 ),
313 ))
314 }
315 (Some(expires_at), None, Some(cred)) => {
317 Ok(Arc::new(
318 DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
319 storage_options_provider,
320 credentials_refresh_offset,
321 cred.get_credential().await?,
322 expires_at,
323 ),
324 ))
325 }
326 (None, None, Some(_)) => Err(Error::IO {
328 source: Box::new(std::io::Error::other(
329 "expires_at_millis is required when using storage_options_provider with storage_options",
330 )),
331 location: location!(),
332 }),
333 (None, Some(_), _) => Err(Error::IO {
335 source: Box::new(std::io::Error::other(
336 "expires_at_millis is required when using storage_options_provider with credentials",
337 )),
338 location: location!(),
339 }),
340 (_, None, None) => Ok(Arc::new(DynamicStorageOptionsCredentialProvider::new(
342 storage_options_provider,
343 credentials_refresh_offset,
344 ))),
345 }
346}
347
348fn extract_static_s3_credentials(
349 options: &HashMap<AmazonS3ConfigKey, String>,
350) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
351 let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
352 let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
353 let token = options.get(&AmazonS3ConfigKey::Token).cloned();
354 match (key_id, secret_key, token) {
355 (Some(key_id), Some(secret_key), token) => {
356 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
357 key_id,
358 secret_key,
359 token,
360 }))
361 }
362 _ => None,
363 }
364}
365
366#[derive(Debug)]
368pub struct AwsCredentialAdapter {
369 pub inner: Arc<dyn ProvideCredentials>,
370
371 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
373
374 credentials_refresh_offset: Duration,
376}
377
378impl AwsCredentialAdapter {
379 pub fn new(
380 provider: Arc<dyn ProvideCredentials>,
381 credentials_refresh_offset: Duration,
382 ) -> Self {
383 Self {
384 inner: provider,
385 cache: Arc::new(RwLock::new(HashMap::new())),
386 credentials_refresh_offset,
387 }
388 }
389}
390
391const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
392
393fn to_system_time(time: std::time::SystemTime) -> SystemTime {
395 let duration_since_epoch = time
396 .duration_since(std::time::UNIX_EPOCH)
397 .expect("time should be after UNIX_EPOCH");
398 UNIX_EPOCH + duration_since_epoch
399}
400
401#[async_trait::async_trait]
402impl CredentialProvider for AwsCredentialAdapter {
403 type Credential = ObjectStoreAwsCredential;
404
405 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
406 let cached_creds = {
407 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
408 let expired = cache_value
409 .clone()
410 .map(|cred| {
411 cred.expiry()
412 .map(|exp| {
413 to_system_time(exp)
414 .checked_sub(self.credentials_refresh_offset)
415 .expect("this time should always be valid")
416 < SystemTime::now()
417 })
418 .unwrap_or(false)
420 })
421 .unwrap_or(true); if expired {
423 None
424 } else {
425 cache_value.clone()
426 }
427 };
428
429 if let Some(creds) = cached_creds {
430 Ok(Arc::new(Self::Credential {
431 key_id: creds.access_key_id().to_string(),
432 secret_key: creds.secret_access_key().to_string(),
433 token: creds.session_token().map(|s| s.to_string()),
434 }))
435 } else {
436 let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
437 |e| Error::Internal {
438 message: format!("Failed to get AWS credentials: {:?}", e),
439 location: location!(),
440 },
441 )?);
442
443 self.cache
444 .write()
445 .await
446 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
447
448 Ok(Arc::new(Self::Credential {
449 key_id: refreshed_creds.access_key_id().to_string(),
450 secret_key: refreshed_creds.secret_access_key().to_string(),
451 token: refreshed_creds.session_token().map(|s| s.to_string()),
452 }))
453 }
454 }
455}
456
457impl StorageOptions {
458 pub fn with_env_s3(&mut self) {
460 for (os_key, os_value) in std::env::vars_os() {
461 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
462 if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
463 if !self.0.contains_key(config_key.as_ref()) {
464 self.0
465 .insert(config_key.as_ref().to_string(), value.to_string());
466 }
467 }
468 }
469 }
470 }
471
472 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
474 self.0
475 .iter()
476 .filter_map(|(key, value)| {
477 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
478 Some((s3_key, value.clone()))
479 })
480 .collect()
481 }
482}
483
484impl ObjectStoreParams {
485 pub fn with_aws_credentials(
487 aws_credentials: Option<AwsCredentialProvider>,
488 region: Option<String>,
489 ) -> Self {
490 Self {
491 aws_credentials,
492 storage_options: region
493 .map(|region| [("region".into(), region)].iter().cloned().collect()),
494 ..Default::default()
495 }
496 }
497}
498
499pub struct DynamicStorageOptionsCredentialProvider {
513 provider: Arc<dyn StorageOptionsProvider>,
514 cache: Arc<RwLock<Option<CachedCredential>>>,
515 refresh_offset: Duration,
516}
517
518impl fmt::Debug for DynamicStorageOptionsCredentialProvider {
519 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
520 f.debug_struct("DynamicStorageOptionsCredentialProvider")
521 .field("provider", &self.provider)
522 .field("refresh_offset", &self.refresh_offset)
523 .finish()
524 }
525}
526
527#[derive(Debug, Clone)]
528struct CachedCredential {
529 credential: Arc<ObjectStoreAwsCredential>,
530 expires_at_millis: Option<u64>,
531}
532
533impl DynamicStorageOptionsCredentialProvider {
534 pub fn new(provider: Arc<dyn StorageOptionsProvider>, refresh_offset: Duration) -> Self {
540 Self {
541 provider,
542 cache: Arc::new(RwLock::new(None)),
543 refresh_offset,
544 }
545 }
546
547 pub fn new_with_initial_credential(
555 provider: Arc<dyn StorageOptionsProvider>,
556 refresh_offset: Duration,
557 credential: Arc<ObjectStoreAwsCredential>,
558 expires_at_millis: u64,
559 ) -> Self {
560 Self {
561 provider,
562 cache: Arc::new(RwLock::new(Some(CachedCredential {
563 credential,
564 expires_at_millis: Some(expires_at_millis),
565 }))),
566 refresh_offset,
567 }
568 }
569
570 fn needs_refresh(&self, cached: &Option<CachedCredential>) -> bool {
571 match cached {
572 None => true,
573 Some(cached_cred) => {
574 if let Some(expires_at_millis) = cached_cred.expires_at_millis {
575 let now_ms = SystemTime::now()
576 .duration_since(UNIX_EPOCH)
577 .unwrap_or(Duration::from_secs(0))
578 .as_millis() as u64;
579
580 let refresh_offset_millis = self.refresh_offset.as_millis() as u64;
582 now_ms + refresh_offset_millis >= expires_at_millis
583 } else {
584 false
586 }
587 }
588 }
589 }
590
591 async fn do_get_credential(&self) -> ObjectStoreResult<Option<Arc<ObjectStoreAwsCredential>>> {
592 {
594 let cached = self.cache.read().await;
595 if !self.needs_refresh(&cached) {
596 if let Some(cached_cred) = &*cached {
597 return Ok(Some(cached_cred.credential.clone()));
598 }
599 }
600 }
601
602 let Ok(mut cache) = self.cache.try_write() else {
604 return Ok(None);
605 };
606
607 if !self.needs_refresh(&cache) {
610 if let Some(cached_cred) = &*cache {
611 return Ok(Some(cached_cred.credential.clone()));
612 }
613 }
614
615 log::debug!(
616 "Refreshing S3 credentials from storage options provider: {}",
617 self.provider.provider_id()
618 );
619
620 let storage_options_map = self
621 .provider
622 .fetch_storage_options()
623 .await
624 .map_err(|e| object_store::Error::Generic {
625 store: "DynamicStorageOptionsCredentialProvider",
626 source: Box::new(e),
627 })?
628 .ok_or_else(|| object_store::Error::Generic {
629 store: "DynamicStorageOptionsCredentialProvider",
630 source: "No storage options available".into(),
631 })?;
632
633 let storage_options = StorageOptions(storage_options_map);
634 let expires_at_millis = storage_options.expires_at_millis();
635 let s3_options = storage_options.as_s3_options();
636 let static_creds = extract_static_s3_credentials(&s3_options).ok_or_else(|| {
637 object_store::Error::Generic {
638 store: "DynamicStorageOptionsCredentialProvider",
639 source: "Missing required credentials in storage options".into(),
640 }
641 })?;
642
643 let credential =
644 static_creds
645 .get_credential()
646 .await
647 .map_err(|e| object_store::Error::Generic {
648 store: "DynamicStorageOptionsCredentialProvider",
649 source: Box::new(e),
650 })?;
651
652 if let Some(expires_at) = expires_at_millis {
653 let now_ms = SystemTime::now()
654 .duration_since(UNIX_EPOCH)
655 .unwrap_or(Duration::from_secs(0))
656 .as_millis() as u64;
657 let expires_in_secs = (expires_at.saturating_sub(now_ms)) / 1000;
658 log::debug!(
659 "Successfully refreshed S3 credentials from provider: {}, credentials expire in {} seconds",
660 self.provider.provider_id(),
661 expires_in_secs
662 );
663 } else {
664 log::debug!(
665 "Successfully refreshed S3 credentials from provider: {} (no expiration)",
666 self.provider.provider_id()
667 );
668 }
669
670 *cache = Some(CachedCredential {
671 credential: credential.clone(),
672 expires_at_millis,
673 });
674
675 Ok(Some(credential))
676 }
677}
678
679#[async_trait::async_trait]
680impl CredentialProvider for DynamicStorageOptionsCredentialProvider {
681 type Credential = ObjectStoreAwsCredential;
682
683 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
684 loop {
686 match self.do_get_credential().await? {
687 Some(cred) => return Ok(cred),
688 None => {
689 tokio::time::sleep(Duration::from_millis(10)).await;
691 continue;
692 }
693 }
694 }
695 }
696}
697
698#[cfg(test)]
699mod tests {
700 use crate::object_store::ObjectStoreRegistry;
701 use mock_instant::thread_local::MockClock;
702 use object_store::path::Path;
703 use std::sync::atomic::{AtomicBool, Ordering};
704
705 use super::*;
706
707 #[derive(Debug, Default)]
708 struct MockAwsCredentialsProvider {
709 called: AtomicBool,
710 }
711
712 #[async_trait::async_trait]
713 impl CredentialProvider for MockAwsCredentialsProvider {
714 type Credential = ObjectStoreAwsCredential;
715
716 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
717 self.called.store(true, Ordering::Relaxed);
718 Ok(Arc::new(Self::Credential {
719 key_id: "".to_string(),
720 secret_key: "".to_string(),
721 token: None,
722 }))
723 }
724 }
725
726 #[tokio::test]
727 async fn test_injected_aws_creds_option_is_used() {
728 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
729 let registry = Arc::new(ObjectStoreRegistry::default());
730
731 let params = ObjectStoreParams {
732 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
733 ..ObjectStoreParams::default()
734 };
735
736 assert!(!mock_provider.called.load(Ordering::Relaxed));
738
739 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
740 .await
741 .unwrap();
742
743 let _ = store
745 .open(&Path::parse("/").unwrap())
746 .await
747 .unwrap()
748 .get_range(0..1)
749 .await;
750
751 assert!(mock_provider.called.load(Ordering::Relaxed));
753 }
754
755 #[test]
756 fn test_s3_path_parsing() {
757 let provider = AwsStoreProvider;
758
759 let cases = [
760 ("s3://bucket/path/to/file", "path/to/file"),
761 ("s3://bucket/测试path/to/file", "测试path/to/file"),
763 ("s3://bucket/path/&to/file", "path/&to/file"),
764 ("s3://bucket/path/=to/file", "path/=to/file"),
765 (
766 "s3+ddb://bucket/path/to/file?ddbTableName=test",
767 "path/to/file",
768 ),
769 ];
770
771 for (uri, expected_path) in cases {
772 let url = Url::parse(uri).unwrap();
773 let path = provider.extract_path(&url).unwrap();
774 let expected_path = Path::from(expected_path);
775 assert_eq!(path, expected_path)
776 }
777 }
778
779 #[test]
780 fn test_is_s3_express() {
781 let cases = [
782 (
783 "s3://bucket/path/to/file",
784 HashMap::from([("s3_express".to_string(), "true".to_string())]),
785 true,
786 ),
787 (
788 "s3://bucket/path/to/file",
789 HashMap::from([("s3_express".to_string(), "false".to_string())]),
790 false,
791 ),
792 ("s3://bucket/path/to/file", HashMap::from([]), false),
793 (
794 "s3://bucket--x-s3/path/to/file",
795 HashMap::from([("s3_express".to_string(), "true".to_string())]),
796 true,
797 ),
798 (
799 "s3://bucket--x-s3/path/to/file",
800 HashMap::from([("s3_express".to_string(), "false".to_string())]),
801 true, ),
803 ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
804 ];
805
806 for (uri, storage_map, expected) in cases {
807 let url = Url::parse(uri).unwrap();
808 let storage_options = StorageOptions(storage_map);
809 let is_s3_express = check_s3_express(&url, &storage_options);
810 assert_eq!(is_s3_express, expected);
811 }
812 }
813
814 #[tokio::test]
815 async fn test_use_opendal_flag() {
816 let provider = AwsStoreProvider;
817 let url = Url::parse("s3://test-bucket/path").unwrap();
818 let params_with_flag = ObjectStoreParams {
819 storage_options: Some(HashMap::from([
820 ("use_opendal".to_string(), "true".to_string()),
821 ("region".to_string(), "us-west-2".to_string()),
822 ])),
823 ..Default::default()
824 };
825
826 let store = provider
827 .new_store(url.clone(), ¶ms_with_flag)
828 .await
829 .unwrap();
830 assert_eq!(store.scheme, "s3");
831 }
832
833 #[derive(Debug)]
834 struct MockStorageOptionsProvider {
835 call_count: Arc<RwLock<usize>>,
836 expires_in_millis: Option<u64>,
837 }
838
839 impl MockStorageOptionsProvider {
840 fn new(expires_in_millis: Option<u64>) -> Self {
841 Self {
842 call_count: Arc::new(RwLock::new(0)),
843 expires_in_millis,
844 }
845 }
846
847 async fn get_call_count(&self) -> usize {
848 *self.call_count.read().await
849 }
850 }
851
852 #[async_trait::async_trait]
853 impl StorageOptionsProvider for MockStorageOptionsProvider {
854 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
855 let count = {
856 let mut c = self.call_count.write().await;
857 *c += 1;
858 *c
859 };
860
861 let mut options = HashMap::from([
862 ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
863 (
864 "aws_secret_access_key".to_string(),
865 format!("SECRET_{}", count),
866 ),
867 ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
868 ]);
869
870 if let Some(expires_in) = self.expires_in_millis {
871 let now_ms = SystemTime::now()
872 .duration_since(UNIX_EPOCH)
873 .unwrap()
874 .as_millis() as u64;
875 let expires_at = now_ms + expires_in;
876 options.insert("expires_at_millis".to_string(), expires_at.to_string());
877 }
878
879 Ok(Some(options))
880 }
881
882 fn provider_id(&self) -> String {
883 let ptr = Arc::as_ptr(&self.call_count) as usize;
884 format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
885 }
886 }
887
888 #[tokio::test]
889 async fn test_dynamic_credential_provider_with_initial_cache() {
890 MockClock::set_system_time(Duration::from_secs(100_000));
891
892 let now_ms = MockClock::system_time().as_millis() as u64;
893
894 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
896 600_000, )));
898
899 let expires_at = now_ms + 600_000; let initial_cred = Arc::new(ObjectStoreAwsCredential {
902 key_id: "AKID_CACHED".to_string(),
903 secret_key: "SECRET_CACHED".to_string(),
904 token: Some("TOKEN_CACHED".to_string()),
905 });
906
907 let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
908 mock.clone(),
909 Duration::from_secs(300), initial_cred,
911 expires_at,
912 );
913
914 let cred = provider.get_credential().await.unwrap();
916 assert_eq!(cred.key_id, "AKID_CACHED");
917 assert_eq!(cred.secret_key, "SECRET_CACHED");
918 assert_eq!(cred.token, Some("TOKEN_CACHED".to_string()));
919
920 assert_eq!(mock.get_call_count().await, 0);
922 }
923
924 #[tokio::test]
925 async fn test_dynamic_credential_provider_with_expired_cache() {
926 MockClock::set_system_time(Duration::from_secs(100_000));
927
928 let now_ms = MockClock::system_time().as_millis() as u64;
929
930 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
932 600_000, )));
934
935 let expired_time = now_ms - 1_000; let initial_cred = Arc::new(ObjectStoreAwsCredential {
938 key_id: "AKID_EXPIRED".to_string(),
939 secret_key: "SECRET_EXPIRED".to_string(),
940 token: None,
941 });
942
943 let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
944 mock.clone(),
945 Duration::from_secs(300), initial_cred,
947 expired_time,
948 );
949
950 let cred = provider.get_credential().await.unwrap();
952 assert_eq!(cred.key_id, "AKID_1");
953 assert_eq!(cred.secret_key, "SECRET_1");
954 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
955
956 assert_eq!(mock.get_call_count().await, 1);
958 }
959
960 #[tokio::test]
961 async fn test_dynamic_credential_provider_refresh_lead_time() {
962 MockClock::set_system_time(Duration::from_secs(100_000));
963
964 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
966 240_000, )));
968
969 let provider = DynamicStorageOptionsCredentialProvider::new(
972 mock.clone(),
973 Duration::from_secs(300), );
975
976 let cred = provider.get_credential().await.unwrap();
980 assert_eq!(cred.key_id, "AKID_1");
981 assert_eq!(mock.get_call_count().await, 1);
982
983 let cred = provider.get_credential().await.unwrap();
987 assert_eq!(cred.key_id, "AKID_2");
988 assert_eq!(mock.get_call_count().await, 2);
989 }
990
991 #[tokio::test]
992 async fn test_dynamic_credential_provider_no_initial_cache() {
993 MockClock::set_system_time(Duration::from_secs(100_000));
994
995 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
997 600_000, )));
999
1000 let provider = DynamicStorageOptionsCredentialProvider::new(
1002 mock.clone(),
1003 Duration::from_secs(300), );
1005
1006 let cred = provider.get_credential().await.unwrap();
1008 assert_eq!(cred.key_id, "AKID_1");
1009 assert_eq!(cred.secret_key, "SECRET_1");
1010 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1011 assert_eq!(mock.get_call_count().await, 1);
1012
1013 let cred = provider.get_credential().await.unwrap();
1015 assert_eq!(cred.key_id, "AKID_1");
1016 assert_eq!(mock.get_call_count().await, 1); MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1020 let cred = provider.get_credential().await.unwrap();
1021 assert_eq!(cred.key_id, "AKID_2");
1022 assert_eq!(cred.secret_key, "SECRET_2");
1023 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
1024 assert_eq!(mock.get_call_count().await, 2);
1025
1026 MockClock::set_system_time(Duration::from_secs(100_000 + 660));
1028 let cred = provider.get_credential().await.unwrap();
1029 assert_eq!(cred.key_id, "AKID_3");
1030 assert_eq!(cred.secret_key, "SECRET_3");
1031 assert_eq!(mock.get_call_count().await, 3);
1032 }
1033
1034 #[tokio::test]
1035 async fn test_dynamic_credential_provider_with_initial_credential() {
1036 MockClock::set_system_time(Duration::from_secs(100_000));
1037
1038 let now_ms = MockClock::system_time().as_millis() as u64;
1039
1040 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1042 600_000, )));
1044
1045 let expires_at = now_ms + 600_000; let initial_cred = Arc::new(ObjectStoreAwsCredential {
1048 key_id: "AKID_INITIAL".to_string(),
1049 secret_key: "SECRET_INITIAL".to_string(),
1050 token: Some("TOKEN_INITIAL".to_string()),
1051 });
1052
1053 let provider = DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
1055 mock.clone(),
1056 Duration::from_secs(300), initial_cred,
1058 expires_at,
1059 );
1060
1061 let cred = provider.get_credential().await.unwrap();
1063 assert_eq!(cred.key_id, "AKID_INITIAL");
1064 assert_eq!(cred.secret_key, "SECRET_INITIAL");
1065 assert_eq!(cred.token, Some("TOKEN_INITIAL".to_string()));
1066
1067 assert_eq!(mock.get_call_count().await, 0);
1069
1070 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
1073 let cred = provider.get_credential().await.unwrap();
1074 assert_eq!(cred.key_id, "AKID_1");
1075 assert_eq!(cred.secret_key, "SECRET_1");
1076 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1077
1078 assert_eq!(mock.get_call_count().await, 1);
1080
1081 MockClock::set_system_time(Duration::from_secs(100_000 + 660));
1083 let cred = provider.get_credential().await.unwrap();
1084 assert_eq!(cred.key_id, "AKID_2");
1085 assert_eq!(cred.secret_key, "SECRET_2");
1086 assert_eq!(cred.token, Some("TOKEN_2".to_string()));
1087
1088 assert_eq!(mock.get_call_count().await, 2);
1090
1091 MockClock::set_system_time(Duration::from_secs(100_000 + 960));
1093 let cred = provider.get_credential().await.unwrap();
1094 assert_eq!(cred.key_id, "AKID_3");
1095 assert_eq!(cred.secret_key, "SECRET_3");
1096 assert_eq!(cred.token, Some("TOKEN_3".to_string()));
1097
1098 assert_eq!(mock.get_call_count().await, 3);
1100 }
1101
1102 #[tokio::test]
1103 async fn test_dynamic_credential_provider_concurrent_access() {
1104 let mock = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
1106
1107 let provider = Arc::new(DynamicStorageOptionsCredentialProvider::new(
1108 mock.clone(),
1109 Duration::from_secs(300),
1110 ));
1111
1112 let mut handles = vec![];
1114 for i in 0..10 {
1115 let provider = provider.clone();
1116 let handle = tokio::spawn(async move {
1117 let cred = provider.get_credential().await.unwrap();
1118 assert_eq!(cred.key_id, "AKID_1");
1120 assert_eq!(cred.secret_key, "SECRET_1");
1121 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1122 i });
1124 handles.push(handle);
1125 }
1126
1127 let results: Vec<_> = futures::future::join_all(handles)
1129 .await
1130 .into_iter()
1131 .map(|r| r.unwrap())
1132 .collect();
1133
1134 assert_eq!(results.len(), 10);
1136 for i in 0..10 {
1137 assert!(results.contains(&i));
1138 }
1139
1140 let call_count = mock.get_call_count().await;
1143 assert_eq!(
1144 call_count, 1,
1145 "Provider should be called exactly once despite concurrent access"
1146 );
1147 }
1148
1149 #[tokio::test]
1150 async fn test_dynamic_credential_provider_concurrent_refresh() {
1151 MockClock::set_system_time(Duration::from_secs(100_000));
1152
1153 let now_ms = MockClock::system_time().as_millis() as u64;
1154
1155 let expires_at = now_ms - 1_000_000;
1157
1158 let initial_cred = Arc::new(ObjectStoreAwsCredential {
1159 key_id: "AKID_OLD".to_string(),
1160 secret_key: "SECRET_OLD".to_string(),
1161 token: Some("TOKEN_OLD".to_string()),
1162 });
1163
1164 let mock = Arc::new(MockStorageOptionsProvider::new(Some(
1166 3_600_000, )));
1168
1169 let provider = Arc::new(
1170 DynamicStorageOptionsCredentialProvider::new_with_initial_credential(
1171 mock.clone(),
1172 Duration::from_secs(300),
1173 initial_cred,
1174 expires_at,
1175 ),
1176 );
1177
1178 let mut handles = vec![];
1181 for i in 0..20 {
1182 let provider = provider.clone();
1183 let handle = tokio::spawn(async move {
1184 let cred = provider.get_credential().await.unwrap();
1185 assert_eq!(cred.key_id, "AKID_1");
1187 assert_eq!(cred.secret_key, "SECRET_1");
1188 assert_eq!(cred.token, Some("TOKEN_1".to_string()));
1189 i
1190 });
1191 handles.push(handle);
1192 }
1193
1194 let results: Vec<_> = futures::future::join_all(handles)
1196 .await
1197 .into_iter()
1198 .map(|r| r.unwrap())
1199 .collect();
1200
1201 assert_eq!(results.len(), 20);
1203
1204 let call_count = mock.get_call_count().await;
1207 assert!(
1208 call_count >= 1,
1209 "Provider should be called at least once, was called {} times",
1210 call_count
1211 );
1212
1213 assert!(
1215 call_count < 10,
1216 "Provider should not be called too many times due to lock contention, was called {} times",
1217 call_count
1218 );
1219 }
1220}