use async_trait::async_trait;
use aws_credential_types::provider::error::CredentialsError;
use aws_credential_types::{provider::ProvideCredentials, Credentials};
use std::{fmt::Debug, sync::Arc};
#[derive(Clone, Debug)]
pub struct CredentialsUpdateOutput {
pub access_key: String,
pub secret_key: String,
pub session_token: Option<String>,
}
#[async_trait]
pub trait CredentialsUpdateCallback: Debug + Send + Sync {
async fn update_credentials(
&self,
) -> Result<CredentialsUpdateOutput, Box<dyn std::error::Error + Send + Sync + 'static>>;
}
#[derive(Debug, Clone)]
pub struct CredentialsProvider(Arc<dyn CredentialsUpdateCallback>);
impl PartialEq for CredentialsProvider {
fn eq(&self, _other: &Self) -> bool {
unreachable!("BUG: You should never compare `CustomCredentialsProvider`");
}
}
impl CredentialsProvider {
pub fn new(inner: impl CredentialsUpdateCallback + 'static) -> Self {
Self(Arc::new(inner))
}
}
impl ProvideCredentials for CredentialsProvider {
fn provide_credentials<'a>(
&'a self,
) -> aws_credential_types::provider::future::ProvideCredentials<'a>
where
Self: 'a,
{
aws_credential_types::provider::future::ProvideCredentials::new(async {
self.0
.update_credentials()
.await
.map(|output| {
Credentials::from_keys(
output.access_key,
output.secret_key,
output.session_token,
)
})
.map_err(CredentialsError::provider_error)
})
}
}