1use std::{collections::BinaryHeap, net::IpAddr, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use chrono::{DateTime, TimeDelta, Utc};
5use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
6use tokio::{
7 sync::RwLock,
8 task::{AbortHandle, JoinHandle},
9 time::interval,
10};
11use tonic::transport::Channel;
12use tracing::{debug, trace};
13
14use crate::{
15 broken_endpoints::{BrokenEndpoints, DelayedAddress},
16 dns::resolve_domain,
17 endpoint_template::EndpointTemplate,
18 ready_channels::ReadyChannels,
19};
20
21#[derive(Debug, Clone)]
23pub struct ManagedChannelPoolBuilder {
24 endpoint: EndpointTemplate,
25 dns_interval: Duration,
26}
27
28impl ManagedChannelPoolBuilder {
29 #[must_use]
31 pub fn new(endpoint: impl Into<EndpointTemplate>) -> Self {
32 Self {
33 endpoint: endpoint.into(),
34 dns_interval: Duration::from_secs(5),
36 }
37 }
38
39 #[must_use]
44 pub fn dns_interval(&mut self, dns_interval: impl Into<Duration>) -> &mut Self {
45 self.dns_interval = dns_interval.into();
46 self
47 }
48
49 #[must_use]
55 pub fn build(self) -> ManagedChannelPool {
56 let ready_clients = Arc::new(ReadyChannels::default());
57 let broken_endpoints = Arc::new(BrokenEndpoints::default());
58
59 let dns_lookup_task = {
60 let ready_clients = ready_clients.clone();
62 let broken_endpoints = broken_endpoints.clone();
63 let endpoint = self.endpoint.clone();
64
65 tokio::spawn(async move {
66 let mut interval = interval(self.dns_interval);
67 loop {
68 check_dns(&endpoint, &ready_clients, &broken_endpoints).await;
69
70 interval.tick().await;
71 }
72 })
73 };
74
75 let doctor_task = {
76 let ready_clients = ready_clients.clone();
78 let broken_endpoints = broken_endpoints.clone();
79 let endpoint = self.endpoint.clone();
80
81 tokio::spawn(async move {
82 loop {
83 recheck_broken_endpoint(
85 broken_endpoints.next_broken_ip_address().await,
86 &endpoint,
87 &ready_clients,
88 &broken_endpoints,
89 )
90 .await;
91 }
92 })
93 };
94
95 ManagedChannelPool {
96 template: Arc::new(self.endpoint),
97 ready_clients,
98 broken_endpoints,
99 _dns_lookup_task: Arc::new(dns_lookup_task.into()),
100 _doctor_task: Arc::new(doctor_task.into()),
101 }
102 }
103}
104
105async fn check_dns(
106 endpoint_template: &EndpointTemplate,
107 ready_clients: &ReadyChannels,
108 broken_endpoints: &BrokenEndpoints,
109) {
110 let Ok(addresses) = resolve_domain(endpoint_template.domain()) else {
112 return;
117 };
118
119 let mut ready = Vec::new();
120 let mut broken = BinaryHeap::new();
121
122 for address in addresses {
123 if let Some(channel) = ready_clients.find(address).await {
125 trace!("Skipping {:?} as already ready", address);
126 ready.push((address, channel));
127 continue;
128 }
129
130 if let Some(entry) = broken_endpoints.get_address(address).await {
132 trace!("Skipping {:?} as already broken", address);
133 broken.push(entry);
134 continue;
135 }
136
137 debug!("Connecting to: {:?}", address);
138 let channel = endpoint_template.build(address).connect().await;
139 if let Ok(channel) = channel {
140 ready.push((address, channel));
141 } else {
142 broken.push(address.into());
143 }
144 }
145
146 ready_clients.replace_with(ready).await;
148 broken_endpoints.replace_with(broken).await;
149}
150
151async fn recheck_broken_endpoint(
152 address: DelayedAddress,
153 endpoint: &EndpointTemplate,
154 ready_clients: &ReadyChannels,
155 broken_endpoints: &BrokenEndpoints,
156) {
157 let connection_test_result = endpoint.build(*address).connect().await;
158
159 if let Ok(channel) = connection_test_result {
160 debug!("Connection established to {:?}", *address);
161 ready_clients.add(*address, channel).await;
162 } else {
163 debug!("Can't connect to {:?}", *address);
164 broken_endpoints.re_add_address(address).await;
165 }
166}
167
168#[derive(Debug, Default)]
169struct AbortOnDrop(Option<AbortHandle>);
170
171impl<T> From<JoinHandle<T>> for AbortOnDrop {
172 fn from(handle: JoinHandle<T>) -> Self {
173 Self(Some(handle.abort_handle()))
174 }
175}
176
177impl Drop for AbortOnDrop {
178 fn drop(&mut self) {
179 if let Some(handle) = self.0.take() {
180 handle.abort();
181 }
182 }
183}
184
185#[async_trait]
190pub trait ChannelPool {
191 async fn get_channel(&self) -> Option<(IpAddr, Channel)>;
195
196 async fn report_broken(&self, ip_address: IpAddr);
200}
201
202#[derive(Debug)]
205pub struct ManagedChannelPool {
206 template: Arc<EndpointTemplate>,
207 ready_clients: Arc<ReadyChannels>,
208 broken_endpoints: Arc<BrokenEndpoints>,
209
210 _dns_lookup_task: Arc<AbortOnDrop>,
211 _doctor_task: Arc<AbortOnDrop>,
212}
213
214#[async_trait]
215impl ChannelPool for ManagedChannelPool {
216 async fn get_channel(&self) -> Option<(IpAddr, Channel)> {
240 static RECHECK_BROKEN_ENDPOINTS: RwLock<DateTime<Utc>> =
241 RwLock::const_new(DateTime::<Utc>::MIN_UTC);
242 const MIN_INTERVAL: TimeDelta = TimeDelta::milliseconds(500);
243
244 if let Some(entry) = self.ready_clients.get_any().await {
245 return Some(entry);
246 }
247
248 let _guard = match RECHECK_BROKEN_ENDPOINTS.try_read() {
250 Ok(last_recheck_time)
251 if Utc::now().signed_duration_since(*last_recheck_time) < MIN_INTERVAL =>
252 {
253 return None;
254 }
255 Ok(guard) => {
256 drop(guard);
257 let mut guard = RECHECK_BROKEN_ENDPOINTS.write().await;
258 if let Some(entry) = self.ready_clients.get_any().await {
259 return Some(entry);
260 }
261 *guard = Utc::now();
262 guard
263 }
264 Err(_) => {
265 let _ = RECHECK_BROKEN_ENDPOINTS.write().await;
269 return self.ready_clients.get_any().await;
270 }
271 };
272
273 trace!("Force recheck of broken endpoints");
274
275 let mut fut = FuturesUnordered::new();
276 fut.push(
277 async {
278 check_dns(&self.template, &self.ready_clients, &self.broken_endpoints).await;
279 self.ready_clients.get_any().await
280 }
281 .boxed(),
282 );
283
284 for address in self.broken_endpoints.addresses().await.iter().copied() {
285 fut.push(
286 async move {
287 recheck_broken_endpoint(
288 address,
289 &self.template,
290 &self.ready_clients,
291 &self.broken_endpoints,
292 )
293 .await;
294 self.ready_clients.get_any().await
295 }
296 .boxed(),
297 );
298 }
299
300 fut.select_next_some().await
301 }
302
303 async fn report_broken(&self, ip_address: IpAddr) {
307 self.ready_clients.remove(ip_address).await;
308 self.broken_endpoints.add_address(ip_address).await;
309 }
310}
311
312impl ManagedChannelPool {
313 #[must_use]
315 pub fn endpoint(&self) -> &EndpointTemplate {
316 &self.template
317 }
318}
319
320impl Clone for ManagedChannelPool {
323 fn clone(&self) -> Self {
324 #[allow(clippy::used_underscore_binding)]
325 Self {
326 template: self.template.clone(),
327 ready_clients: self.ready_clients.clone(),
328 broken_endpoints: self.broken_endpoints.clone(),
329 _dns_lookup_task: self._dns_lookup_task.clone(),
330 _doctor_task: self._doctor_task.clone(),
331 }
332 }
333}