1use std::{cell::UnsafeCell, collections::HashMap, rc::Rc, time::Duration};
2
3use monoio::io::{AsyncReadRent, AsyncWriteRent, Split};
4use monoio_http::{h1::codec::ClientCodec, h2::client::Builder as MonoioH2Builder};
5
6use super::connection::{Http1Connection, Http2Connection, HttpConnection};
7use crate::{
8 connectors::{Connector, TcpConnector, TlsConnector, TransportConnMeta, TransportConnMetadata},
9 pool::{ConnectionPool, Key, Pooled},
10};
11
12#[derive(Debug, Default, Copy, Clone, PartialEq, Eq)]
13enum Protocol {
14 HTTP2,
15 HTTP11,
16 #[default]
17 Auto,
18}
19
20pub struct HttpConnector<C, K, IO: AsyncWriteRent> {
44 connector: C,
45 protocol: Protocol, h1_pool: Option<ConnectionPool<K, Http1Connection<IO>>>,
47 h2_pool: ConnectionPool<K, Http2Connection>,
48 connecting: UnsafeCell<HashMap<K, Rc<local_sync::semaphore::Semaphore>>>,
49 h2_builder: MonoioH2Builder,
50 pub read_timeout: Option<Duration>,
51}
52
53impl<C: Clone, K, IO: AsyncWriteRent> Clone for HttpConnector<C, K, IO> {
54 fn clone(&self) -> Self {
55 Self {
56 connector: self.connector.clone(),
57 h1_pool: self.h1_pool.clone(),
58 h2_pool: self.h2_pool.clone(),
59 protocol: self.protocol,
60 connecting: UnsafeCell::new(HashMap::new()),
61 read_timeout: self.read_timeout,
62 h2_builder: self.h2_builder.clone(),
63 }
64 }
65}
66
67impl<C, K: 'static, IO: AsyncWriteRent + 'static> HttpConnector<C, K, IO> {
68 #[inline]
69 pub fn new(connector: C) -> Self {
70 Self {
71 connector,
72 protocol: Protocol::default(),
73 h1_pool: Some(ConnectionPool::default()),
74 h2_pool: ConnectionPool::new(None),
75 connecting: UnsafeCell::new(HashMap::new()),
76 h2_builder: MonoioH2Builder::default(),
77 read_timeout: None,
78 }
79 }
80
81 #[inline]
85 #[allow(unused)]
86 pub fn set_read_timeout(&mut self, timeout: Option<Duration>) {
87 self.read_timeout = timeout;
88 }
89
90 pub fn set_http1_only(&mut self) {
102 self.protocol = Protocol::HTTP11
103 }
104
105 pub fn set_http2_only(&mut self) {
117 self.protocol = Protocol::HTTP2
118 }
119
120 #[inline]
121 pub fn h2_builder(&mut self) -> &mut MonoioH2Builder {
122 &mut self.h2_builder
123 }
124
125 fn is_config_h2(&self) -> bool {
126 matches!(self.protocol, Protocol::HTTP2)
127 }
128
129 fn is_config_h1(&self) -> bool {
130 matches!(self.protocol, Protocol::HTTP11)
131 }
132
133 fn is_config_auto(&self) -> bool {
134 matches!(self.protocol, Protocol::Auto)
135 }
136
137 pub fn transfer_pool(old: &Self, new: &mut Self) -> Result<(), &'static str> {
159 if old.protocol != new.protocol {
160 return Err("Protocols do not match");
161 }
162 if old.read_timeout != new.read_timeout {
163 return Err("Read timeouts do not match");
164 }
165
166 new.h1_pool = old.h1_pool.clone();
167 new.h2_pool = old.h2_pool.clone();
168
169 Ok(())
170 }
171}
172
173impl<K: 'static, IO: AsyncWriteRent + 'static> HttpConnector<TcpConnector, K, IO> {
174 pub fn build_tcp_http1_only() -> Self {
184 Self {
185 connector: TcpConnector::default(),
186 protocol: Protocol::HTTP11,
187 h1_pool: Some(ConnectionPool::default()),
188 h2_pool: ConnectionPool::new(None),
189 connecting: UnsafeCell::new(HashMap::new()),
190 h2_builder: MonoioH2Builder::default(),
191 read_timeout: None,
192 }
193 }
194
195 pub fn build_tcp_http2_only() -> Self {
205 Self {
206 connector: TcpConnector::default(),
207 protocol: Protocol::HTTP2,
208 h1_pool: Some(ConnectionPool::default()),
209 h2_pool: ConnectionPool::new(None),
210 connecting: UnsafeCell::new(HashMap::new()),
211 h2_builder: MonoioH2Builder::default(),
212 read_timeout: None,
213 }
214 }
215}
216
217impl<C: Default, K: 'static, IO: AsyncWriteRent + 'static> HttpConnector<TlsConnector<C>, K, IO> {
218 pub fn build_tls_http1_only() -> Self {
228 let alpn = vec!["http/1.1"];
229 let tls_connector = TlsConnector::new_with_tls_default(C::default(), Some(alpn));
230 Self {
231 connector: tls_connector,
232 protocol: Protocol::default(),
233 h1_pool: Some(ConnectionPool::default()),
234 h2_pool: ConnectionPool::new(None),
235 connecting: UnsafeCell::new(HashMap::new()),
236 h2_builder: MonoioH2Builder::default(),
237 read_timeout: None,
238 }
239 }
240
241 pub fn build_tls_http2_only() -> Self {
251 let alpn = vec!["h2"];
252 let tls_connector = TlsConnector::new_with_tls_default(C::default(), Some(alpn));
253 Self {
254 connector: tls_connector,
255 protocol: Protocol::default(),
256 h1_pool: Some(ConnectionPool::default()),
257 h2_pool: ConnectionPool::new(None),
258 connecting: UnsafeCell::new(HashMap::new()),
259 h2_builder: MonoioH2Builder::default(),
260 read_timeout: None,
261 }
262 }
263}
264
265impl<C: Default, K: 'static, IO: AsyncWriteRent + 'static> Default for HttpConnector<C, K, IO> {
266 #[inline]
268 fn default() -> Self {
269 HttpConnector::new(C::default())
270 }
271}
272
273macro_rules! try_get {
274 ($self:ident, $pool:ident, $key:ident) => {
275 $self.$pool.and_then_mut(&$key, |mut conns| {
276 conns.retain(|idle| {
277 match idle.conn.conn_error() {
279 Some(_e) => {
280 println!("Removing connection");
281 #[cfg(feature = "logging")]
282 tracing::debug!("Removing invalid connection: {:?}", _e);
283 false
284 }
285 None => true,
286 }
287 });
288
289 conns.front().map(|idle| idle.conn.to_owned())
290 })
291 };
292}
293
294impl<C, K: Key, IO> Connector<K> for HttpConnector<C, K, IO>
295where
296 C: Connector<K, Connection = IO>,
297 C::Connection: TransportConnMetadata<Metadata = TransportConnMeta>,
298 crate::TransportError: From<C::Error>,
299 IO: AsyncReadRent + AsyncWriteRent + Split + Unpin + 'static,
300{
301 type Connection = HttpConnection<K, IO>;
302 type Error = crate::TransportError;
303
304 async fn connect(&self, key: K) -> Result<Self::Connection, Self::Error> {
305 if self.is_config_auto() || self.is_config_h2() {
306 if let Some(conn) = try_get!(self, h2_pool, key) {
307 return Ok(conn.into());
308 }
309 }
310
311 if self.is_config_auto() || self.is_config_h1() {
312 if let Some(h1_pool) = &self.h1_pool {
313 if let Some(h1_pooled) = h1_pool.get(&key) {
314 return Ok(h1_pooled.into());
315 }
316 }
317 }
318
319 let transport_conn = self.connector.connect(key.clone()).await?;
321 let conn_meta = transport_conn.get_conn_metadata();
322
323 let connect_to_h2 = self.is_config_h2() || conn_meta.is_alpn_h2();
324
325 if connect_to_h2 {
326 let lock = {
327 let connecting = unsafe { &mut *self.connecting.get() };
328 let lock = connecting
329 .entry(key.clone())
330 .or_insert_with(|| Rc::new(local_sync::semaphore::Semaphore::new(1)));
331 lock.clone()
332 };
333
334 let _guard = lock.acquire().await?;
336 if let Some(conn) = try_get!(self, h2_pool, key) {
337 return Ok(conn.into());
338 }
339
340 let (tx, conn) = self.h2_builder.handshake(transport_conn).await?;
341 monoio::spawn(conn);
342 self.h2_pool.put(key, Http2Connection::new(tx.clone()));
343 Ok(Http2Connection::new(tx.clone()).into())
344 } else {
345 let client_codec = if let Some(timeout) = self.read_timeout {
346 ClientCodec::new_with_timeout(transport_conn, timeout)
347 } else {
348 ClientCodec::new(transport_conn)
349 };
350 let http_conn = Http1Connection::new(client_codec);
351 let pooled = if let Some(pool) = &self.h1_pool {
352 pool.link(key, http_conn)
353 } else {
354 Pooled::unpooled(http_conn)
355 };
356 Ok(pooled.into())
357 }
358 }
359}
360
361pub struct H1Connector<C, K, IO: AsyncWriteRent> {
364 inner_connector: C,
365 pool: Option<ConnectionPool<K, Http1Connection<IO>>>,
366 pub read_timeout: Option<Duration>,
367}
368
369impl<C: Clone, K, IO: AsyncWriteRent> Clone for H1Connector<C, K, IO> {
370 fn clone(&self) -> Self {
371 Self {
372 inner_connector: self.inner_connector.clone(),
373 pool: self.pool.clone(),
374 read_timeout: self.read_timeout,
375 }
376 }
377}
378
379impl<C, K, IO: AsyncWriteRent> H1Connector<C, K, IO> {
380 #[inline]
381 pub const fn new(inner_connector: C) -> Self {
382 Self {
383 inner_connector,
384 pool: None,
385 read_timeout: None,
386 }
387 }
388
389 #[inline]
390 #[allow(unused)]
391 pub const fn new_with_timeout(inner_connector: C, timeout: Duration) -> Self {
392 Self {
393 inner_connector,
394 pool: None,
395 read_timeout: Some(timeout),
396 }
397 }
398
399 #[inline]
400 #[allow(unused)]
401 pub fn pool(&mut self) -> &mut Option<ConnectionPool<K, Http1Connection<IO>>> {
402 &mut self.pool
403 }
404
405 #[inline]
406 #[allow(unused)]
407 pub fn read_timeout(&mut self) -> &mut Option<Duration> {
408 &mut self.read_timeout
409 }
410}
411
412impl<C, K: 'static, IO: AsyncWriteRent + 'static> H1Connector<C, K, IO> {
413 #[inline]
414 #[allow(unused)]
415 pub fn with_default_pool(self) -> Self {
416 #[cfg(not(feature = "time"))]
417 let pool = ConnectionPool::new(None);
418 #[cfg(feature = "time")]
419 const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(60);
420 #[cfg(feature = "time")]
421 let pool = ConnectionPool::new_with_idle_interval(Some(DEFAULT_IDLE_TIMEOUT), None);
422 Self {
423 pool: Some(pool),
424 ..self
425 }
426 }
427}
428
429impl<C: Default, K, IO: AsyncWriteRent> Default for H1Connector<C, K, IO> {
430 #[inline]
431 fn default() -> Self {
432 H1Connector::new(C::default())
433 }
434}
435
436impl<C, K: Key, IO: AsyncWriteRent> Connector<K> for H1Connector<C, K, IO>
437where
438 C: Connector<K, Connection = IO>,
439 IO: AsyncReadRent + AsyncWriteRent + Split,
441{
442 type Connection = Pooled<K, Http1Connection<IO>>;
443 type Error = C::Error;
444
445 #[inline]
446 async fn connect(&self, key: K) -> Result<Self::Connection, Self::Error> {
447 if let Some(pool) = &self.pool {
448 if let Some(conn) = pool.get(&key) {
449 return Ok(conn);
450 }
451 }
452 let io: IO = self.inner_connector.connect(key.clone()).await?;
453 let client_codec = match self.read_timeout {
454 Some(timeout) => ClientCodec::new_with_timeout(io, timeout),
455 None => ClientCodec::new(io),
456 };
457 let http_conn = Http1Connection::new(client_codec);
458 let pooled = if let Some(pool) = &self.pool {
459 pool.link(key, http_conn)
460 } else {
461 Pooled::unpooled(http_conn)
462 };
463 Ok(pooled)
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use std::net::ToSocketAddrs;
470
471 use http::{request, Uri};
472 use monoio_http::{common::body::HttpBody, h1::payload::Payload};
473
474 use super::*;
475 use crate::connectors::{TcpConnector, TcpTlsAddr};
476
477 #[monoio::test(enable_timer = true)]
478 async fn test_default_https_connector() -> Result<(), crate::TransportError> {
479 let connector: HttpConnector<TlsConnector<TcpConnector>, _, _> = HttpConnector::default();
480
481 let uri = "https://httpbin.org/get".parse::<Uri>().unwrap();
482 let addr: TcpTlsAddr = uri.try_into().unwrap();
483 let mut conn = connector.connect(addr).await.unwrap();
484
485 for _ in 0..10 {
486 let req = request::Builder::new()
487 .uri("/get")
488 .header("Host", "httpbin.org")
489 .body(HttpBody::H1(Payload::None))
490 .unwrap();
491 let (res, _) = conn.send_request(req).await;
492 let resp = res?;
493 assert_eq!(200, resp.status());
494 assert_eq!(
495 "application/json".as_bytes(),
496 resp.headers().get("content-type").unwrap().as_bytes()
497 );
498 assert_eq!(resp.version(), http::Version::HTTP_2);
499 }
500 Ok(())
501 }
502
503 #[monoio::test(enable_timer = true)]
504 async fn test_http2_tls_connector() -> Result<(), crate::TransportError> {
505 let connector: HttpConnector<TlsConnector<TcpConnector>, _, _> =
506 HttpConnector::build_tls_http2_only();
507
508 let uri = "https://httpbin.org/get".parse::<Uri>().unwrap();
509 let addr: TcpTlsAddr = uri.try_into().unwrap();
510 let mut conn = connector.connect(addr).await.unwrap();
511
512 for _ in 0..10 {
513 let req = request::Builder::new()
514 .uri("/get")
515 .header("Host", "httpbin.org")
516 .body(HttpBody::H1(Payload::None))
517 .unwrap();
518 let (res, _) = conn.send_request(req).await;
519 let resp = res?;
520 assert_eq!(200, resp.status());
521 assert_eq!(
522 "application/json".as_bytes(),
523 resp.headers().get("content-type").unwrap().as_bytes()
524 );
525 assert_eq!(resp.version(), http::Version::HTTP_2);
526 }
527 Ok(())
528 }
529
530 #[monoio::test(enable_timer = true)]
531 async fn test_http1_tls_connector() -> Result<(), crate::TransportError> {
532 let connector: HttpConnector<TlsConnector<TcpConnector>, _, _> =
533 HttpConnector::build_tls_http1_only();
534
535 let uri = "https://httpbin.org/get".parse::<Uri>().unwrap();
536 let addr: TcpTlsAddr = uri.try_into().unwrap();
537 let mut conn = connector.connect(addr).await.unwrap();
538
539 for _ in 0..10 {
540 let req = request::Builder::new()
541 .uri("/get")
542 .header("Host", "httpbin.org")
543 .body(HttpBody::H1(Payload::None))
544 .unwrap();
545 let (res, _) = conn.send_request(req).await;
546 let resp = res?;
547 assert_eq!(200, resp.status());
548 assert_eq!(
549 "application/json".as_bytes(),
550 resp.headers().get("content-type").unwrap().as_bytes()
551 );
552 assert_eq!(resp.version(), http::Version::HTTP_11);
553 }
554 Ok(())
555 }
556
557 #[monoio::test(enable_timer = true)]
558 async fn test_http1_tcp_connector() -> Result<(), crate::TransportError> {
559 let connector: HttpConnector<TcpConnector, _, _> = HttpConnector::default();
560
561 #[derive(Debug, Clone, Eq, PartialEq, Hash)]
562 struct Key {
563 host: String,
564 port: u16,
565 }
566 impl ToSocketAddrs for Key {
567 type Iter = std::vec::IntoIter<std::net::SocketAddr>;
568 fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
569 (self.host.as_str(), self.port).to_socket_addrs()
570 }
571 }
572
573 for _i in 0..10 {
574 let uri = "http://httpbin.org/get".parse::<Uri>().unwrap();
575 let host = uri.host().unwrap();
576 let port = uri.port_u16().unwrap_or(80);
577 let key = Key {
578 host: host.to_string(),
579 port,
580 };
581 let mut conn = connector.connect(key).await.unwrap();
582 let req = request::Builder::new()
585 .uri("/get")
586 .header("Host", "httpbin.org")
587 .body(HttpBody::H1(Payload::None))
588 .unwrap();
589 let (res, _) = conn.send_request(req).await;
590 let resp = res?;
591 assert_eq!(200, resp.status());
592 assert_eq!(
593 "application/json".as_bytes(),
594 resp.headers().get("content-type").unwrap().as_bytes()
595 );
596 assert_eq!(resp.version(), http::Version::HTTP_11);
597 }
598 Ok(())
599 }
600 }