lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3// SPDX-License-Identifier: Apache-2.0
4// SPDX-FileCopyrightText: Copyright The Lance Authors
5
6use std::{
7    collections::HashMap,
8    str::FromStr,
9    sync::Arc,
10    time::{Duration, SystemTime},
11};
12
13use aws_config::default_provider::credentials::DefaultCredentialsChain;
14use aws_credential_types::provider::ProvideCredentials;
15use object_store::{
16    aws::{
17        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
18        AwsCredentialProvider,
19    },
20    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
21    StaticCredentialProvider,
22};
23use snafu::location;
24use tokio::sync::RwLock;
25use url::Url;
26
27use crate::object_store::{
28    ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
29    DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
30};
31use lance_core::error::{Error, Result};
32
33#[derive(Default, Debug)]
34pub struct AwsStoreProvider;
35
36#[async_trait::async_trait]
37impl ObjectStoreProvider for AwsStoreProvider {
38    async fn new_store(
39        &self,
40        mut base_path: Url,
41        params: &ObjectStoreParams,
42    ) -> Result<ObjectStore> {
43        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
44        let mut storage_options =
45            StorageOptions(params.storage_options.clone().unwrap_or_default());
46        let download_retry_count = storage_options.download_retry_count();
47
48        let max_retries = storage_options.client_max_retries();
49        let retry_timeout = storage_options.client_retry_timeout();
50        let retry_config = RetryConfig {
51            backoff: Default::default(),
52            max_retries,
53            retry_timeout: Duration::from_secs(retry_timeout),
54        };
55
56        storage_options.with_env_s3();
57
58        let mut storage_options = storage_options.as_s3_options();
59        let region = resolve_s3_region(&base_path, &storage_options).await?;
60        let (aws_creds, region) = build_aws_credential(
61            params.s3_credentials_refresh_offset,
62            params.aws_credentials.clone(),
63            Some(&storage_options),
64            region,
65        )
66        .await?;
67
68        // This will be default in next version of object store.
69        // https://github.com/apache/arrow-rs/pull/7181
70        // We can do this when we upgrade to 0.12.
71        storage_options
72            .entry(AmazonS3ConfigKey::ConditionalPut)
73            .or_insert_with(|| "etag".to_string());
74
75        // Cloudflare does not support varying part sizes.
76        let use_constant_size_upload_parts = storage_options
77            .get(&AmazonS3ConfigKey::Endpoint)
78            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
79            .unwrap_or(false);
80
81        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
82        base_path.set_scheme("s3").unwrap();
83        base_path.set_query(None);
84
85        // we can't use parse_url_opts here because we need to manually set the credentials provider
86        let mut builder = AmazonS3Builder::new();
87        for (key, value) in storage_options {
88            builder = builder.with_config(key, value);
89        }
90        builder = builder
91            .with_url(base_path.as_ref())
92            .with_credentials(aws_creds)
93            .with_retry(retry_config)
94            .with_region(region);
95        let inner = Arc::new(builder.build()?);
96
97        Ok(ObjectStore {
98            inner,
99            scheme: String::from(base_path.scheme()),
100            block_size,
101            max_iop_size: *DEFAULT_MAX_IOP_SIZE,
102            use_constant_size_upload_parts,
103            list_is_lexically_ordered: true,
104            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
105            download_retry_count,
106        })
107    }
108}
109
110/// Figure out the S3 region of the bucket.
111///
112/// This resolves in order of precedence:
113/// 1. The region provided in the storage options
114/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
115///
116/// It can return None if no region is provided and the endpoint is set.
117async fn resolve_s3_region(
118    url: &Url,
119    storage_options: &HashMap<AmazonS3ConfigKey, String>,
120) -> Result<Option<String>> {
121    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
122        Ok(Some(region.clone()))
123    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
124        // If no endpoint is set, we can assume this is AWS S3 and the region
125        // can be resolved from the bucket.
126        let bucket = url.host_str().ok_or_else(|| {
127            Error::invalid_input(
128                format!("Could not parse bucket from url: {}", url),
129                location!(),
130            )
131        })?;
132
133        let mut client_options = ClientOptions::default();
134        for (key, value) in storage_options {
135            if let AmazonS3ConfigKey::Client(client_key) = key {
136                client_options = client_options.with_config(*client_key, value.clone());
137            }
138        }
139
140        let bucket_region =
141            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
142        Ok(Some(bucket_region))
143    } else {
144        Ok(None)
145    }
146}
147
148/// Build AWS credentials
149///
150/// This resolves credentials from the following sources in order:
151/// 1. An explicit `credentials` provider
152/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`,
153///    `aws_secret_access_key`, `aws_session_token`)
154/// 3. The default credential provider chain from AWS SDK.
155///
156/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
157pub async fn build_aws_credential(
158    credentials_refresh_offset: Duration,
159    credentials: Option<AwsCredentialProvider>,
160    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
161    region: Option<String>,
162) -> Result<(AwsCredentialProvider, String)> {
163    // TODO: make this return no credential provider not using AWS
164    use aws_config::meta::region::RegionProviderChain;
165    const DEFAULT_REGION: &str = "us-west-2";
166
167    let region = if let Some(region) = region {
168        region
169    } else {
170        RegionProviderChain::default_provider()
171            .or_else(DEFAULT_REGION)
172            .region()
173            .await
174            .map(|r| r.as_ref().to_string())
175            .unwrap_or(DEFAULT_REGION.to_string())
176    };
177
178    if let Some(creds) = credentials {
179        Ok((creds, region))
180    } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
181        Ok((Arc::new(creds), region))
182    } else {
183        let credentials_provider = DefaultCredentialsChain::builder().build().await;
184
185        Ok((
186            Arc::new(AwsCredentialAdapter::new(
187                Arc::new(credentials_provider),
188                credentials_refresh_offset,
189            )),
190            region,
191        ))
192    }
193}
194
195fn extract_static_s3_credentials(
196    options: &HashMap<AmazonS3ConfigKey, String>,
197) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
198    let key_id = options
199        .get(&AmazonS3ConfigKey::AccessKeyId)
200        .map(|s| s.to_string());
201    let secret_key = options
202        .get(&AmazonS3ConfigKey::SecretAccessKey)
203        .map(|s| s.to_string());
204    let token = options
205        .get(&AmazonS3ConfigKey::Token)
206        .map(|s| s.to_string());
207    match (key_id, secret_key, token) {
208        (Some(key_id), Some(secret_key), token) => {
209            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
210                key_id,
211                secret_key,
212                token,
213            }))
214        }
215        _ => None,
216    }
217}
218
219/// Adapt an AWS SDK cred into object_store credentials
220#[derive(Debug)]
221pub struct AwsCredentialAdapter {
222    pub inner: Arc<dyn ProvideCredentials>,
223
224    // RefCell can't be shared across threads, so we use HashMap
225    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
226
227    // The amount of time before expiry to refresh credentials
228    credentials_refresh_offset: Duration,
229}
230
231impl AwsCredentialAdapter {
232    pub fn new(
233        provider: Arc<dyn ProvideCredentials>,
234        credentials_refresh_offset: Duration,
235    ) -> Self {
236        Self {
237            inner: provider,
238            cache: Arc::new(RwLock::new(HashMap::new())),
239            credentials_refresh_offset,
240        }
241    }
242}
243
244const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
245
246#[async_trait::async_trait]
247impl CredentialProvider for AwsCredentialAdapter {
248    type Credential = ObjectStoreAwsCredential;
249
250    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
251        let cached_creds = {
252            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
253            let expired = cache_value
254                .clone()
255                .map(|cred| {
256                    cred.expiry()
257                        .map(|exp| {
258                            exp.checked_sub(self.credentials_refresh_offset)
259                                .expect("this time should always be valid")
260                                < SystemTime::now()
261                        })
262                        // no expiry is never expire
263                        .unwrap_or(false)
264                })
265                .unwrap_or(true); // no cred is the same as expired;
266            if expired {
267                None
268            } else {
269                cache_value.clone()
270            }
271        };
272
273        if let Some(creds) = cached_creds {
274            Ok(Arc::new(Self::Credential {
275                key_id: creds.access_key_id().to_string(),
276                secret_key: creds.secret_access_key().to_string(),
277                token: creds.session_token().map(|s| s.to_string()),
278            }))
279        } else {
280            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
281                |e| Error::Internal {
282                    message: format!("Failed to get AWS credentials: {}", e),
283                    location: location!(),
284                },
285            )?);
286
287            self.cache
288                .write()
289                .await
290                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
291
292            Ok(Arc::new(Self::Credential {
293                key_id: refreshed_creds.access_key_id().to_string(),
294                secret_key: refreshed_creds.secret_access_key().to_string(),
295                token: refreshed_creds.session_token().map(|s| s.to_string()),
296            }))
297        }
298    }
299}
300
301impl StorageOptions {
302    /// Add values from the environment to storage options
303    pub fn with_env_s3(&mut self) {
304        for (os_key, os_value) in std::env::vars_os() {
305            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
306                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
307                    if !self.0.contains_key(config_key.as_ref()) {
308                        self.0
309                            .insert(config_key.as_ref().to_string(), value.to_string());
310                    }
311                }
312            }
313        }
314    }
315
316    /// Subset of options relevant for s3 storage
317    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
318        self.0
319            .iter()
320            .filter_map(|(key, value)| {
321                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
322                Some((s3_key, value.clone()))
323            })
324            .collect()
325    }
326}
327
328impl ObjectStoreParams {
329    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
330    pub fn with_aws_credentials(
331        aws_credentials: Option<AwsCredentialProvider>,
332        region: Option<String>,
333    ) -> Self {
334        Self {
335            aws_credentials,
336            storage_options: region
337                .map(|region| [("region".into(), region)].iter().cloned().collect()),
338            ..Default::default()
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use std::sync::atomic::{AtomicBool, Ordering};
346
347    use object_store::path::Path;
348
349    use crate::object_store::ObjectStoreRegistry;
350
351    use super::*;
352
353    #[derive(Debug, Default)]
354    struct MockAwsCredentialsProvider {
355        called: AtomicBool,
356    }
357
358    #[async_trait::async_trait]
359    impl CredentialProvider for MockAwsCredentialsProvider {
360        type Credential = ObjectStoreAwsCredential;
361
362        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
363            self.called.store(true, Ordering::Relaxed);
364            Ok(Arc::new(Self::Credential {
365                key_id: "".to_string(),
366                secret_key: "".to_string(),
367                token: None,
368            }))
369        }
370    }
371
372    #[tokio::test]
373    async fn test_injected_aws_creds_option_is_used() {
374        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
375        let registry = Arc::new(ObjectStoreRegistry::default());
376
377        let params = ObjectStoreParams {
378            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
379            ..ObjectStoreParams::default()
380        };
381
382        // Not called yet
383        assert!(!mock_provider.called.load(Ordering::Relaxed));
384
385        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
386            .await
387            .unwrap();
388
389        // fails, but we don't care
390        let _ = store
391            .open(&Path::parse("/").unwrap())
392            .await
393            .unwrap()
394            .get_range(0..1)
395            .await;
396
397        // Not called yet
398        assert!(mock_provider.called.load(Ordering::Relaxed));
399    }
400
401    #[test]
402    fn test_s3_path_parsing() {
403        let provider = AwsStoreProvider;
404
405        let cases = [
406            ("s3://bucket/path/to/file", "path/to/file"),
407            (
408                "s3+ddb://bucket/path/to/file?ddbTableName=test",
409                "path/to/file",
410            ),
411        ];
412
413        for (uri, expected_path) in cases {
414            let url = Url::parse(uri).unwrap();
415            let path = provider.extract_path(&url);
416            let expected_path = Path::from(expected_path);
417            assert_eq!(path, expected_path);
418        }
419    }
420}