Skip to main content

ip_discovery/
resolver.rs

1//! Resolution engine that coordinates providers and applies strategies.
2//!
3//! This module is the core orchestrator: it takes a [`Config`](crate::Config),
4//! queries providers according to the chosen [`Strategy`](crate::Strategy),
5//! and returns a [`ProviderResult`].
6//!
7//! The [`select_first`] and [`join_all_vec`] helper functions replace
8//! `futures::select_all` / `futures::join_all` to avoid pulling in the
9//! `futures-util` crate as a dependency.
10
11use crate::config::{Config, Strategy};
12use crate::error::{Error, ProviderError};
13use crate::provider::BoxedProvider;
14use crate::types::ProviderResult;
15use std::collections::HashMap;
16use std::future::Future;
17use std::net::IpAddr;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use std::time::Instant;
21use tokio::time::timeout;
22
23/// Boxed future returning a fallible provider result.
24type BoxFut<'a> = Pin<Box<dyn Future<Output = Result<ProviderResult, ProviderError>> + Send + 'a>>;
25
26/// Coordinates IP detection across configured providers.
27///
28/// Created via [`Resolver::new()`] with a [`Config`](crate::Config).
29/// Call [`resolve()`](Resolver::resolve) to perform the lookup.
30pub struct Resolver {
31    config: Config,
32}
33
34impl Resolver {
35    /// Create a new resolver with the given configuration
36    pub fn new(config: Config) -> Self {
37        Self { config }
38    }
39
40    /// Return an iterator over providers that support the configured IP version.
41    #[inline]
42    fn matching_providers(&self) -> impl Iterator<Item = &BoxedProvider> {
43        self.config
44            .providers
45            .iter()
46            .filter(|p| p.supports_version(self.config.version))
47    }
48
49    /// Wrap a single provider call in a timeout, returning a boxed future that
50    /// produces either a [`ProviderResult`] or a [`ProviderError`].
51    fn make_provider_future<'a>(&'a self, provider: &'a BoxedProvider) -> BoxFut<'a> {
52        let provider_name = provider.name().to_string();
53        let protocol = provider.protocol();
54        let start = Instant::now();
55        let fut = provider.get_ip(self.config.version);
56        let timeout_duration = self.config.timeout;
57
58        Box::pin(async move {
59            match timeout(timeout_duration, fut).await {
60                Ok(Ok(ip)) => {
61                    let latency = start.elapsed();
62                    Ok(ProviderResult {
63                        ip,
64                        provider: provider_name,
65                        protocol,
66                        latency,
67                    })
68                }
69                Ok(Err(e)) => Err(e),
70                Err(_) => Err(ProviderError::message(provider_name, "timeout")),
71            }
72        })
73    }
74
75    /// Resolve the public IP address using the configured strategy.
76    ///
77    /// # Errors
78    ///
79    /// - [`Error::NoProvidersForVersion`] — no provider supports the requested IP version.
80    /// - [`Error::AllProvidersFailed`] — every provider either failed or timed out.
81    /// - [`Error::ConsensusNotReached`] — (consensus strategy) too few providers agreed.
82    pub async fn resolve(&self) -> Result<ProviderResult, Error> {
83        if self.matching_providers().next().is_none() {
84            return Err(Error::NoProvidersForVersion);
85        }
86
87        match self.config.strategy {
88            Strategy::First => self.resolve_first().await,
89            Strategy::Race => self.resolve_race().await,
90            Strategy::Consensus { min_agree } => self.resolve_consensus(min_agree).await,
91        }
92    }
93
94    /// Try providers in order, return first success.
95    async fn resolve_first(&self) -> Result<ProviderResult, Error> {
96        let mut errors = Vec::new();
97
98        for provider in self.matching_providers() {
99            match self.make_provider_future(provider).await {
100                Ok(result) => return Ok(result),
101                Err(e) => errors.push(e),
102            }
103        }
104
105        Err(Error::AllProvidersFailed(errors))
106    }
107
108    /// Race all providers concurrently, return fastest success.
109    async fn resolve_race(&self) -> Result<ProviderResult, Error> {
110        let mut futures: Vec<BoxFut<'_>> = self
111            .matching_providers()
112            .map(|p| self.make_provider_future(p))
113            .collect();
114
115        // Defensive: matching_providers() was already checked in resolve(),
116        // but guard against direct calls to this method.
117        if futures.is_empty() {
118            return Err(Error::NoProvidersForVersion);
119        }
120
121        let mut errors = Vec::new();
122
123        while !futures.is_empty() {
124            let (result, _index, remaining) = select_first(futures).await;
125            futures = remaining;
126
127            match result {
128                Ok(provider_result) => return Ok(provider_result),
129                Err(e) => errors.push(e),
130            }
131        }
132
133        Err(Error::AllProvidersFailed(errors))
134    }
135
136    /// Query all providers and require consensus.
137    async fn resolve_consensus(&self, min_agree: usize) -> Result<ProviderResult, Error> {
138        let futures: Vec<BoxFut<'_>> = self
139            .matching_providers()
140            .map(|p| self.make_provider_future(p))
141            .collect();
142
143        if futures.is_empty() {
144            return Err(Error::NoProvidersForVersion);
145        }
146
147        let all_results = join_all_vec(futures).await;
148
149        let mut ip_results: HashMap<IpAddr, Vec<ProviderResult>> = HashMap::new();
150        let mut errors = Vec::new();
151
152        for result in all_results {
153            match result {
154                Ok(pr) => ip_results.entry(pr.ip).or_default().push(pr),
155                Err(e) => errors.push(e),
156            }
157        }
158
159        let mut best: Option<(IpAddr, usize)> = None;
160        for (ip, providers) in &ip_results {
161            if providers.len() >= min_agree {
162                match &best {
163                    None => best = Some((*ip, providers.len())),
164                    Some((_, current_len)) if providers.len() > *current_len => {
165                        best = Some((*ip, providers.len()))
166                    }
167                    _ => {}
168                }
169            }
170        }
171
172        match best {
173            Some((ip, _)) => {
174                if let Some(providers) = ip_results.remove(&ip) {
175                    if let Some(fastest) = providers.into_iter().min_by_key(|p| p.latency) {
176                        return Ok(fastest);
177                    }
178                }
179                Err(Error::ConsensusNotReached {
180                    required: min_agree,
181                    got: 0,
182                    errors,
183                })
184            }
185            None => {
186                let max_agreement = ip_results.values().map(|v| v.len()).max().unwrap_or(0);
187                Err(Error::ConsensusNotReached {
188                    required: min_agree,
189                    got: max_agreement,
190                    errors,
191                })
192            }
193        }
194    }
195}
196
197/// Select the first future to complete from a vec, returning the result,
198/// the index in the **original** vec, and the remaining futures.
199///
200/// Note: `remaining` is **unordered** — `swap_remove` is used internally,
201/// so the positions no longer correspond to the original input order.
202///
203/// Equivalent to `futures::select_all`, inlined to avoid the dependency.
204///
205/// # Polling safety
206///
207/// All futures are `Pin<Box<...>>` (i.e. `Unpin`), so `Pin::new(fut).poll(cx)`
208/// is sound. Waker registration is delegated to each sub-future's poll impl;
209/// when any sub-future's I/O becomes ready the shared waker is notified,
210/// causing the entire `poll_fn` closure to be re-polled.
211async fn select_first<F: Future + Unpin>(mut futures: Vec<F>) -> (F::Output, usize, Vec<F>) {
212    std::future::poll_fn(|cx: &mut Context<'_>| {
213        for (i, fut) in futures.iter_mut().enumerate() {
214            if let Poll::Ready(output) = Pin::new(fut).poll(cx) {
215                futures.swap_remove(i);
216                return Poll::Ready((output, i, std::mem::take(&mut futures)));
217            }
218        }
219        Poll::Pending
220    })
221    .await
222}
223
224/// Join all futures in a vec, returning a vec of results in the original order.
225///
226/// Equivalent to `futures::join_all`, inlined to avoid the dependency.
227///
228/// # Polling safety
229///
230/// Same as [`select_first`]. The `is_some()` guard ensures each future is
231/// polled only while still pending, and `done` is never double-counted.
232async fn join_all_vec<T, F: Future<Output = T> + Unpin>(mut futures: Vec<F>) -> Vec<T> {
233    let total = futures.len();
234    let mut results: Vec<Option<T>> = (0..total).map(|_| None).collect();
235    let mut done = 0;
236
237    std::future::poll_fn(|cx: &mut Context<'_>| {
238        for (i, fut) in futures.iter_mut().enumerate() {
239            if results[i].is_some() {
240                continue;
241            }
242            if let Poll::Ready(output) = Pin::new(fut).poll(cx) {
243                results[i] = Some(output);
244                done += 1;
245            }
246        }
247        if done == total {
248            Poll::Ready(())
249        } else {
250            Poll::Pending
251        }
252    })
253    .await;
254
255    results
256        .into_iter()
257        .map(|r| r.expect("bug: future completed but result slot is empty"))
258        .collect()
259}