1use std::collections::VecDeque;
4use std::convert::Infallible;
5use std::future::{Future, Ready, ready};
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::task::{Context, Poll};
8
9#[derive(Debug, Clone, Default)]
11pub struct SocketAddrs(VecDeque<SocketAddr>);
12
13impl SocketAddrs {
14 #[allow(dead_code)]
15 pub(crate) fn set_port(&mut self, port: u16) {
16 for addr in &mut self.0 {
17 addr.set_port(port)
18 }
19 }
20
21 #[allow(dead_code)]
22 pub(crate) fn peek(&self) -> Option<SocketAddr> {
23 self.0.front().copied()
24 }
25
26 pub(crate) fn pop(&mut self) -> Option<SocketAddr> {
27 self.0.pop_front()
28 }
29
30 #[allow(dead_code)]
31 pub(crate) fn is_empty(&self) -> bool {
32 self.0.is_empty()
33 }
34
35 pub(crate) fn len(&self) -> usize {
36 self.0.len()
37 }
38
39 pub(crate) fn sort_preferred(&mut self, prefer: Option<IpVersion>) {
40 let mut v4_idx = None;
41 let mut v6_idx = None;
42
43 for (idx, addr) in self.0.iter().enumerate() {
44 match (addr.version(), v4_idx, v6_idx) {
45 (IpVersion::V4, None, _) => {
46 v4_idx = Some(idx);
47 }
48 (IpVersion::V6, _, None) => {
49 v6_idx = Some(idx);
50 }
51 (_, Some(_), Some(_)) => break,
52 _ => {}
53 }
54 }
55
56 let v4: Option<SocketAddr>;
57 let v6: Option<SocketAddr>;
58 if v4_idx.zip(v6_idx).is_some_and(|(v4, v6)| v4 > v6) {
59 v4 = v4_idx.and_then(|idx| self.0.remove(idx));
60 v6 = v6_idx.and_then(|idx| self.0.remove(idx));
61 } else {
62 v6 = v6_idx.and_then(|idx| self.0.remove(idx));
63 v4 = v4_idx.and_then(|idx| self.0.remove(idx));
64 }
65
66 match (prefer, v4, v6) {
67 (Some(IpVersion::V4), Some(addr_v4), Some(addr_v6)) => {
68 self.0.push_front(addr_v6);
69 self.0.push_front(addr_v4);
70 }
71 (Some(IpVersion::V6), Some(addr_v4), Some(addr_v6)) => {
72 self.0.push_front(addr_v4);
73 self.0.push_front(addr_v6);
74 }
75
76 (_, Some(addr_v4), Some(addr_v6)) => {
77 self.0.push_front(addr_v4);
78 self.0.push_front(addr_v6);
79 }
80 (_, Some(addr_v4), None) => {
81 self.0.push_front(addr_v4);
82 }
83 (_, None, Some(addr_v6)) => {
84 self.0.push_front(addr_v6);
85 }
86 _ => {}
87 }
88 }
89}
90
91impl From<SocketAddr> for SocketAddrs {
92 fn from(value: SocketAddr) -> Self {
93 let mut addrs = VecDeque::with_capacity(1);
94 addrs.push_front(value);
95 SocketAddrs(addrs)
96 }
97}
98
99impl FromIterator<SocketAddr> for SocketAddrs {
100 fn from_iter<T: IntoIterator<Item = SocketAddr>>(iter: T) -> Self {
101 Self(iter.into_iter().collect())
102 }
103}
104
105impl IntoIterator for SocketAddrs {
106 type Item = SocketAddr;
107 type IntoIter = std::collections::vec_deque::IntoIter<SocketAddr>;
108
109 fn into_iter(self) -> Self::IntoIter {
110 self.0.into_iter()
111 }
112}
113
114impl<'a> IntoIterator for &'a SocketAddrs {
115 type Item = &'a SocketAddr;
116 type IntoIter = std::collections::vec_deque::Iter<'a, SocketAddr>;
117
118 fn into_iter(self) -> Self::IntoIter {
119 self.0.iter()
120 }
121}
122
123pub trait IpVersionExt {
125 fn version(&self) -> IpVersion;
127}
128
129#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
131pub enum IpVersion {
132 V4,
134
135 V6,
137}
138
139impl IpVersion {
140 pub(super) fn from_binding(
141 ip_v4_address: Option<Ipv4Addr>,
142 ip_v6_address: Option<Ipv6Addr>,
143 ) -> Option<Self> {
144 match (ip_v4_address, ip_v6_address) {
145 (Some(_), Some(_)) => Some(Self::V6),
147 (Some(_), None) => Some(Self::V4),
148 (None, Some(_)) => Some(Self::V6),
149 (None, None) => None,
150 }
151 }
152
153 #[allow(dead_code)]
155 pub fn is_v4(&self) -> bool {
156 matches!(self, Self::V4)
157 }
158
159 #[allow(dead_code)]
161 pub fn is_v6(&self) -> bool {
162 matches!(self, Self::V6)
163 }
164}
165
166impl IpVersionExt for SocketAddr {
167 fn version(&self) -> IpVersion {
168 match self {
169 SocketAddr::V4(_) => IpVersion::V4,
170 SocketAddr::V6(_) => IpVersion::V6,
171 }
172 }
173}
174
175impl IpVersionExt for IpAddr {
176 fn version(&self) -> IpVersion {
177 match self {
178 IpAddr::V4(_) => IpVersion::V4,
179 IpAddr::V6(_) => IpVersion::V6,
180 }
181 }
182}
183
184pub trait Resolver<Request> {
188 type Address;
190
191 type Error;
193
194 type Future: Future<Output = Result<Self::Address, Self::Error>>;
196
197 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
199
200 fn resolve(&mut self, request: &Request) -> Self::Future;
202}
203
204impl<T, F, R, A, E> Resolver<R> for T
205where
206 T: for<'a> tower::Service<&'a R, Response = A, Error = E, Future = F>,
207 F: Future<Output = Result<A, E>>,
208{
209 type Address = A;
210 type Error = E;
211 type Future = F;
212
213 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
214 tower::Service::poll_ready(self, cx)
215 }
216
217 fn resolve(&mut self, request: &R) -> Self::Future {
218 tower::Service::call(self, request)
219 }
220}
221
222#[derive(Debug, Clone, Default, PartialEq, Eq)]
227pub struct StaticResolver<A> {
228 address: A,
229}
230
231impl<A> StaticResolver<A> {
232 pub fn new(address: A) -> Self {
234 Self { address }
235 }
236}
237
238impl<R, A> tower::Service<&R> for StaticResolver<A>
239where
240 A: Clone,
241{
242 type Response = A;
243 type Error = Infallible;
244 type Future = Ready<Result<Self::Response, Self::Error>>;
245
246 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
247 Poll::Ready(Ok(()))
248 }
249
250 fn call(&mut self, _: &R) -> Self::Future {
251 ready(Ok(self.address.clone()))
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
259
260 use static_assertions::assert_impl_all;
261
262 assert_impl_all!(SocketAddrs: Send, Sync, Clone, Default);
263 assert_impl_all!(IpVersion: Send, Sync, Clone, Copy, PartialEq, Eq, std::hash::Hash);
264 assert_impl_all!(StaticResolver<u32>: Send, Sync, Clone, Default, PartialEq, Eq);
265
266 #[test]
267 fn test_socket_addrs_default() {
268 let addrs = SocketAddrs::default();
269 assert!(addrs.is_empty());
270 assert_eq!(addrs.len(), 0);
271 assert_eq!(addrs.peek(), None);
272 }
273
274 #[test]
275 fn test_socket_addrs_from_socket_addr() {
276 let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
277 let addrs = SocketAddrs::from(addr);
278
279 assert!(!addrs.is_empty());
280 assert_eq!(addrs.len(), 1);
281 assert_eq!(addrs.peek(), Some(addr));
282 }
283
284 #[test]
285 fn test_socket_addrs_from_iterator() {
286 let addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
287 let addr2 = SocketAddr::from(([127, 0, 0, 1], 8081));
288 let vec_addrs = vec![addr1, addr2];
289
290 let addrs: SocketAddrs = vec_addrs.into_iter().collect();
291
292 assert!(!addrs.is_empty());
293 assert_eq!(addrs.len(), 2);
294 assert_eq!(addrs.peek(), Some(addr1));
295 }
296
297 #[test]
298 fn test_socket_addrs_set_port() {
299 let addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
300 let addr2 = SocketAddr::from((std::net::Ipv6Addr::LOCALHOST, 8080));
301 let mut addrs: SocketAddrs = vec![addr1, addr2].into_iter().collect();
302
303 addrs.set_port(9090);
304
305 let modified: Vec<SocketAddr> = addrs.into_iter().collect();
306 assert_eq!(modified[0].port(), 9090);
307 assert_eq!(modified[1].port(), 9090);
308 }
309
310 #[test]
311 fn test_socket_addrs_pop() {
312 let addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
313 let addr2 = SocketAddr::from(([127, 0, 0, 1], 8081));
314 let mut addrs: SocketAddrs = vec![addr1, addr2].into_iter().collect();
315
316 assert_eq!(addrs.pop(), Some(addr1));
317 assert_eq!(addrs.len(), 1);
318 assert_eq!(addrs.peek(), Some(addr2));
319
320 assert_eq!(addrs.pop(), Some(addr2));
321 assert_eq!(addrs.len(), 0);
322 assert!(addrs.is_empty());
323
324 assert_eq!(addrs.pop(), None);
325 }
326
327 #[test]
328 fn test_socket_addrs_sort_preferred_ipv4() {
329 let v4_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
330 let v6_addr = SocketAddr::from((std::net::Ipv6Addr::LOCALHOST, 8080));
331 let mut addrs: SocketAddrs = vec![v6_addr, v4_addr].into_iter().collect();
332
333 addrs.sort_preferred(Some(IpVersion::V4));
334
335 let sorted: Vec<SocketAddr> = addrs.into_iter().collect();
336 assert_eq!(sorted[0], v4_addr);
337 assert_eq!(sorted[1], v6_addr);
338 }
339
340 #[test]
341 fn test_socket_addrs_sort_preferred_ipv6() {
342 let v4_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
343 let v6_addr = SocketAddr::from((std::net::Ipv6Addr::LOCALHOST, 8080));
344 let mut addrs: SocketAddrs = vec![v4_addr, v6_addr].into_iter().collect();
345
346 addrs.sort_preferred(Some(IpVersion::V6));
347
348 let sorted: Vec<SocketAddr> = addrs.into_iter().collect();
349 assert_eq!(sorted[0], v6_addr);
350 assert_eq!(sorted[1], v4_addr);
351 }
352
353 #[test]
354 fn test_socket_addrs_sort_preferred_none() {
355 let v4_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
356 let v6_addr = SocketAddr::from((std::net::Ipv6Addr::LOCALHOST, 8080));
357 let mut addrs: SocketAddrs = vec![v4_addr, v6_addr].into_iter().collect();
358
359 addrs.sort_preferred(None);
360
361 let sorted: Vec<SocketAddr> = addrs.into_iter().collect();
362 assert_eq!(sorted[0], v6_addr);
363 assert_eq!(sorted[1], v4_addr);
364 }
365
366 #[test]
367 fn test_socket_addrs_sort_preferred_single_v4() {
368 let v4_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
369 let mut addrs: SocketAddrs = vec![v4_addr].into_iter().collect();
370
371 addrs.sort_preferred(Some(IpVersion::V6));
372
373 let sorted: Vec<SocketAddr> = addrs.into_iter().collect();
374 assert_eq!(sorted[0], v4_addr);
375 }
376
377 #[test]
378 fn test_socket_addrs_into_iterator() {
379 let addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
380 let addr2 = SocketAddr::from(([127, 0, 0, 1], 8081));
381 let addrs: SocketAddrs = vec![addr1, addr2].into_iter().collect();
382
383 let collected: Vec<SocketAddr> = addrs.into_iter().collect();
384 assert_eq!(collected, vec![addr1, addr2]);
385 }
386
387 #[test]
388 fn test_socket_addrs_iter_ref() {
389 let addr1 = SocketAddr::from(([127, 0, 0, 1], 8080));
390 let addr2 = SocketAddr::from(([127, 0, 0, 1], 8081));
391 let addrs: SocketAddrs = vec![addr1, addr2].into_iter().collect();
392
393 let collected: Vec<&SocketAddr> = (&addrs).into_iter().collect();
394 assert_eq!(collected, vec![&addr1, &addr2]);
395 }
396
397 #[test]
398 fn test_ip_version_from_binding() {
399 let v4_addr = Ipv4Addr::LOCALHOST;
400 let v6_addr = Ipv6Addr::LOCALHOST;
401
402 assert_eq!(
403 IpVersion::from_binding(Some(v4_addr), None),
404 Some(IpVersion::V4)
405 );
406 assert_eq!(
407 IpVersion::from_binding(None, Some(v6_addr)),
408 Some(IpVersion::V6)
409 );
410 assert_eq!(
411 IpVersion::from_binding(Some(v4_addr), Some(v6_addr)),
412 Some(IpVersion::V6)
413 );
414 assert_eq!(IpVersion::from_binding(None, None), None);
415 }
416
417 #[test]
418 fn test_ip_version_methods() {
419 let v4 = IpVersion::V4;
420 let v6 = IpVersion::V6;
421
422 assert!(v4.is_v4());
423 assert!(!v4.is_v6());
424
425 assert!(!v6.is_v4());
426 assert!(v6.is_v6());
427 }
428
429 #[test]
430 fn test_ip_version_ext_socket_addr() {
431 let v4_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
432 let v6_addr = SocketAddr::from((std::net::Ipv6Addr::LOCALHOST, 8080));
433
434 assert_eq!(v4_addr.version(), IpVersion::V4);
435 assert_eq!(v6_addr.version(), IpVersion::V6);
436 }
437
438 #[test]
439 fn test_ip_version_ext_ip_addr() {
440 let v4_ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
441 let v6_ip = IpAddr::V6(Ipv6Addr::LOCALHOST);
442
443 assert_eq!(v4_ip.version(), IpVersion::V4);
444 assert_eq!(v6_ip.version(), IpVersion::V6);
445 }
446
447 #[test]
448 fn test_static_resolver_new() {
449 let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
450 let resolver = StaticResolver::new(addr);
451
452 assert_eq!(resolver.address, addr);
453 }
454
455 #[test]
456 fn test_static_resolver_default() {
457 let resolver = StaticResolver::<u32>::default();
458 assert_eq!(resolver.address, 0u32);
459 }
460
461 #[test]
462 fn test_static_resolver_service() {
463 use std::task::{Context, Poll};
464
465 let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
466 let resolver = StaticResolver::new(addr);
467
468 let waker = std::task::Waker::noop();
469 let mut cx = Context::from_waker(waker);
470
471 let _request = "test_request";
472 let future = ready(Ok::<SocketAddr, Infallible>(resolver.address));
473 let mut future = Box::pin(future);
474
475 match future.as_mut().poll(&mut cx) {
476 Poll::Ready(Ok(result)) => assert_eq!(result, addr),
477 Poll::Ready(Err(_)) => panic!("Expected Ok result"),
478 Poll::Pending => panic!("Expected ready result"),
479 }
480 }
481
482 #[test]
483 fn test_static_resolver_clone_eq() {
484 let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
485 let resolver1 = StaticResolver::new(addr);
486 let resolver2 = resolver1.clone();
487
488 assert_eq!(resolver1, resolver2);
489
490 let resolver3 = StaticResolver::new(SocketAddr::from(([127, 0, 0, 1], 9090)));
491 assert_ne!(resolver1, resolver3);
492 }
493}