use aws_smithy_async::future::BoxFuture;
use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep};
use aws_smithy_async::time::SharedTimeSource;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::endpoint::{EndpointFuture, EndpointResolverParams, ResolveEndpoint};
use aws_smithy_types::endpoint::Endpoint;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use tokio::sync::oneshot::error::TryRecvError;
use tokio::sync::oneshot::{Receiver, Sender};
#[must_use]
pub struct ReloadEndpoint {
loader: Box<dyn Fn() -> BoxFuture<'static, (Endpoint, SystemTime), BoxError> + Send + Sync>,
endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
error: Arc<Mutex<Option<BoxError>>>,
rx: Receiver<()>,
sleep: SharedAsyncSleep,
time: SharedTimeSource,
}
impl Debug for ReloadEndpoint {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReloadEndpoint").finish()
}
}
impl ReloadEndpoint {
pub async fn reload_once(&self) {
match (self.loader)().await {
Ok((endpoint, expiry)) => {
tracing::debug!("caching resolved endpoint: {:?}", (&endpoint, &expiry));
*self.endpoint.lock().unwrap() = Some(ExpiringEndpoint { endpoint, expiry })
}
Err(err) => *self.error.lock().unwrap() = Some(err),
}
}
pub async fn reload_task(mut self) {
loop {
match self.rx.try_recv() {
Ok(_) | Err(TryRecvError::Closed) => break,
_ => {}
}
self.reload_increment(self.time.now()).await;
self.sleep.sleep(Duration::from_secs(60)).await;
}
}
async fn reload_increment(&self, now: SystemTime) {
let should_reload = self.endpoint.lock().unwrap().as_ref().map(|e| e.is_expired(now)).unwrap_or(true);
if should_reload {
tracing::debug!("reloading endpoint, previous endpoint was expired");
self.reload_once().await;
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct EndpointCache {
error: Arc<Mutex<Option<BoxError>>>,
endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
_drop_guard: Arc<Sender<()>>,
}
impl ResolveEndpoint for EndpointCache {
fn resolve_endpoint<'a>(&'a self, _params: &'a EndpointResolverParams) -> EndpointFuture<'a> {
self.resolve_endpoint()
}
}
#[derive(Debug)]
struct ExpiringEndpoint {
endpoint: Endpoint,
expiry: SystemTime,
}
impl ExpiringEndpoint {
fn is_expired(&self, now: SystemTime) -> bool {
tracing::debug!(expiry = ?self.expiry, now = ?now, delta = ?self.expiry.duration_since(now), "checking expiry status of endpoint");
match self.expiry.duration_since(now) {
Err(_) => true,
Ok(t) => t < Duration::from_secs(120),
}
}
}
pub(crate) async fn create_cache<F>(
loader_fn: impl Fn() -> F + Send + Sync + 'static,
sleep: SharedAsyncSleep,
time: SharedTimeSource,
) -> Result<(EndpointCache, ReloadEndpoint), BoxError>
where
F: Future<Output = Result<(Endpoint, SystemTime), BoxError>> + Send + 'static,
{
let error_holder = Arc::new(Mutex::new(None));
let endpoint_holder = Arc::new(Mutex::new(None));
let (tx, rx) = tokio::sync::oneshot::channel();
let cache = EndpointCache {
error: error_holder.clone(),
endpoint: endpoint_holder.clone(),
_drop_guard: Arc::new(tx),
};
let reloader = ReloadEndpoint {
loader: Box::new(move || Box::pin((loader_fn)()) as _),
endpoint: endpoint_holder,
error: error_holder,
rx,
sleep,
time,
};
tracing::debug!("populating initial endpoint discovery cache");
reloader.reload_once().await;
cache.resolve_endpoint().await?;
Ok((cache, reloader))
}
impl EndpointCache {
fn resolve_endpoint(&self) -> EndpointFuture<'_> {
tracing::trace!("resolving endpoint from endpoint discovery cache");
let ep = self.endpoint.lock().unwrap().as_ref().map(|e| e.endpoint.clone()).ok_or_else(|| {
let error: Option<BoxError> = self.error.lock().unwrap().take();
error.unwrap_or_else(|| "Failed to resolve endpoint".into())
});
EndpointFuture::ready(ep)
}
}
#[cfg(test)]
mod test {
use crate::endpoint_discovery::create_cache;
use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
use aws_smithy_async::test_util::controlled_time_and_sleep;
use aws_smithy_async::time::{SharedTimeSource, SystemTimeSource, TimeSource};
use aws_smithy_types::endpoint::Endpoint;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, UNIX_EPOCH};
use tokio::time::timeout;
fn check_send_v<T: Send>(t: T) -> T {
t
}
#[tokio::test]
#[allow(unused_must_use)]
async fn check_traits() {
let (cache, reloader) = create_cache(
|| async { Ok((Endpoint::builder().url("http://foo.com").build(), SystemTimeSource::new().now())) },
SharedAsyncSleep::new(TokioSleep::new()),
SharedTimeSource::new(SystemTimeSource::new()),
)
.await
.unwrap();
check_send_v(reloader.reload_task());
check_send_v(cache);
}
#[tokio::test]
async fn erroring_endpoint_always_reloaded() {
let expiry = UNIX_EPOCH + Duration::from_secs(123456789);
let ct = Arc::new(AtomicUsize::new(0));
let (cache, reloader) = create_cache(
move || {
let shared_ct = ct.clone();
shared_ct.fetch_add(1, Ordering::AcqRel);
async move { Ok((Endpoint::builder().url(format!("http://foo.com/{shared_ct:?}")).build(), expiry)) }
},
SharedAsyncSleep::new(TokioSleep::new()),
SharedTimeSource::new(SystemTimeSource::new()),
)
.await
.expect("returns an endpoint");
assert_eq!(cache.resolve_endpoint().await.expect("ok").url(), "http://foo.com/1");
reloader.reload_increment(expiry - Duration::from_secs(240)).await;
assert_eq!(cache.resolve_endpoint().await.expect("ok").url(), "http://foo.com/1");
reloader.reload_increment(expiry).await;
assert_eq!(cache.resolve_endpoint().await.expect("ok").url(), "http://foo.com/2");
}
#[tokio::test]
async fn test_advance_of_task() {
let expiry = UNIX_EPOCH + Duration::from_secs(123456789);
let (time, sleep, mut gate) = controlled_time_and_sleep(expiry - Duration::from_secs(239));
let ct = Arc::new(AtomicUsize::new(0));
let (cache, reloader) = create_cache(
move || {
let shared_ct = ct.clone();
shared_ct.fetch_add(1, Ordering::AcqRel);
async move { Ok((Endpoint::builder().url(format!("http://foo.com/{shared_ct:?}")).build(), expiry)) }
},
SharedAsyncSleep::new(sleep.clone()),
SharedTimeSource::new(time.clone()),
)
.await
.expect("first load success");
let reload_task = tokio::spawn(reloader.reload_task());
assert!(!reload_task.is_finished());
assert_eq!(gate.expect_sleep().await.duration(), Duration::from_secs(60));
assert_eq!(cache.resolve_endpoint().await.unwrap().url(), "http://foo.com/1");
let sleep = gate.expect_sleep().await;
assert_eq!(cache.resolve_endpoint().await.unwrap().url(), "http://foo.com/1");
assert_eq!(sleep.duration(), Duration::from_secs(60));
sleep.allow_progress();
let sleep = gate.expect_sleep().await;
assert_eq!(cache.resolve_endpoint().await.unwrap().url(), "http://foo.com/2");
sleep.allow_progress();
let sleep = gate.expect_sleep().await;
drop(cache);
sleep.allow_progress();
timeout(Duration::from_secs(1), reload_task)
.await
.expect("task finishes successfully")
.expect("finishes");
}
}