soda_pool/
pool.rs

1use std::{collections::BinaryHeap, net::IpAddr, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use chrono::{DateTime, TimeDelta, Utc};
5use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
6use tokio::{
7    sync::RwLock,
8    task::{AbortHandle, JoinHandle},
9    time::interval,
10};
11use tonic::transport::Channel;
12use tracing::{debug, trace};
13
14use crate::{
15    broken_endpoints::{BrokenEndpoints, DelayedAddress},
16    dns::resolve_domain,
17    endpoint_template::EndpointTemplate,
18    ready_channels::ReadyChannels,
19};
20
21/// Builder for creating a [`ChannelPool`].
22#[derive(Debug, Clone)]
23pub struct ManagedChannelPoolBuilder {
24    endpoint: EndpointTemplate,
25    dns_interval: Duration,
26}
27
28impl ManagedChannelPoolBuilder {
29    /// Create a new `ChannelPoolBuilder` from the given endpoint template.
30    #[must_use]
31    pub fn new(endpoint: impl Into<EndpointTemplate>) -> Self {
32        Self {
33            endpoint: endpoint.into(),
34            // Note: Is this a good default?
35            dns_interval: Duration::from_secs(5),
36        }
37    }
38
39    /// Set the DNS check interval.
40    ///
41    /// Set how often the resulting pool will check the DNS for new IP
42    /// addresses. Default is 5 seconds.
43    #[must_use]
44    pub fn dns_interval(&mut self, dns_interval: impl Into<Duration>) -> &mut Self {
45        self.dns_interval = dns_interval.into();
46        self
47    }
48
49    /// Build the [`ChannelPool`].
50    ///
51    /// This function will create a new channel pool from the given endpoint
52    /// template and settings. This includes starting channel pool's background
53    /// tasks.
54    #[must_use]
55    pub fn build(self) -> ManagedChannelPool {
56        let ready_clients = Arc::new(ReadyChannels::default());
57        let broken_endpoints = Arc::new(BrokenEndpoints::default());
58
59        let dns_lookup_task = {
60            // Get shared ownership of the resources.
61            let ready_clients = ready_clients.clone();
62            let broken_endpoints = broken_endpoints.clone();
63            let endpoint = self.endpoint.clone();
64
65            tokio::spawn(async move {
66                let mut interval = interval(self.dns_interval);
67                loop {
68                    check_dns(&endpoint, &ready_clients, &broken_endpoints).await;
69
70                    interval.tick().await;
71                }
72            })
73        };
74
75        let doctor_task = {
76            // Get shared ownership of the resources.
77            let ready_clients = ready_clients.clone();
78            let broken_endpoints = broken_endpoints.clone();
79            let endpoint = self.endpoint.clone();
80
81            tokio::spawn(async move {
82                loop {
83                    // There is an asynchronous wait inside this function so we can run it in a tight loop here.
84                    recheck_broken_endpoint(
85                        broken_endpoints.next_broken_ip_address().await,
86                        &endpoint,
87                        &ready_clients,
88                        &broken_endpoints,
89                    )
90                    .await;
91                }
92            })
93        };
94
95        ManagedChannelPool {
96            template: Arc::new(self.endpoint),
97            ready_clients,
98            broken_endpoints,
99            _dns_lookup_task: Arc::new(dns_lookup_task.into()),
100            _doctor_task: Arc::new(doctor_task.into()),
101        }
102    }
103}
104
105async fn check_dns(
106    endpoint_template: &EndpointTemplate,
107    ready_clients: &ReadyChannels,
108    broken_endpoints: &BrokenEndpoints,
109) {
110    // Resolve domain to IP addresses.
111    let Ok(addresses) = resolve_domain(endpoint_template.domain()) else {
112        // todo-interface: DNS resolution would mainly fail if domain does not
113        // resolve to any IP address, but it could also fail for other reasons.
114        // In the future version, we should record this error and allow user to
115        // see it.
116        return;
117    };
118
119    let mut ready = Vec::new();
120    let mut broken = BinaryHeap::new();
121
122    for address in addresses {
123        // Skip if the address is already in ready_clients.
124        if let Some(channel) = ready_clients.find(address).await {
125            trace!("Skipping {:?} as already ready", address);
126            ready.push((address, channel));
127            continue;
128        }
129
130        // Skip if the address is already in broken_endpoints.
131        if let Some(entry) = broken_endpoints.get_address(address).await {
132            trace!("Skipping {:?} as already broken", address);
133            broken.push(entry);
134            continue;
135        }
136
137        debug!("Connecting to: {:?}", address);
138        let channel = endpoint_template.build(address).connect().await;
139        if let Ok(channel) = channel {
140            ready.push((address, channel));
141        } else {
142            broken.push(address.into());
143        }
144    }
145
146    // Replace a list of clients stored in `ready_clients`` with the new ones constructed in `ready`.
147    ready_clients.replace_with(ready).await;
148    broken_endpoints.replace_with(broken).await;
149}
150
151async fn recheck_broken_endpoint(
152    address: DelayedAddress,
153    endpoint: &EndpointTemplate,
154    ready_clients: &ReadyChannels,
155    broken_endpoints: &BrokenEndpoints,
156) {
157    let connection_test_result = endpoint.build(*address).connect().await;
158
159    if let Ok(channel) = connection_test_result {
160        debug!("Connection established to {:?}", *address);
161        ready_clients.add(*address, channel).await;
162    } else {
163        debug!("Can't connect to {:?}", *address);
164        broken_endpoints.re_add_address(address).await;
165    }
166}
167
168#[derive(Debug, Default)]
169struct AbortOnDrop(Option<AbortHandle>);
170
171impl<T> From<JoinHandle<T>> for AbortOnDrop {
172    fn from(handle: JoinHandle<T>) -> Self {
173        Self(Some(handle.abort_handle()))
174    }
175}
176
177impl Drop for AbortOnDrop {
178    fn drop(&mut self) {
179        if let Some(handle) = self.0.take() {
180            handle.abort();
181        }
182    }
183}
184
185/// Trait implemented by channel pools.
186///
187/// Note: The trait definition is using [`mod@async_trait`] so it is possible
188/// (or even suggested) for implementations to do the same.
189#[async_trait]
190pub trait ChannelPool {
191    /// Get a channel from the pool.
192    ///
193    /// Returns a channel if one is available, or `None` if no channels are available.
194    async fn get_channel(&self) -> Option<(IpAddr, Channel)>;
195
196    /// Report a broken endpoint to the pool.
197    ///
198    /// Removes the endpoint from the pool and add it to the list of currently dead servers.
199    async fn report_broken(&self, ip_address: IpAddr);
200}
201
202/// Self-managed pool of tonic's [`Channel`]s.
203// todo-performance: Probably better to change to INNER pattern to avoid cloning multiple Arcs.
204#[derive(Debug)]
205pub struct ManagedChannelPool {
206    template: Arc<EndpointTemplate>,
207    ready_clients: Arc<ReadyChannels>,
208    broken_endpoints: Arc<BrokenEndpoints>,
209
210    _dns_lookup_task: Arc<AbortOnDrop>,
211    _doctor_task: Arc<AbortOnDrop>,
212}
213
214#[async_trait]
215impl ChannelPool for ManagedChannelPool {
216    /// Get a channel from the pool.
217    ///
218    /// This function will return a channel if one is available, or `None` if no
219    /// channels are available.
220    ///
221    /// ## Selection algorithm
222    ///
223    /// Currently, the channel is selected randomly from the pool of available
224    /// channels. However, this behavior may change in the future.
225    ///
226    /// ## Additional DNS and broken connection checks
227    ///
228    /// If no channels are available, the function will check the DNS and recheck connections to all
229    /// servers currently marked as dead. To avoid spamming the DNS and other
230    /// servers, this will be performed no more than once every 500ms.
231    ///
232    /// If the above check is running while this function is called, the function
233    /// will wait for the check to finish and return the result.
234    ///
235    /// If the check is not running, but the last check was performed less than 500ms ago,
236    /// the function will return `None` immediately.
237    ///
238    /// The specifics of this behavior are not set in stone and may change in the future.
239    async fn get_channel(&self) -> Option<(IpAddr, Channel)> {
240        static RECHECK_BROKEN_ENDPOINTS: RwLock<DateTime<Utc>> =
241            RwLock::const_new(DateTime::<Utc>::MIN_UTC);
242        const MIN_INTERVAL: TimeDelta = TimeDelta::milliseconds(500);
243
244        if let Some(entry) = self.ready_clients.get_any().await {
245            return Some(entry);
246        }
247
248        // todo: This entire function is a bit of a mess, but this part absolutely needs to be cleaned up.
249        let _guard = match RECHECK_BROKEN_ENDPOINTS.try_read() {
250            Ok(last_recheck_time)
251                if Utc::now().signed_duration_since(*last_recheck_time) < MIN_INTERVAL =>
252            {
253                return None;
254            }
255            Ok(guard) => {
256                drop(guard);
257                let mut guard = RECHECK_BROKEN_ENDPOINTS.write().await;
258                if let Some(entry) = self.ready_clients.get_any().await {
259                    return Some(entry);
260                }
261                *guard = Utc::now();
262                guard
263            }
264            Err(_) => {
265                // RECHECK_BROKEN_ENDPOINTS used here to wait until ready channels and broken endpoints are checked.
266                // Thus, there is no need to hold the lock after acquiring it.
267                // (Some other implementation might be worth considering, but this is a good start.)
268                let _ = RECHECK_BROKEN_ENDPOINTS.write().await;
269                return self.ready_clients.get_any().await;
270            }
271        };
272
273        trace!("Force recheck of broken endpoints");
274
275        let mut fut = FuturesUnordered::new();
276        fut.push(
277            async {
278                check_dns(&self.template, &self.ready_clients, &self.broken_endpoints).await;
279                self.ready_clients.get_any().await
280            }
281            .boxed(),
282        );
283
284        for address in self.broken_endpoints.addresses().await.iter().copied() {
285            fut.push(
286                async move {
287                    recheck_broken_endpoint(
288                        address,
289                        &self.template,
290                        &self.ready_clients,
291                        &self.broken_endpoints,
292                    )
293                    .await;
294                    self.ready_clients.get_any().await
295                }
296                .boxed(),
297            );
298        }
299
300        fut.select_next_some().await
301    }
302
303    /// Report a broken endpoint to the pool.
304    ///
305    /// This function will remove the endpoint from the pool and add it to the list of currently dead servers.
306    async fn report_broken(&self, ip_address: IpAddr) {
307        self.ready_clients.remove(ip_address).await;
308        self.broken_endpoints.add_address(ip_address).await;
309    }
310}
311
312impl ManagedChannelPool {
313    /// Returns a reference to the endpoint template used by the pool.
314    #[must_use]
315    pub fn endpoint(&self) -> &EndpointTemplate {
316        &self.template
317    }
318}
319
320/// This is a shallow clone, meaning that the new pool will reference the same
321/// resources as the original pool.
322impl Clone for ManagedChannelPool {
323    fn clone(&self) -> Self {
324        #[allow(clippy::used_underscore_binding)]
325        Self {
326            template: self.template.clone(),
327            ready_clients: self.ready_clients.clone(),
328            broken_endpoints: self.broken_endpoints.clone(),
329            _dns_lookup_task: self._dns_lookup_task.clone(),
330            _doctor_task: self._doctor_task.clone(),
331        }
332    }
333}