soda_pool/
pool.rs

1use std::{collections::BinaryHeap, net::IpAddr, sync::Arc, time::Duration};
2
3use chrono::{DateTime, TimeDelta, Utc};
4use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
5use tokio::{
6    sync::RwLock,
7    task::{AbortHandle, JoinHandle},
8    time::interval,
9};
10use tonic::transport::Channel;
11use tracing::{debug, trace};
12
13use crate::{
14    broken_endpoints::{BrokenEndpoints, DelayedAddress},
15    dns::resolve_domain,
16    endpoint_template::EndpointTemplate,
17    ready_channels::ReadyChannels,
18};
19
20/// Builder for creating a [`ChannelPool`].
21#[derive(Debug, Clone)]
22pub struct ChannelPoolBuilder {
23    endpoint: EndpointTemplate,
24    dns_interval: Duration,
25}
26
27impl ChannelPoolBuilder {
28    /// Create a new `ChannelPoolBuilder` from the given endpoint template.
29    #[must_use]
30    pub fn new(endpoint: impl Into<EndpointTemplate>) -> Self {
31        Self {
32            endpoint: endpoint.into(),
33            // Note: Is this a good default?
34            dns_interval: Duration::from_secs(5),
35        }
36    }
37
38    /// Set the DNS check interval.
39    ///
40    /// Set how often the resulting pool will check the DNS for new IP
41    /// addresses. Default is 5 seconds.
42    #[must_use]
43    pub fn dns_interval(&mut self, dns_interval: impl Into<Duration>) -> &mut Self {
44        self.dns_interval = dns_interval.into();
45        self
46    }
47
48    /// Build the [`ChannelPool`].
49    ///
50    /// This function will create a new channel pool from the given endpoint
51    /// template and settings. This includes starting channel pool's background
52    /// tasks.
53    #[must_use]
54    pub fn build(self) -> ChannelPool {
55        let ready_clients = Arc::new(ReadyChannels::default());
56        let broken_endpoints = Arc::new(BrokenEndpoints::default());
57
58        let dns_lookup_task = {
59            // Get shared ownership of the resources.
60            let ready_clients = ready_clients.clone();
61            let broken_endpoints = broken_endpoints.clone();
62            let endpoint = self.endpoint.clone();
63
64            tokio::spawn(async move {
65                let mut interval = interval(self.dns_interval);
66                loop {
67                    check_dns(&endpoint, &ready_clients, &broken_endpoints).await;
68
69                    interval.tick().await;
70                }
71            })
72        };
73
74        let doctor_task = {
75            // Get shared ownership of the resources.
76            let ready_clients = ready_clients.clone();
77            let broken_endpoints = broken_endpoints.clone();
78            let endpoint = self.endpoint.clone();
79
80            tokio::spawn(async move {
81                loop {
82                    // There is an asynchronous wait inside this function so we can run it in a tight loop here.
83                    recheck_broken_endpoint(
84                        broken_endpoints.next_broken_ip_address().await,
85                        &endpoint,
86                        &ready_clients,
87                        &broken_endpoints,
88                    )
89                    .await;
90                }
91            })
92        };
93
94        ChannelPool {
95            template: Arc::new(self.endpoint),
96            ready_clients,
97            broken_endpoints,
98            _dns_lookup_task: Arc::new(dns_lookup_task.into()),
99            _doctor_task: Arc::new(doctor_task.into()),
100        }
101    }
102}
103
104async fn check_dns(
105    endpoint_template: &EndpointTemplate,
106    ready_clients: &ReadyChannels,
107    broken_endpoints: &BrokenEndpoints,
108) {
109    // Resolve domain to IP addresses.
110    let Ok(addresses) = resolve_domain(endpoint_template.domain()) else {
111        // todo-interface: DNS resolution would mainly fail if domain does not
112        // resolve to any IP address, but it could also fail for other reasons.
113        // In the future version, we should record this error and allow user to
114        // see it.
115        return;
116    };
117
118    let mut ready = Vec::new();
119    let mut broken = BinaryHeap::new();
120
121    for address in addresses {
122        // Skip if the address is already in ready_clients.
123        if let Some(channel) = ready_clients.find(address).await {
124            trace!("Skipping {:?} as already ready", address);
125            ready.push((address, channel));
126            continue;
127        }
128
129        // Skip if the address is already in broken_endpoints.
130        if let Some(entry) = broken_endpoints.get_address(address).await {
131            trace!("Skipping {:?} as already broken", address);
132            broken.push(entry);
133            continue;
134        }
135
136        debug!("Connecting to: {:?}", address);
137        let channel = endpoint_template.build(address).connect().await;
138        if let Ok(channel) = channel {
139            ready.push((address, channel));
140        } else {
141            broken.push(address.into());
142        }
143    }
144
145    // Replace a list of clients stored in `ready_clients`` with the new ones constructed in `ready`.
146    ready_clients.replace_with(ready).await;
147    broken_endpoints.replace_with(broken).await;
148}
149
150async fn recheck_broken_endpoint(
151    address: DelayedAddress,
152    endpoint: &EndpointTemplate,
153    ready_clients: &ReadyChannels,
154    broken_endpoints: &BrokenEndpoints,
155) {
156    let connection_test_result = endpoint.build(*address).connect().await;
157
158    if let Ok(channel) = connection_test_result {
159        debug!("Connection established to {:?}", *address);
160        ready_clients.add(*address, channel).await;
161    } else {
162        debug!("Can't connect to {:?}", *address);
163        broken_endpoints.re_add_address(address).await;
164    }
165}
166
167#[derive(Debug, Default)]
168struct AbortOnDrop(Option<AbortHandle>);
169
170impl<T> From<JoinHandle<T>> for AbortOnDrop {
171    fn from(handle: JoinHandle<T>) -> Self {
172        Self(Some(handle.abort_handle()))
173    }
174}
175
176impl Drop for AbortOnDrop {
177    fn drop(&mut self) {
178        if let Some(handle) = self.0.take() {
179            handle.abort();
180        }
181    }
182}
183
184/// Self-managed pool of tonic's [`Channel`]s.
185// todo-performance: Probably better to change to INNER pattern to avoid cloning multiple Arcs.
186#[derive(Debug)]
187pub struct ChannelPool {
188    template: Arc<EndpointTemplate>,
189    ready_clients: Arc<ReadyChannels>,
190    broken_endpoints: Arc<BrokenEndpoints>,
191
192    _dns_lookup_task: Arc<AbortOnDrop>,
193    _doctor_task: Arc<AbortOnDrop>,
194}
195
196impl ChannelPool {
197    /// Get a channel from the pool.
198    ///
199    /// This function will return a channel if one is available, or `None` if no
200    /// channels are available.
201    ///
202    /// ## Selection algorithm
203    ///
204    /// Currently, the channel is selected randomly from the pool of available
205    /// channels. However, this behavior may change in the future.
206    ///
207    /// ## Additional DNS and broken connection checks
208    ///
209    /// If no channels are available, the function will check the DNS and recheck connections to all
210    /// servers currently marked as dead. To avoid spamming the DNS and other
211    /// servers, this will be performed no more than once every 500ms.
212    ///
213    /// If the above check is running while this function is called, the function
214    /// will wait for the check to finish and return the result.
215    ///
216    /// If the check is not running, but the last check was performed less than 500ms ago,
217    /// the function will return `None` immediately.
218    ///
219    /// The specifics of this behavior are not set in stone and may change in the future.
220    pub async fn get_channel(&self) -> Option<(IpAddr, Channel)> {
221        static RECHECK_BROKEN_ENDPOINTS: RwLock<DateTime<Utc>> =
222            RwLock::const_new(DateTime::<Utc>::MIN_UTC);
223        const MIN_INTERVAL: TimeDelta = TimeDelta::milliseconds(500);
224
225        if let Some(entry) = self.ready_clients.get_any().await {
226            return Some(entry);
227        }
228
229        // todo: This entire function is a bit of a mess, but this part absolutely needs to be cleaned up.
230        let _guard = match RECHECK_BROKEN_ENDPOINTS.try_read() {
231            Ok(last_recheck_time)
232                if Utc::now().signed_duration_since(*last_recheck_time) < MIN_INTERVAL =>
233            {
234                return None;
235            }
236            Ok(guard) => {
237                drop(guard);
238                let mut guard = RECHECK_BROKEN_ENDPOINTS.write().await;
239                if let Some(entry) = self.ready_clients.get_any().await {
240                    return Some(entry);
241                }
242                *guard = Utc::now();
243                guard
244            }
245            Err(_) => {
246                // RECHECK_BROKEN_ENDPOINTS used here to wait until ready channels and broken endpoints are checked.
247                // Thus, there is no need to hold the lock after acquiring it.
248                // (Some other implementation might be worth considering, but this is a good start.)
249                let _ = RECHECK_BROKEN_ENDPOINTS.write().await;
250                return self.ready_clients.get_any().await;
251            }
252        };
253
254        trace!("Force recheck of broken endpoints");
255
256        let mut fut = FuturesUnordered::new();
257        fut.push(
258            async {
259                check_dns(&self.template, &self.ready_clients, &self.broken_endpoints).await;
260                self.ready_clients.get_any().await
261            }
262            .boxed(),
263        );
264
265        for address in self.broken_endpoints.addresses().await.iter().copied() {
266            fut.push(
267                async move {
268                    recheck_broken_endpoint(
269                        address,
270                        &self.template,
271                        &self.ready_clients,
272                        &self.broken_endpoints,
273                    )
274                    .await;
275                    self.ready_clients.get_any().await
276                }
277                .boxed(),
278            );
279        }
280
281        fut.select_next_some().await
282    }
283
284    /// Report a broken endpoint to the pool.
285    ///
286    /// This function will remove the endpoint from the pool and add it to the list of currently dead servers.
287    pub async fn report_broken(&self, ip_address: impl Into<IpAddr>) {
288        let ip_address = ip_address.into();
289        self.ready_clients.remove(ip_address).await;
290        self.broken_endpoints.add_address(ip_address).await;
291    }
292}
293
294/// This is a shallow clone, meaning that the new pool will reference the same
295/// resources as the original pool.
296impl Clone for ChannelPool {
297    fn clone(&self) -> Self {
298        #[allow(clippy::used_underscore_binding)]
299        Self {
300            template: self.template.clone(),
301            ready_clients: self.ready_clients.clone(),
302            broken_endpoints: self.broken_endpoints.clone(),
303            _dns_lookup_task: self._dns_lookup_task.clone(),
304            _doctor_task: self._doctor_task.clone(),
305        }
306    }
307}