use std::fmt::{Debug, Formatter, Result as FmtResult};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
use aws_credential_types::Credentials;
use object_store::aws::AwsCredential;
use object_store::client::CredentialProvider;
use object_store::Error as ObjectStoreError;
use tokio::sync::RwLock;
const BUFFER: Duration = Duration::from_mins(1);
pub struct AwsCredentialAdapter {
provider: SharedCredentialsProvider,
cache: RwLock<Option<Credentials>>,
}
impl AwsCredentialAdapter {
pub fn new(provider: SharedCredentialsProvider) -> Self {
Self {
provider,
cache: RwLock::new(None),
}
}
async fn get(&self) -> Result<Credentials, Box<dyn std::error::Error + Send + Sync>> {
{
let guard = self.cache.read().await;
if let Some(cached) = guard.as_ref() {
if Self::is_valid(cached) {
return Ok(cached.clone());
}
}
}
let mut guard = self.cache.write().await;
if let Some(cached) = guard.as_ref() {
if Self::is_valid(cached) {
return Ok(cached.clone());
}
}
let creds = self.provider.provide_credentials().await?;
*guard = Some(creds.clone());
Ok(creds)
}
fn is_valid(creds: &Credentials) -> bool {
creds
.expiry()
.map(|exp| SystemTime::now() + BUFFER < exp)
.unwrap_or(true)
}
}
impl Debug for AwsCredentialAdapter {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.debug_struct("AwsCredentialAdapter").finish()
}
}
#[async_trait::async_trait]
impl CredentialProvider for AwsCredentialAdapter {
type Credential = AwsCredential;
async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
let creds = self.get().await.map_err(|e| ObjectStoreError::Generic {
store: "S3",
source: e,
})?;
Ok(Arc::new(AwsCredential {
key_id: creds.access_key_id().to_string(),
secret_key: creds.secret_access_key().to_string(),
token: creds.session_token().map(|s| s.to_string()),
}))
}
}