1use std::collections::{vec_deque, VecDeque};
2use std::{fmt, iter::FusedIterator, net::SocketAddr};
3
4use ntex_bytes::ByteString;
5use ntex_util::future::Either;
6
7pub trait Address: Unpin + 'static {
9 fn host(&self) -> &str;
11
12 fn port(&self) -> Option<u16>;
14
15 fn addr(&self) -> Option<SocketAddr> {
17 None
18 }
19}
20
21impl Address for String {
22 fn host(&self) -> &str {
23 self
24 }
25
26 fn port(&self) -> Option<u16> {
27 None
28 }
29}
30
31impl Address for ByteString {
32 fn host(&self) -> &str {
33 self
34 }
35
36 fn port(&self) -> Option<u16> {
37 None
38 }
39}
40
41impl Address for &'static str {
42 fn host(&self) -> &str {
43 self
44 }
45
46 fn port(&self) -> Option<u16> {
47 None
48 }
49}
50
51impl Address for SocketAddr {
52 fn host(&self) -> &str {
53 ""
54 }
55
56 fn port(&self) -> Option<u16> {
57 None
58 }
59
60 fn addr(&self) -> Option<SocketAddr> {
61 Some(*self)
62 }
63}
64
65#[derive(Eq, PartialEq, Debug, Hash)]
67pub struct Connect<T> {
68 pub(super) req: T,
69 pub(super) port: u16,
70 pub(super) addr: Option<Either<SocketAddr, VecDeque<SocketAddr>>>,
71}
72
73impl<T: Address> Connect<T> {
74 pub fn new(req: T) -> Connect<T> {
76 let (_, port) = parse(req.host());
77 Connect {
78 req,
79 port: port.unwrap_or(0),
80 addr: None,
81 }
82 }
83
84 pub fn with(req: T, addr: SocketAddr) -> Connect<T> {
86 Connect {
87 req,
88 port: 0,
89 addr: Some(Either::Left(addr)),
90 }
91 }
92
93 pub fn set_port(mut self, port: u16) -> Self {
97 self.port = port;
98 self
99 }
100
101 pub fn set_addr(mut self, addr: Option<SocketAddr>) -> Self {
103 if let Some(addr) = addr {
104 self.addr = Some(Either::Left(addr));
105 }
106 self
107 }
108
109 pub fn set_addrs<I>(mut self, addrs: I) -> Self
111 where
112 I: IntoIterator<Item = SocketAddr>,
113 {
114 let mut addrs = VecDeque::from_iter(addrs);
115 self.addr = if addrs.len() < 2 {
116 addrs.pop_front().map(Either::Left)
117 } else {
118 Some(Either::Right(addrs))
119 };
120 self
121 }
122
123 pub fn host(&self) -> &str {
125 self.req.host()
126 }
127
128 pub fn port(&self) -> u16 {
130 self.req.port().unwrap_or(self.port)
131 }
132
133 pub fn addrs(&self) -> ConnectAddrsIter<'_> {
135 if let Some(addr) = self.req.addr() {
136 ConnectAddrsIter {
137 inner: Either::Left(Some(addr)),
138 }
139 } else {
140 let inner = match self.addr {
141 None => Either::Left(None),
142 Some(Either::Left(addr)) => Either::Left(Some(addr)),
143 Some(Either::Right(ref addrs)) => Either::Right(addrs.iter()),
144 };
145
146 ConnectAddrsIter { inner }
147 }
148 }
149
150 pub fn take_addrs(&mut self) -> ConnectTakeAddrsIter {
152 if let Some(addr) = self.req.addr() {
153 ConnectTakeAddrsIter {
154 inner: Either::Left(Some(addr)),
155 }
156 } else {
157 let inner = match self.addr.take() {
158 None => Either::Left(None),
159 Some(Either::Left(addr)) => Either::Left(Some(addr)),
160 Some(Either::Right(addrs)) => Either::Right(addrs.into_iter()),
161 };
162
163 ConnectTakeAddrsIter { inner }
164 }
165 }
166
167 pub fn get_ref(&self) -> &T {
169 &self.req
170 }
171
172 pub fn map_addr<F, R>(self, f: F) -> Connect<R>
174 where
175 F: FnOnce(T) -> R,
176 {
177 let req = f(self.req);
178
179 Connect {
180 req,
181 port: self.port,
182 addr: self.addr,
183 }
184 }
185}
186
187impl<T: Clone> Clone for Connect<T> {
188 fn clone(&self) -> Self {
189 Connect {
190 req: self.req.clone(),
191 port: self.port,
192 addr: self.addr.clone(),
193 }
194 }
195}
196
197impl<T: Address> From<T> for Connect<T> {
198 fn from(addr: T) -> Self {
199 Connect::new(addr)
200 }
201}
202
203impl<T: Address> fmt::Display for Connect<T> {
204 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205 write!(f, "{}:{}", self.host(), self.port())
206 }
207}
208
209#[derive(Clone)]
211pub struct ConnectAddrsIter<'a> {
212 inner: Either<Option<SocketAddr>, vec_deque::Iter<'a, SocketAddr>>,
213}
214
215impl Iterator for ConnectAddrsIter<'_> {
216 type Item = SocketAddr;
217
218 fn next(&mut self) -> Option<Self::Item> {
219 match self.inner {
220 Either::Left(ref mut opt) => opt.take(),
221 Either::Right(ref mut iter) => iter.next().copied(),
222 }
223 }
224
225 fn size_hint(&self) -> (usize, Option<usize>) {
226 match self.inner {
227 Either::Left(Some(_)) => (1, Some(1)),
228 Either::Left(None) => (0, Some(0)),
229 Either::Right(ref iter) => iter.size_hint(),
230 }
231 }
232}
233
234impl fmt::Debug for ConnectAddrsIter<'_> {
235 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
236 f.debug_list().entries(self.clone()).finish()
237 }
238}
239
240impl ExactSizeIterator for ConnectAddrsIter<'_> {}
241
242impl FusedIterator for ConnectAddrsIter<'_> {}
243
244#[derive(Debug)]
246pub struct ConnectTakeAddrsIter {
247 inner: Either<Option<SocketAddr>, vec_deque::IntoIter<SocketAddr>>,
248}
249
250impl Iterator for ConnectTakeAddrsIter {
251 type Item = SocketAddr;
252
253 fn next(&mut self) -> Option<Self::Item> {
254 match self.inner {
255 Either::Left(ref mut opt) => opt.take(),
256 Either::Right(ref mut iter) => iter.next(),
257 }
258 }
259
260 fn size_hint(&self) -> (usize, Option<usize>) {
261 match self.inner {
262 Either::Left(Some(_)) => (1, Some(1)),
263 Either::Left(None) => (0, Some(0)),
264 Either::Right(ref iter) => iter.size_hint(),
265 }
266 }
267}
268
269impl ExactSizeIterator for ConnectTakeAddrsIter {}
270
271impl FusedIterator for ConnectTakeAddrsIter {}
272
273fn parse(host: &str) -> (&str, Option<u16>) {
274 let mut parts_iter = host.splitn(2, ':');
275 if let Some(host) = parts_iter.next() {
276 let port_str = parts_iter.next().unwrap_or("");
277 if let Ok(port) = port_str.parse::<u16>() {
278 (host, Some(port))
279 } else {
280 (host, None)
281 }
282 } else {
283 (host, None)
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn address() {
293 assert_eq!("test".host(), "test");
294 assert_eq!("test".port(), None);
295
296 let s = "test".to_string();
297 assert_eq!(s.host(), "test");
298 assert_eq!(s.port(), None);
299
300 let s = ByteString::from("test");
301 assert_eq!(s.host(), "test");
302 assert_eq!(s.port(), None);
303 }
304
305 #[test]
306 fn connect() {
307 let mut connect = Connect::new("www.rust-lang.org");
308 assert_eq!(connect.host(), "www.rust-lang.org");
309 assert_eq!(connect.port(), 0);
310 assert_eq!(*connect.get_ref(), "www.rust-lang.org");
311 connect = connect.set_port(80);
312 assert_eq!(connect.port(), 80);
313 let addrs = connect.addrs().clone();
314 assert_eq!(format!("{addrs:?}"), "[]");
315 assert!(connect.addrs().next().is_none());
316 assert!(format!("{:?}", connect.clone()).contains("Connect"));
317
318 let c = connect.clone().map_addr(|_| "www.rust-lang.org:80");
319 assert_eq!(c.host(), "www.rust-lang.org:80");
320 assert_eq!(c.port(), 80);
321 let addrs = c.addrs().clone();
322 assert_eq!(format!("{addrs:?}"), "[]");
323 assert!(c.addrs().next().is_none());
324
325 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
326 connect = connect.set_addrs(vec![addr]);
327 let addrs = connect.addrs().clone();
328 assert_eq!(format!("{addrs:?}"), "[127.0.0.1:8080]");
329 let addrs: Vec<_> = connect.take_addrs().collect();
330 assert_eq!(addrs.len(), 1);
331 assert!(addrs.contains(&addr));
332
333 let addr2: SocketAddr = "127.0.0.1:8081".parse().unwrap();
334 connect = connect.set_addrs(vec![addr, addr2]);
335 let addrs: Vec<_> = connect.addrs().collect();
336 assert_eq!(addrs.len(), 2);
337 assert!(addrs.contains(&addr));
338 assert!(addrs.contains(&addr2));
339
340 let addrs: Vec<_> = connect.take_addrs().collect();
341 assert_eq!(addrs.len(), 2);
342 assert!(addrs.contains(&addr));
343 assert!(addrs.contains(&addr2));
344 assert!(connect.addrs().next().is_none());
345
346 connect = connect.set_addrs(vec![addr]);
347 assert_eq!(format!("{connect}"), "www.rust-lang.org:80");
348
349 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
350 let mut connect = Connect::new(addr);
351 assert_eq!(connect.host(), "");
352 assert_eq!(connect.port(), 0);
353 let addrs: Vec<_> = connect.addrs().collect();
354 assert_eq!(addrs.len(), 1);
355 assert!(addrs.contains(&addr));
356 let addrs: Vec<_> = connect.take_addrs().collect();
357 assert_eq!(addrs.len(), 1);
358 assert!(addrs.contains(&addr));
359 }
360}