monoio_http_client/client/
connector.rs

1use std::{
2    fmt::{Debug, Display},
3    future::Future,
4    hash::Hash,
5    io,
6    net::ToSocketAddrs,
7    path::Path,
8};
9
10use http::Version;
11use monoio::{
12    io::{AsyncReadRent, AsyncWriteRent, Split},
13    net::{TcpStream, UnixStream},
14};
15use monoio_http::h1::codec::ClientCodec;
16
17use super::{
18    connection::HttpConnection,
19    key::HttpVersion,
20    pool::{ConnectionPool, PooledConnection},
21    ClientGlobalConfig, ConnectionConfig, Proto,
22};
23
24#[cfg(not(feature = "native-tls"))]
25pub type TlsStream<C> = monoio_rustls::ClientTlsStream<C>;
26
27#[cfg(feature = "native-tls")]
28pub type TlsStream<C> = monoio_native_tls::TlsStream<C>;
29
30pub trait Connector<K> {
31    type Connection;
32    type Error;
33
34    fn connect(&self, key: K) -> impl Future<Output = Result<Self::Connection, Self::Error>>;
35}
36
37#[derive(Default, Clone, Debug)]
38pub struct TcpConnector;
39
40impl<T> Connector<T> for TcpConnector
41where
42    T: ToSocketAddrs,
43{
44    type Connection = TcpStream;
45    type Error = io::Error;
46
47    async fn connect(&self, key: T) -> Result<Self::Connection, Self::Error> {
48        TcpStream::connect(key).await.map(|io| {
49            // we will ignore the set nodelay error
50            let _ = io.set_nodelay(true);
51            io
52        })
53    }
54}
55
56#[derive(Default, Clone, Debug)]
57pub struct UnixConnector;
58
59impl<P> Connector<P> for UnixConnector
60where
61    P: AsRef<Path>,
62{
63    type Connection = UnixStream;
64    type Error = io::Error;
65
66    async fn connect(&self, key: P) -> Result<Self::Connection, Self::Error> {
67        UnixStream::connect(key).await
68    }
69}
70
71#[derive(Clone)]
72pub struct TlsConnector<C> {
73    inner_connector: C,
74    #[cfg(not(feature = "native-tls"))]
75    tls_connector: monoio_rustls::TlsConnector,
76    #[cfg(feature = "native-tls")]
77    tls_connector: monoio_native_tls::TlsConnector,
78}
79
80impl<C: Debug> std::fmt::Debug for TlsConnector<C> {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        write!(f, "TlsConnector, inner: {:?}", self.inner_connector)
83    }
84}
85
86impl<C: Default> Default for TlsConnector<C> {
87    #[cfg(not(feature = "native-tls"))]
88    fn default() -> Self {
89        let mut root_store = rustls::RootCertStore::empty();
90        root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
91            rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
92                ta.subject,
93                ta.spki,
94                ta.name_constraints,
95            )
96        }));
97
98        let cfg = rustls::ClientConfig::builder()
99            .with_safe_defaults()
100            .with_root_certificates(root_store)
101            .with_no_client_auth();
102
103        Self {
104            inner_connector: Default::default(),
105            tls_connector: cfg.into(),
106        }
107    }
108
109    #[cfg(feature = "native-tls")]
110    fn default() -> Self {
111        Self {
112            inner_connector: Default::default(),
113            tls_connector: native_tls::TlsConnector::builder().build().unwrap().into(),
114        }
115    }
116}
117
118#[cfg(not(feature = "native-tls"))]
119impl<C, T> Connector<T> for TlsConnector<C>
120where
121    T: service_async::Param<super::key::ServerName>,
122    C: Connector<T, Error = std::io::Error>,
123    C::Connection: AsyncReadRent + AsyncWriteRent,
124{
125    type Connection = TlsStream<C::Connection>;
126    type Error = monoio_rustls::TlsError;
127
128    async fn connect(&self, key: T) -> Result<Self::Connection, Self::Error> {
129        let server_name = key.param();
130
131        let stream = self.inner_connector.connect(key).await?;
132        let tls_stream = self.tls_connector.connect(server_name, stream).await?;
133        Ok(tls_stream)
134    }
135}
136
137#[cfg(feature = "native-tls")]
138impl<C, T> Connector<T> for TlsConnector<C>
139where
140    T: service_async::Param<super::key::ServerName>,
141    C: Connector<T, Error = std::io::Error>,
142    C::Connection: AsyncReadRent + AsyncWriteRent,
143{
144    type Connection = TlsStream<C::Connection>;
145    type Error = monoio_native_tls::TlsError;
146
147    async fn connect(&self, key: T) -> Result<Self::Connection, Self::Error> {
148        let server_name = key.param();
149
150        let stream = self.inner_connector.connect(key).await?;
151        self.tls_connector.connect(&server_name.0, stream).await
152    }
153}
154
155#[derive(Clone)]
156pub struct HttpConnector {
157    conn_config: ConnectionConfig,
158}
159
160impl HttpConnector {
161    pub fn new(conn_config: ConnectionConfig) -> Self {
162        Self { conn_config }
163    }
164
165    pub async fn connect<IO>(&self, io: IO, version: Version) -> crate::Result<HttpConnection<IO>>
166    where
167        IO: AsyncReadRent + AsyncWriteRent + Split + Unpin + 'static,
168    {
169        let proto = if self.conn_config.proto == Proto::Auto {
170            version // Use version from the header
171        } else {
172            match self.conn_config.proto {
173                Proto::Http1 => Version::HTTP_11,
174                Proto::Http2 => Version::HTTP_2,
175                Proto::Auto => unreachable!(),
176            }
177        };
178
179        match proto {
180            Version::HTTP_11 => Ok(HttpConnection::H1(ClientCodec::new(io))),
181            Version::HTTP_2 => {
182                let (send_request, h2_conn) = self.conn_config.h2_builder.handshake(io).await?;
183                monoio::spawn(async move {
184                    if let Err(e) = h2_conn.await {
185                        println!("H2 CONN ERR={:?}", e);
186                    }
187                });
188                Ok(HttpConnection::H2(send_request))
189            }
190            _ => {
191                unreachable!()
192            }
193        }
194    }
195}
196
197/// PooledConnector does 2 things:
198/// 1. pool
199/// 2. combine connection with codec(of cause with buffer)
200pub struct PooledConnector<TC, K, IO: AsyncWriteRent> {
201    global_config: ClientGlobalConfig,
202    transport_connector: TC,
203    http_connector: HttpConnector,
204    pool: ConnectionPool<K, IO>,
205}
206
207impl<TC: Clone, K, IO: AsyncWriteRent> Clone for PooledConnector<TC, K, IO> {
208    fn clone(&self) -> Self {
209        Self {
210            global_config: self.global_config.clone(),
211            transport_connector: self.transport_connector.clone(),
212            http_connector: self.http_connector.clone(),
213            pool: self.pool.clone(),
214        }
215    }
216}
217
218impl<TC, K, IO: AsyncWriteRent> std::fmt::Debug for PooledConnector<TC, K, IO> {
219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220        write!(f, "PooledConnector")
221    }
222}
223
224impl<TC, K: 'static, IO: AsyncWriteRent + 'static> PooledConnector<TC, K, IO>
225where
226    TC: Default,
227{
228    pub fn new_default(global_config: ClientGlobalConfig, c_config: ConnectionConfig) -> Self {
229        Self {
230            global_config,
231            transport_connector: Default::default(),
232            http_connector: HttpConnector::new(c_config),
233            pool: ConnectionPool::default(),
234        }
235    }
236}
237
238impl<TC, K: 'static, IO: AsyncWriteRent + 'static> PooledConnector<TC, K, IO> {
239    pub fn new(
240        global_config: ClientGlobalConfig,
241        c_config: ConnectionConfig,
242        connector: TC,
243    ) -> Self {
244        Self {
245            global_config,
246            transport_connector: connector,
247            http_connector: HttpConnector::new(c_config),
248            pool: ConnectionPool::default(),
249        }
250    }
251}
252
253impl<TC, K, IO> Connector<K> for PooledConnector<TC, K, IO>
254where
255    K: ToSocketAddrs + Hash + Eq + ToOwned<Owned = K> + Display + HttpVersion + 'static,
256    TC: Connector<K, Connection = IO>,
257    IO: AsyncReadRent + AsyncWriteRent + Split + Unpin + 'static,
258    crate::Error: From<<TC as Connector<K>>::Error>,
259{
260    type Connection = PooledConnection<K, IO>;
261    type Error = crate::Error;
262
263    async fn connect(&self, key: K) -> Result<Self::Connection, Self::Error> {
264        if let Some(conn) = self.pool.get(&key) {
265            return Ok(conn);
266        }
267        let key_owned = key.to_owned();
268        let io = self.transport_connector.connect(key).await?;
269
270        let pipe = self
271            .http_connector
272            .connect(io, key_owned.get_version())
273            .await?;
274        Ok(self.pool.link(key_owned, pipe))
275    }
276}