1use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
7
8#[cfg(test)]
9use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
10
11#[cfg(not(test))]
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use object_store::ObjectStore as OSObjectStore;
15use object_store_opendal::OpendalStore;
16use opendal::{services::S3, Operator};
17
18use aws_config::default_provider::credentials::DefaultCredentialsChain;
19use aws_credential_types::provider::ProvideCredentials;
20use object_store::{
21 aws::{
22 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
23 AwsCredentialProvider,
24 },
25 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
26 StaticCredentialProvider,
27};
28use snafu::location;
29use tokio::sync::RwLock;
30use url::Url;
31
32use crate::object_store::{
33 ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
34 DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
35};
36use lance_core::error::{Error, Result};
37
38#[derive(Default, Debug)]
39pub struct AwsStoreProvider;
40
41impl AwsStoreProvider {
42 async fn build_amazon_s3_store(
43 &self,
44 base_path: &mut Url,
45 params: &ObjectStoreParams,
46 storage_options: &StorageOptions,
47 is_s3_express: bool,
48 ) -> Result<Arc<dyn OSObjectStore>> {
49 let max_retries = storage_options.client_max_retries();
50 let retry_timeout = storage_options.client_retry_timeout();
51 let retry_config = RetryConfig {
52 backoff: Default::default(),
53 max_retries,
54 retry_timeout: Duration::from_secs(retry_timeout),
55 };
56
57 let mut s3_storage_options = storage_options.as_s3_options();
58 let region = resolve_s3_region(base_path, &s3_storage_options).await?;
59 let (aws_creds, region) = build_aws_credential(
60 params.s3_credentials_refresh_offset,
61 params.aws_credentials.clone(),
62 Some(&s3_storage_options),
63 region,
64 )
65 .await?;
66
67 if is_s3_express {
69 s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
70 }
71
72 base_path.set_scheme("s3").unwrap();
74 base_path.set_query(None);
75
76 let mut builder = AmazonS3Builder::new();
78 for (key, value) in s3_storage_options {
79 builder = builder.with_config(key, value);
80 }
81 builder = builder
82 .with_url(base_path.as_ref())
83 .with_credentials(aws_creds)
84 .with_retry(retry_config)
85 .with_region(region);
86
87 Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
88 }
89
90 async fn build_opendal_s3_store(
91 &self,
92 base_path: &Url,
93 storage_options: &StorageOptions,
94 ) -> Result<Arc<dyn OSObjectStore>> {
95 let bucket = base_path
96 .host_str()
97 .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name", location!()))?
98 .to_string();
99
100 let prefix = base_path.path().trim_start_matches('/').to_string();
101
102 let mut config_map: HashMap<String, String> = storage_options.0.clone();
105
106 config_map.insert("bucket".to_string(), bucket);
108
109 if !prefix.is_empty() {
110 config_map.insert("root".to_string(), "/".to_string());
111 }
112
113 let operator = Operator::from_iter::<S3>(config_map)
114 .map_err(|e| {
115 Error::invalid_input(
116 format!("Failed to create S3 operator: {:?}", e),
117 location!(),
118 )
119 })?
120 .finish();
121
122 Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
123 }
124}
125
126#[async_trait::async_trait]
127impl ObjectStoreProvider for AwsStoreProvider {
128 async fn new_store(
129 &self,
130 mut base_path: Url,
131 params: &ObjectStoreParams,
132 ) -> Result<ObjectStore> {
133 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
134 let mut storage_options =
135 StorageOptions(params.storage_options.clone().unwrap_or_default());
136 storage_options.with_env_s3();
137 let download_retry_count = storage_options.download_retry_count();
138
139 let use_opendal = storage_options
140 .0
141 .get("use_opendal")
142 .map(|v| v == "true")
143 .unwrap_or(false);
144
145 let is_s3_express = check_s3_express(&base_path, &storage_options);
147
148 let use_constant_size_upload_parts = storage_options
149 .0
150 .get("aws_endpoint")
151 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
152 .unwrap_or(false);
153
154 let inner = if use_opendal {
155 self.build_opendal_s3_store(&base_path, &storage_options)
157 .await?
158 } else {
159 self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
161 .await?
162 };
163
164 Ok(ObjectStore {
165 inner,
166 scheme: String::from(base_path.scheme()),
167 block_size,
168 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
169 use_constant_size_upload_parts,
170 list_is_lexically_ordered: !is_s3_express,
171 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
172 download_retry_count,
173 })
174 }
175}
176
177fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
179 storage_options
180 .0
181 .get("s3_express")
182 .map(|v| v == "true")
183 .unwrap_or(false)
184 || url.authority().ends_with("--x-s3")
185}
186
187async fn resolve_s3_region(
195 url: &Url,
196 storage_options: &HashMap<AmazonS3ConfigKey, String>,
197) -> Result<Option<String>> {
198 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
199 Ok(Some(region.clone()))
200 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
201 let bucket = url.host_str().ok_or_else(|| {
204 Error::invalid_input(
205 format!("Could not parse bucket from url: {}", url),
206 location!(),
207 )
208 })?;
209
210 let mut client_options = ClientOptions::default();
211 for (key, value) in storage_options {
212 if let AmazonS3ConfigKey::Client(client_key) = key {
213 client_options = client_options.with_config(*client_key, value.clone());
214 }
215 }
216
217 let bucket_region =
218 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
219 Ok(Some(bucket_region))
220 } else {
221 Ok(None)
222 }
223}
224
225pub async fn build_aws_credential(
235 credentials_refresh_offset: Duration,
236 credentials: Option<AwsCredentialProvider>,
237 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
238 region: Option<String>,
239) -> Result<(AwsCredentialProvider, String)> {
240 use aws_config::meta::region::RegionProviderChain;
242 const DEFAULT_REGION: &str = "us-west-2";
243
244 let region = if let Some(region) = region {
245 region
246 } else {
247 RegionProviderChain::default_provider()
248 .or_else(DEFAULT_REGION)
249 .region()
250 .await
251 .map(|r| r.as_ref().to_string())
252 .unwrap_or(DEFAULT_REGION.to_string())
253 };
254
255 if let Some(creds) = credentials {
256 Ok((creds, region))
257 } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
258 Ok((Arc::new(creds), region))
259 } else {
260 let credentials_provider = DefaultCredentialsChain::builder().build().await;
261
262 Ok((
263 Arc::new(AwsCredentialAdapter::new(
264 Arc::new(credentials_provider),
265 credentials_refresh_offset,
266 )),
267 region,
268 ))
269 }
270}
271
272fn extract_static_s3_credentials(
273 options: &HashMap<AmazonS3ConfigKey, String>,
274) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
275 let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
276 let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
277 let token = options.get(&AmazonS3ConfigKey::Token).cloned();
278 match (key_id, secret_key, token) {
279 (Some(key_id), Some(secret_key), token) => {
280 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
281 key_id,
282 secret_key,
283 token,
284 }))
285 }
286 _ => None,
287 }
288}
289
290#[derive(Debug)]
292pub struct AwsCredentialAdapter {
293 pub inner: Arc<dyn ProvideCredentials>,
294
295 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
297
298 credentials_refresh_offset: Duration,
300}
301
302impl AwsCredentialAdapter {
303 pub fn new(
304 provider: Arc<dyn ProvideCredentials>,
305 credentials_refresh_offset: Duration,
306 ) -> Self {
307 Self {
308 inner: provider,
309 cache: Arc::new(RwLock::new(HashMap::new())),
310 credentials_refresh_offset,
311 }
312 }
313}
314
315const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
316
317fn to_system_time(time: std::time::SystemTime) -> SystemTime {
319 let duration_since_epoch = time
320 .duration_since(std::time::UNIX_EPOCH)
321 .expect("time should be after UNIX_EPOCH");
322 UNIX_EPOCH + duration_since_epoch
323}
324
325#[async_trait::async_trait]
326impl CredentialProvider for AwsCredentialAdapter {
327 type Credential = ObjectStoreAwsCredential;
328
329 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
330 let cached_creds = {
331 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
332 let expired = cache_value
333 .clone()
334 .map(|cred| {
335 cred.expiry()
336 .map(|exp| {
337 to_system_time(exp)
338 .checked_sub(self.credentials_refresh_offset)
339 .expect("this time should always be valid")
340 < SystemTime::now()
341 })
342 .unwrap_or(false)
344 })
345 .unwrap_or(true); if expired {
347 None
348 } else {
349 cache_value.clone()
350 }
351 };
352
353 if let Some(creds) = cached_creds {
354 Ok(Arc::new(Self::Credential {
355 key_id: creds.access_key_id().to_string(),
356 secret_key: creds.secret_access_key().to_string(),
357 token: creds.session_token().map(|s| s.to_string()),
358 }))
359 } else {
360 let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
361 |e| Error::Internal {
362 message: format!("Failed to get AWS credentials: {:?}", e),
363 location: location!(),
364 },
365 )?);
366
367 self.cache
368 .write()
369 .await
370 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
371
372 Ok(Arc::new(Self::Credential {
373 key_id: refreshed_creds.access_key_id().to_string(),
374 secret_key: refreshed_creds.secret_access_key().to_string(),
375 token: refreshed_creds.session_token().map(|s| s.to_string()),
376 }))
377 }
378 }
379}
380
381impl StorageOptions {
382 pub fn with_env_s3(&mut self) {
384 for (os_key, os_value) in std::env::vars_os() {
385 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
386 if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
387 if !self.0.contains_key(config_key.as_ref()) {
388 self.0
389 .insert(config_key.as_ref().to_string(), value.to_string());
390 }
391 }
392 }
393 }
394 }
395
396 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
398 self.0
399 .iter()
400 .filter_map(|(key, value)| {
401 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
402 Some((s3_key, value.clone()))
403 })
404 .collect()
405 }
406}
407
408impl ObjectStoreParams {
409 pub fn with_aws_credentials(
411 aws_credentials: Option<AwsCredentialProvider>,
412 region: Option<String>,
413 ) -> Self {
414 Self {
415 aws_credentials,
416 storage_options: region
417 .map(|region| [("region".into(), region)].iter().cloned().collect()),
418 ..Default::default()
419 }
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use std::sync::atomic::{AtomicBool, Ordering};
426
427 use object_store::path::Path;
428
429 use crate::object_store::ObjectStoreRegistry;
430
431 use super::*;
432
433 #[derive(Debug, Default)]
434 struct MockAwsCredentialsProvider {
435 called: AtomicBool,
436 }
437
438 #[async_trait::async_trait]
439 impl CredentialProvider for MockAwsCredentialsProvider {
440 type Credential = ObjectStoreAwsCredential;
441
442 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
443 self.called.store(true, Ordering::Relaxed);
444 Ok(Arc::new(Self::Credential {
445 key_id: "".to_string(),
446 secret_key: "".to_string(),
447 token: None,
448 }))
449 }
450 }
451
452 #[tokio::test]
453 async fn test_injected_aws_creds_option_is_used() {
454 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
455 let registry = Arc::new(ObjectStoreRegistry::default());
456
457 let params = ObjectStoreParams {
458 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
459 ..ObjectStoreParams::default()
460 };
461
462 assert!(!mock_provider.called.load(Ordering::Relaxed));
464
465 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
466 .await
467 .unwrap();
468
469 let _ = store
471 .open(&Path::parse("/").unwrap())
472 .await
473 .unwrap()
474 .get_range(0..1)
475 .await;
476
477 assert!(mock_provider.called.load(Ordering::Relaxed));
479 }
480
481 #[test]
482 fn test_s3_path_parsing() {
483 let provider = AwsStoreProvider;
484
485 let cases = [
486 ("s3://bucket/path/to/file", "path/to/file"),
487 ("s3://bucket/测试path/to/file", "测试path/to/file"),
489 ("s3://bucket/path/&to/file", "path/&to/file"),
490 ("s3://bucket/path/=to/file", "path/=to/file"),
491 (
492 "s3+ddb://bucket/path/to/file?ddbTableName=test",
493 "path/to/file",
494 ),
495 ];
496
497 for (uri, expected_path) in cases {
498 let url = Url::parse(uri).unwrap();
499 let path = provider.extract_path(&url).unwrap();
500 let expected_path = Path::from(expected_path);
501 assert_eq!(path, expected_path)
502 }
503 }
504
505 #[test]
506 fn test_is_s3_express() {
507 let cases = [
508 (
509 "s3://bucket/path/to/file",
510 HashMap::from([("s3_express".to_string(), "true".to_string())]),
511 true,
512 ),
513 (
514 "s3://bucket/path/to/file",
515 HashMap::from([("s3_express".to_string(), "false".to_string())]),
516 false,
517 ),
518 ("s3://bucket/path/to/file", HashMap::from([]), false),
519 (
520 "s3://bucket--x-s3/path/to/file",
521 HashMap::from([("s3_express".to_string(), "true".to_string())]),
522 true,
523 ),
524 (
525 "s3://bucket--x-s3/path/to/file",
526 HashMap::from([("s3_express".to_string(), "false".to_string())]),
527 true, ),
529 ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
530 ];
531
532 for (uri, storage_map, expected) in cases {
533 let url = Url::parse(uri).unwrap();
534 let storage_options = StorageOptions(storage_map);
535 let is_s3_express = check_s3_express(&url, &storage_options);
536 assert_eq!(is_s3_express, expected);
537 }
538 }
539
540 #[tokio::test]
541 async fn test_use_opendal_flag() {
542 let provider = AwsStoreProvider;
543 let url = Url::parse("s3://test-bucket/path").unwrap();
544 let params_with_flag = ObjectStoreParams {
545 storage_options: Some(HashMap::from([
546 ("use_opendal".to_string(), "true".to_string()),
547 ("region".to_string(), "us-west-2".to_string()),
548 ])),
549 ..Default::default()
550 };
551
552 let store = provider
553 .new_store(url.clone(), ¶ms_with_flag)
554 .await
555 .unwrap();
556 assert_eq!(store.scheme, "s3");
557 }
558}