Skip to main content

ntex_net/connect/
service.rs

1use std::{collections::VecDeque, io, marker::PhantomData, net::SocketAddr};
2
3use ntex_error::Error;
4use ntex_io::{Io, IoConfig, types};
5use ntex_service::cfg::{Cfg, SharedCfg};
6use ntex_service::{Service, ServiceCtx, ServiceFactory};
7use ntex_util::{future::Either, time::timeout_checked};
8
9use super::{Address, Connect, ConnectError, ConnectServiceError, resolve};
10
11#[derive(Copy, Clone, Debug)]
12/// Basic tcp stream connector
13pub struct Connector<T>(PhantomData<T>);
14
15#[derive(Clone, Debug)]
16/// Basic tcp stream connector
17pub struct ConnectorService<T> {
18    cfg: Cfg<IoConfig>,
19    shared: SharedCfg,
20    _t: PhantomData<T>,
21}
22
23impl<T> Connector<T> {
24    /// Construct new connect service with default configuration
25    pub fn new() -> Self {
26        Connector(PhantomData)
27    }
28}
29
30impl<T> Default for Connector<T> {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl<T> ConnectorService<T> {
37    #[inline]
38    /// Construct new connect service with default configuration
39    pub fn new() -> Self {
40        ConnectorService::with(SharedCfg::default())
41    }
42
43    #[inline]
44    /// Construct new connect service with custom configuration
45    pub fn with(cfg: SharedCfg) -> Self {
46        ConnectorService {
47            cfg: cfg.get(),
48            shared: cfg,
49            _t: PhantomData,
50        }
51    }
52}
53
54impl<T> Default for ConnectorService<T> {
55    fn default() -> Self {
56        ConnectorService::new()
57    }
58}
59
60impl<T: Address> ConnectorService<T> {
61    /// Resolve and connect to remote host
62    pub async fn connect<U>(&self, message: U) -> Result<Io, ConnectError>
63    where
64        Connect<T>: From<U>,
65    {
66        timeout_checked(self.cfg.connect_timeout(), async {
67            // resolve first
68            let msg = resolve::lookup(message.into(), self.shared.tag())
69                .await
70                .map_err(Error::into_error)?;
71
72            let port = msg.port();
73            let Connect { req, addr, .. } = msg;
74
75            if let Some(addr) = addr {
76                connect(req, port, addr, self.shared.clone())
77                    .await
78                    .map_err(Error::into_error)
79            } else if let Some(addr) = req.addr() {
80                connect(req, addr.port(), Either::Left(addr), self.shared.clone())
81                    .await
82                    .map_err(Error::into_error)
83            } else {
84                log::error!("{}: TCP connector: got unresolved address", self.cfg.tag());
85                Err(ConnectError::Unresolved)
86            }
87        })
88        .await
89        .map_err(|()| {
90            ConnectError::Io(io::Error::new(io::ErrorKind::TimedOut, "Connect timeout"))
91        })
92        .and_then(|item| item)
93    }
94}
95
96impl<T: Address> ServiceFactory<Connect<T>, SharedCfg> for Connector<T> {
97    type Response = Io;
98    type Error = ConnectError;
99    type Service = ConnectorService<T>;
100    type InitError = ConnectServiceError;
101
102    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
103        Ok(ConnectorService::with(cfg))
104    }
105}
106
107impl<T: Address> Service<Connect<T>> for ConnectorService<T> {
108    type Response = Io;
109    type Error = ConnectError;
110
111    async fn call(
112        &self,
113        req: Connect<T>,
114        _: ServiceCtx<'_, Self>,
115    ) -> Result<Self::Response, Self::Error> {
116        self.connect(req).await
117    }
118}
119
120#[derive(Copy, Clone, Debug)]
121/// Basic tcp stream connector
122pub struct Connector2<T>(PhantomData<T>);
123
124#[derive(Clone, Debug)]
125/// Basic tcp stream connector
126pub struct ConnectorService2<T> {
127    cfg: Cfg<IoConfig>,
128    shared: SharedCfg,
129    _t: PhantomData<T>,
130}
131
132impl<T> Connector2<T> {
133    /// Construct new connect service with default configuration
134    pub fn new() -> Self {
135        Connector2(PhantomData)
136    }
137}
138
139impl<T> Default for Connector2<T> {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145impl<T> ConnectorService2<T> {
146    #[inline]
147    /// Construct new connect service with default configuration
148    pub fn new() -> Self {
149        ConnectorService2::with(SharedCfg::default())
150    }
151
152    #[inline]
153    /// Construct new connect service with custom configuration
154    pub fn with(cfg: SharedCfg) -> Self {
155        ConnectorService2 {
156            cfg: cfg.get(),
157            shared: cfg,
158            _t: PhantomData,
159        }
160    }
161}
162
163impl<T> Default for ConnectorService2<T> {
164    fn default() -> Self {
165        ConnectorService2::new()
166    }
167}
168
169impl<T: Address> ConnectorService2<T> {
170    /// Resolve and connect to remote host
171    pub async fn connect<U>(&self, message: U) -> Result<Io, Error<ConnectError>>
172    where
173        Connect<T>: From<U>,
174    {
175        timeout_checked(self.cfg.connect_timeout(), async {
176            // resolve first
177            let msg = resolve::lookup(message.into(), self.shared.tag()).await?;
178
179            let port = msg.port();
180            let Connect { req, addr, .. } = msg;
181
182            if let Some(addr) = addr {
183                connect(req, port, addr, self.shared.clone()).await
184            } else if let Some(addr) = req.addr() {
185                connect(req, addr.port(), Either::Left(addr), self.shared.clone()).await
186            } else {
187                Err(Error::from(ConnectError::Unresolved))
188            }
189        })
190        .await
191        .map_err(|()| {
192            Error::from(ConnectError::Io(io::Error::new(
193                io::ErrorKind::TimedOut,
194                "Connect timeout",
195            )))
196        })
197        .and_then(|item| item)
198        .map_err(|e| e.set_service(self.shared.service()))
199    }
200}
201
202impl<T: Address> ServiceFactory<Connect<T>, SharedCfg> for Connector2<T> {
203    type Response = Io;
204    type Error = Error<ConnectError>;
205    type Service = ConnectorService2<T>;
206    type InitError = ConnectServiceError;
207
208    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
209        Ok(ConnectorService2::with(cfg))
210    }
211}
212
213impl<T: Address> Service<Connect<T>> for ConnectorService2<T> {
214    type Response = Io;
215    type Error = Error<ConnectError>;
216
217    async fn call(
218        &self,
219        req: Connect<T>,
220        _: ServiceCtx<'_, Self>,
221    ) -> Result<Self::Response, Self::Error> {
222        self.connect(req).await
223    }
224}
225
226/// Tcp stream connector
227async fn connect<T: Address>(
228    req: T,
229    port: u16,
230    addr: Either<SocketAddr, VecDeque<SocketAddr>>,
231    cfg: SharedCfg,
232) -> Result<Io, Error<ConnectError>> {
233    log::trace!(
234        "{}: TCP connector - connecting to {:?} addr:{addr:?} port:{port}",
235        cfg.tag(),
236        req.host(),
237    );
238
239    let io = match addr {
240        Either::Left(addr) => crate::tcp_connect(addr, cfg.clone())
241            .await
242            .map_err(ConnectError::from)?,
243        Either::Right(mut addrs) => loop {
244            let addr = addrs.pop_front().unwrap();
245
246            match crate::tcp_connect(addr, cfg.clone()).await {
247                Ok(io) => break io,
248                Err(err) => {
249                    log::trace!(
250                        "{}: TCP connector - failed to connect to {:?} port: {port} err: {err:?}",
251                        cfg.tag(),
252                        req.host(),
253                    );
254                    if addrs.is_empty() {
255                        return Err(ConnectError::from(err).into());
256                    }
257                }
258            }
259        },
260    };
261
262    log::trace!(
263        "{}: TCP connector - successfully connected to {:?} - {:?}",
264        cfg.tag(),
265        req.host(),
266        io.query::<types::PeerAddr>().get()
267    );
268    Ok(io)
269}