lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3// SPDX-License-Identifier: Apache-2.0
4// SPDX-FileCopyrightText: Copyright The Lance Authors
5
6use std::{
7    collections::HashMap,
8    str::FromStr,
9    sync::Arc,
10    time::{Duration, SystemTime},
11};
12
13use aws_config::default_provider::credentials::DefaultCredentialsChain;
14use aws_credential_types::provider::ProvideCredentials;
15use object_store::{
16    aws::{
17        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
18        AwsCredentialProvider,
19    },
20    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
21    StaticCredentialProvider,
22};
23use snafu::location;
24use tokio::sync::RwLock;
25use url::Url;
26
27use crate::object_store::{
28    ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
29    DEFAULT_CLOUD_IO_PARALLELISM,
30};
31use lance_core::error::{Error, Result};
32
33#[derive(Default, Debug)]
34pub struct AwsStoreProvider;
35
36#[async_trait::async_trait]
37impl ObjectStoreProvider for AwsStoreProvider {
38    async fn new_store(
39        &self,
40        mut base_path: Url,
41        params: &ObjectStoreParams,
42    ) -> Result<ObjectStore> {
43        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
44        let mut storage_options =
45            StorageOptions(params.storage_options.clone().unwrap_or_default());
46        let download_retry_count = storage_options.download_retry_count();
47
48        let max_retries = storage_options.client_max_retries();
49        let retry_timeout = storage_options.client_retry_timeout();
50        let retry_config = RetryConfig {
51            backoff: Default::default(),
52            max_retries,
53            retry_timeout: Duration::from_secs(retry_timeout),
54        };
55
56        storage_options.with_env_s3();
57
58        let mut storage_options = storage_options.as_s3_options();
59        let region = resolve_s3_region(&base_path, &storage_options).await?;
60        let (aws_creds, region) = build_aws_credential(
61            params.s3_credentials_refresh_offset,
62            params.aws_credentials.clone(),
63            Some(&storage_options),
64            region,
65        )
66        .await?;
67
68        // This will be default in next version of object store.
69        // https://github.com/apache/arrow-rs/pull/7181
70        // We can do this when we upgrade to 0.12.
71        storage_options
72            .entry(AmazonS3ConfigKey::ConditionalPut)
73            .or_insert_with(|| "etag".to_string());
74
75        // Cloudflare does not support varying part sizes.
76        let use_constant_size_upload_parts = storage_options
77            .get(&AmazonS3ConfigKey::Endpoint)
78            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
79            .unwrap_or(false);
80
81        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
82        base_path.set_scheme("s3").unwrap();
83        base_path.set_query(None);
84
85        // we can't use parse_url_opts here because we need to manually set the credentials provider
86        let mut builder = AmazonS3Builder::new();
87        for (key, value) in storage_options {
88            builder = builder.with_config(key, value);
89        }
90        builder = builder
91            .with_url(base_path.as_ref())
92            .with_credentials(aws_creds)
93            .with_retry(retry_config)
94            .with_region(region);
95        let inner = Arc::new(builder.build()?);
96
97        Ok(ObjectStore {
98            inner,
99            scheme: String::from(base_path.scheme()),
100            block_size,
101            use_constant_size_upload_parts,
102            list_is_lexically_ordered: true,
103            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
104            download_retry_count,
105        })
106    }
107}
108
109/// Figure out the S3 region of the bucket.
110///
111/// This resolves in order of precedence:
112/// 1. The region provided in the storage options
113/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
114///
115/// It can return None if no region is provided and the endpoint is set.
116async fn resolve_s3_region(
117    url: &Url,
118    storage_options: &HashMap<AmazonS3ConfigKey, String>,
119) -> Result<Option<String>> {
120    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
121        Ok(Some(region.clone()))
122    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
123        // If no endpoint is set, we can assume this is AWS S3 and the region
124        // can be resolved from the bucket.
125        let bucket = url.host_str().ok_or_else(|| {
126            Error::invalid_input(
127                format!("Could not parse bucket from url: {}", url),
128                location!(),
129            )
130        })?;
131
132        let mut client_options = ClientOptions::default();
133        for (key, value) in storage_options {
134            if let AmazonS3ConfigKey::Client(client_key) = key {
135                client_options = client_options.with_config(*client_key, value.clone());
136            }
137        }
138
139        let bucket_region =
140            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
141        Ok(Some(bucket_region))
142    } else {
143        Ok(None)
144    }
145}
146
147/// Build AWS credentials
148///
149/// This resolves credentials from the following sources in order:
150/// 1. An explicit `credentials` provider
151/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`,
152///    `aws_secret_access_key`, `aws_session_token`)
153/// 3. The default credential provider chain from AWS SDK.
154///
155/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
156pub async fn build_aws_credential(
157    credentials_refresh_offset: Duration,
158    credentials: Option<AwsCredentialProvider>,
159    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
160    region: Option<String>,
161) -> Result<(AwsCredentialProvider, String)> {
162    // TODO: make this return no credential provider not using AWS
163    use aws_config::meta::region::RegionProviderChain;
164    const DEFAULT_REGION: &str = "us-west-2";
165
166    let region = if let Some(region) = region {
167        region
168    } else {
169        RegionProviderChain::default_provider()
170            .or_else(DEFAULT_REGION)
171            .region()
172            .await
173            .map(|r| r.as_ref().to_string())
174            .unwrap_or(DEFAULT_REGION.to_string())
175    };
176
177    if let Some(creds) = credentials {
178        Ok((creds, region))
179    } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
180        Ok((Arc::new(creds), region))
181    } else {
182        let credentials_provider = DefaultCredentialsChain::builder().build().await;
183
184        Ok((
185            Arc::new(AwsCredentialAdapter::new(
186                Arc::new(credentials_provider),
187                credentials_refresh_offset,
188            )),
189            region,
190        ))
191    }
192}
193
194fn extract_static_s3_credentials(
195    options: &HashMap<AmazonS3ConfigKey, String>,
196) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
197    let key_id = options
198        .get(&AmazonS3ConfigKey::AccessKeyId)
199        .map(|s| s.to_string());
200    let secret_key = options
201        .get(&AmazonS3ConfigKey::SecretAccessKey)
202        .map(|s| s.to_string());
203    let token = options
204        .get(&AmazonS3ConfigKey::Token)
205        .map(|s| s.to_string());
206    match (key_id, secret_key, token) {
207        (Some(key_id), Some(secret_key), token) => {
208            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
209                key_id,
210                secret_key,
211                token,
212            }))
213        }
214        _ => None,
215    }
216}
217
218/// Adapt an AWS SDK cred into object_store credentials
219#[derive(Debug)]
220pub struct AwsCredentialAdapter {
221    pub inner: Arc<dyn ProvideCredentials>,
222
223    // RefCell can't be shared across threads, so we use HashMap
224    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
225
226    // The amount of time before expiry to refresh credentials
227    credentials_refresh_offset: Duration,
228}
229
230impl AwsCredentialAdapter {
231    pub fn new(
232        provider: Arc<dyn ProvideCredentials>,
233        credentials_refresh_offset: Duration,
234    ) -> Self {
235        Self {
236            inner: provider,
237            cache: Arc::new(RwLock::new(HashMap::new())),
238            credentials_refresh_offset,
239        }
240    }
241}
242
243const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
244
245#[async_trait::async_trait]
246impl CredentialProvider for AwsCredentialAdapter {
247    type Credential = ObjectStoreAwsCredential;
248
249    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
250        let cached_creds = {
251            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
252            let expired = cache_value
253                .clone()
254                .map(|cred| {
255                    cred.expiry()
256                        .map(|exp| {
257                            exp.checked_sub(self.credentials_refresh_offset)
258                                .expect("this time should always be valid")
259                                < SystemTime::now()
260                        })
261                        // no expiry is never expire
262                        .unwrap_or(false)
263                })
264                .unwrap_or(true); // no cred is the same as expired;
265            if expired {
266                None
267            } else {
268                cache_value.clone()
269            }
270        };
271
272        if let Some(creds) = cached_creds {
273            Ok(Arc::new(Self::Credential {
274                key_id: creds.access_key_id().to_string(),
275                secret_key: creds.secret_access_key().to_string(),
276                token: creds.session_token().map(|s| s.to_string()),
277            }))
278        } else {
279            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
280                |e| Error::Internal {
281                    message: format!("Failed to get AWS credentials: {}", e),
282                    location: location!(),
283                },
284            )?);
285
286            self.cache
287                .write()
288                .await
289                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
290
291            Ok(Arc::new(Self::Credential {
292                key_id: refreshed_creds.access_key_id().to_string(),
293                secret_key: refreshed_creds.secret_access_key().to_string(),
294                token: refreshed_creds.session_token().map(|s| s.to_string()),
295            }))
296        }
297    }
298}
299
300impl StorageOptions {
301    /// Add values from the environment to storage options
302    pub fn with_env_s3(&mut self) {
303        for (os_key, os_value) in std::env::vars_os() {
304            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
305                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
306                    if !self.0.contains_key(config_key.as_ref()) {
307                        self.0
308                            .insert(config_key.as_ref().to_string(), value.to_string());
309                    }
310                }
311            }
312        }
313    }
314
315    /// Subset of options relevant for s3 storage
316    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
317        self.0
318            .iter()
319            .filter_map(|(key, value)| {
320                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
321                Some((s3_key, value.clone()))
322            })
323            .collect()
324    }
325}
326
327impl ObjectStoreParams {
328    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
329    pub fn with_aws_credentials(
330        aws_credentials: Option<AwsCredentialProvider>,
331        region: Option<String>,
332    ) -> Self {
333        Self {
334            aws_credentials,
335            storage_options: region
336                .map(|region| [("region".into(), region)].iter().cloned().collect()),
337            ..Default::default()
338        }
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use std::sync::atomic::{AtomicBool, Ordering};
345
346    use object_store::path::Path;
347
348    use crate::object_store::ObjectStoreRegistry;
349
350    use super::*;
351
352    #[derive(Debug, Default)]
353    struct MockAwsCredentialsProvider {
354        called: AtomicBool,
355    }
356
357    #[async_trait::async_trait]
358    impl CredentialProvider for MockAwsCredentialsProvider {
359        type Credential = ObjectStoreAwsCredential;
360
361        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
362            self.called.store(true, Ordering::Relaxed);
363            Ok(Arc::new(Self::Credential {
364                key_id: "".to_string(),
365                secret_key: "".to_string(),
366                token: None,
367            }))
368        }
369    }
370
371    #[tokio::test]
372    async fn test_injected_aws_creds_option_is_used() {
373        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
374        let registry = Arc::new(ObjectStoreRegistry::default());
375
376        let params = ObjectStoreParams {
377            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
378            ..ObjectStoreParams::default()
379        };
380
381        // Not called yet
382        assert!(!mock_provider.called.load(Ordering::Relaxed));
383
384        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
385            .await
386            .unwrap();
387
388        // fails, but we don't care
389        let _ = store
390            .open(&Path::parse("/").unwrap())
391            .await
392            .unwrap()
393            .get_range(0..1)
394            .await;
395
396        // Not called yet
397        assert!(mock_provider.called.load(Ordering::Relaxed));
398    }
399
400    #[test]
401    fn test_s3_path_parsing() {
402        let provider = AwsStoreProvider;
403
404        let cases = [
405            ("s3://bucket/path/to/file", "path/to/file"),
406            (
407                "s3+ddb://bucket/path/to/file?ddbTableName=test",
408                "path/to/file",
409            ),
410        ];
411
412        for (uri, expected_path) in cases {
413            let url = Url::parse(uri).unwrap();
414            let path = provider.extract_path(&url);
415            let expected_path = Path::from(expected_path);
416            assert_eq!(path, expected_path);
417        }
418    }
419}