1use crate::config::{Config, Strategy};
8use crate::error::{Error, ProviderError};
9use crate::types::ProviderResult;
10use std::collections::HashMap;
11use std::net::IpAddr;
12use std::time::Instant;
13use tokio::time::timeout;
14use tracing::{debug, warn};
15
16pub struct Resolver {
21 config: Config,
22}
23
24impl Resolver {
25 pub fn new(config: Config) -> Self {
27 Self { config }
28 }
29
30 pub async fn resolve(&self) -> Result<ProviderResult, Error> {
38 let has_matching = self
39 .config
40 .providers
41 .iter()
42 .any(|p| p.supports_version(self.config.version));
43
44 if !has_matching {
45 return Err(Error::NoProvidersForVersion);
46 }
47
48 match self.config.strategy {
49 Strategy::First => self.resolve_first().await,
50 Strategy::Race => self.resolve_race().await,
51 Strategy::Consensus { min_agree } => {
52 let min = min_agree.max(2);
53 self.resolve_consensus(min).await
54 }
55 }
56 }
57
58 async fn resolve_first(&self) -> Result<ProviderResult, Error> {
60 let mut errors = Vec::new();
61
62 for provider in self
63 .config
64 .providers
65 .iter()
66 .filter(|p| p.supports_version(self.config.version))
67 {
68 let start = Instant::now();
69 debug!(provider = provider.name(), "trying provider");
70
71 match timeout(self.config.timeout, provider.get_ip(self.config.version)).await {
72 Ok(Ok(ip)) => {
73 let latency = start.elapsed();
74 debug!(
75 provider = provider.name(),
76 ip = %ip,
77 latency = ?latency,
78 "got IP from provider"
79 );
80 return Ok(ProviderResult {
81 ip,
82 provider: provider.name().to_string(),
83 protocol: provider.protocol(),
84 latency,
85 });
86 }
87 Ok(Err(e)) => {
88 warn!(provider = provider.name(), error = %e, "provider failed");
89 errors.push(e);
90 }
91 Err(_) => {
92 warn!(provider = provider.name(), "provider timed out");
93 errors.push(ProviderError::message(provider.name(), "timeout"));
94 }
95 }
96 }
97
98 Err(Error::AllProvidersFailed(errors))
99 }
100
101 async fn resolve_race(&self) -> Result<ProviderResult, Error> {
103 use futures_util::future::select_all;
104
105 let version = self.config.version;
106 let timeout_duration = self.config.timeout;
107
108 let futures: Vec<_> = self
109 .config
110 .providers
111 .iter()
112 .filter(|p| p.supports_version(version))
113 .map(|provider| {
114 let provider_name = provider.name().to_string();
115 let protocol = provider.protocol();
116 let start = Instant::now();
117 let fut = provider.get_ip(version);
118
119 Box::pin(async move {
120 match timeout(timeout_duration, fut).await {
121 Ok(Ok(ip)) => {
122 let latency = start.elapsed();
123 Ok(ProviderResult {
124 ip,
125 provider: provider_name,
126 protocol,
127 latency,
128 })
129 }
130 Ok(Err(e)) => Err(e),
131 Err(_) => Err(ProviderError::message(provider_name, "timeout")),
132 }
133 })
134 })
135 .collect();
136
137 if futures.is_empty() {
138 return Err(Error::NoProvidersForVersion);
139 }
140
141 let mut futures = futures;
142 let mut errors = Vec::new();
143
144 while !futures.is_empty() {
145 let (result, _index, remaining) = select_all(futures).await;
146 futures = remaining;
147
148 match result {
149 Ok(provider_result) => {
150 debug!(
151 provider = %provider_result.provider,
152 ip = %provider_result.ip,
153 latency = ?provider_result.latency,
154 "race won"
155 );
156 return Ok(provider_result);
157 }
158 Err(e) => {
159 errors.push(e);
160 }
161 }
162 }
163
164 Err(Error::AllProvidersFailed(errors))
165 }
166
167 async fn resolve_consensus(&self, min_agree: usize) -> Result<ProviderResult, Error> {
169 use futures_util::future::join_all;
170
171 let version = self.config.version;
172 let timeout_duration = self.config.timeout;
173
174 let futures: Vec<_> = self
175 .config
176 .providers
177 .iter()
178 .filter(|p| p.supports_version(version))
179 .map(|provider| {
180 let provider_name = provider.name().to_string();
181 let protocol = provider.protocol();
182 let start = Instant::now();
183 let fut = provider.get_ip(version);
184
185 async move {
186 match timeout(timeout_duration, fut).await {
187 Ok(Ok(ip)) => {
188 let latency = start.elapsed();
189 Some(ProviderResult {
190 ip,
191 provider: provider_name,
192 protocol,
193 latency,
194 })
195 }
196 _ => None,
197 }
198 }
199 })
200 .collect();
201
202 if futures.is_empty() {
203 return Err(Error::NoProvidersForVersion);
204 }
205
206 let results: Vec<Option<ProviderResult>> = join_all(futures).await;
207
208 let mut ip_results: HashMap<IpAddr, Vec<ProviderResult>> = HashMap::new();
209 for result in results.into_iter().flatten() {
210 ip_results.entry(result.ip).or_default().push(result);
211 }
212
213 let mut best: Option<(IpAddr, usize)> = None;
214 for (ip, providers) in &ip_results {
215 if providers.len() >= min_agree {
216 match &best {
217 None => best = Some((*ip, providers.len())),
218 Some((_, current_len)) if providers.len() > *current_len => {
219 best = Some((*ip, providers.len()))
220 }
221 _ => {}
222 }
223 }
224 }
225
226 match best {
227 Some((ip, _)) => {
228 let providers = ip_results.remove(&ip).unwrap_or_default();
230 let Some(fastest) = providers.into_iter().min_by_key(|p| p.latency) else {
231 return Err(Error::ConsensusNotReached {
232 required: min_agree,
233 got: 0,
234 });
235 };
236 debug!(
237 ip = %ip,
238 provider = %fastest.provider,
239 "consensus reached"
240 );
241 Ok(fastest)
242 }
243 None => {
244 let max_agreement = ip_results.values().map(|v| v.len()).max().unwrap_or(0);
245 Err(Error::ConsensusNotReached {
246 required: min_agree,
247 got: max_agreement,
248 })
249 }
250 }
251 }
252}