lance_io/object_store/providers/
aws.rs1use std::{
7 collections::HashMap,
8 str::FromStr,
9 sync::Arc,
10 time::{Duration, SystemTime},
11};
12
13use aws_config::default_provider::credentials::DefaultCredentialsChain;
14use aws_credential_types::provider::ProvideCredentials;
15use object_store::{
16 aws::{
17 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
18 AwsCredentialProvider,
19 },
20 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
21 StaticCredentialProvider,
22};
23use snafu::location;
24use tokio::sync::RwLock;
25use url::Url;
26
27use crate::object_store::{
28 ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
29 DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
30};
31use lance_core::error::{Error, Result};
32
33#[derive(Default, Debug)]
34pub struct AwsStoreProvider;
35
36#[async_trait::async_trait]
37impl ObjectStoreProvider for AwsStoreProvider {
38 async fn new_store(
39 &self,
40 mut base_path: Url,
41 params: &ObjectStoreParams,
42 ) -> Result<ObjectStore> {
43 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
44 let mut storage_options =
45 StorageOptions(params.storage_options.clone().unwrap_or_default());
46 let download_retry_count = storage_options.download_retry_count();
47
48 let max_retries = storage_options.client_max_retries();
49 let retry_timeout = storage_options.client_retry_timeout();
50 let retry_config = RetryConfig {
51 backoff: Default::default(),
52 max_retries,
53 retry_timeout: Duration::from_secs(retry_timeout),
54 };
55
56 storage_options.with_env_s3();
57
58 let mut storage_options = storage_options.as_s3_options();
59 let region = resolve_s3_region(&base_path, &storage_options).await?;
60 let (aws_creds, region) = build_aws_credential(
61 params.s3_credentials_refresh_offset,
62 params.aws_credentials.clone(),
63 Some(&storage_options),
64 region,
65 )
66 .await?;
67
68 storage_options
72 .entry(AmazonS3ConfigKey::ConditionalPut)
73 .or_insert_with(|| "etag".to_string());
74
75 let use_constant_size_upload_parts = storage_options
77 .get(&AmazonS3ConfigKey::Endpoint)
78 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
79 .unwrap_or(false);
80
81 let is_s3_express = check_s3_express(&base_path, &mut storage_options);
82
83 base_path.set_scheme("s3").unwrap();
85 base_path.set_query(None);
86
87 let mut builder = AmazonS3Builder::new();
89 for (key, value) in storage_options {
90 builder = builder.with_config(key, value);
91 }
92 builder = builder
93 .with_url(base_path.as_ref())
94 .with_credentials(aws_creds)
95 .with_retry(retry_config)
96 .with_region(region);
97 let inner = Arc::new(builder.build()?);
98
99 Ok(ObjectStore {
100 inner,
101 scheme: String::from(base_path.scheme()),
102 block_size,
103 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
104 use_constant_size_upload_parts,
105 list_is_lexically_ordered: !is_s3_express,
106 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
107 download_retry_count,
108 })
109 }
110}
111
112fn check_s3_express(url: &Url, storage_options: &mut HashMap<AmazonS3ConfigKey, String>) -> bool {
114 if matches!(storage_options.get(&AmazonS3ConfigKey::S3Express), Some(val) if val == "true") {
115 return true;
116 }
117
118 if url.authority().ends_with("--x-s3") {
119 storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
120 return true;
121 }
122
123 false
124}
125
126async fn resolve_s3_region(
134 url: &Url,
135 storage_options: &HashMap<AmazonS3ConfigKey, String>,
136) -> Result<Option<String>> {
137 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
138 Ok(Some(region.clone()))
139 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
140 let bucket = url.host_str().ok_or_else(|| {
143 Error::invalid_input(
144 format!("Could not parse bucket from url: {}", url),
145 location!(),
146 )
147 })?;
148
149 let mut client_options = ClientOptions::default();
150 for (key, value) in storage_options {
151 if let AmazonS3ConfigKey::Client(client_key) = key {
152 client_options = client_options.with_config(*client_key, value.clone());
153 }
154 }
155
156 let bucket_region =
157 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
158 Ok(Some(bucket_region))
159 } else {
160 Ok(None)
161 }
162}
163
164pub async fn build_aws_credential(
174 credentials_refresh_offset: Duration,
175 credentials: Option<AwsCredentialProvider>,
176 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
177 region: Option<String>,
178) -> Result<(AwsCredentialProvider, String)> {
179 use aws_config::meta::region::RegionProviderChain;
181 const DEFAULT_REGION: &str = "us-west-2";
182
183 let region = if let Some(region) = region {
184 region
185 } else {
186 RegionProviderChain::default_provider()
187 .or_else(DEFAULT_REGION)
188 .region()
189 .await
190 .map(|r| r.as_ref().to_string())
191 .unwrap_or(DEFAULT_REGION.to_string())
192 };
193
194 if let Some(creds) = credentials {
195 Ok((creds, region))
196 } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
197 Ok((Arc::new(creds), region))
198 } else {
199 let credentials_provider = DefaultCredentialsChain::builder().build().await;
200
201 Ok((
202 Arc::new(AwsCredentialAdapter::new(
203 Arc::new(credentials_provider),
204 credentials_refresh_offset,
205 )),
206 region,
207 ))
208 }
209}
210
211fn extract_static_s3_credentials(
212 options: &HashMap<AmazonS3ConfigKey, String>,
213) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
214 let key_id = options
215 .get(&AmazonS3ConfigKey::AccessKeyId)
216 .map(|s| s.to_string());
217 let secret_key = options
218 .get(&AmazonS3ConfigKey::SecretAccessKey)
219 .map(|s| s.to_string());
220 let token = options
221 .get(&AmazonS3ConfigKey::Token)
222 .map(|s| s.to_string());
223 match (key_id, secret_key, token) {
224 (Some(key_id), Some(secret_key), token) => {
225 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
226 key_id,
227 secret_key,
228 token,
229 }))
230 }
231 _ => None,
232 }
233}
234
235#[derive(Debug)]
237pub struct AwsCredentialAdapter {
238 pub inner: Arc<dyn ProvideCredentials>,
239
240 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
242
243 credentials_refresh_offset: Duration,
245}
246
247impl AwsCredentialAdapter {
248 pub fn new(
249 provider: Arc<dyn ProvideCredentials>,
250 credentials_refresh_offset: Duration,
251 ) -> Self {
252 Self {
253 inner: provider,
254 cache: Arc::new(RwLock::new(HashMap::new())),
255 credentials_refresh_offset,
256 }
257 }
258}
259
260const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
261
262#[async_trait::async_trait]
263impl CredentialProvider for AwsCredentialAdapter {
264 type Credential = ObjectStoreAwsCredential;
265
266 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
267 let cached_creds = {
268 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
269 let expired = cache_value
270 .clone()
271 .map(|cred| {
272 cred.expiry()
273 .map(|exp| {
274 exp.checked_sub(self.credentials_refresh_offset)
275 .expect("this time should always be valid")
276 < SystemTime::now()
277 })
278 .unwrap_or(false)
280 })
281 .unwrap_or(true); if expired {
283 None
284 } else {
285 cache_value.clone()
286 }
287 };
288
289 if let Some(creds) = cached_creds {
290 Ok(Arc::new(Self::Credential {
291 key_id: creds.access_key_id().to_string(),
292 secret_key: creds.secret_access_key().to_string(),
293 token: creds.session_token().map(|s| s.to_string()),
294 }))
295 } else {
296 let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
297 |e| Error::Internal {
298 message: format!("Failed to get AWS credentials: {}", e),
299 location: location!(),
300 },
301 )?);
302
303 self.cache
304 .write()
305 .await
306 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
307
308 Ok(Arc::new(Self::Credential {
309 key_id: refreshed_creds.access_key_id().to_string(),
310 secret_key: refreshed_creds.secret_access_key().to_string(),
311 token: refreshed_creds.session_token().map(|s| s.to_string()),
312 }))
313 }
314 }
315}
316
317impl StorageOptions {
318 pub fn with_env_s3(&mut self) {
320 for (os_key, os_value) in std::env::vars_os() {
321 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
322 if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
323 if !self.0.contains_key(config_key.as_ref()) {
324 self.0
325 .insert(config_key.as_ref().to_string(), value.to_string());
326 }
327 }
328 }
329 }
330 }
331
332 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
334 self.0
335 .iter()
336 .filter_map(|(key, value)| {
337 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
338 Some((s3_key, value.clone()))
339 })
340 .collect()
341 }
342}
343
344impl ObjectStoreParams {
345 pub fn with_aws_credentials(
347 aws_credentials: Option<AwsCredentialProvider>,
348 region: Option<String>,
349 ) -> Self {
350 Self {
351 aws_credentials,
352 storage_options: region
353 .map(|region| [("region".into(), region)].iter().cloned().collect()),
354 ..Default::default()
355 }
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use std::sync::atomic::{AtomicBool, Ordering};
362
363 use object_store::path::Path;
364
365 use crate::object_store::ObjectStoreRegistry;
366
367 use super::*;
368
369 #[derive(Debug, Default)]
370 struct MockAwsCredentialsProvider {
371 called: AtomicBool,
372 }
373
374 #[async_trait::async_trait]
375 impl CredentialProvider for MockAwsCredentialsProvider {
376 type Credential = ObjectStoreAwsCredential;
377
378 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
379 self.called.store(true, Ordering::Relaxed);
380 Ok(Arc::new(Self::Credential {
381 key_id: "".to_string(),
382 secret_key: "".to_string(),
383 token: None,
384 }))
385 }
386 }
387
388 #[tokio::test]
389 async fn test_injected_aws_creds_option_is_used() {
390 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
391 let registry = Arc::new(ObjectStoreRegistry::default());
392
393 let params = ObjectStoreParams {
394 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
395 ..ObjectStoreParams::default()
396 };
397
398 assert!(!mock_provider.called.load(Ordering::Relaxed));
400
401 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
402 .await
403 .unwrap();
404
405 let _ = store
407 .open(&Path::parse("/").unwrap())
408 .await
409 .unwrap()
410 .get_range(0..1)
411 .await;
412
413 assert!(mock_provider.called.load(Ordering::Relaxed));
415 }
416
417 #[test]
418 fn test_s3_path_parsing() {
419 let provider = AwsStoreProvider;
420
421 let cases = [
422 ("s3://bucket/path/to/file", "path/to/file"),
423 (
424 "s3+ddb://bucket/path/to/file?ddbTableName=test",
425 "path/to/file",
426 ),
427 ];
428
429 for (uri, expected_path) in cases {
430 let url = Url::parse(uri).unwrap();
431 let path = provider.extract_path(&url);
432 let expected_path = Path::from(expected_path);
433 assert_eq!(path, expected_path);
434 }
435 }
436
437 #[test]
438 fn test_is_s3_express() {
439 let cases = [
440 (
441 "s3://bucket/path/to/file",
442 HashMap::from([(AmazonS3ConfigKey::S3Express, "true".into())]),
443 true,
444 ),
445 (
446 "s3://bucket/path/to/file",
447 HashMap::from([(AmazonS3ConfigKey::S3Express, "false".into())]),
448 false,
449 ),
450 ("s3://bucket/path/to/file", HashMap::from([]), false),
451 (
452 "s3://bucket--x-s3/path/to/file",
453 HashMap::from([(AmazonS3ConfigKey::S3Express, "true".into())]),
454 true,
455 ),
456 (
457 "s3://bucket--x-s3/path/to/file",
458 HashMap::from([(AmazonS3ConfigKey::S3Express, "false".into())]),
459 true,
460 ),
461 ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
462 ];
463
464 for (uri, mut configs, expected) in cases {
465 let url = Url::parse(uri).unwrap();
466 let is_s3_express = check_s3_express(&url, &mut configs);
467 assert_eq!(is_s3_express, expected);
468 if is_s3_express {
469 assert!(configs
470 .get(&AmazonS3ConfigKey::S3Express)
471 .is_some_and(|opt| opt == "true"));
472 }
473 }
474 }
475}