Skip to main content

ip_discovery/
resolver.rs

1//! Resolution engine that coordinates providers and applies strategies.
2//!
3//! This module is the core orchestrator: it takes a [`Config`](crate::Config),
4//! queries providers according to the chosen [`Strategy`](crate::Strategy),
5//! and returns a [`ProviderResult`].
6
7use crate::config::{Config, Strategy};
8use crate::error::{Error, ProviderError};
9use crate::types::ProviderResult;
10use std::collections::HashMap;
11use std::net::IpAddr;
12use std::time::Instant;
13use tokio::time::timeout;
14use tracing::{debug, warn};
15
16/// Coordinates IP detection across configured providers.
17///
18/// Created via [`Resolver::new()`] with a [`Config`](crate::Config).
19/// Call [`resolve()`](Resolver::resolve) to perform the lookup.
20pub struct Resolver {
21    config: Config,
22}
23
24impl Resolver {
25    /// Create a new resolver with the given configuration
26    pub fn new(config: Config) -> Self {
27        Self { config }
28    }
29
30    /// Resolve the public IP address using the configured strategy.
31    ///
32    /// # Errors
33    ///
34    /// - [`Error::NoProvidersForVersion`] — no provider supports the requested IP version.
35    /// - [`Error::AllProvidersFailed`] — every provider either failed or timed out.
36    /// - [`Error::ConsensusNotReached`] — (consensus strategy) too few providers agreed.
37    pub async fn resolve(&self) -> Result<ProviderResult, Error> {
38        let has_matching = self
39            .config
40            .providers
41            .iter()
42            .any(|p| p.supports_version(self.config.version));
43
44        if !has_matching {
45            return Err(Error::NoProvidersForVersion);
46        }
47
48        match self.config.strategy {
49            Strategy::First => self.resolve_first().await,
50            Strategy::Race => self.resolve_race().await,
51            Strategy::Consensus { min_agree } => {
52                let min = min_agree.max(2);
53                self.resolve_consensus(min).await
54            }
55        }
56    }
57
58    /// Try providers in order, return first success
59    async fn resolve_first(&self) -> Result<ProviderResult, Error> {
60        let mut errors = Vec::new();
61
62        for provider in self
63            .config
64            .providers
65            .iter()
66            .filter(|p| p.supports_version(self.config.version))
67        {
68            let start = Instant::now();
69            debug!(provider = provider.name(), "trying provider");
70
71            match timeout(self.config.timeout, provider.get_ip(self.config.version)).await {
72                Ok(Ok(ip)) => {
73                    let latency = start.elapsed();
74                    debug!(
75                        provider = provider.name(),
76                        ip = %ip,
77                        latency = ?latency,
78                        "got IP from provider"
79                    );
80                    return Ok(ProviderResult {
81                        ip,
82                        provider: provider.name().to_string(),
83                        protocol: provider.protocol(),
84                        latency,
85                    });
86                }
87                Ok(Err(e)) => {
88                    warn!(provider = provider.name(), error = %e, "provider failed");
89                    errors.push(e);
90                }
91                Err(_) => {
92                    warn!(provider = provider.name(), "provider timed out");
93                    errors.push(ProviderError::message(provider.name(), "timeout"));
94                }
95            }
96        }
97
98        Err(Error::AllProvidersFailed(errors))
99    }
100
101    /// Race all providers, return fastest result
102    async fn resolve_race(&self) -> Result<ProviderResult, Error> {
103        use futures_util::future::select_all;
104
105        let version = self.config.version;
106        let timeout_duration = self.config.timeout;
107
108        let futures: Vec<_> = self
109            .config
110            .providers
111            .iter()
112            .filter(|p| p.supports_version(version))
113            .map(|provider| {
114                let provider_name = provider.name().to_string();
115                let protocol = provider.protocol();
116                let start = Instant::now();
117                let fut = provider.get_ip(version);
118
119                Box::pin(async move {
120                    match timeout(timeout_duration, fut).await {
121                        Ok(Ok(ip)) => {
122                            let latency = start.elapsed();
123                            Ok(ProviderResult {
124                                ip,
125                                provider: provider_name,
126                                protocol,
127                                latency,
128                            })
129                        }
130                        Ok(Err(e)) => Err(e),
131                        Err(_) => Err(ProviderError::message(provider_name, "timeout")),
132                    }
133                })
134            })
135            .collect();
136
137        if futures.is_empty() {
138            return Err(Error::NoProvidersForVersion);
139        }
140
141        let mut futures = futures;
142        let mut errors = Vec::new();
143
144        while !futures.is_empty() {
145            let (result, _index, remaining) = select_all(futures).await;
146            futures = remaining;
147
148            match result {
149                Ok(provider_result) => {
150                    debug!(
151                        provider = %provider_result.provider,
152                        ip = %provider_result.ip,
153                        latency = ?provider_result.latency,
154                        "race won"
155                    );
156                    return Ok(provider_result);
157                }
158                Err(e) => {
159                    errors.push(e);
160                }
161            }
162        }
163
164        Err(Error::AllProvidersFailed(errors))
165    }
166
167    /// Query all providers and require consensus
168    async fn resolve_consensus(&self, min_agree: usize) -> Result<ProviderResult, Error> {
169        use futures_util::future::join_all;
170
171        let version = self.config.version;
172        let timeout_duration = self.config.timeout;
173
174        let futures: Vec<_> = self
175            .config
176            .providers
177            .iter()
178            .filter(|p| p.supports_version(version))
179            .map(|provider| {
180                let provider_name = provider.name().to_string();
181                let protocol = provider.protocol();
182                let start = Instant::now();
183                let fut = provider.get_ip(version);
184
185                async move {
186                    match timeout(timeout_duration, fut).await {
187                        Ok(Ok(ip)) => {
188                            let latency = start.elapsed();
189                            Some(ProviderResult {
190                                ip,
191                                provider: provider_name,
192                                protocol,
193                                latency,
194                            })
195                        }
196                        _ => None,
197                    }
198                }
199            })
200            .collect();
201
202        if futures.is_empty() {
203            return Err(Error::NoProvidersForVersion);
204        }
205
206        let results: Vec<Option<ProviderResult>> = join_all(futures).await;
207
208        let mut ip_results: HashMap<IpAddr, Vec<ProviderResult>> = HashMap::new();
209        for result in results.into_iter().flatten() {
210            ip_results.entry(result.ip).or_default().push(result);
211        }
212
213        let mut best: Option<(IpAddr, usize)> = None;
214        for (ip, providers) in &ip_results {
215            if providers.len() >= min_agree {
216                match &best {
217                    None => best = Some((*ip, providers.len())),
218                    Some((_, current_len)) if providers.len() > *current_len => {
219                        best = Some((*ip, providers.len()))
220                    }
221                    _ => {}
222                }
223            }
224        }
225
226        match best {
227            Some((ip, _)) => {
228                // Safety: `best` was chosen from `ip_results`, so the key always exists
229                let providers = ip_results.remove(&ip).unwrap_or_default();
230                let Some(fastest) = providers.into_iter().min_by_key(|p| p.latency) else {
231                    return Err(Error::ConsensusNotReached {
232                        required: min_agree,
233                        got: 0,
234                    });
235                };
236                debug!(
237                    ip = %ip,
238                    provider = %fastest.provider,
239                    "consensus reached"
240                );
241                Ok(fastest)
242            }
243            None => {
244                let max_agreement = ip_results.values().map(|v| v.len()).max().unwrap_or(0);
245                Err(Error::ConsensusNotReached {
246                    required: min_agree,
247                    got: max_agreement,
248                })
249            }
250        }
251    }
252}