lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3// SPDX-License-Identifier: Apache-2.0
4// SPDX-FileCopyrightText: Copyright The Lance Authors
5
6use std::{
7    collections::HashMap,
8    str::FromStr,
9    sync::Arc,
10    time::{Duration, SystemTime},
11};
12
13use object_store::ObjectStore as OSObjectStore;
14use object_store_opendal::OpendalStore;
15use opendal::{services::S3, Operator};
16
17use aws_config::default_provider::credentials::DefaultCredentialsChain;
18use aws_credential_types::provider::ProvideCredentials;
19use object_store::{
20    aws::{
21        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
22        AwsCredentialProvider,
23    },
24    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
25    StaticCredentialProvider,
26};
27use snafu::location;
28use tokio::sync::RwLock;
29use url::Url;
30
31use crate::object_store::{
32    ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
33    DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
34};
35use lance_core::error::{Error, Result};
36
37#[derive(Default, Debug)]
38pub struct AwsStoreProvider;
39
40impl AwsStoreProvider {
41    async fn build_amazon_s3_store(
42        &self,
43        base_path: &mut Url,
44        params: &ObjectStoreParams,
45        storage_options: &StorageOptions,
46        is_s3_express: bool,
47    ) -> Result<Arc<dyn OSObjectStore>> {
48        let max_retries = storage_options.client_max_retries();
49        let retry_timeout = storage_options.client_retry_timeout();
50        let retry_config = RetryConfig {
51            backoff: Default::default(),
52            max_retries,
53            retry_timeout: Duration::from_secs(retry_timeout),
54        };
55
56        let mut s3_storage_options = storage_options.as_s3_options();
57        let region = resolve_s3_region(base_path, &s3_storage_options).await?;
58        let (aws_creds, region) = build_aws_credential(
59            params.s3_credentials_refresh_offset,
60            params.aws_credentials.clone(),
61            Some(&s3_storage_options),
62            region,
63        )
64        .await?;
65
66        // Set S3Express flag if detected
67        if is_s3_express {
68            s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
69        }
70
71        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
72        base_path.set_scheme("s3").unwrap();
73        base_path.set_query(None);
74
75        // we can't use parse_url_opts here because we need to manually set the credentials provider
76        let mut builder = AmazonS3Builder::new();
77        for (key, value) in s3_storage_options {
78            builder = builder.with_config(key, value);
79        }
80        builder = builder
81            .with_url(base_path.as_ref())
82            .with_credentials(aws_creds)
83            .with_retry(retry_config)
84            .with_region(region);
85
86        Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
87    }
88
89    async fn build_opendal_s3_store(
90        &self,
91        base_path: &Url,
92        storage_options: &StorageOptions,
93    ) -> Result<Arc<dyn OSObjectStore>> {
94        let bucket = base_path
95            .host_str()
96            .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name", location!()))?
97            .to_string();
98
99        let prefix = base_path.path().trim_start_matches('/').to_string();
100
101        // Start with all storage options as the config map
102        // OpenDAL will handle environment variables through its default credentials chain
103        let mut config_map: HashMap<String, String> = storage_options.0.clone();
104
105        // Set required OpenDAL configuration
106        config_map.insert("bucket".to_string(), bucket);
107
108        if !prefix.is_empty() {
109            config_map.insert("root".to_string(), format!("/{}", prefix));
110        }
111
112        let operator = Operator::from_iter::<S3>(config_map)
113            .map_err(|e| {
114                Error::invalid_input(
115                    format!("Failed to create S3 operator: {:?}", e),
116                    location!(),
117                )
118            })?
119            .finish();
120
121        Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
122    }
123}
124
125#[async_trait::async_trait]
126impl ObjectStoreProvider for AwsStoreProvider {
127    async fn new_store(
128        &self,
129        mut base_path: Url,
130        params: &ObjectStoreParams,
131    ) -> Result<ObjectStore> {
132        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
133        let mut storage_options =
134            StorageOptions(params.storage_options.clone().unwrap_or_default());
135        storage_options.with_env_s3();
136        let download_retry_count = storage_options.download_retry_count();
137
138        let use_opendal = storage_options
139            .0
140            .get("use_opendal")
141            .map(|v| v == "true")
142            .unwrap_or(false);
143
144        // Determine S3 Express and constant size upload parts before building the store
145        let is_s3_express = check_s3_express(&base_path, &storage_options);
146
147        let use_constant_size_upload_parts = storage_options
148            .0
149            .get("aws_endpoint")
150            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
151            .unwrap_or(false);
152
153        let inner = if use_opendal {
154            // Use OpenDAL implementation
155            self.build_opendal_s3_store(&base_path, &storage_options)
156                .await?
157        } else {
158            // Use default Amazon S3 implementation
159            self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
160                .await?
161        };
162
163        Ok(ObjectStore {
164            inner,
165            scheme: String::from(base_path.scheme()),
166            block_size,
167            max_iop_size: *DEFAULT_MAX_IOP_SIZE,
168            use_constant_size_upload_parts,
169            list_is_lexically_ordered: !is_s3_express,
170            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
171            download_retry_count,
172        })
173    }
174}
175
176/// Check if the storage is S3 Express
177fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
178    storage_options
179        .0
180        .get("s3_express")
181        .map(|v| v == "true")
182        .unwrap_or(false)
183        || url.authority().ends_with("--x-s3")
184}
185
186/// Figure out the S3 region of the bucket.
187///
188/// This resolves in order of precedence:
189/// 1. The region provided in the storage options
190/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
191///
192/// It can return None if no region is provided and the endpoint is set.
193async fn resolve_s3_region(
194    url: &Url,
195    storage_options: &HashMap<AmazonS3ConfigKey, String>,
196) -> Result<Option<String>> {
197    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
198        Ok(Some(region.clone()))
199    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
200        // If no endpoint is set, we can assume this is AWS S3 and the region
201        // can be resolved from the bucket.
202        let bucket = url.host_str().ok_or_else(|| {
203            Error::invalid_input(
204                format!("Could not parse bucket from url: {}", url),
205                location!(),
206            )
207        })?;
208
209        let mut client_options = ClientOptions::default();
210        for (key, value) in storage_options {
211            if let AmazonS3ConfigKey::Client(client_key) = key {
212                client_options = client_options.with_config(*client_key, value.clone());
213            }
214        }
215
216        let bucket_region =
217            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
218        Ok(Some(bucket_region))
219    } else {
220        Ok(None)
221    }
222}
223
224/// Build AWS credentials
225///
226/// This resolves credentials from the following sources in order:
227/// 1. An explicit `credentials` provider
228/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`,
229///    `aws_secret_access_key`, `aws_session_token`)
230/// 3. The default credential provider chain from AWS SDK.
231///
232/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
233pub async fn build_aws_credential(
234    credentials_refresh_offset: Duration,
235    credentials: Option<AwsCredentialProvider>,
236    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
237    region: Option<String>,
238) -> Result<(AwsCredentialProvider, String)> {
239    // TODO: make this return no credential provider not using AWS
240    use aws_config::meta::region::RegionProviderChain;
241    const DEFAULT_REGION: &str = "us-west-2";
242
243    let region = if let Some(region) = region {
244        region
245    } else {
246        RegionProviderChain::default_provider()
247            .or_else(DEFAULT_REGION)
248            .region()
249            .await
250            .map(|r| r.as_ref().to_string())
251            .unwrap_or(DEFAULT_REGION.to_string())
252    };
253
254    if let Some(creds) = credentials {
255        Ok((creds, region))
256    } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
257        Ok((Arc::new(creds), region))
258    } else {
259        let credentials_provider = DefaultCredentialsChain::builder().build().await;
260
261        Ok((
262            Arc::new(AwsCredentialAdapter::new(
263                Arc::new(credentials_provider),
264                credentials_refresh_offset,
265            )),
266            region,
267        ))
268    }
269}
270
271fn extract_static_s3_credentials(
272    options: &HashMap<AmazonS3ConfigKey, String>,
273) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
274    let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
275    let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
276    let token = options.get(&AmazonS3ConfigKey::Token).cloned();
277    match (key_id, secret_key, token) {
278        (Some(key_id), Some(secret_key), token) => {
279            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
280                key_id,
281                secret_key,
282                token,
283            }))
284        }
285        _ => None,
286    }
287}
288
289/// Adapt an AWS SDK cred into object_store credentials
290#[derive(Debug)]
291pub struct AwsCredentialAdapter {
292    pub inner: Arc<dyn ProvideCredentials>,
293
294    // RefCell can't be shared across threads, so we use HashMap
295    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
296
297    // The amount of time before expiry to refresh credentials
298    credentials_refresh_offset: Duration,
299}
300
301impl AwsCredentialAdapter {
302    pub fn new(
303        provider: Arc<dyn ProvideCredentials>,
304        credentials_refresh_offset: Duration,
305    ) -> Self {
306        Self {
307            inner: provider,
308            cache: Arc::new(RwLock::new(HashMap::new())),
309            credentials_refresh_offset,
310        }
311    }
312}
313
314const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
315
316#[async_trait::async_trait]
317impl CredentialProvider for AwsCredentialAdapter {
318    type Credential = ObjectStoreAwsCredential;
319
320    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
321        let cached_creds = {
322            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
323            let expired = cache_value
324                .clone()
325                .map(|cred| {
326                    cred.expiry()
327                        .map(|exp| {
328                            exp.checked_sub(self.credentials_refresh_offset)
329                                .expect("this time should always be valid")
330                                < SystemTime::now()
331                        })
332                        // no expiry is never expire
333                        .unwrap_or(false)
334                })
335                .unwrap_or(true); // no cred is the same as expired;
336            if expired {
337                None
338            } else {
339                cache_value.clone()
340            }
341        };
342
343        if let Some(creds) = cached_creds {
344            Ok(Arc::new(Self::Credential {
345                key_id: creds.access_key_id().to_string(),
346                secret_key: creds.secret_access_key().to_string(),
347                token: creds.session_token().map(|s| s.to_string()),
348            }))
349        } else {
350            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
351                |e| Error::Internal {
352                    message: format!("Failed to get AWS credentials: {:?}", e),
353                    location: location!(),
354                },
355            )?);
356
357            self.cache
358                .write()
359                .await
360                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
361
362            Ok(Arc::new(Self::Credential {
363                key_id: refreshed_creds.access_key_id().to_string(),
364                secret_key: refreshed_creds.secret_access_key().to_string(),
365                token: refreshed_creds.session_token().map(|s| s.to_string()),
366            }))
367        }
368    }
369}
370
371impl StorageOptions {
372    /// Add values from the environment to storage options
373    pub fn with_env_s3(&mut self) {
374        for (os_key, os_value) in std::env::vars_os() {
375            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
376                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
377                    if !self.0.contains_key(config_key.as_ref()) {
378                        self.0
379                            .insert(config_key.as_ref().to_string(), value.to_string());
380                    }
381                }
382            }
383        }
384    }
385
386    /// Subset of options relevant for s3 storage
387    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
388        self.0
389            .iter()
390            .filter_map(|(key, value)| {
391                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
392                Some((s3_key, value.clone()))
393            })
394            .collect()
395    }
396}
397
398impl ObjectStoreParams {
399    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
400    pub fn with_aws_credentials(
401        aws_credentials: Option<AwsCredentialProvider>,
402        region: Option<String>,
403    ) -> Self {
404        Self {
405            aws_credentials,
406            storage_options: region
407                .map(|region| [("region".into(), region)].iter().cloned().collect()),
408            ..Default::default()
409        }
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use std::sync::atomic::{AtomicBool, Ordering};
416
417    use object_store::path::Path;
418
419    use crate::object_store::ObjectStoreRegistry;
420
421    use super::*;
422
423    #[derive(Debug, Default)]
424    struct MockAwsCredentialsProvider {
425        called: AtomicBool,
426    }
427
428    #[async_trait::async_trait]
429    impl CredentialProvider for MockAwsCredentialsProvider {
430        type Credential = ObjectStoreAwsCredential;
431
432        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
433            self.called.store(true, Ordering::Relaxed);
434            Ok(Arc::new(Self::Credential {
435                key_id: "".to_string(),
436                secret_key: "".to_string(),
437                token: None,
438            }))
439        }
440    }
441
442    #[tokio::test]
443    async fn test_injected_aws_creds_option_is_used() {
444        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
445        let registry = Arc::new(ObjectStoreRegistry::default());
446
447        let params = ObjectStoreParams {
448            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
449            ..ObjectStoreParams::default()
450        };
451
452        // Not called yet
453        assert!(!mock_provider.called.load(Ordering::Relaxed));
454
455        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
456            .await
457            .unwrap();
458
459        // fails, but we don't care
460        let _ = store
461            .open(&Path::parse("/").unwrap())
462            .await
463            .unwrap()
464            .get_range(0..1)
465            .await;
466
467        // Not called yet
468        assert!(mock_provider.called.load(Ordering::Relaxed));
469    }
470
471    #[test]
472    fn test_s3_path_parsing() {
473        let provider = AwsStoreProvider;
474
475        let cases = [
476            ("s3://bucket/path/to/file", "path/to/file"),
477            // for non ASCII string tests
478            ("s3://bucket/测试path/to/file", "测试path/to/file"),
479            ("s3://bucket/path/&to/file", "path/&to/file"),
480            ("s3://bucket/path/=to/file", "path/=to/file"),
481            (
482                "s3+ddb://bucket/path/to/file?ddbTableName=test",
483                "path/to/file",
484            ),
485        ];
486
487        for (uri, expected_path) in cases {
488            let url = Url::parse(uri).unwrap();
489            let path = provider.extract_path(&url).unwrap();
490            let expected_path = Path::from(expected_path);
491            assert_eq!(path, expected_path)
492        }
493    }
494
495    #[test]
496    fn test_is_s3_express() {
497        let cases = [
498            (
499                "s3://bucket/path/to/file",
500                HashMap::from([("s3_express".to_string(), "true".to_string())]),
501                true,
502            ),
503            (
504                "s3://bucket/path/to/file",
505                HashMap::from([("s3_express".to_string(), "false".to_string())]),
506                false,
507            ),
508            ("s3://bucket/path/to/file", HashMap::from([]), false),
509            (
510                "s3://bucket--x-s3/path/to/file",
511                HashMap::from([("s3_express".to_string(), "true".to_string())]),
512                true,
513            ),
514            (
515                "s3://bucket--x-s3/path/to/file",
516                HashMap::from([("s3_express".to_string(), "false".to_string())]),
517                true, // URL takes precedence
518            ),
519            ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
520        ];
521
522        for (uri, storage_map, expected) in cases {
523            let url = Url::parse(uri).unwrap();
524            let storage_options = StorageOptions(storage_map);
525            let is_s3_express = check_s3_express(&url, &storage_options);
526            assert_eq!(is_s3_express, expected);
527        }
528    }
529
530    #[tokio::test]
531    async fn test_use_opendal_flag() {
532        let provider = AwsStoreProvider;
533        let url = Url::parse("s3://test-bucket/path").unwrap();
534        let params_with_flag = ObjectStoreParams {
535            storage_options: Some(HashMap::from([
536                ("use_opendal".to_string(), "true".to_string()),
537                ("region".to_string(), "us-west-2".to_string()),
538            ])),
539            ..Default::default()
540        };
541
542        let store = provider
543            .new_store(url.clone(), &params_with_flag)
544            .await
545            .unwrap();
546        assert_eq!(store.scheme, "s3");
547    }
548}