1use std::{
7 collections::HashMap,
8 str::FromStr,
9 sync::Arc,
10 time::{Duration, SystemTime},
11};
12
13use object_store::ObjectStore as OSObjectStore;
14use object_store_opendal::OpendalStore;
15use opendal::{services::S3, Operator};
16
17use aws_config::default_provider::credentials::DefaultCredentialsChain;
18use aws_credential_types::provider::ProvideCredentials;
19use object_store::{
20 aws::{
21 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
22 AwsCredentialProvider,
23 },
24 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
25 StaticCredentialProvider,
26};
27use snafu::location;
28use tokio::sync::RwLock;
29use url::Url;
30
31use crate::object_store::{
32 ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
33 DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
34};
35use lance_core::error::{Error, Result};
36
37#[derive(Default, Debug)]
38pub struct AwsStoreProvider;
39
40impl AwsStoreProvider {
41 async fn build_amazon_s3_store(
42 &self,
43 base_path: &mut Url,
44 params: &ObjectStoreParams,
45 storage_options: &StorageOptions,
46 is_s3_express: bool,
47 ) -> Result<Arc<dyn OSObjectStore>> {
48 let max_retries = storage_options.client_max_retries();
49 let retry_timeout = storage_options.client_retry_timeout();
50 let retry_config = RetryConfig {
51 backoff: Default::default(),
52 max_retries,
53 retry_timeout: Duration::from_secs(retry_timeout),
54 };
55
56 let mut s3_storage_options = storage_options.as_s3_options();
57 let region = resolve_s3_region(base_path, &s3_storage_options).await?;
58 let (aws_creds, region) = build_aws_credential(
59 params.s3_credentials_refresh_offset,
60 params.aws_credentials.clone(),
61 Some(&s3_storage_options),
62 region,
63 )
64 .await?;
65
66 if is_s3_express {
68 s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
69 }
70
71 base_path.set_scheme("s3").unwrap();
73 base_path.set_query(None);
74
75 let mut builder = AmazonS3Builder::new();
77 for (key, value) in s3_storage_options {
78 builder = builder.with_config(key, value);
79 }
80 builder = builder
81 .with_url(base_path.as_ref())
82 .with_credentials(aws_creds)
83 .with_retry(retry_config)
84 .with_region(region);
85
86 Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
87 }
88
89 async fn build_opendal_s3_store(
90 &self,
91 base_path: &Url,
92 storage_options: &StorageOptions,
93 ) -> Result<Arc<dyn OSObjectStore>> {
94 let bucket = base_path
95 .host_str()
96 .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name", location!()))?
97 .to_string();
98
99 let prefix = base_path.path().trim_start_matches('/').to_string();
100
101 let mut config_map: HashMap<String, String> = storage_options.0.clone();
104
105 config_map.insert("bucket".to_string(), bucket);
107
108 if !prefix.is_empty() {
109 config_map.insert("root".to_string(), format!("/{}", prefix));
110 }
111
112 let operator = Operator::from_iter::<S3>(config_map)
113 .map_err(|e| {
114 Error::invalid_input(
115 format!("Failed to create S3 operator: {:?}", e),
116 location!(),
117 )
118 })?
119 .finish();
120
121 Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
122 }
123}
124
125#[async_trait::async_trait]
126impl ObjectStoreProvider for AwsStoreProvider {
127 async fn new_store(
128 &self,
129 mut base_path: Url,
130 params: &ObjectStoreParams,
131 ) -> Result<ObjectStore> {
132 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
133 let mut storage_options =
134 StorageOptions(params.storage_options.clone().unwrap_or_default());
135 storage_options.with_env_s3();
136 let download_retry_count = storage_options.download_retry_count();
137
138 let use_opendal = storage_options
139 .0
140 .get("use_opendal")
141 .map(|v| v == "true")
142 .unwrap_or(false);
143
144 let is_s3_express = check_s3_express(&base_path, &storage_options);
146
147 let use_constant_size_upload_parts = storage_options
148 .0
149 .get("aws_endpoint")
150 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
151 .unwrap_or(false);
152
153 let inner = if use_opendal {
154 self.build_opendal_s3_store(&base_path, &storage_options)
156 .await?
157 } else {
158 self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
160 .await?
161 };
162
163 Ok(ObjectStore {
164 inner,
165 scheme: String::from(base_path.scheme()),
166 block_size,
167 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
168 use_constant_size_upload_parts,
169 list_is_lexically_ordered: !is_s3_express,
170 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
171 download_retry_count,
172 })
173 }
174}
175
176fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
178 storage_options
179 .0
180 .get("s3_express")
181 .map(|v| v == "true")
182 .unwrap_or(false)
183 || url.authority().ends_with("--x-s3")
184}
185
186async fn resolve_s3_region(
194 url: &Url,
195 storage_options: &HashMap<AmazonS3ConfigKey, String>,
196) -> Result<Option<String>> {
197 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
198 Ok(Some(region.clone()))
199 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
200 let bucket = url.host_str().ok_or_else(|| {
203 Error::invalid_input(
204 format!("Could not parse bucket from url: {}", url),
205 location!(),
206 )
207 })?;
208
209 let mut client_options = ClientOptions::default();
210 for (key, value) in storage_options {
211 if let AmazonS3ConfigKey::Client(client_key) = key {
212 client_options = client_options.with_config(*client_key, value.clone());
213 }
214 }
215
216 let bucket_region =
217 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
218 Ok(Some(bucket_region))
219 } else {
220 Ok(None)
221 }
222}
223
224pub async fn build_aws_credential(
234 credentials_refresh_offset: Duration,
235 credentials: Option<AwsCredentialProvider>,
236 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
237 region: Option<String>,
238) -> Result<(AwsCredentialProvider, String)> {
239 use aws_config::meta::region::RegionProviderChain;
241 const DEFAULT_REGION: &str = "us-west-2";
242
243 let region = if let Some(region) = region {
244 region
245 } else {
246 RegionProviderChain::default_provider()
247 .or_else(DEFAULT_REGION)
248 .region()
249 .await
250 .map(|r| r.as_ref().to_string())
251 .unwrap_or(DEFAULT_REGION.to_string())
252 };
253
254 if let Some(creds) = credentials {
255 Ok((creds, region))
256 } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
257 Ok((Arc::new(creds), region))
258 } else {
259 let credentials_provider = DefaultCredentialsChain::builder().build().await;
260
261 Ok((
262 Arc::new(AwsCredentialAdapter::new(
263 Arc::new(credentials_provider),
264 credentials_refresh_offset,
265 )),
266 region,
267 ))
268 }
269}
270
271fn extract_static_s3_credentials(
272 options: &HashMap<AmazonS3ConfigKey, String>,
273) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
274 let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
275 let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
276 let token = options.get(&AmazonS3ConfigKey::Token).cloned();
277 match (key_id, secret_key, token) {
278 (Some(key_id), Some(secret_key), token) => {
279 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
280 key_id,
281 secret_key,
282 token,
283 }))
284 }
285 _ => None,
286 }
287}
288
289#[derive(Debug)]
291pub struct AwsCredentialAdapter {
292 pub inner: Arc<dyn ProvideCredentials>,
293
294 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
296
297 credentials_refresh_offset: Duration,
299}
300
301impl AwsCredentialAdapter {
302 pub fn new(
303 provider: Arc<dyn ProvideCredentials>,
304 credentials_refresh_offset: Duration,
305 ) -> Self {
306 Self {
307 inner: provider,
308 cache: Arc::new(RwLock::new(HashMap::new())),
309 credentials_refresh_offset,
310 }
311 }
312}
313
314const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
315
316#[async_trait::async_trait]
317impl CredentialProvider for AwsCredentialAdapter {
318 type Credential = ObjectStoreAwsCredential;
319
320 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
321 let cached_creds = {
322 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
323 let expired = cache_value
324 .clone()
325 .map(|cred| {
326 cred.expiry()
327 .map(|exp| {
328 exp.checked_sub(self.credentials_refresh_offset)
329 .expect("this time should always be valid")
330 < SystemTime::now()
331 })
332 .unwrap_or(false)
334 })
335 .unwrap_or(true); if expired {
337 None
338 } else {
339 cache_value.clone()
340 }
341 };
342
343 if let Some(creds) = cached_creds {
344 Ok(Arc::new(Self::Credential {
345 key_id: creds.access_key_id().to_string(),
346 secret_key: creds.secret_access_key().to_string(),
347 token: creds.session_token().map(|s| s.to_string()),
348 }))
349 } else {
350 let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
351 |e| Error::Internal {
352 message: format!("Failed to get AWS credentials: {:?}", e),
353 location: location!(),
354 },
355 )?);
356
357 self.cache
358 .write()
359 .await
360 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
361
362 Ok(Arc::new(Self::Credential {
363 key_id: refreshed_creds.access_key_id().to_string(),
364 secret_key: refreshed_creds.secret_access_key().to_string(),
365 token: refreshed_creds.session_token().map(|s| s.to_string()),
366 }))
367 }
368 }
369}
370
371impl StorageOptions {
372 pub fn with_env_s3(&mut self) {
374 for (os_key, os_value) in std::env::vars_os() {
375 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
376 if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
377 if !self.0.contains_key(config_key.as_ref()) {
378 self.0
379 .insert(config_key.as_ref().to_string(), value.to_string());
380 }
381 }
382 }
383 }
384 }
385
386 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
388 self.0
389 .iter()
390 .filter_map(|(key, value)| {
391 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
392 Some((s3_key, value.clone()))
393 })
394 .collect()
395 }
396}
397
398impl ObjectStoreParams {
399 pub fn with_aws_credentials(
401 aws_credentials: Option<AwsCredentialProvider>,
402 region: Option<String>,
403 ) -> Self {
404 Self {
405 aws_credentials,
406 storage_options: region
407 .map(|region| [("region".into(), region)].iter().cloned().collect()),
408 ..Default::default()
409 }
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use std::sync::atomic::{AtomicBool, Ordering};
416
417 use object_store::path::Path;
418
419 use crate::object_store::ObjectStoreRegistry;
420
421 use super::*;
422
423 #[derive(Debug, Default)]
424 struct MockAwsCredentialsProvider {
425 called: AtomicBool,
426 }
427
428 #[async_trait::async_trait]
429 impl CredentialProvider for MockAwsCredentialsProvider {
430 type Credential = ObjectStoreAwsCredential;
431
432 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
433 self.called.store(true, Ordering::Relaxed);
434 Ok(Arc::new(Self::Credential {
435 key_id: "".to_string(),
436 secret_key: "".to_string(),
437 token: None,
438 }))
439 }
440 }
441
442 #[tokio::test]
443 async fn test_injected_aws_creds_option_is_used() {
444 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
445 let registry = Arc::new(ObjectStoreRegistry::default());
446
447 let params = ObjectStoreParams {
448 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
449 ..ObjectStoreParams::default()
450 };
451
452 assert!(!mock_provider.called.load(Ordering::Relaxed));
454
455 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
456 .await
457 .unwrap();
458
459 let _ = store
461 .open(&Path::parse("/").unwrap())
462 .await
463 .unwrap()
464 .get_range(0..1)
465 .await;
466
467 assert!(mock_provider.called.load(Ordering::Relaxed));
469 }
470
471 #[test]
472 fn test_s3_path_parsing() {
473 let provider = AwsStoreProvider;
474
475 let cases = [
476 ("s3://bucket/path/to/file", "path/to/file"),
477 ("s3://bucket/测试path/to/file", "测试path/to/file"),
479 ("s3://bucket/path/&to/file", "path/&to/file"),
480 ("s3://bucket/path/=to/file", "path/=to/file"),
481 (
482 "s3+ddb://bucket/path/to/file?ddbTableName=test",
483 "path/to/file",
484 ),
485 ];
486
487 for (uri, expected_path) in cases {
488 let url = Url::parse(uri).unwrap();
489 let path = provider.extract_path(&url).unwrap();
490 let expected_path = Path::from(expected_path);
491 assert_eq!(path, expected_path)
492 }
493 }
494
495 #[test]
496 fn test_is_s3_express() {
497 let cases = [
498 (
499 "s3://bucket/path/to/file",
500 HashMap::from([("s3_express".to_string(), "true".to_string())]),
501 true,
502 ),
503 (
504 "s3://bucket/path/to/file",
505 HashMap::from([("s3_express".to_string(), "false".to_string())]),
506 false,
507 ),
508 ("s3://bucket/path/to/file", HashMap::from([]), false),
509 (
510 "s3://bucket--x-s3/path/to/file",
511 HashMap::from([("s3_express".to_string(), "true".to_string())]),
512 true,
513 ),
514 (
515 "s3://bucket--x-s3/path/to/file",
516 HashMap::from([("s3_express".to_string(), "false".to_string())]),
517 true, ),
519 ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
520 ];
521
522 for (uri, storage_map, expected) in cases {
523 let url = Url::parse(uri).unwrap();
524 let storage_options = StorageOptions(storage_map);
525 let is_s3_express = check_s3_express(&url, &storage_options);
526 assert_eq!(is_s3_express, expected);
527 }
528 }
529
530 #[tokio::test]
531 async fn test_use_opendal_flag() {
532 let provider = AwsStoreProvider;
533 let url = Url::parse("s3://test-bucket/path").unwrap();
534 let params_with_flag = ObjectStoreParams {
535 storage_options: Some(HashMap::from([
536 ("use_opendal".to_string(), "true".to_string()),
537 ("region".to_string(), "us-west-2".to_string()),
538 ])),
539 ..Default::default()
540 };
541
542 let store = provider
543 .new_store(url.clone(), ¶ms_with_flag)
544 .await
545 .unwrap();
546 assert_eq!(store.scheme, "s3");
547 }
548}