use std::sync::Arc;
use std::time::Duration;
use aws_smithy_async::future::timeout::Timeout;
use aws_smithy_async::rt::sleep::AsyncSleep;
use tracing::{trace_span, Instrument};
use aws_types::credentials::{future, CredentialsError, ProvideCredentials};
use aws_types::os_shim_internal::TimeSource;
use crate::cache::ExpiringCache;
const DEFAULT_LOAD_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_CREDENTIAL_EXPIRATION: Duration = Duration::from_secs(15 * 60);
const DEFAULT_BUFFER_TIME: Duration = Duration::from_secs(10);
#[derive(Debug)]
pub struct LazyCachingCredentialsProvider {
time: TimeSource,
sleeper: Arc<dyn AsyncSleep>,
cache: ExpiringCache<Credentials, CredentialsError>,
loader: Arc<dyn ProvideCredentials>,
load_timeout: Duration,
default_credential_expiration: Duration,
}
impl LazyCachingCredentialsProvider {
fn new(
time: TimeSource,
sleeper: Arc<dyn AsyncSleep>,
loader: Arc<dyn ProvideCredentials>,
load_timeout: Duration,
default_credential_expiration: Duration,
buffer_time: Duration,
) -> Self {
LazyCachingCredentialsProvider {
time,
sleeper,
cache: ExpiringCache::new(buffer_time),
loader,
load_timeout,
default_credential_expiration,
}
}
pub fn builder() -> builder::Builder {
builder::Builder::new()
}
}
impl ProvideCredentials for LazyCachingCredentialsProvider {
fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials
where
Self: 'a,
{
let now = self.time.now();
let loader = self.loader.clone();
let timeout_future = self.sleeper.sleep(self.load_timeout);
let load_timeout = self.load_timeout;
let cache = self.cache.clone();
let default_credential_expiration = self.default_credential_expiration;
future::ProvideCredentials::new(async move {
if let Some(credentials) = cache.yield_or_clear_if_expired(now).await {
Ok(credentials)
} else {
let span = trace_span!("lazy_load_credentials");
let future = Timeout::new(loader.provide_credentials(), timeout_future);
cache
.get_or_load(|| {
async move {
let credentials = future.await.map_err(|_err| {
CredentialsError::provider_timed_out(load_timeout)
})??;
let expiry = credentials
.expiry()
.unwrap_or(now + default_credential_expiration);
Ok((credentials, expiry))
}
.instrument(span)
})
.await
}
})
}
}
use aws_types::Credentials;
pub use builder::Builder;
mod builder {
use std::sync::Arc;
use std::time::Duration;
use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep};
use aws_types::credentials::ProvideCredentials;
use super::{
LazyCachingCredentialsProvider, DEFAULT_BUFFER_TIME, DEFAULT_CREDENTIAL_EXPIRATION,
DEFAULT_LOAD_TIMEOUT,
};
use crate::provider_config::ProviderConfig;
use aws_types::os_shim_internal::TimeSource;
#[derive(Default)]
pub struct Builder {
sleep: Option<Arc<dyn AsyncSleep>>,
time_source: Option<TimeSource>,
load: Option<Arc<dyn ProvideCredentials>>,
load_timeout: Option<Duration>,
buffer_time: Option<Duration>,
default_credential_expiration: Option<Duration>,
}
impl Builder {
pub fn new() -> Self {
Default::default()
}
pub fn configure(mut self, config: &ProviderConfig) -> Self {
self.sleep = config.sleep();
self.time_source = Some(config.time_source());
self
}
pub fn load(mut self, loader: impl ProvideCredentials + 'static) -> Self {
self.load = Some(Arc::new(loader));
self
}
pub fn sleep(mut self, sleep: impl AsyncSleep + 'static) -> Self {
self.sleep = Some(Arc::new(sleep));
self
}
pub fn load_timeout(mut self, timeout: Duration) -> Self {
self.load_timeout = Some(timeout);
self
}
pub fn buffer_time(mut self, buffer_time: Duration) -> Self {
self.buffer_time = Some(buffer_time);
self
}
pub fn default_credential_expiration(mut self, duration: Duration) -> Self {
self.default_credential_expiration = Some(duration);
self
}
pub fn build(self) -> LazyCachingCredentialsProvider {
let default_credential_expiration = self
.default_credential_expiration
.unwrap_or(DEFAULT_CREDENTIAL_EXPIRATION);
assert!(
default_credential_expiration >= DEFAULT_CREDENTIAL_EXPIRATION,
"default_credential_expiration must be at least 15 minutes"
);
LazyCachingCredentialsProvider::new(
self.time_source.unwrap_or_default(),
self.sleep.unwrap_or_else(|| {
default_async_sleep().expect("no default sleep implementation available")
}),
self.load.expect("load implementation is required"),
self.load_timeout.unwrap_or(DEFAULT_LOAD_TIMEOUT),
self.buffer_time.unwrap_or(DEFAULT_BUFFER_TIME),
default_credential_expiration,
)
}
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_types::credentials::{self, CredentialsError, ProvideCredentials};
use aws_types::Credentials;
use tracing::info;
use tracing_test::traced_test;
use crate::meta::credentials::credential_fn::provide_credentials_fn;
use super::{
LazyCachingCredentialsProvider, TimeSource, DEFAULT_BUFFER_TIME,
DEFAULT_CREDENTIAL_EXPIRATION, DEFAULT_LOAD_TIMEOUT,
};
use aws_types::os_shim_internal::ManualTimeSource;
fn test_provider(
time: TimeSource,
load_list: Vec<credentials::Result>,
) -> LazyCachingCredentialsProvider {
let load_list = Arc::new(Mutex::new(load_list));
LazyCachingCredentialsProvider::new(
time,
Arc::new(TokioSleep::new()),
Arc::new(provide_credentials_fn(move || {
let list = load_list.clone();
async move {
let next = list.lock().unwrap().remove(0);
info!("refreshing the credentials to {:?}", next);
next
}
})),
DEFAULT_LOAD_TIMEOUT,
DEFAULT_CREDENTIAL_EXPIRATION,
DEFAULT_BUFFER_TIME,
)
}
fn epoch_secs(secs: u64) -> SystemTime {
SystemTime::UNIX_EPOCH + Duration::from_secs(secs)
}
fn credentials(expired_secs: u64) -> Credentials {
Credentials::new("test", "test", None, Some(epoch_secs(expired_secs)), "test")
}
async fn expect_creds(expired_secs: u64, provider: &LazyCachingCredentialsProvider) {
let creds = provider
.provide_credentials()
.await
.expect("expected credentials");
assert_eq!(Some(epoch_secs(expired_secs)), creds.expiry());
}
#[traced_test]
#[tokio::test]
async fn initial_populate_credentials() {
let time = ManualTimeSource::new(UNIX_EPOCH);
let loader = Arc::new(provide_credentials_fn(|| async {
info!("refreshing the credentials");
Ok(credentials(1000))
}));
let provider = LazyCachingCredentialsProvider::new(
TimeSource::manual(&time),
Arc::new(TokioSleep::new()),
loader,
DEFAULT_LOAD_TIMEOUT,
DEFAULT_CREDENTIAL_EXPIRATION,
DEFAULT_BUFFER_TIME,
);
assert_eq!(
epoch_secs(1000),
provider
.provide_credentials()
.await
.unwrap()
.expiry()
.unwrap()
);
}
#[traced_test]
#[tokio::test]
async fn reload_expired_credentials() {
let mut time = ManualTimeSource::new(epoch_secs(100));
let provider = test_provider(
TimeSource::manual(&time),
vec![
Ok(credentials(1000)),
Ok(credentials(2000)),
Ok(credentials(3000)),
],
);
expect_creds(1000, &provider).await;
expect_creds(1000, &provider).await;
time.set_time(epoch_secs(1500));
expect_creds(2000, &provider).await;
expect_creds(2000, &provider).await;
time.set_time(epoch_secs(2500));
expect_creds(3000, &provider).await;
expect_creds(3000, &provider).await;
}
#[traced_test]
#[tokio::test]
async fn load_failed_error() {
let mut time = ManualTimeSource::new(epoch_secs(100));
let provider = test_provider(
TimeSource::manual(&time),
vec![
Ok(credentials(1000)),
Err(CredentialsError::not_loaded("failed")),
],
);
expect_creds(1000, &provider).await;
time.set_time(epoch_secs(1500));
assert!(provider.provide_credentials().await.is_err());
}
#[traced_test]
#[test]
fn load_contention() {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_time()
.worker_threads(16)
.build()
.unwrap();
let time = ManualTimeSource::new(epoch_secs(0));
let provider = Arc::new(test_provider(
TimeSource::manual(&time),
vec![
Ok(credentials(500)),
Ok(credentials(1500)),
Ok(credentials(2500)),
Ok(credentials(3500)),
Ok(credentials(4500)),
],
));
let locked_time = Arc::new(Mutex::new(time));
for i in 0..4 {
let mut tasks = Vec::new();
for j in 0..50 {
let provider = provider.clone();
let time = locked_time.clone();
tasks.push(rt.spawn(async move {
let now = epoch_secs(i * 1000 + (4 * j));
time.lock().unwrap().set_time(now);
let creds = provider.provide_credentials().await.unwrap();
assert!(
creds.expiry().unwrap() >= now,
"{:?} >= {:?}",
creds.expiry(),
now
);
}));
}
for task in tasks {
rt.block_on(task).unwrap();
}
}
}
#[tokio::test]
#[traced_test]
async fn load_timeout() {
let time = ManualTimeSource::new(epoch_secs(100));
let provider = LazyCachingCredentialsProvider::new(
TimeSource::manual(&time),
Arc::new(TokioSleep::new()),
Arc::new(provide_credentials_fn(|| async {
tokio::time::sleep(Duration::from_millis(10)).await;
Ok(credentials(1000))
})),
Duration::from_millis(5),
DEFAULT_CREDENTIAL_EXPIRATION,
DEFAULT_BUFFER_TIME,
);
assert!(matches!(
provider.provide_credentials().await,
Err(CredentialsError::ProviderTimedOut { .. })
));
}
}