1use crate::config::{Config, Strategy};
12use crate::error::{Error, ProviderError};
13use crate::provider::BoxedProvider;
14use crate::types::ProviderResult;
15use std::collections::HashMap;
16use std::future::Future;
17use std::net::IpAddr;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use std::time::Instant;
21use tokio::time::timeout;
22
23type BoxFut<'a> = Pin<Box<dyn Future<Output = Result<ProviderResult, ProviderError>> + Send + 'a>>;
25
26pub struct Resolver {
31 config: Config,
32}
33
34impl Resolver {
35 pub fn new(config: Config) -> Self {
37 Self { config }
38 }
39
40 #[inline]
42 fn matching_providers(&self) -> impl Iterator<Item = &BoxedProvider> {
43 self.config
44 .providers
45 .iter()
46 .filter(|p| p.supports_version(self.config.version))
47 }
48
49 fn make_provider_future<'a>(&'a self, provider: &'a BoxedProvider) -> BoxFut<'a> {
52 let provider_name = provider.name().to_string();
53 let protocol = provider.protocol();
54 let start = Instant::now();
55 let fut = provider.get_ip(self.config.version);
56 let timeout_duration = self.config.timeout;
57
58 Box::pin(async move {
59 match timeout(timeout_duration, fut).await {
60 Ok(Ok(ip)) => {
61 let latency = start.elapsed();
62 Ok(ProviderResult {
63 ip,
64 provider: provider_name,
65 protocol,
66 latency,
67 })
68 }
69 Ok(Err(e)) => Err(e),
70 Err(_) => Err(ProviderError::message(provider_name, "timeout")),
71 }
72 })
73 }
74
75 pub async fn resolve(&self) -> Result<ProviderResult, Error> {
83 if self.matching_providers().next().is_none() {
84 return Err(Error::NoProvidersForVersion);
85 }
86
87 match self.config.strategy {
88 Strategy::First => self.resolve_first().await,
89 Strategy::Race => self.resolve_race().await,
90 Strategy::Consensus { min_agree } => self.resolve_consensus(min_agree).await,
91 }
92 }
93
94 async fn resolve_first(&self) -> Result<ProviderResult, Error> {
96 let mut errors = Vec::new();
97
98 for provider in self.matching_providers() {
99 match self.make_provider_future(provider).await {
100 Ok(result) => return Ok(result),
101 Err(e) => errors.push(e),
102 }
103 }
104
105 Err(Error::AllProvidersFailed(errors))
106 }
107
108 async fn resolve_race(&self) -> Result<ProviderResult, Error> {
110 let mut futures: Vec<BoxFut<'_>> = self
111 .matching_providers()
112 .map(|p| self.make_provider_future(p))
113 .collect();
114
115 if futures.is_empty() {
118 return Err(Error::NoProvidersForVersion);
119 }
120
121 let mut errors = Vec::new();
122
123 while !futures.is_empty() {
124 let (result, _index, remaining) = select_first(futures).await;
125 futures = remaining;
126
127 match result {
128 Ok(provider_result) => return Ok(provider_result),
129 Err(e) => errors.push(e),
130 }
131 }
132
133 Err(Error::AllProvidersFailed(errors))
134 }
135
136 async fn resolve_consensus(&self, min_agree: usize) -> Result<ProviderResult, Error> {
138 let futures: Vec<BoxFut<'_>> = self
139 .matching_providers()
140 .map(|p| self.make_provider_future(p))
141 .collect();
142
143 if futures.is_empty() {
144 return Err(Error::NoProvidersForVersion);
145 }
146
147 let all_results = join_all_vec(futures).await;
148
149 let mut ip_results: HashMap<IpAddr, Vec<ProviderResult>> = HashMap::new();
150 let mut errors = Vec::new();
151
152 for result in all_results {
153 match result {
154 Ok(pr) => ip_results.entry(pr.ip).or_default().push(pr),
155 Err(e) => errors.push(e),
156 }
157 }
158
159 let mut best: Option<(IpAddr, usize)> = None;
160 for (ip, providers) in &ip_results {
161 if providers.len() >= min_agree {
162 match &best {
163 None => best = Some((*ip, providers.len())),
164 Some((_, current_len)) if providers.len() > *current_len => {
165 best = Some((*ip, providers.len()))
166 }
167 _ => {}
168 }
169 }
170 }
171
172 match best {
173 Some((ip, _)) => {
174 if let Some(providers) = ip_results.remove(&ip) {
175 if let Some(fastest) = providers.into_iter().min_by_key(|p| p.latency) {
176 return Ok(fastest);
177 }
178 }
179 Err(Error::ConsensusNotReached {
180 required: min_agree,
181 got: 0,
182 errors,
183 })
184 }
185 None => {
186 let max_agreement = ip_results.values().map(|v| v.len()).max().unwrap_or(0);
187 Err(Error::ConsensusNotReached {
188 required: min_agree,
189 got: max_agreement,
190 errors,
191 })
192 }
193 }
194 }
195}
196
197async fn select_first<F: Future + Unpin>(mut futures: Vec<F>) -> (F::Output, usize, Vec<F>) {
212 std::future::poll_fn(|cx: &mut Context<'_>| {
213 for (i, fut) in futures.iter_mut().enumerate() {
214 if let Poll::Ready(output) = Pin::new(fut).poll(cx) {
215 futures.swap_remove(i);
216 return Poll::Ready((output, i, std::mem::take(&mut futures)));
217 }
218 }
219 Poll::Pending
220 })
221 .await
222}
223
224async fn join_all_vec<T, F: Future<Output = T> + Unpin>(mut futures: Vec<F>) -> Vec<T> {
233 let total = futures.len();
234 let mut results: Vec<Option<T>> = (0..total).map(|_| None).collect();
235 let mut done = 0;
236
237 std::future::poll_fn(|cx: &mut Context<'_>| {
238 for (i, fut) in futures.iter_mut().enumerate() {
239 if results[i].is_some() {
240 continue;
241 }
242 if let Poll::Ready(output) = Pin::new(fut).poll(cx) {
243 results[i] = Some(output);
244 done += 1;
245 }
246 }
247 if done == total {
248 Poll::Ready(())
249 } else {
250 Poll::Pending
251 }
252 })
253 .await;
254
255 results
256 .into_iter()
257 .map(|r| r.expect("bug: future completed but result slot is empty"))
258 .collect()
259}