lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3// SPDX-License-Identifier: Apache-2.0
4// SPDX-FileCopyrightText: Copyright The Lance Authors
5
6use std::{
7    collections::HashMap,
8    str::FromStr,
9    sync::Arc,
10    time::{Duration, SystemTime},
11};
12
13use aws_config::default_provider::credentials::DefaultCredentialsChain;
14use aws_credential_types::provider::ProvideCredentials;
15use object_store::{
16    aws::{
17        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
18        AwsCredentialProvider,
19    },
20    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
21    StaticCredentialProvider,
22};
23use snafu::location;
24use tokio::sync::RwLock;
25use url::Url;
26
27use crate::object_store::{
28    ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
29    DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
30};
31use lance_core::error::{Error, Result};
32
33#[derive(Default, Debug)]
34pub struct AwsStoreProvider;
35
36#[async_trait::async_trait]
37impl ObjectStoreProvider for AwsStoreProvider {
38    async fn new_store(
39        &self,
40        mut base_path: Url,
41        params: &ObjectStoreParams,
42    ) -> Result<ObjectStore> {
43        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
44        let mut storage_options =
45            StorageOptions(params.storage_options.clone().unwrap_or_default());
46        let download_retry_count = storage_options.download_retry_count();
47
48        let max_retries = storage_options.client_max_retries();
49        let retry_timeout = storage_options.client_retry_timeout();
50        let retry_config = RetryConfig {
51            backoff: Default::default(),
52            max_retries,
53            retry_timeout: Duration::from_secs(retry_timeout),
54        };
55
56        storage_options.with_env_s3();
57
58        let mut storage_options = storage_options.as_s3_options();
59        let region = resolve_s3_region(&base_path, &storage_options).await?;
60        let (aws_creds, region) = build_aws_credential(
61            params.s3_credentials_refresh_offset,
62            params.aws_credentials.clone(),
63            Some(&storage_options),
64            region,
65        )
66        .await?;
67
68        // This will be default in next version of object store.
69        // https://github.com/apache/arrow-rs/pull/7181
70        // We can do this when we upgrade to 0.12.
71        storage_options
72            .entry(AmazonS3ConfigKey::ConditionalPut)
73            .or_insert_with(|| "etag".to_string());
74
75        // Cloudflare does not support varying part sizes.
76        let use_constant_size_upload_parts = storage_options
77            .get(&AmazonS3ConfigKey::Endpoint)
78            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
79            .unwrap_or(false);
80
81        let is_s3_express = check_s3_express(&base_path, &mut storage_options);
82
83        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
84        base_path.set_scheme("s3").unwrap();
85        base_path.set_query(None);
86
87        // we can't use parse_url_opts here because we need to manually set the credentials provider
88        let mut builder = AmazonS3Builder::new();
89        for (key, value) in storage_options {
90            builder = builder.with_config(key, value);
91        }
92        builder = builder
93            .with_url(base_path.as_ref())
94            .with_credentials(aws_creds)
95            .with_retry(retry_config)
96            .with_region(region);
97        let inner = Arc::new(builder.build()?);
98
99        Ok(ObjectStore {
100            inner,
101            scheme: String::from(base_path.scheme()),
102            block_size,
103            max_iop_size: *DEFAULT_MAX_IOP_SIZE,
104            use_constant_size_upload_parts,
105            list_is_lexically_ordered: !is_s3_express,
106            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
107            download_retry_count,
108        })
109    }
110}
111
112/// Check if the storage is S3 Express, update object storage options along the way
113fn check_s3_express(url: &Url, storage_options: &mut HashMap<AmazonS3ConfigKey, String>) -> bool {
114    if matches!(storage_options.get(&AmazonS3ConfigKey::S3Express), Some(val) if val == "true") {
115        return true;
116    }
117
118    if url.authority().ends_with("--x-s3") {
119        storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
120        return true;
121    }
122
123    false
124}
125
126/// Figure out the S3 region of the bucket.
127///
128/// This resolves in order of precedence:
129/// 1. The region provided in the storage options
130/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
131///
132/// It can return None if no region is provided and the endpoint is set.
133async fn resolve_s3_region(
134    url: &Url,
135    storage_options: &HashMap<AmazonS3ConfigKey, String>,
136) -> Result<Option<String>> {
137    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
138        Ok(Some(region.clone()))
139    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
140        // If no endpoint is set, we can assume this is AWS S3 and the region
141        // can be resolved from the bucket.
142        let bucket = url.host_str().ok_or_else(|| {
143            Error::invalid_input(
144                format!("Could not parse bucket from url: {}", url),
145                location!(),
146            )
147        })?;
148
149        let mut client_options = ClientOptions::default();
150        for (key, value) in storage_options {
151            if let AmazonS3ConfigKey::Client(client_key) = key {
152                client_options = client_options.with_config(*client_key, value.clone());
153            }
154        }
155
156        let bucket_region =
157            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
158        Ok(Some(bucket_region))
159    } else {
160        Ok(None)
161    }
162}
163
164/// Build AWS credentials
165///
166/// This resolves credentials from the following sources in order:
167/// 1. An explicit `credentials` provider
168/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`,
169///    `aws_secret_access_key`, `aws_session_token`)
170/// 3. The default credential provider chain from AWS SDK.
171///
172/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
173pub async fn build_aws_credential(
174    credentials_refresh_offset: Duration,
175    credentials: Option<AwsCredentialProvider>,
176    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
177    region: Option<String>,
178) -> Result<(AwsCredentialProvider, String)> {
179    // TODO: make this return no credential provider not using AWS
180    use aws_config::meta::region::RegionProviderChain;
181    const DEFAULT_REGION: &str = "us-west-2";
182
183    let region = if let Some(region) = region {
184        region
185    } else {
186        RegionProviderChain::default_provider()
187            .or_else(DEFAULT_REGION)
188            .region()
189            .await
190            .map(|r| r.as_ref().to_string())
191            .unwrap_or(DEFAULT_REGION.to_string())
192    };
193
194    if let Some(creds) = credentials {
195        Ok((creds, region))
196    } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
197        Ok((Arc::new(creds), region))
198    } else {
199        let credentials_provider = DefaultCredentialsChain::builder().build().await;
200
201        Ok((
202            Arc::new(AwsCredentialAdapter::new(
203                Arc::new(credentials_provider),
204                credentials_refresh_offset,
205            )),
206            region,
207        ))
208    }
209}
210
211fn extract_static_s3_credentials(
212    options: &HashMap<AmazonS3ConfigKey, String>,
213) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
214    let key_id = options
215        .get(&AmazonS3ConfigKey::AccessKeyId)
216        .map(|s| s.to_string());
217    let secret_key = options
218        .get(&AmazonS3ConfigKey::SecretAccessKey)
219        .map(|s| s.to_string());
220    let token = options
221        .get(&AmazonS3ConfigKey::Token)
222        .map(|s| s.to_string());
223    match (key_id, secret_key, token) {
224        (Some(key_id), Some(secret_key), token) => {
225            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
226                key_id,
227                secret_key,
228                token,
229            }))
230        }
231        _ => None,
232    }
233}
234
235/// Adapt an AWS SDK cred into object_store credentials
236#[derive(Debug)]
237pub struct AwsCredentialAdapter {
238    pub inner: Arc<dyn ProvideCredentials>,
239
240    // RefCell can't be shared across threads, so we use HashMap
241    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
242
243    // The amount of time before expiry to refresh credentials
244    credentials_refresh_offset: Duration,
245}
246
247impl AwsCredentialAdapter {
248    pub fn new(
249        provider: Arc<dyn ProvideCredentials>,
250        credentials_refresh_offset: Duration,
251    ) -> Self {
252        Self {
253            inner: provider,
254            cache: Arc::new(RwLock::new(HashMap::new())),
255            credentials_refresh_offset,
256        }
257    }
258}
259
260const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
261
262#[async_trait::async_trait]
263impl CredentialProvider for AwsCredentialAdapter {
264    type Credential = ObjectStoreAwsCredential;
265
266    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
267        let cached_creds = {
268            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
269            let expired = cache_value
270                .clone()
271                .map(|cred| {
272                    cred.expiry()
273                        .map(|exp| {
274                            exp.checked_sub(self.credentials_refresh_offset)
275                                .expect("this time should always be valid")
276                                < SystemTime::now()
277                        })
278                        // no expiry is never expire
279                        .unwrap_or(false)
280                })
281                .unwrap_or(true); // no cred is the same as expired;
282            if expired {
283                None
284            } else {
285                cache_value.clone()
286            }
287        };
288
289        if let Some(creds) = cached_creds {
290            Ok(Arc::new(Self::Credential {
291                key_id: creds.access_key_id().to_string(),
292                secret_key: creds.secret_access_key().to_string(),
293                token: creds.session_token().map(|s| s.to_string()),
294            }))
295        } else {
296            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
297                |e| Error::Internal {
298                    message: format!("Failed to get AWS credentials: {}", e),
299                    location: location!(),
300                },
301            )?);
302
303            self.cache
304                .write()
305                .await
306                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
307
308            Ok(Arc::new(Self::Credential {
309                key_id: refreshed_creds.access_key_id().to_string(),
310                secret_key: refreshed_creds.secret_access_key().to_string(),
311                token: refreshed_creds.session_token().map(|s| s.to_string()),
312            }))
313        }
314    }
315}
316
317impl StorageOptions {
318    /// Add values from the environment to storage options
319    pub fn with_env_s3(&mut self) {
320        for (os_key, os_value) in std::env::vars_os() {
321            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
322                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
323                    if !self.0.contains_key(config_key.as_ref()) {
324                        self.0
325                            .insert(config_key.as_ref().to_string(), value.to_string());
326                    }
327                }
328            }
329        }
330    }
331
332    /// Subset of options relevant for s3 storage
333    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
334        self.0
335            .iter()
336            .filter_map(|(key, value)| {
337                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
338                Some((s3_key, value.clone()))
339            })
340            .collect()
341    }
342}
343
344impl ObjectStoreParams {
345    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
346    pub fn with_aws_credentials(
347        aws_credentials: Option<AwsCredentialProvider>,
348        region: Option<String>,
349    ) -> Self {
350        Self {
351            aws_credentials,
352            storage_options: region
353                .map(|region| [("region".into(), region)].iter().cloned().collect()),
354            ..Default::default()
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use std::sync::atomic::{AtomicBool, Ordering};
362
363    use object_store::path::Path;
364
365    use crate::object_store::ObjectStoreRegistry;
366
367    use super::*;
368
369    #[derive(Debug, Default)]
370    struct MockAwsCredentialsProvider {
371        called: AtomicBool,
372    }
373
374    #[async_trait::async_trait]
375    impl CredentialProvider for MockAwsCredentialsProvider {
376        type Credential = ObjectStoreAwsCredential;
377
378        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
379            self.called.store(true, Ordering::Relaxed);
380            Ok(Arc::new(Self::Credential {
381                key_id: "".to_string(),
382                secret_key: "".to_string(),
383                token: None,
384            }))
385        }
386    }
387
388    #[tokio::test]
389    async fn test_injected_aws_creds_option_is_used() {
390        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
391        let registry = Arc::new(ObjectStoreRegistry::default());
392
393        let params = ObjectStoreParams {
394            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
395            ..ObjectStoreParams::default()
396        };
397
398        // Not called yet
399        assert!(!mock_provider.called.load(Ordering::Relaxed));
400
401        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
402            .await
403            .unwrap();
404
405        // fails, but we don't care
406        let _ = store
407            .open(&Path::parse("/").unwrap())
408            .await
409            .unwrap()
410            .get_range(0..1)
411            .await;
412
413        // Not called yet
414        assert!(mock_provider.called.load(Ordering::Relaxed));
415    }
416
417    #[test]
418    fn test_s3_path_parsing() {
419        let provider = AwsStoreProvider;
420
421        let cases = [
422            ("s3://bucket/path/to/file", "path/to/file"),
423            (
424                "s3+ddb://bucket/path/to/file?ddbTableName=test",
425                "path/to/file",
426            ),
427        ];
428
429        for (uri, expected_path) in cases {
430            let url = Url::parse(uri).unwrap();
431            let path = provider.extract_path(&url);
432            let expected_path = Path::from(expected_path);
433            assert_eq!(path, expected_path);
434        }
435    }
436
437    #[test]
438    fn test_is_s3_express() {
439        let cases = [
440            (
441                "s3://bucket/path/to/file",
442                HashMap::from([(AmazonS3ConfigKey::S3Express, "true".into())]),
443                true,
444            ),
445            (
446                "s3://bucket/path/to/file",
447                HashMap::from([(AmazonS3ConfigKey::S3Express, "false".into())]),
448                false,
449            ),
450            ("s3://bucket/path/to/file", HashMap::from([]), false),
451            (
452                "s3://bucket--x-s3/path/to/file",
453                HashMap::from([(AmazonS3ConfigKey::S3Express, "true".into())]),
454                true,
455            ),
456            (
457                "s3://bucket--x-s3/path/to/file",
458                HashMap::from([(AmazonS3ConfigKey::S3Express, "false".into())]),
459                true,
460            ),
461            ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
462        ];
463
464        for (uri, mut configs, expected) in cases {
465            let url = Url::parse(uri).unwrap();
466            let is_s3_express = check_s3_express(&url, &mut configs);
467            assert_eq!(is_s3_express, expected);
468            if is_s3_express {
469                assert!(configs
470                    .get(&AmazonS3ConfigKey::S3Express)
471                    .is_some_and(|opt| opt == "true"));
472            }
473        }
474    }
475}