ntex_net/connect/
service.rs

1use std::{collections::VecDeque, fmt, io, net::SocketAddr};
2
3use ntex_io::{Io, IoConfig, types};
4use ntex_service::cfg::{Cfg, SharedCfg};
5use ntex_service::{Service, ServiceCtx, ServiceFactory};
6use ntex_util::{future::Either, time::timeout_checked};
7
8use super::{Address, Connect, ConnectError, ConnectServiceError, Resolver};
9use crate::tcp_connect;
10
11/// Basic tcp stream connector
12pub struct Connector<T> {
13    cfg: Cfg<IoConfig>,
14    shared: SharedCfg,
15    resolver: Resolver<T>,
16}
17
18impl<T> Copy for Connector<T> {}
19
20impl<T> Connector<T> {
21    /// Construct new connect service with default configuration
22    pub fn new() -> Self {
23        Connector {
24            cfg: Default::default(),
25            shared: Default::default(),
26            resolver: Resolver::new(),
27        }
28    }
29
30    /// Construct new connect service with custom configuration
31    pub fn with(cfg: SharedCfg) -> Self {
32        Connector {
33            cfg: cfg.get(),
34            shared: cfg,
35            resolver: Resolver::new(),
36        }
37    }
38}
39
40impl<T: Address> Connector<T> {
41    /// Resolve and connect to remote host
42    pub async fn connect<U>(&self, message: U) -> Result<Io, ConnectError>
43    where
44        Connect<T>: From<U>,
45    {
46        timeout_checked(self.cfg.connect_timeout(), async {
47            // resolve first
48            let address = self
49                .resolver
50                .lookup_with_tag(message.into(), self.shared.tag())
51                .await?;
52
53            let port = address.port();
54            let Connect { req, addr, .. } = address;
55
56            if let Some(addr) = addr {
57                connect(req, port, addr, self.shared).await
58            } else if let Some(addr) = req.addr() {
59                connect(req, addr.port(), Either::Left(addr), self.shared).await
60            } else {
61                log::error!("{}: TCP connector: got unresolved address", self.cfg.tag());
62                Err(ConnectError::Unresolved)
63            }
64        })
65        .await
66        .map_err(|_| {
67            ConnectError::Io(io::Error::new(io::ErrorKind::TimedOut, "Connect timeout"))
68        })
69        .and_then(|item| item)
70    }
71}
72
73impl<T> Default for Connector<T> {
74    fn default() -> Self {
75        Connector::new()
76    }
77}
78
79impl<T> Clone for Connector<T> {
80    fn clone(&self) -> Self {
81        *self
82    }
83}
84
85impl<T> fmt::Debug for Connector<T> {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        f.debug_struct("Connector")
88            .field("cfg", &self.cfg)
89            .field("resolver", &self.resolver)
90            .finish()
91    }
92}
93
94impl<T: Address> ServiceFactory<Connect<T>, SharedCfg> for Connector<T> {
95    type Response = Io;
96    type Error = ConnectError;
97    type Service = Connector<T>;
98    type InitError = ConnectServiceError;
99
100    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
101        Ok(Connector::with(cfg))
102    }
103}
104
105impl<T: Address> Service<Connect<T>> for Connector<T> {
106    type Response = Io;
107    type Error = ConnectError;
108
109    async fn call(
110        &self,
111        req: Connect<T>,
112        _: ServiceCtx<'_, Self>,
113    ) -> Result<Self::Response, Self::Error> {
114        self.connect(req).await
115    }
116}
117
118/// Tcp stream connector
119async fn connect<T: Address>(
120    req: T,
121    port: u16,
122    addr: Either<SocketAddr, VecDeque<SocketAddr>>,
123    cfg: SharedCfg,
124) -> Result<Io, ConnectError> {
125    log::trace!(
126        "{}: TCP connector - connecting to {:?} addr:{addr:?} port:{port}",
127        cfg.tag(),
128        req.host(),
129    );
130
131    let io = match addr {
132        Either::Left(addr) => tcp_connect(addr, cfg).await?,
133        Either::Right(mut addrs) => loop {
134            let addr = addrs.pop_front().unwrap();
135
136            match tcp_connect(addr, cfg).await {
137                Ok(io) => break io,
138                Err(err) => {
139                    log::trace!(
140                        "{}: TCP connector - failed to connect to {:?} port: {port} err: {err:?}",
141                        cfg.tag(),
142                        req.host(),
143                    );
144                    if addrs.is_empty() {
145                        return Err(err.into());
146                    }
147                }
148            }
149        },
150    };
151
152    log::trace!(
153        "{}: TCP connector - successfully connected to {:?} - {:?}",
154        cfg.tag(),
155        req.host(),
156        io.query::<types::PeerAddr>().get()
157    );
158    Ok(io)
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    use ntex_util::time::Millis;
166
167    #[ntex::test]
168    async fn test_connect() {
169        let server = ntex::server::test_server(|| {
170            ntex_service::fn_service(|_| async { Ok::<_, ()>(()) })
171        });
172
173        let srv = Connector::default()
174            .create(
175                SharedCfg::new("T")
176                    .add(IoConfig::new().set_connect_timeout(Millis(5000)))
177                    .into(),
178            )
179            .await
180            .unwrap();
181        let result = srv.connect("").await;
182        assert!(result.is_err());
183        let result = srv.connect("localhost:99999").await;
184        assert!(result.is_err());
185        assert!(format!("{srv:?}").contains("Connector"));
186
187        let srv = Connector::default();
188        let result = srv.connect(format!("{}", server.addr())).await;
189        assert!(result.is_ok());
190
191        let msg = Connect::new(format!("{}", server.addr())).set_addrs(vec![
192            format!("127.0.0.1:{}", server.addr().port() - 1)
193                .parse()
194                .unwrap(),
195            server.addr(),
196        ]);
197        let result = crate::connect::connect(msg).await;
198        assert!(result.is_ok());
199
200        let msg = Connect::new(server.addr());
201        let result = crate::connect::connect(msg).await;
202        assert!(result.is_ok());
203    }
204}