1use std::task::{Context, Poll};
2use std::{collections::VecDeque, fmt, future::Future, io, net::SocketAddr, pin::Pin};
3
4use ntex_bytes::{PoolId, PoolRef};
5use ntex_io::{types, Io};
6use ntex_service::{Service, ServiceCtx, ServiceFactory};
7use ntex_util::future::{BoxFuture, Either};
8
9use super::{Address, Connect, ConnectError, Resolver};
10use crate::tcp_connect_in;
11
12pub struct Connector<T> {
14 resolver: Resolver<T>,
15 pool: PoolRef,
16 tag: &'static str,
17}
18
19impl<T> Copy for Connector<T> {}
20
21impl<T> Connector<T> {
22 pub fn new() -> Self {
24 Connector {
25 resolver: Resolver::new(),
26 pool: PoolId::P0.pool_ref(),
27 tag: "TCP-CLIENT",
28 }
29 }
30
31 pub fn memory_pool(mut self, id: PoolId) -> Self {
36 self.pool = id.pool_ref();
37 self
38 }
39
40 pub fn tag(mut self, tag: &'static str) -> Self {
44 self.tag = tag;
45 self
46 }
47}
48
49impl<T: Address> Connector<T> {
50 pub async fn connect<U>(&self, message: U) -> Result<Io, ConnectError>
52 where
53 Connect<T>: From<U>,
54 {
55 let address = self
57 .resolver
58 .lookup_with_tag(message.into(), self.tag)
59 .await?;
60
61 let port = address.port();
62 let Connect { req, addr, .. } = address;
63
64 if let Some(addr) = addr {
65 TcpConnectorResponse::new(req, port, addr, self.tag, self.pool).await
66 } else if let Some(addr) = req.addr() {
67 TcpConnectorResponse::new(
68 req,
69 addr.port(),
70 Either::Left(addr),
71 self.tag,
72 self.pool,
73 )
74 .await
75 } else {
76 log::error!("{}: TCP connector: got unresolved address", self.tag);
77 Err(ConnectError::Unresolved)
78 }
79 }
80}
81
82impl<T> Default for Connector<T> {
83 fn default() -> Self {
84 Connector::new()
85 }
86}
87
88impl<T> Clone for Connector<T> {
89 fn clone(&self) -> Self {
90 *self
91 }
92}
93
94impl<T> fmt::Debug for Connector<T> {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 f.debug_struct("Connector")
97 .field("tag", &self.tag)
98 .field("resolver", &self.resolver)
99 .field("memory_pool", &self.pool)
100 .finish()
101 }
102}
103
104impl<T: Address, C> ServiceFactory<Connect<T>, C> for Connector<T> {
105 type Response = Io;
106 type Error = ConnectError;
107 type Service = Connector<T>;
108 type InitError = ();
109
110 async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
111 Ok(*self)
112 }
113}
114
115impl<T: Address> Service<Connect<T>> for Connector<T> {
116 type Response = Io;
117 type Error = ConnectError;
118
119 async fn call(
120 &self,
121 req: Connect<T>,
122 _: ServiceCtx<'_, Self>,
123 ) -> Result<Self::Response, Self::Error> {
124 self.connect(req).await
125 }
126}
127
128struct TcpConnectorResponse<T> {
130 req: Option<T>,
131 port: u16,
132 addrs: Option<VecDeque<SocketAddr>>,
133 #[allow(clippy::type_complexity)]
134 stream: Option<BoxFuture<'static, Result<Io, io::Error>>>,
135 tag: &'static str,
136 pool: PoolRef,
137}
138
139impl<T: Address> TcpConnectorResponse<T> {
140 fn new(
141 req: T,
142 port: u16,
143 addr: Either<SocketAddr, VecDeque<SocketAddr>>,
144 tag: &'static str,
145 pool: PoolRef,
146 ) -> TcpConnectorResponse<T> {
147 log::trace!(
148 "{}: TCP connector - connecting to {:?} addr:{:?} port:{}",
149 tag,
150 req.host(),
151 addr,
152 port
153 );
154
155 match addr {
156 Either::Left(addr) => TcpConnectorResponse {
157 req: Some(req),
158 addrs: None,
159 stream: Some(Box::pin(tcp_connect_in(addr, pool))),
160 tag,
161 pool,
162 port,
163 },
164 Either::Right(addrs) => TcpConnectorResponse {
165 tag,
166 port,
167 pool,
168 req: Some(req),
169 addrs: Some(addrs),
170 stream: None,
171 },
172 }
173 }
174
175 fn can_continue(&self, err: &io::Error) -> bool {
176 log::trace!(
177 "{}: TCP connector - failed to connect to {:?} port: {} err: {:?}",
178 self.tag,
179 self.req.as_ref().unwrap().host(),
180 self.port,
181 err
182 );
183 !(self.addrs.is_none() || self.addrs.as_ref().unwrap().is_empty())
184 }
185}
186
187impl<T: Address> Future for TcpConnectorResponse<T> {
188 type Output = Result<Io, ConnectError>;
189
190 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
191 let this = self.get_mut();
192
193 loop {
195 if let Some(new) = this.stream.as_mut() {
196 match new.as_mut().poll(cx) {
197 Poll::Ready(Ok(sock)) => {
198 let req = this.req.take().unwrap();
199 log::trace!(
200 "{}: TCP connector - successfully connected to {:?} - {:?}",
201 this.tag,
202 req.host(),
203 sock.query::<types::PeerAddr>().get()
204 );
205 sock.set_tag(this.tag);
206 return Poll::Ready(Ok(sock));
207 }
208 Poll::Pending => return Poll::Pending,
209 Poll::Ready(Err(err)) => {
210 if !this.can_continue(&err) {
211 return Poll::Ready(Err(err.into()));
212 }
213 }
214 }
215 }
216
217 let addr = this.addrs.as_mut().unwrap().pop_front().unwrap();
219 this.stream = Some(Box::pin(tcp_connect_in(addr, this.pool)));
220 }
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[ntex::test]
229 async fn test_connect() {
230 let server = ntex::server::test_server(|| {
231 ntex_service::fn_service(|_| async { Ok::<_, ()>(()) })
232 });
233
234 let srv = Connector::default().tag("T").memory_pool(PoolId::P5);
235 let result = srv.connect("").await;
236 assert!(result.is_err());
237 let result = srv.connect("localhost:99999").await;
238 assert!(result.is_err());
239 assert!(format!("{:?}", srv).contains("Connector"));
240
241 let srv = Connector::default();
242 let result = srv.connect(format!("{}", server.addr())).await;
243 assert!(result.is_ok());
244
245 let msg = Connect::new(format!("{}", server.addr())).set_addrs(vec![
246 format!("127.0.0.1:{}", server.addr().port() - 1)
247 .parse()
248 .unwrap(),
249 server.addr(),
250 ]);
251 let result = crate::connect::connect(msg).await;
252 assert!(result.is_ok());
253
254 let msg = Connect::new(server.addr());
255 let result = crate::connect::connect(msg).await;
256 assert!(result.is_ok());
257 }
258}