1use std::fmt;
2use std::marker::PhantomData;
3use std::time::Duration;
4
5use requiem_codec::{AsyncRead, AsyncWrite};
6use requiem_connect::{
7 default_connector, Connect as TcpConnect, Connection as TcpConnection,
8};
9use requiem_rt::net::TcpStream;
10use requiem_service::{apply_fn, Service};
11use requiem_utils::timeout::{TimeoutError, TimeoutService};
12use http::Uri;
13
14use super::connection::Connection;
15use super::error::ConnectError;
16use super::pool::{ConnectionPool, Protocol};
17use super::Connect;
18
19#[cfg(feature = "openssl")]
20use requiem_connect::ssl::openssl::SslConnector as OpensslConnector;
21
22#[cfg(feature = "rustls")]
23use requiem_connect::ssl::rustls::ClientConfig;
24#[cfg(feature = "rustls")]
25use std::sync::Arc;
26
27#[cfg(any(feature = "openssl", feature = "rustls"))]
28enum SslConnector {
29 #[cfg(feature = "openssl")]
30 Openssl(OpensslConnector),
31 #[cfg(feature = "rustls")]
32 Rustls(Arc<ClientConfig>),
33}
34#[cfg(not(any(feature = "openssl", feature = "rustls")))]
35type SslConnector = ();
36
37pub struct Connector<T, U> {
50 connector: T,
51 timeout: Duration,
52 conn_lifetime: Duration,
53 conn_keep_alive: Duration,
54 disconnect_timeout: Duration,
55 limit: usize,
56 #[allow(dead_code)]
57 ssl: SslConnector,
58 _t: PhantomData<U>,
59}
60
61trait Io: AsyncRead + AsyncWrite + Unpin {}
62impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
63
64impl Connector<(), ()> {
65 #[allow(clippy::new_ret_no_self, clippy::let_unit_value)]
66 pub fn new() -> Connector<
67 impl Service<
68 Request = TcpConnect<Uri>,
69 Response = TcpConnection<Uri, TcpStream>,
70 Error = requiem_connect::ConnectError,
71 > + Clone,
72 TcpStream,
73 > {
74 let ssl = {
75 #[cfg(feature = "openssl")]
76 {
77 use requiem_connect::ssl::openssl::SslMethod;
78
79 let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap();
80 let _ = ssl
81 .set_alpn_protos(b"\x02h2\x08http/1.1")
82 .map_err(|e| error!("Can not set alpn protocol: {:?}", e));
83 SslConnector::Openssl(ssl.build())
84 }
85 #[cfg(all(not(feature = "openssl"), feature = "rustls"))]
86 {
87 let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
88 let mut config = ClientConfig::new();
89 config.set_protocols(&protos);
90 config
91 .root_store
92 .add_server_trust_anchors(&requiem_tls::rustls::TLS_SERVER_ROOTS);
93 SslConnector::Rustls(Arc::new(config))
94 }
95 #[cfg(not(any(feature = "openssl", feature = "rustls")))]
96 {}
97 };
98
99 Connector {
100 ssl,
101 connector: default_connector(),
102 timeout: Duration::from_secs(1),
103 conn_lifetime: Duration::from_secs(75),
104 conn_keep_alive: Duration::from_secs(15),
105 disconnect_timeout: Duration::from_millis(3000),
106 limit: 100,
107 _t: PhantomData,
108 }
109 }
110}
111
112impl<T, U> Connector<T, U> {
113 pub fn connector<T1, U1>(self, connector: T1) -> Connector<T1, U1>
115 where
116 U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
117 T1: Service<
118 Request = TcpConnect<Uri>,
119 Response = TcpConnection<Uri, U1>,
120 Error = requiem_connect::ConnectError,
121 > + Clone,
122 {
123 Connector {
124 connector,
125 timeout: self.timeout,
126 conn_lifetime: self.conn_lifetime,
127 conn_keep_alive: self.conn_keep_alive,
128 disconnect_timeout: self.disconnect_timeout,
129 limit: self.limit,
130 ssl: self.ssl,
131 _t: PhantomData,
132 }
133 }
134}
135
136impl<T, U> Connector<T, U>
137where
138 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
139 T: Service<
140 Request = TcpConnect<Uri>,
141 Response = TcpConnection<Uri, U>,
142 Error = requiem_connect::ConnectError,
143 > + Clone
144 + 'static,
145{
146 pub fn timeout(mut self, timeout: Duration) -> Self {
149 self.timeout = timeout;
150 self
151 }
152
153 #[cfg(feature = "openssl")]
154 pub fn ssl(mut self, connector: OpensslConnector) -> Self {
156 self.ssl = SslConnector::Openssl(connector);
157 self
158 }
159
160 #[cfg(feature = "rustls")]
161 pub fn rustls(mut self, connector: Arc<ClientConfig>) -> Self {
162 self.ssl = SslConnector::Rustls(connector);
163 self
164 }
165
166 pub fn limit(mut self, limit: usize) -> Self {
171 self.limit = limit;
172 self
173 }
174
175 pub fn conn_keep_alive(mut self, dur: Duration) -> Self {
182 self.conn_keep_alive = dur;
183 self
184 }
185
186 pub fn conn_lifetime(mut self, dur: Duration) -> Self {
192 self.conn_lifetime = dur;
193 self
194 }
195
196 pub fn disconnect_timeout(mut self, dur: Duration) -> Self {
205 self.disconnect_timeout = dur;
206 self
207 }
208
209 pub fn finish(
213 self,
214 ) -> impl Service<Request = Connect, Response = impl Connection, Error = ConnectError>
215 + Clone {
216 #[cfg(not(any(feature = "openssl", feature = "rustls")))]
217 {
218 let connector = TimeoutService::new(
219 self.timeout,
220 apply_fn(self.connector, |msg: Connect, srv| {
221 srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
222 })
223 .map_err(ConnectError::from)
224 .map(|stream| (stream.into_parts().0, Protocol::Http1)),
225 )
226 .map_err(|e| match e {
227 TimeoutError::Service(e) => e,
228 TimeoutError::Timeout => ConnectError::Timeout,
229 });
230
231 connect_impl::InnerConnector {
232 tcp_pool: ConnectionPool::new(
233 connector,
234 self.conn_lifetime,
235 self.conn_keep_alive,
236 None,
237 self.limit,
238 ),
239 }
240 }
241 #[cfg(any(feature = "openssl", feature = "rustls"))]
242 {
243 const H2: &[u8] = b"h2";
244 #[cfg(feature = "openssl")]
245 use requiem_connect::ssl::openssl::OpensslConnector;
246 #[cfg(feature = "rustls")]
247 use requiem_connect::ssl::rustls::{RustlsConnector, Session};
248 use requiem_service::{boxed::service, pipeline};
249
250 let ssl_service = TimeoutService::new(
251 self.timeout,
252 pipeline(
253 apply_fn(self.connector.clone(), |msg: Connect, srv| {
254 srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
255 })
256 .map_err(ConnectError::from),
257 )
258 .and_then(match self.ssl {
259 #[cfg(feature = "openssl")]
260 SslConnector::Openssl(ssl) => service(
261 OpensslConnector::service(ssl)
262 .map(|stream| {
263 let sock = stream.into_parts().0;
264 let h2 = sock
265 .ssl()
266 .selected_alpn_protocol()
267 .map(|protos| protos.windows(2).any(|w| w == H2))
268 .unwrap_or(false);
269 if h2 {
270 (Box::new(sock) as Box<dyn Io>, Protocol::Http2)
271 } else {
272 (Box::new(sock) as Box<dyn Io>, Protocol::Http1)
273 }
274 })
275 .map_err(ConnectError::from),
276 ),
277 #[cfg(feature = "rustls")]
278 SslConnector::Rustls(ssl) => service(
279 RustlsConnector::service(ssl)
280 .map_err(ConnectError::from)
281 .map(|stream| {
282 let sock = stream.into_parts().0;
283 let h2 = sock
284 .get_ref()
285 .1
286 .get_alpn_protocol()
287 .map(|protos| protos.windows(2).any(|w| w == H2))
288 .unwrap_or(false);
289 if h2 {
290 (Box::new(sock) as Box<dyn Io>, Protocol::Http2)
291 } else {
292 (Box::new(sock) as Box<dyn Io>, Protocol::Http1)
293 }
294 }),
295 ),
296 }),
297 )
298 .map_err(|e| match e {
299 TimeoutError::Service(e) => e,
300 TimeoutError::Timeout => ConnectError::Timeout,
301 });
302
303 let tcp_service = TimeoutService::new(
304 self.timeout,
305 apply_fn(self.connector, |msg: Connect, srv| {
306 srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
307 })
308 .map_err(ConnectError::from)
309 .map(|stream| (stream.into_parts().0, Protocol::Http1)),
310 )
311 .map_err(|e| match e {
312 TimeoutError::Service(e) => e,
313 TimeoutError::Timeout => ConnectError::Timeout,
314 });
315
316 connect_impl::InnerConnector {
317 tcp_pool: ConnectionPool::new(
318 tcp_service,
319 self.conn_lifetime,
320 self.conn_keep_alive,
321 None,
322 self.limit,
323 ),
324 ssl_pool: ConnectionPool::new(
325 ssl_service,
326 self.conn_lifetime,
327 self.conn_keep_alive,
328 Some(self.disconnect_timeout),
329 self.limit,
330 ),
331 }
332 }
333 }
334}
335
336#[cfg(not(any(feature = "openssl", feature = "rustls")))]
337mod connect_impl {
338 use std::task::{Context, Poll};
339
340 use futures_util::future::{err, Either, Ready};
341
342 use super::*;
343 use crate::client::connection::IoConnection;
344
345 pub(crate) struct InnerConnector<T, Io>
346 where
347 Io: AsyncRead + AsyncWrite + Unpin + 'static,
348 T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
349 + 'static,
350 {
351 pub(crate) tcp_pool: ConnectionPool<T, Io>,
352 }
353
354 impl<T, Io> Clone for InnerConnector<T, Io>
355 where
356 Io: AsyncRead + AsyncWrite + Unpin + 'static,
357 T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
358 + 'static,
359 {
360 fn clone(&self) -> Self {
361 InnerConnector {
362 tcp_pool: self.tcp_pool.clone(),
363 }
364 }
365 }
366
367 impl<T, Io> Service for InnerConnector<T, Io>
368 where
369 Io: AsyncRead + AsyncWrite + Unpin + 'static,
370 T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
371 + 'static,
372 {
373 type Request = Connect;
374 type Response = IoConnection<Io>;
375 type Error = ConnectError;
376 type Future = Either<
377 <ConnectionPool<T, Io> as Service>::Future,
378 Ready<Result<IoConnection<Io>, ConnectError>>,
379 >;
380
381 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
382 self.tcp_pool.poll_ready(cx)
383 }
384
385 fn call(&mut self, req: Connect) -> Self::Future {
386 match req.uri.scheme_str() {
387 Some("https") | Some("wss") => {
388 Either::Right(err(ConnectError::SslIsNotSupported))
389 }
390 _ => Either::Left(self.tcp_pool.call(req)),
391 }
392 }
393 }
394}
395
396#[cfg(any(feature = "openssl", feature = "rustls"))]
397mod connect_impl {
398 use std::future::Future;
399 use std::marker::PhantomData;
400 use std::pin::Pin;
401 use std::task::{Context, Poll};
402
403 use futures_core::ready;
404 use futures_util::future::Either;
405
406 use super::*;
407 use crate::client::connection::EitherConnection;
408
409 pub(crate) struct InnerConnector<T1, T2, Io1, Io2>
410 where
411 Io1: AsyncRead + AsyncWrite + Unpin + 'static,
412 Io2: AsyncRead + AsyncWrite + Unpin + 'static,
413 T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>,
414 T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>,
415 {
416 pub(crate) tcp_pool: ConnectionPool<T1, Io1>,
417 pub(crate) ssl_pool: ConnectionPool<T2, Io2>,
418 }
419
420 impl<T1, T2, Io1, Io2> Clone for InnerConnector<T1, T2, Io1, Io2>
421 where
422 Io1: AsyncRead + AsyncWrite + Unpin + 'static,
423 Io2: AsyncRead + AsyncWrite + Unpin + 'static,
424 T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
425 + 'static,
426 T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
427 + 'static,
428 {
429 fn clone(&self) -> Self {
430 InnerConnector {
431 tcp_pool: self.tcp_pool.clone(),
432 ssl_pool: self.ssl_pool.clone(),
433 }
434 }
435 }
436
437 impl<T1, T2, Io1, Io2> Service for InnerConnector<T1, T2, Io1, Io2>
438 where
439 Io1: AsyncRead + AsyncWrite + Unpin + 'static,
440 Io2: AsyncRead + AsyncWrite + Unpin + 'static,
441 T1: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
442 + 'static,
443 T2: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
444 + 'static,
445 {
446 type Request = Connect;
447 type Response = EitherConnection<Io1, Io2>;
448 type Error = ConnectError;
449 type Future = Either<
450 InnerConnectorResponseA<T1, Io1, Io2>,
451 InnerConnectorResponseB<T2, Io1, Io2>,
452 >;
453
454 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
455 self.tcp_pool.poll_ready(cx)
456 }
457
458 fn call(&mut self, req: Connect) -> Self::Future {
459 match req.uri.scheme_str() {
460 Some("https") | Some("wss") => Either::Right(InnerConnectorResponseB {
461 fut: self.ssl_pool.call(req),
462 _t: PhantomData,
463 }),
464 _ => Either::Left(InnerConnectorResponseA {
465 fut: self.tcp_pool.call(req),
466 _t: PhantomData,
467 }),
468 }
469 }
470 }
471
472 #[pin_project::pin_project]
473 pub(crate) struct InnerConnectorResponseA<T, Io1, Io2>
474 where
475 Io1: AsyncRead + AsyncWrite + Unpin + 'static,
476 T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
477 + 'static,
478 {
479 #[pin]
480 fut: <ConnectionPool<T, Io1> as Service>::Future,
481 _t: PhantomData<Io2>,
482 }
483
484 impl<T, Io1, Io2> Future for InnerConnectorResponseA<T, Io1, Io2>
485 where
486 T: Service<Request = Connect, Response = (Io1, Protocol), Error = ConnectError>
487 + 'static,
488 Io1: AsyncRead + AsyncWrite + Unpin + 'static,
489 Io2: AsyncRead + AsyncWrite + Unpin + 'static,
490 {
491 type Output = Result<EitherConnection<Io1, Io2>, ConnectError>;
492
493 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
494 Poll::Ready(
495 ready!(Pin::new(&mut self.get_mut().fut).poll(cx))
496 .map(EitherConnection::A),
497 )
498 }
499 }
500
501 #[pin_project::pin_project]
502 pub(crate) struct InnerConnectorResponseB<T, Io1, Io2>
503 where
504 Io2: AsyncRead + AsyncWrite + Unpin + 'static,
505 T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
506 + 'static,
507 {
508 #[pin]
509 fut: <ConnectionPool<T, Io2> as Service>::Future,
510 _t: PhantomData<Io1>,
511 }
512
513 impl<T, Io1, Io2> Future for InnerConnectorResponseB<T, Io1, Io2>
514 where
515 T: Service<Request = Connect, Response = (Io2, Protocol), Error = ConnectError>
516 + 'static,
517 Io1: AsyncRead + AsyncWrite + Unpin + 'static,
518 Io2: AsyncRead + AsyncWrite + Unpin + 'static,
519 {
520 type Output = Result<EitherConnection<Io1, Io2>, ConnectError>;
521
522 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
523 Poll::Ready(
524 ready!(Pin::new(&mut self.get_mut().fut).poll(cx))
525 .map(EitherConnection::B),
526 )
527 }
528 }
529}