1use aws_smithy_async::future::BoxFuture;
10use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep};
11use aws_smithy_async::time::SharedTimeSource;
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::endpoint::{EndpointFuture, EndpointResolverParams, ResolveEndpoint};
14use aws_smithy_types::endpoint::Endpoint;
15use std::fmt::{Debug, Formatter};
16use std::future::Future;
17use std::sync::{Arc, Mutex};
18use std::time::{Duration, SystemTime};
19use tokio::sync::oneshot::error::TryRecvError;
20use tokio::sync::oneshot::{Receiver, Sender};
21
22#[must_use]
24pub struct ReloadEndpoint {
25 loader: Box<dyn Fn() -> BoxFuture<'static, (Endpoint, SystemTime), BoxError> + Send + Sync>,
26 endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
27 error: Arc<Mutex<Option<BoxError>>>,
28 rx: Receiver<()>,
29 sleep: SharedAsyncSleep,
30 time: SharedTimeSource,
31}
32
33impl Debug for ReloadEndpoint {
34 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("ReloadEndpoint").finish()
36 }
37}
38
39impl ReloadEndpoint {
40 pub async fn reload_once(&self) {
42 match (self.loader)().await {
43 Ok((endpoint, expiry)) => {
44 tracing::debug!("caching resolved endpoint: {:?}", (&endpoint, &expiry));
45 *self.endpoint.lock().unwrap() = Some(ExpiringEndpoint { endpoint, expiry })
46 }
47 Err(err) => *self.error.lock().unwrap() = Some(err),
48 }
49 }
50
51 pub async fn reload_task(mut self) {
55 loop {
56 match self.rx.try_recv() {
57 Ok(_) | Err(TryRecvError::Closed) => break,
58 _ => {}
59 }
60 self.reload_increment(self.time.now()).await;
61 self.sleep.sleep(Duration::from_secs(60)).await;
62 }
63 }
64
65 async fn reload_increment(&self, now: SystemTime) {
66 let should_reload = self.endpoint.lock().unwrap().as_ref().map(|e| e.is_expired(now)).unwrap_or(true);
67 if should_reload {
68 tracing::debug!("reloading endpoint, previous endpoint was expired");
69 self.reload_once().await;
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
75pub(crate) struct EndpointCache {
76 error: Arc<Mutex<Option<BoxError>>>,
77 endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
78 _drop_guard: Arc<Sender<()>>,
80}
81
82impl ResolveEndpoint for EndpointCache {
83 fn resolve_endpoint<'a>(&'a self, _params: &'a EndpointResolverParams) -> EndpointFuture<'a> {
84 self.resolve_endpoint()
85 }
86}
87
88#[derive(Debug)]
89struct ExpiringEndpoint {
90 endpoint: Endpoint,
91 expiry: SystemTime,
92}
93
94impl ExpiringEndpoint {
95 fn is_expired(&self, now: SystemTime) -> bool {
96 tracing::debug!(expiry = ?self.expiry, now = ?now, delta = ?self.expiry.duration_since(now), "checking expiry status of endpoint");
97 match self.expiry.duration_since(now) {
98 Err(_) => true,
99 Ok(t) => t < Duration::from_secs(120),
100 }
101 }
102}
103
104pub(crate) async fn create_cache<F>(
105 loader_fn: impl Fn() -> F + Send + Sync + 'static,
106 sleep: SharedAsyncSleep,
107 time: SharedTimeSource,
108) -> Result<(EndpointCache, ReloadEndpoint), BoxError>
109where
110 F: Future<Output = Result<(Endpoint, SystemTime), BoxError>> + Send + 'static,
111{
112 let error_holder = Arc::new(Mutex::new(None));
113 let endpoint_holder = Arc::new(Mutex::new(None));
114 let (tx, rx) = tokio::sync::oneshot::channel();
115 let cache = EndpointCache {
116 error: error_holder.clone(),
117 endpoint: endpoint_holder.clone(),
118 _drop_guard: Arc::new(tx),
119 };
120 let reloader = ReloadEndpoint {
121 loader: Box::new(move || Box::pin((loader_fn)()) as _),
122 endpoint: endpoint_holder,
123 error: error_holder,
124 rx,
125 sleep,
126 time,
127 };
128 tracing::debug!("populating initial endpoint discovery cache");
129 reloader.reload_once().await;
130 cache.resolve_endpoint().await?;
133 Ok((cache, reloader))
134}
135
136impl EndpointCache {
137 fn resolve_endpoint(&self) -> EndpointFuture<'_> {
138 tracing::trace!("resolving endpoint from endpoint discovery cache");
139 let ep = self.endpoint.lock().unwrap().as_ref().map(|e| e.endpoint.clone()).ok_or_else(|| {
140 let error: Option<BoxError> = self.error.lock().unwrap().take();
141 error.unwrap_or_else(|| "Failed to resolve endpoint".into())
142 });
143 EndpointFuture::ready(ep)
144 }
145}
146
147#[cfg(test)]
148mod test {
149 use crate::endpoint_discovery::create_cache;
150 use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
151 use aws_smithy_async::test_util::controlled_time_and_sleep;
152 use aws_smithy_async::time::{SharedTimeSource, SystemTimeSource, TimeSource};
153 use aws_smithy_types::endpoint::Endpoint;
154 use std::sync::atomic::{AtomicUsize, Ordering};
155 use std::sync::Arc;
156 use std::time::{Duration, UNIX_EPOCH};
157 use tokio::time::timeout;
158
159 fn check_send_v<T: Send>(t: T) -> T {
160 t
161 }
162
163 #[tokio::test]
164 #[allow(unused_must_use)]
165 async fn check_traits() {
166 let (cache, reloader) = create_cache(
167 || async { Ok((Endpoint::builder().url("http://foo.com").build(), SystemTimeSource::new().now())) },
168 SharedAsyncSleep::new(TokioSleep::new()),
169 SharedTimeSource::new(SystemTimeSource::new()),
170 )
171 .await
172 .unwrap();
173 check_send_v(reloader.reload_task());
174 check_send_v(cache);
175 }
176
177 #[tokio::test]
178 async fn erroring_endpoint_always_reloaded() {
179 let expiry = UNIX_EPOCH + Duration::from_secs(123456789);
180 let ct = Arc::new(AtomicUsize::new(0));
181 let (cache, reloader) = create_cache(
182 move || {
183 let shared_ct = ct.clone();
184 shared_ct.fetch_add(1, Ordering::AcqRel);
185 async move { Ok((Endpoint::builder().url(format!("http://foo.com/{shared_ct:?}")).build(), expiry)) }
186 },
187 SharedAsyncSleep::new(TokioSleep::new()),
188 SharedTimeSource::new(SystemTimeSource::new()),
189 )
190 .await
191 .expect("returns an endpoint");
192 assert_eq!(cache.resolve_endpoint().await.expect("ok").url(), "http://foo.com/1");
193 reloader.reload_increment(expiry - Duration::from_secs(240)).await;
195 assert_eq!(cache.resolve_endpoint().await.expect("ok").url(), "http://foo.com/1");
196
197 reloader.reload_increment(expiry).await;
198 assert_eq!(cache.resolve_endpoint().await.expect("ok").url(), "http://foo.com/2");
199 }
200
201 #[tokio::test]
202 async fn test_advance_of_task() {
203 let expiry = UNIX_EPOCH + Duration::from_secs(123456789);
204 let (time, sleep, mut gate) = controlled_time_and_sleep(expiry - Duration::from_secs(239));
206 let ct = Arc::new(AtomicUsize::new(0));
207 let (cache, reloader) = create_cache(
208 move || {
209 let shared_ct = ct.clone();
210 shared_ct.fetch_add(1, Ordering::AcqRel);
211 async move { Ok((Endpoint::builder().url(format!("http://foo.com/{shared_ct:?}")).build(), expiry)) }
212 },
213 SharedAsyncSleep::new(sleep.clone()),
214 SharedTimeSource::new(time.clone()),
215 )
216 .await
217 .expect("first load success");
218 let reload_task = tokio::spawn(reloader.reload_task());
219 assert!(!reload_task.is_finished());
220 assert_eq!(gate.expect_sleep().await.duration(), Duration::from_secs(60));
223 assert_eq!(cache.resolve_endpoint().await.unwrap().url(), "http://foo.com/1");
224 let sleep = gate.expect_sleep().await;
227 assert_eq!(cache.resolve_endpoint().await.unwrap().url(), "http://foo.com/1");
229 assert_eq!(sleep.duration(), Duration::from_secs(60));
230 sleep.allow_progress();
231 let sleep = gate.expect_sleep().await;
234 assert_eq!(cache.resolve_endpoint().await.unwrap().url(), "http://foo.com/2");
235 sleep.allow_progress();
236
237 let sleep = gate.expect_sleep().await;
238 drop(cache);
239 sleep.allow_progress();
240
241 timeout(Duration::from_secs(1), reload_task)
242 .await
243 .expect("task finishes successfully")
244 .expect("finishes");
245 }
246}