lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3// SPDX-License-Identifier: Apache-2.0
4// SPDX-FileCopyrightText: Copyright The Lance Authors
5
6use std::{
7    collections::HashMap,
8    str::FromStr,
9    sync::Arc,
10    time::{Duration, SystemTime},
11};
12
13use aws_config::default_provider::credentials::DefaultCredentialsChain;
14use aws_credential_types::provider::ProvideCredentials;
15use object_store::{
16    aws::{
17        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
18        AwsCredentialProvider,
19    },
20    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
21    StaticCredentialProvider,
22};
23use snafu::location;
24use tokio::sync::RwLock;
25use url::Url;
26
27use crate::object_store::{
28    ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
29    DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
30};
31use lance_core::error::{Error, Result};
32
33#[derive(Default, Debug)]
34pub struct AwsStoreProvider;
35
36#[async_trait::async_trait]
37impl ObjectStoreProvider for AwsStoreProvider {
38    async fn new_store(
39        &self,
40        mut base_path: Url,
41        params: &ObjectStoreParams,
42    ) -> Result<ObjectStore> {
43        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
44        let mut storage_options =
45            StorageOptions(params.storage_options.clone().unwrap_or_default());
46        let download_retry_count = storage_options.download_retry_count();
47
48        let max_retries = storage_options.client_max_retries();
49        let retry_timeout = storage_options.client_retry_timeout();
50        let retry_config = RetryConfig {
51            backoff: Default::default(),
52            max_retries,
53            retry_timeout: Duration::from_secs(retry_timeout),
54        };
55
56        storage_options.with_env_s3();
57
58        let mut storage_options = storage_options.as_s3_options();
59        let region = resolve_s3_region(&base_path, &storage_options).await?;
60        let (aws_creds, region) = build_aws_credential(
61            params.s3_credentials_refresh_offset,
62            params.aws_credentials.clone(),
63            Some(&storage_options),
64            region,
65        )
66        .await?;
67
68        // This will be default in next version of object store.
69        // https://github.com/apache/arrow-rs/pull/7181
70        // We can do this when we upgrade to 0.12.
71        storage_options
72            .entry(AmazonS3ConfigKey::ConditionalPut)
73            .or_insert_with(|| "etag".to_string());
74
75        // Cloudflare does not support varying part sizes.
76        let use_constant_size_upload_parts = storage_options
77            .get(&AmazonS3ConfigKey::Endpoint)
78            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
79            .unwrap_or(false);
80
81        let is_s3_express = check_s3_express(&base_path, &mut storage_options);
82
83        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
84        base_path.set_scheme("s3").unwrap();
85        base_path.set_query(None);
86
87        // we can't use parse_url_opts here because we need to manually set the credentials provider
88        let mut builder = AmazonS3Builder::new();
89        for (key, value) in storage_options {
90            builder = builder.with_config(key, value);
91        }
92        builder = builder
93            .with_url(base_path.as_ref())
94            .with_credentials(aws_creds)
95            .with_retry(retry_config)
96            .with_region(region);
97        let inner = Arc::new(builder.build()?);
98
99        Ok(ObjectStore {
100            inner,
101            scheme: String::from(base_path.scheme()),
102            block_size,
103            max_iop_size: *DEFAULT_MAX_IOP_SIZE,
104            use_constant_size_upload_parts,
105            list_is_lexically_ordered: !is_s3_express,
106            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
107            download_retry_count,
108        })
109    }
110}
111
112/// Check if the storage is S3 Express, update object storage options along the way
113fn check_s3_express(url: &Url, storage_options: &mut HashMap<AmazonS3ConfigKey, String>) -> bool {
114    if matches!(storage_options.get(&AmazonS3ConfigKey::S3Express), Some(val) if val == "true") {
115        return true;
116    }
117
118    if url.authority().ends_with("--x-s3") {
119        storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
120        return true;
121    }
122
123    false
124}
125
126/// Figure out the S3 region of the bucket.
127///
128/// This resolves in order of precedence:
129/// 1. The region provided in the storage options
130/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
131///
132/// It can return None if no region is provided and the endpoint is set.
133async fn resolve_s3_region(
134    url: &Url,
135    storage_options: &HashMap<AmazonS3ConfigKey, String>,
136) -> Result<Option<String>> {
137    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
138        Ok(Some(region.clone()))
139    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
140        // If no endpoint is set, we can assume this is AWS S3 and the region
141        // can be resolved from the bucket.
142        let bucket = url.host_str().ok_or_else(|| {
143            Error::invalid_input(
144                format!("Could not parse bucket from url: {}", url),
145                location!(),
146            )
147        })?;
148
149        let mut client_options = ClientOptions::default();
150        for (key, value) in storage_options {
151            if let AmazonS3ConfigKey::Client(client_key) = key {
152                client_options = client_options.with_config(*client_key, value.clone());
153            }
154        }
155
156        let bucket_region =
157            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
158        Ok(Some(bucket_region))
159    } else {
160        Ok(None)
161    }
162}
163
164/// Build AWS credentials
165///
166/// This resolves credentials from the following sources in order:
167/// 1. An explicit `credentials` provider
168/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`,
169///    `aws_secret_access_key`, `aws_session_token`)
170/// 3. The default credential provider chain from AWS SDK.
171///
172/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
173pub async fn build_aws_credential(
174    credentials_refresh_offset: Duration,
175    credentials: Option<AwsCredentialProvider>,
176    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
177    region: Option<String>,
178) -> Result<(AwsCredentialProvider, String)> {
179    // TODO: make this return no credential provider not using AWS
180    use aws_config::meta::region::RegionProviderChain;
181    const DEFAULT_REGION: &str = "us-west-2";
182
183    let region = if let Some(region) = region {
184        region
185    } else {
186        RegionProviderChain::default_provider()
187            .or_else(DEFAULT_REGION)
188            .region()
189            .await
190            .map(|r| r.as_ref().to_string())
191            .unwrap_or(DEFAULT_REGION.to_string())
192    };
193
194    if let Some(creds) = credentials {
195        Ok((creds, region))
196    } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
197        Ok((Arc::new(creds), region))
198    } else {
199        let credentials_provider = DefaultCredentialsChain::builder().build().await;
200
201        Ok((
202            Arc::new(AwsCredentialAdapter::new(
203                Arc::new(credentials_provider),
204                credentials_refresh_offset,
205            )),
206            region,
207        ))
208    }
209}
210
211fn extract_static_s3_credentials(
212    options: &HashMap<AmazonS3ConfigKey, String>,
213) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
214    let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
215    let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
216    let token = options.get(&AmazonS3ConfigKey::Token).cloned();
217    match (key_id, secret_key, token) {
218        (Some(key_id), Some(secret_key), token) => {
219            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
220                key_id,
221                secret_key,
222                token,
223            }))
224        }
225        _ => None,
226    }
227}
228
229/// Adapt an AWS SDK cred into object_store credentials
230#[derive(Debug)]
231pub struct AwsCredentialAdapter {
232    pub inner: Arc<dyn ProvideCredentials>,
233
234    // RefCell can't be shared across threads, so we use HashMap
235    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
236
237    // The amount of time before expiry to refresh credentials
238    credentials_refresh_offset: Duration,
239}
240
241impl AwsCredentialAdapter {
242    pub fn new(
243        provider: Arc<dyn ProvideCredentials>,
244        credentials_refresh_offset: Duration,
245    ) -> Self {
246        Self {
247            inner: provider,
248            cache: Arc::new(RwLock::new(HashMap::new())),
249            credentials_refresh_offset,
250        }
251    }
252}
253
254const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
255
256#[async_trait::async_trait]
257impl CredentialProvider for AwsCredentialAdapter {
258    type Credential = ObjectStoreAwsCredential;
259
260    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
261        let cached_creds = {
262            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
263            let expired = cache_value
264                .clone()
265                .map(|cred| {
266                    cred.expiry()
267                        .map(|exp| {
268                            exp.checked_sub(self.credentials_refresh_offset)
269                                .expect("this time should always be valid")
270                                < SystemTime::now()
271                        })
272                        // no expiry is never expire
273                        .unwrap_or(false)
274                })
275                .unwrap_or(true); // no cred is the same as expired;
276            if expired {
277                None
278            } else {
279                cache_value.clone()
280            }
281        };
282
283        if let Some(creds) = cached_creds {
284            Ok(Arc::new(Self::Credential {
285                key_id: creds.access_key_id().to_string(),
286                secret_key: creds.secret_access_key().to_string(),
287                token: creds.session_token().map(|s| s.to_string()),
288            }))
289        } else {
290            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
291                |e| Error::Internal {
292                    message: format!("Failed to get AWS credentials: {:?}", e),
293                    location: location!(),
294                },
295            )?);
296
297            self.cache
298                .write()
299                .await
300                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
301
302            Ok(Arc::new(Self::Credential {
303                key_id: refreshed_creds.access_key_id().to_string(),
304                secret_key: refreshed_creds.secret_access_key().to_string(),
305                token: refreshed_creds.session_token().map(|s| s.to_string()),
306            }))
307        }
308    }
309}
310
311impl StorageOptions {
312    /// Add values from the environment to storage options
313    pub fn with_env_s3(&mut self) {
314        for (os_key, os_value) in std::env::vars_os() {
315            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
316                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
317                    if !self.0.contains_key(config_key.as_ref()) {
318                        self.0
319                            .insert(config_key.as_ref().to_string(), value.to_string());
320                    }
321                }
322            }
323        }
324    }
325
326    /// Subset of options relevant for s3 storage
327    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
328        self.0
329            .iter()
330            .filter_map(|(key, value)| {
331                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
332                Some((s3_key, value.clone()))
333            })
334            .collect()
335    }
336}
337
338impl ObjectStoreParams {
339    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
340    pub fn with_aws_credentials(
341        aws_credentials: Option<AwsCredentialProvider>,
342        region: Option<String>,
343    ) -> Self {
344        Self {
345            aws_credentials,
346            storage_options: region
347                .map(|region| [("region".into(), region)].iter().cloned().collect()),
348            ..Default::default()
349        }
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use std::sync::atomic::{AtomicBool, Ordering};
356
357    use object_store::path::Path;
358
359    use crate::object_store::ObjectStoreRegistry;
360
361    use super::*;
362
363    #[derive(Debug, Default)]
364    struct MockAwsCredentialsProvider {
365        called: AtomicBool,
366    }
367
368    #[async_trait::async_trait]
369    impl CredentialProvider for MockAwsCredentialsProvider {
370        type Credential = ObjectStoreAwsCredential;
371
372        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
373            self.called.store(true, Ordering::Relaxed);
374            Ok(Arc::new(Self::Credential {
375                key_id: "".to_string(),
376                secret_key: "".to_string(),
377                token: None,
378            }))
379        }
380    }
381
382    #[tokio::test]
383    async fn test_injected_aws_creds_option_is_used() {
384        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
385        let registry = Arc::new(ObjectStoreRegistry::default());
386
387        let params = ObjectStoreParams {
388            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
389            ..ObjectStoreParams::default()
390        };
391
392        // Not called yet
393        assert!(!mock_provider.called.load(Ordering::Relaxed));
394
395        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
396            .await
397            .unwrap();
398
399        // fails, but we don't care
400        let _ = store
401            .open(&Path::parse("/").unwrap())
402            .await
403            .unwrap()
404            .get_range(0..1)
405            .await;
406
407        // Not called yet
408        assert!(mock_provider.called.load(Ordering::Relaxed));
409    }
410
411    #[test]
412    fn test_s3_path_parsing() {
413        let provider = AwsStoreProvider;
414
415        let cases = [
416            ("s3://bucket/path/to/file", "path/to/file"),
417            // for non ASCII string tests
418            ("s3://bucket/测试path/to/file", "测试path/to/file"),
419            ("s3://bucket/path/&to/file", "path/&to/file"),
420            ("s3://bucket/path/=to/file", "path/=to/file"),
421            (
422                "s3+ddb://bucket/path/to/file?ddbTableName=test",
423                "path/to/file",
424            ),
425        ];
426
427        for (uri, expected_path) in cases {
428            let url = Url::parse(uri).unwrap();
429            let path = provider.extract_path(&url).unwrap();
430            let expected_path = Path::from(expected_path);
431            assert_eq!(path, expected_path)
432        }
433    }
434
435    #[test]
436    fn test_is_s3_express() {
437        let cases = [
438            (
439                "s3://bucket/path/to/file",
440                HashMap::from([(AmazonS3ConfigKey::S3Express, "true".into())]),
441                true,
442            ),
443            (
444                "s3://bucket/path/to/file",
445                HashMap::from([(AmazonS3ConfigKey::S3Express, "false".into())]),
446                false,
447            ),
448            ("s3://bucket/path/to/file", HashMap::from([]), false),
449            (
450                "s3://bucket--x-s3/path/to/file",
451                HashMap::from([(AmazonS3ConfigKey::S3Express, "true".into())]),
452                true,
453            ),
454            (
455                "s3://bucket--x-s3/path/to/file",
456                HashMap::from([(AmazonS3ConfigKey::S3Express, "false".into())]),
457                true,
458            ),
459            ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
460        ];
461
462        for (uri, mut configs, expected) in cases {
463            let url = Url::parse(uri).unwrap();
464            let is_s3_express = check_s3_express(&url, &mut configs);
465            assert_eq!(is_s3_express, expected);
466            if is_s3_express {
467                assert!(configs
468                    .get(&AmazonS3ConfigKey::S3Express)
469                    .is_some_and(|opt| opt == "true"));
470            }
471        }
472    }
473}