ntex_net/connect/
service.rs

1use std::{collections::VecDeque, io, marker::PhantomData, net::SocketAddr};
2
3use ntex_io::{Io, IoConfig, types};
4use ntex_service::cfg::{Cfg, SharedCfg};
5use ntex_service::{Service, ServiceCtx, ServiceFactory};
6use ntex_util::{future::Either, time::timeout_checked};
7
8use super::{Address, Connect, ConnectError, ConnectServiceError, Resolver};
9use crate::tcp_connect;
10
11#[derive(Copy, Clone, Debug)]
12/// Basic tcp stream connector
13pub struct Connector<T>(PhantomData<T>);
14
15#[derive(Copy, Clone, Debug)]
16/// Basic tcp stream connector
17pub struct ConnectorService<T> {
18    cfg: Cfg<IoConfig>,
19    shared: SharedCfg,
20    resolver: Resolver<T>,
21}
22
23impl<T> Connector<T> {
24    /// Construct new connect service with default configuration
25    pub fn new() -> Self {
26        Connector(PhantomData)
27    }
28}
29
30impl<T> Default for Connector<T> {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl<T> ConnectorService<T> {
37    /// Construct new connect service with default configuration
38    pub fn new() -> Self {
39        ConnectorService {
40            cfg: Default::default(),
41            shared: Default::default(),
42            resolver: Resolver::new(),
43        }
44    }
45
46    /// Construct new connect service with custom configuration
47    pub fn with(cfg: SharedCfg) -> Self {
48        ConnectorService {
49            cfg: cfg.get(),
50            shared: cfg,
51            resolver: Resolver::new(),
52        }
53    }
54}
55
56impl<T> Default for ConnectorService<T> {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl<T: Address> ConnectorService<T> {
63    /// Resolve and connect to remote host
64    pub async fn connect<U>(&self, message: U) -> Result<Io, ConnectError>
65    where
66        Connect<T>: From<U>,
67    {
68        timeout_checked(self.cfg.connect_timeout(), async {
69            // resolve first
70            let address = self
71                .resolver
72                .lookup_with_tag(message.into(), self.shared.tag())
73                .await?;
74
75            let port = address.port();
76            let Connect { req, addr, .. } = address;
77
78            if let Some(addr) = addr {
79                connect(req, port, addr, self.shared).await
80            } else if let Some(addr) = req.addr() {
81                connect(req, addr.port(), Either::Left(addr), self.shared).await
82            } else {
83                log::error!("{}: TCP connector: got unresolved address", self.cfg.tag());
84                Err(ConnectError::Unresolved)
85            }
86        })
87        .await
88        .map_err(|_| {
89            ConnectError::Io(io::Error::new(io::ErrorKind::TimedOut, "Connect timeout"))
90        })
91        .and_then(|item| item)
92    }
93}
94
95impl<T: Address> ServiceFactory<Connect<T>, SharedCfg> for Connector<T> {
96    type Response = Io;
97    type Error = ConnectError;
98    type Service = ConnectorService<T>;
99    type InitError = ConnectServiceError;
100
101    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
102        Ok(ConnectorService::with(cfg))
103    }
104}
105
106impl<T: Address> Service<Connect<T>> for ConnectorService<T> {
107    type Response = Io;
108    type Error = ConnectError;
109
110    async fn call(
111        &self,
112        req: Connect<T>,
113        _: ServiceCtx<'_, Self>,
114    ) -> Result<Self::Response, Self::Error> {
115        self.connect(req).await
116    }
117}
118
119/// Tcp stream connector
120async fn connect<T: Address>(
121    req: T,
122    port: u16,
123    addr: Either<SocketAddr, VecDeque<SocketAddr>>,
124    cfg: SharedCfg,
125) -> Result<Io, ConnectError> {
126    log::trace!(
127        "{}: TCP connector - connecting to {:?} addr:{addr:?} port:{port}",
128        cfg.tag(),
129        req.host(),
130    );
131
132    let io = match addr {
133        Either::Left(addr) => tcp_connect(addr, cfg).await?,
134        Either::Right(mut addrs) => loop {
135            let addr = addrs.pop_front().unwrap();
136
137            match tcp_connect(addr, cfg).await {
138                Ok(io) => break io,
139                Err(err) => {
140                    log::trace!(
141                        "{}: TCP connector - failed to connect to {:?} port: {port} err: {err:?}",
142                        cfg.tag(),
143                        req.host(),
144                    );
145                    if addrs.is_empty() {
146                        return Err(err.into());
147                    }
148                }
149            }
150        },
151    };
152
153    log::trace!(
154        "{}: TCP connector - successfully connected to {:?} - {:?}",
155        cfg.tag(),
156        req.host(),
157        io.query::<types::PeerAddr>().get()
158    );
159    Ok(io)
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    use ntex_util::time::Millis;
167
168    #[ntex::test]
169    async fn test_connect() {
170        let server = ntex::server::test_server(async || {
171            ntex_service::fn_service(|_| async { Ok::<_, ()>(()) })
172        });
173
174        let srv = Connector::default()
175            .create(
176                SharedCfg::new("T")
177                    .add(IoConfig::new().set_connect_timeout(Millis(5000)))
178                    .into(),
179            )
180            .await
181            .unwrap();
182        let result = srv.connect("").await;
183        assert!(result.is_err());
184        let result = srv.connect("localhost:99999").await;
185        assert!(result.is_err());
186        assert!(format!("{srv:?}").contains("Connector"));
187
188        let srv = ConnectorService::default();
189        let result = srv.connect(format!("{}", server.addr())).await;
190        assert!(result.is_ok());
191
192        let msg = Connect::new(format!("{}", server.addr())).set_addrs(vec![
193            format!("127.0.0.1:{}", server.addr().port() - 1)
194                .parse()
195                .unwrap(),
196            server.addr(),
197        ]);
198        let result = crate::connect::connect(msg).await;
199        assert!(result.is_ok());
200
201        let msg = Connect::new(server.addr());
202        let result = crate::connect::connect(msg).await;
203        assert!(result.is_ok());
204    }
205}