lance_io/object_store/providers/
aws.rs1use std::{
7 collections::HashMap,
8 str::FromStr,
9 sync::Arc,
10 time::{Duration, SystemTime},
11};
12
13use aws_config::default_provider::credentials::DefaultCredentialsChain;
14use aws_credential_types::provider::ProvideCredentials;
15use object_store::{
16 aws::{
17 AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential,
18 AwsCredentialProvider,
19 },
20 ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig,
21 StaticCredentialProvider,
22};
23use snafu::location;
24use tokio::sync::RwLock;
25use url::Url;
26
27use crate::object_store::{
28 ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE,
29 DEFAULT_CLOUD_IO_PARALLELISM,
30};
31use lance_core::error::{Error, Result};
32
33#[derive(Default, Debug)]
34pub struct AwsStoreProvider;
35
36#[async_trait::async_trait]
37impl ObjectStoreProvider for AwsStoreProvider {
38 async fn new_store(
39 &self,
40 mut base_path: Url,
41 params: &ObjectStoreParams,
42 ) -> Result<ObjectStore> {
43 let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE);
44 let mut storage_options =
45 StorageOptions(params.storage_options.clone().unwrap_or_default());
46 let download_retry_count = storage_options.download_retry_count();
47
48 let max_retries = storage_options.client_max_retries();
49 let retry_timeout = storage_options.client_retry_timeout();
50 let retry_config = RetryConfig {
51 backoff: Default::default(),
52 max_retries,
53 retry_timeout: Duration::from_secs(retry_timeout),
54 };
55
56 storage_options.with_env_s3();
57
58 let mut storage_options = storage_options.as_s3_options();
59 let region = resolve_s3_region(&base_path, &storage_options).await?;
60 let (aws_creds, region) = build_aws_credential(
61 params.s3_credentials_refresh_offset,
62 params.aws_credentials.clone(),
63 Some(&storage_options),
64 region,
65 )
66 .await?;
67
68 storage_options
72 .entry(AmazonS3ConfigKey::ConditionalPut)
73 .or_insert_with(|| "etag".to_string());
74
75 let use_constant_size_upload_parts = storage_options
77 .get(&AmazonS3ConfigKey::Endpoint)
78 .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com"))
79 .unwrap_or(false);
80
81 base_path.set_scheme("s3").unwrap();
83 base_path.set_query(None);
84
85 let mut builder = AmazonS3Builder::new();
87 for (key, value) in storage_options {
88 builder = builder.with_config(key, value);
89 }
90 builder = builder
91 .with_url(base_path.as_ref())
92 .with_credentials(aws_creds)
93 .with_retry(retry_config)
94 .with_region(region);
95 let inner = Arc::new(builder.build()?);
96
97 Ok(ObjectStore {
98 inner,
99 scheme: String::from(base_path.scheme()),
100 block_size,
101 use_constant_size_upload_parts,
102 list_is_lexically_ordered: true,
103 io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM,
104 download_retry_count,
105 })
106 }
107}
108
109async fn resolve_s3_region(
117 url: &Url,
118 storage_options: &HashMap<AmazonS3ConfigKey, String>,
119) -> Result<Option<String>> {
120 if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) {
121 Ok(Some(region.clone()))
122 } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() {
123 let bucket = url.host_str().ok_or_else(|| {
126 Error::invalid_input(
127 format!("Could not parse bucket from url: {}", url),
128 location!(),
129 )
130 })?;
131
132 let mut client_options = ClientOptions::default();
133 for (key, value) in storage_options {
134 if let AmazonS3ConfigKey::Client(client_key) = key {
135 client_options = client_options.with_config(*client_key, value.clone());
136 }
137 }
138
139 let bucket_region =
140 object_store::aws::resolve_bucket_region(bucket, &client_options).await?;
141 Ok(Some(bucket_region))
142 } else {
143 Ok(None)
144 }
145}
146
147pub async fn build_aws_credential(
157 credentials_refresh_offset: Duration,
158 credentials: Option<AwsCredentialProvider>,
159 storage_options: Option<&HashMap<AmazonS3ConfigKey, String>>,
160 region: Option<String>,
161) -> Result<(AwsCredentialProvider, String)> {
162 use aws_config::meta::region::RegionProviderChain;
164 const DEFAULT_REGION: &str = "us-west-2";
165
166 let region = if let Some(region) = region {
167 region
168 } else {
169 RegionProviderChain::default_provider()
170 .or_else(DEFAULT_REGION)
171 .region()
172 .await
173 .map(|r| r.as_ref().to_string())
174 .unwrap_or(DEFAULT_REGION.to_string())
175 };
176
177 if let Some(creds) = credentials {
178 Ok((creds, region))
179 } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) {
180 Ok((Arc::new(creds), region))
181 } else {
182 let credentials_provider = DefaultCredentialsChain::builder().build().await;
183
184 Ok((
185 Arc::new(AwsCredentialAdapter::new(
186 Arc::new(credentials_provider),
187 credentials_refresh_offset,
188 )),
189 region,
190 ))
191 }
192}
193
194fn extract_static_s3_credentials(
195 options: &HashMap<AmazonS3ConfigKey, String>,
196) -> Option<StaticCredentialProvider<ObjectStoreAwsCredential>> {
197 let key_id = options
198 .get(&AmazonS3ConfigKey::AccessKeyId)
199 .map(|s| s.to_string());
200 let secret_key = options
201 .get(&AmazonS3ConfigKey::SecretAccessKey)
202 .map(|s| s.to_string());
203 let token = options
204 .get(&AmazonS3ConfigKey::Token)
205 .map(|s| s.to_string());
206 match (key_id, secret_key, token) {
207 (Some(key_id), Some(secret_key), token) => {
208 Some(StaticCredentialProvider::new(ObjectStoreAwsCredential {
209 key_id,
210 secret_key,
211 token,
212 }))
213 }
214 _ => None,
215 }
216}
217
218#[derive(Debug)]
220pub struct AwsCredentialAdapter {
221 pub inner: Arc<dyn ProvideCredentials>,
222
223 cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
225
226 credentials_refresh_offset: Duration,
228}
229
230impl AwsCredentialAdapter {
231 pub fn new(
232 provider: Arc<dyn ProvideCredentials>,
233 credentials_refresh_offset: Duration,
234 ) -> Self {
235 Self {
236 inner: provider,
237 cache: Arc::new(RwLock::new(HashMap::new())),
238 credentials_refresh_offset,
239 }
240 }
241}
242
243const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
244
245#[async_trait::async_trait]
246impl CredentialProvider for AwsCredentialAdapter {
247 type Credential = ObjectStoreAwsCredential;
248
249 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
250 let cached_creds = {
251 let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
252 let expired = cache_value
253 .clone()
254 .map(|cred| {
255 cred.expiry()
256 .map(|exp| {
257 exp.checked_sub(self.credentials_refresh_offset)
258 .expect("this time should always be valid")
259 < SystemTime::now()
260 })
261 .unwrap_or(false)
263 })
264 .unwrap_or(true); if expired {
266 None
267 } else {
268 cache_value.clone()
269 }
270 };
271
272 if let Some(creds) = cached_creds {
273 Ok(Arc::new(Self::Credential {
274 key_id: creds.access_key_id().to_string(),
275 secret_key: creds.secret_access_key().to_string(),
276 token: creds.session_token().map(|s| s.to_string()),
277 }))
278 } else {
279 let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err(
280 |e| Error::Internal {
281 message: format!("Failed to get AWS credentials: {}", e),
282 location: location!(),
283 },
284 )?);
285
286 self.cache
287 .write()
288 .await
289 .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
290
291 Ok(Arc::new(Self::Credential {
292 key_id: refreshed_creds.access_key_id().to_string(),
293 secret_key: refreshed_creds.secret_access_key().to_string(),
294 token: refreshed_creds.session_token().map(|s| s.to_string()),
295 }))
296 }
297 }
298}
299
300impl StorageOptions {
301 pub fn with_env_s3(&mut self) {
303 for (os_key, os_value) in std::env::vars_os() {
304 if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) {
305 if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) {
306 if !self.0.contains_key(config_key.as_ref()) {
307 self.0
308 .insert(config_key.as_ref().to_string(), value.to_string());
309 }
310 }
311 }
312 }
313 }
314
315 pub fn as_s3_options(&self) -> HashMap<AmazonS3ConfigKey, String> {
317 self.0
318 .iter()
319 .filter_map(|(key, value)| {
320 let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?;
321 Some((s3_key, value.clone()))
322 })
323 .collect()
324 }
325}
326
327impl ObjectStoreParams {
328 pub fn with_aws_credentials(
330 aws_credentials: Option<AwsCredentialProvider>,
331 region: Option<String>,
332 ) -> Self {
333 Self {
334 aws_credentials,
335 storage_options: region
336 .map(|region| [("region".into(), region)].iter().cloned().collect()),
337 ..Default::default()
338 }
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use std::sync::atomic::{AtomicBool, Ordering};
345
346 use object_store::path::Path;
347
348 use crate::object_store::ObjectStoreRegistry;
349
350 use super::*;
351
352 #[derive(Debug, Default)]
353 struct MockAwsCredentialsProvider {
354 called: AtomicBool,
355 }
356
357 #[async_trait::async_trait]
358 impl CredentialProvider for MockAwsCredentialsProvider {
359 type Credential = ObjectStoreAwsCredential;
360
361 async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
362 self.called.store(true, Ordering::Relaxed);
363 Ok(Arc::new(Self::Credential {
364 key_id: "".to_string(),
365 secret_key: "".to_string(),
366 token: None,
367 }))
368 }
369 }
370
371 #[tokio::test]
372 async fn test_injected_aws_creds_option_is_used() {
373 let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
374 let registry = Arc::new(ObjectStoreRegistry::default());
375
376 let params = ObjectStoreParams {
377 aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider),
378 ..ObjectStoreParams::default()
379 };
380
381 assert!(!mock_provider.called.load(Ordering::Relaxed));
383
384 let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms)
385 .await
386 .unwrap();
387
388 let _ = store
390 .open(&Path::parse("/").unwrap())
391 .await
392 .unwrap()
393 .get_range(0..1)
394 .await;
395
396 assert!(mock_provider.called.load(Ordering::Relaxed));
398 }
399
400 #[test]
401 fn test_s3_path_parsing() {
402 let provider = AwsStoreProvider;
403
404 let cases = [
405 ("s3://bucket/path/to/file", "path/to/file"),
406 (
407 "s3+ddb://bucket/path/to/file?ddbTableName=test",
408 "path/to/file",
409 ),
410 ];
411
412 for (uri, expected_path) in cases {
413 let url = Url::parse(uri).unwrap();
414 let path = provider.extract_path(&url);
415 let expected_path = Path::from(expected_path);
416 assert_eq!(path, expected_path);
417 }
418 }
419}