lance_io/object_store/providers/
aws.rs1use std::{
7 collections::HashMap,
8 str::FromStr,
9 sync::Arc,
10 time::{Duration, SystemTime},
11};
12
13use aws_config::default_provider::credentials::DefaultCredentialsChain;
14use aws_credential_types::provider::ProvideCredentials;
15use object_store::{
16 aws::{
17 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
18 AwsCredentialProvider,
19 },
20 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
21 StaticCredentialProvider,
22};
23use snafu::location;
24use tokio::sync::RwLock;
25use url::Url;
26
27use crate::object_store::{
28 ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
29 DEFAULT_CLOUD_IO_PARALLELISM, DEFAULT_MAX_IOP_SIZE,
30};
31use lance_core::error::{Error, Result};
32
33#[derive(Default, Debug)]
34pub struct AwsStoreProvider;
35
36#[async_trait::async_trait]
37impl ObjectStoreProvider for AwsStoreProvider {
38 async fn new_store(
39 &self,
40 mut base_path: Url,
41 params: &ObjectStoreParams,
42 ) -> Result<ObjectStore> {
43 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
44 let mut storage_options =
45 StorageOptions(params.storage_options.clone().unwrap_or_default());
46 let download_retry_count = storage_options.download_retry_count();
47
48 let max_retries = storage_options.client_max_retries();
49 let retry_timeout = storage_options.client_retry_timeout();
50 let retry_config = RetryConfig {
51 backoff: Default::default(),
52 max_retries,
53 retry_timeout: Duration::from_secs(retry_timeout),
54 };
55
56 storage_options.with_env_s3();
57
58 let mut storage_options = storage_options.as_s3_options();
59 let region = resolve_s3_region(&base_path, &storage_options).await?;
60 let (aws_creds, region) = build_aws_credential(
61 params.s3_credentials_refresh_offset,
62 params.aws_credentials.clone(),
63 Some(&storage_options),
64 region,
65 )
66 .await?;
67
68 storage_options
72 .entry(AmazonS3ConfigKey::ConditionalPut)
73 .or_insert_with(|| "etag".to_string());
74
75 let use_constant_size_upload_parts = storage_options
77 .get(&AmazonS3ConfigKey::Endpoint)
78 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
79 .unwrap_or(false);
80
81 base_path.set_scheme("s3").unwrap();
83 base_path.set_query(None);
84
85 let mut builder = AmazonS3Builder::new();
87 for (key, value) in storage_options {
88 builder = builder.with_config(key, value);
89 }
90 builder = builder
91 .with_url(base_path.as_ref())
92 .with_credentials(aws_creds)
93 .with_retry(retry_config)
94 .with_region(region);
95 let inner = Arc::new(builder.build()?);
96
97 Ok(ObjectStore {
98 inner,
99 scheme: String::from(base_path.scheme()),
100 block_size,
101 max_iop_size: *DEFAULT_MAX_IOP_SIZE,
102 use_constant_size_upload_parts,
103 list_is_lexically_ordered: true,
104 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
105 download_retry_count,
106 })
107 }
108}
109
110async fn resolve_s3_region(
118 url: &Url,
119 storage_options: &HashMap<AmazonS3ConfigKey, String>,
120) -> Result<Option<String>> {
121 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
122 Ok(Some(region.clone()))
123 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
124 let bucket = url.host_str().ok_or_else(|| {
127 Error::invalid_input(
128 format!("Could not parse bucket from url: {}", url),
129 location!(),
130 )
131 })?;
132
133 let mut client_options = ClientOptions::default();
134 for (key, value) in storage_options {
135 if let AmazonS3ConfigKey::Client(client_key) = key {
136 client_options = client_options.with_config(*client_key, value.clone());
137 }
138 }
139
140 let bucket_region =
141 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
142 Ok(Some(bucket_region))
143 } else {
144 Ok(None)
145 }
146}
147
148pub async fn build_aws_credential(
158 credentials_refresh_offset: Duration,
159 credentials: Option<AwsCredentialProvider>,
160 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
161 region: Option<String>,
162) -> Result<(AwsCredentialProvider, String)> {
163 use aws_config::meta::region::RegionProviderChain;
165 const DEFAULT_REGION: &str = "us-west-2";
166
167 let region = if let Some(region) = region {
168 region
169 } else {
170 RegionProviderChain::default_provider()
171 .or_else(DEFAULT_REGION)
172 .region()
173 .await
174 .map(|r| r.as_ref().to_string())
175 .unwrap_or(DEFAULT_REGION.to_string())
176 };
177
178 if let Some(creds) = credentials {
179 Ok((creds, region))
180 } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
181 Ok((Arc::new(creds), region))
182 } else {
183 let credentials_provider = DefaultCredentialsChain::builder().build().await;
184
185 Ok((
186 Arc::new(AwsCredentialAdapter::new(
187 Arc::new(credentials_provider),
188 credentials_refresh_offset,
189 )),
190 region,
191 ))
192 }
193}
194
195fn extract_static_s3_credentials(
196 options: &HashMap<AmazonS3ConfigKey, String>,
197) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
198 let key_id = options
199 .get(&AmazonS3ConfigKey::AccessKeyId)
200 .map(|s| s.to_string());
201 let secret_key = options
202 .get(&AmazonS3ConfigKey::SecretAccessKey)
203 .map(|s| s.to_string());
204 let token = options
205 .get(&AmazonS3ConfigKey::Token)
206 .map(|s| s.to_string());
207 match (key_id, secret_key, token) {
208 (Some(key_id), Some(secret_key), token) => {
209 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
210 key_id,
211 secret_key,
212 token,
213 }))
214 }
215 _ => None,
216 }
217}
218
219#[derive(Debug)]
221pub struct AwsCredentialAdapter {
222 pub inner: Arc<dyn ProvideCredentials>,
223
224 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
226
227 credentials_refresh_offset: Duration,
229}
230
231impl AwsCredentialAdapter {
232 pub fn new(
233 provider: Arc<dyn ProvideCredentials>,
234 credentials_refresh_offset: Duration,
235 ) -> Self {
236 Self {
237 inner: provider,
238 cache: Arc::new(RwLock::new(HashMap::new())),
239 credentials_refresh_offset,
240 }
241 }
242}
243
244const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
245
246#[async_trait::async_trait]
247impl CredentialProvider for AwsCredentialAdapter {
248 type Credential = ObjectStoreAwsCredential;
249
250 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
251 let cached_creds = {
252 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
253 let expired = cache_value
254 .clone()
255 .map(|cred| {
256 cred.expiry()
257 .map(|exp| {
258 exp.checked_sub(self.credentials_refresh_offset)
259 .expect("this time should always be valid")
260 < SystemTime::now()
261 })
262 .unwrap_or(false)
264 })
265 .unwrap_or(true); if expired {
267 None
268 } else {
269 cache_value.clone()
270 }
271 };
272
273 if let Some(creds) = cached_creds {
274 Ok(Arc::new(Self::Credential {
275 key_id: creds.access_key_id().to_string(),
276 secret_key: creds.secret_access_key().to_string(),
277 token: creds.session_token().map(|s| s.to_string()),
278 }))
279 } else {
280 let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
281 |e| Error::Internal {
282 message: format!("Failed to get AWS credentials: {}", e),
283 location: location!(),
284 },
285 )?);
286
287 self.cache
288 .write()
289 .await
290 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
291
292 Ok(Arc::new(Self::Credential {
293 key_id: refreshed_creds.access_key_id().to_string(),
294 secret_key: refreshed_creds.secret_access_key().to_string(),
295 token: refreshed_creds.session_token().map(|s| s.to_string()),
296 }))
297 }
298 }
299}
300
301impl StorageOptions {
302 pub fn with_env_s3(&mut self) {
304 for (os_key, os_value) in std::env::vars_os() {
305 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
306 if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
307 if !self.0.contains_key(config_key.as_ref()) {
308 self.0
309 .insert(config_key.as_ref().to_string(), value.to_string());
310 }
311 }
312 }
313 }
314 }
315
316 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
318 self.0
319 .iter()
320 .filter_map(|(key, value)| {
321 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
322 Some((s3_key, value.clone()))
323 })
324 .collect()
325 }
326}
327
328impl ObjectStoreParams {
329 pub fn with_aws_credentials(
331 aws_credentials: Option<AwsCredentialProvider>,
332 region: Option<String>,
333 ) -> Self {
334 Self {
335 aws_credentials,
336 storage_options: region
337 .map(|region| [("region".into(), region)].iter().cloned().collect()),
338 ..Default::default()
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use std::sync::atomic::{AtomicBool, Ordering};
346
347 use object_store::path::Path;
348
349 use crate::object_store::ObjectStoreRegistry;
350
351 use super::*;
352
353 #[derive(Debug, Default)]
354 struct MockAwsCredentialsProvider {
355 called: AtomicBool,
356 }
357
358 #[async_trait::async_trait]
359 impl CredentialProvider for MockAwsCredentialsProvider {
360 type Credential = ObjectStoreAwsCredential;
361
362 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
363 self.called.store(true, Ordering::Relaxed);
364 Ok(Arc::new(Self::Credential {
365 key_id: "".to_string(),
366 secret_key: "".to_string(),
367 token: None,
368 }))
369 }
370 }
371
372 #[tokio::test]
373 async fn test_injected_aws_creds_option_is_used() {
374 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
375 let registry = Arc::new(ObjectStoreRegistry::default());
376
377 let params = ObjectStoreParams {
378 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
379 ..ObjectStoreParams::default()
380 };
381
382 assert!(!mock_provider.called.load(Ordering::Relaxed));
384
385 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
386 .await
387 .unwrap();
388
389 let _ = store
391 .open(&Path::parse("/").unwrap())
392 .await
393 .unwrap()
394 .get_range(0..1)
395 .await;
396
397 assert!(mock_provider.called.load(Ordering::Relaxed));
399 }
400
401 #[test]
402 fn test_s3_path_parsing() {
403 let provider = AwsStoreProvider;
404
405 let cases = [
406 ("s3://bucket/path/to/file", "path/to/file"),
407 (
408 "s3+ddb://bucket/path/to/file?ddbTableName=test",
409 "path/to/file",
410 ),
411 ];
412
413 for (uri, expected_path) in cases {
414 let url = Url::parse(uri).unwrap();
415 let path = provider.extract_path(&url);
416 let expected_path = Path::from(expected_path);
417 assert_eq!(path, expected_path);
418 }
419 }
420}