chateau/client/conn/
dns.rs

1//! DNS resolution utilities.
2
3use std::collections::VecDeque;
4use std::convert::Infallible;
5use std::future::{Future, Ready, ready};
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::task::{Context, Poll};
8
9/// A collection of socket addresses.
10#[derive(Debug, Clone, Default)]
11pub struct SocketAddrs(VecDeque<SocketAddr>);
12
13impl SocketAddrs {
14    #[allow(dead_code)]
15    pub(crate) fn set_port(&mut self, port: u16) {
16        for addr in &mut self.0 {
17            addr.set_port(port)
18        }
19    }
20
21    #[allow(dead_code)]
22    pub(crate) fn peek(&self) -> Option<SocketAddr> {
23        self.0.front().copied()
24    }
25
26    pub(crate) fn pop(&mut self) -> Option<SocketAddr> {
27        self.0.pop_front()
28    }
29
30    #[allow(dead_code)]
31    pub(crate) fn is_empty(&self) -> bool {
32        self.0.is_empty()
33    }
34
35    pub(crate) fn len(&self) -> usize {
36        self.0.len()
37    }
38
39    pub(crate) fn sort_preferred(&mut self, prefer: Option<IpVersion>) {
40        let mut v4_idx = None;
41        let mut v6_idx = None;
42
43        for (idx, addr) in self.0.iter().enumerate() {
44            match (addr.version(), v4_idx, v6_idx) {
45                (IpVersion::V4, None, _) => {
46                    v4_idx = Some(idx);
47                }
48                (IpVersion::V6, _, None) => {
49                    v6_idx = Some(idx);
50                }
51                (_, Some(_), Some(_)) => break,
52                _ => {}
53            }
54        }
55
56        let v4: Option<SocketAddr>;
57        let v6: Option<SocketAddr>;
58        if v4_idx.zip(v6_idx).is_some_and(|(v4, v6)| v4 > v6) {
59            v4 = v4_idx.and_then(|idx| self.0.remove(idx));
60            v6 = v6_idx.and_then(|idx| self.0.remove(idx));
61        } else {
62            v6 = v6_idx.and_then(|idx| self.0.remove(idx));
63            v4 = v4_idx.and_then(|idx| self.0.remove(idx));
64        }
65
66        match (prefer, v4, v6) {
67            (Some(IpVersion::V4), Some(addr_v4), Some(addr_v6)) => {
68                self.0.push_front(addr_v6);
69                self.0.push_front(addr_v4);
70            }
71            (Some(IpVersion::V6), Some(addr_v4), Some(addr_v6)) => {
72                self.0.push_front(addr_v4);
73                self.0.push_front(addr_v6);
74            }
75
76            (_, Some(addr_v4), Some(addr_v6)) => {
77                self.0.push_front(addr_v4);
78                self.0.push_front(addr_v6);
79            }
80            (_, Some(addr_v4), None) => {
81                self.0.push_front(addr_v4);
82            }
83            (_, None, Some(addr_v6)) => {
84                self.0.push_front(addr_v6);
85            }
86            _ => {}
87        }
88    }
89}
90
91impl From<SocketAddr> for SocketAddrs {
92    fn from(value: SocketAddr) -> Self {
93        let mut addrs = VecDeque::with_capacity(1);
94        addrs.push_front(value);
95        SocketAddrs(addrs)
96    }
97}
98
99impl FromIterator<SocketAddr> for SocketAddrs {
100    fn from_iter<T: IntoIterator<Item = SocketAddr>>(iter: T) -> Self {
101        Self(iter.into_iter().collect())
102    }
103}
104
105impl IntoIterator for SocketAddrs {
106    type Item = SocketAddr;
107    type IntoIter = std::collections::vec_deque::IntoIter<SocketAddr>;
108
109    fn into_iter(self) -> Self::IntoIter {
110        self.0.into_iter()
111    }
112}
113
114impl<'a> IntoIterator for &'a SocketAddrs {
115    type Item = &'a SocketAddr;
116    type IntoIter = std::collections::vec_deque::Iter<'a, SocketAddr>;
117
118    fn into_iter(self) -> Self::IntoIter {
119        self.0.iter()
120    }
121}
122
123/// Extension trait for `IpAddr` and `SocketAddr` to get the IP version.
124pub trait IpVersionExt {
125    /// Get the IP version of this address.
126    fn version(&self) -> IpVersion;
127}
128
129/// IP version.
130#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
131pub enum IpVersion {
132    /// IPv4
133    V4,
134
135    /// IPv6
136    V6,
137}
138
139impl IpVersion {
140    pub(super) fn from_binding(
141        ip_v4_address: Option<Ipv4Addr>,
142        ip_v6_address: Option<Ipv6Addr>,
143    ) -> Option<Self> {
144        match (ip_v4_address, ip_v6_address) {
145            // Prefer IPv6 if both are available.
146            (Some(_), Some(_)) => Some(Self::V6),
147            (Some(_), None) => Some(Self::V4),
148            (None, Some(_)) => Some(Self::V6),
149            (None, None) => None,
150        }
151    }
152
153    /// Is this IP version IPv4?
154    #[allow(dead_code)]
155    pub fn is_v4(&self) -> bool {
156        matches!(self, Self::V4)
157    }
158
159    /// Is this IP version IPv6?
160    #[allow(dead_code)]
161    pub fn is_v6(&self) -> bool {
162        matches!(self, Self::V6)
163    }
164}
165
166impl IpVersionExt for SocketAddr {
167    fn version(&self) -> IpVersion {
168        match self {
169            SocketAddr::V4(_) => IpVersion::V4,
170            SocketAddr::V6(_) => IpVersion::V6,
171        }
172    }
173}
174
175impl IpVersionExt for IpAddr {
176    fn version(&self) -> IpVersion {
177        match self {
178            IpAddr::V4(_) => IpVersion::V4,
179            IpAddr::V6(_) => IpVersion::V6,
180        }
181    }
182}
183
184/// A service to convert request references into destination addresses.
185///
186/// Commonly, this might be a DNS lookup, but other schemes are possible.
187pub trait Resolver<Request> {
188    /// Address type returned
189    type Address;
190
191    /// Resolution error returned
192    type Error;
193
194    /// Future type that the resolver uses to work.
195    type Future: Future<Output = Result<Self::Address, Self::Error>>;
196
197    /// Check if the resolver is ready to resolve.
198    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
199
200    /// Return a future representing the work the resolver does.
201    fn resolve(&mut self, request: &Request) -> Self::Future;
202}
203
204impl<T, F, R, A, E> Resolver<R> for T
205where
206    T: for<'a> tower::Service<&'a R, Response = A, Error = E, Future = F>,
207    F: Future<Output = Result<A, E>>,
208{
209    type Address = A;
210    type Error = E;
211    type Future = F;
212
213    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
214        tower::Service::poll_ready(self, cx)
215    }
216
217    fn resolve(&mut self, request: &R) -> Self::Future {
218        tower::Service::call(self, request)
219    }
220}
221
222/// A static address resolver always returns the same address
223/// and ignores the request.
224///
225/// This is useful for connecting with a single location.
226#[derive(Debug, Clone, Default, PartialEq, Eq)]
227pub struct StaticResolver<A> {
228    address: A,
229}
230
231impl<A> StaticResolver<A> {
232    /// Create a new static-address resolver
233    pub fn new(address: A) -> Self {
234        Self { address }
235    }
236}
237
238impl<R, A> tower::Service<&R> for StaticResolver<A>
239where
240    A: Clone,
241{
242    type Response = A;
243    type Error = Infallible;
244    type Future = Ready<Result<Self::Response, Self::Error>>;
245
246    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
247        Poll::Ready(Ok(()))
248    }
249
250    fn call(&mut self, _: &R) -> Self::Future {
251        ready(Ok(self.address.clone()))
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
259
260    use static_assertions::assert_impl_all;
261
262    assert_impl_all!(SocketAddrs: Send, Sync, Clone, Default);
263    assert_impl_all!(IpVersion: Send, Sync, Clone, Copy, PartialEq, Eq, std::hash::Hash);
264    assert_impl_all!(StaticResolver<u32>: Send, Sync, Clone, Default, PartialEq, Eq);
265
266    #[test]
267    fn test_socket_addrs_default() {
268        let addrs = SocketAddrs::default();
269        assert!(addrs.is_empty());
270        assert_eq!(addrs.len(), 0);
271        assert_eq!(addrs.peek(), None);
272    }
273
274    #[test]
275    fn test_socket_addrs_from_socket_addr() {
276        let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
277        let addrs = SocketAddrs::from(addr);
278
279        assert!(!addrs.is_empty());
280        assert_eq!(addrs.len(), 1);
281        assert_eq!(addrs.peek(), Some(addr));
282    }
283
284    #[test]
285    fn test_socket_addrs_from_iterator() {
286        let addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
287        let addr2 = SocketAddr::from(([127, 0, 0, 1], 8081));
288        let vec_addrs = vec![addr1, addr2];
289
290        let addrs: SocketAddrs = vec_addrs.into_iter().collect();
291
292        assert!(!addrs.is_empty());
293        assert_eq!(addrs.len(), 2);
294        assert_eq!(addrs.peek(), Some(addr1));
295    }
296
297    #[test]
298    fn test_socket_addrs_set_port() {
299        let addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
300        let addr2 = SocketAddr::from((std::net::Ipv6Addr::LOCALHOST, 8080));
301        let mut addrs: SocketAddrs = vec![addr1, addr2].into_iter().collect();
302
303        addrs.set_port(9090);
304
305        let modified: Vec<SocketAddr> = addrs.into_iter().collect();
306        assert_eq!(modified[0].port(), 9090);
307        assert_eq!(modified[1].port(), 9090);
308    }
309
310    #[test]
311    fn test_socket_addrs_pop() {
312        let addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
313        let addr2 = SocketAddr::from(([127, 0, 0, 1], 8081));
314        let mut addrs: SocketAddrs = vec![addr1, addr2].into_iter().collect();
315
316        assert_eq!(addrs.pop(), Some(addr1));
317        assert_eq!(addrs.len(), 1);
318        assert_eq!(addrs.peek(), Some(addr2));
319
320        assert_eq!(addrs.pop(), Some(addr2));
321        assert_eq!(addrs.len(), 0);
322        assert!(addrs.is_empty());
323
324        assert_eq!(addrs.pop(), None);
325    }
326
327    #[test]
328    fn test_socket_addrs_sort_preferred_ipv4() {
329        let v4_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
330        let v6_addr = SocketAddr::from((std::net::Ipv6Addr::LOCALHOST, 8080));
331        let mut addrs: SocketAddrs = vec![v6_addr, v4_addr].into_iter().collect();
332
333        addrs.sort_preferred(Some(IpVersion::V4));
334
335        let sorted: Vec<SocketAddr> = addrs.into_iter().collect();
336        assert_eq!(sorted[0], v4_addr);
337        assert_eq!(sorted[1], v6_addr);
338    }
339
340    #[test]
341    fn test_socket_addrs_sort_preferred_ipv6() {
342        let v4_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
343        let v6_addr = SocketAddr::from((std::net::Ipv6Addr::LOCALHOST, 8080));
344        let mut addrs: SocketAddrs = vec![v4_addr, v6_addr].into_iter().collect();
345
346        addrs.sort_preferred(Some(IpVersion::V6));
347
348        let sorted: Vec<SocketAddr> = addrs.into_iter().collect();
349        assert_eq!(sorted[0], v6_addr);
350        assert_eq!(sorted[1], v4_addr);
351    }
352
353    #[test]
354    fn test_socket_addrs_sort_preferred_none() {
355        let v4_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
356        let v6_addr = SocketAddr::from((std::net::Ipv6Addr::LOCALHOST, 8080));
357        let mut addrs: SocketAddrs = vec![v4_addr, v6_addr].into_iter().collect();
358
359        addrs.sort_preferred(None);
360
361        let sorted: Vec<SocketAddr> = addrs.into_iter().collect();
362        assert_eq!(sorted[0], v6_addr);
363        assert_eq!(sorted[1], v4_addr);
364    }
365
366    #[test]
367    fn test_socket_addrs_sort_preferred_single_v4() {
368        let v4_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
369        let mut addrs: SocketAddrs = vec![v4_addr].into_iter().collect();
370
371        addrs.sort_preferred(Some(IpVersion::V6));
372
373        let sorted: Vec<SocketAddr> = addrs.into_iter().collect();
374        assert_eq!(sorted[0], v4_addr);
375    }
376
377    #[test]
378    fn test_socket_addrs_into_iterator() {
379        let addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
380        let addr2 = SocketAddr::from(([127, 0, 0, 1], 8081));
381        let addrs: SocketAddrs = vec![addr1, addr2].into_iter().collect();
382
383        let collected: Vec<SocketAddr> = addrs.into_iter().collect();
384        assert_eq!(collected, vec![addr1, addr2]);
385    }
386
387    #[test]
388    fn test_socket_addrs_iter_ref() {
389        let addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
390        let addr2 = SocketAddr::from(([127, 0, 0, 1], 8081));
391        let addrs: SocketAddrs = vec![addr1, addr2].into_iter().collect();
392
393        let collected: Vec<&SocketAddr> = (&addrs).into_iter().collect();
394        assert_eq!(collected, vec![&addr1, &addr2]);
395    }
396
397    #[test]
398    fn test_ip_version_from_binding() {
399        let v4_addr = Ipv4Addr::LOCALHOST;
400        let v6_addr = Ipv6Addr::LOCALHOST;
401
402        assert_eq!(
403            IpVersion::from_binding(Some(v4_addr), None),
404            Some(IpVersion::V4)
405        );
406        assert_eq!(
407            IpVersion::from_binding(None, Some(v6_addr)),
408            Some(IpVersion::V6)
409        );
410        assert_eq!(
411            IpVersion::from_binding(Some(v4_addr), Some(v6_addr)),
412            Some(IpVersion::V6)
413        );
414        assert_eq!(IpVersion::from_binding(None, None), None);
415    }
416
417    #[test]
418    fn test_ip_version_methods() {
419        let v4 = IpVersion::V4;
420        let v6 = IpVersion::V6;
421
422        assert!(v4.is_v4());
423        assert!(!v4.is_v6());
424
425        assert!(!v6.is_v4());
426        assert!(v6.is_v6());
427    }
428
429    #[test]
430    fn test_ip_version_ext_socket_addr() {
431        let v4_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
432        let v6_addr = SocketAddr::from((std::net::Ipv6Addr::LOCALHOST, 8080));
433
434        assert_eq!(v4_addr.version(), IpVersion::V4);
435        assert_eq!(v6_addr.version(), IpVersion::V6);
436    }
437
438    #[test]
439    fn test_ip_version_ext_ip_addr() {
440        let v4_ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
441        let v6_ip = IpAddr::V6(Ipv6Addr::LOCALHOST);
442
443        assert_eq!(v4_ip.version(), IpVersion::V4);
444        assert_eq!(v6_ip.version(), IpVersion::V6);
445    }
446
447    #[test]
448    fn test_static_resolver_new() {
449        let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
450        let resolver = StaticResolver::new(addr);
451
452        assert_eq!(resolver.address, addr);
453    }
454
455    #[test]
456    fn test_static_resolver_default() {
457        let resolver = StaticResolver::<u32>::default();
458        assert_eq!(resolver.address, 0u32);
459    }
460
461    #[test]
462    fn test_static_resolver_service() {
463        use std::task::{Context, Poll};
464
465        let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
466        let resolver = StaticResolver::new(addr);
467
468        let waker = std::task::Waker::noop();
469        let mut cx = Context::from_waker(waker);
470
471        let _request = "test_request";
472        let future = ready(Ok::<SocketAddr, Infallible>(resolver.address));
473        let mut future = Box::pin(future);
474
475        match future.as_mut().poll(&mut cx) {
476            Poll::Ready(Ok(result)) => assert_eq!(result, addr),
477            Poll::Ready(Err(_)) => panic!("Expected Ok result"),
478            Poll::Pending => panic!("Expected ready result"),
479        }
480    }
481
482    #[test]
483    fn test_static_resolver_clone_eq() {
484        let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
485        let resolver1 = StaticResolver::new(addr);
486        let resolver2 = resolver1.clone();
487
488        assert_eq!(resolver1, resolver2);
489
490        let resolver3 = StaticResolver::new(SocketAddr::from(([127, 0, 0, 1], 9090)));
491        assert_ne!(resolver1, resolver3);
492    }
493}