lance_io/object_store/providers/
aws.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3// SPDX-License-Identifier: Apache-2.0
4// SPDX-FileCopyrightText: Copyright The Lance Authors
5
6use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
7
8#[cfg(test)]
9use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
10
11#[cfg(not(test))]
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use object_store::ObjectStore as OSObjectStore;
15use object_store_opendal::OpendalStore;
16use opendal::{services::S3, Operator};
17
18use aws_config::default_provider::credentials::DefaultCredentialsChain;
19use aws_credential_types::provider::ProvideCredentials;
20use object_store::{
21    aws::{
22        AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
23        AwsCredentialProvider,
24    },
25    ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
26    StaticCredentialProvider,
27};
28use snafu::location;
29use tokio::sync::RwLock;
30use url::Url;
31
32use crate::object_store::{
33    ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
34    DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
35};
36use lance_core::error::{Error, Result};
37
38#[derive(Default, Debug)]
39pub struct AwsStoreProvider;
40
41impl AwsStoreProvider {
42    async fn build_amazon_s3_store(
43        &self,
44        base_path: &mut Url,
45        params: &ObjectStoreParams,
46        storage_options: &StorageOptions,
47        is_s3_express: bool,
48    ) -> Result<Arc<dyn OSObjectStore>> {
49        let max_retries = storage_options.client_max_retries();
50        let retry_timeout = storage_options.client_retry_timeout();
51        let retry_config = RetryConfig {
52            backoff: Default::default(),
53            max_retries,
54            retry_timeout: Duration::from_secs(retry_timeout),
55        };
56
57        let mut s3_storage_options = storage_options.as_s3_options();
58        let region = resolve_s3_region(base_path, &s3_storage_options).await?;
59        let (aws_creds, region) = build_aws_credential(
60            params.s3_credentials_refresh_offset,
61            params.aws_credentials.clone(),
62            Some(&s3_storage_options),
63            region,
64        )
65        .await?;
66
67        // Set S3Express flag if detected
68        if is_s3_express {
69            s3_storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
70        }
71
72        // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts
73        base_path.set_scheme("s3").unwrap();
74        base_path.set_query(None);
75
76        // we can't use parse_url_opts here because we need to manually set the credentials provider
77        let mut builder = AmazonS3Builder::new();
78        for (key, value) in s3_storage_options {
79            builder = builder.with_config(key, value);
80        }
81        builder = builder
82            .with_url(base_path.as_ref())
83            .with_credentials(aws_creds)
84            .with_retry(retry_config)
85            .with_region(region);
86
87        Ok(Arc::new(builder.build()?) as Arc<dyn OSObjectStore>)
88    }
89
90    async fn build_opendal_s3_store(
91        &self,
92        base_path: &Url,
93        storage_options: &StorageOptions,
94    ) -> Result<Arc<dyn OSObjectStore>> {
95        let bucket = base_path
96            .host_str()
97            .ok_or_else(|| Error::invalid_input("S3 URL must contain bucket name", location!()))?
98            .to_string();
99
100        let prefix = base_path.path().trim_start_matches('/').to_string();
101
102        // Start with all storage options as the config map
103        // OpenDAL will handle environment variables through its default credentials chain
104        let mut config_map: HashMap<String, String> = storage_options.0.clone();
105
106        // Set required OpenDAL configuration
107        config_map.insert("bucket".to_string(), bucket);
108
109        if !prefix.is_empty() {
110            config_map.insert("root".to_string(), "/".to_string());
111        }
112
113        let operator = Operator::from_iter::<S3>(config_map)
114            .map_err(|e| {
115                Error::invalid_input(
116                    format!("Failed to create S3 operator: {:?}", e),
117                    location!(),
118                )
119            })?
120            .finish();
121
122        Ok(Arc::new(OpendalStore::new(operator)) as Arc<dyn OSObjectStore>)
123    }
124}
125
126#[async_trait::async_trait]
127impl ObjectStoreProvider for AwsStoreProvider {
128    async fn new_store(
129        &self,
130        mut base_path: Url,
131        params: &ObjectStoreParams,
132    ) -> Result<ObjectStore> {
133        let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
134        let mut storage_options =
135            StorageOptions(params.storage_options.clone().unwrap_or_default());
136        storage_options.with_env_s3();
137        let download_retry_count = storage_options.download_retry_count();
138
139        let use_opendal = storage_options
140            .0
141            .get("use_opendal")
142            .map(|v| v == "true")
143            .unwrap_or(false);
144
145        // Determine S3 Express and constant size upload parts before building the store
146        let is_s3_express = check_s3_express(&base_path, &storage_options);
147
148        let use_constant_size_upload_parts = storage_options
149            .0
150            .get("aws_endpoint")
151            .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
152            .unwrap_or(false);
153
154        let inner = if use_opendal {
155            // Use OpenDAL implementation
156            self.build_opendal_s3_store(&base_path, &storage_options)
157                .await?
158        } else {
159            // Use default Amazon S3 implementation
160            self.build_amazon_s3_store(&mut base_path, params, &storage_options, is_s3_express)
161                .await?
162        };
163
164        Ok(ObjectStore {
165            inner,
166            scheme: String::from(base_path.scheme()),
167            block_size,
168            max_iop_size: *DEFAULT_MAX_IOP_SIZE,
169            use_constant_size_upload_parts,
170            list_is_lexically_ordered: !is_s3_express,
171            io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
172            download_retry_count,
173        })
174    }
175}
176
177/// Check if the storage is S3 Express
178fn check_s3_express(url: &Url, storage_options: &StorageOptions) -> bool {
179    storage_options
180        .0
181        .get("s3_express")
182        .map(|v| v == "true")
183        .unwrap_or(false)
184        || url.authority().ends_with("--x-s3")
185}
186
187/// Figure out the S3 region of the bucket.
188///
189/// This resolves in order of precedence:
190/// 1. The region provided in the storage options
191/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket
192///
193/// It can return None if no region is provided and the endpoint is set.
194async fn resolve_s3_region(
195    url: &Url,
196    storage_options: &HashMap<AmazonS3ConfigKey, String>,
197) -> Result<Option<String>> {
198    if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
199        Ok(Some(region.clone()))
200    } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
201        // If no endpoint is set, we can assume this is AWS S3 and the region
202        // can be resolved from the bucket.
203        let bucket = url.host_str().ok_or_else(|| {
204            Error::invalid_input(
205                format!("Could not parse bucket from url: {}", url),
206                location!(),
207            )
208        })?;
209
210        let mut client_options = ClientOptions::default();
211        for (key, value) in storage_options {
212            if let AmazonS3ConfigKey::Client(client_key) = key {
213                client_options = client_options.with_config(*client_key, value.clone());
214            }
215        }
216
217        let bucket_region =
218            object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
219        Ok(Some(bucket_region))
220    } else {
221        Ok(None)
222    }
223}
224
225/// Build AWS credentials
226///
227/// This resolves credentials from the following sources in order:
228/// 1. An explicit `credentials` provider
229/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`,
230///    `aws_secret_access_key`, `aws_session_token`)
231/// 3. The default credential provider chain from AWS SDK.
232///
233/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials.
234pub async fn build_aws_credential(
235    credentials_refresh_offset: Duration,
236    credentials: Option<AwsCredentialProvider>,
237    storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
238    region: Option<String>,
239) -> Result<(AwsCredentialProvider, String)> {
240    // TODO: make this return no credential provider not using AWS
241    use aws_config::meta::region::RegionProviderChain;
242    const DEFAULT_REGION: &str = "us-west-2";
243
244    let region = if let Some(region) = region {
245        region
246    } else {
247        RegionProviderChain::default_provider()
248            .or_else(DEFAULT_REGION)
249            .region()
250            .await
251            .map(|r| r.as_ref().to_string())
252            .unwrap_or(DEFAULT_REGION.to_string())
253    };
254
255    if let Some(creds) = credentials {
256        Ok((creds, region))
257    } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
258        Ok((Arc::new(creds), region))
259    } else {
260        let credentials_provider = DefaultCredentialsChain::builder().build().await;
261
262        Ok((
263            Arc::new(AwsCredentialAdapter::new(
264                Arc::new(credentials_provider),
265                credentials_refresh_offset,
266            )),
267            region,
268        ))
269    }
270}
271
272fn extract_static_s3_credentials(
273    options: &HashMap<AmazonS3ConfigKey, String>,
274) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
275    let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
276    let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
277    let token = options.get(&AmazonS3ConfigKey::Token).cloned();
278    match (key_id, secret_key, token) {
279        (Some(key_id), Some(secret_key), token) => {
280            Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
281                key_id,
282                secret_key,
283                token,
284            }))
285        }
286        _ => None,
287    }
288}
289
290/// Adapt an AWS SDK cred into object_store credentials
291#[derive(Debug)]
292pub struct AwsCredentialAdapter {
293    pub inner: Arc<dyn ProvideCredentials>,
294
295    // RefCell can't be shared across threads, so we use HashMap
296    cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
297
298    // The amount of time before expiry to refresh credentials
299    credentials_refresh_offset: Duration,
300}
301
302impl AwsCredentialAdapter {
303    pub fn new(
304        provider: Arc<dyn ProvideCredentials>,
305        credentials_refresh_offset: Duration,
306    ) -> Self {
307        Self {
308            inner: provider,
309            cache: Arc::new(RwLock::new(HashMap::new())),
310            credentials_refresh_offset,
311        }
312    }
313}
314
315const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
316
317/// Convert std::time::SystemTime from AWS SDK to our mockable SystemTime
318fn to_system_time(time: std::time::SystemTime) -> SystemTime {
319    let duration_since_epoch = time
320        .duration_since(std::time::UNIX_EPOCH)
321        .expect("time should be after UNIX_EPOCH");
322    UNIX_EPOCH + duration_since_epoch
323}
324
325#[async_trait::async_trait]
326impl CredentialProvider for AwsCredentialAdapter {
327    type Credential = ObjectStoreAwsCredential;
328
329    async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
330        let cached_creds = {
331            let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
332            let expired = cache_value
333                .clone()
334                .map(|cred| {
335                    cred.expiry()
336                        .map(|exp| {
337                            to_system_time(exp)
338                                .checked_sub(self.credentials_refresh_offset)
339                                .expect("this time should always be valid")
340                                < SystemTime::now()
341                        })
342                        // no expiry is never expire
343                        .unwrap_or(false)
344                })
345                .unwrap_or(true); // no cred is the same as expired;
346            if expired {
347                None
348            } else {
349                cache_value.clone()
350            }
351        };
352
353        if let Some(creds) = cached_creds {
354            Ok(Arc::new(Self::Credential {
355                key_id: creds.access_key_id().to_string(),
356                secret_key: creds.secret_access_key().to_string(),
357                token: creds.session_token().map(|s| s.to_string()),
358            }))
359        } else {
360            let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
361                |e| Error::Internal {
362                    message: format!("Failed to get AWS credentials: {:?}", e),
363                    location: location!(),
364                },
365            )?);
366
367            self.cache
368                .write()
369                .await
370                .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
371
372            Ok(Arc::new(Self::Credential {
373                key_id: refreshed_creds.access_key_id().to_string(),
374                secret_key: refreshed_creds.secret_access_key().to_string(),
375                token: refreshed_creds.session_token().map(|s| s.to_string()),
376            }))
377        }
378    }
379}
380
381impl StorageOptions {
382    /// Add values from the environment to storage options
383    pub fn with_env_s3(&mut self) {
384        for (os_key, os_value) in std::env::vars_os() {
385            if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
386                if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
387                    if !self.0.contains_key(config_key.as_ref()) {
388                        self.0
389                            .insert(config_key.as_ref().to_string(), value.to_string());
390                    }
391                }
392            }
393        }
394    }
395
396    /// Subset of options relevant for s3 storage
397    pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
398        self.0
399            .iter()
400            .filter_map(|(key, value)| {
401                let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
402                Some((s3_key, value.clone()))
403            })
404            .collect()
405    }
406}
407
408impl ObjectStoreParams {
409    /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials.
410    pub fn with_aws_credentials(
411        aws_credentials: Option<AwsCredentialProvider>,
412        region: Option<String>,
413    ) -> Self {
414        Self {
415            aws_credentials,
416            storage_options: region
417                .map(|region| [("region".into(), region)].iter().cloned().collect()),
418            ..Default::default()
419        }
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use std::sync::atomic::{AtomicBool, Ordering};
426
427    use object_store::path::Path;
428
429    use crate::object_store::ObjectStoreRegistry;
430
431    use super::*;
432
433    #[derive(Debug, Default)]
434    struct MockAwsCredentialsProvider {
435        called: AtomicBool,
436    }
437
438    #[async_trait::async_trait]
439    impl CredentialProvider for MockAwsCredentialsProvider {
440        type Credential = ObjectStoreAwsCredential;
441
442        async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
443            self.called.store(true, Ordering::Relaxed);
444            Ok(Arc::new(Self::Credential {
445                key_id: "".to_string(),
446                secret_key: "".to_string(),
447                token: None,
448            }))
449        }
450    }
451
452    #[tokio::test]
453    async fn test_injected_aws_creds_option_is_used() {
454        let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
455        let registry = Arc::new(ObjectStoreRegistry::default());
456
457        let params = ObjectStoreParams {
458            aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
459            ..ObjectStoreParams::default()
460        };
461
462        // Not called yet
463        assert!(!mock_provider.called.load(Ordering::Relaxed));
464
465        let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", &params)
466            .await
467            .unwrap();
468
469        // fails, but we don't care
470        let _ = store
471            .open(&Path::parse("/").unwrap())
472            .await
473            .unwrap()
474            .get_range(0..1)
475            .await;
476
477        // Not called yet
478        assert!(mock_provider.called.load(Ordering::Relaxed));
479    }
480
481    #[test]
482    fn test_s3_path_parsing() {
483        let provider = AwsStoreProvider;
484
485        let cases = [
486            ("s3://bucket/path/to/file", "path/to/file"),
487            // for non ASCII string tests
488            ("s3://bucket/测试path/to/file", "测试path/to/file"),
489            ("s3://bucket/path/&to/file", "path/&to/file"),
490            ("s3://bucket/path/=to/file", "path/=to/file"),
491            (
492                "s3+ddb://bucket/path/to/file?ddbTableName=test",
493                "path/to/file",
494            ),
495        ];
496
497        for (uri, expected_path) in cases {
498            let url = Url::parse(uri).unwrap();
499            let path = provider.extract_path(&url).unwrap();
500            let expected_path = Path::from(expected_path);
501            assert_eq!(path, expected_path)
502        }
503    }
504
505    #[test]
506    fn test_is_s3_express() {
507        let cases = [
508            (
509                "s3://bucket/path/to/file",
510                HashMap::from([("s3_express".to_string(), "true".to_string())]),
511                true,
512            ),
513            (
514                "s3://bucket/path/to/file",
515                HashMap::from([("s3_express".to_string(), "false".to_string())]),
516                false,
517            ),
518            ("s3://bucket/path/to/file", HashMap::from([]), false),
519            (
520                "s3://bucket--x-s3/path/to/file",
521                HashMap::from([("s3_express".to_string(), "true".to_string())]),
522                true,
523            ),
524            (
525                "s3://bucket--x-s3/path/to/file",
526                HashMap::from([("s3_express".to_string(), "false".to_string())]),
527                true, // URL takes precedence
528            ),
529            ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
530        ];
531
532        for (uri, storage_map, expected) in cases {
533            let url = Url::parse(uri).unwrap();
534            let storage_options = StorageOptions(storage_map);
535            let is_s3_express = check_s3_express(&url, &storage_options);
536            assert_eq!(is_s3_express, expected);
537        }
538    }
539
540    #[tokio::test]
541    async fn test_use_opendal_flag() {
542        let provider = AwsStoreProvider;
543        let url = Url::parse("s3://test-bucket/path").unwrap();
544        let params_with_flag = ObjectStoreParams {
545            storage_options: Some(HashMap::from([
546                ("use_opendal".to_string(), "true".to_string()),
547                ("region".to_string(), "us-west-2".to_string()),
548            ])),
549            ..Default::default()
550        };
551
552        let store = provider
553            .new_store(url.clone(), &params_with_flag)
554            .await
555            .unwrap();
556        assert_eq!(store.scheme, "s3");
557    }
558}