ntex_net/connect/
service.rs

1use std::{collections::VecDeque, fmt, io, net::SocketAddr};
2
3use ntex_bytes::PoolId;
4use ntex_io::{types, Io};
5use ntex_service::{Service, ServiceCtx, ServiceFactory};
6use ntex_util::{future::Either, time::timeout_checked, time::Millis};
7
8use super::{Address, Connect, ConnectError, Resolver};
9use crate::tcp_connect;
10
11/// Basic tcp stream connector
12pub struct Connector<T> {
13    tag: &'static str,
14    timeout: Millis,
15    resolver: Resolver<T>,
16}
17
18impl<T> Copy for Connector<T> {}
19
20impl<T> Connector<T> {
21    /// Construct new connect service with default dns resolver
22    pub fn new() -> Self {
23        Connector {
24            resolver: Resolver::new(),
25            tag: "TCP-CLIENT",
26            timeout: Millis::ZERO,
27        }
28    }
29
30    /// Set io tag
31    ///
32    /// Set tag to opened io object.
33    pub fn tag(mut self, tag: &'static str) -> Self {
34        self.tag = tag;
35        self
36    }
37
38    /// Connect timeout.
39    ///
40    /// i.e. max time to connect to remote host including dns name resolution.
41    /// Timeout is disabled by default
42    pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
43        self.timeout = timeout.into();
44        self
45    }
46
47    #[deprecated]
48    #[doc(hidden)]
49    /// Set memory pool
50    ///
51    /// Use specified memory pool for memory allocations. By default P0
52    /// memory pool is used.
53    pub fn memory_pool(self, _: PoolId) -> Self {
54        self
55    }
56}
57
58impl<T: Address> Connector<T> {
59    /// Resolve and connect to remote host
60    pub async fn connect<U>(&self, message: U) -> Result<Io, ConnectError>
61    where
62        Connect<T>: From<U>,
63    {
64        timeout_checked(self.timeout, async {
65            // resolve first
66            let address = self
67                .resolver
68                .lookup_with_tag(message.into(), self.tag)
69                .await?;
70
71            let port = address.port();
72            let Connect { req, addr, .. } = address;
73
74            if let Some(addr) = addr {
75                connect(req, port, addr, self.tag).await
76            } else if let Some(addr) = req.addr() {
77                connect(req, addr.port(), Either::Left(addr), self.tag).await
78            } else {
79                log::error!("{}: TCP connector: got unresolved address", self.tag);
80                Err(ConnectError::Unresolved)
81            }
82        })
83        .await
84        .map_err(|_| ConnectError::Io(io::Error::new(io::ErrorKind::TimedOut, "Timeout")))
85        .and_then(|item| item)
86    }
87}
88
89impl<T> Default for Connector<T> {
90    fn default() -> Self {
91        Connector::new()
92    }
93}
94
95impl<T> Clone for Connector<T> {
96    fn clone(&self) -> Self {
97        *self
98    }
99}
100
101impl<T> fmt::Debug for Connector<T> {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        f.debug_struct("Connector")
104            .field("tag", &self.tag)
105            .field("timeout", &self.timeout)
106            .field("resolver", &self.resolver)
107            .finish()
108    }
109}
110
111impl<T: Address, C> ServiceFactory<Connect<T>, C> for Connector<T> {
112    type Response = Io;
113    type Error = ConnectError;
114    type Service = Connector<T>;
115    type InitError = ();
116
117    async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
118        Ok(*self)
119    }
120}
121
122impl<T: Address> Service<Connect<T>> for Connector<T> {
123    type Response = Io;
124    type Error = ConnectError;
125
126    async fn call(
127        &self,
128        req: Connect<T>,
129        _: ServiceCtx<'_, Self>,
130    ) -> Result<Self::Response, Self::Error> {
131        self.connect(req).await
132    }
133}
134
135/// Tcp stream connector
136async fn connect<T: Address>(
137    req: T,
138    port: u16,
139    addr: Either<SocketAddr, VecDeque<SocketAddr>>,
140    tag: &'static str,
141) -> Result<Io, ConnectError> {
142    log::trace!(
143        "{tag}: TCP connector - connecting to {:?} addr:{addr:?} port:{port}",
144        req.host(),
145    );
146
147    let io = match addr {
148        Either::Left(addr) => tcp_connect(addr).await?,
149        Either::Right(mut addrs) => loop {
150            let addr = addrs.pop_front().unwrap();
151
152            match tcp_connect(addr).await {
153                Ok(io) => break io,
154                Err(err) => {
155                    log::trace!(
156                        "{tag}: TCP connector - failed to connect to {:?} port: {port} err: {err:?}",
157                        req.host(),
158                    );
159                    if addrs.is_empty() {
160                        return Err(err.into());
161                    }
162                }
163            }
164        },
165    };
166
167    log::trace!(
168        "{tag}: TCP connector - successfully connected to {:?} - {:?}",
169        req.host(),
170        io.query::<types::PeerAddr>().get()
171    );
172    io.set_tag(tag);
173    Ok(io)
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[ntex::test]
181    async fn test_connect() {
182        let server = ntex::server::test_server(|| {
183            ntex_service::fn_service(|_| async { Ok::<_, ()>(()) })
184        });
185
186        let srv = Connector::default().tag("T").timeout(Millis(5000));
187        let result = srv.connect("").await;
188        assert!(result.is_err());
189        let result = srv.connect("localhost:99999").await;
190        assert!(result.is_err());
191        assert!(format!("{srv:?}").contains("Connector"));
192
193        let srv = Connector::default();
194        let result = srv.connect(format!("{}", server.addr())).await;
195        assert!(result.is_ok());
196
197        let msg = Connect::new(format!("{}", server.addr())).set_addrs(vec![
198            format!("127.0.0.1:{}", server.addr().port() - 1)
199                .parse()
200                .unwrap(),
201            server.addr(),
202        ]);
203        let result = crate::connect::connect(msg).await;
204        assert!(result.is_ok());
205
206        let msg = Connect::new(server.addr());
207        let result = crate::connect::connect(msg).await;
208        assert!(result.is_ok());
209    }
210}