aws_sdk_timestreamquery/
endpoint_discovery.rs

1// Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT.
2/*
3 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7//! Maintain a cache of discovered endpoints
8
9use aws_smithy_async::future::BoxFuture;
10use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep};
11use aws_smithy_async::time::SharedTimeSource;
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::endpoint::{EndpointFuture, EndpointResolverParams, ResolveEndpoint};
14use aws_smithy_types::endpoint::Endpoint;
15use std::fmt::{Debug, Formatter};
16use std::future::Future;
17use std::sync::{Arc, Mutex};
18use std::time::{Duration, SystemTime};
19use tokio::sync::oneshot::error::TryRecvError;
20use tokio::sync::oneshot::{Receiver, Sender};
21
22/// Endpoint reloader
23#[must_use]
24pub struct ReloadEndpoint {
25    loader: Box<dyn Fn() -> BoxFuture<'static, (Endpoint, SystemTime), BoxError> + Send + Sync>,
26    endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
27    error: Arc<Mutex<Option<BoxError>>>,
28    rx: Receiver<()>,
29    sleep: SharedAsyncSleep,
30    time: SharedTimeSource,
31}
32
33impl Debug for ReloadEndpoint {
34    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("ReloadEndpoint").finish()
36    }
37}
38
39impl ReloadEndpoint {
40    /// Reload the endpoint once
41    pub async fn reload_once(&self) {
42        match (self.loader)().await {
43            Ok((endpoint, expiry)) => {
44                tracing::debug!("caching resolved endpoint: {:?}", (&endpoint, &expiry));
45                *self.endpoint.lock().unwrap() = Some(ExpiringEndpoint { endpoint, expiry })
46            }
47            Err(err) => *self.error.lock().unwrap() = Some(err),
48        }
49    }
50
51    /// An infinite loop task that will reload the endpoint
52    ///
53    /// This task will terminate when the corresponding [`Client`](crate::Client) is dropped.
54    pub async fn reload_task(mut self) {
55        loop {
56            match self.rx.try_recv() {
57                Ok(_) | Err(TryRecvError::Closed) => break,
58                _ => {}
59            }
60            self.reload_increment(self.time.now()).await;
61            self.sleep.sleep(Duration::from_secs(60)).await;
62        }
63    }
64
65    async fn reload_increment(&self, now: SystemTime) {
66        let should_reload = self.endpoint.lock().unwrap().as_ref().map(|e| e.is_expired(now)).unwrap_or(true);
67        if should_reload {
68            tracing::debug!("reloading endpoint, previous endpoint was expired");
69            self.reload_once().await;
70        }
71    }
72}
73
74#[derive(Debug, Clone)]
75pub(crate) struct EndpointCache {
76    error: Arc<Mutex<Option<BoxError>>>,
77    endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
78    // When the sender is dropped, this allows the reload loop to stop
79    _drop_guard: Arc<Sender<()>>,
80}
81
82impl ResolveEndpoint for EndpointCache {
83    fn resolve_endpoint<'a>(&'a self, _params: &'a EndpointResolverParams) -> EndpointFuture<'a> {
84        self.resolve_endpoint()
85    }
86}
87
88#[derive(Debug)]
89struct ExpiringEndpoint {
90    endpoint: Endpoint,
91    expiry: SystemTime,
92}
93
94impl ExpiringEndpoint {
95    fn is_expired(&self, now: SystemTime) -> bool {
96        tracing::debug!(expiry = ?self.expiry, now = ?now, delta = ?self.expiry.duration_since(now), "checking expiry status of endpoint");
97        match self.expiry.duration_since(now) {
98            Err(_) => true,
99            Ok(t) => t < Duration::from_secs(120),
100        }
101    }
102}
103
104pub(crate) async fn create_cache<F>(
105    loader_fn: impl Fn() -> F + Send + Sync + 'static,
106    sleep: SharedAsyncSleep,
107    time: SharedTimeSource,
108) -> Result<(EndpointCache, ReloadEndpoint), BoxError>
109where
110    F: Future<Output = Result<(Endpoint, SystemTime), BoxError>> + Send + 'static,
111{
112    let error_holder = Arc::new(Mutex::new(None));
113    let endpoint_holder = Arc::new(Mutex::new(None));
114    let (tx, rx) = tokio::sync::oneshot::channel();
115    let cache = EndpointCache {
116        error: error_holder.clone(),
117        endpoint: endpoint_holder.clone(),
118        _drop_guard: Arc::new(tx),
119    };
120    let reloader = ReloadEndpoint {
121        loader: Box::new(move || Box::pin((loader_fn)()) as _),
122        endpoint: endpoint_holder,
123        error: error_holder,
124        rx,
125        sleep,
126        time,
127    };
128    tracing::debug!("populating initial endpoint discovery cache");
129    reloader.reload_once().await;
130    // if we didn't successfully get an endpoint, bail out so the client knows
131    // configuration failed to work
132    cache.resolve_endpoint().await?;
133    Ok((cache, reloader))
134}
135
136impl EndpointCache {
137    fn resolve_endpoint(&self) -> EndpointFuture<'_> {
138        tracing::trace!("resolving endpoint from endpoint discovery cache");
139        let ep = self.endpoint.lock().unwrap().as_ref().map(|e| e.endpoint.clone()).ok_or_else(|| {
140            let error: Option<BoxError> = self.error.lock().unwrap().take();
141            error.unwrap_or_else(|| "Failed to resolve endpoint".into())
142        });
143        EndpointFuture::ready(ep)
144    }
145}
146
147#[cfg(test)]
148mod test {
149    use crate::endpoint_discovery::create_cache;
150    use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
151    use aws_smithy_async::test_util::controlled_time_and_sleep;
152    use aws_smithy_async::time::{SharedTimeSource, SystemTimeSource, TimeSource};
153    use aws_smithy_types::endpoint::Endpoint;
154    use std::sync::atomic::{AtomicUsize, Ordering};
155    use std::sync::Arc;
156    use std::time::{Duration, UNIX_EPOCH};
157    use tokio::time::timeout;
158
159    fn check_send_v<T: Send>(t: T) -> T {
160        t
161    }
162
163    #[tokio::test]
164    #[allow(unused_must_use)]
165    async fn check_traits() {
166        let (cache, reloader) = create_cache(
167            || async { Ok((Endpoint::builder().url("http://foo.com").build(), SystemTimeSource::new().now())) },
168            SharedAsyncSleep::new(TokioSleep::new()),
169            SharedTimeSource::new(SystemTimeSource::new()),
170        )
171        .await
172        .unwrap();
173        check_send_v(reloader.reload_task());
174        check_send_v(cache);
175    }
176
177    #[tokio::test]
178    async fn erroring_endpoint_always_reloaded() {
179        let expiry = UNIX_EPOCH + Duration::from_secs(123456789);
180        let ct = Arc::new(AtomicUsize::new(0));
181        let (cache, reloader) = create_cache(
182            move || {
183                let shared_ct = ct.clone();
184                shared_ct.fetch_add(1, Ordering::AcqRel);
185                async move { Ok((Endpoint::builder().url(format!("http://foo.com/{shared_ct:?}")).build(), expiry)) }
186            },
187            SharedAsyncSleep::new(TokioSleep::new()),
188            SharedTimeSource::new(SystemTimeSource::new()),
189        )
190        .await
191        .expect("returns an endpoint");
192        assert_eq!(cache.resolve_endpoint().await.expect("ok").url(), "http://foo.com/1");
193        // 120 second buffer
194        reloader.reload_increment(expiry - Duration::from_secs(240)).await;
195        assert_eq!(cache.resolve_endpoint().await.expect("ok").url(), "http://foo.com/1");
196
197        reloader.reload_increment(expiry).await;
198        assert_eq!(cache.resolve_endpoint().await.expect("ok").url(), "http://foo.com/2");
199    }
200
201    #[tokio::test]
202    async fn test_advance_of_task() {
203        let expiry = UNIX_EPOCH + Duration::from_secs(123456789);
204        // expires in 8 minutes
205        let (time, sleep, mut gate) = controlled_time_and_sleep(expiry - Duration::from_secs(239));
206        let ct = Arc::new(AtomicUsize::new(0));
207        let (cache, reloader) = create_cache(
208            move || {
209                let shared_ct = ct.clone();
210                shared_ct.fetch_add(1, Ordering::AcqRel);
211                async move { Ok((Endpoint::builder().url(format!("http://foo.com/{shared_ct:?}")).build(), expiry)) }
212            },
213            SharedAsyncSleep::new(sleep.clone()),
214            SharedTimeSource::new(time.clone()),
215        )
216        .await
217        .expect("first load success");
218        let reload_task = tokio::spawn(reloader.reload_task());
219        assert!(!reload_task.is_finished());
220        // expiry occurs after 2 sleeps
221        // t = 0
222        assert_eq!(gate.expect_sleep().await.duration(), Duration::from_secs(60));
223        assert_eq!(cache.resolve_endpoint().await.unwrap().url(), "http://foo.com/1");
224        // t = 60
225
226        let sleep = gate.expect_sleep().await;
227        // we're still holding the drop guard, so we haven't expired yet.
228        assert_eq!(cache.resolve_endpoint().await.unwrap().url(), "http://foo.com/1");
229        assert_eq!(sleep.duration(), Duration::from_secs(60));
230        sleep.allow_progress();
231        // t = 120
232
233        let sleep = gate.expect_sleep().await;
234        assert_eq!(cache.resolve_endpoint().await.unwrap().url(), "http://foo.com/2");
235        sleep.allow_progress();
236
237        let sleep = gate.expect_sleep().await;
238        drop(cache);
239        sleep.allow_progress();
240
241        timeout(Duration::from_secs(1), reload_task)
242            .await
243            .expect("task finishes successfully")
244            .expect("finishes");
245    }
246}