monoio_http_client/client/
connector.rs1use std::{
2 fmt::{Debug, Display},
3 future::Future,
4 hash::Hash,
5 io,
6 net::ToSocketAddrs,
7 path::Path,
8};
9
10use http::Version;
11use monoio::{
12 io::{AsyncReadRent, AsyncWriteRent, Split},
13 net::{TcpStream, UnixStream},
14};
15use monoio_http::h1::codec::ClientCodec;
16
17use super::{
18 connection::HttpConnection,
19 key::HttpVersion,
20 pool::{ConnectionPool, PooledConnection},
21 ClientGlobalConfig, ConnectionConfig, Proto,
22};
23
24#[cfg(not(feature = "native-tls"))]
25pub type TlsStream<C> = monoio_rustls::ClientTlsStream<C>;
26
27#[cfg(feature = "native-tls")]
28pub type TlsStream<C> = monoio_native_tls::TlsStream<C>;
29
30pub trait Connector<K> {
31 type Connection;
32 type Error;
33
34 fn connect(&self, key: K) -> impl Future<Output = Result<Self::Connection, Self::Error>>;
35}
36
37#[derive(Default, Clone, Debug)]
38pub struct TcpConnector;
39
40impl<T> Connector<T> for TcpConnector
41where
42 T: ToSocketAddrs,
43{
44 type Connection = TcpStream;
45 type Error = io::Error;
46
47 async fn connect(&self, key: T) -> Result<Self::Connection, Self::Error> {
48 TcpStream::connect(key).await.map(|io| {
49 let _ = io.set_nodelay(true);
51 io
52 })
53 }
54}
55
56#[derive(Default, Clone, Debug)]
57pub struct UnixConnector;
58
59impl<P> Connector<P> for UnixConnector
60where
61 P: AsRef<Path>,
62{
63 type Connection = UnixStream;
64 type Error = io::Error;
65
66 async fn connect(&self, key: P) -> Result<Self::Connection, Self::Error> {
67 UnixStream::connect(key).await
68 }
69}
70
71#[derive(Clone)]
72pub struct TlsConnector<C> {
73 inner_connector: C,
74 #[cfg(not(feature = "native-tls"))]
75 tls_connector: monoio_rustls::TlsConnector,
76 #[cfg(feature = "native-tls")]
77 tls_connector: monoio_native_tls::TlsConnector,
78}
79
80impl<C: Debug> std::fmt::Debug for TlsConnector<C> {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(f, "TlsConnector, inner: {:?}", self.inner_connector)
83 }
84}
85
86impl<C: Default> Default for TlsConnector<C> {
87 #[cfg(not(feature = "native-tls"))]
88 fn default() -> Self {
89 let mut root_store = rustls::RootCertStore::empty();
90 root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
91 rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
92 ta.subject,
93 ta.spki,
94 ta.name_constraints,
95 )
96 }));
97
98 let cfg = rustls::ClientConfig::builder()
99 .with_safe_defaults()
100 .with_root_certificates(root_store)
101 .with_no_client_auth();
102
103 Self {
104 inner_connector: Default::default(),
105 tls_connector: cfg.into(),
106 }
107 }
108
109 #[cfg(feature = "native-tls")]
110 fn default() -> Self {
111 Self {
112 inner_connector: Default::default(),
113 tls_connector: native_tls::TlsConnector::builder().build().unwrap().into(),
114 }
115 }
116}
117
118#[cfg(not(feature = "native-tls"))]
119impl<C, T> Connector<T> for TlsConnector<C>
120where
121 T: service_async::Param<super::key::ServerName>,
122 C: Connector<T, Error = std::io::Error>,
123 C::Connection: AsyncReadRent + AsyncWriteRent,
124{
125 type Connection = TlsStream<C::Connection>;
126 type Error = monoio_rustls::TlsError;
127
128 async fn connect(&self, key: T) -> Result<Self::Connection, Self::Error> {
129 let server_name = key.param();
130
131 let stream = self.inner_connector.connect(key).await?;
132 let tls_stream = self.tls_connector.connect(server_name, stream).await?;
133 Ok(tls_stream)
134 }
135}
136
137#[cfg(feature = "native-tls")]
138impl<C, T> Connector<T> for TlsConnector<C>
139where
140 T: service_async::Param<super::key::ServerName>,
141 C: Connector<T, Error = std::io::Error>,
142 C::Connection: AsyncReadRent + AsyncWriteRent,
143{
144 type Connection = TlsStream<C::Connection>;
145 type Error = monoio_native_tls::TlsError;
146
147 async fn connect(&self, key: T) -> Result<Self::Connection, Self::Error> {
148 let server_name = key.param();
149
150 let stream = self.inner_connector.connect(key).await?;
151 self.tls_connector.connect(&server_name.0, stream).await
152 }
153}
154
155#[derive(Clone)]
156pub struct HttpConnector {
157 conn_config: ConnectionConfig,
158}
159
160impl HttpConnector {
161 pub fn new(conn_config: ConnectionConfig) -> Self {
162 Self { conn_config }
163 }
164
165 pub async fn connect<IO>(&self, io: IO, version: Version) -> crate::Result<HttpConnection<IO>>
166 where
167 IO: AsyncReadRent + AsyncWriteRent + Split + Unpin + 'static,
168 {
169 let proto = if self.conn_config.proto == Proto::Auto {
170 version } else {
172 match self.conn_config.proto {
173 Proto::Http1 => Version::HTTP_11,
174 Proto::Http2 => Version::HTTP_2,
175 Proto::Auto => unreachable!(),
176 }
177 };
178
179 match proto {
180 Version::HTTP_11 => Ok(HttpConnection::H1(ClientCodec::new(io))),
181 Version::HTTP_2 => {
182 let (send_request, h2_conn) = self.conn_config.h2_builder.handshake(io).await?;
183 monoio::spawn(async move {
184 if let Err(e) = h2_conn.await {
185 println!("H2 CONN ERR={:?}", e);
186 }
187 });
188 Ok(HttpConnection::H2(send_request))
189 }
190 _ => {
191 unreachable!()
192 }
193 }
194 }
195}
196
197pub struct PooledConnector<TC, K, IO: AsyncWriteRent> {
201 global_config: ClientGlobalConfig,
202 transport_connector: TC,
203 http_connector: HttpConnector,
204 pool: ConnectionPool<K, IO>,
205}
206
207impl<TC: Clone, K, IO: AsyncWriteRent> Clone for PooledConnector<TC, K, IO> {
208 fn clone(&self) -> Self {
209 Self {
210 global_config: self.global_config.clone(),
211 transport_connector: self.transport_connector.clone(),
212 http_connector: self.http_connector.clone(),
213 pool: self.pool.clone(),
214 }
215 }
216}
217
218impl<TC, K, IO: AsyncWriteRent> std::fmt::Debug for PooledConnector<TC, K, IO> {
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 write!(f, "PooledConnector")
221 }
222}
223
224impl<TC, K: 'static, IO: AsyncWriteRent + 'static> PooledConnector<TC, K, IO>
225where
226 TC: Default,
227{
228 pub fn new_default(global_config: ClientGlobalConfig, c_config: ConnectionConfig) -> Self {
229 Self {
230 global_config,
231 transport_connector: Default::default(),
232 http_connector: HttpConnector::new(c_config),
233 pool: ConnectionPool::default(),
234 }
235 }
236}
237
238impl<TC, K: 'static, IO: AsyncWriteRent + 'static> PooledConnector<TC, K, IO> {
239 pub fn new(
240 global_config: ClientGlobalConfig,
241 c_config: ConnectionConfig,
242 connector: TC,
243 ) -> Self {
244 Self {
245 global_config,
246 transport_connector: connector,
247 http_connector: HttpConnector::new(c_config),
248 pool: ConnectionPool::default(),
249 }
250 }
251}
252
253impl<TC, K, IO> Connector<K> for PooledConnector<TC, K, IO>
254where
255 K: ToSocketAddrs + Hash + Eq + ToOwned<Owned = K> + Display + HttpVersion + 'static,
256 TC: Connector<K, Connection = IO>,
257 IO: AsyncReadRent + AsyncWriteRent + Split + Unpin + 'static,
258 crate::Error: From<<TC as Connector<K>>::Error>,
259{
260 type Connection = PooledConnection<K, IO>;
261 type Error = crate::Error;
262
263 async fn connect(&self, key: K) -> Result<Self::Connection, Self::Error> {
264 if let Some(conn) = self.pool.get(&key) {
265 return Ok(conn);
266 }
267 let key_owned = key.to_owned();
268 let io = self.transport_connector.connect(key).await?;
269
270 let pipe = self
271 .http_connector
272 .connect(io, key_owned.get_version())
273 .await?;
274 Ok(self.pool.link(key_owned, pipe))
275 }
276}