1use std::collections::{VecDeque, vec_deque};
2use std::{fmt, iter::FusedIterator, net::SocketAddr};
3
4use ntex_bytes::ByteString;
5use ntex_util::future::Either;
6
7pub trait Address: Unpin + 'static {
9 fn host(&self) -> &str;
11
12 fn port(&self) -> Option<u16>;
14
15 fn addr(&self) -> Option<SocketAddr> {
17 None
18 }
19}
20
21impl Address for String {
22 fn host(&self) -> &str {
23 self
24 }
25
26 fn port(&self) -> Option<u16> {
27 None
28 }
29}
30
31impl Address for ByteString {
32 fn host(&self) -> &str {
33 self
34 }
35
36 fn port(&self) -> Option<u16> {
37 None
38 }
39}
40
41impl Address for &'static str {
42 fn host(&self) -> &str {
43 self
44 }
45
46 fn port(&self) -> Option<u16> {
47 None
48 }
49}
50
51impl Address for SocketAddr {
52 fn host(&self) -> &'static str {
53 ""
54 }
55
56 fn port(&self) -> Option<u16> {
57 None
58 }
59
60 fn addr(&self) -> Option<SocketAddr> {
61 Some(*self)
62 }
63}
64
65#[derive(Eq, PartialEq, Debug, Hash)]
67pub struct Connect<T> {
68 pub(super) req: T,
69 pub(super) port: u16,
70 pub(super) addr: Option<Either<SocketAddr, VecDeque<SocketAddr>>>,
71}
72
73impl<T: Address> Connect<T> {
74 #[must_use]
76 pub fn new(req: T) -> Connect<T> {
77 let (_, port) = parse(req.host());
78 Connect {
79 req,
80 port: port.unwrap_or(0),
81 addr: None,
82 }
83 }
84
85 #[must_use]
87 pub fn with(req: T, addr: SocketAddr) -> Connect<T> {
88 Connect {
89 req,
90 port: 0,
91 addr: Some(Either::Left(addr)),
92 }
93 }
94
95 #[must_use]
99 pub fn set_port(mut self, port: u16) -> Self {
100 self.port = port;
101 self
102 }
103
104 #[must_use]
106 pub fn set_addr(mut self, addr: Option<SocketAddr>) -> Self {
107 if let Some(addr) = addr {
108 self.addr = Some(Either::Left(addr));
109 }
110 self
111 }
112
113 #[must_use]
115 pub fn set_addrs<I>(mut self, addrs: I) -> Self
116 where
117 I: IntoIterator<Item = SocketAddr>,
118 {
119 let mut addrs = VecDeque::from_iter(addrs);
120 self.addr = if addrs.len() < 2 {
121 addrs.pop_front().map(Either::Left)
122 } else {
123 Some(Either::Right(addrs))
124 };
125 self
126 }
127
128 pub fn host(&self) -> &str {
130 self.req.host()
131 }
132
133 pub fn port(&self) -> u16 {
135 self.req.port().unwrap_or(self.port)
136 }
137
138 pub fn addrs(&self) -> ConnectAddrsIter<'_> {
140 if let Some(addr) = self.req.addr() {
141 ConnectAddrsIter {
142 inner: Either::Left(Some(addr)),
143 }
144 } else {
145 let inner = match self.addr {
146 None => Either::Left(None),
147 Some(Either::Left(addr)) => Either::Left(Some(addr)),
148 Some(Either::Right(ref addrs)) => Either::Right(addrs.iter()),
149 };
150
151 ConnectAddrsIter { inner }
152 }
153 }
154
155 pub fn take_addrs(&mut self) -> ConnectTakeAddrsIter {
157 if let Some(addr) = self.req.addr() {
158 ConnectTakeAddrsIter {
159 inner: Either::Left(Some(addr)),
160 }
161 } else {
162 let inner = match self.addr.take() {
163 None => Either::Left(None),
164 Some(Either::Left(addr)) => Either::Left(Some(addr)),
165 Some(Either::Right(addrs)) => Either::Right(addrs.into_iter()),
166 };
167
168 ConnectTakeAddrsIter { inner }
169 }
170 }
171
172 pub fn get_ref(&self) -> &T {
174 &self.req
175 }
176
177 pub fn map_addr<F, R>(self, f: F) -> Connect<R>
179 where
180 F: FnOnce(T) -> R,
181 {
182 let req = f(self.req);
183
184 Connect {
185 req,
186 port: self.port,
187 addr: self.addr,
188 }
189 }
190}
191
192impl<T: Clone> Clone for Connect<T> {
193 fn clone(&self) -> Self {
194 Connect {
195 req: self.req.clone(),
196 port: self.port,
197 addr: self.addr.clone(),
198 }
199 }
200}
201
202impl<T: Address> From<T> for Connect<T> {
203 fn from(addr: T) -> Self {
204 Connect::new(addr)
205 }
206}
207
208impl<T: Address> fmt::Display for Connect<T> {
209 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 write!(f, "{}:{}", self.host(), self.port())
211 }
212}
213
214#[derive(Clone)]
216pub struct ConnectAddrsIter<'a> {
217 inner: Either<Option<SocketAddr>, vec_deque::Iter<'a, SocketAddr>>,
218}
219
220impl Iterator for ConnectAddrsIter<'_> {
221 type Item = SocketAddr;
222
223 fn next(&mut self) -> Option<Self::Item> {
224 match self.inner {
225 Either::Left(ref mut opt) => opt.take(),
226 Either::Right(ref mut iter) => iter.next().copied(),
227 }
228 }
229
230 fn size_hint(&self) -> (usize, Option<usize>) {
231 match self.inner {
232 Either::Left(Some(_)) => (1, Some(1)),
233 Either::Left(None) => (0, Some(0)),
234 Either::Right(ref iter) => iter.size_hint(),
235 }
236 }
237}
238
239impl fmt::Debug for ConnectAddrsIter<'_> {
240 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241 f.debug_list().entries(self.clone()).finish()
242 }
243}
244
245impl ExactSizeIterator for ConnectAddrsIter<'_> {}
246
247impl FusedIterator for ConnectAddrsIter<'_> {}
248
249#[derive(Debug)]
251pub struct ConnectTakeAddrsIter {
252 inner: Either<Option<SocketAddr>, vec_deque::IntoIter<SocketAddr>>,
253}
254
255impl Iterator for ConnectTakeAddrsIter {
256 type Item = SocketAddr;
257
258 fn next(&mut self) -> Option<Self::Item> {
259 match self.inner {
260 Either::Left(ref mut opt) => opt.take(),
261 Either::Right(ref mut iter) => iter.next(),
262 }
263 }
264
265 fn size_hint(&self) -> (usize, Option<usize>) {
266 match self.inner {
267 Either::Left(Some(_)) => (1, Some(1)),
268 Either::Left(None) => (0, Some(0)),
269 Either::Right(ref iter) => iter.size_hint(),
270 }
271 }
272}
273
274impl ExactSizeIterator for ConnectTakeAddrsIter {}
275
276impl FusedIterator for ConnectTakeAddrsIter {}
277
278fn parse(host: &str) -> (&str, Option<u16>) {
279 let mut parts_iter = host.splitn(2, ':');
280 if let Some(host) = parts_iter.next() {
281 let port_str = parts_iter.next().unwrap_or("");
282 if let Ok(port) = port_str.parse::<u16>() {
283 (host, Some(port))
284 } else {
285 (host, None)
286 }
287 } else {
288 (host, None)
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn address() {
298 assert_eq!("test".host(), "test");
299 assert_eq!("test".port(), None);
300
301 let s = "test".to_string();
302 assert_eq!(s.host(), "test");
303 assert_eq!(s.port(), None);
304
305 let s = ByteString::from("test");
306 assert_eq!(s.host(), "test");
307 assert_eq!(s.port(), None);
308 }
309
310 #[test]
311 #[allow(clippy::similar_names)]
312 fn connect() {
313 let mut connect = Connect::new("www.rust-lang.org");
314 assert_eq!(connect.host(), "www.rust-lang.org");
315 assert_eq!(connect.port(), 0);
316 assert_eq!(*connect.get_ref(), "www.rust-lang.org");
317 connect = connect.set_port(80);
318 assert_eq!(connect.port(), 80);
319 let addrs = connect.addrs().clone();
320 assert_eq!(format!("{addrs:?}"), "[]");
321 assert!(connect.addrs().next().is_none());
322 assert!(format!("{:?}", connect.clone()).contains("Connect"));
323
324 let c = connect.clone().map_addr(|_| "www.rust-lang.org:80");
325 assert_eq!(c.host(), "www.rust-lang.org:80");
326 assert_eq!(c.port(), 80);
327 let addrs = c.addrs().clone();
328 assert_eq!(format!("{addrs:?}"), "[]");
329 assert!(c.addrs().next().is_none());
330
331 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
332 connect = connect.set_addrs(vec![addr]);
333 let addrs = connect.addrs().clone();
334 assert_eq!(format!("{addrs:?}"), "[127.0.0.1:8080]");
335 let addrs: Vec<_> = connect.take_addrs().collect();
336 assert_eq!(addrs.len(), 1);
337 assert!(addrs.contains(&addr));
338
339 let addr2: SocketAddr = "127.0.0.1:8081".parse().unwrap();
340 connect = connect.set_addrs(vec![addr, addr2]);
341 let addrs: Vec<_> = connect.addrs().collect();
342 assert_eq!(addrs.len(), 2);
343 assert!(addrs.contains(&addr));
344 assert!(addrs.contains(&addr2));
345
346 let addrs: Vec<_> = connect.take_addrs().collect();
347 assert_eq!(addrs.len(), 2);
348 assert!(addrs.contains(&addr));
349 assert!(addrs.contains(&addr2));
350 assert!(connect.addrs().next().is_none());
351
352 connect = connect.set_addrs(vec![addr]);
353 assert_eq!(format!("{connect}"), "www.rust-lang.org:80");
354
355 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
356 let mut connect = Connect::new(addr);
357 assert_eq!(connect.host(), "");
358 assert_eq!(connect.port(), 0);
359 let addrs: Vec<_> = connect.addrs().collect();
360 assert_eq!(addrs.len(), 1);
361 assert!(addrs.contains(&addr));
362 let addrs: Vec<_> = connect.take_addrs().collect();
363 assert_eq!(addrs.len(), 1);
364 assert!(addrs.contains(&addr));
365 }
366}