1use std::{
7 collections::HashMap,
8 str::FromStr,
9 sync::Arc,
10 time::{Duration, SystemTime},
11};
12
13use aws_config::default_provider::credentials::DefaultCredentialsChain;
14use aws_credential_types::provider::ProvideCredentials;
15use object_store::{
16 aws::{
17 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
18 AwsCredentialProvider,
19 },
20 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
21 StaticCredentialProvider,
22};
23use snafu::location;
24use tokio::sync::RwLock;
25use url::Url;
26
27use crate::object_store::{
28 ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
29 DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
30};
31use lance_core::error::{Error, Result};
32
33#[derive(Default, Debug)]
34pub struct AwsStoreProvider;
35
36#[async_trait::async_trait]
37impl ObjectStoreProvider for AwsStoreProvider {
38 async fn new_store(
39 &self,
40 mut base_path: Url,
41 params: &ObjectStoreParams,
42 ) -> Result<ObjectStore> {
43 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
44 let mut storage_options =
45 StorageOptions(params.storage_options.clone().unwrap_or_default());
46 let download_retry_count = storage_options.download_retry_count();
47
48 let max_retries = storage_options.client_max_retries();
49 let retry_timeout = storage_options.client_retry_timeout();
50 let retry_config = RetryConfig {
51 backoff: Default::default(),
52 max_retries,
53 retry_timeout: Duration::from_secs(retry_timeout),
54 };
55
56 storage_options.with_env_s3();
57
58 let mut storage_options = storage_options.as_s3_options();
59 let region = resolve_s3_region(&base_path, &storage_options).await?;
60 let (aws_creds, region) = build_aws_credential(
61 params.s3_credentials_refresh_offset,
62 params.aws_credentials.clone(),
63 Some(&storage_options),
64 region,
65 )
66 .await?;
67
68 storage_options
72 .entry(AmazonS3ConfigKey::ConditionalPut)
73 .or_insert_with(|| "etag".to_string());
74
75 let use_constant_size_upload_parts = storage_options
77 .get(&AmazonS3ConfigKey::Endpoint)
78 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
79 .unwrap_or(false);
80
81 let is_s3_express = check_s3_express(&base_path, &mut storage_options);
82
83 base_path.set_scheme("s3").unwrap();
85 base_path.set_query(None);
86
87 let mut builder = AmazonS3Builder::new();
89 for (key, value) in storage_options {
90 builder = builder.with_config(key, value);
91 }
92 builder = builder
93 .with_url(base_path.as_ref())
94 .with_credentials(aws_creds)
95 .with_retry(retry_config)
96 .with_region(region);
97 let inner = Arc::new(builder.build()?);
98
99 Ok(ObjectStore {
100 inner,
101 scheme: String::from(base_path.scheme()),
102 block_size,
103 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
104 use_constant_size_upload_parts,
105 list_is_lexically_ordered: !is_s3_express,
106 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
107 download_retry_count,
108 })
109 }
110}
111
112fn check_s3_express(url: &Url, storage_options: &mut HashMap<AmazonS3ConfigKey, String>) -> bool {
114 if matches!(storage_options.get(&AmazonS3ConfigKey::S3Express), Some(val) if val == "true") {
115 return true;
116 }
117
118 if url.authority().ends_with("--x-s3") {
119 storage_options.insert(AmazonS3ConfigKey::S3Express, true.to_string());
120 return true;
121 }
122
123 false
124}
125
126async fn resolve_s3_region(
134 url: &Url,
135 storage_options: &HashMap<AmazonS3ConfigKey, String>,
136) -> Result<Option<String>> {
137 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
138 Ok(Some(region.clone()))
139 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
140 let bucket = url.host_str().ok_or_else(|| {
143 Error::invalid_input(
144 format!("Could not parse bucket from url: {}", url),
145 location!(),
146 )
147 })?;
148
149 let mut client_options = ClientOptions::default();
150 for (key, value) in storage_options {
151 if let AmazonS3ConfigKey::Client(client_key) = key {
152 client_options = client_options.with_config(*client_key, value.clone());
153 }
154 }
155
156 let bucket_region =
157 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
158 Ok(Some(bucket_region))
159 } else {
160 Ok(None)
161 }
162}
163
164pub async fn build_aws_credential(
174 credentials_refresh_offset: Duration,
175 credentials: Option<AwsCredentialProvider>,
176 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
177 region: Option<String>,
178) -> Result<(AwsCredentialProvider, String)> {
179 use aws_config::meta::region::RegionProviderChain;
181 const DEFAULT_REGION: &str = "us-west-2";
182
183 let region = if let Some(region) = region {
184 region
185 } else {
186 RegionProviderChain::default_provider()
187 .or_else(DEFAULT_REGION)
188 .region()
189 .await
190 .map(|r| r.as_ref().to_string())
191 .unwrap_or(DEFAULT_REGION.to_string())
192 };
193
194 if let Some(creds) = credentials {
195 Ok((creds, region))
196 } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
197 Ok((Arc::new(creds), region))
198 } else {
199 let credentials_provider = DefaultCredentialsChain::builder().build().await;
200
201 Ok((
202 Arc::new(AwsCredentialAdapter::new(
203 Arc::new(credentials_provider),
204 credentials_refresh_offset,
205 )),
206 region,
207 ))
208 }
209}
210
211fn extract_static_s3_credentials(
212 options: &HashMap<AmazonS3ConfigKey, String>,
213) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
214 let key_id = options.get(&AmazonS3ConfigKey::AccessKeyId).cloned();
215 let secret_key = options.get(&AmazonS3ConfigKey::SecretAccessKey).cloned();
216 let token = options.get(&AmazonS3ConfigKey::Token).cloned();
217 match (key_id, secret_key, token) {
218 (Some(key_id), Some(secret_key), token) => {
219 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
220 key_id,
221 secret_key,
222 token,
223 }))
224 }
225 _ => None,
226 }
227}
228
229#[derive(Debug)]
231pub struct AwsCredentialAdapter {
232 pub inner: Arc<dyn ProvideCredentials>,
233
234 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
236
237 credentials_refresh_offset: Duration,
239}
240
241impl AwsCredentialAdapter {
242 pub fn new(
243 provider: Arc<dyn ProvideCredentials>,
244 credentials_refresh_offset: Duration,
245 ) -> Self {
246 Self {
247 inner: provider,
248 cache: Arc::new(RwLock::new(HashMap::new())),
249 credentials_refresh_offset,
250 }
251 }
252}
253
254const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
255
256#[async_trait::async_trait]
257impl CredentialProvider for AwsCredentialAdapter {
258 type Credential = ObjectStoreAwsCredential;
259
260 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
261 let cached_creds = {
262 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
263 let expired = cache_value
264 .clone()
265 .map(|cred| {
266 cred.expiry()
267 .map(|exp| {
268 exp.checked_sub(self.credentials_refresh_offset)
269 .expect("this time should always be valid")
270 < SystemTime::now()
271 })
272 .unwrap_or(false)
274 })
275 .unwrap_or(true); if expired {
277 None
278 } else {
279 cache_value.clone()
280 }
281 };
282
283 if let Some(creds) = cached_creds {
284 Ok(Arc::new(Self::Credential {
285 key_id: creds.access_key_id().to_string(),
286 secret_key: creds.secret_access_key().to_string(),
287 token: creds.session_token().map(|s| s.to_string()),
288 }))
289 } else {
290 let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
291 |e| Error::Internal {
292 message: format!("Failed to get AWS credentials: {:?}", e),
293 location: location!(),
294 },
295 )?);
296
297 self.cache
298 .write()
299 .await
300 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
301
302 Ok(Arc::new(Self::Credential {
303 key_id: refreshed_creds.access_key_id().to_string(),
304 secret_key: refreshed_creds.secret_access_key().to_string(),
305 token: refreshed_creds.session_token().map(|s| s.to_string()),
306 }))
307 }
308 }
309}
310
311impl StorageOptions {
312 pub fn with_env_s3(&mut self) {
314 for (os_key, os_value) in std::env::vars_os() {
315 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
316 if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
317 if !self.0.contains_key(config_key.as_ref()) {
318 self.0
319 .insert(config_key.as_ref().to_string(), value.to_string());
320 }
321 }
322 }
323 }
324 }
325
326 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
328 self.0
329 .iter()
330 .filter_map(|(key, value)| {
331 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
332 Some((s3_key, value.clone()))
333 })
334 .collect()
335 }
336}
337
338impl ObjectStoreParams {
339 pub fn with_aws_credentials(
341 aws_credentials: Option<AwsCredentialProvider>,
342 region: Option<String>,
343 ) -> Self {
344 Self {
345 aws_credentials,
346 storage_options: region
347 .map(|region| [("region".into(), region)].iter().cloned().collect()),
348 ..Default::default()
349 }
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use std::sync::atomic::{AtomicBool, Ordering};
356
357 use object_store::path::Path;
358
359 use crate::object_store::ObjectStoreRegistry;
360
361 use super::*;
362
363 #[derive(Debug, Default)]
364 struct MockAwsCredentialsProvider {
365 called: AtomicBool,
366 }
367
368 #[async_trait::async_trait]
369 impl CredentialProvider for MockAwsCredentialsProvider {
370 type Credential = ObjectStoreAwsCredential;
371
372 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
373 self.called.store(true, Ordering::Relaxed);
374 Ok(Arc::new(Self::Credential {
375 key_id: "".to_string(),
376 secret_key: "".to_string(),
377 token: None,
378 }))
379 }
380 }
381
382 #[tokio::test]
383 async fn test_injected_aws_creds_option_is_used() {
384 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
385 let registry = Arc::new(ObjectStoreRegistry::default());
386
387 let params = ObjectStoreParams {
388 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
389 ..ObjectStoreParams::default()
390 };
391
392 assert!(!mock_provider.called.load(Ordering::Relaxed));
394
395 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
396 .await
397 .unwrap();
398
399 let _ = store
401 .open(&Path::parse("/").unwrap())
402 .await
403 .unwrap()
404 .get_range(0..1)
405 .await;
406
407 assert!(mock_provider.called.load(Ordering::Relaxed));
409 }
410
411 #[test]
412 fn test_s3_path_parsing() {
413 let provider = AwsStoreProvider;
414
415 let cases = [
416 ("s3://bucket/path/to/file", "path/to/file"),
417 ("s3://bucket/测试path/to/file", "测试path/to/file"),
419 ("s3://bucket/path/&to/file", "path/&to/file"),
420 ("s3://bucket/path/=to/file", "path/=to/file"),
421 (
422 "s3+ddb://bucket/path/to/file?ddbTableName=test",
423 "path/to/file",
424 ),
425 ];
426
427 for (uri, expected_path) in cases {
428 let url = Url::parse(uri).unwrap();
429 let path = provider.extract_path(&url).unwrap();
430 let expected_path = Path::from(expected_path);
431 assert_eq!(path, expected_path)
432 }
433 }
434
435 #[test]
436 fn test_is_s3_express() {
437 let cases = [
438 (
439 "s3://bucket/path/to/file",
440 HashMap::from([(AmazonS3ConfigKey::S3Express, "true".into())]),
441 true,
442 ),
443 (
444 "s3://bucket/path/to/file",
445 HashMap::from([(AmazonS3ConfigKey::S3Express, "false".into())]),
446 false,
447 ),
448 ("s3://bucket/path/to/file", HashMap::from([]), false),
449 (
450 "s3://bucket--x-s3/path/to/file",
451 HashMap::from([(AmazonS3ConfigKey::S3Express, "true".into())]),
452 true,
453 ),
454 (
455 "s3://bucket--x-s3/path/to/file",
456 HashMap::from([(AmazonS3ConfigKey::S3Express, "false".into())]),
457 true,
458 ),
459 ("s3://bucket--x-s3/path/to/file", HashMap::from([]), true),
460 ];
461
462 for (uri, mut configs, expected) in cases {
463 let url = Url::parse(uri).unwrap();
464 let is_s3_express = check_s3_express(&url, &mut configs);
465 assert_eq!(is_s3_express, expected);
466 if is_s3_express {
467 assert!(configs
468 .get(&AmazonS3ConfigKey::S3Express)
469 .is_some_and(|opt| opt == "true"));
470 }
471 }
472 }
473}